@@ -1,16 +1,23 @@ | |||||
using System; | using System; | ||||
using System.Collections; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public class InputList | |||||
public class InputList : IEnumerable | |||||
{ | { | ||||
public Tensor[] _inputs; | public Tensor[] _inputs; | ||||
public Tensor this[int index] => _inputs[index]; | |||||
public InputList(Tensor[] inputs) | public InputList(Tensor[] inputs) | ||||
{ | { | ||||
_inputs = inputs; | _inputs = inputs; | ||||
} | } | ||||
public IEnumerator GetEnumerator() | |||||
{ | |||||
return _inputs.GetEnumerator(); | |||||
} | |||||
} | } | ||||
} | } |
@@ -0,0 +1,20 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public partial class Operation | |||||
{ | |||||
/// <summary> | |||||
/// Add this op to its control flow context. | |||||
/// </summary> | |||||
public void _control_flow_post_processing() | |||||
{ | |||||
foreach(var input_tensor in inputs) | |||||
{ | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -12,21 +12,7 @@ namespace Tensorflow | |||||
public int OutputListLength(string name) => c_api.TF_OperationOutputListLength(_handle, name, status); | public int OutputListLength(string name) => c_api.TF_OperationOutputListLength(_handle, name, status); | ||||
private Tensor[] _outputs; | private Tensor[] _outputs; | ||||
public Tensor[] outputs | |||||
{ | |||||
get | |||||
{ | |||||
if (_outputs == null) | |||||
{ | |||||
_outputs = new Tensor[NumOutputs]; | |||||
for (int i = 0; i < NumOutputs; i++) | |||||
_outputs[i] = new Tensor(this, i, OutputType(i)); | |||||
} | |||||
return _outputs; | |||||
} | |||||
} | |||||
public Tensor[] outputs => _outputs; | |||||
public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | ||||
public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); | public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); | ||||
@@ -8,7 +8,7 @@ namespace Tensorflow | |||||
{ | { | ||||
public partial class Operation | public partial class Operation | ||||
{ | { | ||||
private readonly IntPtr _handle; | |||||
private readonly IntPtr _handle; // _c_op in python | |||||
public Graph Graph { get; } | public Graph Graph { get; } | ||||
public int _id => _id_value; | public int _id => _id_value; | ||||
@@ -97,12 +97,20 @@ namespace Tensorflow | |||||
_handle = ops._create_c_op(g, node_def, inputs, control_input_ops.ToArray()); | _handle = ops._create_c_op(g, node_def, inputs, control_input_ops.ToArray()); | ||||
// Initialize self._outputs. | |||||
output_types = new TF_DataType[NumOutputs]; | output_types = new TF_DataType[NumOutputs]; | ||||
for (int i = 0; i < NumOutputs; i++) | for (int i = 0; i < NumOutputs; i++) | ||||
output_types[i] = OutputType(i); | output_types[i] = OutputType(i); | ||||
_outputs = new Tensor[NumOutputs]; | |||||
for (int i = 0; i < NumOutputs; i++) | |||||
_outputs[i] = new Tensor(this, i, OutputType(i)); | |||||
Graph._add_op(this); | Graph._add_op(this); | ||||
if (_handle != IntPtr.Zero) | |||||
_control_flow_post_processing(); | |||||
} | } | ||||
public object get_attr<T>(string name) | public object get_attr<T>(string name) | ||||
@@ -18,7 +18,7 @@ namespace Tensorflow | |||||
var _op = _op_def_lib._apply_op_helper("Add", keywords: keywords); | var _op = _op_def_lib._apply_op_helper("Add", keywords: keywords); | ||||
return new Tensor(_op, 0, _op.OutputType(0)); | |||||
return _op.outputs[0]; | |||||
} | } | ||||
public static Tensor sub(Tensor x, Tensor y) | public static Tensor sub(Tensor x, Tensor y) | ||||
@@ -29,7 +29,7 @@ namespace Tensorflow | |||||
var _op = _op_def_lib._apply_op_helper("Sub", name: "sub", keywords: keywords); | var _op = _op_def_lib._apply_op_helper("Sub", name: "sub", keywords: keywords); | ||||
return new Tensor(_op, 0, _op.OutputType(0)); | |||||
return _op.outputs[0]; | |||||
} | } | ||||
public static Tensor mul(Tensor x, Tensor y) | public static Tensor mul(Tensor x, Tensor y) | ||||
@@ -40,7 +40,7 @@ namespace Tensorflow | |||||
var _op = _op_def_lib._apply_op_helper("Mul", keywords: keywords); | var _op = _op_def_lib._apply_op_helper("Mul", keywords: keywords); | ||||
return new Tensor(_op, 0, _op.OutputType(0)); | |||||
return _op.outputs[0]; | |||||
} | } | ||||
public static Tensor real_div(Tensor x, Tensor y) | public static Tensor real_div(Tensor x, Tensor y) | ||||
@@ -51,7 +51,7 @@ namespace Tensorflow | |||||
var _op = _op_def_lib._apply_op_helper("RealDiv", name: "truediv", keywords: keywords); | var _op = _op_def_lib._apply_op_helper("RealDiv", name: "truediv", keywords: keywords); | ||||
return new Tensor(_op, 0, _op.OutputType(0)); | |||||
return _op.outputs[0]; | |||||
} | } | ||||
public static Tensor mat_mul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false) | public static Tensor mat_mul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false) | ||||
@@ -64,7 +64,7 @@ namespace Tensorflow | |||||
var _op = _op_def_lib._apply_op_helper("MatMul", keywords: keywords); | var _op = _op_def_lib._apply_op_helper("MatMul", keywords: keywords); | ||||
return new Tensor(_op, 0, _op.OutputType(0)); | |||||
return _op.outputs[0]; | |||||
} | } | ||||
public static Tensor pow(Tensor x, double y) | public static Tensor pow(Tensor x, double y) | ||||
@@ -75,7 +75,7 @@ namespace Tensorflow | |||||
var _op = _op_def_lib._apply_op_helper("Pow", keywords: keywords); | var _op = _op_def_lib._apply_op_helper("Pow", keywords: keywords); | ||||
return new Tensor(_op, 0, _op.OutputType(0)); | |||||
return _op.outputs[0]; | |||||
} | } | ||||
public static Tensor sum(Tensor input, Tensor axis = null) | public static Tensor sum(Tensor input, Tensor axis = null) | ||||
@@ -87,7 +87,7 @@ namespace Tensorflow | |||||
var _op = _op_def_lib._apply_op_helper("Sum", keywords: keywords); | var _op = _op_def_lib._apply_op_helper("Sum", keywords: keywords); | ||||
return new Tensor(_op, 0, _op.OutputType(0)); | |||||
return _op.outputs[0]; | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -23,6 +23,12 @@ namespace Tensorflow | |||||
public static implicit operator RefVariable(Tensor var) | public static implicit operator RefVariable(Tensor var) | ||||
{ | { | ||||
switch (var.dtype) | |||||
{ | |||||
case TF_DataType.TF_INT32: | |||||
return tf.Variable(var.Data<int>()[0]); | |||||
} | |||||
return null; | return null; | ||||
} | } | ||||
} | } | ||||