/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/
using Google.Protobuf.Collections;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Tensorflow.Util;
namespace Tensorflow
{
///
/// Represents a graph node that performs computation on tensors.
///
/// An `Operation` is a node in a TensorFlow `Graph` that takes zero or
/// more `Tensor` objects as input, and produces zero or more `Tensor`
/// objects as output. Objects of type `Operation` are created by
/// calling an op constructor(such as `tf.matmul`)
/// or `tf.Graph.create_op`.
///
/// For example `c = tf.matmul(a, b)` creates an `Operation` of type
/// "MatMul" that takes tensors `a` and `b` as input, and produces `c`
/// as output.
///
/// After the graph has been launched in a session, an `Operation` can
/// be executed by passing it to
/// `tf.Session.run`.
/// `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`.
///
public partial class Operation : ITensorOrOperation
{
private readonly IntPtr _handle; // _c_op in python
private readonly IntPtr _operDesc;
private Graph _graph;
public string type => OpType;
public Graph graph => _graph;
public int _id => _id_value;
public int _id_value;
public Operation op => this;
public TF_DataType dtype => TF_DataType.DtInvalid;
public string name => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationName(_handle));
public string OpType => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationOpType(_handle));
public string Device => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationDevice(_handle));
private NodeDef _node_def;
public NodeDef node_def
{
get
{
if(_node_def == null)
_node_def = GetNodeDef();
return _node_def;
}
}
public Operation(IntPtr handle, Graph g=null)
{
if (handle == IntPtr.Zero)
return;
_handle = handle;
_graph = g ?? ops.get_default_graph();
_outputs = new Tensor[NumOutputs];
for (int i = 0; i < NumOutputs; i++)
_outputs[i] = new Tensor(this, i, OutputType(i));
// Dict mapping op name to file and line information for op colocation
// context managers.
_control_flow_context = _graph._get_control_flow_context();
// Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor.
}
public Operation(Graph g, string opType, string oper_name)
{
_graph = g;
_operDesc = c_api.TF_NewOperation(g, opType, oper_name);
c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32);
using (var status = new Status())
{
_handle = c_api.TF_FinishOperation(_operDesc, status);
status.Check(true);
}
// Dict mapping op name to file and line information for op colocation
// context managers.
_control_flow_context = graph._get_control_flow_context();
}
///
/// Creates an `Operation`.
///
/// `node_def_pb2.NodeDef`. `NodeDef` for the `Operation`.
/// `Graph`. The parent graph.
/// list of `Tensor` objects. The inputs to this `Operation`.
/// list of `DType` objects.
///
/// list of operations or tensors from which to have a
/// control dependency.
///
///
/// List of `DType` objects representing the
/// types of the tensors accepted by the `Operation`. By default
/// uses `[x.dtype.base_dtype for x in inputs]`. Operations that expect
/// reference-typed inputs must specify these explicitly.
///
///
///
public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
{
_graph = g;
// Build the list of control inputs.
var control_input_ops = new List();
if(control_inputs != null)
{
foreach(var c in control_inputs)
{
switch (c)
{
case Operation c1:
control_input_ops.Add(c1);
break;
case Tensor tensor:
control_input_ops.Add(tensor.op);
break;
// TODO: IndexedSlices don't yet exist, but once they do, this needs to be uncommented
//case IndexedSlices islices:
// control_input_ops.Add(islices.op);
// break;
default:
throw new NotImplementedException($"Control input must be an Operation, a Tensor, or IndexedSlices: {c}");
}
}
}
// Dict mapping op name to file and line information for op colocation
// context managers.
_control_flow_context = graph._get_control_flow_context();
// This will be set by self.inputs.
if (op_def == null)
op_def = g.GetOpDef(node_def.Op);
var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);
(_handle, _operDesc) = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());
// Initialize self._outputs.
output_types = new TF_DataType[NumOutputs];
for (int i = 0; i < NumOutputs; i++)
output_types[i] = OutputType(i);
_outputs = new Tensor[NumOutputs];
for (int i = 0; i < NumOutputs; i++)
_outputs[i] = new Tensor(this, i, OutputType(i));
graph._add_op(this);
if (_handle != IntPtr.Zero)
_control_flow_post_processing();
}
public void run(FeedItem[] feed_dict = null, Session session = null)
{
ops._run_using_default_session(this, feed_dict, graph, session);
}
private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField attrs)
{
var grouped_inputs = new List