Browse Source

Merge pull request #261 from arnavdas88/master

Ongoing tf.keras.backend.cs
tags/v0.9
Haiping GitHub 6 years ago
parent
commit
0df701389c
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 697 additions and 10 deletions
  1. +73
    -0
      src/TensorFlowNET.Core/Keras/BackendBase.cs
  2. +8
    -0
      src/TensorFlowNET.Core/Keras/GraphLearningPhase.cs
  3. +8
    -0
      src/TensorFlowNET.Core/Keras/ImageDataFormat.cs
  4. +3
    -3
      src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs
  5. +78
    -7
      src/TensorFlowNET.Core/Keras/backend.cs
  6. +22
    -0
      src/TensorFlowNET.Core/Keras/defaultdict.cs
  7. +63
    -0
      src/TensorFlowNET.Core/Python.cs
  8. +424
    -0
      src/TensorFlowNET.Core/WeakKeyDicionary.cs
  9. +18
    -0
      test/TensorFlowNET.UnitTest/PythonBaseTests.cs

+ 73
- 0
src/TensorFlowNET.Core/Keras/BackendBase.cs View File

@@ -0,0 +1,73 @@
using NumSharp;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;

namespace Tensorflow.Keras
{
public abstract class BackendBase
{
TF_DataType _FLOATX = dtypes.float32;
float _EPSILON = 1e-7f;
ImageDataFormat _IMAGE_DATA_FORMAT = ImageDataFormat.channels_last;


public float epsilon() => _EPSILON;

public void set_epsilon(float e) => _EPSILON = e;

public TF_DataType floatx() => _FLOATX;

public void set_floatx(TF_DataType floatx) => _FLOATX = floatx;

public NDArray cast_to_floatx(NDArray x) => np.array(x, dtype: _FLOATX.as_numpy_datatype());

public ImageDataFormat image_data_format() => _IMAGE_DATA_FORMAT;

public void set_image_data_format(ImageDataFormat data_format) => _IMAGE_DATA_FORMAT = data_format;

public ImageDataFormat normalize_data_format(object value = null)
{
if (value == null)
value = _IMAGE_DATA_FORMAT;
if (value.GetType() == typeof(ImageDataFormat))
return (ImageDataFormat)value;
else if (value.GetType() == typeof(string))
{
ImageDataFormat dataFormat;
if(Enum.TryParse((string)value, true, out dataFormat))
{
if (Enum.IsDefined(typeof(ImageDataFormat), dataFormat) | dataFormat.ToString().Contains(","))
return dataFormat;
}
}
throw new Exception("The `data_format` argument must be one of \"channels_first\", \"channels_last\". Received: " + value.ToString());
}

//Legacy Methods

public void set_image_dim_ordering(ImageDimOrder dim_ordering)
{
if (dim_ordering == ImageDimOrder.th)
_IMAGE_DATA_FORMAT = ImageDataFormat.channels_first;
else if (dim_ordering == ImageDimOrder.tf)
_IMAGE_DATA_FORMAT = ImageDataFormat.channels_last;
else
throw new Exception("Unknown dim_ordering:"+ dim_ordering);
}

public ImageDimOrder image_dim_ordering()
{
if (_IMAGE_DATA_FORMAT == ImageDataFormat.channels_first)
return ImageDimOrder.th;
else
return ImageDimOrder.tf;
}
}
public enum ImageDimOrder
{
tf,
th
}
}

+ 8
- 0
src/TensorFlowNET.Core/Keras/GraphLearningPhase.cs View File

@@ -0,0 +1,8 @@
namespace Tensorflow.Keras
{
public enum GraphLearningPhase
{
train_mode = 1,
test_mode = 0
}
}

+ 8
- 0
src/TensorFlowNET.Core/Keras/ImageDataFormat.cs View File

@@ -0,0 +1,8 @@
namespace Tensorflow.Keras
{
public enum ImageDataFormat
{
channels_last,
channels_first
}
}

+ 3
- 3
src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs View File

@@ -95,14 +95,14 @@ namespace Tensorflow.Keras.Utils
{
var graph = ops.get_default_graph();
Dictionary<(string, string), int> name_uid_map = null;
if (backend.PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph.graph_key))
if (backend.PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph))
{
name_uid_map = backend.PER_GRAPH_LAYER_NAME_UIDS[graph.graph_key];
name_uid_map = backend.PER_GRAPH_LAYER_NAME_UIDS[graph];
}
else
{
name_uid_map = new Dictionary<(string, string), int>();
backend.PER_GRAPH_LAYER_NAME_UIDS[graph.graph_key] = name_uid_map;
backend.PER_GRAPH_LAYER_NAME_UIDS[graph] = name_uid_map;
}

return name_uid_map;


+ 78
- 7
src/TensorFlowNET.Core/Keras/backend.cs View File

@@ -1,31 +1,51 @@
using System;
using System.Collections.Generic;
using System.Text;
using System.Runtime.CompilerServices;
using static Tensorflow.Python;

namespace Tensorflow.Keras
{
public class backend
public class backend : BackendBase
{
/* ---------------------------------------- KERAS BACKEND NATIVE OBJECTS ---------------------------------------- */
public static Func<Array, double> py_sum = sum;
public static Func<Array, bool> py_all = all;
//Func<Array, bool> py_any = any;
//Func<double, double, double, IEnumerable<double>> py_slice = slice;

public static Session _SESSION = Tensorflow.tf.defaultSession;
public static Graph _GRAPH = null;
public static Dictionary<Graph, GraphLearningPhase> _GRAPH_LEARNING_PHASES;
//Dictionary<Graph, Dictionary<string, int>> PER_GRAPH_LAYER_NAME_UIDS;
public static bool _MANUAL_VAR_INIT = false;
public static List<string> _LOCAL_DEVICES = null;
/* -------------------------------------- KERAS BACKEND NATIVE OBJECTS END -------------------------------------- */

/// <summary>
/// A global dictionary mapping graph objects to an index of counters used
/// for various layer names in each graph.
/// Allows to give unique autogenerated names to layers, in a graph-specific way.
/// </summary>
public static Dictionary<string, Dictionary<(string, string), int>> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<string, Dictionary<(string, string), int>>();
public static Dictionary<Graph, Dictionary<(string, string), int>> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<(string, string), int>>();
public static Dictionary<string, RefVariable> _GRAPH_VARIABLES = new Dictionary<string, RefVariable>();
public static Dictionary<string, Optimizer> _GRAPH_TF_OPTIMIZERS = new Dictionary<string, Optimizer>();

public static _DummyEagerGraph _DUMMY_EAGER_GRAPH = new _DummyEagerGraph();

public static void track_variable(RefVariable v)
{
var graph = v.graph;
_GRAPH_VARIABLES[graph.graph_key] = v;
}

public static Tensor placeholder(int[] shape = null,
int ndim = -1,
TF_DataType dtype = TF_DataType.DtInvalid,
bool sparse = false,
public static Tensor placeholder(int[] shape = null,
int ndim = -1,
TF_DataType dtype = TF_DataType.DtInvalid,
bool sparse = false,
string name = null)
{
if(sparse)
if (sparse)
{
throw new NotImplementedException("placeholder sparse is true");
}
@@ -39,5 +59,56 @@ namespace Tensorflow.Keras
{
return ops.get_default_graph();
}

public static int get_uid(string prefix, string @namespace = "")
{
var graph = tf.get_default_graph();
if (!PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph))
PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict<(string, string), int>());
PER_GRAPH_LAYER_NAME_UIDS[graph][(@namespace, prefix)] += 1;

return PER_GRAPH_LAYER_NAME_UIDS[graph][(@namespace, prefix)];
}
public static int get_uid((string, string) name)
{
var graph = tf.get_default_graph();
if (!PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph))
PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict<(string, string), int>());
PER_GRAPH_LAYER_NAME_UIDS[graph][(name)] += 1;

return PER_GRAPH_LAYER_NAME_UIDS[graph][name];
}
public static void reset_uids() => PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<(string, string), int>>();
public static void clear_session()
{
ops.reset_default_graph();
reset_uids();
_SESSION = null;
var phase = tf.placeholder_with_default(false, new int[] { }, name: "keras_learning_phase");
_GRAPH_LEARNING_PHASES = new Dictionary<Graph, GraphLearningPhase>();
_GRAPH_LEARNING_PHASES[tf.get_default_graph()] = 0;
}
public static void manual_variable_initialization(bool value)
{
_MANUAL_VAR_INIT = value;
}
public static GraphLearningPhase learning_phase()
{
var graph = tf.get_default_graph();
if (_GRAPH_LEARNING_PHASES.ContainsKey(graph))
{
var phase = tf.placeholder_with_default(false, shape: new int[] { }, name: "keras_learning_phase");
_GRAPH_LEARNING_PHASES[graph] = 0;
}
return _GRAPH_LEARNING_PHASES[graph];
}
public static void set_learning_phase(bool value)
{
_GRAPH_LEARNING_PHASES[tf.get_default_graph()] = (GraphLearningPhase)((value) ? 1 : 0);
}


public class _DummyEagerGraph
{ }
}
}

+ 22
- 0
src/TensorFlowNET.Core/Keras/defaultdict.cs View File

@@ -0,0 +1,22 @@
using System.Collections.Generic;

namespace System.Collections.Generic
{
public class defaultdict<TKey, TValue> : Dictionary<TKey, TValue> where TValue : new()
{
public new TValue this[TKey key]
{
get
{
TValue val;
if(!TryGetValue(key, out val))
{
val = default(TValue);
Add(key, val);
}
return val;
}
set { base[key] = value; }
}
}
}

+ 63
- 0
src/TensorFlowNET.Core/Python.cs View File

@@ -184,6 +184,69 @@ namespace Tensorflow
return dictionary;
}


public static bool all(IEnumerable enumerable)
{
foreach (var e1 in enumerable)
{
if (!Convert.ToBoolean(e1))
return false;
}
return true;
}

public static bool any(IEnumerable enumerable)
{
foreach (var e1 in enumerable)
{
if (Convert.ToBoolean(e1))
return true;
}
return false;
}

public static double sum(IEnumerable enumerable)
{
var typedef = new Type[] { typeof(double), typeof(int), typeof(float) };
var sum = 0.0d;
foreach (var e1 in enumerable)
{
if (!typedef.Contains(e1.GetType()))
throw new Exception("Numeric array expected");
sum += (double)e1;
}
return sum;
}

public static double sum<TKey, TValue>(Dictionary<TKey, TValue> values)
{
return sum(values.Keys);
}

public static IEnumerable<double> slice(double start, double end, double step = 1)
{
for (double i = start; i < end; i += step)
yield return i;
}

public static IEnumerable<float> slice(float start, float end, float step = 1)
{
for (float i = start; i < end; i += step)
yield return i;
}

public static IEnumerable<int> slice(int start, int end, int step = 1)
{
for (int i = start; i < end; i += step)
yield return i;
}

public static IEnumerable<int> slice(int range)
{
for (int i = 0; i < range; i++)
yield return i;
}

public static bool hasattr(object obj, string key)
{
var __type__ = (obj).GetType();


+ 424
- 0
src/TensorFlowNET.Core/WeakKeyDicionary.cs View File

@@ -0,0 +1,424 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Text;
using System.Runtime.InteropServices;

namespace Tensorflow
{
public class WeakKeyDictionary<TKey, TValue> : IDictionary<TKey, TValue>
{
private Dictionary<WeakKey, TValue> _internalDictionary;
private object _internalObject = new object();
private bool _finalized;
public WeakKeyDictionary()
{
_internalDictionary = new Dictionary<WeakKey, TValue>(new WeakComparer());
}
public WeakKeyDictionary(int capacity)
{
_internalDictionary = new Dictionary<WeakKey, TValue>(capacity, new WeakComparer());
}
public WeakKeyDictionary(IEqualityComparer<TKey> comparer)
{
_internalDictionary = new Dictionary<WeakKey, TValue>(new WeakComparer(comparer));
}
public WeakKeyDictionary(int capacity, IEqualityComparer<TKey> comparer)
{
_internalDictionary = new Dictionary<WeakKey, TValue>(capacity, new WeakComparer(comparer));
}
// FXCop: this is not empty; we need to mark this so we know if a key
// still has an active dictionary at its finalization.
[SuppressMessage("Microsoft.Performance", "CA1821:RemoveEmptyFinalizers")]
~WeakKeyDictionary()
{
_finalized = true;
}
public ICollection<TKey> Keys
{
get
{
List<TKey> list = new List<TKey>();
lock (_internalObject)
{
foreach (WeakKey key in _internalDictionary.Keys)
{
object TKey = key.Target;
if (TKey != null)
{
list.Add((TKey)TKey);
}
}
}
return list;
}
}
public ICollection<TValue> Values
{
get {
lock (_internalObject) {
return _internalDictionary.Values;
}
}
}
public int Count
{
get
{
// Ensure a fairly accurate count.
ScavangeLostKeys();
lock (_internalObject)
{
return _internalDictionary.Count;
}
}
}
public bool IsReadOnly
{
get {
return false;
}
}
[SuppressMessage("Microsoft.Usage", "CA1806:DoNotIgnoreMethodResults", Justification = "LostKeyFinder's purpose is to get garbage collected as soon as posible")]
public TValue this[TKey key]
{
get {
lock (_internalObject) {
return _internalDictionary[new WeakKey(key)];
}
}
set
{
WeakKey Tkey = new WeakKey(key);
lock (_internalObject)
{
//_internalDictionary[Tkey] = value;
_internalDictionary.Add(Tkey, value);
}
// This looks a bit weird but the purpose of the lost key finder is to execute
// code in some future garbage collection phase so we immediately create some garbage.
new LostKeyFinder(this, Tkey);
}
}
public bool TryGetValue(TKey key, out TValue value)
{
WeakKey tkey = new WeakKey(key);
lock (_internalObject)
{
return _internalDictionary.TryGetValue(tkey, out value);
}
}
[SuppressMessage("Microsoft.Usage", "CA1806:DoNotIgnoreMethodResults", Justification = "LostKeyFinder's purpose is to get garbage collected as soon as posible")]
public void Add(TKey key, TValue value)
{
WeakKey tkey = new WeakKey(key);
lock (_internalObject)
{
_internalDictionary.Add(tkey, value);
}
// This looks a bit weird but the purpose of the lost key finder is to execute
// code in some future garbage collection phase so we immediately create some garbage.
new LostKeyFinder(this, tkey);
}
public bool ContainsKey(TKey key)
{
return _internalDictionary.ContainsKey(new WeakKey(key));
}
public bool Remove(TKey key)
{
lock (_internalObject)
{
return _internalDictionary.Remove(new WeakKey(key));
}
}
public void Add(KeyValuePair<TKey, TValue> item)
{
Add(item.Key, item.Value);
}
public void Clear()
{
lock (_internalObject)
{
_internalDictionary.Clear();
}
}
public bool Contains(KeyValuePair<TKey, TValue> item)
{
TValue value;
bool result;
lock (_internalObject)
{
result = _internalDictionary.TryGetValue(new WeakKey(item.Key), out value);
}
if (result)
{
return value.Equals(item.Value);
}
else
{
return false;
}
}
public void CopyTo(KeyValuePair<TKey, TValue>[] array, int arrayIndex)
{
lock (_internalObject)
{
foreach (KeyValuePair<WeakKey, TValue> item in _internalDictionary)
{
KeyValuePair<TKey, TValue> kv = new KeyValuePair<TKey, TValue>((TKey)item.Key.Target, item.Value);
array[arrayIndex] = kv;
arrayIndex++;
}
}
}
public bool Remove(KeyValuePair<TKey, TValue> item)
{
WeakKey key = new WeakKey(item.Key);
lock (_internalObject)
{
return _internalDictionary.Remove(key);
}
}
public IEnumerator<KeyValuePair<TKey, TValue>> GetEnumerator()
{
List<WeakKey> lostKeys = null;
lock (_internalObject)
{
foreach (KeyValuePair<WeakKey, TValue> item in _internalDictionary)
{
object TKey = item.Key.Target;
if (TKey != null)
{
yield return new KeyValuePair<TKey, TValue>((TKey)TKey, item.Value);
}
else
{
if (lostKeys == null)
{
lostKeys = new List<WeakKey>();
}
lostKeys.Add(item.Key);
}
}
}
// Recover any lost keys.
if (lostKeys != null)
{
lock (_internalObject)
{
foreach (WeakKey key in lostKeys)
{
_internalDictionary.Remove(key);
}
}
}
}
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
private void ScavangeLostKeys()
{
List<WeakKey> lostKeys = null;
lock (_internalObject)
{
foreach (WeakKey key in _internalDictionary.Keys)
{
if (!key.IsAlive)
{
if (lostKeys == null)
{
lostKeys = new List<WeakKey>();
}
lostKeys.Add(key);
}
}
}
if (lostKeys != null)
{
lock (_internalObject)
{
foreach (WeakKey key in lostKeys)
{
_internalDictionary.Remove(key);
}
}
}
}

IEnumerator<KeyValuePair<TKey, TValue>> IEnumerable<KeyValuePair<TKey, TValue>>.GetEnumerator()
{
return this.GetEnumerator();
}

private class WeakKey : WeakReference
{
private int _hashCode;
// private GCHandle _gcHandle;
public WeakKey(TKey key)
: base(key, true)
{
_hashCode = key.GetHashCode();
// Keep the key alive until it is explicitly collected
// _gcHandle = GCHandle.Alloc(this);
}
internal void Release()
{
// _gcHandle.Free();
}
public override int GetHashCode()
{
return _hashCode;
}
public override bool Equals(object obj)
{
if (obj == null)
{
return false;
}
if (obj.GetHashCode() != _hashCode)
{
return false;
}
if (obj != this && (!IsAlive || !obj.Equals(Target)))
{
return false;
}
return true;
}
}
private class WeakComparer : IEqualityComparer<WeakKey>
{
private IEqualityComparer<TKey> _comparer;
public WeakComparer()
{
}
public WeakComparer(IEqualityComparer<TKey> comparer)
{
_comparer = comparer;
}
public bool Equals(WeakKey x, WeakKey y)
{
if (x.GetHashCode() != y.GetHashCode())
{
return false;
}
if (object.ReferenceEquals(x, y))
{
return true;
}
object ref1 = x.Target;
if (ref1 == null)
{
return false;
}
object ref2 = y.Target;
if (ref2 == null)
{
return false;
}
if (_comparer != null)
{
return _comparer.Equals((TKey)ref1, (TKey)ref2);
}
else
{
return ref1.Equals(ref2);
}
}
public int GetHashCode(WeakKey obj)
{
return obj.GetHashCode();
}
}
private class LostKeyFinder
{
WeakKeyDictionary<TKey, TValue> _dictionary;
WeakKey _key;
public LostKeyFinder(WeakKeyDictionary<TKey, TValue> dictionary, WeakKey key)
{
_dictionary = dictionary;
_key = key;
}
~LostKeyFinder()
{
if (_dictionary._finalized || _key == null)
{
if (_key != null)
{
_key.Release();
_key = null;
}
return;
}
// if (!_key.IsAlive) {
if (_key.Target == null)
{
lock (_dictionary._internalObject)
{
_dictionary._internalDictionary.Remove(_key);
}
_key.Release();
_key = null;
}
else if (_dictionary._internalDictionary.ContainsKey(_key))
{
GC.ReRegisterForFinalize(this);
}
}
}
}
}

+ 18
- 0
test/TensorFlowNET.UnitTest/PythonBaseTests.cs View File

@@ -10,6 +10,24 @@ namespace TensorFlowNET.UnitTest
[TestClass]
public class PythonBaseTests : PythonTest
{
[Ignore]
[TestMethod]
public void weakKeyDictionary_test()
{
var weakKeyDict = new WeakKeyDictionary<int, char>();
for (int i = 0; i < 5; i++)
{
var c = (char)((int)'a' + i);
weakKeyDict[i] = c;
//Assert.AreEqual(weakKeyDict.Count, (int)(i + 1));
var v = (weakKeyDict.Count == i + 1);
Assert.IsTrue(v);
}
//Assert.AreEqual(weakKeyDict.Count, 0);
var b = (weakKeyDict.Count == 0);
Assert.IsTrue(b);
}

[TestMethod]
public void hasattr_getattr()
{


Loading…
Cancel
Save