@@ -26,7 +26,18 @@ namespace Tensorflow | |||||
{ | { | ||||
public TF_Output Input(int index) => c_api.TF_OperationInput(new TF_Input(_handle, index)); | public TF_Output Input(int index) => c_api.TF_OperationInput(new TF_Input(_handle, index)); | ||||
public TF_DataType InputType(int index) => c_api.TF_OperationInputType(new TF_Input(_handle, index)); | public TF_DataType InputType(int index) => c_api.TF_OperationInputType(new TF_Input(_handle, index)); | ||||
public int InputListLength(string name) => c_api.TF_OperationInputListLength(_handle, name, status); | |||||
public int InputListLength(string name) | |||||
{ | |||||
int num = 0; | |||||
using(var status = new Status()) | |||||
{ | |||||
num = c_api.TF_OperationInputListLength(_handle, name, status); | |||||
status.Check(true); | |||||
} | |||||
return num; | |||||
} | |||||
public int NumInputs => c_api.TF_OperationNumInputs(_handle); | public int NumInputs => c_api.TF_OperationNumInputs(_handle); | ||||
private TF_DataType[] _input_types => _inputs._inputs.Select(x => x.dtype).ToArray(); | private TF_DataType[] _input_types => _inputs._inputs.Select(x => x.dtype).ToArray(); | ||||
@@ -24,7 +24,18 @@ namespace Tensorflow | |||||
{ | { | ||||
public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); | public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); | ||||
public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(new TF_Output(_handle, index)); | public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(new TF_Output(_handle, index)); | ||||
public int OutputListLength(string name) => c_api.TF_OperationOutputListLength(_handle, name, status); | |||||
public int OutputListLength(string name) | |||||
{ | |||||
int num = 0; | |||||
using (var status = new Status()) | |||||
{ | |||||
num = c_api.TF_OperationOutputListLength(_handle, name, status); | |||||
status.Check(true); | |||||
} | |||||
return num; | |||||
} | |||||
private Tensor[] _outputs; | private Tensor[] _outputs; | ||||
public Tensor[] outputs => _outputs; | public Tensor[] outputs => _outputs; | ||||
@@ -54,7 +54,6 @@ namespace Tensorflow | |||||
public Operation op => this; | public Operation op => this; | ||||
public TF_DataType dtype => TF_DataType.DtInvalid; | public TF_DataType dtype => TF_DataType.DtInvalid; | ||||
private Status status = new Status(); | |||||
public string name => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationName(_handle)); | public string name => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationName(_handle)); | ||||
public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | ||||
@@ -96,10 +95,14 @@ namespace Tensorflow | |||||
_operDesc = c_api.TF_NewOperation(g, opType, oper_name); | _operDesc = c_api.TF_NewOperation(g, opType, oper_name); | ||||
c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32); | c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32); | ||||
_handle = c_api.TF_FinishOperation(_operDesc, status); | |||||
// Dict mapping op name to file and line information for op colocation | |||||
// context managers. | |||||
using (var status = new Status()) | |||||
{ | |||||
_handle = c_api.TF_FinishOperation(_operDesc, status); | |||||
status.Check(true); | |||||
} | |||||
// Dict mapping op name to file and line information for op colocation | |||||
// context managers. | |||||
_control_flow_context = graph._get_control_flow_context(); | _control_flow_context = graph._get_control_flow_context(); | ||||
} | } | ||||
@@ -220,6 +223,7 @@ namespace Tensorflow | |||||
{ | { | ||||
AttrValue x = null; | AttrValue x = null; | ||||
using (var status = new Status()) | |||||
using (var buf = new Buffer()) | using (var buf = new Buffer()) | ||||
{ | { | ||||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | ||||
@@ -274,12 +278,15 @@ namespace Tensorflow | |||||
var output = tensor._as_tf_output(); | var output = tensor._as_tf_output(); | ||||
// Reset cached inputs. | // Reset cached inputs. | ||||
_inputs = null; | |||||
_inputs = null; | |||||
// after the c_api call next time _inputs is accessed | // after the c_api call next time _inputs is accessed | ||||
// the updated inputs are reloaded from the c_api | // the updated inputs are reloaded from the c_api | ||||
c_api.UpdateEdge(_graph, output, input, status); | |||||
//var updated_inputs = inputs; | |||||
status.Check(); | |||||
using (var status = new Status()) | |||||
{ | |||||
c_api.UpdateEdge(_graph, output, input, status); | |||||
//var updated_inputs = inputs; | |||||
status.Check(); | |||||
} | |||||
} | } | ||||
private void _assert_same_graph(Tensor tensor) | private void _assert_same_graph(Tensor tensor) | ||||
@@ -33,7 +33,7 @@ namespace Tensorflow | |||||
public static extern IntPtr TF_AllocateTensor(TF_DataType dtype, IntPtr dims, int num_dims, UIntPtr len); | public static extern IntPtr TF_AllocateTensor(TF_DataType dtype, IntPtr dims, int num_dims, UIntPtr len); | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern IntPtr TF_AllocateTensor(TF_DataType dtype, long[] dims, int num_dims, UIntPtr len); | |||||
public static extern IntPtr TF_AllocateTensor(TF_DataType dtype, long[] dims, int num_dims, ulong len); | |||||
/// <summary> | /// <summary> | ||||
/// returns the sizeof() for the underlying type corresponding to the given TF_DataType enum value. | /// returns the sizeof() for the underlying type corresponding to the given TF_DataType enum value. | ||||
@@ -77,14 +77,14 @@ namespace TensorFlowNET.UnitTest | |||||
[TestMethod] | [TestMethod] | ||||
public void AllocateTensor() | public void AllocateTensor() | ||||
{ | { | ||||
/*ulong num_bytes = 6 * sizeof(float); | |||||
ulong num_bytes = 6 * sizeof(float); | |||||
long[] dims = { 2, 3 }; | long[] dims = { 2, 3 }; | ||||
Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes); | Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes); | ||||
EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype); | EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype); | ||||
EXPECT_EQ(2, t.NDims); | EXPECT_EQ(2, t.NDims); | ||||
Assert.IsTrue(Enumerable.SequenceEqual(dims, t.shape)); | |||||
EXPECT_EQ((int)dims[0], t.shape[0]); | |||||
EXPECT_EQ(num_bytes, t.bytesize); | EXPECT_EQ(num_bytes, t.bytesize); | ||||
t.Dispose();*/ | |||||
t.Dispose(); | |||||
} | } | ||||