diff --git a/src/TensorFlowNET.Core/WeakKeyDicionary.cs b/src/TensorFlowNET.Core/WeakKeyDicionary.cs new file mode 100644 index 00000000..98df4f30 --- /dev/null +++ b/src/TensorFlowNET.Core/WeakKeyDicionary.cs @@ -0,0 +1,424 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Text; +using System.Runtime.InteropServices; + +namespace Tensorflow +{ + public class WeakKeyDictionary : IDictionary + { + + private Dictionary _internalDictionary; + private object _internalObject = new object(); + private bool _finalized; + + public WeakKeyDictionary() + { + _internalDictionary = new Dictionary(new WeakComparer()); + } + + public WeakKeyDictionary(int capacity) + { + _internalDictionary = new Dictionary(capacity, new WeakComparer()); + } + + public WeakKeyDictionary(IEqualityComparer comparer) + { + _internalDictionary = new Dictionary(new WeakComparer(comparer)); + } + + public WeakKeyDictionary(int capacity, IEqualityComparer comparer) + { + _internalDictionary = new Dictionary(capacity, new WeakComparer(comparer)); + } + + // FXCop: this is not empty; we need to mark this so we know if a key + // still has an active dictionary at its finalization. + [SuppressMessage("Microsoft.Performance", "CA1821:RemoveEmptyFinalizers")] + ~WeakKeyDictionary() + { + _finalized = true; + } + + public ICollection Keys + { + get + { + List list = new List(); + lock (_internalObject) + { + foreach (WeakKey key in _internalDictionary.Keys) + { + object TKey = key.Target; + if (TKey != null) + { + list.Add((TKey)TKey); + } + } + } + return list; + } + } + + public ICollection Values + { + get { + lock (_internalObject) { + return _internalDictionary.Values; + } + } + } + + public int Count + { + get + { + // Ensure a fairly accurate count. + ScavangeLostKeys(); + lock (_internalObject) + { + return _internalDictionary.Count; + } + } + } + + public bool IsReadOnly + { + get { + return false; + } + } + + [SuppressMessage("Microsoft.Usage", "CA1806:DoNotIgnoreMethodResults", Justification = "LostKeyFinder's purpose is to get garbage collected as soon as posible")] + public TValue this[TKey key] + { + get { + lock (_internalObject) { + return _internalDictionary[new WeakKey(key)]; + } + } + set + { + WeakKey Tkey = new WeakKey(key); + lock (_internalObject) + { + //_internalDictionary[Tkey] = value; + _internalDictionary.Add(Tkey, value); + } + // This looks a bit weird but the purpose of the lost key finder is to execute + // code in some future garbage collection phase so we immediately create some garbage. + new LostKeyFinder(this, Tkey); + } + } + + + + + + public bool TryGetValue(TKey key, out TValue value) + { + WeakKey tkey = new WeakKey(key); + lock (_internalObject) + { + return _internalDictionary.TryGetValue(tkey, out value); + } + } + + + [SuppressMessage("Microsoft.Usage", "CA1806:DoNotIgnoreMethodResults", Justification = "LostKeyFinder's purpose is to get garbage collected as soon as posible")] + public void Add(TKey key, TValue value) + { + WeakKey tkey = new WeakKey(key); + lock (_internalObject) + { + _internalDictionary.Add(tkey, value); + } + // This looks a bit weird but the purpose of the lost key finder is to execute + // code in some future garbage collection phase so we immediately create some garbage. + new LostKeyFinder(this, tkey); + + } + + public bool ContainsKey(TKey key) + { + return _internalDictionary.ContainsKey(new WeakKey(key)); + } + + public bool Remove(TKey key) + { + lock (_internalObject) + { + return _internalDictionary.Remove(new WeakKey(key)); + } + } + + public void Add(KeyValuePair item) + { + Add(item.Key, item.Value); + } + + public void Clear() + { + lock (_internalObject) + { + _internalDictionary.Clear(); + } + } + + public bool Contains(KeyValuePair item) + { + TValue value; + bool result; + lock (_internalObject) + { + result = _internalDictionary.TryGetValue(new WeakKey(item.Key), out value); + } + if (result) + { + return value.Equals(item.Value); + } + else + { + return false; + } + } + + public void CopyTo(KeyValuePair[] array, int arrayIndex) + { + lock (_internalObject) + { + foreach (KeyValuePair item in _internalDictionary) + { + KeyValuePair kv = new KeyValuePair((TKey)item.Key.Target, item.Value); + array[arrayIndex] = kv; + arrayIndex++; + } + } + } + + public bool Remove(KeyValuePair item) + { + WeakKey key = new WeakKey(item.Key); + lock (_internalObject) + { + return _internalDictionary.Remove(key); + } + } + + + + + + public IEnumerator> GetEnumerator() + { + List lostKeys = null; + lock (_internalObject) + { + foreach (KeyValuePair item in _internalDictionary) + { + object TKey = item.Key.Target; + if (TKey != null) + { + yield return new KeyValuePair((TKey)TKey, item.Value); + } + else + { + if (lostKeys == null) + { + lostKeys = new List(); + } + lostKeys.Add(item.Key); + } + } + } + // Recover any lost keys. + if (lostKeys != null) + { + lock (_internalObject) + { + foreach (WeakKey key in lostKeys) + { + _internalDictionary.Remove(key); + } + } + } + } + + + + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + + + private void ScavangeLostKeys() + { + List lostKeys = null; + lock (_internalObject) + { + foreach (WeakKey key in _internalDictionary.Keys) + { + if (!key.IsAlive) + { + if (lostKeys == null) + { + lostKeys = new List(); + } + lostKeys.Add(key); + } + } + } + if (lostKeys != null) + { + lock (_internalObject) + { + foreach (WeakKey key in lostKeys) + { + _internalDictionary.Remove(key); + } + } + } + } + + IEnumerator> IEnumerable>.GetEnumerator() + { + return this.GetEnumerator(); + } + + private class WeakKey : WeakReference + { + private int _hashCode; + // private GCHandle _gcHandle; + + public WeakKey(TKey key) + : base(key, true) + { + _hashCode = key.GetHashCode(); + // Keep the key alive until it is explicitly collected + // _gcHandle = GCHandle.Alloc(this); + } + + internal void Release() + { + // _gcHandle.Free(); + } + + public override int GetHashCode() + { + return _hashCode; + } + + public override bool Equals(object obj) + { + if (obj == null) + { + return false; + } + if (obj.GetHashCode() != _hashCode) + { + return false; + } + if (obj != this && (!IsAlive || !obj.Equals(Target))) + { + return false; + } + return true; + } + } + + private class WeakComparer : IEqualityComparer + { + + private IEqualityComparer _comparer; + public WeakComparer() + { + } + + public WeakComparer(IEqualityComparer comparer) + { + _comparer = comparer; + } + + public bool Equals(WeakKey x, WeakKey y) + { + if (x.GetHashCode() != y.GetHashCode()) + { + return false; + } + if (object.ReferenceEquals(x, y)) + { + return true; + } + object ref1 = x.Target; + if (ref1 == null) + { + return false; + } + object ref2 = y.Target; + if (ref2 == null) + { + return false; + } + + if (_comparer != null) + { + return _comparer.Equals((TKey)ref1, (TKey)ref2); + } + else + { + return ref1.Equals(ref2); + } + } + + public int GetHashCode(WeakKey obj) + { + return obj.GetHashCode(); + } + } + + private class LostKeyFinder + { + WeakKeyDictionary _dictionary; + WeakKey _key; + + public LostKeyFinder(WeakKeyDictionary dictionary, WeakKey key) + { + _dictionary = dictionary; + _key = key; + } + + ~LostKeyFinder() + { + if (_dictionary._finalized || _key == null) + { + if (_key != null) + { + _key.Release(); + _key = null; + } + return; + } + // if (!_key.IsAlive) { + if (_key.Target == null) + { + lock (_dictionary._internalObject) + { + _dictionary._internalDictionary.Remove(_key); + } + _key.Release(); + _key = null; + } + else if (_dictionary._internalDictionary.ContainsKey(_key)) + { + GC.ReRegisterForFinalize(this); + } + } + } + } +} + \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/PythonBaseTests.cs b/test/TensorFlowNET.UnitTest/PythonBaseTests.cs index 765a71c2..c5010923 100644 --- a/test/TensorFlowNET.UnitTest/PythonBaseTests.cs +++ b/test/TensorFlowNET.UnitTest/PythonBaseTests.cs @@ -10,6 +10,24 @@ namespace TensorFlowNET.UnitTest [TestClass] public class PythonBaseTests : PythonTest { + [Ignore] + [TestMethod] + public void weakKeyDictionary_test() + { + var weakKeyDict = new WeakKeyDictionary(); + for (int i = 0; i < 5; i++) + { + var c = (char)((int)'a' + i); + weakKeyDict[i] = c; + //Assert.AreEqual(weakKeyDict.Count, (int)(i + 1)); + var v = (weakKeyDict.Count == i + 1); + Assert.IsTrue(v); + } + //Assert.AreEqual(weakKeyDict.Count, 0); + var b = (weakKeyDict.Count == 0); + Assert.IsTrue(b); + } + [TestMethod] public void hasattr_getattr() {