From bf45277be88c3a96adf335e63ba93b5889fdd7cf Mon Sep 17 00:00:00 2001 From: haiping008 Date: Thu, 7 Feb 2019 17:20:16 -0600 Subject: [PATCH] add VariableScope and _VariableStore --- .../Operations/IInitializer.cs | 12 +++ .../Operations/tf.init_ops.cs | 34 ++++++++ src/TensorFlowNET.Core/Tensors/dtypes.cs | 5 ++ .../Variables/VariableAggregation.cs | 14 ++++ .../Variables/VariableScope.cs | 31 ++++++- .../Variables/_ReuseMode.cs | 16 ++++ .../Variables/_VariableStore.cs | 80 +++++++++++++++++++ .../Variables/tf.variable.cs | 10 +++ .../Variables/variable_scope.py.cs | 28 ++++--- test/TensorFlowNET.UnitTest/TrainSaverTest.cs | 21 +++++ .../python/train_saver.py | 26 ++++++ 11 files changed, 265 insertions(+), 12 deletions(-) create mode 100644 src/TensorFlowNET.Core/Operations/IInitializer.cs create mode 100644 src/TensorFlowNET.Core/Operations/tf.init_ops.cs create mode 100644 src/TensorFlowNET.Core/Variables/VariableAggregation.cs create mode 100644 src/TensorFlowNET.Core/Variables/_ReuseMode.cs create mode 100644 src/TensorFlowNET.Core/Variables/_VariableStore.cs create mode 100644 test/TensorFlowNET.UnitTest/TrainSaverTest.cs create mode 100644 test/TensorFlowNET.UnitTest/python/train_saver.py diff --git a/src/TensorFlowNET.Core/Operations/IInitializer.cs b/src/TensorFlowNET.Core/Operations/IInitializer.cs new file mode 100644 index 00000000..6382e3e0 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/IInitializer.cs @@ -0,0 +1,12 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public interface IInitializer + { + Tensor call(TensorShape shape, TF_DataType dtype); + object get_config(); + } +} diff --git a/src/TensorFlowNET.Core/Operations/tf.init_ops.cs b/src/TensorFlowNET.Core/Operations/tf.init_ops.cs new file mode 100644 index 00000000..584f3b01 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/tf.init_ops.cs @@ -0,0 +1,34 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public static partial class tf + { + public static IInitializer zeros_initializer => new Zeros(); + + public class Zeros : IInitializer + { + private TF_DataType dtype; + + public Zeros(TF_DataType dtype = TF_DataType.TF_FLOAT) + { + this.dtype = dtype; + } + + public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) + { + if (dtype == TF_DataType.DtInvalid) + dtype = this.dtype; + + return array_ops.zeros(shape, dtype); + } + + public object get_config() + { + return new { dtype = dtype.name() }; + } + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index 286ea603..af429ee1 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -71,6 +71,11 @@ namespace Tensorflow type; } + public static int name(this TF_DataType type) + { + return (int)type; + } + public static DataType as_base_dtype(this DataType type) { return (int)type > 100 ? diff --git a/src/TensorFlowNET.Core/Variables/VariableAggregation.cs b/src/TensorFlowNET.Core/Variables/VariableAggregation.cs new file mode 100644 index 00000000..3f7e4ff0 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/VariableAggregation.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public enum VariableAggregation + { + NONE = 0, + SUM = 1, + MEAN = 2, + ONLY_FIRST_REPLICA = 3 // ONLY_FIRST_TOWER + } +} diff --git a/src/TensorFlowNET.Core/Variables/VariableScope.cs b/src/TensorFlowNET.Core/Variables/VariableScope.cs index 025660c8..2d2efdfa 100644 --- a/src/TensorFlowNET.Core/Variables/VariableScope.cs +++ b/src/TensorFlowNET.Core/Variables/VariableScope.cs @@ -6,6 +6,35 @@ namespace Tensorflow { public class VariableScope { - public bool? use_resource { get; set; } + public bool use_resource { get; set; } + private _ReuseMode _reuse { get; set; } + + private object _regularizer; + private TF_DataType _dtype; + public string name { get; set; } + + public VariableScope() + { + _reuse = _ReuseMode.AUTO_REUSE; + } + + public RefVariable get_variable(_VariableStore var_store, + string name, + TensorShape shape = null, + TF_DataType dtype = TF_DataType.DtInvalid, + VariableSynchronization synchronization = VariableSynchronization.AUTO, + VariableAggregation aggregation= VariableAggregation.NONE) + { + string full_name = !string.IsNullOrEmpty(this.name) ? this.name + "/" + name : name; + return Python.with(new ops.name_scope(""), scope => + { + if (dtype == TF_DataType.DtInvalid) + dtype = _dtype; + + return var_store.get_variable(full_name); + + }); + + } } } diff --git a/src/TensorFlowNET.Core/Variables/_ReuseMode.cs b/src/TensorFlowNET.Core/Variables/_ReuseMode.cs new file mode 100644 index 00000000..f2717310 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/_ReuseMode.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + /// + /// Mode for variable access within a variable scope. + /// + public enum _ReuseMode + { + // Indicates that variables are to be fetched if they already exist or + // otherwise created. + AUTO_REUSE = 1 + } +} diff --git a/src/TensorFlowNET.Core/Variables/_VariableStore.cs b/src/TensorFlowNET.Core/Variables/_VariableStore.cs new file mode 100644 index 00000000..27f16c55 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/_VariableStore.cs @@ -0,0 +1,80 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + /// + /// Variable store that carries a number of named Variables. + /// + public class _VariableStore + { + private Dictionary _vars; + private Dictionary _partitioned_vars; + private bool _store_eager_variables; + + public _VariableStore() + { + _vars = new Dictionary(); + _partitioned_vars = new Dictionary(); + _store_eager_variables = false; + } + + public RefVariable get_variable(string name, + TensorShape shape = null, + TF_DataType dtype = TF_DataType.TF_FLOAT, + IInitializer initializer = null, + bool trainable = false, + bool validate_shape = true, + VariableSynchronization synchronization = VariableSynchronization.AUTO, + VariableAggregation aggregation = VariableAggregation.NONE) + { + dtype = dtype.as_base_dtype(); + trainable = variable_scope._get_trainable_value(synchronization, trainable); + + return _true_getter(name, + shape: shape, + dtype: dtype, + initializer: initializer, + trainable: trainable, + validate_shape: validate_shape, + synchronization: synchronization, + aggregation: aggregation); + } + + private RefVariable _true_getter(string name, + TensorShape shape = null, + TF_DataType dtype = TF_DataType.DtInvalid, + IInitializer initializer = null, + bool trainable = false, + bool validate_shape = true, + VariableSynchronization synchronization = VariableSynchronization.AUTO, + VariableAggregation aggregation = VariableAggregation.NONE) + { + return _get_single_variable(name: name); + } + + private RefVariable _get_single_variable(string name, + TensorShape shape = null, + TF_DataType dtype = TF_DataType.DtInvalid, + IInitializer initializer = null, + bool reuse = false, + bool trainable = false, + bool validate_shape = false, + VariableSynchronization synchronization = VariableSynchronization.AUTO, + VariableAggregation aggregation = VariableAggregation.NONE) + { + if (_vars.ContainsKey(name)) + { + if (!reuse) + { + var var = _vars[name]; + + } + throw new NotImplementedException("_get_single_variable"); + } + + throw new NotImplementedException("_get_single_variable"); + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/tf.variable.cs b/src/TensorFlowNET.Core/Variables/tf.variable.cs index 06680b05..6515399a 100644 --- a/src/TensorFlowNET.Core/Variables/tf.variable.cs +++ b/src/TensorFlowNET.Core/Variables/tf.variable.cs @@ -11,5 +11,15 @@ namespace Tensorflow var g = variables.global_variables(); return variables.variables_initializer(g.ToArray()); } + + public static RefVariable get_variable(string name, + TensorShape shape = null, + IInitializer initializer = null, + VariableSynchronization synchronization = VariableSynchronization.AUTO, + VariableAggregation aggregation = VariableAggregation.NONE) + { + var store = variable_scope._get_default_variable_store(); + return variable_scope.get_variable_scope().get_variable(store, name, shape: shape); + } } } diff --git a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs index 1c64d591..b7d2662a 100644 --- a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs +++ b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs @@ -6,6 +6,7 @@ namespace Tensorflow { public class variable_scope { + public static string _VARSTORE_KEY = "__variable_store"; public static string _VARSCOPESTORE_KEY = "__varscope"; public static bool _DEFAULT_USE_RESOURCE = false; @@ -32,6 +33,17 @@ namespace Tensorflow } } + public static _VariableStore _get_default_variable_store() + { + var store = ops.get_collection(_VARSTORE_KEY); + if (store != null) + return (store as List<_VariableStore>)[0]; + + var store1 = new _VariableStore(); + ops.add_to_collection(_VARSTORE_KEY, store1); + return store1; + } + public static VariableScope get_variable_scope() { return get_variable_scope_store().current_scope; @@ -65,24 +77,18 @@ namespace Tensorflow return ret; } - public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = null) + public static bool _get_trainable_value(VariableSynchronization synchronization, bool trainable = true) { - if(synchronization == VariableSynchronization.ON_READ) + if (synchronization == VariableSynchronization.ON_READ) { - if (trainable.Value) + if (trainable) throw new ValueError("Synchronization value can be set to " + "VariableSynchronization.ON_READ only for non-trainable variables. " + "You have specified trainable=True and " + "synchronization=VariableSynchronization.ON_READ."); - else - trainable = false; } - else if (!trainable.HasValue) - { - trainable = true; - } - - return trainable.Value; + + return trainable; } } } diff --git a/test/TensorFlowNET.UnitTest/TrainSaverTest.cs b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs new file mode 100644 index 00000000..1bb5fc3d --- /dev/null +++ b/test/TensorFlowNET.UnitTest/TrainSaverTest.cs @@ -0,0 +1,21 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow; + +namespace TensorFlowNET.UnitTest +{ + [TestClass] + public class TrainSaverTest + { + [TestMethod] + public void Save() + { + var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer); + var v2 = tf.get_variable("v2", shape: new TensorShape(5), initializer: tf.zeros_initializer); + + + } + } +} diff --git a/test/TensorFlowNET.UnitTest/python/train_saver.py b/test/TensorFlowNET.UnitTest/python/train_saver.py new file mode 100644 index 00000000..47ffd6a1 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/python/train_saver.py @@ -0,0 +1,26 @@ + +import tensorflow as tf + +# Create some variables. +v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer) +v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer) + +inc_v1 = v1.assign(v1+1) +dec_v2 = v2.assign(v2-1) + +# Add an op to initialize the variables. +init_op = tf.global_variables_initializer() + +# Add ops to save and restore all the variables. +saver = tf.train.Saver() + +# Later, launch the model, initialize the variables, do some work, and save the +# variables to disk. +with tf.Session() as sess: + sess.run(init_op) + # Do some work with the model. + inc_v1.op.run() + dec_v2.op.run() + # Save the variables to disk. + save_path = saver.save(sess, "/tmp/model.ckpt") + print("Model saved in path: %s" % save_path)