using Google.Protobuf.Collections;
//using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
namespace Tensorflow
{
///
/// Represents a graph node that performs computation on tensors.
///
/// An `Operation` is a node in a TensorFlow `Graph` that takes zero or
/// more `Tensor` objects as input, and produces zero or more `Tensor`
/// objects as output. Objects of type `Operation` are created by
/// calling an op constructor(such as `tf.matmul`)
/// or `tf.Graph.create_op`.
///
/// For example `c = tf.matmul(a, b)` creates an `Operation` of type
/// "MatMul" that takes tensors `a` and `b` as input, and produces `c`
/// as output.
///
/// After the graph has been launched in a session, an `Operation` can
/// be executed by passing it to
/// `tf.Session.run`.
/// `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`.
///
public partial class Operation : ITensorOrOperation
{
private readonly IntPtr _handle; // _c_op in python
private readonly IntPtr _operDesc;
private Graph _graph;
//[JsonIgnore]
public Graph graph => _graph;
//[JsonIgnore]
public int _id => _id_value;
//[JsonIgnore]
public int _id_value;
public string type => OpType;
//[JsonIgnore]
public Operation op => this;
public TF_DataType dtype => TF_DataType.DtInvalid;
private Status status = new Status();
public string name => c_api.StringPiece(c_api.TF_OperationName(_handle));
public string OpType => c_api.StringPiece(c_api.TF_OperationOpType(_handle));
public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle));
private NodeDef _node_def;
public NodeDef node_def
{
get
{
if(_node_def == null)
_node_def = GetNodeDef();
return _node_def;
}
}
public Operation(IntPtr handle, Graph g=null)
{
if (handle == IntPtr.Zero)
return;
_handle = handle;
_graph = g ?? ops.get_default_graph();
_outputs = new Tensor[NumOutputs];
for (int i = 0; i < NumOutputs; i++)
_outputs[i] = new Tensor(this, i, OutputType(i));
// Dict mapping op name to file and line information for op colocation
// context managers.
_control_flow_context = graph._get_control_flow_context();
// Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor.
}
public Operation(Graph g, string opType, string oper_name)
{
_graph = g;
_operDesc = c_api.TF_NewOperation(g, opType, oper_name);
c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32);
_handle = c_api.TF_FinishOperation(_operDesc, status);
// Dict mapping op name to file and line information for op colocation
// context managers.
_control_flow_context = graph._get_control_flow_context();
}
///
/// Creates an `Operation`.
///
/// `node_def_pb2.NodeDef`. `NodeDef` for the `Operation`.
/// `Graph`. The parent graph.
/// list of `Tensor` objects. The inputs to this `Operation`.
/// list of `DType` objects.
///
/// list of operations or tensors from which to have a
/// control dependency.
///
///
/// List of `DType` objects representing the
/// types of the tensors accepted by the `Operation`. By default
/// uses `[x.dtype.base_dtype for x in inputs]`. Operations that expect
/// reference-typed inputs must specify these explicitly.
///
///
///
public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
{
_graph = g;
// Build the list of control inputs.
var control_input_ops = new List();
if(control_inputs != null)
{
foreach(var c in control_inputs)
{
switch (c)
{
case Operation c1:
control_input_ops.Add(c1);
break;
case Tensor tensor:
control_input_ops.Add(tensor.op);
break;
// TODO: IndexedSlices don't yet exist, but once they do, this needs to be uncommented
//case IndexedSlices islices:
// control_input_ops.Add(islices.op);
// break;
default:
throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}");
}
}
}
// Dict mapping op name to file and line information for op colocation
// context managers.
_control_flow_context = graph._get_control_flow_context();
// This will be set by self.inputs.
if (op_def == null)
op_def = g.GetOpDef(node_def.Op);
var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);
(_handle, _operDesc) = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());
// Initialize self._outputs.
output_types = new TF_DataType[NumOutputs];
for (int i = 0; i < NumOutputs; i++)
output_types[i] = OutputType(i);
_outputs = new Tensor[NumOutputs];
for (int i = 0; i < NumOutputs; i++)
_outputs[i] = new Tensor(this, i, OutputType(i));
graph._add_op(this);
if (_handle != IntPtr.Zero)
_control_flow_post_processing();
}
public void run(FeedItem[] feed_dict = null, Session session = null)
{
ops._run_using_default_session(this, feed_dict, graph, session);
}
private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField attrs)
{
var grouped_inputs = new List