diff --git a/src/TensorFlowNET.Core/APIs/c_api.cs b/src/TensorFlowNET.Core/APIs/c_api.cs index fd5952c8..a983d033 100644 --- a/src/TensorFlowNET.Core/APIs/c_api.cs +++ b/src/TensorFlowNET.Core/APIs/c_api.cs @@ -25,6 +25,7 @@ namespace Tensorflow /// size_t* => ref uint /// void* => IntPtr /// string => IntPtr c_api.StringPiece(IntPtr) + /// unsigned char => byte /// public static partial class c_api { diff --git a/src/TensorFlowNET.Core/Exceptions/ValueError.cs b/src/TensorFlowNET.Core/Exceptions/ValueError.cs new file mode 100644 index 00000000..92955d27 --- /dev/null +++ b/src/TensorFlowNET.Core/Exceptions/ValueError.cs @@ -0,0 +1,19 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class ValueError : Exception + { + public ValueError() : base() + { + + } + + public ValueError(string message) : base(message) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index fcc335fc..b36480db 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -27,6 +27,11 @@ namespace Tensorflow public string _graph_key; public Status Status { get; } + /// + /// Arbitrary collections of objects. + /// + private Dictionary _collections = new Dictionary(); + public Graph() { _handle = c_api.TF_NewGraph(); @@ -86,6 +91,11 @@ namespace Tensorflow throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}."); } + public void add_to_collection(string name, object value) + { + _collections[name] = value; + } + public unsafe Operation create_op(string op_type, List inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, string name = "", Dictionary attrs = null, OpDef op_def = null) @@ -221,6 +231,11 @@ namespace Tensorflow return _nodes_by_name.Values.Select(x => x).ToArray(); } + public Dictionary get_collection(string name) + { + return _collections; + } + public void Dispose() { c_api.TF_DeleteGraph(_handle); diff --git a/src/TensorFlowNET.Core/Train/Optimizer.cs b/src/TensorFlowNET.Core/Train/Optimizer.cs index 83c2e4d9..e006eec9 100644 --- a/src/TensorFlowNET.Core/Train/Optimizer.cs +++ b/src/TensorFlowNET.Core/Train/Optimizer.cs @@ -49,6 +49,7 @@ namespace Tensorflow } + var var_list = variables.trainable_variables(); return null; } } diff --git a/src/TensorFlowNET.Core/Tensors/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs similarity index 95% rename from src/TensorFlowNET.Core/Tensors/RefVariable.cs rename to src/TensorFlowNET.Core/Variables/RefVariable.cs index f65b8e9a..6f129767 100644 --- a/src/TensorFlowNET.Core/Tensors/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -4,7 +4,7 @@ using System.Text; namespace Tensorflow { - public class RefVariable : Variable + public class RefVariable : VariableV1 { public bool _in_graph_mode = true; public Tensor _initial_value; diff --git a/src/TensorFlowNET.Core/Variables/VariableScope.cs b/src/TensorFlowNET.Core/Variables/VariableScope.cs new file mode 100644 index 00000000..025660c8 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/VariableScope.cs @@ -0,0 +1,11 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class VariableScope + { + public bool? use_resource { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Variables/VariableSynchronization.cs b/src/TensorFlowNET.Core/Variables/VariableSynchronization.cs new file mode 100644 index 00000000..9d184cff --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/VariableSynchronization.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public enum VariableSynchronization + { + AUTO = 0, + NONE = 1, + ON_WRITE = 2, + ON_READ = 3 + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Variable.cs b/src/TensorFlowNET.Core/Variables/VariableV1.cs similarity index 81% rename from src/TensorFlowNET.Core/Tensors/Variable.cs rename to src/TensorFlowNET.Core/Variables/VariableV1.cs index b8031490..efd866fd 100644 --- a/src/TensorFlowNET.Core/Tensors/Variable.cs +++ b/src/TensorFlowNET.Core/Variables/VariableV1.cs @@ -14,9 +14,9 @@ namespace Tensorflow /// the variable are fixed. The value can be changed using one of the assign methods. /// https://tensorflow.org/guide/variables /// - public class Variable + public class VariableV1 { - public Variable(object initial_value, string name = "", TF_DataType trainable = TF_DataType.DtInvalid, bool validate_shape = true) + public VariableV1(object initial_value, string name = "", TF_DataType trainable = TF_DataType.DtInvalid, bool validate_shape = true) { } diff --git a/src/TensorFlowNET.Core/Variables/_VariableScopeStore.cs b/src/TensorFlowNET.Core/Variables/_VariableScopeStore.cs new file mode 100644 index 00000000..a7b3e3b5 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/_VariableScopeStore.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class _VariableScopeStore + { + public VariableScope current_scope { get; set; } + + public _VariableScopeStore() + { + current_scope = new VariableScope(); + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs new file mode 100644 index 00000000..b794c0f1 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs @@ -0,0 +1,74 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class variable_scope + { + public static string _VARSCOPESTORE_KEY = "__varscope"; + public static bool _DEFAULT_USE_RESOURCE = false; + + public static RefVariable default_variable_creator(object initial_value, string name = "", TF_DataType dtype = TF_DataType.DtInvalid, bool ? use_resource = null, VariableSynchronization synchronization = VariableSynchronization.AUTO) + { + var trainable = _get_trainable_value(synchronization); + if (!use_resource.HasValue) + { + use_resource = get_variable_scope().use_resource; + } + + if(!use_resource.HasValue) + use_resource = _DEFAULT_USE_RESOURCE; + + if (use_resource.Value) + { + throw new NotImplementedException(); + } + else + { + return new RefVariable(initial_value); + } + } + + public static VariableScope get_variable_scope() + { + return get_variable_scope_store().current_scope; + } + + public static _VariableScopeStore get_variable_scope_store() + { + var scope_store = ops.get_collection(_VARSCOPESTORE_KEY); + if (scope_store == null) + { + scope_store = new _VariableScopeStore(); + ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store); + } + else + { + // scope_store = scope_store[0]; + } + + return scope_store; + } + + public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = null) + { + if(synchronization == VariableSynchronization.ON_READ) + { + if (trainable.Value) + throw new ValueError("Synchronization value can be set to " + + "VariableSynchronization.ON_READ only for non-trainable variables. " + + "You have specified trainable=True and " + + "synchronization=VariableSynchronization.ON_READ."); + else + trainable = false; + } + else if (!trainable.HasValue) + { + trainable = true; + } + + return trainable.Value; + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/variables.py.cs b/src/TensorFlowNET.Core/Variables/variables.py.cs new file mode 100644 index 00000000..9a2602b2 --- /dev/null +++ b/src/TensorFlowNET.Core/Variables/variables.py.cs @@ -0,0 +1,18 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class variables + { + /// + /// Returns all variables created with `trainable=True` + /// + /// + public static object trainable_variables() + { + return ops.get_collection(ops.GraphKey.TRAINABLE_VARIABLES); + } + } +} diff --git a/src/TensorFlowNET.Core/ops.GraphKeys.py.cs b/src/TensorFlowNET.Core/ops.GraphKeys.py.cs new file mode 100644 index 00000000..a7f03cf9 --- /dev/null +++ b/src/TensorFlowNET.Core/ops.GraphKeys.py.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public partial class ops + { + /// + /// Standard names to use for graph collections. + /// The standard library uses various well-known names to collect and + /// retrieve values associated with a graph. For example, the + /// `tf.Optimizer` subclasses default to optimizing the variables + /// collected under `tf.GraphKeys.TRAINABLE_VARIABLES` if none is + /// specified, but it is also possible to pass an explicit list of + /// variables. + /// + public static class GraphKey + { + /// + /// the subset of `Variable` objects that will be trained by an optimizer. + /// + public static string TRAINABLE_VARIABLES = "trainable_variables"; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/ops.cs b/src/TensorFlowNET.Core/ops.py.cs similarity index 89% rename from src/TensorFlowNET.Core/Operations/ops.cs rename to src/TensorFlowNET.Core/ops.py.cs index 02b7ca11..a0dfc52a 100644 --- a/src/TensorFlowNET.Core/Operations/ops.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -10,8 +10,19 @@ using System.Linq; namespace Tensorflow { - public static class ops + public partial class ops { + public static void add_to_collection(string name, object value) + { + var graph = tf.get_default_graph(); + graph.add_to_collection(name, value); + } + + public static _VariableScopeStore get_collection(string key) + { + return null;// get_default_graph().get_collection(key); + } + public static Graph get_default_graph() { return tf.Graph(); diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs index 2ed1e223..8c4d5611 100644 --- a/src/TensorFlowNET.Core/tf.cs +++ b/src/TensorFlowNET.Core/tf.cs @@ -22,7 +22,7 @@ namespace Tensorflow public static RefVariable Variable(T data, string name = "", TF_DataType dtype = TF_DataType.DtInvalid) { - return new RefVariable(data, name, dtype); + return variable_scope.default_variable_creator(data, name: name, dtype: TF_DataType.DtInvalid); } public static unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null) diff --git a/test/TensorFlowNET.Examples/LinearRegression.cs b/test/TensorFlowNET.Examples/LinearRegression.cs index ae303fe9..2dc2d297 100644 --- a/test/TensorFlowNET.Examples/LinearRegression.cs +++ b/test/TensorFlowNET.Examples/LinearRegression.cs @@ -43,18 +43,13 @@ namespace TensorFlowNET.Examples var sub = pred - Y; var pow = tf.pow(sub, 2); - - - - - - var reduce = tf.reduce_sum(pow); var cost = reduce / (2d * n_samples); // radient descent // Note, minimize() knows to modify W and b because Variable objects are trainable=True by default - var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); + var optimizer = tf.train.GradientDescentOptimizer(learning_rate); + optimizer.minimize(cost); } } }