diff --git a/src/TensorFlowNET.Core/APIs/c_api.cs b/src/TensorFlowNET.Core/APIs/c_api.cs
index fd5952c8..a983d033 100644
--- a/src/TensorFlowNET.Core/APIs/c_api.cs
+++ b/src/TensorFlowNET.Core/APIs/c_api.cs
@@ -25,6 +25,7 @@ namespace Tensorflow
/// size_t* => ref uint
/// void* => IntPtr
/// string => IntPtr c_api.StringPiece(IntPtr)
+ /// unsigned char => byte
///
public static partial class c_api
{
diff --git a/src/TensorFlowNET.Core/Exceptions/ValueError.cs b/src/TensorFlowNET.Core/Exceptions/ValueError.cs
new file mode 100644
index 00000000..92955d27
--- /dev/null
+++ b/src/TensorFlowNET.Core/Exceptions/ValueError.cs
@@ -0,0 +1,19 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public class ValueError : Exception
+ {
+ public ValueError() : base()
+ {
+
+ }
+
+ public ValueError(string message) : base(message)
+ {
+
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs
index fcc335fc..b36480db 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.cs
@@ -27,6 +27,11 @@ namespace Tensorflow
public string _graph_key;
public Status Status { get; }
+ ///
+ /// Arbitrary collections of objects.
+ ///
+ private Dictionary _collections = new Dictionary();
+
public Graph()
{
_handle = c_api.TF_NewGraph();
@@ -86,6 +91,11 @@ namespace Tensorflow
throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}.");
}
+ public void add_to_collection(string name, object value)
+ {
+ _collections[name] = value;
+ }
+
public unsafe Operation create_op(string op_type, List inputs, TF_DataType[] dtypes,
TF_DataType[] input_types = null, string name = "",
Dictionary attrs = null, OpDef op_def = null)
@@ -221,6 +231,11 @@ namespace Tensorflow
return _nodes_by_name.Values.Select(x => x).ToArray();
}
+ public Dictionary get_collection(string name)
+ {
+ return _collections;
+ }
+
public void Dispose()
{
c_api.TF_DeleteGraph(_handle);
diff --git a/src/TensorFlowNET.Core/Train/Optimizer.cs b/src/TensorFlowNET.Core/Train/Optimizer.cs
index 83c2e4d9..e006eec9 100644
--- a/src/TensorFlowNET.Core/Train/Optimizer.cs
+++ b/src/TensorFlowNET.Core/Train/Optimizer.cs
@@ -49,6 +49,7 @@ namespace Tensorflow
}
+ var var_list = variables.trainable_variables();
return null;
}
}
diff --git a/src/TensorFlowNET.Core/Tensors/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs
similarity index 95%
rename from src/TensorFlowNET.Core/Tensors/RefVariable.cs
rename to src/TensorFlowNET.Core/Variables/RefVariable.cs
index f65b8e9a..6f129767 100644
--- a/src/TensorFlowNET.Core/Tensors/RefVariable.cs
+++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs
@@ -4,7 +4,7 @@ using System.Text;
namespace Tensorflow
{
- public class RefVariable : Variable
+ public class RefVariable : VariableV1
{
public bool _in_graph_mode = true;
public Tensor _initial_value;
diff --git a/src/TensorFlowNET.Core/Variables/VariableScope.cs b/src/TensorFlowNET.Core/Variables/VariableScope.cs
new file mode 100644
index 00000000..025660c8
--- /dev/null
+++ b/src/TensorFlowNET.Core/Variables/VariableScope.cs
@@ -0,0 +1,11 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public class VariableScope
+ {
+ public bool? use_resource { get; set; }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Variables/VariableSynchronization.cs b/src/TensorFlowNET.Core/Variables/VariableSynchronization.cs
new file mode 100644
index 00000000..9d184cff
--- /dev/null
+++ b/src/TensorFlowNET.Core/Variables/VariableSynchronization.cs
@@ -0,0 +1,14 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public enum VariableSynchronization
+ {
+ AUTO = 0,
+ NONE = 1,
+ ON_WRITE = 2,
+ ON_READ = 3
+ }
+}
diff --git a/src/TensorFlowNET.Core/Tensors/Variable.cs b/src/TensorFlowNET.Core/Variables/VariableV1.cs
similarity index 81%
rename from src/TensorFlowNET.Core/Tensors/Variable.cs
rename to src/TensorFlowNET.Core/Variables/VariableV1.cs
index b8031490..efd866fd 100644
--- a/src/TensorFlowNET.Core/Tensors/Variable.cs
+++ b/src/TensorFlowNET.Core/Variables/VariableV1.cs
@@ -14,9 +14,9 @@ namespace Tensorflow
/// the variable are fixed. The value can be changed using one of the assign methods.
/// https://tensorflow.org/guide/variables
///
- public class Variable
+ public class VariableV1
{
- public Variable(object initial_value, string name = "", TF_DataType trainable = TF_DataType.DtInvalid, bool validate_shape = true)
+ public VariableV1(object initial_value, string name = "", TF_DataType trainable = TF_DataType.DtInvalid, bool validate_shape = true)
{
}
diff --git a/src/TensorFlowNET.Core/Variables/_VariableScopeStore.cs b/src/TensorFlowNET.Core/Variables/_VariableScopeStore.cs
new file mode 100644
index 00000000..a7b3e3b5
--- /dev/null
+++ b/src/TensorFlowNET.Core/Variables/_VariableScopeStore.cs
@@ -0,0 +1,16 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public class _VariableScopeStore
+ {
+ public VariableScope current_scope { get; set; }
+
+ public _VariableScopeStore()
+ {
+ current_scope = new VariableScope();
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs
new file mode 100644
index 00000000..b794c0f1
--- /dev/null
+++ b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs
@@ -0,0 +1,74 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public class variable_scope
+ {
+ public static string _VARSCOPESTORE_KEY = "__varscope";
+ public static bool _DEFAULT_USE_RESOURCE = false;
+
+ public static RefVariable default_variable_creator(object initial_value, string name = "", TF_DataType dtype = TF_DataType.DtInvalid, bool ? use_resource = null, VariableSynchronization synchronization = VariableSynchronization.AUTO)
+ {
+ var trainable = _get_trainable_value(synchronization);
+ if (!use_resource.HasValue)
+ {
+ use_resource = get_variable_scope().use_resource;
+ }
+
+ if(!use_resource.HasValue)
+ use_resource = _DEFAULT_USE_RESOURCE;
+
+ if (use_resource.Value)
+ {
+ throw new NotImplementedException();
+ }
+ else
+ {
+ return new RefVariable(initial_value);
+ }
+ }
+
+ public static VariableScope get_variable_scope()
+ {
+ return get_variable_scope_store().current_scope;
+ }
+
+ public static _VariableScopeStore get_variable_scope_store()
+ {
+ var scope_store = ops.get_collection(_VARSCOPESTORE_KEY);
+ if (scope_store == null)
+ {
+ scope_store = new _VariableScopeStore();
+ ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store);
+ }
+ else
+ {
+ // scope_store = scope_store[0];
+ }
+
+ return scope_store;
+ }
+
+ public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = null)
+ {
+ if(synchronization == VariableSynchronization.ON_READ)
+ {
+ if (trainable.Value)
+ throw new ValueError("Synchronization value can be set to " +
+ "VariableSynchronization.ON_READ only for non-trainable variables. " +
+ "You have specified trainable=True and " +
+ "synchronization=VariableSynchronization.ON_READ.");
+ else
+ trainable = false;
+ }
+ else if (!trainable.HasValue)
+ {
+ trainable = true;
+ }
+
+ return trainable.Value;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Variables/variables.py.cs b/src/TensorFlowNET.Core/Variables/variables.py.cs
new file mode 100644
index 00000000..9a2602b2
--- /dev/null
+++ b/src/TensorFlowNET.Core/Variables/variables.py.cs
@@ -0,0 +1,18 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public class variables
+ {
+ ///
+ /// Returns all variables created with `trainable=True`
+ ///
+ ///
+ public static object trainable_variables()
+ {
+ return ops.get_collection(ops.GraphKey.TRAINABLE_VARIABLES);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/ops.GraphKeys.py.cs b/src/TensorFlowNET.Core/ops.GraphKeys.py.cs
new file mode 100644
index 00000000..a7f03cf9
--- /dev/null
+++ b/src/TensorFlowNET.Core/ops.GraphKeys.py.cs
@@ -0,0 +1,26 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public partial class ops
+ {
+ ///
+ /// Standard names to use for graph collections.
+ /// The standard library uses various well-known names to collect and
+ /// retrieve values associated with a graph. For example, the
+ /// `tf.Optimizer` subclasses default to optimizing the variables
+ /// collected under `tf.GraphKeys.TRAINABLE_VARIABLES` if none is
+ /// specified, but it is also possible to pass an explicit list of
+ /// variables.
+ ///
+ public static class GraphKey
+ {
+ ///
+ /// the subset of `Variable` objects that will be trained by an optimizer.
+ ///
+ public static string TRAINABLE_VARIABLES = "trainable_variables";
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/ops.cs b/src/TensorFlowNET.Core/ops.py.cs
similarity index 89%
rename from src/TensorFlowNET.Core/Operations/ops.cs
rename to src/TensorFlowNET.Core/ops.py.cs
index 02b7ca11..a0dfc52a 100644
--- a/src/TensorFlowNET.Core/Operations/ops.cs
+++ b/src/TensorFlowNET.Core/ops.py.cs
@@ -10,8 +10,19 @@ using System.Linq;
namespace Tensorflow
{
- public static class ops
+ public partial class ops
{
+ public static void add_to_collection(string name, object value)
+ {
+ var graph = tf.get_default_graph();
+ graph.add_to_collection(name, value);
+ }
+
+ public static _VariableScopeStore get_collection(string key)
+ {
+ return null;// get_default_graph().get_collection(key);
+ }
+
public static Graph get_default_graph()
{
return tf.Graph();
diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs
index 2ed1e223..8c4d5611 100644
--- a/src/TensorFlowNET.Core/tf.cs
+++ b/src/TensorFlowNET.Core/tf.cs
@@ -22,7 +22,7 @@ namespace Tensorflow
public static RefVariable Variable(T data, string name = "", TF_DataType dtype = TF_DataType.DtInvalid)
{
- return new RefVariable(data, name, dtype);
+ return variable_scope.default_variable_creator(data, name: name, dtype: TF_DataType.DtInvalid);
}
public static unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null)
diff --git a/test/TensorFlowNET.Examples/LinearRegression.cs b/test/TensorFlowNET.Examples/LinearRegression.cs
index ae303fe9..2dc2d297 100644
--- a/test/TensorFlowNET.Examples/LinearRegression.cs
+++ b/test/TensorFlowNET.Examples/LinearRegression.cs
@@ -43,18 +43,13 @@ namespace TensorFlowNET.Examples
var sub = pred - Y;
var pow = tf.pow(sub, 2);
-
-
-
-
-
-
var reduce = tf.reduce_sum(pow);
var cost = reduce / (2d * n_samples);
// radient descent
// Note, minimize() knows to modify W and b because Variable objects are trainable=True by default
- var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost);
+ var optimizer = tf.train.GradientDescentOptimizer(learning_rate);
+ optimizer.minimize(cost);
}
}
}