using System.Collections.Generic;
using System;
using System.Linq;
namespace Tensorflow.Eager
{
public class Execute
{
///
/// Execute a TensorFlow operation.
///
///
/// Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
/// execute.
///
///
/// The number of outputs of the operation to fetch.
///
///
/// A list of inputs to the operation. Each entry should be a Tensor, or
/// a value which can be passed to the Tensor constructor to create one.
///
///
/// A tuple with alternating string attr names and attr values for this
/// operation.
///
/// The value of context.context().
/// Customized name for the operation.
/// List of output Tensor objects. The list is empty if there are no outputs
public Tensor execute(Context ctx, string op_name, int num_outputs,
Tensor[] inputs, object[] attrs,
string name = null)
{
ctx.ensure_initialized();
// TFE_TensorHandle
using var status = new Status();
/*var retVals = wrap_tfe_src.TFE_Execute(ctx, ctx.device_name, op_name, inputs, attrs, num_outputs, status);
return new EagerTensor((TFE_TensorHandle)retVals[0]);*/
IntPtr[] outputs = new IntPtr[num_outputs];
c_api.TFE_QuickExecute(ctx,
ctx.device_name,
op_name,
inputs.Select(x => (x as EagerTensor).GetTfeTensorHandle()).ToArray(),
inputs.Length,
op => wrap_tfe_src.SetOpAttrs(op, attrs),
outputs,
num_outputs,
status);
status.Check(true);
TFE_TensorHandle tfe_tensor_handle = outputs[0];
return new EagerTensor(tfe_tensor_handle);
}
public (TF_DataType, Tensor[]) args_to_matching_eager(Context ctx, TF_DataType default_dtype = TF_DataType.DtInvalid, object[] args = null)
{
if (args.Length == 0 && default_dtype != TF_DataType.DtInvalid)
return (default_dtype, null);
if (args.Count(x => x is EagerTensor) == args.Length)
return ((args[0] as EagerTensor).dtype, args.Select(x => x as EagerTensor).ToArray());
var dtype = TF_DataType.DtInvalid;
foreach (var x in args)
{
if (x is EagerTensor et)
dtype = et.dtype;
}
if (dtype == TF_DataType.DtInvalid)
{
var ret = new List();
foreach (var t in args)
{
ret.Add(ops.convert_to_tensor(t, dtype, preferred_dtype: default_dtype, ctx: ctx));
if (dtype == TF_DataType.DtInvalid)
dtype = ret.Last().dtype;
}
return (dtype, ret.ToArray());
}
else
throw new NotImplementedException("");
}
}
}