Browse Source

Merge remote-tracking branch 'upstream/master'

# Conflicts:
#	src/TensorFlowNET.Core/Operations/Operation.cs
#	test/TensorFlowNET.UnitTest/PythonTest.cs
#	test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs
tags/v0.9
Meinrad Recheis 6 years ago
parent
commit
fe3440f2f9
17 changed files with 1672 additions and 396 deletions
  1. +2
    -2
      src/TensorFlowNET.Core/APIs/tf.math.cs
  2. +18
    -3
      src/TensorFlowNET.Core/Framework/meta_graph.py.cs
  3. +5
    -0
      src/TensorFlowNET.Core/Graphs/Graph.cs
  4. +40
    -8
      src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
  5. +32
    -2
      src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
  6. +2
    -2
      src/TensorFlowNET.Core/Operations/Operation.Output.cs
  7. +308
    -309
      src/TensorFlowNET.Core/Operations/Operation.cs
  8. +4
    -4
      src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
  9. +2
    -2
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  10. +1172
    -0
      src/TensorFlowNET.Core/Protobuf/ControlFlow.cs
  11. +4
    -4
      src/TensorFlowNET.Core/Protobuf/IProtoBuf.cs
  12. +5
    -2
      src/TensorFlowNET.Core/Protobuf/README.md
  13. +2
    -2
      src/TensorFlowNET.Core/Variables/RefVariable.cs
  14. +1
    -1
      src/TensorFlowNET.Core/ops.py.cs
  15. +21
    -43
      test/TensorFlowNET.UnitTest/PythonTest.cs
  16. +53
    -11
      test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs
  17. +1
    -1
      test/TensorFlowNET.UnitTest/ops_test/ControlDependenciesTest.cs

+ 2
- 2
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -27,7 +27,7 @@ namespace Tensorflow
public static Tensor asin(Tensor x, string name = null) public static Tensor asin(Tensor x, string name = null)
=> gen_math_ops.asin(x, name); => gen_math_ops.asin(x, name);


public static Tensor add(Tensor a, Tensor b)
public static Tensor add<Tx, Ty>(Tx a, Ty b)
=> gen_math_ops.add(a, b); => gen_math_ops.add(a, b);


/// <summary> /// <summary>
@@ -251,7 +251,7 @@ namespace Tensorflow
public static Tensor minimum<T1, T2>(T1 x, T2 y, string name = null) public static Tensor minimum<T1, T2>(T1 x, T2 y, string name = null)
=> gen_math_ops.minimum(x, y, name: name); => gen_math_ops.minimum(x, y, name: name);


public static Tensor multiply(Tensor x, Tensor y)
public static Tensor multiply<Tx, Ty>(Tx x, Ty y)
=> gen_math_ops.mul(x, y); => gen_math_ops.mul(x, y);


public static Tensor negative(Tensor x, string name = null) public static Tensor negative(Tensor x, string name = null)


+ 18
- 3
src/TensorFlowNET.Core/Framework/meta_graph.py.cs View File

@@ -4,6 +4,7 @@ using System.Collections.Generic;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using Tensorflow.Operations;
using static Tensorflow.CollectionDef; using static Tensorflow.CollectionDef;
using static Tensorflow.MetaGraphDef.Types; using static Tensorflow.MetaGraphDef.Types;


@@ -95,15 +96,29 @@ namespace Tensorflow
} }
else else
{ {
throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
foreach(var value in col.Value.BytesList.Value)
{
switch (col.Key)
{
case "cond_context":
var proto = CondContextDef.Parser.ParseFrom(value);
var condContext = new CondContext().from_proto(proto, import_scope);
graph.add_to_collection(col.Key, condContext);
break;
default:
throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
}
}
} }
break; break;
default:
throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
} }
} }


var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
scope: scope_to_prepend_to_names) as List<RefVariable>;
var variables = graph.get_collection<RefVariable>(ops.GraphKeys.GLOBAL_VARIABLES,
scope: scope_to_prepend_to_names);
var var_list = new Dictionary<string, RefVariable>(); var var_list = new Dictionary<string, RefVariable>();
variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v); variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v);




+ 5
- 0
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -412,6 +412,11 @@ namespace Tensorflow
return _collections.ContainsKey(name) ? _collections[name] : null; return _collections.ContainsKey(name) ? _collections[name] : null;
} }


public List<T> get_collection<T>(string name, string scope = null)
{
return _collections.ContainsKey(name) ? _collections[name] as List<T> : new List<T>();
}

public object get_collection_ref(string name) public object get_collection_ref(string name)
{ {
if (!_collections.ContainsKey(name)) if (!_collections.ContainsKey(name))


+ 40
- 8
src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs View File

@@ -8,7 +8,7 @@ namespace Tensorflow.Operations
/// <summary> /// <summary>
/// The context for the conditional construct. /// The context for the conditional construct.
/// </summary> /// </summary>
public class CondContext : ControlFlowContext
public class CondContext : ControlFlowContext, IProtoBuf<CondContextDef, CondContext>
{ {




@@ -35,16 +35,20 @@ namespace Tensorflow.Operations
/// <param name="name">Name of the `CondContext` python object.</param> /// <param name="name">Name of the `CondContext` python object.</param>
/// <param name="context_def"></param> /// <param name="context_def"></param>
/// <param name="import_scope"></param> /// <param name="import_scope"></param>
public CondContext(Tensor pred,
Tensor pivot,
int branch,
public CondContext(Tensor pred = null,
Tensor pivot = null,
int? branch = null,
string name = "cond_text", string name = "cond_text",
object context_def = null,
CondContextDef context_def = null,
string import_scope = null) string import_scope = null)
{ {
if (pred == null && context_def == null) return;

_name = ops.get_default_graph().unique_name(name); _name = ops.get_default_graph().unique_name(name);
if (context_def != null)
throw new NotImplementedException("CondContext context_def is not null");
if (context_def != null)
{
_init_from_proto(context_def, import_scope: import_scope);
}
else else
{ {
// Initializes the default fields. // Initializes the default fields.
@@ -61,6 +65,18 @@ namespace Tensorflow.Operations
} }
} }


private void _init_from_proto(CondContextDef context_def, string import_scope = null)
{
var g = ops.get_default_graph();
_name = ops.prepend_name_scope(context_def.ContextName, import_scope);
var p1 = ops.prepend_name_scope(context_def.PredName, import_scope);
_pred = g.as_graph_element(p1) as Tensor;
var p2 = ops.prepend_name_scope(context_def.PivotName, import_scope);
_pivot = g.as_graph_element(p2) as Tensor;
_branch = context_def.Branch;
__init__(values_def: context_def.ValuesDef, import_scope: import_scope);
}

/// <summary> /// <summary>
/// Add `val` to the current context and its outer context recursively. /// Add `val` to the current context and its outer context recursively.
/// </summary> /// </summary>
@@ -230,6 +246,22 @@ namespace Tensorflow.Operations
public override void AddInnerOp(Operation resultOp) public override void AddInnerOp(Operation resultOp)
{ {
throw new NotImplementedException(); throw new NotImplementedException();
}
}
public CondContextDef to_proto(string export_scope)
{
throw new NotImplementedException();
}
public CondContext from_proto(CondContextDef proto, string import_scope)
{
var ret = new CondContext(context_def: proto, import_scope: import_scope);
ret.Enter();
foreach (var nested_def in proto.NestedContexts)
throw new NotImplementedException("");
ret.Exit();
return ret;
}
} }
} }

+ 32
- 2
src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs View File

@@ -32,6 +32,8 @@ namespace Tensorflow.Operations
protected Stack<IControlFlowContext> _context_stack; protected Stack<IControlFlowContext> _context_stack;
protected IControlFlowContext _outer_context; protected IControlFlowContext _outer_context;


protected Dictionary<string, ITensorOrOperation> _external_values;

public ControlFlowContext() public ControlFlowContext()
{ {
_context_stack = new Stack<IControlFlowContext>(); _context_stack = new Stack<IControlFlowContext>();
@@ -40,15 +42,43 @@ namespace Tensorflow.Operations
public string name { get => _name; } public string name { get => _name; }
protected string _name; protected string _name;


public void __init__()
public void __init__(ValuesDef values_def = null, string import_scope = null)
{ {

_outer_context = ops.get_default_graph()._get_control_flow_context();
if (values_def != null)
_init_values_from_proto(values_def, import_scope: import_scope);
} }


public void __enter__() public void __enter__()
{ {
} }


/// <summary>
/// Initializes values and external_values from `ValuesDef` protocol buffer.
/// </summary>
/// <param name="values_def"></param>
/// <param name="import_scope"></param>
protected void _init_values_from_proto(ValuesDef values_def, string import_scope = null)
{
_external_values = new Dictionary<string, ITensorOrOperation>();
foreach (var value in values_def.Values)
_values.Add(value);
var g = ops.get_default_graph();
foreach(var value in values_def.ExternalValues)
{
var k = ops.prepend_name_scope(value.Key, import_scope);
var v = value.Value;
_external_values[k] = g.as_graph_element(ops.prepend_name_scope(v, import_scope));
}
var op_names = _values.Where(x => !_external_values.ContainsKey(x))
.Select(x => x.Split(':')[0])
.ToArray();
foreach (var op in op_names)
(g.as_graph_element(op) as Operation)._set_control_flow_context(this);
}

public void __exit__() public void __exit__()
{ {
} }


+ 2
- 2
src/TensorFlowNET.Core/Operations/Operation.Output.cs View File

@@ -42,8 +42,8 @@ namespace Tensorflow
if (NumControlOutputs > 0) if (NumControlOutputs > 0)
{ {
IntPtr control_output_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlOutputs); IntPtr control_output_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlOutputs);
c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlInputs);
for (int i = 0; i < NumControlInputs; i++)
c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlOutputs);
for (int i = 0; i < NumControlOutputs; i++)
{ {
var handle = control_output_handle + Marshal.SizeOf<IntPtr>() * i; var handle = control_output_handle + Marshal.SizeOf<IntPtr>() * i;
control_outputs[i] = new Operation(*(IntPtr*)handle); control_outputs[i] = new Operation(*(IntPtr*)handle);


+ 308
- 309
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -1,319 +1,318 @@
using Google.Protobuf.Collections;
//using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
namespace Tensorflow
using Google.Protobuf.Collections;
//using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
namespace Tensorflow
{ {
/// <summary>
/// Represents a graph node that performs computation on tensors.
///
/// An `Operation` is a node in a TensorFlow `Graph` that takes zero or
/// more `Tensor` objects as input, and produces zero or more `Tensor`
/// objects as output. Objects of type `Operation` are created by
/// calling an op constructor(such as `tf.matmul`)
/// or `tf.Graph.create_op`.
///
/// For example `c = tf.matmul(a, b)` creates an `Operation` of type
/// "MatMul" that takes tensors `a` and `b` as input, and produces `c`
/// as output.
///
/// After the graph has been launched in a session, an `Operation` can
/// be executed by passing it to
/// `tf.Session.run`.
/// <summary>
/// Represents a graph node that performs computation on tensors.
///
/// An `Operation` is a node in a TensorFlow `Graph` that takes zero or
/// more `Tensor` objects as input, and produces zero or more `Tensor`
/// objects as output. Objects of type `Operation` are created by
/// calling an op constructor(such as `tf.matmul`)
/// or `tf.Graph.create_op`.
///
/// For example `c = tf.matmul(a, b)` creates an `Operation` of type
/// "MatMul" that takes tensors `a` and `b` as input, and produces `c`
/// as output.
///
/// After the graph has been launched in a session, an `Operation` can
/// be executed by passing it to
/// `tf.Session.run`.
/// `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`. /// `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`.
/// </summary>
public partial class Operation : ITensorOrOperation
{
private readonly IntPtr _handle; // _c_op in python
private readonly IntPtr _operDesc;

private Graph _graph;
//[JsonIgnore]
public Graph graph => _graph;
//[JsonIgnore]
public int _id => _id_value;
//[JsonIgnore]
public int _id_value;

public string type => OpType;
//[JsonIgnore]
public Operation op => this;
public TF_DataType dtype => TF_DataType.DtInvalid;
private Status status = new Status();

public string name => c_api.StringPiece(c_api.TF_OperationName(_handle));
public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle));
public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle));

private NodeDef _node_def;
public NodeDef node_def
{
get
{
if(_node_def == null)
_node_def = GetNodeDef();

return _node_def;
}
}

public Operation(IntPtr handle, Graph g=null)
{
if (handle == IntPtr.Zero)
return;

_handle = handle;
_graph = g ?? ops.get_default_graph();
_outputs = new Tensor[NumOutputs];
for (int i = 0; i < NumOutputs; i++)
_outputs[i] = new Tensor(this, i, OutputType(i));

// Dict mapping op name to file and line information for op colocation
// context managers.
_control_flow_context = graph._get_control_flow_context();

// Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor.
}

public Operation(Graph g, string opType, string oper_name)
{
_graph = g;

_operDesc = c_api.TF_NewOperation(g, opType, oper_name);
c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32);
_handle = c_api.TF_FinishOperation(_operDesc, status);

// Dict mapping op name to file and line information for op colocation
// context managers.
_control_flow_context = graph._get_control_flow_context();
}

/// <summary>
/// Creates an `Operation`.
/// </summary>
/// <param name="node_def">`node_def_pb2.NodeDef`. `NodeDef` for the `Operation`.</param>
/// <param name="g">`Graph`. The parent graph.</param>
/// <param name="inputs">list of `Tensor` objects. The inputs to this `Operation`.</param>
/// <param name="output_types">list of `DType` objects.</param>
/// <param name="control_inputs">
/// list of operations or tensors from which to have a
/// control dependency.
/// </param>
/// <param name="input_types">
/// List of `DType` objects representing the
/// types of the tensors accepted by the `Operation`. By default
/// uses `[x.dtype.base_dtype for x in inputs]`. Operations that expect
/// reference-typed inputs must specify these explicitly.
/// </param>
/// <param name="original_op"></param>
/// <param name="op_def"></param>
public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
{
_graph = g;

// Build the list of control inputs.
var control_input_ops = new List<Operation>();
if(control_inputs != null)
{
foreach(var c in control_inputs)
{
switch (c)
{
case Operation c1:
control_input_ops.Add(c1);
break;
case Tensor tensor:
control_input_ops.Add(tensor.op);
break;
// TODO: IndexedSlices don't yet exist, but once they do, this needs to be uncommented
//case IndexedSlices islices:
// control_input_ops.Add(islices.op);
// break;
default:
throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}");
}
}
}

// Dict mapping op name to file and line information for op colocation
// context managers.
_control_flow_context = graph._get_control_flow_context();

// This will be set by self.inputs.
if (op_def == null)
op_def = g.GetOpDef(node_def.Op);

var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);
(_handle, _operDesc) = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());

// Initialize self._outputs.
output_types = new TF_DataType[NumOutputs];
for (int i = 0; i < NumOutputs; i++)
output_types[i] = OutputType(i);

_outputs = new Tensor[NumOutputs];
for (int i = 0; i < NumOutputs; i++)
_outputs[i] = new Tensor(this, i, OutputType(i));

graph._add_op(this);

if (_handle != IntPtr.Zero)
_control_flow_post_processing();
}

public void run(FeedItem[] feed_dict = null, Session session = null)
{
ops._run_using_default_session(this, feed_dict, graph, session);
}

private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField<string, AttrValue> attrs)
{
var grouped_inputs = new List<object>();
int i = 0;
int input_len = 0;
bool is_sequence = false;
foreach (var input_arg in op_def.InputArg)
{
if (!string.IsNullOrEmpty(input_arg.NumberAttr))
{
input_len = (int)attrs[input_arg.NumberAttr].I;
is_sequence = true;
}
else if (!string.IsNullOrEmpty(input_arg.TypeListAttr))
{
input_len = attrs[input_arg.TypeListAttr].List.Type.Count;
is_sequence = true;
}
else
{
input_len = 1;
is_sequence = false;
}

if (is_sequence)
grouped_inputs.Add(inputs.Skip(i).Take(input_len).ToArray());
else
grouped_inputs.Add(inputs[i]);

i += input_len;
}

return grouped_inputs.ToArray();
}

public object get_attr(string name)
{
AttrValue x = null;

using (var buf = new Buffer())
{
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status);
status.Check(true);
x = AttrValue.Parser.ParseFrom(buf);
}

string oneof_value = x.ValueCase.ToString();
if (string.IsNullOrEmpty(oneof_value))
return null;

if(oneof_value == "list")
throw new NotImplementedException($"Unsupported field type in {x.ToString()}");

if (oneof_value == "type")
return x.Type;

object result = x.GetType().GetProperty(oneof_value).GetValue(x);
if (result is Google.Protobuf.ByteString byteString)
return byteString.ToStringUtf8();
return result;
}

public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s)
{
return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s);
}

private NodeDef GetNodeDef()
{
using (var s = new Status())
using (var buffer = new Buffer())
{
c_api.TF_OperationToNodeDef(_handle, buffer, s);
s.Check();
return NodeDef.Parser.ParseFrom(buffer);
}
}

public override string ToString()
{
return _handle == IntPtr.Zero ? "tf.Operation Undefined" : $"tf.Operation '{name}' type={OpType}";
}

public static implicit operator Operation(IntPtr handle) => new Operation(handle);
public static implicit operator IntPtr(Operation op) => op._handle;

public override bool Equals(object obj)
{
switch (obj)
{
case IntPtr val:
return val == _handle;
case Operation val:
return val._handle == _handle;
}

return base.Equals(obj);
/// </summary>
public partial class Operation : ITensorOrOperation
{
private readonly IntPtr _handle; // _c_op in python
private readonly IntPtr _operDesc;
private Graph _graph;
//[JsonIgnore]
public Graph graph => _graph;
//[JsonIgnore]
public int _id => _id_value;
//[JsonIgnore]
public int _id_value;
public string type => OpType;
//[JsonIgnore]
public Operation op => this;
public TF_DataType dtype => TF_DataType.DtInvalid;
private Status status = new Status();
public string name => c_api.StringPiece(c_api.TF_OperationName(_handle));
public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle));
public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle));
private NodeDef _node_def;
public NodeDef node_def
{
get
{
if(_node_def == null)
_node_def = GetNodeDef();
return _node_def;
}
}
public Operation(IntPtr handle, Graph g=null)
{
if (handle == IntPtr.Zero)
return;
_handle = handle;
_graph = g ?? ops.get_default_graph();
_outputs = new Tensor[NumOutputs];
for (int i = 0; i < NumOutputs; i++)
_outputs[i] = new Tensor(this, i, OutputType(i));
// Dict mapping op name to file and line information for op colocation
// context managers.
_control_flow_context = graph._get_control_flow_context();
// Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor.
}
public Operation(Graph g, string opType, string oper_name)
{
_graph = g;
_operDesc = c_api.TF_NewOperation(g, opType, oper_name);
c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32);
_handle = c_api.TF_FinishOperation(_operDesc, status);
// Dict mapping op name to file and line information for op colocation
// context managers.
_control_flow_context = graph._get_control_flow_context();
} }
/// <summary>
/// Update the input to this operation at the given index.
///
/// NOTE: This is for TF internal use only.Please don't use it.
/// </summary>
/// <param name="index">the index of the input to update.</param>
/// <param name="tensor"> the Tensor to be used as the input at the given index.</param>
public void _update_input(int index, Tensor tensor)
{
_assert_same_graph(tensor);

var input = _tf_input(index);
/// <summary>
/// Creates an `Operation`.
/// </summary>
/// <param name="node_def">`node_def_pb2.NodeDef`. `NodeDef` for the `Operation`.</param>
/// <param name="g">`Graph`. The parent graph.</param>
/// <param name="inputs">list of `Tensor` objects. The inputs to this `Operation`.</param>
/// <param name="output_types">list of `DType` objects.</param>
/// <param name="control_inputs">
/// list of operations or tensors from which to have a
/// control dependency.
/// </param>
/// <param name="input_types">
/// List of `DType` objects representing the
/// types of the tensors accepted by the `Operation`. By default
/// uses `[x.dtype.base_dtype for x in inputs]`. Operations that expect
/// reference-typed inputs must specify these explicitly.
/// </param>
/// <param name="original_op"></param>
/// <param name="op_def"></param>
public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
{
_graph = g;
// Build the list of control inputs.
var control_input_ops = new List<Operation>();
if(control_inputs != null)
{
foreach(var c in control_inputs)
{
switch (c)
{
case Operation c1:
control_input_ops.Add(c1);
break;
case Tensor tensor:
control_input_ops.Add(tensor.op);
break;
// TODO: IndexedSlices don't yet exist, but once they do, this needs to be uncommented
//case IndexedSlices islices:
// control_input_ops.Add(islices.op);
// break;
default:
throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}");
}
}
}
// Dict mapping op name to file and line information for op colocation
// context managers.
_control_flow_context = graph._get_control_flow_context();
// This will be set by self.inputs.
if (op_def == null)
op_def = g.GetOpDef(node_def.Op);
var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);
(_handle, _operDesc) = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());
// Initialize self._outputs.
output_types = new TF_DataType[NumOutputs];
for (int i = 0; i < NumOutputs; i++)
output_types[i] = OutputType(i);
_outputs = new Tensor[NumOutputs];
for (int i = 0; i < NumOutputs; i++)
_outputs[i] = new Tensor(this, i, OutputType(i));
graph._add_op(this);
if (_handle != IntPtr.Zero)
_control_flow_post_processing();
}
public void run(FeedItem[] feed_dict = null, Session session = null)
{
ops._run_using_default_session(this, feed_dict, graph, session);
}
private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField<string, AttrValue> attrs)
{
var grouped_inputs = new List<object>();
int i = 0;
int input_len = 0;
bool is_sequence = false;
foreach (var input_arg in op_def.InputArg)
{
if (!string.IsNullOrEmpty(input_arg.NumberAttr))
{
input_len = (int)attrs[input_arg.NumberAttr].I;
is_sequence = true;
}
else if (!string.IsNullOrEmpty(input_arg.TypeListAttr))
{
input_len = attrs[input_arg.TypeListAttr].List.Type.Count;
is_sequence = true;
}
else
{
input_len = 1;
is_sequence = false;
}
if (is_sequence)
grouped_inputs.Add(inputs.Skip(i).Take(input_len).ToArray());
else
grouped_inputs.Add(inputs[i]);
i += input_len;
}
return grouped_inputs.ToArray();
}
public object get_attr(string name)
{
AttrValue x = null;
using (var buf = new Buffer())
{
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status);
status.Check(true);
x = AttrValue.Parser.ParseFrom(buf);
}
string oneof_value = x.ValueCase.ToString();
if (string.IsNullOrEmpty(oneof_value))
return null;
if(oneof_value == "list")
throw new NotImplementedException($"Unsupported field type in {x.ToString()}");
if (oneof_value == "type")
return x.Type;
object result = x.GetType().GetProperty(oneof_value).GetValue(x);
if (result is Google.Protobuf.ByteString byteString)
return byteString.ToStringUtf8();
return result;
}
public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s)
{
return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s);
}
private NodeDef GetNodeDef()
{
using (var s = new Status())
using (var buffer = new Buffer())
{
c_api.TF_OperationToNodeDef(_handle, buffer, s);
s.Check();
return NodeDef.Parser.ParseFrom(buffer);
}
}
public override string ToString()
{
return _handle == IntPtr.Zero ? "tf.Operation Undefined" : $"tf.Operation '{name}' type={OpType}";
}
public static implicit operator Operation(IntPtr handle) => new Operation(handle);
public static implicit operator IntPtr(Operation op) => op._handle;
public override bool Equals(object obj)
{
switch (obj)
{
case IntPtr val:
return val == _handle;
case Operation val:
return val._handle == _handle;
}
return base.Equals(obj);
}
/// <summary>
/// Update the input to this operation at the given index.
///
/// NOTE: This is for TF internal use only.Please don't use it.
/// </summary>
/// <param name="index">the index of the input to update.</param>
/// <param name="tensor"> the Tensor to be used as the input at the given index.</param>
public void _update_input(int index, Tensor tensor)
{
_assert_same_graph(tensor);
var input = _tf_input(index);
var output = tensor._as_tf_output(); var output = tensor._as_tf_output();
// Reset cached inputs. // Reset cached inputs.
_inputs = null; _inputs = null;
// after the c_api call next time _inputs is accessed // after the c_api call next time _inputs is accessed
// the updated inputs are reloaded from the c_api
c_api.TF_UpdateEdge(_graph, output, input, status);
//var updated_inputs = inputs;

}

private void _assert_same_graph(Tensor tensor)
{
//TODO: implement
}

/// <summary>
/// Create and return a new TF_Output for output_idx'th output of this op.
/// </summary>
public TF_Output _tf_output(int output_idx)
{
var tf_output = new TF_Output(op, output_idx);
return tf_output;
}

/// <summary>
/// Create and return a new TF_Input for input_idx'th input of this op.
/// </summary>
public TF_Input _tf_input(int input_idx)
{
var tf_input = new TF_Input(op, input_idx);
return tf_input;
}
}
}
// the updated inputs are reloaded from the c_api
c_api.TF_UpdateEdge(_graph, output, input, status);
//var updated_inputs = inputs;
}
private void _assert_same_graph(Tensor tensor)
{
//TODO: implement
}
/// <summary>
/// Create and return a new TF_Output for output_idx'th output of this op.
/// </summary>
public TF_Output _tf_output(int output_idx)
{
var tf_output = new TF_Output(op, output_idx);
return tf_output;
}
/// <summary>
/// Create and return a new TF_Input for input_idx'th input of this op.
/// </summary>
public TF_Input _tf_input(int input_idx)
{
var tf_input = new TF_Input(op, input_idx);
return tf_input;
}
}
}

+ 4
- 4
src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs View File

@@ -308,7 +308,7 @@ namespace Tensorflow
tensor.op.graph.prevent_fetching(tensor.op); tensor.op.graph.prevent_fetching(tensor.op);


// Build the graph for the true branch in a new context. // Build the graph for the true branch in a new context.
var context_t = new CondContext(pred, pivot_1, branch: 1);
var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1);
ITensorOrOperation orig_res_t; ITensorOrOperation orig_res_t;
Tensor res_t; Tensor res_t;
try try
@@ -321,7 +321,7 @@ namespace Tensorflow
context_t.Exit(); context_t.Exit();
} }
// Build the graph for the false branch in a new context. // Build the graph for the false branch in a new context.
var context_f = new CondContext(pred, pivot_2, branch: 0);
var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0);
ITensorOrOperation orig_res_f; ITensorOrOperation orig_res_f;
Tensor res_f; Tensor res_f;
try try
@@ -389,13 +389,13 @@ namespace Tensorflow
tensor.op.graph.prevent_fetching(tensor.op); tensor.op.graph.prevent_fetching(tensor.op);


// Build the graph for the true branch in a new context. // Build the graph for the true branch in a new context.
var context_t = new CondContext(pred, pivot_1, branch: 1);
var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1);
context_t.Enter(); context_t.Enter();
var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn);
context_t.Exit(); context_t.Exit();


// Build the graph for the false branch in a new context. // Build the graph for the false branch in a new context.
var context_f = new CondContext(pred, pivot_2, branch: 0);
var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0);
context_f.Enter(); context_f.Enter();
var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn);
context_f.Exit(); context_f.Exit();


+ 2
- 2
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -80,7 +80,7 @@ namespace Tensorflow
return _op.outputs[0]; return _op.outputs[0];
} }
public static Tensor add(Tensor x, Tensor y, string name = null)
public static Tensor add<Tx, Ty>(Tx x, Ty y, string name = null)
{ {
var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y }); var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y });
@@ -300,7 +300,7 @@ namespace Tensorflow
return _op.outputs[0]; return _op.outputs[0];
} }
public static Tensor mul(Tensor x, Tensor y, string name = null)
public static Tensor mul<Tx, Ty>(Tx x, Ty y, string name = null)
{ {
var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y }); var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y });


+ 1172
- 0
src/TensorFlowNET.Core/Protobuf/ControlFlow.cs
File diff suppressed because it is too large
View File


+ 4
- 4
src/TensorFlowNET.Core/Protobuf/IProtoBuf.cs View File

@@ -8,7 +8,7 @@ namespace Tensorflow
/// In order for a object to be serialized to and from MetaGraphDef, /// In order for a object to be serialized to and from MetaGraphDef,
/// the class must implement to_proto() and from_proto() methods /// the class must implement to_proto() and from_proto() methods
/// </summary> /// </summary>
public interface IProtoBuf
public interface IProtoBuf<TProtoDef, TDef>
{ {
string name { get; } string name { get; }


@@ -17,15 +17,15 @@ namespace Tensorflow
/// </summary> /// </summary>
/// <param name="export_scope"></param> /// <param name="export_scope"></param>
/// <returns></returns> /// <returns></returns>
VariableDef to_proto(string export_scope);
TProtoDef to_proto(string export_scope);


/// <summary> /// <summary>
/// Returns a `Variable` object created from `variable_def`. /// Returns a `Variable` object created from `variable_def`.
/// </summary> /// </summary>
/// <typeparam name="T"></typeparam> /// <typeparam name="T"></typeparam>
/// <param name="variable_def"></param>
/// <param name="proto"></param>
/// <param name="import_scope"></param> /// <param name="import_scope"></param>
/// <returns></returns> /// <returns></returns>
T from_proto<T>(VariableDef variable_def, string import_scope);
TDef from_proto(TProtoDef proto, string import_scope);
} }
} }

+ 5
- 2
src/TensorFlowNET.Core/Protobuf/README.md View File

@@ -1,10 +1,12 @@
### Download compiler from https://github.com/protocolbuffers/protobuf/releases ### Download compiler from https://github.com/protocolbuffers/protobuf/releases
Work in command line

```shell ```shell
cd tensorflow

set SRC_DIR=D:/Projects/tensorflow set SRC_DIR=D:/Projects/tensorflow
set DST_DIR=D:/Projects/TensorFlow.NET/src/TensorFlowNET.Core/Protobuf set DST_DIR=D:/Projects/TensorFlow.NET/src/TensorFlowNET.Core/Protobuf


cd tensorflow

protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/resource_handle.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/resource_handle.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/tensor_shape.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/tensor_shape.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/types.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/types.proto
@@ -32,6 +34,7 @@ protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/cluster.prot
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/config.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/config.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/debug.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/debug.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/rewriter_config.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/rewriter_config.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/control_flow.proto
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/python/training/checkpoint_state.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/python/training/checkpoint_state.proto
``` ```



+ 2
- 2
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -7,7 +7,7 @@ using System.Text;


namespace Tensorflow namespace Tensorflow
{ {
public partial class RefVariable : VariableV1, IProtoBuf
public partial class RefVariable : VariableV1, IProtoBuf<VariableDef, RefVariable>
{ {
public bool _in_graph_mode = true; public bool _in_graph_mode = true;
public Tensor _initial_value; public Tensor _initial_value;
@@ -288,7 +288,7 @@ namespace Tensorflow
throw new NotImplementedException("to_proto RefVariable"); throw new NotImplementedException("to_proto RefVariable");
} }


public T from_proto<T>(VariableDef variable_def, string import_scope)
public RefVariable from_proto(VariableDef proto, string import_scope)
{ {
throw new NotImplementedException(); throw new NotImplementedException();
} }


+ 1
- 1
src/TensorFlowNET.Core/ops.py.cs View File

@@ -376,7 +376,7 @@ namespace Tensorflow
if (import_scope.EndsWith("/")) if (import_scope.EndsWith("/"))
import_scope = import_scope.Substring(0, import_scope.Length - 1); import_scope = import_scope.Substring(0, import_scope.Length - 1);


throw new NotImplementedException("prepend_name_scope");
return $"{import_scope}/{name}";
} }
else else
return name; return name;


+ 21
- 43
test/TensorFlowNET.UnitTest/PythonTest.cs View File

@@ -132,10 +132,11 @@ namespace TensorFlowNET.UnitTest
} }
/// <summary> /// <summary>
/// Evaluates tensors and returns a dictionary of {name:result, ...}.
/// <param name="tensors">A Tensor or a nested list/tuple of Tensors.</param>
/// This function is used in many original tensorflow unit tests to evaluate tensors
/// in a test session with special settings (for instance constant folding off)
///
/// </summary> /// </summary>
public Dictionary<string, NDArray> evaluate(params Tensor[] tensors)
public T evaluate<T>(Tensor tensor)
{ {
var results = new Dictionary<string, NDArray>(); var results = new Dictionary<string, NDArray>();
// if context.executing_eagerly(): // if context.executing_eagerly():
@@ -145,49 +146,26 @@ namespace TensorFlowNET.UnitTest
var sess = ops.get_default_session(); var sess = ops.get_default_session();
if (sess == null) if (sess == null)
sess = self.session(); sess = self.session();
with<Session>(sess, s =>
{
foreach (var t in tensors)
results[t.name] = t.eval();
});
return results;
}
}
public NDArray evaluate(Tensor tensor)
{
NDArray result = null;
// if context.executing_eagerly():
// return self._eval_helper(tensors)
// else:
{
var sess = ops.get_default_session();
if (sess == null)
sess = self.session();
with<Session>(sess, s =>
{
result = tensor.eval();
});
return result;
}
}
public object eval_scalar(Tensor tensor)
{
NDArray result = null;
// if context.executing_eagerly():
// return self._eval_helper(tensors)
// else:
{
var sess = ops.get_default_session();
if (sess == null)
sess = self.session();
T t_result = (T)(object)null;
with<Session>(sess, s => with<Session>(sess, s =>
{ {
result = tensor.eval();
var ndarray=tensor.eval();
if (typeof(T) == typeof(double))
{
double d = ndarray;
t_result = (T)(object)d;
}
else if (typeof(T) == typeof(int))
{
int d = ndarray;
t_result = (T) (object) d;
}
else
{
t_result = (T)(object)ndarray;
}
}); });
return result.Array.GetValue(0);
return t_result;
} }
} }


+ 53
- 11
test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs View File

@@ -1,4 +1,5 @@
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow; using Tensorflow;
namespace TensorFlowNET.UnitTest.control_flow_ops_test namespace TensorFlowNET.UnitTest.control_flow_ops_test
@@ -9,32 +10,73 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
[TestClass] [TestClass]
public class CondTestCases : PythonTest public class CondTestCases : PythonTest
{ {
[TestMethod] [TestMethod]
public void testCondTrue() public void testCondTrue()
{ {
with(tf.Graph().as_default(), g =>
var graph = tf.Graph().as_default();
with(tf.Session(graph), sess =>
{ {
var x = tf.constant(2); var x = tf.constant(2);
var y = tf.constant(5); var y = tf.constant(5);
var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)),
() => tf.add(y, tf.constant(23)));
//tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false);
self.assertEquals(eval_scalar(z), 34);
var pred = tf.less(x, y);
Func<ITensorOrOperation> if_true = delegate
{
return tf.multiply(x, 17);
};
Func<ITensorOrOperation> if_false = delegate
{
return tf.add(y, 23);
};
var z = control_flow_ops.cond(pred, if_true, if_false);
int result = z.eval(sess);
assertEquals(result, 34);
}); });
} }
//[Ignore("This Test Fails due to missing edges in the graph!")]
[TestMethod] [TestMethod]
public void testCondFalse() public void testCondFalse()
{ {
with(tf.Graph().as_default(), g =>
/* python
* import tensorflow as tf
from tensorflow.python.framework import ops
def if_true():
return tf.math.multiply(x, 17)
def if_false():
return tf.math.add(y, 23)
with tf.Session() as sess:
x = tf.constant(2)
y = tf.constant(1)
pred = tf.math.less(x,y)
z = tf.cond(pred, if_true, if_false)
result = z.eval()
print(result == 24) */
with(tf.Session(), sess =>
{ {
var x = tf.constant(2); var x = tf.constant(2);
var y = tf.constant(1); var y = tf.constant(1);
var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)),
() => tf.add(y, tf.constant(23)));
self.assertEquals(eval_scalar(z), 24);
var pred = tf.less(x, y);
Func<ITensorOrOperation> if_true = delegate
{
return tf.multiply(x, 17);
};
Func<ITensorOrOperation> if_false = delegate
{
return tf.add(y, 23);
};
var z = control_flow_ops.cond(pred, if_true, if_false);
int result = z.eval(sess);
assertEquals(result, 24);
}); });
} }


+ 1
- 1
test/TensorFlowNET.UnitTest/ops_test/ControlDependenciesTest.cs View File

@@ -162,7 +162,7 @@ namespace TensorFlowNET.UnitTest.ops_test
{ {
var z1 = tf.add(a_3, tf.multiply(a_4, a_2)); var z1 = tf.add(a_3, tf.multiply(a_4, a_2));
}); });
tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false);
//tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false);
assertItemsEqual(b_1.op.control_inputs, new[] { a_1.op, a_2.op, a_3.op, a_4.op }); assertItemsEqual(b_1.op.control_inputs, new[] { a_1.op, a_2.op, a_3.op, a_4.op });
assertItemsEqual(b_2.op.control_inputs, b_1.op.control_inputs); assertItemsEqual(b_2.op.control_inputs, b_1.op.control_inputs);
} }


Loading…
Cancel
Save