From 8ce3caa3f5d9dac5b6e62a2fe866bd429c1da5f3 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Fri, 28 Dec 2018 01:00:32 -0600 Subject: [PATCH] string and scalar variable. --- src/TensorFlowNET.Core/Graphs/Graph.cs | 26 ++++++++++++++++++- .../Operations/gen_array_ops.cs | 2 +- src/TensorFlowNET.Core/Operations/ops.cs | 14 ++++++++++ src/TensorFlowNET.Core/Tensors/RefVariable.cs | 3 ++- src/TensorFlowNET.Core/Tensors/constant_op.cs | 8 +++++- src/TensorFlowNET.Core/tf.cs | 4 +-- test/TensorFlowNET.UnitTest/VariableTest.cs | 12 +++++++-- 7 files changed, 61 insertions(+), 8 deletions(-) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 96131b8a..2109c964 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -24,6 +24,8 @@ namespace Tensorflow private int _next_id_counter; private List _unfetchable_ops = new List(); + private string _name_stack; + public Graph(IntPtr graph) { this._c_graph = graph; @@ -126,8 +128,31 @@ namespace Tensorflow return false; } + public string name_scope(string name) + { + string new_stack = ""; + + if (name.EndsWith("/")) + { + new_stack = ops._name_from_scope_name(name); + } + else + { + new_stack = unique_name(name); + } + + _name_stack = new_stack; + + return String.IsNullOrEmpty(new_stack) ? "" : new_stack + "/"; + } + public string unique_name(string name) { + if (!String.IsNullOrEmpty(_name_stack)) + { + name = _name_stack + "/" + name; + } + var name_key = name.ToLower(); if (_names_in_use.ContainsKey(name_key)) { @@ -138,7 +163,6 @@ namespace Tensorflow _names_in_use[name_key] = 1; return name; } - return $"{name}_{_names_in_use[name_key]}"; } diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index b4bb76bf..1a52b3cf 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -38,7 +38,7 @@ namespace Tensorflow private static OpDefLibrary _InitOpDefLibrary() { // c_api.TF_GraphGetOpDef(g.Handle, op_type_name, buffer.Handle, status.Handle); - var bytes = File.ReadAllBytes("Tensorflow/op_list_proto_array.bin"); + var bytes = File.ReadAllBytes("Operations/op_list_proto_array.bin"); var op_list = OpList.Parser.ParseFrom(bytes); var op_def_lib = new OpDefLibrary(); op_def_lib.add_op_list(op_list); diff --git a/src/TensorFlowNET.Core/Operations/ops.cs b/src/TensorFlowNET.Core/Operations/ops.cs index a0ea4531..7b6dfcf7 100644 --- a/src/TensorFlowNET.Core/Operations/ops.cs +++ b/src/TensorFlowNET.Core/Operations/ops.cs @@ -71,6 +71,20 @@ namespace Tensorflow return node_def; } + public static string name_scope(string name, string default_name = "", object values = null) + { + string _name = ""; + if (String.IsNullOrEmpty(name)) + { + _name = default_name; + } + + var g = get_default_graph(); + var _name_scope = g.name_scope(_name); + + return _name_scope; + } + public static string _name_from_scope_name(string name) { if (name.EndsWith("/")) diff --git a/src/TensorFlowNET.Core/Tensors/RefVariable.cs b/src/TensorFlowNET.Core/Tensors/RefVariable.cs index bbe5996a..083a978d 100644 --- a/src/TensorFlowNET.Core/Tensors/RefVariable.cs +++ b/src/TensorFlowNET.Core/Tensors/RefVariable.cs @@ -14,12 +14,13 @@ namespace Tensorflow bool validate_shape = true) : base(initial_value, trainable, validate_shape) { - + _init_from_args(initial_value, trainable); } private void _init_from_args(object initial_value, TF_DataType trainable) { + var name = ops.name_scope("", "Variable", initial_value); _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value"); } } diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index 0beda95e..9b9eba0b 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -33,7 +33,13 @@ namespace Tensorflow var attrs = new Dictionary(); attrs["dtype"] = dtype_value; attrs["value"] = tensor_value; - var const_tensor = g.create_op("Const", null, new TF_DataType[] { (TF_DataType)dtype_value.Type }, attrs: attrs).outputs[0]; + + var const_tensor = g.create_op("Const", + null, + new TF_DataType[] { (TF_DataType)dtype_value.Type }, + attrs: attrs, + name: name).outputs[0]; + const_tensor.value = nd.Data(); return const_tensor; diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs index d50ccdb1..2e860410 100644 --- a/src/TensorFlowNET.Core/tf.cs +++ b/src/TensorFlowNET.Core/tf.cs @@ -17,9 +17,9 @@ namespace Tensorflow public static Graph g = new Graph(c_api.TF_NewGraph()); - public static object Variable(T data, TF_DataType dtype) + public static object Variable(T data, TF_DataType dtype = TF_DataType.DtInvalid) { - return new Variable(null, TF_DataType.DtInvalid); + return new RefVariable(data, dtype); } public static unsafe Tensor add(Tensor a, Tensor b) diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs index 970a5670..2dbe3816 100644 --- a/test/TensorFlowNET.UnitTest/VariableTest.cs +++ b/test/TensorFlowNET.UnitTest/VariableTest.cs @@ -10,9 +10,17 @@ namespace TensorFlowNET.UnitTest public class VariableTest { [TestMethod] - public void Creating() + public void StringVar() { - var mammal = tf.Variable("Elephant", tf.chars); + var mammal1 = tf.Variable("Elephant", tf.chars); + var mammal2 = tf.Variable("Tiger"); + } + + [TestMethod] + public void ScalarVar() + { + var x = tf.Variable(3); + var y = tf.Variable(6f); } } }