# Conflicts: # src/TensorFlowNET.Core/Operations/Operation.cs # test/TensorFlowNET.UnitTest/PythonTest.cs # test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cstags/v0.9
@@ -27,7 +27,7 @@ namespace Tensorflow | |||||
public static Tensor asin(Tensor x, string name = null) | public static Tensor asin(Tensor x, string name = null) | ||||
=> gen_math_ops.asin(x, name); | => gen_math_ops.asin(x, name); | ||||
public static Tensor add(Tensor a, Tensor b) | |||||
public static Tensor add<Tx, Ty>(Tx a, Ty b) | |||||
=> gen_math_ops.add(a, b); | => gen_math_ops.add(a, b); | ||||
/// <summary> | /// <summary> | ||||
@@ -251,7 +251,7 @@ namespace Tensorflow | |||||
public static Tensor minimum<T1, T2>(T1 x, T2 y, string name = null) | public static Tensor minimum<T1, T2>(T1 x, T2 y, string name = null) | ||||
=> gen_math_ops.minimum(x, y, name: name); | => gen_math_ops.minimum(x, y, name: name); | ||||
public static Tensor multiply(Tensor x, Tensor y) | |||||
public static Tensor multiply<Tx, Ty>(Tx x, Ty y) | |||||
=> gen_math_ops.mul(x, y); | => gen_math_ops.mul(x, y); | ||||
public static Tensor negative(Tensor x, string name = null) | public static Tensor negative(Tensor x, string name = null) | ||||
@@ -4,6 +4,7 @@ using System.Collections.Generic; | |||||
using System.IO; | using System.IO; | ||||
using System.Linq; | using System.Linq; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Operations; | |||||
using static Tensorflow.CollectionDef; | using static Tensorflow.CollectionDef; | ||||
using static Tensorflow.MetaGraphDef.Types; | using static Tensorflow.MetaGraphDef.Types; | ||||
@@ -95,15 +96,29 @@ namespace Tensorflow | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); | |||||
foreach(var value in col.Value.BytesList.Value) | |||||
{ | |||||
switch (col.Key) | |||||
{ | |||||
case "cond_context": | |||||
var proto = CondContextDef.Parser.ParseFrom(value); | |||||
var condContext = new CondContext().from_proto(proto, import_scope); | |||||
graph.add_to_collection(col.Key, condContext); | |||||
break; | |||||
default: | |||||
throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); | |||||
} | |||||
} | |||||
} | } | ||||
break; | break; | ||||
default: | |||||
throw new NotImplementedException("import_scoped_meta_graph_with_return_elements"); | |||||
} | } | ||||
} | } | ||||
var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, | |||||
scope: scope_to_prepend_to_names) as List<RefVariable>; | |||||
var variables = graph.get_collection<RefVariable>(ops.GraphKeys.GLOBAL_VARIABLES, | |||||
scope: scope_to_prepend_to_names); | |||||
var var_list = new Dictionary<string, RefVariable>(); | var var_list = new Dictionary<string, RefVariable>(); | ||||
variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v); | variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v); | ||||
@@ -412,6 +412,11 @@ namespace Tensorflow | |||||
return _collections.ContainsKey(name) ? _collections[name] : null; | return _collections.ContainsKey(name) ? _collections[name] : null; | ||||
} | } | ||||
public List<T> get_collection<T>(string name, string scope = null) | |||||
{ | |||||
return _collections.ContainsKey(name) ? _collections[name] as List<T> : new List<T>(); | |||||
} | |||||
public object get_collection_ref(string name) | public object get_collection_ref(string name) | ||||
{ | { | ||||
if (!_collections.ContainsKey(name)) | if (!_collections.ContainsKey(name)) | ||||
@@ -8,7 +8,7 @@ namespace Tensorflow.Operations | |||||
/// <summary> | /// <summary> | ||||
/// The context for the conditional construct. | /// The context for the conditional construct. | ||||
/// </summary> | /// </summary> | ||||
public class CondContext : ControlFlowContext | |||||
public class CondContext : ControlFlowContext, IProtoBuf<CondContextDef, CondContext> | |||||
{ | { | ||||
@@ -35,16 +35,20 @@ namespace Tensorflow.Operations | |||||
/// <param name="name">Name of the `CondContext` python object.</param> | /// <param name="name">Name of the `CondContext` python object.</param> | ||||
/// <param name="context_def"></param> | /// <param name="context_def"></param> | ||||
/// <param name="import_scope"></param> | /// <param name="import_scope"></param> | ||||
public CondContext(Tensor pred, | |||||
Tensor pivot, | |||||
int branch, | |||||
public CondContext(Tensor pred = null, | |||||
Tensor pivot = null, | |||||
int? branch = null, | |||||
string name = "cond_text", | string name = "cond_text", | ||||
object context_def = null, | |||||
CondContextDef context_def = null, | |||||
string import_scope = null) | string import_scope = null) | ||||
{ | { | ||||
if (pred == null && context_def == null) return; | |||||
_name = ops.get_default_graph().unique_name(name); | _name = ops.get_default_graph().unique_name(name); | ||||
if (context_def != null) | |||||
throw new NotImplementedException("CondContext context_def is not null"); | |||||
if (context_def != null) | |||||
{ | |||||
_init_from_proto(context_def, import_scope: import_scope); | |||||
} | |||||
else | else | ||||
{ | { | ||||
// Initializes the default fields. | // Initializes the default fields. | ||||
@@ -61,6 +65,18 @@ namespace Tensorflow.Operations | |||||
} | } | ||||
} | } | ||||
private void _init_from_proto(CondContextDef context_def, string import_scope = null) | |||||
{ | |||||
var g = ops.get_default_graph(); | |||||
_name = ops.prepend_name_scope(context_def.ContextName, import_scope); | |||||
var p1 = ops.prepend_name_scope(context_def.PredName, import_scope); | |||||
_pred = g.as_graph_element(p1) as Tensor; | |||||
var p2 = ops.prepend_name_scope(context_def.PivotName, import_scope); | |||||
_pivot = g.as_graph_element(p2) as Tensor; | |||||
_branch = context_def.Branch; | |||||
__init__(values_def: context_def.ValuesDef, import_scope: import_scope); | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Add `val` to the current context and its outer context recursively. | /// Add `val` to the current context and its outer context recursively. | ||||
/// </summary> | /// </summary> | ||||
@@ -230,6 +246,22 @@ namespace Tensorflow.Operations | |||||
public override void AddInnerOp(Operation resultOp) | public override void AddInnerOp(Operation resultOp) | ||||
{ | { | ||||
throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
} | |||||
} | |||||
public CondContextDef to_proto(string export_scope) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
public CondContext from_proto(CondContextDef proto, string import_scope) | |||||
{ | |||||
var ret = new CondContext(context_def: proto, import_scope: import_scope); | |||||
ret.Enter(); | |||||
foreach (var nested_def in proto.NestedContexts) | |||||
throw new NotImplementedException(""); | |||||
ret.Exit(); | |||||
return ret; | |||||
} | |||||
} | } | ||||
} | } |
@@ -32,6 +32,8 @@ namespace Tensorflow.Operations | |||||
protected Stack<IControlFlowContext> _context_stack; | protected Stack<IControlFlowContext> _context_stack; | ||||
protected IControlFlowContext _outer_context; | protected IControlFlowContext _outer_context; | ||||
protected Dictionary<string, ITensorOrOperation> _external_values; | |||||
public ControlFlowContext() | public ControlFlowContext() | ||||
{ | { | ||||
_context_stack = new Stack<IControlFlowContext>(); | _context_stack = new Stack<IControlFlowContext>(); | ||||
@@ -40,15 +42,43 @@ namespace Tensorflow.Operations | |||||
public string name { get => _name; } | public string name { get => _name; } | ||||
protected string _name; | protected string _name; | ||||
public void __init__() | |||||
public void __init__(ValuesDef values_def = null, string import_scope = null) | |||||
{ | { | ||||
_outer_context = ops.get_default_graph()._get_control_flow_context(); | |||||
if (values_def != null) | |||||
_init_values_from_proto(values_def, import_scope: import_scope); | |||||
} | } | ||||
public void __enter__() | public void __enter__() | ||||
{ | { | ||||
} | } | ||||
/// <summary> | |||||
/// Initializes values and external_values from `ValuesDef` protocol buffer. | |||||
/// </summary> | |||||
/// <param name="values_def"></param> | |||||
/// <param name="import_scope"></param> | |||||
protected void _init_values_from_proto(ValuesDef values_def, string import_scope = null) | |||||
{ | |||||
_external_values = new Dictionary<string, ITensorOrOperation>(); | |||||
foreach (var value in values_def.Values) | |||||
_values.Add(value); | |||||
var g = ops.get_default_graph(); | |||||
foreach(var value in values_def.ExternalValues) | |||||
{ | |||||
var k = ops.prepend_name_scope(value.Key, import_scope); | |||||
var v = value.Value; | |||||
_external_values[k] = g.as_graph_element(ops.prepend_name_scope(v, import_scope)); | |||||
} | |||||
var op_names = _values.Where(x => !_external_values.ContainsKey(x)) | |||||
.Select(x => x.Split(':')[0]) | |||||
.ToArray(); | |||||
foreach (var op in op_names) | |||||
(g.as_graph_element(op) as Operation)._set_control_flow_context(this); | |||||
} | |||||
public void __exit__() | public void __exit__() | ||||
{ | { | ||||
} | } | ||||
@@ -42,8 +42,8 @@ namespace Tensorflow | |||||
if (NumControlOutputs > 0) | if (NumControlOutputs > 0) | ||||
{ | { | ||||
IntPtr control_output_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlOutputs); | IntPtr control_output_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlOutputs); | ||||
c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlInputs); | |||||
for (int i = 0; i < NumControlInputs; i++) | |||||
c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlOutputs); | |||||
for (int i = 0; i < NumControlOutputs; i++) | |||||
{ | { | ||||
var handle = control_output_handle + Marshal.SizeOf<IntPtr>() * i; | var handle = control_output_handle + Marshal.SizeOf<IntPtr>() * i; | ||||
control_outputs[i] = new Operation(*(IntPtr*)handle); | control_outputs[i] = new Operation(*(IntPtr*)handle); | ||||
@@ -1,319 +1,318 @@ | |||||
using Google.Protobuf.Collections; | |||||
//using Newtonsoft.Json; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Runtime.InteropServices; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
using Google.Protobuf.Collections; | |||||
//using Newtonsoft.Json; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Runtime.InteropServices; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | { | ||||
/// <summary> | |||||
/// Represents a graph node that performs computation on tensors. | |||||
/// | |||||
/// An `Operation` is a node in a TensorFlow `Graph` that takes zero or | |||||
/// more `Tensor` objects as input, and produces zero or more `Tensor` | |||||
/// objects as output. Objects of type `Operation` are created by | |||||
/// calling an op constructor(such as `tf.matmul`) | |||||
/// or `tf.Graph.create_op`. | |||||
/// | |||||
/// For example `c = tf.matmul(a, b)` creates an `Operation` of type | |||||
/// "MatMul" that takes tensors `a` and `b` as input, and produces `c` | |||||
/// as output. | |||||
/// | |||||
/// After the graph has been launched in a session, an `Operation` can | |||||
/// be executed by passing it to | |||||
/// `tf.Session.run`. | |||||
/// <summary> | |||||
/// Represents a graph node that performs computation on tensors. | |||||
/// | |||||
/// An `Operation` is a node in a TensorFlow `Graph` that takes zero or | |||||
/// more `Tensor` objects as input, and produces zero or more `Tensor` | |||||
/// objects as output. Objects of type `Operation` are created by | |||||
/// calling an op constructor(such as `tf.matmul`) | |||||
/// or `tf.Graph.create_op`. | |||||
/// | |||||
/// For example `c = tf.matmul(a, b)` creates an `Operation` of type | |||||
/// "MatMul" that takes tensors `a` and `b` as input, and produces `c` | |||||
/// as output. | |||||
/// | |||||
/// After the graph has been launched in a session, an `Operation` can | |||||
/// be executed by passing it to | |||||
/// `tf.Session.run`. | |||||
/// `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`. | /// `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`. | ||||
/// </summary> | |||||
public partial class Operation : ITensorOrOperation | |||||
{ | |||||
private readonly IntPtr _handle; // _c_op in python | |||||
private readonly IntPtr _operDesc; | |||||
private Graph _graph; | |||||
//[JsonIgnore] | |||||
public Graph graph => _graph; | |||||
//[JsonIgnore] | |||||
public int _id => _id_value; | |||||
//[JsonIgnore] | |||||
public int _id_value; | |||||
public string type => OpType; | |||||
//[JsonIgnore] | |||||
public Operation op => this; | |||||
public TF_DataType dtype => TF_DataType.DtInvalid; | |||||
private Status status = new Status(); | |||||
public string name => c_api.StringPiece(c_api.TF_OperationName(_handle)); | |||||
public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | |||||
public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | |||||
private NodeDef _node_def; | |||||
public NodeDef node_def | |||||
{ | |||||
get | |||||
{ | |||||
if(_node_def == null) | |||||
_node_def = GetNodeDef(); | |||||
return _node_def; | |||||
} | |||||
} | |||||
public Operation(IntPtr handle, Graph g=null) | |||||
{ | |||||
if (handle == IntPtr.Zero) | |||||
return; | |||||
_handle = handle; | |||||
_graph = g ?? ops.get_default_graph(); | |||||
_outputs = new Tensor[NumOutputs]; | |||||
for (int i = 0; i < NumOutputs; i++) | |||||
_outputs[i] = new Tensor(this, i, OutputType(i)); | |||||
// Dict mapping op name to file and line information for op colocation | |||||
// context managers. | |||||
_control_flow_context = graph._get_control_flow_context(); | |||||
// Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor. | |||||
} | |||||
public Operation(Graph g, string opType, string oper_name) | |||||
{ | |||||
_graph = g; | |||||
_operDesc = c_api.TF_NewOperation(g, opType, oper_name); | |||||
c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32); | |||||
_handle = c_api.TF_FinishOperation(_operDesc, status); | |||||
// Dict mapping op name to file and line information for op colocation | |||||
// context managers. | |||||
_control_flow_context = graph._get_control_flow_context(); | |||||
} | |||||
/// <summary> | |||||
/// Creates an `Operation`. | |||||
/// </summary> | |||||
/// <param name="node_def">`node_def_pb2.NodeDef`. `NodeDef` for the `Operation`.</param> | |||||
/// <param name="g">`Graph`. The parent graph.</param> | |||||
/// <param name="inputs">list of `Tensor` objects. The inputs to this `Operation`.</param> | |||||
/// <param name="output_types">list of `DType` objects.</param> | |||||
/// <param name="control_inputs"> | |||||
/// list of operations or tensors from which to have a | |||||
/// control dependency. | |||||
/// </param> | |||||
/// <param name="input_types"> | |||||
/// List of `DType` objects representing the | |||||
/// types of the tensors accepted by the `Operation`. By default | |||||
/// uses `[x.dtype.base_dtype for x in inputs]`. Operations that expect | |||||
/// reference-typed inputs must specify these explicitly. | |||||
/// </param> | |||||
/// <param name="original_op"></param> | |||||
/// <param name="op_def"></param> | |||||
public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) | |||||
{ | |||||
_graph = g; | |||||
// Build the list of control inputs. | |||||
var control_input_ops = new List<Operation>(); | |||||
if(control_inputs != null) | |||||
{ | |||||
foreach(var c in control_inputs) | |||||
{ | |||||
switch (c) | |||||
{ | |||||
case Operation c1: | |||||
control_input_ops.Add(c1); | |||||
break; | |||||
case Tensor tensor: | |||||
control_input_ops.Add(tensor.op); | |||||
break; | |||||
// TODO: IndexedSlices don't yet exist, but once they do, this needs to be uncommented | |||||
//case IndexedSlices islices: | |||||
// control_input_ops.Add(islices.op); | |||||
// break; | |||||
default: | |||||
throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}"); | |||||
} | |||||
} | |||||
} | |||||
// Dict mapping op name to file and line information for op colocation | |||||
// context managers. | |||||
_control_flow_context = graph._get_control_flow_context(); | |||||
// This will be set by self.inputs. | |||||
if (op_def == null) | |||||
op_def = g.GetOpDef(node_def.Op); | |||||
var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | |||||
(_handle, _operDesc) = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | |||||
// Initialize self._outputs. | |||||
output_types = new TF_DataType[NumOutputs]; | |||||
for (int i = 0; i < NumOutputs; i++) | |||||
output_types[i] = OutputType(i); | |||||
_outputs = new Tensor[NumOutputs]; | |||||
for (int i = 0; i < NumOutputs; i++) | |||||
_outputs[i] = new Tensor(this, i, OutputType(i)); | |||||
graph._add_op(this); | |||||
if (_handle != IntPtr.Zero) | |||||
_control_flow_post_processing(); | |||||
} | |||||
public void run(FeedItem[] feed_dict = null, Session session = null) | |||||
{ | |||||
ops._run_using_default_session(this, feed_dict, graph, session); | |||||
} | |||||
private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField<string, AttrValue> attrs) | |||||
{ | |||||
var grouped_inputs = new List<object>(); | |||||
int i = 0; | |||||
int input_len = 0; | |||||
bool is_sequence = false; | |||||
foreach (var input_arg in op_def.InputArg) | |||||
{ | |||||
if (!string.IsNullOrEmpty(input_arg.NumberAttr)) | |||||
{ | |||||
input_len = (int)attrs[input_arg.NumberAttr].I; | |||||
is_sequence = true; | |||||
} | |||||
else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) | |||||
{ | |||||
input_len = attrs[input_arg.TypeListAttr].List.Type.Count; | |||||
is_sequence = true; | |||||
} | |||||
else | |||||
{ | |||||
input_len = 1; | |||||
is_sequence = false; | |||||
} | |||||
if (is_sequence) | |||||
grouped_inputs.Add(inputs.Skip(i).Take(input_len).ToArray()); | |||||
else | |||||
grouped_inputs.Add(inputs[i]); | |||||
i += input_len; | |||||
} | |||||
return grouped_inputs.ToArray(); | |||||
} | |||||
public object get_attr(string name) | |||||
{ | |||||
AttrValue x = null; | |||||
using (var buf = new Buffer()) | |||||
{ | |||||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | |||||
status.Check(true); | |||||
x = AttrValue.Parser.ParseFrom(buf); | |||||
} | |||||
string oneof_value = x.ValueCase.ToString(); | |||||
if (string.IsNullOrEmpty(oneof_value)) | |||||
return null; | |||||
if(oneof_value == "list") | |||||
throw new NotImplementedException($"Unsupported field type in {x.ToString()}"); | |||||
if (oneof_value == "type") | |||||
return x.Type; | |||||
object result = x.GetType().GetProperty(oneof_value).GetValue(x); | |||||
if (result is Google.Protobuf.ByteString byteString) | |||||
return byteString.ToStringUtf8(); | |||||
return result; | |||||
} | |||||
public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s) | |||||
{ | |||||
return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s); | |||||
} | |||||
private NodeDef GetNodeDef() | |||||
{ | |||||
using (var s = new Status()) | |||||
using (var buffer = new Buffer()) | |||||
{ | |||||
c_api.TF_OperationToNodeDef(_handle, buffer, s); | |||||
s.Check(); | |||||
return NodeDef.Parser.ParseFrom(buffer); | |||||
} | |||||
} | |||||
public override string ToString() | |||||
{ | |||||
return _handle == IntPtr.Zero ? "tf.Operation Undefined" : $"tf.Operation '{name}' type={OpType}"; | |||||
} | |||||
public static implicit operator Operation(IntPtr handle) => new Operation(handle); | |||||
public static implicit operator IntPtr(Operation op) => op._handle; | |||||
public override bool Equals(object obj) | |||||
{ | |||||
switch (obj) | |||||
{ | |||||
case IntPtr val: | |||||
return val == _handle; | |||||
case Operation val: | |||||
return val._handle == _handle; | |||||
} | |||||
return base.Equals(obj); | |||||
/// </summary> | |||||
public partial class Operation : ITensorOrOperation | |||||
{ | |||||
private readonly IntPtr _handle; // _c_op in python | |||||
private readonly IntPtr _operDesc; | |||||
private Graph _graph; | |||||
//[JsonIgnore] | |||||
public Graph graph => _graph; | |||||
//[JsonIgnore] | |||||
public int _id => _id_value; | |||||
//[JsonIgnore] | |||||
public int _id_value; | |||||
public string type => OpType; | |||||
//[JsonIgnore] | |||||
public Operation op => this; | |||||
public TF_DataType dtype => TF_DataType.DtInvalid; | |||||
private Status status = new Status(); | |||||
public string name => c_api.StringPiece(c_api.TF_OperationName(_handle)); | |||||
public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | |||||
public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | |||||
private NodeDef _node_def; | |||||
public NodeDef node_def | |||||
{ | |||||
get | |||||
{ | |||||
if(_node_def == null) | |||||
_node_def = GetNodeDef(); | |||||
return _node_def; | |||||
} | |||||
} | |||||
public Operation(IntPtr handle, Graph g=null) | |||||
{ | |||||
if (handle == IntPtr.Zero) | |||||
return; | |||||
_handle = handle; | |||||
_graph = g ?? ops.get_default_graph(); | |||||
_outputs = new Tensor[NumOutputs]; | |||||
for (int i = 0; i < NumOutputs; i++) | |||||
_outputs[i] = new Tensor(this, i, OutputType(i)); | |||||
// Dict mapping op name to file and line information for op colocation | |||||
// context managers. | |||||
_control_flow_context = graph._get_control_flow_context(); | |||||
// Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor. | |||||
} | |||||
public Operation(Graph g, string opType, string oper_name) | |||||
{ | |||||
_graph = g; | |||||
_operDesc = c_api.TF_NewOperation(g, opType, oper_name); | |||||
c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32); | |||||
_handle = c_api.TF_FinishOperation(_operDesc, status); | |||||
// Dict mapping op name to file and line information for op colocation | |||||
// context managers. | |||||
_control_flow_context = graph._get_control_flow_context(); | |||||
} | } | ||||
/// <summary> | |||||
/// Update the input to this operation at the given index. | |||||
/// | |||||
/// NOTE: This is for TF internal use only.Please don't use it. | |||||
/// </summary> | |||||
/// <param name="index">the index of the input to update.</param> | |||||
/// <param name="tensor"> the Tensor to be used as the input at the given index.</param> | |||||
public void _update_input(int index, Tensor tensor) | |||||
{ | |||||
_assert_same_graph(tensor); | |||||
var input = _tf_input(index); | |||||
/// <summary> | |||||
/// Creates an `Operation`. | |||||
/// </summary> | |||||
/// <param name="node_def">`node_def_pb2.NodeDef`. `NodeDef` for the `Operation`.</param> | |||||
/// <param name="g">`Graph`. The parent graph.</param> | |||||
/// <param name="inputs">list of `Tensor` objects. The inputs to this `Operation`.</param> | |||||
/// <param name="output_types">list of `DType` objects.</param> | |||||
/// <param name="control_inputs"> | |||||
/// list of operations or tensors from which to have a | |||||
/// control dependency. | |||||
/// </param> | |||||
/// <param name="input_types"> | |||||
/// List of `DType` objects representing the | |||||
/// types of the tensors accepted by the `Operation`. By default | |||||
/// uses `[x.dtype.base_dtype for x in inputs]`. Operations that expect | |||||
/// reference-typed inputs must specify these explicitly. | |||||
/// </param> | |||||
/// <param name="original_op"></param> | |||||
/// <param name="op_def"></param> | |||||
public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) | |||||
{ | |||||
_graph = g; | |||||
// Build the list of control inputs. | |||||
var control_input_ops = new List<Operation>(); | |||||
if(control_inputs != null) | |||||
{ | |||||
foreach(var c in control_inputs) | |||||
{ | |||||
switch (c) | |||||
{ | |||||
case Operation c1: | |||||
control_input_ops.Add(c1); | |||||
break; | |||||
case Tensor tensor: | |||||
control_input_ops.Add(tensor.op); | |||||
break; | |||||
// TODO: IndexedSlices don't yet exist, but once they do, this needs to be uncommented | |||||
//case IndexedSlices islices: | |||||
// control_input_ops.Add(islices.op); | |||||
// break; | |||||
default: | |||||
throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}"); | |||||
} | |||||
} | |||||
} | |||||
// Dict mapping op name to file and line information for op colocation | |||||
// context managers. | |||||
_control_flow_context = graph._get_control_flow_context(); | |||||
// This will be set by self.inputs. | |||||
if (op_def == null) | |||||
op_def = g.GetOpDef(node_def.Op); | |||||
var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | |||||
(_handle, _operDesc) = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | |||||
// Initialize self._outputs. | |||||
output_types = new TF_DataType[NumOutputs]; | |||||
for (int i = 0; i < NumOutputs; i++) | |||||
output_types[i] = OutputType(i); | |||||
_outputs = new Tensor[NumOutputs]; | |||||
for (int i = 0; i < NumOutputs; i++) | |||||
_outputs[i] = new Tensor(this, i, OutputType(i)); | |||||
graph._add_op(this); | |||||
if (_handle != IntPtr.Zero) | |||||
_control_flow_post_processing(); | |||||
} | |||||
public void run(FeedItem[] feed_dict = null, Session session = null) | |||||
{ | |||||
ops._run_using_default_session(this, feed_dict, graph, session); | |||||
} | |||||
private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField<string, AttrValue> attrs) | |||||
{ | |||||
var grouped_inputs = new List<object>(); | |||||
int i = 0; | |||||
int input_len = 0; | |||||
bool is_sequence = false; | |||||
foreach (var input_arg in op_def.InputArg) | |||||
{ | |||||
if (!string.IsNullOrEmpty(input_arg.NumberAttr)) | |||||
{ | |||||
input_len = (int)attrs[input_arg.NumberAttr].I; | |||||
is_sequence = true; | |||||
} | |||||
else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) | |||||
{ | |||||
input_len = attrs[input_arg.TypeListAttr].List.Type.Count; | |||||
is_sequence = true; | |||||
} | |||||
else | |||||
{ | |||||
input_len = 1; | |||||
is_sequence = false; | |||||
} | |||||
if (is_sequence) | |||||
grouped_inputs.Add(inputs.Skip(i).Take(input_len).ToArray()); | |||||
else | |||||
grouped_inputs.Add(inputs[i]); | |||||
i += input_len; | |||||
} | |||||
return grouped_inputs.ToArray(); | |||||
} | |||||
public object get_attr(string name) | |||||
{ | |||||
AttrValue x = null; | |||||
using (var buf = new Buffer()) | |||||
{ | |||||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | |||||
status.Check(true); | |||||
x = AttrValue.Parser.ParseFrom(buf); | |||||
} | |||||
string oneof_value = x.ValueCase.ToString(); | |||||
if (string.IsNullOrEmpty(oneof_value)) | |||||
return null; | |||||
if(oneof_value == "list") | |||||
throw new NotImplementedException($"Unsupported field type in {x.ToString()}"); | |||||
if (oneof_value == "type") | |||||
return x.Type; | |||||
object result = x.GetType().GetProperty(oneof_value).GetValue(x); | |||||
if (result is Google.Protobuf.ByteString byteString) | |||||
return byteString.ToStringUtf8(); | |||||
return result; | |||||
} | |||||
public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s) | |||||
{ | |||||
return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s); | |||||
} | |||||
private NodeDef GetNodeDef() | |||||
{ | |||||
using (var s = new Status()) | |||||
using (var buffer = new Buffer()) | |||||
{ | |||||
c_api.TF_OperationToNodeDef(_handle, buffer, s); | |||||
s.Check(); | |||||
return NodeDef.Parser.ParseFrom(buffer); | |||||
} | |||||
} | |||||
public override string ToString() | |||||
{ | |||||
return _handle == IntPtr.Zero ? "tf.Operation Undefined" : $"tf.Operation '{name}' type={OpType}"; | |||||
} | |||||
public static implicit operator Operation(IntPtr handle) => new Operation(handle); | |||||
public static implicit operator IntPtr(Operation op) => op._handle; | |||||
public override bool Equals(object obj) | |||||
{ | |||||
switch (obj) | |||||
{ | |||||
case IntPtr val: | |||||
return val == _handle; | |||||
case Operation val: | |||||
return val._handle == _handle; | |||||
} | |||||
return base.Equals(obj); | |||||
} | |||||
/// <summary> | |||||
/// Update the input to this operation at the given index. | |||||
/// | |||||
/// NOTE: This is for TF internal use only.Please don't use it. | |||||
/// </summary> | |||||
/// <param name="index">the index of the input to update.</param> | |||||
/// <param name="tensor"> the Tensor to be used as the input at the given index.</param> | |||||
public void _update_input(int index, Tensor tensor) | |||||
{ | |||||
_assert_same_graph(tensor); | |||||
var input = _tf_input(index); | |||||
var output = tensor._as_tf_output(); | var output = tensor._as_tf_output(); | ||||
// Reset cached inputs. | // Reset cached inputs. | ||||
_inputs = null; | _inputs = null; | ||||
// after the c_api call next time _inputs is accessed | // after the c_api call next time _inputs is accessed | ||||
// the updated inputs are reloaded from the c_api | |||||
c_api.TF_UpdateEdge(_graph, output, input, status); | |||||
//var updated_inputs = inputs; | |||||
} | |||||
private void _assert_same_graph(Tensor tensor) | |||||
{ | |||||
//TODO: implement | |||||
} | |||||
/// <summary> | |||||
/// Create and return a new TF_Output for output_idx'th output of this op. | |||||
/// </summary> | |||||
public TF_Output _tf_output(int output_idx) | |||||
{ | |||||
var tf_output = new TF_Output(op, output_idx); | |||||
return tf_output; | |||||
} | |||||
/// <summary> | |||||
/// Create and return a new TF_Input for input_idx'th input of this op. | |||||
/// </summary> | |||||
public TF_Input _tf_input(int input_idx) | |||||
{ | |||||
var tf_input = new TF_Input(op, input_idx); | |||||
return tf_input; | |||||
} | |||||
} | |||||
} | |||||
// the updated inputs are reloaded from the c_api | |||||
c_api.TF_UpdateEdge(_graph, output, input, status); | |||||
//var updated_inputs = inputs; | |||||
} | |||||
private void _assert_same_graph(Tensor tensor) | |||||
{ | |||||
//TODO: implement | |||||
} | |||||
/// <summary> | |||||
/// Create and return a new TF_Output for output_idx'th output of this op. | |||||
/// </summary> | |||||
public TF_Output _tf_output(int output_idx) | |||||
{ | |||||
var tf_output = new TF_Output(op, output_idx); | |||||
return tf_output; | |||||
} | |||||
/// <summary> | |||||
/// Create and return a new TF_Input for input_idx'th input of this op. | |||||
/// </summary> | |||||
public TF_Input _tf_input(int input_idx) | |||||
{ | |||||
var tf_input = new TF_Input(op, input_idx); | |||||
return tf_input; | |||||
} | |||||
} | |||||
} |
@@ -308,7 +308,7 @@ namespace Tensorflow | |||||
tensor.op.graph.prevent_fetching(tensor.op); | tensor.op.graph.prevent_fetching(tensor.op); | ||||
// Build the graph for the true branch in a new context. | // Build the graph for the true branch in a new context. | ||||
var context_t = new CondContext(pred, pivot_1, branch: 1); | |||||
var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1); | |||||
ITensorOrOperation orig_res_t; | ITensorOrOperation orig_res_t; | ||||
Tensor res_t; | Tensor res_t; | ||||
try | try | ||||
@@ -321,7 +321,7 @@ namespace Tensorflow | |||||
context_t.Exit(); | context_t.Exit(); | ||||
} | } | ||||
// Build the graph for the false branch in a new context. | // Build the graph for the false branch in a new context. | ||||
var context_f = new CondContext(pred, pivot_2, branch: 0); | |||||
var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0); | |||||
ITensorOrOperation orig_res_f; | ITensorOrOperation orig_res_f; | ||||
Tensor res_f; | Tensor res_f; | ||||
try | try | ||||
@@ -389,13 +389,13 @@ namespace Tensorflow | |||||
tensor.op.graph.prevent_fetching(tensor.op); | tensor.op.graph.prevent_fetching(tensor.op); | ||||
// Build the graph for the true branch in a new context. | // Build the graph for the true branch in a new context. | ||||
var context_t = new CondContext(pred, pivot_1, branch: 1); | |||||
var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1); | |||||
context_t.Enter(); | context_t.Enter(); | ||||
var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); | var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn); | ||||
context_t.Exit(); | context_t.Exit(); | ||||
// Build the graph for the false branch in a new context. | // Build the graph for the false branch in a new context. | ||||
var context_f = new CondContext(pred, pivot_2, branch: 0); | |||||
var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0); | |||||
context_f.Enter(); | context_f.Enter(); | ||||
var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); | var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn); | ||||
context_f.Exit(); | context_f.Exit(); | ||||
@@ -80,7 +80,7 @@ namespace Tensorflow | |||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
} | } | ||||
public static Tensor add(Tensor x, Tensor y, string name = null) | |||||
public static Tensor add<Tx, Ty>(Tx x, Ty y, string name = null) | |||||
{ | { | ||||
var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y }); | var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y }); | ||||
@@ -300,7 +300,7 @@ namespace Tensorflow | |||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
} | } | ||||
public static Tensor mul(Tensor x, Tensor y, string name = null) | |||||
public static Tensor mul<Tx, Ty>(Tx x, Ty y, string name = null) | |||||
{ | { | ||||
var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y }); | var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y }); | ||||
@@ -8,7 +8,7 @@ namespace Tensorflow | |||||
/// In order for a object to be serialized to and from MetaGraphDef, | /// In order for a object to be serialized to and from MetaGraphDef, | ||||
/// the class must implement to_proto() and from_proto() methods | /// the class must implement to_proto() and from_proto() methods | ||||
/// </summary> | /// </summary> | ||||
public interface IProtoBuf | |||||
public interface IProtoBuf<TProtoDef, TDef> | |||||
{ | { | ||||
string name { get; } | string name { get; } | ||||
@@ -17,15 +17,15 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
/// <param name="export_scope"></param> | /// <param name="export_scope"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
VariableDef to_proto(string export_scope); | |||||
TProtoDef to_proto(string export_scope); | |||||
/// <summary> | /// <summary> | ||||
/// Returns a `Variable` object created from `variable_def`. | /// Returns a `Variable` object created from `variable_def`. | ||||
/// </summary> | /// </summary> | ||||
/// <typeparam name="T"></typeparam> | /// <typeparam name="T"></typeparam> | ||||
/// <param name="variable_def"></param> | |||||
/// <param name="proto"></param> | |||||
/// <param name="import_scope"></param> | /// <param name="import_scope"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
T from_proto<T>(VariableDef variable_def, string import_scope); | |||||
TDef from_proto(TProtoDef proto, string import_scope); | |||||
} | } | ||||
} | } |
@@ -1,10 +1,12 @@ | |||||
### Download compiler from https://github.com/protocolbuffers/protobuf/releases | ### Download compiler from https://github.com/protocolbuffers/protobuf/releases | ||||
Work in command line | |||||
```shell | ```shell | ||||
cd tensorflow | |||||
set SRC_DIR=D:/Projects/tensorflow | set SRC_DIR=D:/Projects/tensorflow | ||||
set DST_DIR=D:/Projects/TensorFlow.NET/src/TensorFlowNET.Core/Protobuf | set DST_DIR=D:/Projects/TensorFlow.NET/src/TensorFlowNET.Core/Protobuf | ||||
cd tensorflow | |||||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/resource_handle.proto | protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/resource_handle.proto | ||||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/tensor_shape.proto | protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/tensor_shape.proto | ||||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/types.proto | protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/framework/types.proto | ||||
@@ -32,6 +34,7 @@ protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/cluster.prot | |||||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/config.proto | protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/config.proto | ||||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/debug.proto | protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/debug.proto | ||||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/rewriter_config.proto | protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/rewriter_config.proto | ||||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/core/protobuf/control_flow.proto | |||||
protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/python/training/checkpoint_state.proto | protoc -I=%SRC_DIR% --csharp_out=%DST_DIR% tensorflow/python/training/checkpoint_state.proto | ||||
``` | ``` | ||||
@@ -7,7 +7,7 @@ using System.Text; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public partial class RefVariable : VariableV1, IProtoBuf | |||||
public partial class RefVariable : VariableV1, IProtoBuf<VariableDef, RefVariable> | |||||
{ | { | ||||
public bool _in_graph_mode = true; | public bool _in_graph_mode = true; | ||||
public Tensor _initial_value; | public Tensor _initial_value; | ||||
@@ -288,7 +288,7 @@ namespace Tensorflow | |||||
throw new NotImplementedException("to_proto RefVariable"); | throw new NotImplementedException("to_proto RefVariable"); | ||||
} | } | ||||
public T from_proto<T>(VariableDef variable_def, string import_scope) | |||||
public RefVariable from_proto(VariableDef proto, string import_scope) | |||||
{ | { | ||||
throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
} | } | ||||
@@ -376,7 +376,7 @@ namespace Tensorflow | |||||
if (import_scope.EndsWith("/")) | if (import_scope.EndsWith("/")) | ||||
import_scope = import_scope.Substring(0, import_scope.Length - 1); | import_scope = import_scope.Substring(0, import_scope.Length - 1); | ||||
throw new NotImplementedException("prepend_name_scope"); | |||||
return $"{import_scope}/{name}"; | |||||
} | } | ||||
else | else | ||||
return name; | return name; | ||||
@@ -132,10 +132,11 @@ namespace TensorFlowNET.UnitTest | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
/// Evaluates tensors and returns a dictionary of {name:result, ...}. | |||||
/// <param name="tensors">A Tensor or a nested list/tuple of Tensors.</param> | |||||
/// This function is used in many original tensorflow unit tests to evaluate tensors | |||||
/// in a test session with special settings (for instance constant folding off) | |||||
/// | |||||
/// </summary> | /// </summary> | ||||
public Dictionary<string, NDArray> evaluate(params Tensor[] tensors) | |||||
public T evaluate<T>(Tensor tensor) | |||||
{ | { | ||||
var results = new Dictionary<string, NDArray>(); | var results = new Dictionary<string, NDArray>(); | ||||
// if context.executing_eagerly(): | // if context.executing_eagerly(): | ||||
@@ -145,49 +146,26 @@ namespace TensorFlowNET.UnitTest | |||||
var sess = ops.get_default_session(); | var sess = ops.get_default_session(); | ||||
if (sess == null) | if (sess == null) | ||||
sess = self.session(); | sess = self.session(); | ||||
with<Session>(sess, s => | |||||
{ | |||||
foreach (var t in tensors) | |||||
results[t.name] = t.eval(); | |||||
}); | |||||
return results; | |||||
} | |||||
} | |||||
public NDArray evaluate(Tensor tensor) | |||||
{ | |||||
NDArray result = null; | |||||
// if context.executing_eagerly(): | |||||
// return self._eval_helper(tensors) | |||||
// else: | |||||
{ | |||||
var sess = ops.get_default_session(); | |||||
if (sess == null) | |||||
sess = self.session(); | |||||
with<Session>(sess, s => | |||||
{ | |||||
result = tensor.eval(); | |||||
}); | |||||
return result; | |||||
} | |||||
} | |||||
public object eval_scalar(Tensor tensor) | |||||
{ | |||||
NDArray result = null; | |||||
// if context.executing_eagerly(): | |||||
// return self._eval_helper(tensors) | |||||
// else: | |||||
{ | |||||
var sess = ops.get_default_session(); | |||||
if (sess == null) | |||||
sess = self.session(); | |||||
T t_result = (T)(object)null; | |||||
with<Session>(sess, s => | with<Session>(sess, s => | ||||
{ | { | ||||
result = tensor.eval(); | |||||
var ndarray=tensor.eval(); | |||||
if (typeof(T) == typeof(double)) | |||||
{ | |||||
double d = ndarray; | |||||
t_result = (T)(object)d; | |||||
} | |||||
else if (typeof(T) == typeof(int)) | |||||
{ | |||||
int d = ndarray; | |||||
t_result = (T) (object) d; | |||||
} | |||||
else | |||||
{ | |||||
t_result = (T)(object)ndarray; | |||||
} | |||||
}); | }); | ||||
return result.Array.GetValue(0); | |||||
return t_result; | |||||
} | } | ||||
} | } | ||||
@@ -1,4 +1,5 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using System; | |||||
using Tensorflow; | using Tensorflow; | ||||
namespace TensorFlowNET.UnitTest.control_flow_ops_test | namespace TensorFlowNET.UnitTest.control_flow_ops_test | ||||
@@ -9,32 +10,73 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||||
[TestClass] | [TestClass] | ||||
public class CondTestCases : PythonTest | public class CondTestCases : PythonTest | ||||
{ | { | ||||
[TestMethod] | [TestMethod] | ||||
public void testCondTrue() | public void testCondTrue() | ||||
{ | { | ||||
with(tf.Graph().as_default(), g => | |||||
var graph = tf.Graph().as_default(); | |||||
with(tf.Session(graph), sess => | |||||
{ | { | ||||
var x = tf.constant(2); | var x = tf.constant(2); | ||||
var y = tf.constant(5); | var y = tf.constant(5); | ||||
var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)), | |||||
() => tf.add(y, tf.constant(23))); | |||||
//tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false); | |||||
self.assertEquals(eval_scalar(z), 34); | |||||
var pred = tf.less(x, y); | |||||
Func<ITensorOrOperation> if_true = delegate | |||||
{ | |||||
return tf.multiply(x, 17); | |||||
}; | |||||
Func<ITensorOrOperation> if_false = delegate | |||||
{ | |||||
return tf.add(y, 23); | |||||
}; | |||||
var z = control_flow_ops.cond(pred, if_true, if_false); | |||||
int result = z.eval(sess); | |||||
assertEquals(result, 34); | |||||
}); | }); | ||||
} | } | ||||
//[Ignore("This Test Fails due to missing edges in the graph!")] | |||||
[TestMethod] | [TestMethod] | ||||
public void testCondFalse() | public void testCondFalse() | ||||
{ | { | ||||
with(tf.Graph().as_default(), g => | |||||
/* python | |||||
* import tensorflow as tf | |||||
from tensorflow.python.framework import ops | |||||
def if_true(): | |||||
return tf.math.multiply(x, 17) | |||||
def if_false(): | |||||
return tf.math.add(y, 23) | |||||
with tf.Session() as sess: | |||||
x = tf.constant(2) | |||||
y = tf.constant(1) | |||||
pred = tf.math.less(x,y) | |||||
z = tf.cond(pred, if_true, if_false) | |||||
result = z.eval() | |||||
print(result == 24) */ | |||||
with(tf.Session(), sess => | |||||
{ | { | ||||
var x = tf.constant(2); | var x = tf.constant(2); | ||||
var y = tf.constant(1); | var y = tf.constant(1); | ||||
var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)), | |||||
() => tf.add(y, tf.constant(23))); | |||||
self.assertEquals(eval_scalar(z), 24); | |||||
var pred = tf.less(x, y); | |||||
Func<ITensorOrOperation> if_true = delegate | |||||
{ | |||||
return tf.multiply(x, 17); | |||||
}; | |||||
Func<ITensorOrOperation> if_false = delegate | |||||
{ | |||||
return tf.add(y, 23); | |||||
}; | |||||
var z = control_flow_ops.cond(pred, if_true, if_false); | |||||
int result = z.eval(sess); | |||||
assertEquals(result, 24); | |||||
}); | }); | ||||
} | } | ||||
@@ -162,7 +162,7 @@ namespace TensorFlowNET.UnitTest.ops_test | |||||
{ | { | ||||
var z1 = tf.add(a_3, tf.multiply(a_4, a_2)); | var z1 = tf.add(a_3, tf.multiply(a_4, a_2)); | ||||
}); | }); | ||||
tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false); | |||||
//tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false); | |||||
assertItemsEqual(b_1.op.control_inputs, new[] { a_1.op, a_2.op, a_3.op, a_4.op }); | assertItemsEqual(b_1.op.control_inputs, new[] { a_1.op, a_2.op, a_3.op, a_4.op }); | ||||
assertItemsEqual(b_2.op.control_inputs, b_1.op.control_inputs); | assertItemsEqual(b_2.op.control_inputs, b_1.op.control_inputs); | ||||
} | } | ||||