@@ -24,6 +24,8 @@ namespace Tensorflow | |||||
private int _next_id_counter; | private int _next_id_counter; | ||||
private List<String> _unfetchable_ops = new List<string>(); | private List<String> _unfetchable_ops = new List<string>(); | ||||
private string _name_stack; | |||||
public Graph(IntPtr graph) | public Graph(IntPtr graph) | ||||
{ | { | ||||
this._c_graph = graph; | this._c_graph = graph; | ||||
@@ -126,8 +128,31 @@ namespace Tensorflow | |||||
return false; | return false; | ||||
} | } | ||||
public string name_scope(string name) | |||||
{ | |||||
string new_stack = ""; | |||||
if (name.EndsWith("/")) | |||||
{ | |||||
new_stack = ops._name_from_scope_name(name); | |||||
} | |||||
else | |||||
{ | |||||
new_stack = unique_name(name); | |||||
} | |||||
_name_stack = new_stack; | |||||
return String.IsNullOrEmpty(new_stack) ? "" : new_stack + "/"; | |||||
} | |||||
public string unique_name(string name) | public string unique_name(string name) | ||||
{ | { | ||||
if (!String.IsNullOrEmpty(_name_stack)) | |||||
{ | |||||
name = _name_stack + "/" + name; | |||||
} | |||||
var name_key = name.ToLower(); | var name_key = name.ToLower(); | ||||
if (_names_in_use.ContainsKey(name_key)) | if (_names_in_use.ContainsKey(name_key)) | ||||
{ | { | ||||
@@ -138,7 +163,6 @@ namespace Tensorflow | |||||
_names_in_use[name_key] = 1; | _names_in_use[name_key] = 1; | ||||
return name; | return name; | ||||
} | } | ||||
return $"{name}_{_names_in_use[name_key]}"; | return $"{name}_{_names_in_use[name_key]}"; | ||||
} | } | ||||
@@ -38,7 +38,7 @@ namespace Tensorflow | |||||
private static OpDefLibrary _InitOpDefLibrary() | private static OpDefLibrary _InitOpDefLibrary() | ||||
{ | { | ||||
// c_api.TF_GraphGetOpDef(g.Handle, op_type_name, buffer.Handle, status.Handle); | // c_api.TF_GraphGetOpDef(g.Handle, op_type_name, buffer.Handle, status.Handle); | ||||
var bytes = File.ReadAllBytes("Tensorflow/op_list_proto_array.bin"); | |||||
var bytes = File.ReadAllBytes("Operations/op_list_proto_array.bin"); | |||||
var op_list = OpList.Parser.ParseFrom(bytes); | var op_list = OpList.Parser.ParseFrom(bytes); | ||||
var op_def_lib = new OpDefLibrary(); | var op_def_lib = new OpDefLibrary(); | ||||
op_def_lib.add_op_list(op_list); | op_def_lib.add_op_list(op_list); | ||||
@@ -71,6 +71,20 @@ namespace Tensorflow | |||||
return node_def; | return node_def; | ||||
} | } | ||||
public static string name_scope(string name, string default_name = "", object values = null) | |||||
{ | |||||
string _name = ""; | |||||
if (String.IsNullOrEmpty(name)) | |||||
{ | |||||
_name = default_name; | |||||
} | |||||
var g = get_default_graph(); | |||||
var _name_scope = g.name_scope(_name); | |||||
return _name_scope; | |||||
} | |||||
public static string _name_from_scope_name(string name) | public static string _name_from_scope_name(string name) | ||||
{ | { | ||||
if (name.EndsWith("/")) | if (name.EndsWith("/")) | ||||
@@ -14,12 +14,13 @@ namespace Tensorflow | |||||
bool validate_shape = true) : | bool validate_shape = true) : | ||||
base(initial_value, trainable, validate_shape) | base(initial_value, trainable, validate_shape) | ||||
{ | { | ||||
_init_from_args(initial_value, trainable); | |||||
} | } | ||||
private void _init_from_args(object initial_value, | private void _init_from_args(object initial_value, | ||||
TF_DataType trainable) | TF_DataType trainable) | ||||
{ | { | ||||
var name = ops.name_scope("", "Variable", initial_value); | |||||
_initial_value = ops.convert_to_tensor(initial_value, name: "initial_value"); | _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value"); | ||||
} | } | ||||
} | } | ||||
@@ -33,7 +33,13 @@ namespace Tensorflow | |||||
var attrs = new Dictionary<string, AttrValue>(); | var attrs = new Dictionary<string, AttrValue>(); | ||||
attrs["dtype"] = dtype_value; | attrs["dtype"] = dtype_value; | ||||
attrs["value"] = tensor_value; | attrs["value"] = tensor_value; | ||||
var const_tensor = g.create_op("Const", null, new TF_DataType[] { (TF_DataType)dtype_value.Type }, attrs: attrs).outputs[0]; | |||||
var const_tensor = g.create_op("Const", | |||||
null, | |||||
new TF_DataType[] { (TF_DataType)dtype_value.Type }, | |||||
attrs: attrs, | |||||
name: name).outputs[0]; | |||||
const_tensor.value = nd.Data(); | const_tensor.value = nd.Data(); | ||||
return const_tensor; | return const_tensor; | ||||
@@ -17,9 +17,9 @@ namespace Tensorflow | |||||
public static Graph g = new Graph(c_api.TF_NewGraph()); | public static Graph g = new Graph(c_api.TF_NewGraph()); | ||||
public static object Variable<T>(T data, TF_DataType dtype) | |||||
public static object Variable<T>(T data, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
{ | { | ||||
return new Variable(null, TF_DataType.DtInvalid); | |||||
return new RefVariable(data, dtype); | |||||
} | } | ||||
public static unsafe Tensor add(Tensor a, Tensor b) | public static unsafe Tensor add(Tensor a, Tensor b) | ||||
@@ -10,9 +10,17 @@ namespace TensorFlowNET.UnitTest | |||||
public class VariableTest | public class VariableTest | ||||
{ | { | ||||
[TestMethod] | [TestMethod] | ||||
public void Creating() | |||||
public void StringVar() | |||||
{ | { | ||||
var mammal = tf.Variable("Elephant", tf.chars); | |||||
var mammal1 = tf.Variable("Elephant", tf.chars); | |||||
var mammal2 = tf.Variable("Tiger"); | |||||
} | |||||
[TestMethod] | |||||
public void ScalarVar() | |||||
{ | |||||
var x = tf.Variable(3); | |||||
var y = tf.Variable(6f); | |||||
} | } | ||||
} | } | ||||
} | } |