diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index bbf240e3..fbf3dd00 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -27,7 +27,7 @@ namespace Tensorflow public static Tensor asin(Tensor x, string name = null) => gen_math_ops.asin(x, name); - public static Tensor add(Tensor a, Tensor b) + public static Tensor add(Tx a, Ty b) => gen_math_ops.add(a, b); /// @@ -251,7 +251,7 @@ namespace Tensorflow public static Tensor minimum(T1 x, T2 y, string name = null) => gen_math_ops.minimum(x, y, name: name); - public static Tensor multiply(Tensor x, Tensor y) + public static Tensor multiply(Tx x, Ty y) => gen_math_ops.mul(x, y); public static Tensor negative(Tensor x, string name = null) diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs b/src/TensorFlowNET.Core/Framework/meta_graph.py.cs index c7af7051..ceebdc6e 100644 --- a/src/TensorFlowNET.Core/Framework/meta_graph.py.cs +++ b/src/TensorFlowNET.Core/Framework/meta_graph.py.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; using System.Text; +using Tensorflow.Operations; using static Tensorflow.CollectionDef; using static Tensorflow.MetaGraphDef.Types; @@ -95,15 +96,29 @@ namespace Tensorflow } else { - throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); + foreach(var value in col.Value.BytesList.Value) + { + switch (col.Key) + { + case "cond_context": + var proto = CondContextDef.Parser.ParseFrom(value); + var condContext = new CondContext().from_proto(proto, import_scope); + graph.add_to_collection(col.Key, condContext); + break; + default: + throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); + } + } } break; + default: + throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); } } - var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, - scope: scope_to_prepend_to_names) as List; + var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, + scope: scope_to_prepend_to_names); var var_list = new Dictionary(); variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v); diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 081893c2..f1a33371 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -412,6 +412,11 @@ namespace Tensorflow return _collections.ContainsKey(name) ? _collections[name] : null; } + public List get_collection(string name, string scope = null) + { + return _collections.ContainsKey(name) ? _collections[name] as List : new List(); + } + public object get_collection_ref(string name) { if (!_collections.ContainsKey(name)) diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs index 1bfa81f2..c00e2c0e 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs @@ -8,7 +8,7 @@ namespace Tensorflow.Operations /// /// The context for the conditional construct. /// - public class CondContext : ControlFlowContext + public class CondContext : ControlFlowContext, IProtoBuf { @@ -35,16 +35,20 @@ namespace Tensorflow.Operations /// Name of the `CondContext` python object. /// /// - public CondContext(Tensor pred, - Tensor pivot, - int branch, + public CondContext(Tensor pred = null, + Tensor pivot = null, + int? branch = null, string name = "cond_text", - object context_def = null, + CondContextDef context_def = null, string import_scope = null) { + if (pred == null && context_def == null) return; + _name = ops.get_default_graph().unique_name(name); - if (context_def != null) - throw new NotImplementedException("CondContext context_def is not null"); + if (context_def != null) + { + _init_from_proto(context_def, import_scope: import_scope); + } else { // Initializes the default fields. @@ -61,6 +65,18 @@ namespace Tensorflow.Operations } } + private void _init_from_proto(CondContextDef context_def, string import_scope = null) + { + var g = ops.get_default_graph(); + _name = ops.prepend_name_scope(context_def.ContextName, import_scope); + var p1 = ops.prepend_name_scope(context_def.PredName, import_scope); + _pred = g.as_graph_element(p1) as Tensor; + var p2 = ops.prepend_name_scope(context_def.PivotName, import_scope); + _pivot = g.as_graph_element(p2) as Tensor; + _branch = context_def.Branch; + __init__(values_def: context_def.ValuesDef, import_scope: import_scope); + } + /// /// Add `val` to the current context and its outer context recursively. /// @@ -230,6 +246,22 @@ namespace Tensorflow.Operations public override void AddInnerOp(Operation resultOp) { throw new NotImplementedException(); - } + } + + public CondContextDef to_proto(string export_scope) + { + throw new NotImplementedException(); + } + + public CondContext from_proto(CondContextDef proto, string import_scope) + { + var ret = new CondContext(context_def: proto, import_scope: import_scope); + + ret.Enter(); + foreach (var nested_def in proto.NestedContexts) + throw new NotImplementedException(""); + ret.Exit(); + return ret; + } } } \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs index fef79c8d..86452e50 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs @@ -32,6 +32,8 @@ namespace Tensorflow.Operations protected Stack _context_stack; protected IControlFlowContext _outer_context; + protected Dictionary _external_values; + public ControlFlowContext() { _context_stack = new Stack(); @@ -40,15 +42,43 @@ namespace Tensorflow.Operations public string name { get => _name; } protected string _name; - public void __init__() + public void __init__(ValuesDef values_def = null, string import_scope = null) { - + _outer_context = ops.get_default_graph()._get_control_flow_context(); + if (values_def != null) + _init_values_from_proto(values_def, import_scope: import_scope); } public void __enter__() { } + /// + /// Initializes values and external_values from `ValuesDef` protocol buffer. + /// + /// + /// + protected void _init_values_from_proto(ValuesDef values_def, string import_scope = null) + { + _external_values = new Dictionary(); + foreach (var value in values_def.Values) + _values.Add(value); + var g = ops.get_default_graph(); + foreach(var value in values_def.ExternalValues) + { + var k = ops.prepend_name_scope(value.Key, import_scope); + var v = value.Value; + _external_values[k] = g.as_graph_element(ops.prepend_name_scope(v, import_scope)); + } + + var op_names = _values.Where(x => !_external_values.ContainsKey(x)) + .Select(x => x.Split(':')[0]) + .ToArray(); + + foreach (var op in op_names) + (g.as_graph_element(op) as Operation)._set_control_flow_context(this); + } + public void __exit__() { } diff --git a/src/TensorFlowNET.Core/Operations/Operation.Output.cs b/src/TensorFlowNET.Core/Operations/Operation.Output.cs index 5b0b43b3..d7e975bb 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Output.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Output.cs @@ -42,8 +42,8 @@ namespace Tensorflow if (NumControlOutputs > 0) { IntPtr control_output_handle = Marshal.AllocHGlobal(Marshal.SizeOf() * NumControlOutputs); - c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlInputs); - for (int i = 0; i < NumControlInputs; i++) + c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlOutputs); + for (int i = 0; i < NumControlOutputs; i++) { var handle = control_output_handle + Marshal.SizeOf() * i; control_outputs[i] = new Operation(*(IntPtr*)handle); diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index e1bf9246..6c49c4a9 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -1,319 +1,318 @@ -using Google.Protobuf.Collections; -//using Newtonsoft.Json; -using System; -using System.Collections.Generic; -using System.Linq; -using System.Runtime.InteropServices; -using System.Text; - -namespace Tensorflow +using Google.Protobuf.Collections; +//using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; + +namespace Tensorflow { - /// - /// Represents a graph node that performs computation on tensors. - /// - /// An `Operation` is a node in a TensorFlow `Graph` that takes zero or - /// more `Tensor` objects as input, and produces zero or more `Tensor` - /// objects as output. Objects of type `Operation` are created by - /// calling an op constructor(such as `tf.matmul`) - /// or `tf.Graph.create_op`. - /// - /// For example `c = tf.matmul(a, b)` creates an `Operation` of type - /// "MatMul" that takes tensors `a` and `b` as input, and produces `c` - /// as output. - /// - /// After the graph has been launched in a session, an `Operation` can - /// be executed by passing it to - /// `tf.Session.run`. + /// + /// Represents a graph node that performs computation on tensors. + /// + /// An `Operation` is a node in a TensorFlow `Graph` that takes zero or + /// more `Tensor` objects as input, and produces zero or more `Tensor` + /// objects as output. Objects of type `Operation` are created by + /// calling an op constructor(such as `tf.matmul`) + /// or `tf.Graph.create_op`. + /// + /// For example `c = tf.matmul(a, b)` creates an `Operation` of type + /// "MatMul" that takes tensors `a` and `b` as input, and produces `c` + /// as output. + /// + /// After the graph has been launched in a session, an `Operation` can + /// be executed by passing it to + /// `tf.Session.run`. /// `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`. - /// - public partial class Operation : ITensorOrOperation - { - private readonly IntPtr _handle; // _c_op in python - private readonly IntPtr _operDesc; - - private Graph _graph; - //[JsonIgnore] - public Graph graph => _graph; - //[JsonIgnore] - public int _id => _id_value; - //[JsonIgnore] - public int _id_value; - - public string type => OpType; - //[JsonIgnore] - public Operation op => this; - public TF_DataType dtype => TF_DataType.DtInvalid; - private Status status = new Status(); - - public string name => c_api.StringPiece(c_api.TF_OperationName(_handle)); - public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle)); - public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); - - private NodeDef _node_def; - public NodeDef node_def - { - get - { - if(_node_def == null) - _node_def = GetNodeDef(); - - return _node_def; - } - } - - public Operation(IntPtr handle, Graph g=null) - { - if (handle == IntPtr.Zero) - return; - - _handle = handle; - _graph = g ?? ops.get_default_graph(); - _outputs = new Tensor[NumOutputs]; - for (int i = 0; i < NumOutputs; i++) - _outputs[i] = new Tensor(this, i, OutputType(i)); - - // Dict mapping op name to file and line information for op colocation - // context managers. - _control_flow_context = graph._get_control_flow_context(); - - // Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor. - } - - public Operation(Graph g, string opType, string oper_name) - { - _graph = g; - - _operDesc = c_api.TF_NewOperation(g, opType, oper_name); - c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32); - _handle = c_api.TF_FinishOperation(_operDesc, status); - - // Dict mapping op name to file and line information for op colocation - // context managers. - _control_flow_context = graph._get_control_flow_context(); - } - - /// - /// Creates an `Operation`. - /// - /// `node_def_pb2.NodeDef`. `NodeDef` for the `Operation`. - /// `Graph`. The parent graph. - /// list of `Tensor` objects. The inputs to this `Operation`. - /// list of `DType` objects. - /// - /// list of operations or tensors from which to have a - /// control dependency. - /// - /// - /// List of `DType` objects representing the - /// types of the tensors accepted by the `Operation`. By default - /// uses `[x.dtype.base_dtype for x in inputs]`. Operations that expect - /// reference-typed inputs must specify these explicitly. - /// - /// - /// - public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) - { - _graph = g; - - // Build the list of control inputs. - var control_input_ops = new List(); - if(control_inputs != null) - { - foreach(var c in control_inputs) - { - switch (c) - { - case Operation c1: - control_input_ops.Add(c1); - break; - case Tensor tensor: - control_input_ops.Add(tensor.op); - break; - // TODO: IndexedSlices don't yet exist, but once they do, this needs to be uncommented - //case IndexedSlices islices: - // control_input_ops.Add(islices.op); - // break; - default: - throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}"); - } - } - } - - // Dict mapping op name to file and line information for op colocation - // context managers. - _control_flow_context = graph._get_control_flow_context(); - - // This will be set by self.inputs. - if (op_def == null) - op_def = g.GetOpDef(node_def.Op); - - var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); - (_handle, _operDesc) = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); - - // Initialize self._outputs. - output_types = new TF_DataType[NumOutputs]; - for (int i = 0; i < NumOutputs; i++) - output_types[i] = OutputType(i); - - _outputs = new Tensor[NumOutputs]; - for (int i = 0; i < NumOutputs; i++) - _outputs[i] = new Tensor(this, i, OutputType(i)); - - graph._add_op(this); - - if (_handle != IntPtr.Zero) - _control_flow_post_processing(); - } - - public void run(FeedItem[] feed_dict = null, Session session = null) - { - ops._run_using_default_session(this, feed_dict, graph, session); - } - - private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField attrs) - { - var grouped_inputs = new List(); - int i = 0; - int input_len = 0; - bool is_sequence = false; - foreach (var input_arg in op_def.InputArg) - { - if (!string.IsNullOrEmpty(input_arg.NumberAttr)) - { - input_len = (int)attrs[input_arg.NumberAttr].I; - is_sequence = true; - } - else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) - { - input_len = attrs[input_arg.TypeListAttr].List.Type.Count; - is_sequence = true; - } - else - { - input_len = 1; - is_sequence = false; - } - - if (is_sequence) - grouped_inputs.Add(inputs.Skip(i).Take(input_len).ToArray()); - else - grouped_inputs.Add(inputs[i]); - - i += input_len; - } - - return grouped_inputs.ToArray(); - } - - public object get_attr(string name) - { - AttrValue x = null; - - using (var buf = new Buffer()) - { - c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); - status.Check(true); - x = AttrValue.Parser.ParseFrom(buf); - } - - string oneof_value = x.ValueCase.ToString(); - if (string.IsNullOrEmpty(oneof_value)) - return null; - - if(oneof_value == "list") - throw new NotImplementedException($"Unsupported field type in {x.ToString()}"); - - if (oneof_value == "type") - return x.Type; - - object result = x.GetType().GetProperty(oneof_value).GetValue(x); - if (result is Google.Protobuf.ByteString byteString) - return byteString.ToStringUtf8(); - return result; - } - - public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s) - { - return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s); - } - - private NodeDef GetNodeDef() - { - using (var s = new Status()) - using (var buffer = new Buffer()) - { - c_api.TF_OperationToNodeDef(_handle, buffer, s); - s.Check(); - return NodeDef.Parser.ParseFrom(buffer); - } - } - - public override string ToString() - { - return _handle == IntPtr.Zero ? "tf.Operation Undefined" : $"tf.Operation '{name}' type={OpType}"; - } - - public static implicit operator Operation(IntPtr handle) => new Operation(handle); - public static implicit operator IntPtr(Operation op) => op._handle; - - public override bool Equals(object obj) - { - switch (obj) - { - case IntPtr val: - return val == _handle; - case Operation val: - return val._handle == _handle; - } - - return base.Equals(obj); + /// + public partial class Operation : ITensorOrOperation + { + private readonly IntPtr _handle; // _c_op in python + private readonly IntPtr _operDesc; + + private Graph _graph; + //[JsonIgnore] + public Graph graph => _graph; + //[JsonIgnore] + public int _id => _id_value; + //[JsonIgnore] + public int _id_value; + + public string type => OpType; + //[JsonIgnore] + public Operation op => this; + public TF_DataType dtype => TF_DataType.DtInvalid; + private Status status = new Status(); + + public string name => c_api.StringPiece(c_api.TF_OperationName(_handle)); + public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle)); + public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); + + private NodeDef _node_def; + public NodeDef node_def + { + get + { + if(_node_def == null) + _node_def = GetNodeDef(); + + return _node_def; + } + } + + public Operation(IntPtr handle, Graph g=null) + { + if (handle == IntPtr.Zero) + return; + + _handle = handle; + _graph = g ?? ops.get_default_graph(); + _outputs = new Tensor[NumOutputs]; + for (int i = 0; i < NumOutputs; i++) + _outputs[i] = new Tensor(this, i, OutputType(i)); + + // Dict mapping op name to file and line information for op colocation + // context managers. + _control_flow_context = graph._get_control_flow_context(); + + // Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor. + } + + public Operation(Graph g, string opType, string oper_name) + { + _graph = g; + + _operDesc = c_api.TF_NewOperation(g, opType, oper_name); + c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32); + _handle = c_api.TF_FinishOperation(_operDesc, status); + + // Dict mapping op name to file and line information for op colocation + // context managers. + _control_flow_context = graph._get_control_flow_context(); } - /// - /// Update the input to this operation at the given index. - /// - /// NOTE: This is for TF internal use only.Please don't use it. - /// - /// the index of the input to update. - /// the Tensor to be used as the input at the given index. - public void _update_input(int index, Tensor tensor) - { - _assert_same_graph(tensor); - - var input = _tf_input(index); + /// + /// Creates an `Operation`. + /// + /// `node_def_pb2.NodeDef`. `NodeDef` for the `Operation`. + /// `Graph`. The parent graph. + /// list of `Tensor` objects. The inputs to this `Operation`. + /// list of `DType` objects. + /// + /// list of operations or tensors from which to have a + /// control dependency. + /// + /// + /// List of `DType` objects representing the + /// types of the tensors accepted by the `Operation`. By default + /// uses `[x.dtype.base_dtype for x in inputs]`. Operations that expect + /// reference-typed inputs must specify these explicitly. + /// + /// + /// + public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) + { + _graph = g; + + // Build the list of control inputs. + var control_input_ops = new List(); + if(control_inputs != null) + { + foreach(var c in control_inputs) + { + switch (c) + { + case Operation c1: + control_input_ops.Add(c1); + break; + case Tensor tensor: + control_input_ops.Add(tensor.op); + break; + // TODO: IndexedSlices don't yet exist, but once they do, this needs to be uncommented + //case IndexedSlices islices: + // control_input_ops.Add(islices.op); + // break; + default: + throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}"); + } + } + } + + // Dict mapping op name to file and line information for op colocation + // context managers. + _control_flow_context = graph._get_control_flow_context(); + + // This will be set by self.inputs. + if (op_def == null) + op_def = g.GetOpDef(node_def.Op); + + var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); + (_handle, _operDesc) = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); + + // Initialize self._outputs. + output_types = new TF_DataType[NumOutputs]; + for (int i = 0; i < NumOutputs; i++) + output_types[i] = OutputType(i); + + _outputs = new Tensor[NumOutputs]; + for (int i = 0; i < NumOutputs; i++) + _outputs[i] = new Tensor(this, i, OutputType(i)); + + graph._add_op(this); + + if (_handle != IntPtr.Zero) + _control_flow_post_processing(); + } + + public void run(FeedItem[] feed_dict = null, Session session = null) + { + ops._run_using_default_session(this, feed_dict, graph, session); + } + + private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField attrs) + { + var grouped_inputs = new List(); + int i = 0; + int input_len = 0; + bool is_sequence = false; + foreach (var input_arg in op_def.InputArg) + { + if (!string.IsNullOrEmpty(input_arg.NumberAttr)) + { + input_len = (int)attrs[input_arg.NumberAttr].I; + is_sequence = true; + } + else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) + { + input_len = attrs[input_arg.TypeListAttr].List.Type.Count; + is_sequence = true; + } + else + { + input_len = 1; + is_sequence = false; + } + + if (is_sequence) + grouped_inputs.Add(inputs.Skip(i).Take(input_len).ToArray()); + else + grouped_inputs.Add(inputs[i]); + + i += input_len; + } + + return grouped_inputs.ToArray(); + } + + public object get_attr(string name) + { + AttrValue x = null; + + using (var buf = new Buffer()) + { + c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); + status.Check(true); + x = AttrValue.Parser.ParseFrom(buf); + } + + string oneof_value = x.ValueCase.ToString(); + if (string.IsNullOrEmpty(oneof_value)) + return null; + + if(oneof_value == "list") + throw new NotImplementedException($"Unsupported field type in {x.ToString()}"); + + if (oneof_value == "type") + return x.Type; + + object result = x.GetType().GetProperty(oneof_value).GetValue(x); + if (result is Google.Protobuf.ByteString byteString) + return byteString.ToStringUtf8(); + return result; + } + + public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s) + { + return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s); + } + + private NodeDef GetNodeDef() + { + using (var s = new Status()) + using (var buffer = new Buffer()) + { + c_api.TF_OperationToNodeDef(_handle, buffer, s); + s.Check(); + return NodeDef.Parser.ParseFrom(buffer); + } + } + + public override string ToString() + { + return _handle == IntPtr.Zero ? "tf.Operation Undefined" : $"tf.Operation '{name}' type={OpType}"; + } + + public static implicit operator Operation(IntPtr handle) => new Operation(handle); + public static implicit operator IntPtr(Operation op) => op._handle; + + public override bool Equals(object obj) + { + switch (obj) + { + case IntPtr val: + return val == _handle; + case Operation val: + return val._handle == _handle; + } + + return base.Equals(obj); + } + + /// + /// Update the input to this operation at the given index. + /// + /// NOTE: This is for TF internal use only.Please don't use it. + /// + /// the index of the input to update. + /// the Tensor to be used as the input at the given index. + public void _update_input(int index, Tensor tensor) + { + _assert_same_graph(tensor); + + var input = _tf_input(index); var output = tensor._as_tf_output(); // Reset cached inputs. _inputs = null; // after the c_api call next time _inputs is accessed - // the updated inputs are reloaded from the c_api - c_api.TF_UpdateEdge(_graph, output, input, status); - //var updated_inputs = inputs; - - } - - private void _assert_same_graph(Tensor tensor) - { - //TODO: implement - } - - /// - /// Create and return a new TF_Output for output_idx'th output of this op. - /// - public TF_Output _tf_output(int output_idx) - { - var tf_output = new TF_Output(op, output_idx); - return tf_output; - } - - /// - /// Create and return a new TF_Input for input_idx'th input of this op. - /// - public TF_Input _tf_input(int input_idx) - { - var tf_input = new TF_Input(op, input_idx); - return tf_input; - } - } -} + // the updated inputs are reloaded from the c_api + c_api.TF_UpdateEdge(_graph, output, input, status); + //var updated_inputs = inputs; + } + + private void _assert_same_graph(Tensor tensor) + { + //TODO: implement + } + + /// + /// Create and return a new TF_Output for output_idx'th output of this op. + /// + public TF_Output _tf_output(int output_idx) + { + var tf_output = new TF_Output(op, output_idx); + return tf_output; + } + + /// + /// Create and return a new TF_Input for input_idx'th input of this op. + /// + public TF_Input _tf_input(int input_idx) + { + var tf_input = new TF_Input(op, input_idx); + return tf_input; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs index f1c799e9..dbb7a96e 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -308,7 +308,7 @@ namespace Tensorflow tensor.op.graph.prevent_fetching(tensor.op); // Build the graph for the true branch in a new context. - var context_t = new CondContext(pred, pivot_1, branch: 1); + var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1); ITensorOrOperation orig_res_t; Tensor res_t; try @@ -321,7 +321,7 @@ namespace Tensorflow context_t.Exit(); } // Build the graph for the false branch in a new context. - var context_f = new CondContext(pred, pivot_2, branch: 0); + var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0); ITensorOrOperation orig_res_f; Tensor res_f; try @@ -389,13 +389,13 @@ namespace Tensorflow tensor.op.graph.prevent_fetching(tensor.op); // Build the graph for the true branch in a new context. - var context_t = new CondContext(pred, pivot_1, branch: 1); + var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1); context_t.Enter(); var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); context_t.Exit(); // Build the graph for the false branch in a new context. - var context_f = new CondContext(pred, pivot_2, branch: 0); + var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0); context_f.Enter(); var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); context_f.Exit(); diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 5e58df45..56477442 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -80,7 +80,7 @@ namespace Tensorflow return _op.outputs[0]; } - public static Tensor add(Tensor x, Tensor y, string name = null) + public static Tensor add(Tx x, Ty y, string name = null) { var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y }); @@ -300,7 +300,7 @@ namespace Tensorflow return _op.outputs[0]; } - public static Tensor mul(Tensor x, Tensor y, string name = null) + public static Tensor mul(Tx x, Ty y, string name = null) { var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y }); diff --git a/src/TensorFlowNET.Core/Protobuf/ControlFlow.cs b/src/TensorFlowNET.Core/Protobuf/ControlFlow.cs new file mode 100644 index 00000000..108bcc45 --- /dev/null +++ b/src/TensorFlowNET.Core/Protobuf/ControlFlow.cs @@ -0,0 +1,1172 @@ +// +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensorflow/core/protobuf/control_flow.proto +// +#pragma warning disable 1591, 0612, 3021 +#region Designer generated code + +using pb = global::Google.Protobuf; +using pbc = global::Google.Protobuf.Collections; +using pbr = global::Google.Protobuf.Reflection; +using scg = global::System.Collections.Generic; +namespace Tensorflow { + + /// Holder for reflection information generated from tensorflow/core/protobuf/control_flow.proto + public static partial class ControlFlowReflection { + + #region Descriptor + /// File descriptor for tensorflow/core/protobuf/control_flow.proto + public static pbr::FileDescriptor Descriptor { + get { return descriptor; } + } + private static pbr::FileDescriptor descriptor; + + static ControlFlowReflection() { + byte[] descriptorData = global::System.Convert.FromBase64String( + string.Concat( + "Cit0ZW5zb3JmbG93L2NvcmUvcHJvdG9idWYvY29udHJvbF9mbG93LnByb3Rv", + "Egp0ZW5zb3JmbG93IpYBCglWYWx1ZXNEZWYSDgoGdmFsdWVzGAEgAygJEkIK", + "D2V4dGVybmFsX3ZhbHVlcxgCIAMoCzIpLnRlbnNvcmZsb3cuVmFsdWVzRGVm", + "LkV4dGVybmFsVmFsdWVzRW50cnkaNQoTRXh0ZXJuYWxWYWx1ZXNFbnRyeRIL", + "CgNrZXkYASABKAkSDQoFdmFsdWUYAiABKAk6AjgBIoMBChVDb250cm9sRmxv", + "d0NvbnRleHREZWYSLwoJY29uZF9jdHh0GAEgASgLMhoudGVuc29yZmxvdy5D", + "b25kQ29udGV4dERlZkgAEjEKCndoaWxlX2N0eHQYAiABKAsyGy50ZW5zb3Jm", + "bG93LldoaWxlQ29udGV4dERlZkgAQgYKBGN0eHQixAEKDkNvbmRDb250ZXh0", + "RGVmEhQKDGNvbnRleHRfbmFtZRgBIAEoCRIRCglwcmVkX25hbWUYAiABKAkS", + "EgoKcGl2b3RfbmFtZRgDIAEoCRIOCgZicmFuY2gYBCABKAUSKQoKdmFsdWVz", + "X2RlZhgFIAEoCzIVLnRlbnNvcmZsb3cuVmFsdWVzRGVmEjoKD25lc3RlZF9j", + "b250ZXh0cxgGIAMoCzIhLnRlbnNvcmZsb3cuQ29udHJvbEZsb3dDb250ZXh0", + "RGVmIvUCCg9XaGlsZUNvbnRleHREZWYSFAoMY29udGV4dF9uYW1lGAEgASgJ", + "EhsKE3BhcmFsbGVsX2l0ZXJhdGlvbnMYAiABKAUSEQoJYmFja19wcm9wGAMg", + "ASgIEhMKC3N3YXBfbWVtb3J5GAQgASgIEhIKCnBpdm90X25hbWUYBSABKAkS", + "GwoTcGl2b3RfZm9yX3ByZWRfbmFtZRgGIAEoCRIbChNwaXZvdF9mb3JfYm9k", + "eV9uYW1lGAcgASgJEhcKD2xvb3BfZXhpdF9uYW1lcxgIIAMoCRIYChBsb29w", + "X2VudGVyX25hbWVzGAogAygJEikKCnZhbHVlc19kZWYYCSABKAsyFS50ZW5z", + "b3JmbG93LlZhbHVlc0RlZhIfChdtYXhpbXVtX2l0ZXJhdGlvbnNfbmFtZRgL", + "IAEoCRI6Cg9uZXN0ZWRfY29udGV4dHMYDCADKAsyIS50ZW5zb3JmbG93LkNv", + "bnRyb2xGbG93Q29udGV4dERlZkJwChhvcmcudGVuc29yZmxvdy5mcmFtZXdv", + "cmtCEUNvbnRyb2xGbG93UHJvdG9zUAFaPGdpdGh1Yi5jb20vdGVuc29yZmxv", + "dy90ZW5zb3JmbG93L3RlbnNvcmZsb3cvZ28vY29yZS9wcm90b2J1ZvgBAWIG", + "cHJvdG8z")); + descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, + new pbr::FileDescriptor[] { }, + new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ValuesDef), global::Tensorflow.ValuesDef.Parser, new[]{ "Values", "ExternalValues" }, null, null, new pbr::GeneratedClrTypeInfo[] { null, }), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.ControlFlowContextDef), global::Tensorflow.ControlFlowContextDef.Parser, new[]{ "CondCtxt", "WhileCtxt" }, new[]{ "Ctxt" }, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.CondContextDef), global::Tensorflow.CondContextDef.Parser, new[]{ "ContextName", "PredName", "PivotName", "Branch", "ValuesDef", "NestedContexts" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::Tensorflow.WhileContextDef), global::Tensorflow.WhileContextDef.Parser, new[]{ "ContextName", "ParallelIterations", "BackProp", "SwapMemory", "PivotName", "PivotForPredName", "PivotForBodyName", "LoopExitNames", "LoopEnterNames", "ValuesDef", "MaximumIterationsName", "NestedContexts" }, null, null, null) + })); + } + #endregion + + } + #region Messages + /// + /// Protocol buffer representing the values in ControlFlowContext. + /// + public sealed partial class ValuesDef : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ValuesDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ValuesDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ValuesDef(ValuesDef other) : this() { + values_ = other.values_.Clone(); + externalValues_ = other.externalValues_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ValuesDef Clone() { + return new ValuesDef(this); + } + + /// Field number for the "values" field. + public const int ValuesFieldNumber = 1; + private static readonly pb::FieldCodec _repeated_values_codec + = pb::FieldCodec.ForString(10); + private readonly pbc::RepeatedField values_ = new pbc::RepeatedField(); + /// + /// Value names that have been seen in this context. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField Values { + get { return values_; } + } + + /// Field number for the "external_values" field. + public const int ExternalValuesFieldNumber = 2; + private static readonly pbc::MapField.Codec _map_externalValues_codec + = new pbc::MapField.Codec(pb::FieldCodec.ForString(10), pb::FieldCodec.ForString(18), 18); + private readonly pbc::MapField externalValues_ = new pbc::MapField(); + /// + /// Value names referenced by but external to this context. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::MapField ExternalValues { + get { return externalValues_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ValuesDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ValuesDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if(!values_.Equals(other.values_)) return false; + if (!ExternalValues.Equals(other.ExternalValues)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + hash ^= values_.GetHashCode(); + hash ^= ExternalValues.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + values_.WriteTo(output, _repeated_values_codec); + externalValues_.WriteTo(output, _map_externalValues_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + size += values_.CalculateSize(_repeated_values_codec); + size += externalValues_.CalculateSize(_map_externalValues_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ValuesDef other) { + if (other == null) { + return; + } + values_.Add(other.values_); + externalValues_.Add(other.externalValues_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + values_.AddEntriesFrom(input, _repeated_values_codec); + break; + } + case 18: { + externalValues_.AddEntriesFrom(input, _map_externalValues_codec); + break; + } + } + } + } + + } + + /// + /// Container for any kind of control flow context. Any other control flow + /// contexts that are added below should also be added here. + /// + public sealed partial class ControlFlowContextDef : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new ControlFlowContextDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[1]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ControlFlowContextDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ControlFlowContextDef(ControlFlowContextDef other) : this() { + switch (other.CtxtCase) { + case CtxtOneofCase.CondCtxt: + CondCtxt = other.CondCtxt.Clone(); + break; + case CtxtOneofCase.WhileCtxt: + WhileCtxt = other.WhileCtxt.Clone(); + break; + } + + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public ControlFlowContextDef Clone() { + return new ControlFlowContextDef(this); + } + + /// Field number for the "cond_ctxt" field. + public const int CondCtxtFieldNumber = 1; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.CondContextDef CondCtxt { + get { return ctxtCase_ == CtxtOneofCase.CondCtxt ? (global::Tensorflow.CondContextDef) ctxt_ : null; } + set { + ctxt_ = value; + ctxtCase_ = value == null ? CtxtOneofCase.None : CtxtOneofCase.CondCtxt; + } + } + + /// Field number for the "while_ctxt" field. + public const int WhileCtxtFieldNumber = 2; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.WhileContextDef WhileCtxt { + get { return ctxtCase_ == CtxtOneofCase.WhileCtxt ? (global::Tensorflow.WhileContextDef) ctxt_ : null; } + set { + ctxt_ = value; + ctxtCase_ = value == null ? CtxtOneofCase.None : CtxtOneofCase.WhileCtxt; + } + } + + private object ctxt_; + /// Enum of possible cases for the "ctxt" oneof. + public enum CtxtOneofCase { + None = 0, + CondCtxt = 1, + WhileCtxt = 2, + } + private CtxtOneofCase ctxtCase_ = CtxtOneofCase.None; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CtxtOneofCase CtxtCase { + get { return ctxtCase_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void ClearCtxt() { + ctxtCase_ = CtxtOneofCase.None; + ctxt_ = null; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as ControlFlowContextDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(ControlFlowContextDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (!object.Equals(CondCtxt, other.CondCtxt)) return false; + if (!object.Equals(WhileCtxt, other.WhileCtxt)) return false; + if (CtxtCase != other.CtxtCase) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (ctxtCase_ == CtxtOneofCase.CondCtxt) hash ^= CondCtxt.GetHashCode(); + if (ctxtCase_ == CtxtOneofCase.WhileCtxt) hash ^= WhileCtxt.GetHashCode(); + hash ^= (int) ctxtCase_; + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (ctxtCase_ == CtxtOneofCase.CondCtxt) { + output.WriteRawTag(10); + output.WriteMessage(CondCtxt); + } + if (ctxtCase_ == CtxtOneofCase.WhileCtxt) { + output.WriteRawTag(18); + output.WriteMessage(WhileCtxt); + } + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (ctxtCase_ == CtxtOneofCase.CondCtxt) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(CondCtxt); + } + if (ctxtCase_ == CtxtOneofCase.WhileCtxt) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(WhileCtxt); + } + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(ControlFlowContextDef other) { + if (other == null) { + return; + } + switch (other.CtxtCase) { + case CtxtOneofCase.CondCtxt: + if (CondCtxt == null) { + CondCtxt = new global::Tensorflow.CondContextDef(); + } + CondCtxt.MergeFrom(other.CondCtxt); + break; + case CtxtOneofCase.WhileCtxt: + if (WhileCtxt == null) { + WhileCtxt = new global::Tensorflow.WhileContextDef(); + } + WhileCtxt.MergeFrom(other.WhileCtxt); + break; + } + + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + global::Tensorflow.CondContextDef subBuilder = new global::Tensorflow.CondContextDef(); + if (ctxtCase_ == CtxtOneofCase.CondCtxt) { + subBuilder.MergeFrom(CondCtxt); + } + input.ReadMessage(subBuilder); + CondCtxt = subBuilder; + break; + } + case 18: { + global::Tensorflow.WhileContextDef subBuilder = new global::Tensorflow.WhileContextDef(); + if (ctxtCase_ == CtxtOneofCase.WhileCtxt) { + subBuilder.MergeFrom(WhileCtxt); + } + input.ReadMessage(subBuilder); + WhileCtxt = subBuilder; + break; + } + } + } + } + + } + + /// + /// Protocol buffer representing a CondContext object. + /// + public sealed partial class CondContextDef : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new CondContextDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CondContextDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CondContextDef(CondContextDef other) : this() { + contextName_ = other.contextName_; + predName_ = other.predName_; + pivotName_ = other.pivotName_; + branch_ = other.branch_; + valuesDef_ = other.valuesDef_ != null ? other.valuesDef_.Clone() : null; + nestedContexts_ = other.nestedContexts_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public CondContextDef Clone() { + return new CondContextDef(this); + } + + /// Field number for the "context_name" field. + public const int ContextNameFieldNumber = 1; + private string contextName_ = ""; + /// + /// Name of the context. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string ContextName { + get { return contextName_; } + set { + contextName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "pred_name" field. + public const int PredNameFieldNumber = 2; + private string predName_ = ""; + /// + /// Name of the pred tensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string PredName { + get { return predName_; } + set { + predName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "pivot_name" field. + public const int PivotNameFieldNumber = 3; + private string pivotName_ = ""; + /// + /// Name of the pivot tensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string PivotName { + get { return pivotName_; } + set { + pivotName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "branch" field. + public const int BranchFieldNumber = 4; + private int branch_; + /// + /// Branch prediction. 0 or 1. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int Branch { + get { return branch_; } + set { + branch_ = value; + } + } + + /// Field number for the "values_def" field. + public const int ValuesDefFieldNumber = 5; + private global::Tensorflow.ValuesDef valuesDef_; + /// + /// Values and external values in control flow context. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.ValuesDef ValuesDef { + get { return valuesDef_; } + set { + valuesDef_ = value; + } + } + + /// Field number for the "nested_contexts" field. + public const int NestedContextsFieldNumber = 6; + private static readonly pb::FieldCodec _repeated_nestedContexts_codec + = pb::FieldCodec.ForMessage(50, global::Tensorflow.ControlFlowContextDef.Parser); + private readonly pbc::RepeatedField nestedContexts_ = new pbc::RepeatedField(); + /// + /// Contexts contained inside this context (e.g. nested conds). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField NestedContexts { + get { return nestedContexts_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as CondContextDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(CondContextDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (ContextName != other.ContextName) return false; + if (PredName != other.PredName) return false; + if (PivotName != other.PivotName) return false; + if (Branch != other.Branch) return false; + if (!object.Equals(ValuesDef, other.ValuesDef)) return false; + if(!nestedContexts_.Equals(other.nestedContexts_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (ContextName.Length != 0) hash ^= ContextName.GetHashCode(); + if (PredName.Length != 0) hash ^= PredName.GetHashCode(); + if (PivotName.Length != 0) hash ^= PivotName.GetHashCode(); + if (Branch != 0) hash ^= Branch.GetHashCode(); + if (valuesDef_ != null) hash ^= ValuesDef.GetHashCode(); + hash ^= nestedContexts_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (ContextName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(ContextName); + } + if (PredName.Length != 0) { + output.WriteRawTag(18); + output.WriteString(PredName); + } + if (PivotName.Length != 0) { + output.WriteRawTag(26); + output.WriteString(PivotName); + } + if (Branch != 0) { + output.WriteRawTag(32); + output.WriteInt32(Branch); + } + if (valuesDef_ != null) { + output.WriteRawTag(42); + output.WriteMessage(ValuesDef); + } + nestedContexts_.WriteTo(output, _repeated_nestedContexts_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (ContextName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ContextName); + } + if (PredName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(PredName); + } + if (PivotName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(PivotName); + } + if (Branch != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Branch); + } + if (valuesDef_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ValuesDef); + } + size += nestedContexts_.CalculateSize(_repeated_nestedContexts_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(CondContextDef other) { + if (other == null) { + return; + } + if (other.ContextName.Length != 0) { + ContextName = other.ContextName; + } + if (other.PredName.Length != 0) { + PredName = other.PredName; + } + if (other.PivotName.Length != 0) { + PivotName = other.PivotName; + } + if (other.Branch != 0) { + Branch = other.Branch; + } + if (other.valuesDef_ != null) { + if (valuesDef_ == null) { + valuesDef_ = new global::Tensorflow.ValuesDef(); + } + ValuesDef.MergeFrom(other.ValuesDef); + } + nestedContexts_.Add(other.nestedContexts_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + ContextName = input.ReadString(); + break; + } + case 18: { + PredName = input.ReadString(); + break; + } + case 26: { + PivotName = input.ReadString(); + break; + } + case 32: { + Branch = input.ReadInt32(); + break; + } + case 42: { + if (valuesDef_ == null) { + valuesDef_ = new global::Tensorflow.ValuesDef(); + } + input.ReadMessage(valuesDef_); + break; + } + case 50: { + nestedContexts_.AddEntriesFrom(input, _repeated_nestedContexts_codec); + break; + } + } + } + } + + } + + /// + /// Protocol buffer representing a WhileContext object. + /// + public sealed partial class WhileContextDef : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new WhileContextDef()); + private pb::UnknownFieldSet _unknownFields; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::Tensorflow.ControlFlowReflection.Descriptor.MessageTypes[3]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public WhileContextDef() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public WhileContextDef(WhileContextDef other) : this() { + contextName_ = other.contextName_; + parallelIterations_ = other.parallelIterations_; + backProp_ = other.backProp_; + swapMemory_ = other.swapMemory_; + pivotName_ = other.pivotName_; + pivotForPredName_ = other.pivotForPredName_; + pivotForBodyName_ = other.pivotForBodyName_; + loopExitNames_ = other.loopExitNames_.Clone(); + loopEnterNames_ = other.loopEnterNames_.Clone(); + valuesDef_ = other.valuesDef_ != null ? other.valuesDef_.Clone() : null; + maximumIterationsName_ = other.maximumIterationsName_; + nestedContexts_ = other.nestedContexts_.Clone(); + _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public WhileContextDef Clone() { + return new WhileContextDef(this); + } + + /// Field number for the "context_name" field. + public const int ContextNameFieldNumber = 1; + private string contextName_ = ""; + /// + /// Name of the context. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string ContextName { + get { return contextName_; } + set { + contextName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "parallel_iterations" field. + public const int ParallelIterationsFieldNumber = 2; + private int parallelIterations_; + /// + /// The number of iterations allowed to run in parallel. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int ParallelIterations { + get { return parallelIterations_; } + set { + parallelIterations_ = value; + } + } + + /// Field number for the "back_prop" field. + public const int BackPropFieldNumber = 3; + private bool backProp_; + /// + /// Whether backprop is enabled for this while loop. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool BackProp { + get { return backProp_; } + set { + backProp_ = value; + } + } + + /// Field number for the "swap_memory" field. + public const int SwapMemoryFieldNumber = 4; + private bool swapMemory_; + /// + /// Whether GPU-CPU memory swap is enabled for this loop. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool SwapMemory { + get { return swapMemory_; } + set { + swapMemory_ = value; + } + } + + /// Field number for the "pivot_name" field. + public const int PivotNameFieldNumber = 5; + private string pivotName_ = ""; + /// + /// Name of the pivot tensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string PivotName { + get { return pivotName_; } + set { + pivotName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "pivot_for_pred_name" field. + public const int PivotForPredNameFieldNumber = 6; + private string pivotForPredName_ = ""; + /// + /// Name of the pivot_for_pred tensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string PivotForPredName { + get { return pivotForPredName_; } + set { + pivotForPredName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "pivot_for_body_name" field. + public const int PivotForBodyNameFieldNumber = 7; + private string pivotForBodyName_ = ""; + /// + /// Name of the pivot_for_body tensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string PivotForBodyName { + get { return pivotForBodyName_; } + set { + pivotForBodyName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "loop_exit_names" field. + public const int LoopExitNamesFieldNumber = 8; + private static readonly pb::FieldCodec _repeated_loopExitNames_codec + = pb::FieldCodec.ForString(66); + private readonly pbc::RepeatedField loopExitNames_ = new pbc::RepeatedField(); + /// + /// List of names for exit tensors. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField LoopExitNames { + get { return loopExitNames_; } + } + + /// Field number for the "loop_enter_names" field. + public const int LoopEnterNamesFieldNumber = 10; + private static readonly pb::FieldCodec _repeated_loopEnterNames_codec + = pb::FieldCodec.ForString(82); + private readonly pbc::RepeatedField loopEnterNames_ = new pbc::RepeatedField(); + /// + /// List of names for enter tensors. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField LoopEnterNames { + get { return loopEnterNames_; } + } + + /// Field number for the "values_def" field. + public const int ValuesDefFieldNumber = 9; + private global::Tensorflow.ValuesDef valuesDef_; + /// + /// Values and external values in control flow context. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::Tensorflow.ValuesDef ValuesDef { + get { return valuesDef_; } + set { + valuesDef_ = value; + } + } + + /// Field number for the "maximum_iterations_name" field. + public const int MaximumIterationsNameFieldNumber = 11; + private string maximumIterationsName_ = ""; + /// + /// Optional name of the maximum_iterations tensor. + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string MaximumIterationsName { + get { return maximumIterationsName_; } + set { + maximumIterationsName_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + /// Field number for the "nested_contexts" field. + public const int NestedContextsFieldNumber = 12; + private static readonly pb::FieldCodec _repeated_nestedContexts_codec + = pb::FieldCodec.ForMessage(98, global::Tensorflow.ControlFlowContextDef.Parser); + private readonly pbc::RepeatedField nestedContexts_ = new pbc::RepeatedField(); + /// + /// Contexts contained inside this context (e.g. nested whiles). + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public pbc::RepeatedField NestedContexts { + get { return nestedContexts_; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as WhileContextDef); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(WhileContextDef other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (ContextName != other.ContextName) return false; + if (ParallelIterations != other.ParallelIterations) return false; + if (BackProp != other.BackProp) return false; + if (SwapMemory != other.SwapMemory) return false; + if (PivotName != other.PivotName) return false; + if (PivotForPredName != other.PivotForPredName) return false; + if (PivotForBodyName != other.PivotForBodyName) return false; + if(!loopExitNames_.Equals(other.loopExitNames_)) return false; + if(!loopEnterNames_.Equals(other.loopEnterNames_)) return false; + if (!object.Equals(ValuesDef, other.ValuesDef)) return false; + if (MaximumIterationsName != other.MaximumIterationsName) return false; + if(!nestedContexts_.Equals(other.nestedContexts_)) return false; + return Equals(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (ContextName.Length != 0) hash ^= ContextName.GetHashCode(); + if (ParallelIterations != 0) hash ^= ParallelIterations.GetHashCode(); + if (BackProp != false) hash ^= BackProp.GetHashCode(); + if (SwapMemory != false) hash ^= SwapMemory.GetHashCode(); + if (PivotName.Length != 0) hash ^= PivotName.GetHashCode(); + if (PivotForPredName.Length != 0) hash ^= PivotForPredName.GetHashCode(); + if (PivotForBodyName.Length != 0) hash ^= PivotForBodyName.GetHashCode(); + hash ^= loopExitNames_.GetHashCode(); + hash ^= loopEnterNames_.GetHashCode(); + if (valuesDef_ != null) hash ^= ValuesDef.GetHashCode(); + if (MaximumIterationsName.Length != 0) hash ^= MaximumIterationsName.GetHashCode(); + hash ^= nestedContexts_.GetHashCode(); + if (_unknownFields != null) { + hash ^= _unknownFields.GetHashCode(); + } + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (ContextName.Length != 0) { + output.WriteRawTag(10); + output.WriteString(ContextName); + } + if (ParallelIterations != 0) { + output.WriteRawTag(16); + output.WriteInt32(ParallelIterations); + } + if (BackProp != false) { + output.WriteRawTag(24); + output.WriteBool(BackProp); + } + if (SwapMemory != false) { + output.WriteRawTag(32); + output.WriteBool(SwapMemory); + } + if (PivotName.Length != 0) { + output.WriteRawTag(42); + output.WriteString(PivotName); + } + if (PivotForPredName.Length != 0) { + output.WriteRawTag(50); + output.WriteString(PivotForPredName); + } + if (PivotForBodyName.Length != 0) { + output.WriteRawTag(58); + output.WriteString(PivotForBodyName); + } + loopExitNames_.WriteTo(output, _repeated_loopExitNames_codec); + if (valuesDef_ != null) { + output.WriteRawTag(74); + output.WriteMessage(ValuesDef); + } + loopEnterNames_.WriteTo(output, _repeated_loopEnterNames_codec); + if (MaximumIterationsName.Length != 0) { + output.WriteRawTag(90); + output.WriteString(MaximumIterationsName); + } + nestedContexts_.WriteTo(output, _repeated_nestedContexts_codec); + if (_unknownFields != null) { + _unknownFields.WriteTo(output); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (ContextName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(ContextName); + } + if (ParallelIterations != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(ParallelIterations); + } + if (BackProp != false) { + size += 1 + 1; + } + if (SwapMemory != false) { + size += 1 + 1; + } + if (PivotName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(PivotName); + } + if (PivotForPredName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(PivotForPredName); + } + if (PivotForBodyName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(PivotForBodyName); + } + size += loopExitNames_.CalculateSize(_repeated_loopExitNames_codec); + size += loopEnterNames_.CalculateSize(_repeated_loopEnterNames_codec); + if (valuesDef_ != null) { + size += 1 + pb::CodedOutputStream.ComputeMessageSize(ValuesDef); + } + if (MaximumIterationsName.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(MaximumIterationsName); + } + size += nestedContexts_.CalculateSize(_repeated_nestedContexts_codec); + if (_unknownFields != null) { + size += _unknownFields.CalculateSize(); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(WhileContextDef other) { + if (other == null) { + return; + } + if (other.ContextName.Length != 0) { + ContextName = other.ContextName; + } + if (other.ParallelIterations != 0) { + ParallelIterations = other.ParallelIterations; + } + if (other.BackProp != false) { + BackProp = other.BackProp; + } + if (other.SwapMemory != false) { + SwapMemory = other.SwapMemory; + } + if (other.PivotName.Length != 0) { + PivotName = other.PivotName; + } + if (other.PivotForPredName.Length != 0) { + PivotForPredName = other.PivotForPredName; + } + if (other.PivotForBodyName.Length != 0) { + PivotForBodyName = other.PivotForBodyName; + } + loopExitNames_.Add(other.loopExitNames_); + loopEnterNames_.Add(other.loopEnterNames_); + if (other.valuesDef_ != null) { + if (valuesDef_ == null) { + valuesDef_ = new global::Tensorflow.ValuesDef(); + } + ValuesDef.MergeFrom(other.ValuesDef); + } + if (other.MaximumIterationsName.Length != 0) { + MaximumIterationsName = other.MaximumIterationsName; + } + nestedContexts_.Add(other.nestedContexts_); + _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + _unknownFields = pb::UnknownFieldSet.MergeFieldFrom(_unknownFields, input); + break; + case 10: { + ContextName = input.ReadString(); + break; + } + case 16: { + ParallelIterations = input.ReadInt32(); + break; + } + case 24: { + BackProp = input.ReadBool(); + break; + } + case 32: { + SwapMemory = input.ReadBool(); + break; + } + case 42: { + PivotName = input.ReadString(); + break; + } + case 50: { + PivotForPredName = input.ReadString(); + break; + } + case 58: { + PivotForBodyName = input.ReadString(); + break; + } + case 66: { + loopExitNames_.AddEntriesFrom(input, _repeated_loopExitNames_codec); + break; + } + case 74: { + if (valuesDef_ == null) { + valuesDef_ = new global::Tensorflow.ValuesDef(); + } + input.ReadMessage(valuesDef_); + break; + } + case 82: { + loopEnterNames_.AddEntriesFrom(input, _repeated_loopEnterNames_codec); + break; + } + case 90: { + MaximumIterationsName = input.ReadString(); + break; + } + case 98: { + nestedContexts_.AddEntriesFrom(input, _repeated_nestedContexts_codec); + break; + } + } + } + } + + } + + #endregion + +} + +#endregion Designer generated code diff --git a/src/TensorFlowNET.Core/Protobuf/IProtoBuf.cs b/src/TensorFlowNET.Core/Protobuf/IProtoBuf.cs index 36ce9088..ce08f5ed 100644 --- a/src/TensorFlowNET.Core/Protobuf/IProtoBuf.cs +++ b/src/TensorFlowNET.Core/Protobuf/IProtoBuf.cs @@ -8,7 +8,7 @@ namespace Tensorflow /// In order for a object to be serialized to and from MetaGraphDef, /// the class must implement to_proto() and from_proto() methods /// - public interface IProtoBuf + public interface IProtoBuf { string name { get; } @@ -17,15 +17,15 @@ namespace Tensorflow /// /// /// - VariableDef to_proto(string export_scope); + TProtoDef to_proto(string export_scope); /// /// Returns a `Variable` object created from `variable_def`. /// /// - /// + /// /// /// - T from_proto(VariableDef variable_def, string import_scope); + TDef from_proto(TProtoDef proto, string import_scope); } } diff --git a/src/TensorFlowNET.Core/Protobuf/README.md b/src/TensorFlowNET.Core/Protobuf/README.md index 0c8bb9ed..2cc8356e 100644 --- a/src/TensorFlowNET.Core/Protobuf/README.md +++ b/src/TensorFlowNET.Core/Protobuf/README.md @@ -1,10 +1,12 @@ ### Download compiler from https://github.com/protocolbuffers/protobuf/releases +Work in command line + ```shell +cd tensorflow + set SRC_DIR=D:/Projects/tensorflow set DST_DIR=D:/Projects/TensorFlow.NET/src/TensorFlowNET.Core/Protobuf -cd tensorflow - protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/resource_handle.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/tensor_shape.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/types.proto @@ -32,6 +34,7 @@ protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/cluster.prot protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/config.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/debug.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/rewriter_config.proto +protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/control_flow.proto protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/python/training/checkpoint_state.proto ``` diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 8d20f34d..95d5520d 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -7,7 +7,7 @@ using System.Text; namespace Tensorflow { - public partial class RefVariable : VariableV1, IProtoBuf + public partial class RefVariable : VariableV1, IProtoBuf { public bool _in_graph_mode = true; public Tensor _initial_value; @@ -288,7 +288,7 @@ namespace Tensorflow throw new NotImplementedException("to_proto RefVariable"); } - public T from_proto(VariableDef variable_def, string import_scope) + public RefVariable from_proto(VariableDef proto, string import_scope) { throw new NotImplementedException(); } diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index 34885776..c147e1b5 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -376,7 +376,7 @@ namespace Tensorflow if (import_scope.EndsWith("/")) import_scope = import_scope.Substring(0, import_scope.Length - 1); - throw new NotImplementedException("prepend_name_scope"); + return $"{import_scope}/{name}"; } else return name; diff --git a/test/TensorFlowNET.UnitTest/PythonTest.cs b/test/TensorFlowNET.UnitTest/PythonTest.cs index 508a0a81..b888f883 100644 --- a/test/TensorFlowNET.UnitTest/PythonTest.cs +++ b/test/TensorFlowNET.UnitTest/PythonTest.cs @@ -132,10 +132,11 @@ namespace TensorFlowNET.UnitTest } /// - /// Evaluates tensors and returns a dictionary of {name:result, ...}. - /// A Tensor or a nested list/tuple of Tensors. + /// This function is used in many original tensorflow unit tests to evaluate tensors + /// in a test session with special settings (for instance constant folding off) + /// /// - public Dictionary evaluate(params Tensor[] tensors) + public T evaluate(Tensor tensor) { var results = new Dictionary(); // if context.executing_eagerly(): @@ -145,49 +146,26 @@ namespace TensorFlowNET.UnitTest var sess = ops.get_default_session(); if (sess == null) sess = self.session(); - - with(sess, s => - { - foreach (var t in tensors) - results[t.name] = t.eval(); - }); - return results; - } - } - - public NDArray evaluate(Tensor tensor) - { - NDArray result = null; - // if context.executing_eagerly(): - // return self._eval_helper(tensors) - // else: - { - var sess = ops.get_default_session(); - if (sess == null) - sess = self.session(); - with(sess, s => - { - result = tensor.eval(); - }); - return result; - } - } - - public object eval_scalar(Tensor tensor) - { - NDArray result = null; - // if context.executing_eagerly(): - // return self._eval_helper(tensors) - // else: - { - var sess = ops.get_default_session(); - if (sess == null) - sess = self.session(); + T t_result = (T)(object)null; with(sess, s => { - result = tensor.eval(); + var ndarray=tensor.eval(); + if (typeof(T) == typeof(double)) + { + double d = ndarray; + t_result = (T)(object)d; + } + else if (typeof(T) == typeof(int)) + { + int d = ndarray; + t_result = (T) (object) d; + } + else + { + t_result = (T)(object)ndarray; + } }); - return result.Array.GetValue(0); + return t_result; } } diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs b/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs index 8fb9d9bb..7206fada 100644 --- a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs +++ b/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs @@ -1,4 +1,5 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; using Tensorflow; namespace TensorFlowNET.UnitTest.control_flow_ops_test @@ -9,32 +10,73 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test [TestClass] public class CondTestCases : PythonTest { - [TestMethod] public void testCondTrue() { - with(tf.Graph().as_default(), g => + var graph = tf.Graph().as_default(); + + with(tf.Session(graph), sess => { var x = tf.constant(2); var y = tf.constant(5); - var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)), - () => tf.add(y, tf.constant(23))); - //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false); - self.assertEquals(eval_scalar(z), 34); + var pred = tf.less(x, y); + + Func if_true = delegate + { + return tf.multiply(x, 17); + }; + + Func if_false = delegate + { + return tf.add(y, 23); + }; + + var z = control_flow_ops.cond(pred, if_true, if_false); + int result = z.eval(sess); + assertEquals(result, 34); }); } - //[Ignore("This Test Fails due to missing edges in the graph!")] [TestMethod] public void testCondFalse() { - with(tf.Graph().as_default(), g => + /* python + * import tensorflow as tf + from tensorflow.python.framework import ops + + def if_true(): + return tf.math.multiply(x, 17) + def if_false(): + return tf.math.add(y, 23) + + with tf.Session() as sess: + x = tf.constant(2) + y = tf.constant(1) + pred = tf.math.less(x,y) + z = tf.cond(pred, if_true, if_false) + result = z.eval() + + print(result == 24) */ + + with(tf.Session(), sess => { var x = tf.constant(2); var y = tf.constant(1); - var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)), - () => tf.add(y, tf.constant(23))); - self.assertEquals(eval_scalar(z), 24); + var pred = tf.less(x, y); + + Func if_true = delegate + { + return tf.multiply(x, 17); + }; + + Func if_false = delegate + { + return tf.add(y, 23); + }; + + var z = control_flow_ops.cond(pred, if_true, if_false); + int result = z.eval(sess); + assertEquals(result, 24); }); } diff --git a/test/TensorFlowNET.UnitTest/ops_test/ControlDependenciesTest.cs b/test/TensorFlowNET.UnitTest/ops_test/ControlDependenciesTest.cs index 74935c76..ca2665ff 100644 --- a/test/TensorFlowNET.UnitTest/ops_test/ControlDependenciesTest.cs +++ b/test/TensorFlowNET.UnitTest/ops_test/ControlDependenciesTest.cs @@ -162,7 +162,7 @@ namespace TensorFlowNET.UnitTest.ops_test { var z1 = tf.add(a_3, tf.multiply(a_4, a_2)); }); - tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false); + //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false); assertItemsEqual(b_1.op.control_inputs, new[] { a_1.op, a_2.op, a_3.op, a_4.op }); assertItemsEqual(b_2.op.control_inputs, b_1.op.control_inputs); }