@@ -0,0 +1,73 @@ | |||
using NumSharp; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow; | |||
namespace Tensorflow.Keras | |||
{ | |||
public abstract class BackendBase | |||
{ | |||
TF_DataType _FLOATX = dtypes.float32; | |||
float _EPSILON = 1e-7f; | |||
ImageDataFormat _IMAGE_DATA_FORMAT = ImageDataFormat.channels_last; | |||
public float epsilon() => _EPSILON; | |||
public void set_epsilon(float e) => _EPSILON = e; | |||
public TF_DataType floatx() => _FLOATX; | |||
public void set_floatx(TF_DataType floatx) => _FLOATX = floatx; | |||
public NDArray cast_to_floatx(NDArray x) => np.array(x, dtype: _FLOATX.as_numpy_datatype()); | |||
public ImageDataFormat image_data_format() => _IMAGE_DATA_FORMAT; | |||
public void set_image_data_format(ImageDataFormat data_format) => _IMAGE_DATA_FORMAT = data_format; | |||
public ImageDataFormat normalize_data_format(object value = null) | |||
{ | |||
if (value == null) | |||
value = _IMAGE_DATA_FORMAT; | |||
if (value.GetType() == typeof(ImageDataFormat)) | |||
return (ImageDataFormat)value; | |||
else if (value.GetType() == typeof(string)) | |||
{ | |||
ImageDataFormat dataFormat; | |||
if(Enum.TryParse((string)value, true, out dataFormat)) | |||
{ | |||
if (Enum.IsDefined(typeof(ImageDataFormat), dataFormat) | dataFormat.ToString().Contains(",")) | |||
return dataFormat; | |||
} | |||
} | |||
throw new Exception("The `data_format` argument must be one of \"channels_first\", \"channels_last\". Received: " + value.ToString()); | |||
} | |||
//Legacy Methods | |||
public void set_image_dim_ordering(ImageDimOrder dim_ordering) | |||
{ | |||
if (dim_ordering == ImageDimOrder.th) | |||
_IMAGE_DATA_FORMAT = ImageDataFormat.channels_first; | |||
else if (dim_ordering == ImageDimOrder.tf) | |||
_IMAGE_DATA_FORMAT = ImageDataFormat.channels_last; | |||
else | |||
throw new Exception("Unknown dim_ordering:"+ dim_ordering); | |||
} | |||
public ImageDimOrder image_dim_ordering() | |||
{ | |||
if (_IMAGE_DATA_FORMAT == ImageDataFormat.channels_first) | |||
return ImageDimOrder.th; | |||
else | |||
return ImageDimOrder.tf; | |||
} | |||
} | |||
public enum ImageDimOrder | |||
{ | |||
tf, | |||
th | |||
} | |||
} |
@@ -0,0 +1,8 @@ | |||
namespace Tensorflow.Keras | |||
{ | |||
public enum GraphLearningPhase | |||
{ | |||
train_mode = 1, | |||
test_mode = 0 | |||
} | |||
} |
@@ -0,0 +1,8 @@ | |||
namespace Tensorflow.Keras | |||
{ | |||
public enum ImageDataFormat | |||
{ | |||
channels_last, | |||
channels_first | |||
} | |||
} |
@@ -95,14 +95,14 @@ namespace Tensorflow.Keras.Utils | |||
{ | |||
var graph = ops.get_default_graph(); | |||
Dictionary<(string, string), int> name_uid_map = null; | |||
if (backend.PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph.graph_key)) | |||
if (backend.PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph)) | |||
{ | |||
name_uid_map = backend.PER_GRAPH_LAYER_NAME_UIDS[graph.graph_key]; | |||
name_uid_map = backend.PER_GRAPH_LAYER_NAME_UIDS[graph]; | |||
} | |||
else | |||
{ | |||
name_uid_map = new Dictionary<(string, string), int>(); | |||
backend.PER_GRAPH_LAYER_NAME_UIDS[graph.graph_key] = name_uid_map; | |||
backend.PER_GRAPH_LAYER_NAME_UIDS[graph] = name_uid_map; | |||
} | |||
return name_uid_map; | |||
@@ -1,31 +1,51 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using System.Runtime.CompilerServices; | |||
using static Tensorflow.Python; | |||
namespace Tensorflow.Keras | |||
{ | |||
public class backend | |||
public class backend : BackendBase | |||
{ | |||
/* ---------------------------------------- KERAS BACKEND NATIVE OBJECTS ---------------------------------------- */ | |||
public static Func<Array, double> py_sum = sum; | |||
public static Func<Array, bool> py_all = all; | |||
//Func<Array, bool> py_any = any; | |||
//Func<double, double, double, IEnumerable<double>> py_slice = slice; | |||
public static Session _SESSION = Tensorflow.tf.defaultSession; | |||
public static Graph _GRAPH = null; | |||
public static Dictionary<Graph, GraphLearningPhase> _GRAPH_LEARNING_PHASES; | |||
//Dictionary<Graph, Dictionary<string, int>> PER_GRAPH_LAYER_NAME_UIDS; | |||
public static bool _MANUAL_VAR_INIT = false; | |||
public static List<string> _LOCAL_DEVICES = null; | |||
/* -------------------------------------- KERAS BACKEND NATIVE OBJECTS END -------------------------------------- */ | |||
/// <summary> | |||
/// A global dictionary mapping graph objects to an index of counters used | |||
/// for various layer names in each graph. | |||
/// Allows to give unique autogenerated names to layers, in a graph-specific way. | |||
/// </summary> | |||
public static Dictionary<string, Dictionary<(string, string), int>> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<string, Dictionary<(string, string), int>>(); | |||
public static Dictionary<Graph, Dictionary<(string, string), int>> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<(string, string), int>>(); | |||
public static Dictionary<string, RefVariable> _GRAPH_VARIABLES = new Dictionary<string, RefVariable>(); | |||
public static Dictionary<string, Optimizer> _GRAPH_TF_OPTIMIZERS = new Dictionary<string, Optimizer>(); | |||
public static _DummyEagerGraph _DUMMY_EAGER_GRAPH = new _DummyEagerGraph(); | |||
public static void track_variable(RefVariable v) | |||
{ | |||
var graph = v.graph; | |||
_GRAPH_VARIABLES[graph.graph_key] = v; | |||
} | |||
public static Tensor placeholder(int[] shape = null, | |||
int ndim = -1, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
bool sparse = false, | |||
public static Tensor placeholder(int[] shape = null, | |||
int ndim = -1, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
bool sparse = false, | |||
string name = null) | |||
{ | |||
if(sparse) | |||
if (sparse) | |||
{ | |||
throw new NotImplementedException("placeholder sparse is true"); | |||
} | |||
@@ -39,5 +59,56 @@ namespace Tensorflow.Keras | |||
{ | |||
return ops.get_default_graph(); | |||
} | |||
public static int get_uid(string prefix, string @namespace = "") | |||
{ | |||
var graph = tf.get_default_graph(); | |||
if (!PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph)) | |||
PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict<(string, string), int>()); | |||
PER_GRAPH_LAYER_NAME_UIDS[graph][(@namespace, prefix)] += 1; | |||
return PER_GRAPH_LAYER_NAME_UIDS[graph][(@namespace, prefix)]; | |||
} | |||
public static int get_uid((string, string) name) | |||
{ | |||
var graph = tf.get_default_graph(); | |||
if (!PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph)) | |||
PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict<(string, string), int>()); | |||
PER_GRAPH_LAYER_NAME_UIDS[graph][(name)] += 1; | |||
return PER_GRAPH_LAYER_NAME_UIDS[graph][name]; | |||
} | |||
public static void reset_uids() => PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<(string, string), int>>(); | |||
public static void clear_session() | |||
{ | |||
ops.reset_default_graph(); | |||
reset_uids(); | |||
_SESSION = null; | |||
var phase = tf.placeholder_with_default(false, new int[] { }, name: "keras_learning_phase"); | |||
_GRAPH_LEARNING_PHASES = new Dictionary<Graph, GraphLearningPhase>(); | |||
_GRAPH_LEARNING_PHASES[tf.get_default_graph()] = 0; | |||
} | |||
public static void manual_variable_initialization(bool value) | |||
{ | |||
_MANUAL_VAR_INIT = value; | |||
} | |||
public static GraphLearningPhase learning_phase() | |||
{ | |||
var graph = tf.get_default_graph(); | |||
if (_GRAPH_LEARNING_PHASES.ContainsKey(graph)) | |||
{ | |||
var phase = tf.placeholder_with_default(false, shape: new int[] { }, name: "keras_learning_phase"); | |||
_GRAPH_LEARNING_PHASES[graph] = 0; | |||
} | |||
return _GRAPH_LEARNING_PHASES[graph]; | |||
} | |||
public static void set_learning_phase(bool value) | |||
{ | |||
_GRAPH_LEARNING_PHASES[tf.get_default_graph()] = (GraphLearningPhase)((value) ? 1 : 0); | |||
} | |||
public class _DummyEagerGraph | |||
{ } | |||
} | |||
} |
@@ -0,0 +1,22 @@ | |||
using System.Collections.Generic; | |||
namespace System.Collections.Generic | |||
{ | |||
public class defaultdict<TKey, TValue> : Dictionary<TKey, TValue> where TValue : new() | |||
{ | |||
public new TValue this[TKey key] | |||
{ | |||
get | |||
{ | |||
TValue val; | |||
if(!TryGetValue(key, out val)) | |||
{ | |||
val = default(TValue); | |||
Add(key, val); | |||
} | |||
return val; | |||
} | |||
set { base[key] = value; } | |||
} | |||
} | |||
} |
@@ -184,6 +184,69 @@ namespace Tensorflow | |||
return dictionary; | |||
} | |||
public static bool all(IEnumerable enumerable) | |||
{ | |||
foreach (var e1 in enumerable) | |||
{ | |||
if (!Convert.ToBoolean(e1)) | |||
return false; | |||
} | |||
return true; | |||
} | |||
public static bool any(IEnumerable enumerable) | |||
{ | |||
foreach (var e1 in enumerable) | |||
{ | |||
if (Convert.ToBoolean(e1)) | |||
return true; | |||
} | |||
return false; | |||
} | |||
public static double sum(IEnumerable enumerable) | |||
{ | |||
var typedef = new Type[] { typeof(double), typeof(int), typeof(float) }; | |||
var sum = 0.0d; | |||
foreach (var e1 in enumerable) | |||
{ | |||
if (!typedef.Contains(e1.GetType())) | |||
throw new Exception("Numeric array expected"); | |||
sum += (double)e1; | |||
} | |||
return sum; | |||
} | |||
public static double sum<TKey, TValue>(Dictionary<TKey, TValue> values) | |||
{ | |||
return sum(values.Keys); | |||
} | |||
public static IEnumerable<double> slice(double start, double end, double step = 1) | |||
{ | |||
for (double i = start; i < end; i += step) | |||
yield return i; | |||
} | |||
public static IEnumerable<float> slice(float start, float end, float step = 1) | |||
{ | |||
for (float i = start; i < end; i += step) | |||
yield return i; | |||
} | |||
public static IEnumerable<int> slice(int start, int end, int step = 1) | |||
{ | |||
for (int i = start; i < end; i += step) | |||
yield return i; | |||
} | |||
public static IEnumerable<int> slice(int range) | |||
{ | |||
for (int i = 0; i < range; i++) | |||
yield return i; | |||
} | |||
public static bool hasattr(object obj, string key) | |||
{ | |||
var __type__ = (obj).GetType(); | |||
@@ -0,0 +1,424 @@ | |||
using System; | |||
using System.Collections; | |||
using System.Collections.Generic; | |||
using System.Diagnostics.CodeAnalysis; | |||
using System.Text; | |||
using System.Runtime.InteropServices; | |||
namespace Tensorflow | |||
{ | |||
public class WeakKeyDictionary<TKey, TValue> : IDictionary<TKey, TValue> | |||
{ | |||
private Dictionary<WeakKey, TValue> _internalDictionary; | |||
private object _internalObject = new object(); | |||
private bool _finalized; | |||
public WeakKeyDictionary() | |||
{ | |||
_internalDictionary = new Dictionary<WeakKey, TValue>(new WeakComparer()); | |||
} | |||
public WeakKeyDictionary(int capacity) | |||
{ | |||
_internalDictionary = new Dictionary<WeakKey, TValue>(capacity, new WeakComparer()); | |||
} | |||
public WeakKeyDictionary(IEqualityComparer<TKey> comparer) | |||
{ | |||
_internalDictionary = new Dictionary<WeakKey, TValue>(new WeakComparer(comparer)); | |||
} | |||
public WeakKeyDictionary(int capacity, IEqualityComparer<TKey> comparer) | |||
{ | |||
_internalDictionary = new Dictionary<WeakKey, TValue>(capacity, new WeakComparer(comparer)); | |||
} | |||
// FXCop: this is not empty; we need to mark this so we know if a key | |||
// still has an active dictionary at its finalization. | |||
[SuppressMessage("Microsoft.Performance", "CA1821:RemoveEmptyFinalizers")] | |||
~WeakKeyDictionary() | |||
{ | |||
_finalized = true; | |||
} | |||
public ICollection<TKey> Keys | |||
{ | |||
get | |||
{ | |||
List<TKey> list = new List<TKey>(); | |||
lock (_internalObject) | |||
{ | |||
foreach (WeakKey key in _internalDictionary.Keys) | |||
{ | |||
object TKey = key.Target; | |||
if (TKey != null) | |||
{ | |||
list.Add((TKey)TKey); | |||
} | |||
} | |||
} | |||
return list; | |||
} | |||
} | |||
public ICollection<TValue> Values | |||
{ | |||
get { | |||
lock (_internalObject) { | |||
return _internalDictionary.Values; | |||
} | |||
} | |||
} | |||
public int Count | |||
{ | |||
get | |||
{ | |||
// Ensure a fairly accurate count. | |||
ScavangeLostKeys(); | |||
lock (_internalObject) | |||
{ | |||
return _internalDictionary.Count; | |||
} | |||
} | |||
} | |||
public bool IsReadOnly | |||
{ | |||
get { | |||
return false; | |||
} | |||
} | |||
[SuppressMessage("Microsoft.Usage", "CA1806:DoNotIgnoreMethodResults", Justification = "LostKeyFinder's purpose is to get garbage collected as soon as posible")] | |||
public TValue this[TKey key] | |||
{ | |||
get { | |||
lock (_internalObject) { | |||
return _internalDictionary[new WeakKey(key)]; | |||
} | |||
} | |||
set | |||
{ | |||
WeakKey Tkey = new WeakKey(key); | |||
lock (_internalObject) | |||
{ | |||
//_internalDictionary[Tkey] = value; | |||
_internalDictionary.Add(Tkey, value); | |||
} | |||
// This looks a bit weird but the purpose of the lost key finder is to execute | |||
// code in some future garbage collection phase so we immediately create some garbage. | |||
new LostKeyFinder(this, Tkey); | |||
} | |||
} | |||
public bool TryGetValue(TKey key, out TValue value) | |||
{ | |||
WeakKey tkey = new WeakKey(key); | |||
lock (_internalObject) | |||
{ | |||
return _internalDictionary.TryGetValue(tkey, out value); | |||
} | |||
} | |||
[SuppressMessage("Microsoft.Usage", "CA1806:DoNotIgnoreMethodResults", Justification = "LostKeyFinder's purpose is to get garbage collected as soon as posible")] | |||
public void Add(TKey key, TValue value) | |||
{ | |||
WeakKey tkey = new WeakKey(key); | |||
lock (_internalObject) | |||
{ | |||
_internalDictionary.Add(tkey, value); | |||
} | |||
// This looks a bit weird but the purpose of the lost key finder is to execute | |||
// code in some future garbage collection phase so we immediately create some garbage. | |||
new LostKeyFinder(this, tkey); | |||
} | |||
public bool ContainsKey(TKey key) | |||
{ | |||
return _internalDictionary.ContainsKey(new WeakKey(key)); | |||
} | |||
public bool Remove(TKey key) | |||
{ | |||
lock (_internalObject) | |||
{ | |||
return _internalDictionary.Remove(new WeakKey(key)); | |||
} | |||
} | |||
public void Add(KeyValuePair<TKey, TValue> item) | |||
{ | |||
Add(item.Key, item.Value); | |||
} | |||
public void Clear() | |||
{ | |||
lock (_internalObject) | |||
{ | |||
_internalDictionary.Clear(); | |||
} | |||
} | |||
public bool Contains(KeyValuePair<TKey, TValue> item) | |||
{ | |||
TValue value; | |||
bool result; | |||
lock (_internalObject) | |||
{ | |||
result = _internalDictionary.TryGetValue(new WeakKey(item.Key), out value); | |||
} | |||
if (result) | |||
{ | |||
return value.Equals(item.Value); | |||
} | |||
else | |||
{ | |||
return false; | |||
} | |||
} | |||
public void CopyTo(KeyValuePair<TKey, TValue>[] array, int arrayIndex) | |||
{ | |||
lock (_internalObject) | |||
{ | |||
foreach (KeyValuePair<WeakKey, TValue> item in _internalDictionary) | |||
{ | |||
KeyValuePair<TKey, TValue> kv = new KeyValuePair<TKey, TValue>((TKey)item.Key.Target, item.Value); | |||
array[arrayIndex] = kv; | |||
arrayIndex++; | |||
} | |||
} | |||
} | |||
public bool Remove(KeyValuePair<TKey, TValue> item) | |||
{ | |||
WeakKey key = new WeakKey(item.Key); | |||
lock (_internalObject) | |||
{ | |||
return _internalDictionary.Remove(key); | |||
} | |||
} | |||
public IEnumerator<KeyValuePair<TKey, TValue>> GetEnumerator() | |||
{ | |||
List<WeakKey> lostKeys = null; | |||
lock (_internalObject) | |||
{ | |||
foreach (KeyValuePair<WeakKey, TValue> item in _internalDictionary) | |||
{ | |||
object TKey = item.Key.Target; | |||
if (TKey != null) | |||
{ | |||
yield return new KeyValuePair<TKey, TValue>((TKey)TKey, item.Value); | |||
} | |||
else | |||
{ | |||
if (lostKeys == null) | |||
{ | |||
lostKeys = new List<WeakKey>(); | |||
} | |||
lostKeys.Add(item.Key); | |||
} | |||
} | |||
} | |||
// Recover any lost keys. | |||
if (lostKeys != null) | |||
{ | |||
lock (_internalObject) | |||
{ | |||
foreach (WeakKey key in lostKeys) | |||
{ | |||
_internalDictionary.Remove(key); | |||
} | |||
} | |||
} | |||
} | |||
IEnumerator IEnumerable.GetEnumerator() | |||
{ | |||
return GetEnumerator(); | |||
} | |||
private void ScavangeLostKeys() | |||
{ | |||
List<WeakKey> lostKeys = null; | |||
lock (_internalObject) | |||
{ | |||
foreach (WeakKey key in _internalDictionary.Keys) | |||
{ | |||
if (!key.IsAlive) | |||
{ | |||
if (lostKeys == null) | |||
{ | |||
lostKeys = new List<WeakKey>(); | |||
} | |||
lostKeys.Add(key); | |||
} | |||
} | |||
} | |||
if (lostKeys != null) | |||
{ | |||
lock (_internalObject) | |||
{ | |||
foreach (WeakKey key in lostKeys) | |||
{ | |||
_internalDictionary.Remove(key); | |||
} | |||
} | |||
} | |||
} | |||
IEnumerator<KeyValuePair<TKey, TValue>> IEnumerable<KeyValuePair<TKey, TValue>>.GetEnumerator() | |||
{ | |||
return this.GetEnumerator(); | |||
} | |||
private class WeakKey : WeakReference | |||
{ | |||
private int _hashCode; | |||
// private GCHandle _gcHandle; | |||
public WeakKey(TKey key) | |||
: base(key, true) | |||
{ | |||
_hashCode = key.GetHashCode(); | |||
// Keep the key alive until it is explicitly collected | |||
// _gcHandle = GCHandle.Alloc(this); | |||
} | |||
internal void Release() | |||
{ | |||
// _gcHandle.Free(); | |||
} | |||
public override int GetHashCode() | |||
{ | |||
return _hashCode; | |||
} | |||
public override bool Equals(object obj) | |||
{ | |||
if (obj == null) | |||
{ | |||
return false; | |||
} | |||
if (obj.GetHashCode() != _hashCode) | |||
{ | |||
return false; | |||
} | |||
if (obj != this && (!IsAlive || !obj.Equals(Target))) | |||
{ | |||
return false; | |||
} | |||
return true; | |||
} | |||
} | |||
private class WeakComparer : IEqualityComparer<WeakKey> | |||
{ | |||
private IEqualityComparer<TKey> _comparer; | |||
public WeakComparer() | |||
{ | |||
} | |||
public WeakComparer(IEqualityComparer<TKey> comparer) | |||
{ | |||
_comparer = comparer; | |||
} | |||
public bool Equals(WeakKey x, WeakKey y) | |||
{ | |||
if (x.GetHashCode() != y.GetHashCode()) | |||
{ | |||
return false; | |||
} | |||
if (object.ReferenceEquals(x, y)) | |||
{ | |||
return true; | |||
} | |||
object ref1 = x.Target; | |||
if (ref1 == null) | |||
{ | |||
return false; | |||
} | |||
object ref2 = y.Target; | |||
if (ref2 == null) | |||
{ | |||
return false; | |||
} | |||
if (_comparer != null) | |||
{ | |||
return _comparer.Equals((TKey)ref1, (TKey)ref2); | |||
} | |||
else | |||
{ | |||
return ref1.Equals(ref2); | |||
} | |||
} | |||
public int GetHashCode(WeakKey obj) | |||
{ | |||
return obj.GetHashCode(); | |||
} | |||
} | |||
private class LostKeyFinder | |||
{ | |||
WeakKeyDictionary<TKey, TValue> _dictionary; | |||
WeakKey _key; | |||
public LostKeyFinder(WeakKeyDictionary<TKey, TValue> dictionary, WeakKey key) | |||
{ | |||
_dictionary = dictionary; | |||
_key = key; | |||
} | |||
~LostKeyFinder() | |||
{ | |||
if (_dictionary._finalized || _key == null) | |||
{ | |||
if (_key != null) | |||
{ | |||
_key.Release(); | |||
_key = null; | |||
} | |||
return; | |||
} | |||
// if (!_key.IsAlive) { | |||
if (_key.Target == null) | |||
{ | |||
lock (_dictionary._internalObject) | |||
{ | |||
_dictionary._internalDictionary.Remove(_key); | |||
} | |||
_key.Release(); | |||
_key = null; | |||
} | |||
else if (_dictionary._internalDictionary.ContainsKey(_key)) | |||
{ | |||
GC.ReRegisterForFinalize(this); | |||
} | |||
} | |||
} | |||
} | |||
} | |||
@@ -10,6 +10,24 @@ namespace TensorFlowNET.UnitTest | |||
[TestClass] | |||
public class PythonBaseTests : PythonTest | |||
{ | |||
[Ignore] | |||
[TestMethod] | |||
public void weakKeyDictionary_test() | |||
{ | |||
var weakKeyDict = new WeakKeyDictionary<int, char>(); | |||
for (int i = 0; i < 5; i++) | |||
{ | |||
var c = (char)((int)'a' + i); | |||
weakKeyDict[i] = c; | |||
//Assert.AreEqual(weakKeyDict.Count, (int)(i + 1)); | |||
var v = (weakKeyDict.Count == i + 1); | |||
Assert.IsTrue(v); | |||
} | |||
//Assert.AreEqual(weakKeyDict.Count, 0); | |||
var b = (weakKeyDict.Count == 0); | |||
Assert.IsTrue(b); | |||
} | |||
[TestMethod] | |||
public void hasattr_getattr() | |||
{ | |||