Browse Source

string and scalar variable.

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
8ce3caa3f5
7 changed files with 61 additions and 8 deletions
  1. +25
    -1
      src/TensorFlowNET.Core/Graphs/Graph.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  3. +14
    -0
      src/TensorFlowNET.Core/Operations/ops.cs
  4. +2
    -1
      src/TensorFlowNET.Core/Tensors/RefVariable.cs
  5. +7
    -1
      src/TensorFlowNET.Core/Tensors/constant_op.cs
  6. +2
    -2
      src/TensorFlowNET.Core/tf.cs
  7. +10
    -2
      test/TensorFlowNET.UnitTest/VariableTest.cs

+ 25
- 1
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -24,6 +24,8 @@ namespace Tensorflow
private int _next_id_counter;
private List<String> _unfetchable_ops = new List<string>();

private string _name_stack;

public Graph(IntPtr graph)
{
this._c_graph = graph;
@@ -126,8 +128,31 @@ namespace Tensorflow
return false;
}

public string name_scope(string name)
{
string new_stack = "";

if (name.EndsWith("/"))
{
new_stack = ops._name_from_scope_name(name);
}
else
{
new_stack = unique_name(name);
}

_name_stack = new_stack;

return String.IsNullOrEmpty(new_stack) ? "" : new_stack + "/";
}

public string unique_name(string name)
{
if (!String.IsNullOrEmpty(_name_stack))
{
name = _name_stack + "/" + name;
}

var name_key = name.ToLower();
if (_names_in_use.ContainsKey(name_key))
{
@@ -138,7 +163,6 @@ namespace Tensorflow
_names_in_use[name_key] = 1;
return name;
}

return $"{name}_{_names_in_use[name_key]}";
}


+ 1
- 1
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -38,7 +38,7 @@ namespace Tensorflow
private static OpDefLibrary _InitOpDefLibrary()
{
// c_api.TF_GraphGetOpDef(g.Handle, op_type_name, buffer.Handle, status.Handle);
var bytes = File.ReadAllBytes("Tensorflow/op_list_proto_array.bin");
var bytes = File.ReadAllBytes("Operations/op_list_proto_array.bin");
var op_list = OpList.Parser.ParseFrom(bytes);
var op_def_lib = new OpDefLibrary();
op_def_lib.add_op_list(op_list);


+ 14
- 0
src/TensorFlowNET.Core/Operations/ops.cs View File

@@ -71,6 +71,20 @@ namespace Tensorflow
return node_def;
}

public static string name_scope(string name, string default_name = "", object values = null)
{
string _name = "";
if (String.IsNullOrEmpty(name))
{
_name = default_name;
}

var g = get_default_graph();
var _name_scope = g.name_scope(_name);

return _name_scope;
}

public static string _name_from_scope_name(string name)
{
if (name.EndsWith("/"))


+ 2
- 1
src/TensorFlowNET.Core/Tensors/RefVariable.cs View File

@@ -14,12 +14,13 @@ namespace Tensorflow
bool validate_shape = true) :
base(initial_value, trainable, validate_shape)
{
_init_from_args(initial_value, trainable);
}

private void _init_from_args(object initial_value,
TF_DataType trainable)
{
var name = ops.name_scope("", "Variable", initial_value);
_initial_value = ops.convert_to_tensor(initial_value, name: "initial_value");
}
}


+ 7
- 1
src/TensorFlowNET.Core/Tensors/constant_op.cs View File

@@ -33,7 +33,13 @@ namespace Tensorflow
var attrs = new Dictionary<string, AttrValue>();
attrs["dtype"] = dtype_value;
attrs["value"] = tensor_value;
var const_tensor = g.create_op("Const", null, new TF_DataType[] { (TF_DataType)dtype_value.Type }, attrs: attrs).outputs[0];

var const_tensor = g.create_op("Const",
null,
new TF_DataType[] { (TF_DataType)dtype_value.Type },
attrs: attrs,
name: name).outputs[0];

const_tensor.value = nd.Data();

return const_tensor;


+ 2
- 2
src/TensorFlowNET.Core/tf.cs View File

@@ -17,9 +17,9 @@ namespace Tensorflow

public static Graph g = new Graph(c_api.TF_NewGraph());

public static object Variable<T>(T data, TF_DataType dtype)
public static object Variable<T>(T data, TF_DataType dtype = TF_DataType.DtInvalid)
{
return new Variable(null, TF_DataType.DtInvalid);
return new RefVariable(data, dtype);
}

public static unsafe Tensor add(Tensor a, Tensor b)


+ 10
- 2
test/TensorFlowNET.UnitTest/VariableTest.cs View File

@@ -10,9 +10,17 @@ namespace TensorFlowNET.UnitTest
public class VariableTest
{
[TestMethod]
public void Creating()
public void StringVar()
{
var mammal = tf.Variable("Elephant", tf.chars);
var mammal1 = tf.Variable("Elephant", tf.chars);
var mammal2 = tf.Variable("Tiger");
}

[TestMethod]
public void ScalarVar()
{
var x = tf.Variable(3);
var y = tf.Variable(6f);
}
}
}

Loading…
Cancel
Save