using System; using System.Collections.Generic; using System.Runtime.InteropServices; using System.Text; namespace Tensorflow { public class Operation { private readonly IntPtr _handle; public Graph Graph { get; } public int _id => _id_value; private int _id_value; private Status status = new Status(); public string Name => c_api.StringPiece(c_api.TF_OperationName(_handle)); public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle)); public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(new TF_Output(_handle, index)); public int OutputListLength(string name) => c_api.TF_OperationOutputListLength(_handle, name, status); public TF_Output Input(int index) => c_api.TF_OperationInput(new TF_Input(_handle, index)); public TF_DataType InputType(int index) => c_api.TF_OperationInputType(new TF_Input(_handle, index)); public int InputListLength(string name) => c_api.TF_OperationInputListLength(_handle, name, status); public int NumInputs => c_api.TF_OperationNumInputs(_handle); public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) { int size = Marshal.SizeOf(); var handle = (TF_Input*)Marshal.AllocHGlobal(size); int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers); var consumers = new TF_Input[num]; for(int i = 0; i < num; i++) { consumers[i] = new TF_Input((*handle).oper + i * size, (*handle).index); } return consumers; } public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); private Tensor[] _outputs; public Tensor[] outputs => _outputs; public Tensor[] inputs; public Operation(IntPtr handle) { if (handle == IntPtr.Zero) return; _handle = handle; } public Operation(Graph g, string opType, string oper_name) { Graph = g; var desc = c_api.TF_NewOperation(g, opType, oper_name); c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_INT32); c_api.TF_FinishOperation(desc, status); } public Operation(NodeDef node_def, Graph g, List inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) { Graph = g; _id_value = Graph._next_id(); _handle = ops._create_c_op(g, node_def, inputs); _outputs = new Tensor[NumOutputs]; for (int i = 0; i < NumOutputs; i++) { _outputs[i] = new Tensor(this, i, output_types[i]); } Graph._add_op(this); } public object get_attr(string name) { object ret = null; var fields = new string[] { "s", "i", "f", "b", "type", "shape", "tensor", "func" }; switch (name) { case "dtype": ret = _outputs[0]; break; case "shape": ret = new TensorShapeProto(); break; } return ret; } public static implicit operator Operation(IntPtr handle) { return new Operation(handle); } public static implicit operator IntPtr(Operation op) { return op._handle; } public override bool Equals(object obj) { switch (obj) { case IntPtr val: return val == _handle; } return base.Equals(obj); } } }