using Google.Protobuf.Collections;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
namespace Tensorflow
{
public partial class Operation : IReturnTensorOrOperation
{
private readonly IntPtr _handle; // _c_op in python
public Graph Graph { get; }
public int _id => _id_value;
private int _id_value;
public string type => OpType;
private Status status = new Status();
public string Name => c_api.StringPiece(c_api.TF_OperationName(_handle));
public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle));
public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle));
private NodeDef _node_def;
public NodeDef node_def
{
get
{
if(_node_def == null)
_node_def = GetNodeDef();
return _node_def;
}
}
public Operation(IntPtr handle)
{
if (handle == IntPtr.Zero)
return;
_handle = handle;
this.Graph = ops.get_default_graph();
_outputs = new Tensor[NumOutputs];
for (int i = 0; i < NumOutputs; i++)
_outputs[i] = new Tensor(this, i, OutputType(i));
}
public Operation(Graph g, string opType, string oper_name)
{
Graph = g;
var desc = c_api.TF_NewOperation(g, opType, oper_name);
c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_INT32);
c_api.TF_FinishOperation(desc, status);
}
///
/// Creates an `Operation`.
///
/// `node_def_pb2.NodeDef`. `NodeDef` for the `Operation`.
/// `Graph`. The parent graph.
/// list of `Tensor` objects. The inputs to this `Operation`.
/// list of `DType` objects.
///
/// list of operations or tensors from which to have a
/// control dependency.
///
///
/// List of `DType` objects representing the
/// types of the tensors accepted by the `Operation`. By default
/// uses `[x.dtype.base_dtype for x in inputs]`. Operations that expect
/// reference-typed inputs must specify these explicitly.
///
///
///
public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, Operation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
{
Graph = g;
// Build the list of control inputs.
var control_input_ops = new List();
if(control_inputs != null)
{
foreach(var c in control_inputs)
{
switch (c)
{
case Operation c1:
control_input_ops.Add(c1);
break;
default:
throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}");
}
}
}
// This will be set by self.inputs.
_id_value = Graph._next_id();
if(op_def == null)
op_def = g.GetOpDef(node_def.Op);
var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);
_handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());
// Initialize self._outputs.
output_types = new TF_DataType[NumOutputs];
for (int i = 0; i < NumOutputs; i++)
output_types[i] = OutputType(i);
_outputs = new Tensor[NumOutputs];
for (int i = 0; i < NumOutputs; i++)
_outputs[i] = new Tensor(this, i, OutputType(i));
Graph._add_op(this);
if (_handle != IntPtr.Zero)
_control_flow_post_processing();
}
private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField attrs)
{
var grouped_inputs = new List