Browse Source

seperate Input and Output implementation from Operation

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
666372455c
4 changed files with 128 additions and 108 deletions
  1. +60
    -0
      src/TensorFlowNET.Core/Operations/Operation.Input.cs
  2. +66
    -0
      src/TensorFlowNET.Core/Operations/Operation.Output.cs
  3. +1
    -107
      src/TensorFlowNET.Core/Operations/Operation.cs
  4. +1
    -1
      test/TensorFlowNET.UnitTest/VariableTest.cs

+ 60
- 0
src/TensorFlowNET.Core/Operations/Operation.Input.cs View File

@@ -0,0 +1,60 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;

namespace Tensorflow
{
public partial class Operation
{
public TF_Output Input(int index) => c_api.TF_OperationInput(new TF_Input(_handle, index));
public TF_DataType InputType(int index) => c_api.TF_OperationInputType(new TF_Input(_handle, index));
public int InputListLength(string name) => c_api.TF_OperationInputListLength(_handle, name, status);
public int NumInputs => c_api.TF_OperationNumInputs(_handle);
private TF_DataType[] _input_types => _inputs._inputs.Select(x => x.dtype).ToArray();

private InputList _inputs;
public InputList inputs
{
get
{
if (_inputs == null)
{
var retval = new Tensor[NumInputs];

for (int i = 0; i < NumInputs; i++)
{
var tf_outpus = Input(i);
var op = new Operation(tf_outpus.oper);
retval[i] = op.outputs[tf_outpus.index];
}

_inputs = new InputList(retval);
}

return _inputs;
}
}

public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle);

public unsafe Operation[] GetControlInputs()
{
var control_inputs = new Operation[NumControlInputs];

if (NumControlInputs > 0)
{
IntPtr control_input_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlInputs);
c_api.TF_OperationGetControlInputs(_handle, control_input_handle, NumControlInputs);
for (int i = 0; i < NumControlInputs; i++)
{
var handle = control_input_handle + Marshal.SizeOf<IntPtr>() * i;
control_inputs[i] = new Operation(*(IntPtr*)handle);
}
}

return control_inputs;
}
}
}

+ 66
- 0
src/TensorFlowNET.Core/Operations/Operation.Output.cs View File

@@ -0,0 +1,66 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;

namespace Tensorflow
{
public partial class Operation
{
public int NumOutputs => c_api.TF_OperationNumOutputs(_handle);
public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(new TF_Output(_handle, index));
public int OutputListLength(string name) => c_api.TF_OperationOutputListLength(_handle, name, status);

private Tensor[] _outputs;
public Tensor[] outputs
{
get
{
if (_outputs == null)
{
_outputs = new Tensor[NumOutputs];

for (int i = 0; i < NumOutputs; i++)
_outputs[i] = new Tensor(this, i, OutputType(i));
}

return _outputs;
}
}

public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle);
public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index));

public unsafe TF_Input[] OutputConsumers(int index, int max_consumers)
{
int size = Marshal.SizeOf<TF_Input>();
var handle = Marshal.AllocHGlobal(size);
int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers);
var consumers = new TF_Input[num];
for (int i = 0; i < num; i++)
{
consumers[i] = Marshal.PtrToStructure<TF_Input>(handle + i * size);
}

return consumers;
}

public unsafe Operation[] GetControlOutputs()
{
var control_outputs = new Operation[NumControlOutputs];

if (NumControlOutputs > 0)
{
IntPtr control_output_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlOutputs);
c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlInputs);
for (int i = 0; i < NumControlInputs; i++)
{
var handle = control_output_handle + Marshal.SizeOf<IntPtr>() * i;
control_outputs[i] = new Operation(*(IntPtr*)handle);
}
}

return control_outputs;
}
}
}

+ 1
- 107
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -6,7 +6,7 @@ using System.Text;

namespace Tensorflow
{
public class Operation
public partial class Operation
{
private readonly IntPtr _handle;

@@ -20,112 +20,6 @@ namespace Tensorflow
public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle));
public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle));

public int NumOutputs => c_api.TF_OperationNumOutputs(_handle);
public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(new TF_Output(_handle, index));
public int OutputListLength(string name) => c_api.TF_OperationOutputListLength(_handle, name, status);

public TF_Output Input(int index) => c_api.TF_OperationInput(new TF_Input(_handle, index));
public TF_DataType InputType(int index) => c_api.TF_OperationInputType(new TF_Input(_handle, index));
public int InputListLength(string name) => c_api.TF_OperationInputListLength(_handle, name, status);
public int NumInputs => c_api.TF_OperationNumInputs(_handle);

public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index));
public unsafe TF_Input[] OutputConsumers(int index, int max_consumers)
{
int size = Marshal.SizeOf<TF_Input>();
var handle = Marshal.AllocHGlobal(size);
int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers);
var consumers = new TF_Input[num];
for (int i = 0; i < num; i++)
{
consumers[i] = Marshal.PtrToStructure<TF_Input>(handle + i * size);
}

return consumers;
}

public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle);

public unsafe Operation[] GetControlInputs()
{
var control_inputs = new Operation[NumControlInputs];

if (NumControlInputs > 0)
{
IntPtr control_input_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlInputs);
c_api.TF_OperationGetControlInputs(_handle, control_input_handle, NumControlInputs);
for (int i = 0; i < NumControlInputs; i++)
{
var handle = control_input_handle + Marshal.SizeOf<IntPtr>() * i;
control_inputs[i] = new Operation(*(IntPtr*)handle);
}
}

return control_inputs;
}

public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle);

public unsafe Operation[] GetControlOutputs()
{
var control_outputs = new Operation[NumControlOutputs];

if (NumControlOutputs > 0)
{
IntPtr control_output_handle = Marshal.AllocHGlobal(Marshal.SizeOf<IntPtr>() * NumControlOutputs);
c_api.TF_OperationGetControlOutputs(_handle, control_output_handle, NumControlInputs);
for (int i = 0; i < NumControlInputs; i++)
{
var handle = control_output_handle + Marshal.SizeOf<IntPtr>() * i;
control_outputs[i] = new Operation(*(IntPtr*)handle);
}
}

return control_outputs;
}

private Tensor[] _outputs;
public Tensor[] outputs
{
get
{
if (_outputs == null)
{
_outputs = new Tensor[NumOutputs];

for (int i = 0; i < NumOutputs; i++)
_outputs[i] = new Tensor(this, i, OutputType(i));
}

return _outputs;
}
}

private InputList _inputs;
public InputList inputs
{
get
{
if (_inputs == null)
{
var retval = new Tensor[NumInputs];

for (int i = 0; i < NumInputs; i++)
{
var tf_outpus = Input(i);
var op = new Operation(tf_outpus.oper);
retval[i] = op.outputs[tf_outpus.index];
}

_inputs = new InputList(retval);
}

return _inputs;
}
}

private TF_DataType[] _input_types => _inputs._inputs.Select(x => x.dtype).ToArray();

private NodeDef _node_def;
public NodeDef node_def
{


+ 1
- 1
test/TensorFlowNET.UnitTest/VariableTest.cs View File

@@ -49,7 +49,7 @@ namespace TensorFlowNET.UnitTest

using (var session = tf.Session())
{
var sm = session.run(model);
session.run(x.initializer);
for(int i = 0; i < 5; i++)
{
var x1 = x + 1;


Loading…
Cancel
Save