Browse Source

string and scalar variable.

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
8ce3caa3f5
7 changed files with 61 additions and 8 deletions
  1. +25
    -1
      src/TensorFlowNET.Core/Graphs/Graph.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  3. +14
    -0
      src/TensorFlowNET.Core/Operations/ops.cs
  4. +2
    -1
      src/TensorFlowNET.Core/Tensors/RefVariable.cs
  5. +7
    -1
      src/TensorFlowNET.Core/Tensors/constant_op.cs
  6. +2
    -2
      src/TensorFlowNET.Core/tf.cs
  7. +10
    -2
      test/TensorFlowNET.UnitTest/VariableTest.cs

+ 25
- 1
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -24,6 +24,8 @@ namespace Tensorflow
private int _next_id_counter; private int _next_id_counter;
private List<String> _unfetchable_ops = new List<string>(); private List<String> _unfetchable_ops = new List<string>();


private string _name_stack;

public Graph(IntPtr graph) public Graph(IntPtr graph)
{ {
this._c_graph = graph; this._c_graph = graph;
@@ -126,8 +128,31 @@ namespace Tensorflow
return false; return false;
} }


public string name_scope(string name)
{
string new_stack = "";

if (name.EndsWith("/"))
{
new_stack = ops._name_from_scope_name(name);
}
else
{
new_stack = unique_name(name);
}

_name_stack = new_stack;

return String.IsNullOrEmpty(new_stack) ? "" : new_stack + "/";
}

public string unique_name(string name) public string unique_name(string name)
{ {
if (!String.IsNullOrEmpty(_name_stack))
{
name = _name_stack + "/" + name;
}

var name_key = name.ToLower(); var name_key = name.ToLower();
if (_names_in_use.ContainsKey(name_key)) if (_names_in_use.ContainsKey(name_key))
{ {
@@ -138,7 +163,6 @@ namespace Tensorflow
_names_in_use[name_key] = 1; _names_in_use[name_key] = 1;
return name; return name;
} }


return $"{name}_{_names_in_use[name_key]}"; return $"{name}_{_names_in_use[name_key]}";
} }


+ 1
- 1
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -38,7 +38,7 @@ namespace Tensorflow
private static OpDefLibrary _InitOpDefLibrary() private static OpDefLibrary _InitOpDefLibrary()
{ {
// c_api.TF_GraphGetOpDef(g.Handle, op_type_name, buffer.Handle, status.Handle); // c_api.TF_GraphGetOpDef(g.Handle, op_type_name, buffer.Handle, status.Handle);
var bytes = File.ReadAllBytes("Tensorflow/op_list_proto_array.bin");
var bytes = File.ReadAllBytes("Operations/op_list_proto_array.bin");
var op_list = OpList.Parser.ParseFrom(bytes); var op_list = OpList.Parser.ParseFrom(bytes);
var op_def_lib = new OpDefLibrary(); var op_def_lib = new OpDefLibrary();
op_def_lib.add_op_list(op_list); op_def_lib.add_op_list(op_list);


+ 14
- 0
src/TensorFlowNET.Core/Operations/ops.cs View File

@@ -71,6 +71,20 @@ namespace Tensorflow
return node_def; return node_def;
} }


public static string name_scope(string name, string default_name = "", object values = null)
{
string _name = "";
if (String.IsNullOrEmpty(name))
{
_name = default_name;
}

var g = get_default_graph();
var _name_scope = g.name_scope(_name);

return _name_scope;
}

public static string _name_from_scope_name(string name) public static string _name_from_scope_name(string name)
{ {
if (name.EndsWith("/")) if (name.EndsWith("/"))


+ 2
- 1
src/TensorFlowNET.Core/Tensors/RefVariable.cs View File

@@ -14,12 +14,13 @@ namespace Tensorflow
bool validate_shape = true) : bool validate_shape = true) :
base(initial_value, trainable, validate_shape) base(initial_value, trainable, validate_shape)
{ {
_init_from_args(initial_value, trainable);
} }


private void _init_from_args(object initial_value, private void _init_from_args(object initial_value,
TF_DataType trainable) TF_DataType trainable)
{ {
var name = ops.name_scope("", "Variable", initial_value);
_initial_value = ops.convert_to_tensor(initial_value, name: "initial_value"); _initial_value = ops.convert_to_tensor(initial_value, name: "initial_value");
} }
} }


+ 7
- 1
src/TensorFlowNET.Core/Tensors/constant_op.cs View File

@@ -33,7 +33,13 @@ namespace Tensorflow
var attrs = new Dictionary<string, AttrValue>(); var attrs = new Dictionary<string, AttrValue>();
attrs["dtype"] = dtype_value; attrs["dtype"] = dtype_value;
attrs["value"] = tensor_value; attrs["value"] = tensor_value;
var const_tensor = g.create_op("Const", null, new TF_DataType[] { (TF_DataType)dtype_value.Type }, attrs: attrs).outputs[0];

var const_tensor = g.create_op("Const",
null,
new TF_DataType[] { (TF_DataType)dtype_value.Type },
attrs: attrs,
name: name).outputs[0];

const_tensor.value = nd.Data(); const_tensor.value = nd.Data();


return const_tensor; return const_tensor;


+ 2
- 2
src/TensorFlowNET.Core/tf.cs View File

@@ -17,9 +17,9 @@ namespace Tensorflow


public static Graph g = new Graph(c_api.TF_NewGraph()); public static Graph g = new Graph(c_api.TF_NewGraph());


public static object Variable<T>(T data, TF_DataType dtype)
public static object Variable<T>(T data, TF_DataType dtype = TF_DataType.DtInvalid)
{ {
return new Variable(null, TF_DataType.DtInvalid);
return new RefVariable(data, dtype);
} }


public static unsafe Tensor add(Tensor a, Tensor b) public static unsafe Tensor add(Tensor a, Tensor b)


+ 10
- 2
test/TensorFlowNET.UnitTest/VariableTest.cs View File

@@ -10,9 +10,17 @@ namespace TensorFlowNET.UnitTest
public class VariableTest public class VariableTest
{ {
[TestMethod] [TestMethod]
public void Creating()
public void StringVar()
{ {
var mammal = tf.Variable("Elephant", tf.chars);
var mammal1 = tf.Variable("Elephant", tf.chars);
var mammal2 = tf.Variable("Tiger");
}

[TestMethod]
public void ScalarVar()
{
var x = tf.Variable(3);
var y = tf.Variable(6f);
} }
} }
} }

Loading…
Cancel
Save