@@ -25,6 +25,7 @@ namespace Tensorflow | |||||
/// size_t* => ref uint | /// size_t* => ref uint | ||||
/// void* => IntPtr | /// void* => IntPtr | ||||
/// string => IntPtr c_api.StringPiece(IntPtr) | /// string => IntPtr c_api.StringPiece(IntPtr) | ||||
/// unsigned char => byte | |||||
/// </summary> | /// </summary> | ||||
public static partial class c_api | public static partial class c_api | ||||
{ | { | ||||
@@ -0,0 +1,19 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class ValueError : Exception | |||||
{ | |||||
public ValueError() : base() | |||||
{ | |||||
} | |||||
public ValueError(string message) : base(message) | |||||
{ | |||||
} | |||||
} | |||||
} |
@@ -27,6 +27,11 @@ namespace Tensorflow | |||||
public string _graph_key; | public string _graph_key; | ||||
public Status Status { get; } | public Status Status { get; } | ||||
/// <summary> | |||||
/// Arbitrary collections of objects. | |||||
/// </summary> | |||||
private Dictionary<string, object> _collections = new Dictionary<string, object>(); | |||||
public Graph() | public Graph() | ||||
{ | { | ||||
_handle = c_api.TF_NewGraph(); | _handle = c_api.TF_NewGraph(); | ||||
@@ -86,6 +91,11 @@ namespace Tensorflow | |||||
throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}."); | throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}."); | ||||
} | } | ||||
public void add_to_collection(string name, object value) | |||||
{ | |||||
_collections[name] = value; | |||||
} | |||||
public unsafe Operation create_op(string op_type, List<Tensor> inputs, TF_DataType[] dtypes, | public unsafe Operation create_op(string op_type, List<Tensor> inputs, TF_DataType[] dtypes, | ||||
TF_DataType[] input_types = null, string name = "", | TF_DataType[] input_types = null, string name = "", | ||||
Dictionary<string, AttrValue> attrs = null, OpDef op_def = null) | Dictionary<string, AttrValue> attrs = null, OpDef op_def = null) | ||||
@@ -221,6 +231,11 @@ namespace Tensorflow | |||||
return _nodes_by_name.Values.Select(x => x).ToArray(); | return _nodes_by_name.Values.Select(x => x).ToArray(); | ||||
} | } | ||||
public Dictionary<string, object> get_collection(string name) | |||||
{ | |||||
return _collections; | |||||
} | |||||
public void Dispose() | public void Dispose() | ||||
{ | { | ||||
c_api.TF_DeleteGraph(_handle); | c_api.TF_DeleteGraph(_handle); | ||||
@@ -49,6 +49,7 @@ namespace Tensorflow | |||||
} | } | ||||
var var_list = variables.trainable_variables(); | |||||
return null; | return null; | ||||
} | } | ||||
} | } | ||||
@@ -4,7 +4,7 @@ using System.Text; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public class RefVariable : Variable | |||||
public class RefVariable : VariableV1 | |||||
{ | { | ||||
public bool _in_graph_mode = true; | public bool _in_graph_mode = true; | ||||
public Tensor _initial_value; | public Tensor _initial_value; |
@@ -0,0 +1,11 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class VariableScope | |||||
{ | |||||
public bool? use_resource { get; set; } | |||||
} | |||||
} |
@@ -0,0 +1,14 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public enum VariableSynchronization | |||||
{ | |||||
AUTO = 0, | |||||
NONE = 1, | |||||
ON_WRITE = 2, | |||||
ON_READ = 3 | |||||
} | |||||
} |
@@ -14,9 +14,9 @@ namespace Tensorflow | |||||
/// the variable are fixed. The value can be changed using one of the assign methods. | /// the variable are fixed. The value can be changed using one of the assign methods. | ||||
/// https://tensorflow.org/guide/variables | /// https://tensorflow.org/guide/variables | ||||
/// </summary> | /// </summary> | ||||
public class Variable | |||||
public class VariableV1 | |||||
{ | { | ||||
public Variable(object initial_value, string name = "", TF_DataType trainable = TF_DataType.DtInvalid, bool validate_shape = true) | |||||
public VariableV1(object initial_value, string name = "", TF_DataType trainable = TF_DataType.DtInvalid, bool validate_shape = true) | |||||
{ | { | ||||
} | } |
@@ -0,0 +1,16 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class _VariableScopeStore | |||||
{ | |||||
public VariableScope current_scope { get; set; } | |||||
public _VariableScopeStore() | |||||
{ | |||||
current_scope = new VariableScope(); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,74 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class variable_scope | |||||
{ | |||||
public static string _VARSCOPESTORE_KEY = "__varscope"; | |||||
public static bool _DEFAULT_USE_RESOURCE = false; | |||||
public static RefVariable default_variable_creator(object initial_value, string name = "", TF_DataType dtype = TF_DataType.DtInvalid, bool ? use_resource = null, VariableSynchronization synchronization = VariableSynchronization.AUTO) | |||||
{ | |||||
var trainable = _get_trainable_value(synchronization); | |||||
if (!use_resource.HasValue) | |||||
{ | |||||
use_resource = get_variable_scope().use_resource; | |||||
} | |||||
if(!use_resource.HasValue) | |||||
use_resource = _DEFAULT_USE_RESOURCE; | |||||
if (use_resource.Value) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
else | |||||
{ | |||||
return new RefVariable(initial_value); | |||||
} | |||||
} | |||||
public static VariableScope get_variable_scope() | |||||
{ | |||||
return get_variable_scope_store().current_scope; | |||||
} | |||||
public static _VariableScopeStore get_variable_scope_store() | |||||
{ | |||||
var scope_store = ops.get_collection(_VARSCOPESTORE_KEY); | |||||
if (scope_store == null) | |||||
{ | |||||
scope_store = new _VariableScopeStore(); | |||||
ops.add_to_collection(_VARSCOPESTORE_KEY, scope_store); | |||||
} | |||||
else | |||||
{ | |||||
// scope_store = scope_store[0]; | |||||
} | |||||
return scope_store; | |||||
} | |||||
public static bool _get_trainable_value(VariableSynchronization synchronization, bool? trainable = null) | |||||
{ | |||||
if(synchronization == VariableSynchronization.ON_READ) | |||||
{ | |||||
if (trainable.Value) | |||||
throw new ValueError("Synchronization value can be set to " + | |||||
"VariableSynchronization.ON_READ only for non-trainable variables. " + | |||||
"You have specified trainable=True and " + | |||||
"synchronization=VariableSynchronization.ON_READ."); | |||||
else | |||||
trainable = false; | |||||
} | |||||
else if (!trainable.HasValue) | |||||
{ | |||||
trainable = true; | |||||
} | |||||
return trainable.Value; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,18 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class variables | |||||
{ | |||||
/// <summary> | |||||
/// Returns all variables created with `trainable=True` | |||||
/// </summary> | |||||
/// <returns></returns> | |||||
public static object trainable_variables() | |||||
{ | |||||
return ops.get_collection(ops.GraphKey.TRAINABLE_VARIABLES); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,26 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public partial class ops | |||||
{ | |||||
/// <summary> | |||||
/// Standard names to use for graph collections. | |||||
/// The standard library uses various well-known names to collect and | |||||
/// retrieve values associated with a graph. For example, the | |||||
/// `tf.Optimizer` subclasses default to optimizing the variables | |||||
/// collected under `tf.GraphKeys.TRAINABLE_VARIABLES` if none is | |||||
/// specified, but it is also possible to pass an explicit list of | |||||
/// variables. | |||||
/// </summary> | |||||
public static class GraphKey | |||||
{ | |||||
/// <summary> | |||||
/// the subset of `Variable` objects that will be trained by an optimizer. | |||||
/// </summary> | |||||
public static string TRAINABLE_VARIABLES = "trainable_variables"; | |||||
} | |||||
} | |||||
} |
@@ -10,8 +10,19 @@ using System.Linq; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public static class ops | |||||
public partial class ops | |||||
{ | { | ||||
public static void add_to_collection(string name, object value) | |||||
{ | |||||
var graph = tf.get_default_graph(); | |||||
graph.add_to_collection(name, value); | |||||
} | |||||
public static _VariableScopeStore get_collection(string key) | |||||
{ | |||||
return null;// get_default_graph().get_collection(key); | |||||
} | |||||
public static Graph get_default_graph() | public static Graph get_default_graph() | ||||
{ | { | ||||
return tf.Graph(); | return tf.Graph(); |
@@ -22,7 +22,7 @@ namespace Tensorflow | |||||
public static RefVariable Variable<T>(T data, string name = "", TF_DataType dtype = TF_DataType.DtInvalid) | public static RefVariable Variable<T>(T data, string name = "", TF_DataType dtype = TF_DataType.DtInvalid) | ||||
{ | { | ||||
return new RefVariable(data, name, dtype); | |||||
return variable_scope.default_variable_creator(data, name: name, dtype: TF_DataType.DtInvalid); | |||||
} | } | ||||
public static unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null) | public static unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null) | ||||
@@ -43,18 +43,13 @@ namespace TensorFlowNET.Examples | |||||
var sub = pred - Y; | var sub = pred - Y; | ||||
var pow = tf.pow(sub, 2); | var pow = tf.pow(sub, 2); | ||||
var reduce = tf.reduce_sum(pow); | var reduce = tf.reduce_sum(pow); | ||||
var cost = reduce / (2d * n_samples); | var cost = reduce / (2d * n_samples); | ||||
// radient descent | // radient descent | ||||
// Note, minimize() knows to modify W and b because Variable objects are trainable=True by default | // Note, minimize() knows to modify W and b because Variable objects are trainable=True by default | ||||
var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); | |||||
var optimizer = tf.train.GradientDescentOptimizer(learning_rate); | |||||
optimizer.minimize(cost); | |||||
} | } | ||||
} | } | ||||
} | } |