diff --git a/src/TensorFlowNET.Console/MemoryMonitor.cs b/src/TensorFlowNET.Console/MemoryMonitor.cs
index 92cd224f..f9a6bfd1 100644
--- a/src/TensorFlowNET.Console/MemoryMonitor.cs
+++ b/src/TensorFlowNET.Console/MemoryMonitor.cs
@@ -23,11 +23,9 @@ namespace Tensorflow
var x = tf.placeholder(tf.float64, shape: (1024, 1024));
var log = tf.log(x);
- using (var sess = tf.Session())
- {
- var ones = np.ones((1024, 1024), dtype: np.float64);
- var o = sess.run(log, new FeedItem(x, ones));
- }
+ var sess = tf.Session();
+ var ones = np.ones((1024, 1024), dtype: np.float64);
+ var o = sess.run(log, new FeedItem(x, ones));
// Thread.Sleep(1);
}
diff --git a/src/TensorFlowNET.Core/Buffers/Buffer.cs b/src/TensorFlowNET.Core/Buffers/Buffer.cs
index bb4b880a..9ec9e22f 100644
--- a/src/TensorFlowNET.Core/Buffers/Buffer.cs
+++ b/src/TensorFlowNET.Core/Buffers/Buffer.cs
@@ -25,15 +25,15 @@ namespace Tensorflow
///
/// Represents a TF_Buffer that can be passed to Tensorflow.
///
- public sealed class Buffer : IDisposable
+ public sealed class Buffer
{
- public SafeBufferHandle Handle { get; }
+ SafeBufferHandle _handle;
///
///
///
private unsafe ref readonly TF_Buffer DangerousBuffer
- => ref Unsafe.AsRef(Handle.DangerousGetHandle().ToPointer());
+ => ref Unsafe.AsRef(_handle.DangerousGetHandle().ToPointer());
///
/// The memory block representing this buffer.
@@ -59,7 +59,7 @@ namespace Tensorflow
{
get
{
- using (Handle.Lease())
+ using (_handle.Lease())
{
return DangerousBuffer.length;
}
@@ -67,13 +67,13 @@ namespace Tensorflow
}
public Buffer()
- => Handle = TF_NewBuffer();
+ => _handle = TF_NewBuffer();
public Buffer(SafeBufferHandle handle)
- => Handle = handle;
+ => _handle = handle;
public Buffer(byte[] data)
- => Handle = _toBuffer(data);
+ => _handle = _toBuffer(data);
private static SafeBufferHandle _toBuffer(byte[] data)
{
@@ -92,7 +92,7 @@ namespace Tensorflow
///
public unsafe byte[] ToArray()
{
- using (Handle.Lease())
+ using (_handle.Lease())
{
ref readonly TF_Buffer buffer = ref DangerousBuffer;
@@ -107,7 +107,12 @@ namespace Tensorflow
}
}
- public void Dispose()
- => Handle.Dispose();
+ public override string ToString()
+ => $"0x{_handle.DangerousGetHandle():x16}";
+
+ public static implicit operator SafeBufferHandle(Buffer buffer)
+ {
+ return buffer._handle;
+ }
}
}
\ No newline at end of file
diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs b/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs
index ffefe312..a1dba371 100644
--- a/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs
+++ b/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs
@@ -11,7 +11,7 @@ public class CheckpointReader
Status status = new Status();
VariableToDataTypeMap = new Dictionary();
VariableToShapeMap = new Dictionary();
- _handle = c_api.TF_NewCheckpointReader(filename, status.Handle);
+ _handle = c_api.TF_NewCheckpointReader(filename, status);
status.Check(true);
ReadAllShapeAndType();
}
@@ -38,7 +38,7 @@ public class CheckpointReader
int num_dims = GetVariableNumDims(name);
long[] dims = new long[num_dims];
Status status = new Status();
- c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status.Handle);
+ c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status);
status.Check(true);
return new Shape(dims);
}
@@ -49,7 +49,7 @@ public class CheckpointReader
public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid)
{
Status status = new Status();
- var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status.Handle);
+ var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status);
status.Check(true);
return new Tensor(tensor);
}
diff --git a/src/TensorFlowNET.Core/Contexts/Context.Device.cs b/src/TensorFlowNET.Core/Contexts/Context.Device.cs
index fea2c824..97c550e8 100644
--- a/src/TensorFlowNET.Core/Contexts/Context.Device.cs
+++ b/src/TensorFlowNET.Core/Contexts/Context.Device.cs
@@ -37,7 +37,7 @@ namespace Tensorflow.Contexts
public void log_device_placement(bool enable)
{
if (_handle != null)
- c_api.TFE_ContextSetLogDevicePlacement(_handle, enable, tf.Status.Handle);
+ c_api.TFE_ContextSetLogDevicePlacement(_handle, enable, tf.Status);
_log_device_placement = enable;
// _thread_local_data.function_call_options = null;
}
@@ -60,15 +60,15 @@ namespace Tensorflow.Contexts
public PhysicalDevice[] list_physical_devices(string device_type = null)
{
using var opts = c_api.TFE_NewContextOptions();
- using var ctx = c_api.TFE_NewContext(opts, tf.Status.Handle);
- using var devices = c_api.TFE_ContextListDevices(ctx, tf.Status.Handle);
+ using var ctx = c_api.TFE_NewContext(opts, tf.Status);
+ using var devices = c_api.TFE_ContextListDevices(ctx, tf.Status);
tf.Status.Check(true);
int num_devices = c_api.TF_DeviceListCount(devices);
var results = new List();
for (int i = 0; i < num_devices; ++i)
{
- var dev_type = c_api.StringPiece(c_api.TF_DeviceListType(devices, i, tf.Status.Handle));
+ var dev_type = c_api.StringPiece(c_api.TF_DeviceListType(devices, i, tf.Status));
tf.Status.Check(true);
if (dev_type.StartsWith("XLA"))
@@ -76,7 +76,7 @@ namespace Tensorflow.Contexts
if (device_type == null || dev_type == device_type)
{
- var dev_name = c_api.TF_DeviceListName(devices, i, tf.Status.Handle);
+ var dev_name = c_api.TF_DeviceListName(devices, i, tf.Status);
tf.Status.Check(true);
results.Add(new PhysicalDevice
diff --git a/src/TensorFlowNET.Core/Contexts/Context.cs b/src/TensorFlowNET.Core/Contexts/Context.cs
index 5d02c027..21a14831 100644
--- a/src/TensorFlowNET.Core/Contexts/Context.cs
+++ b/src/TensorFlowNET.Core/Contexts/Context.cs
@@ -28,7 +28,7 @@ namespace Tensorflow.Contexts
///
/// Environment in which eager operations execute.
///
- public sealed partial class Context : IDisposable
+ public sealed partial class Context
{
public const int GRAPH_MODE = 0;
public const int EAGER_MODE = 1;
@@ -41,15 +41,7 @@ namespace Tensorflow.Contexts
public FunctionCallOptions FunctionCallOptions { get; }
SafeContextHandle _handle;
- public SafeContextHandle Handle
- {
- get
- {
- if (_handle == null)
- ensure_initialized();
- return _handle;
- }
- }
+
int? _seed;
Random _rng;
@@ -59,6 +51,7 @@ namespace Tensorflow.Contexts
context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE, false);
initialized = false;
FunctionCallOptions = new FunctionCallOptions();
+ ensure_initialized();
}
///
@@ -72,12 +65,12 @@ namespace Tensorflow.Contexts
Config = MergeConfig();
FunctionCallOptions.Config = Config;
var config_str = Config.ToByteArray();
- using var opts = new ContextOptions();
- using var status = new Status();
- c_api.TFE_ContextOptionsSetConfig(opts.Handle, config_str, (ulong)config_str.Length, status.Handle);
+ var opts = new ContextOptions();
+ var status = new Status();
+ c_api.TFE_ContextOptionsSetConfig(opts, config_str, (ulong)config_str.Length, status);
status.Check(true);
- c_api.TFE_ContextOptionsSetDevicePlacementPolicy(opts.Handle, _device_policy);
- _handle = c_api.TFE_NewContext(opts.Handle, status.Handle);
+ c_api.TFE_ContextOptionsSetDevicePlacementPolicy(opts, _device_policy);
+ _handle = c_api.TFE_NewContext(opts, status);
status.Check(true);
initialized = true;
}
@@ -178,10 +171,14 @@ namespace Tensorflow.Contexts
tf.Context.ensure_initialized();
if (_handle != null)
+ {
c_api.TFE_ContextClearCaches(_handle);
+ }
}
- public void Dispose()
- => _handle.Dispose();
+ public static implicit operator SafeContextHandle(Context ctx)
+ {
+ return ctx._handle;
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Contexts/ContextOptions.cs b/src/TensorFlowNET.Core/Contexts/ContextOptions.cs
index 6c2156a9..4a07f1f5 100644
--- a/src/TensorFlowNET.Core/Contexts/ContextOptions.cs
+++ b/src/TensorFlowNET.Core/Contexts/ContextOptions.cs
@@ -14,21 +14,21 @@
limitations under the License.
******************************************************************************/
-using System;
using Tensorflow.Eager;
-namespace Tensorflow.Contexts
+namespace Tensorflow.Contexts;
+
+public sealed class ContextOptions
{
- public sealed class ContextOptions : IDisposable
- {
- public SafeContextOptionsHandle Handle { get; }
+ SafeContextOptionsHandle _handle { get; }
- public ContextOptions()
- {
- Handle = c_api.TFE_NewContextOptions();
- }
+ public ContextOptions()
+ {
+ _handle = c_api.TFE_NewContextOptions();
+ }
- public void Dispose()
- => Handle.Dispose();
+ public static implicit operator SafeContextOptionsHandle(ContextOptions opt)
+ {
+ return opt._handle;
}
}
diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs
index 4aad851f..aa205d45 100644
--- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs
@@ -43,7 +43,7 @@ namespace Tensorflow.Eager
{
var status = tf.Status;
var op = GetOp(ctx, op_name, status);
- c_api.TFE_OpSetDevice(op, device_name, status.Handle);
+ c_api.TFE_OpSetDevice(op, device_name, status);
if (status.ok())
{
for (int i = 0; i < inputs.Length; ++i)
@@ -54,7 +54,7 @@ namespace Tensorflow.Eager
Tensor nd => nd.EagerTensorHandle,
_ => throw new NotImplementedException("Eager tensor handle has not been allocated.")
};
- c_api.TFE_OpAddInput(op, tensor_handle, status.Handle);
+ c_api.TFE_OpAddInput(op, tensor_handle, status);
status.Check(true);
}
}
@@ -64,7 +64,7 @@ namespace Tensorflow.Eager
var outputs = new SafeEagerTensorHandle[num_outputs];
if (status.ok())
{
- c_api.TFE_Execute(op, outputs, out num_outputs, status.Handle);
+ c_api.TFE_Execute(op, outputs, out num_outputs, status);
status.Check(true);
}
return outputs.Select(x => new EagerTensor(x)).ToArray();
diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
index c6158ab0..92d5b2a4 100644
--- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
@@ -104,7 +104,7 @@ namespace Tensorflow.Eager
var eager_tensor = ops.convert_to_tensor(fast_input_array[j]);
attr_values[j] = eager_tensor.dtype;
- c_api.TFE_OpAddInput(op, eager_tensor.EagerTensorHandle, status.Handle);
+ c_api.TFE_OpAddInput(op, eager_tensor.EagerTensorHandle, status);
if (op_exec_info.run_callbacks)
{
@@ -142,7 +142,7 @@ namespace Tensorflow.Eager
}
var retVals = new SafeEagerTensorHandle[num_retvals];
- c_api.TFE_Execute(op, retVals, out num_retvals, status.Handle);
+ c_api.TFE_Execute(op, retVals, out num_retvals, status);
status.Check(true);
var flat_result = retVals.Select(x => new EagerTensor(x)).ToArray();
@@ -160,10 +160,10 @@ namespace Tensorflow.Eager
SafeEagerOpHandle GetOp(Context ctx, string op_or_function_name, Status status)
{
if (thread_local_eager_operation_map.find(op_or_function_name, out var op))
- c_api.TFE_OpReset(op, op_or_function_name, ctx.DeviceName, status.Handle);
+ c_api.TFE_OpReset(op, op_or_function_name, ctx.DeviceName, status);
else
{
- op = c_api.TFE_NewOp(ctx.Handle, op_or_function_name, status.Handle);
+ op = c_api.TFE_NewOp(ctx, op_or_function_name, status);
thread_local_eager_operation_map[op_or_function_name] = op;
}
@@ -219,7 +219,7 @@ namespace Tensorflow.Eager
flattened_attrs.Add(dtype);
}
- c_api.TFE_OpAddInput(op, tensor.EagerTensorHandle, status.Handle);
+ c_api.TFE_OpAddInput(op, tensor.EagerTensorHandle, status);
status.Check(true);
return true;
@@ -235,7 +235,7 @@ namespace Tensorflow.Eager
var value = attrs[i + 1];
byte is_list = 0;
- var type = c_api.TFE_OpGetAttrType(op, key, ref is_list, status.Handle);
+ var type = c_api.TFE_OpGetAttrType(op, key, ref is_list, status);
if (!status.ok()) return;
if (is_list != 0)
SetOpAttrList(tf.Context, op, key, value as object[], type, null, status);
@@ -264,7 +264,7 @@ namespace Tensorflow.Eager
Status status)
{
byte is_list = 0;
- var type = c_api.TFE_OpGetAttrType(op, attr_name, ref is_list, status.Handle);
+ var type = c_api.TFE_OpGetAttrType(op, attr_name, ref is_list, status);
if (status.Code != TF_Code.TF_OK) return;
if (attr_value == null)
@@ -305,7 +305,7 @@ namespace Tensorflow.Eager
tf.memcpy(dims[i], values1[i].dims, values1[i].ndim * sizeof(long));
}
- c_api.TFE_OpSetAttrShapeList(op, key, dims, num_dims, num_values, status.Handle);
+ c_api.TFE_OpSetAttrShapeList(op, key, dims, num_dims, num_values, status);
Array.ForEach(dims, x => Marshal.FreeHGlobal(x));
}
else if (type == TF_AttrType.TF_ATTR_TYPE && values is TF_DataType[] values2)
@@ -353,7 +353,7 @@ namespace Tensorflow.Eager
break;
case TF_AttrType.TF_ATTR_SHAPE:
var dims = (value as long[]).ToArray();
- c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status.Handle);
+ c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status);
status.Check(true);
break;
case TF_AttrType.TF_ATTR_FUNC:
diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
index b9f741f3..c7d71de3 100644
--- a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
@@ -54,7 +54,7 @@ namespace Tensorflow.Eager
void NewEagerTensorHandle(SafeTensorHandle h)
{
_id = ops.uid();
- _eagerTensorHandle = c_api.TFE_NewTensorHandle(h, tf.Status.Handle);
+ _eagerTensorHandle = c_api.TFE_NewTensorHandle(h, tf.Status);
#if TRACK_TENSOR_LIFE
Console.WriteLine($"New EagerTensor {_eagerTensorHandle}");
#endif
@@ -65,7 +65,7 @@ namespace Tensorflow.Eager
{
if (_handle != null)
return;
- _handle = c_api.TFE_TensorHandleResolve(_eagerTensorHandle, tf.Status.Handle);
+ _handle = c_api.TFE_TensorHandleResolve(_eagerTensorHandle, tf.Status);
tf.Status.Check(true);
}
diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.cs
index f85e8df6..02bd0bdf 100644
--- a/src/TensorFlowNET.Core/Eager/EagerTensor.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerTensor.cs
@@ -24,10 +24,10 @@ namespace Tensorflow.Eager
}
}
- public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(_eagerTensorHandle, tf.Status.Handle));
+ public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(_eagerTensorHandle, tf.Status));
public override TF_DataType dtype => c_api.TFE_TensorHandleDataType(_eagerTensorHandle);
- public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.Status.Handle);
+ public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.Status);
public override ulong bytesize
{
@@ -49,9 +49,9 @@ namespace Tensorflow.Eager
protected override Shape GetShapeInternal()
{
- var dims = new int[c_api.TFE_TensorHandleNumDims(_eagerTensorHandle, tf.Status.Handle)];
+ var dims = new int[c_api.TFE_TensorHandleNumDims(_eagerTensorHandle, tf.Status)];
for (int i = 0; i < dims.Length; i++)
- dims[i] = c_api.TFE_TensorHandleDim(_eagerTensorHandle, i, tf.Status.Handle);
+ dims[i] = c_api.TFE_TensorHandleDim(_eagerTensorHandle, i, tf.Status);
return dims;
}
@@ -64,15 +64,15 @@ namespace Tensorflow.Eager
public static int GetRank(IntPtr handle)
{
var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle);
- return c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, tf.Status.Handle);
+ return c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, tf.Status);
}
public static int[] GetDims(IntPtr handle)
{
var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle);
- var dims = new int[c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, tf.Status.Handle)];
+ var dims = new int[c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, tf.Status)];
for (int i = 0; i < dims.Length; i++)
- dims[i] = c_api.TFE_TensorHandleDim(tfe_tensor_handle, i, tf.Status.Handle);
+ dims[i] = c_api.TFE_TensorHandleDim(tfe_tensor_handle, i, tf.Status);
return dims;
}
diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs
index d874ac93..6930b0c7 100644
--- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs
+++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs
@@ -114,7 +114,7 @@ namespace Tensorflow
///
///
[DllImport(TensorFlowLibName)]
- public static extern void TFE_ContextAddFunction(SafeContextHandle ctx, IntPtr function, SafeStatusHandle status);
+ public static extern void TFE_ContextAddFunction(SafeContextHandle ctx, SafeFuncGraphHandle function, SafeStatusHandle status);
///
/// Removes a function from the context. Once removed, you can no longer
diff --git a/src/TensorFlowNET.Core/Framework/importer.cs b/src/TensorFlowNET.Core/Framework/importer.cs
index 1d0098b4..5b99c200 100644
--- a/src/TensorFlowNET.Core/Framework/importer.cs
+++ b/src/TensorFlowNET.Core/Framework/importer.cs
@@ -56,15 +56,14 @@ namespace Tensorflow
TF_ImportGraphDefResults results = null;
var bytes = graph_def.ToByteString().ToArray();
- using (var buffer = c_api_util.tf_buffer(bytes))
- using (var scoped_options = c_api_util.ScopedTFImportGraphDefOptions())
- using (var status = new Status())
- {
- _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements);
- // need to create a class ImportGraphDefWithResults with IDisposal
- results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer.Handle, scoped_options.Handle, status.Handle));
- status.Check(true);
- }
+ var buffer = c_api_util.tf_buffer(bytes);
+ var scoped_options = c_api_util.ScopedTFImportGraphDefOptions();
+ var status = new Status();
+
+ _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements);
+ // need to create a class ImportGraphDefWithResults with IDisposal
+ results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status));
+ status.Check(true);
_ProcessNewOps(graph);
@@ -116,13 +115,13 @@ namespace Tensorflow
Dictionary input_map,
string[] return_elements)
{
- c_api.TF_ImportGraphDefOptionsSetPrefix(options.Handle, prefix);
- c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options.Handle, (char)1);
+ c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix);
+ c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, (char)1);
foreach (var input in input_map)
{
var (src_name, src_index) = _ParseTensorName(input.Key);
- c_api.TF_ImportGraphDefOptionsAddInputMapping(options.Handle, src_name, src_index, input.Value._as_tf_output());
+ c_api.TF_ImportGraphDefOptionsAddInputMapping(options, src_name, src_index, input.Value._as_tf_output());
}
if (return_elements == null)
@@ -133,11 +132,11 @@ namespace Tensorflow
if (name.Contains(":"))
{
var (op_name, index) = _ParseTensorName(name);
- c_api.TF_ImportGraphDefOptionsAddReturnOutput(options.Handle, op_name, index);
+ c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index);
}
else
{
- c_api.TF_ImportGraphDefOptionsAddReturnOperation(options.Handle, name);
+ c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, name);
}
}
diff --git a/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs b/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs
index eec234c6..111719aa 100644
--- a/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs
+++ b/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs
@@ -33,7 +33,7 @@ namespace Tensorflow
if (_registered_ops.Count > 0)
return _registered_ops;
- using var buffer = new Buffer(c_api.TF_GetAllOpList());
+ var buffer = new Buffer(c_api.TF_GetAllOpList());
var op_list = OpList.Parser.ParseFrom(buffer.ToArray());
foreach (var op_def in op_list.Op)
_registered_ops[op_def.Name] = op_def;
diff --git a/src/TensorFlowNET.Core/Framework/smart_module.cs b/src/TensorFlowNET.Core/Framework/smart_module.cs
index d9e35a6d..e1f84d7e 100644
--- a/src/TensorFlowNET.Core/Framework/smart_module.cs
+++ b/src/TensorFlowNET.Core/Framework/smart_module.cs
@@ -56,8 +56,8 @@ namespace Tensorflow.Framework
if (pred_value is null)
{
var result = range(pred.op.NumOutputs).Select(x => IntPtr.Zero).ToArray();
- var evaluated = c_api.TF_TryEvaluateConstant(pred.graph, pred._as_tf_output(), result, tf.Status.Handle);
- if (!evaluated || c_api.TF_GetCode(tf.Status.Handle) != TF_Code.TF_OK)
+ var evaluated = c_api.TF_TryEvaluateConstant(pred.graph, pred._as_tf_output(), result, tf.Status);
+ if (!evaluated || c_api.TF_GetCode(tf.Status) != TF_Code.TF_OK)
return null;
else
throw new NotImplementedException("");
diff --git a/src/TensorFlowNET.Core/Functions/c_api.function.cs b/src/TensorFlowNET.Core/Functions/c_api.function.cs
index 230d85ba..3fbb3868 100644
--- a/src/TensorFlowNET.Core/Functions/c_api.function.cs
+++ b/src/TensorFlowNET.Core/Functions/c_api.function.cs
@@ -34,10 +34,10 @@ namespace Tensorflow
///
///
[DllImport(TensorFlowLibName)]
- public static extern void TF_FunctionToFunctionDef(IntPtr func, SafeBufferHandle output_func_def, SafeStatusHandle status);
+ public static extern void TF_FunctionToFunctionDef(SafeFuncGraphHandle func, SafeBufferHandle output_func_def, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
- public static extern IntPtr TF_GraphToFunction(IntPtr fn_body, string fn_name,
+ public static extern SafeFuncGraphHandle TF_GraphToFunction(SafeGraphHandle fn_body, string fn_name,
bool append_hash_to_fn_name,
int num_opers, IntPtr[] opers,
int ninputs, TF_Output[] inputs,
@@ -48,12 +48,12 @@ namespace Tensorflow
SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
- public static extern IntPtr TF_FunctionSetAttrValueProto(IntPtr func, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status);
+ public static extern IntPtr TF_FunctionSetAttrValueProto(SafeFuncGraphHandle func, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
- public static extern IntPtr TF_FunctionName(IntPtr func);
+ public static extern IntPtr TF_FunctionName(SafeFuncGraphHandle func);
[DllImport(TensorFlowLibName)]
- public static extern void TF_GraphCopyFunction(IntPtr g, IntPtr func, IntPtr grad, SafeStatusHandle status);
+ public static extern void TF_GraphCopyFunction(SafeGraphHandle g, SafeFuncGraphHandle func, IntPtr grad, SafeStatusHandle status);
}
}
diff --git a/src/TensorFlowNET.Core/Gradients/c_api.gradient.cs b/src/TensorFlowNET.Core/Gradients/c_api.gradient.cs
index 70dcfd67..901a33ca 100644
--- a/src/TensorFlowNET.Core/Gradients/c_api.gradient.cs
+++ b/src/TensorFlowNET.Core/Gradients/c_api.gradient.cs
@@ -37,7 +37,7 @@ namespace Tensorflow
/// TF_Status*
/// TF_Output*
[DllImport(TensorFlowLibName)]
- public static extern void TF_AddGradientsWithPrefix(IntPtr g, string prefix, TF_Output[] y, int ny,
+ public static extern void TF_AddGradientsWithPrefix(SafeGraphHandle g, string prefix, TF_Output[] y, int ny,
TF_Output[] x, int nx, TF_Output[] dx, SafeStatusHandle status, IntPtr[] dy);
}
}
diff --git a/src/TensorFlowNET.Core/GraphTransformation/GraphTransformer.cs b/src/TensorFlowNET.Core/GraphTransformation/GraphTransformer.cs
index 8870e295..f662b448 100644
--- a/src/TensorFlowNET.Core/GraphTransformation/GraphTransformer.cs
+++ b/src/TensorFlowNET.Core/GraphTransformation/GraphTransformer.cs
@@ -22,21 +22,19 @@ namespace Tensorflow
var inputs_string = string.Join(",", inputs);
var outputs_string = string.Join(",", outputs);
var transforms_string = string.Join(" ", transforms);
- using (var status = new Status())
- {
- var buffer = new Buffer();
- var len = c_api.TransformGraphWithStringInputs(input_graph_def_string,
- input_graph_def_string.Length,
- inputs_string,
- outputs_string,
- transforms_string,
- buffer.Handle,
- status.Handle);
+ var status = new Status();
+ var buffer = new Buffer();
+ var len = c_api.TransformGraphWithStringInputs(input_graph_def_string,
+ input_graph_def_string.Length,
+ inputs_string,
+ outputs_string,
+ transforms_string,
+ buffer,
+ status);
- status.Check(false);
- var bytes = buffer.ToArray();
- return GraphDef.Parser.ParseFrom(bytes);
- }
+ status.Check(false);
+ var bytes = buffer.ToArray();
+ return GraphDef.Parser.ParseFrom(bytes);
}
}
}
diff --git a/src/TensorFlowNET.Core/Graphs/AutoGraph.cs b/src/TensorFlowNET.Core/Graphs/AutoGraph.cs
index ceeca8ab..48d14d6b 100644
--- a/src/TensorFlowNET.Core/Graphs/AutoGraph.cs
+++ b/src/TensorFlowNET.Core/Graphs/AutoGraph.cs
@@ -37,11 +37,9 @@ namespace Tensorflow.Graphs
1);
return result[0];
}
- using (var s = tf.Session(input.graph))
- {
- var output = func(input);
- return output;
- }
+ var s = tf.Session(input.graph);
+ var output = func(input);
+ return output;
};
}
@@ -75,12 +73,10 @@ namespace Tensorflow.Graphs
1);
return result[0];
}
- using (var s = tf.Session(a.graph))
- {
- Debug.Assert(a.graph == b.graph);
- var output = func(a, b);
- return output;
- }
+ var s = tf.Session(a.graph);
+ Debug.Assert(a.graph == b.graph);
+ var output = func(a, b);
+ return output;
};
}
}
diff --git a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs
index df750813..a8dd4eb9 100644
--- a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs
+++ b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs
@@ -1,258 +1,252 @@
using Google.Protobuf;
-using System;
-using System.Collections.Generic;
-using System.Linq;
using Tensorflow.Eager;
using Tensorflow.Exceptions;
using static Tensorflow.Binding;
-namespace Tensorflow.Graphs
+namespace Tensorflow.Graphs;
+
+///
+/// Graph representing a function body.
+///
+public class FuncGraph : Graph, IDisposable
{
+ SafeFuncGraphHandle _func_graph_handle;
+ public string FuncName => _graph_key;
+
+ public Tensors Inputs { get; set; } = new Tensors();
+ public Tensors Outputs { get; set; } = new Tensors();
+ public Dictionary Attrs { get; set; }
+
+ Dictionary _captures
+ = new Dictionary();
+
+ public Tensor[] external_captures
+ => _captures.Select(x => x.Value.Item1).ToArray();
+ public (Tensor, Tensor)[] captures
+ => _captures.Values.Select(x => x).ToArray();
+
+ public Tensor[] internal_captures
+ => _captures.Select(x => x.Value.Item2).ToArray();
+
+ public Tensor[] captured_inputs
+ => external_captures;
+
///
- /// Graph representing a function body.
+ /// Construct a new FuncGraph.
///
- public class FuncGraph : Graph
+ public FuncGraph(string name) : base()
{
- IntPtr _func_graph_handle;
- public string FuncName => _graph_key;
-
- public Tensors Inputs { get; set; } = new Tensors();
- public Tensors Outputs { get; set; } = new Tensors();
- public Dictionary Attrs { get; set; }
+ outer_graph = ops.get_default_graph();
+ while (outer_graph.building_function)
+ outer_graph = outer_graph.OuterGraph;
+ _graph_key = name;
+ building_function = true;
+ }
- Dictionary _captures
- = new Dictionary();
+ public FuncGraph(SafeGraphHandle handle, string name, Dictionary attrs) : base()
+ {
+ outer_graph = ops.get_default_graph();
+ while (outer_graph.building_function)
+ outer_graph = outer_graph.OuterGraph;
+ _graph_key = name;
+ building_function = true;
+ Attrs = attrs;
+ // Will to test if FuncGraph has memory leak
+ // c_api.TF_DeleteGraph(_handle);
+ _handle = handle;
+ }
- public Tensor[] external_captures
- => _captures.Select(x => x.Value.Item1).ToArray();
- public (Tensor, Tensor)[] captures
- => _captures.Values.Select(x => x).ToArray();
+ public void ToGraph(Operation[] opers,
+ Tensor[] inputs, Tensor[] outputs,
+ string[] output_names)
+ {
+ var status = new Status();
+ _func_graph_handle = c_api.TF_GraphToFunction(_handle,
+ _graph_key,
+ false,
+ opers.Length,
+ opers.Select(x => (IntPtr)x).ToArray(),
+ inputs.Length,
+ inputs.Select(x => new TF_Output(x.op, 0)).ToArray(),
+ outputs.Length,
+ outputs.Select(x => new TF_Output(x.op, 0)).ToArray(),
+ output_names,
+ IntPtr.Zero,
+ null,
+ status);
+ status.Check(true);
+
+ SetAttrs();
+
+ // c_api.TF_GraphCopyFunction(outer_graph, _func_graph_handle, IntPtr.Zero, status.Handle);
+ // status.Check(true);
+
+ c_api.TFE_ContextAddFunction(tf.Context, _func_graph_handle, status);
+ status.Check(true);
+
+ _graph_key = c_api.StringPiece(c_api.TF_FunctionName(_func_graph_handle));
+
+ Inputs = inputs;
+ // mark_as_return
+ Outputs = outputs;// .Select(x => array_ops.identity(x)).ToArray();
+ }
- public Tensor[] internal_captures
- => _captures.Select(x => x.Value.Item2).ToArray();
+ public override Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, string name = null, Dictionary attrs = null, OpDef op_def = null, bool compute_device = true)
+ {
+ foreach(var (i, inp) in enumerate(inputs))
+ inputs[i] = capture(inp);
- public Tensor[] captured_inputs
- => external_captures;
+ return base.create_op(op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device);
+ }
- ///
- /// Construct a new FuncGraph.
- ///
- public FuncGraph(string name) : base()
+ const int _EAGER_CONST_THRESHOLD = 128;
+ public Tensor capture(Tensor tensor, string name = null, Shape shape = null)
+ {
+ if(tensor is EagerTensor)
{
- outer_graph = ops.get_default_graph();
- while (outer_graph.building_function)
- outer_graph = outer_graph.OuterGraph;
- _graph_key = name;
- building_function = true;
+ if (name == null)
+ name = ops.uid().ToString();
+
+ // Small EagerTensors are captured with Const ops
+ if (dtypes.is_value_dtype(tensor.dtype)
+ && (tensor.rank == 0 || tensor.size < _EAGER_CONST_THRESHOLD))
+ return capture_eager_tensor(tensor, name);
+
+ // Large EagerTensors and resources are captured with Placeholder ops
+ return _capture_helper(tensor, name, shape: shape);
}
- public FuncGraph(IntPtr handle, string name, Dictionary attrs) : base()
+ if(tensor.graph != this)
{
- outer_graph = ops.get_default_graph();
- while (outer_graph.building_function)
- outer_graph = outer_graph.OuterGraph;
- _graph_key = name;
- building_function = true;
- Attrs = attrs;
- // Will to test if FuncGraph has memory leak
- // c_api.TF_DeleteGraph(_handle);
- _handle = handle;
+ if (name == null)
+ name = tensor.op.name;
+ var inner_graph = tensor.graph;
+ while(inner_graph != null && inner_graph is FuncGraph inner_func_graph)
+ {
+ if (inner_graph == this)
+ throw new InaccessibleTensorError($"The tensor '{tensor.name}' cannot be accessed here: it is defined" +
+ " in another function or code block. Use return values," +
+ " explicit Python locals or TensorFlow collections to access" +
+ $" it. Defined in: {tensor.graph.graph_key}; accessed from: {graph_key}.");
+ inner_graph = inner_func_graph.outer_graph;
+ }
+ return _capture_helper(tensor, name);
}
- public void ToGraph(Operation[] opers,
- Tensor[] inputs, Tensor[] outputs,
- string[] output_names)
+ return tensor;
+ }
+
+ Tensor capture_eager_tensor(Tensor tensor, string name)
+ {
+ Tensor graph_const = null;
+ if (!_captures.ContainsKey(tensor.Id))
{
- var status = new Status();
- _func_graph_handle = c_api.TF_GraphToFunction(_handle,
- _graph_key,
- false,
- opers.Length,
- opers.Select(x => (IntPtr)x).ToArray(),
- inputs.Length,
- inputs.Select(x => new TF_Output(x.op, 0)).ToArray(),
- outputs.Length,
- outputs.Select(x => new TF_Output(x.op, 0)).ToArray(),
- output_names == null || output_names.Length == 0 ? null : output_names,
- IntPtr.Zero,
- null,
- status.Handle);
- status.Check(true);
-
- SetAttrs();
-
- // c_api.TF_GraphCopyFunction(outer_graph, _func_graph_handle, IntPtr.Zero, status.Handle);
- // status.Check(true);
-
- c_api.TFE_ContextAddFunction(tf.Context.Handle, _func_graph_handle, status.Handle);
- status.Check(true);
-
- _graph_key = c_api.StringPiece(c_api.TF_FunctionName(_func_graph_handle));
-
- Inputs = inputs;
- // mark_as_return
- Outputs = outputs;// .Select(x => array_ops.identity(x)).ToArray();
+ graph_const = tf_with(ops.control_dependencies(null), ctl
+ => constant_op.constant(tensor.numpy(), dtype: tensor.dtype, shape: tensor.shape, name: name));
+ add_capture(tensor, graph_const);
}
-
- public override Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, string name = null, Dictionary attrs = null, OpDef op_def = null, bool compute_device = true)
+ else
{
- foreach(var (i, inp) in enumerate(inputs))
- inputs[i] = capture(inp);
-
- return base.create_op(op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device);
+ graph_const = _captures[tensor.Id].Item2;
}
- const int _EAGER_CONST_THRESHOLD = 128;
- public Tensor capture(Tensor tensor, string name = null, Shape shape = null)
+ BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) =>
{
- if(tensor is EagerTensor)
- {
- if (name == null)
- name = ops.uid().ToString();
-
- // Small EagerTensors are captured with Const ops
- if (dtypes.is_value_dtype(tensor.dtype)
- && (tensor.rank == 0 || tensor.size < _EAGER_CONST_THRESHOLD))
- return capture_eager_tensor(tensor, name);
+ return output_grads;
+ };
- // Large EagerTensors and resources are captured with Placeholder ops
- return _capture_helper(tensor, name, shape: shape);
- }
+ tf.Runner.RecordGradient("captured_value",
+ new[] { graph_const }, null,
+ new[] { tensor },
+ getBackwardFunction: _backward_function_wrapper
+ /*getForwardFunction: forward_function*/);
- if(tensor.graph != this)
- {
- if (name == null)
- name = tensor.op.name;
- var inner_graph = tensor.graph;
- while(inner_graph != null && inner_graph is FuncGraph inner_func_graph)
- {
- if (inner_graph == this)
- throw new InaccessibleTensorError($"The tensor '{tensor.name}' cannot be accessed here: it is defined" +
- " in another function or code block. Use return values," +
- " explicit Python locals or TensorFlow collections to access" +
- $" it. Defined in: {tensor.graph.graph_key}; accessed from: {graph_key}.");
- inner_graph = inner_func_graph.outer_graph;
- }
- return _capture_helper(tensor, name);
- }
+ return graph_const;
+ }
- return tensor;
+ Tensor _capture_helper(Tensor tensor, string name, Shape shape = null)
+ {
+ Tensor placeholder = null;
+ if (!_captures.ContainsKey(tensor.Id))
+ {
+ placeholder = _create_substitute_placeholder(tensor,
+ name: name,
+ dtype: tensor.dtype,
+ shape: shape);
+ add_capture(tensor, placeholder);
}
-
- Tensor capture_eager_tensor(Tensor tensor, string name)
+ else
{
- Tensor graph_const = null;
- if (!_captures.ContainsKey(tensor.Id))
- {
- graph_const = tf_with(ops.control_dependencies(null), ctl
- => constant_op.constant(tensor.numpy(), dtype: tensor.dtype, shape: tensor.shape, name: name));
- add_capture(tensor, graph_const);
- }
- else
- {
- graph_const = _captures[tensor.Id].Item2;
- }
-
- BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) =>
- {
- return output_grads;
- };
-
- tf.Runner.RecordGradient("captured_value",
- new[] { graph_const }, null,
- new[] { tensor },
- getBackwardFunction: _backward_function_wrapper
- /*getForwardFunction: forward_function*/);
-
- return graph_const;
+ placeholder = _captures[tensor.Id].Item2;
}
- Tensor _capture_helper(Tensor tensor, string name, Shape shape = null)
+ BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) =>
{
- Tensor placeholder = null;
- if (!_captures.ContainsKey(tensor.Id))
- {
- placeholder = _create_substitute_placeholder(tensor,
- name: name,
- dtype: tensor.dtype,
- shape: shape);
- add_capture(tensor, placeholder);
- }
- else
- {
- placeholder = _captures[tensor.Id].Item2;
- }
+ return output_grads;
+ };
- BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) =>
- {
- return output_grads;
- };
+ tf.Runner.RecordGradient("captured_value",
+ new[] { placeholder }, null,
+ new[] { tensor },
+ getBackwardFunction: _backward_function_wrapper
+ /*getForwardFunction: forward_function*/);
- tf.Runner.RecordGradient("captured_value",
- new[] { placeholder }, null,
- new[] { tensor },
- getBackwardFunction: _backward_function_wrapper
- /*getForwardFunction: forward_function*/);
+ return placeholder;
+ }
- return placeholder;
- }
+ void add_capture(Tensor tensor, Tensor placeholder)
+ {
+ _captures.Add(tensor.Id, (tensor, placeholder));
+ Inputs.Add(placeholder);
+ }
- void add_capture(Tensor tensor, Tensor placeholder)
- {
- _captures.Add(tensor.Id, (tensor, placeholder));
- Inputs.Add(placeholder);
- }
+ Tensor _create_substitute_placeholder(Tensor value,
+ string name = null,
+ TF_DataType dtype = TF_DataType.DtInvalid,
+ Shape shape = null)
+ {
+ if (shape is null)
+ shape = value.shape;
+ if (dtype == TF_DataType.DtInvalid)
+ dtype = value.dtype;
+
+ var placeholder = tf_with(ops.control_dependencies(null), ctl
+ => array_ops.placeholder(dtype, shape: shape, name: name));
+ // custom_gradient.copy_handle_data(value, placeholder)
+ return placeholder;
+ }
- Tensor _create_substitute_placeholder(Tensor value,
- string name = null,
- TF_DataType dtype = TF_DataType.DtInvalid,
- Shape shape = null)
- {
- if (shape is null)
- shape = value.shape;
- if (dtype == TF_DataType.DtInvalid)
- dtype = value.dtype;
-
- var placeholder = tf_with(ops.control_dependencies(null), ctl
- => array_ops.placeholder(dtype, shape: shape, name: name));
- // custom_gradient.copy_handle_data(value, placeholder)
- return placeholder;
- }
+ void SetAttrs()
+ {
+ if (Attrs == null)
+ return;
- void SetAttrs()
+ foreach (var (_name, attr_value) in enumerate(Attrs))
{
- if (Attrs == null)
- return;
-
- foreach (var (_name, attr_value) in enumerate(Attrs))
+ var serialized = new AttrValue
{
- var serialized = new AttrValue
- {
- S = ByteString.CopyFromUtf8(attr_value)
- }.ToByteArray();
- c_api.TF_FunctionSetAttrValueProto(_func_graph_handle, _name, serialized, serialized.Length, tf.Status.Handle);
- tf.Status.Check(true);
- }
+ S = ByteString.CopyFromUtf8(attr_value)
+ }.ToByteArray();
+ c_api.TF_FunctionSetAttrValueProto(_func_graph_handle, _name, serialized, serialized.Length, tf.Status);
+ tf.Status.Check(true);
}
+ }
- public override Graph as_default()
- {
- tf.Context.graph_mode(isFunc: true);
- ops.set_default_graph(this);
- return this;
- }
+ public override Graph as_default()
+ {
+ tf.Context.graph_mode(isFunc: true);
+ ops.set_default_graph(this);
+ return this;
+ }
- public override void Exit()
- {
- tf.Context.restore_mode();
- ops.pop_graph();
- }
+ public override void Exit()
+ {
+ tf.Context.restore_mode();
+ ops.pop_graph();
+ }
- protected override void DisposeUnmanagedResources(IntPtr handle)
- {
- c_api.TFE_ContextRemoveFunction(tf.Context.Handle, _graph_key, tf.Status.Handle);
- c_api.TF_DeleteFunction(_func_graph_handle);
- base.DisposeUnmanagedResources(handle);
- }
+ public void Dispose()
+ {
+ c_api.TFE_ContextRemoveFunction(tf.Context, _graph_key, tf.Status);
}
}
diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Export.cs b/src/TensorFlowNET.Core/Graphs/Graph.Export.cs
index 612c7401..a11d91e7 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.Export.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.Export.cs
@@ -24,7 +24,7 @@ namespace Tensorflow
public Buffer ToGraphDef(Status s)
{
var buffer = new Buffer();
- c_api.TF_GraphToGraphDef(_handle, buffer.Handle, s.Handle);
+ c_api.TF_GraphToGraphDef(_handle, buffer, s);
s.Check(true);
return buffer;
@@ -33,14 +33,12 @@ namespace Tensorflow
private GraphDef _as_graph_def(bool add_shapes = false)
{
GraphDef def;
- using (var status = new Status())
- using (var buffer = ToGraphDef(status))
- {
- status.Check(true);
- // limit size to 250M, recursion to max 100
- var inputStream = CodedInputStream.CreateWithLimits(buffer.DangerousMemoryBlock, 250 * 1024 * 1024, 100);
- def = GraphDef.Parser.ParseFrom(inputStream);
- }
+ var status = new Status();
+ var buffer = ToGraphDef(status);
+ status.Check(true);
+ // limit size to 250M, recursion to max 100
+ var inputStream = CodedInputStream.CreateWithLimits(buffer.DangerousMemoryBlock, 250 * 1024 * 1024, 100);
+ def = GraphDef.Parser.ParseFrom(inputStream);
// Strip the experimental library field iff it's empty.
// if(def.Library.Function.Count == 0)
diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Import.cs b/src/TensorFlowNET.Core/Graphs/Graph.Import.cs
index 28ecd64e..b80e2659 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.Import.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.Import.cs
@@ -29,7 +29,7 @@ namespace Tensorflow
int size = Marshal.SizeOf();
var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs);
- c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def.Handle, opts.Handle, return_output_handle, num_return_outputs, s.Handle);
+ c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s);
var tf_output_ptr = (TF_Output*)return_output_handle;
for (int i = 0; i < num_return_outputs; i++)
@@ -48,15 +48,14 @@ namespace Tensorflow
public bool Import(byte[] bytes, string prefix = "")
{
- using (var opts = new ImportGraphDefOptions())
- using (var status = new Status())
- using (var graph_def = new Buffer(bytes))
- {
- c_api.TF_ImportGraphDefOptionsSetPrefix(opts.Handle, prefix);
- c_api.TF_GraphImportGraphDef(_handle, graph_def.Handle, opts.Handle, status.Handle);
- status.Check(true);
- return status.Code == TF_Code.TF_OK;
- }
+ var opts = new ImportGraphDefOptions();
+ var status = new Status();
+ var graph_def = new Buffer(bytes);
+
+ c_api.TF_ImportGraphDefOptionsSetPrefix(opts, prefix);
+ c_api.TF_GraphImportGraphDef(_handle, graph_def, opts, status);
+ status.Check(true);
+ return status.Code == TF_Code.TF_OK;
}
public Graph ImportGraphDef(string file_path, string name = null)
diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs
index 2a982274..98cad3b2 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.cs
@@ -75,9 +75,9 @@ namespace Tensorflow
/// then create a TensorFlow session to run parts of the graph across a set of local and remote devices.
///
/// https://www.tensorflow.org/guide/graphs
https://www.tensorflow.org/api_docs/python/tf/Graph
- public partial class Graph : DisposableObject
- , IEnumerable
+ public partial class Graph : IEnumerable
{
+ protected new SafeGraphHandle _handle;
private Dictionary _nodes_by_id;
public Dictionary _nodes_by_name;
private Dictionary _names_in_use;
@@ -130,15 +130,6 @@ namespace Tensorflow
_graph_key = $"graph-{ops.GraphUniqueId()}/";
}
- public Graph(IntPtr handle)
- {
- _handle = handle;
- _nodes_by_id = new Dictionary();
- _nodes_by_name = new Dictionary();
- _names_in_use = new Dictionary();
- _graph_key = $"grap-{ops.GraphUniqueId()}/";
- }
-
public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true)
{
return _as_graph_element_locked(obj, allow_tensor, allow_operation);
@@ -486,16 +477,6 @@ namespace Tensorflow
_unfetchable_ops.Add(op);
}
- protected override void DisposeManagedResources()
- {
-
- }
-
- protected override void DisposeUnmanagedResources(IntPtr handle)
- {
- c_api.TF_DeleteGraph(handle);
- }
-
public Tensor get_tensor_by_tf_output(TF_Output tf_output)
{
var op = _get_operation_by_tf_operation(tf_output.oper);
@@ -517,14 +498,14 @@ namespace Tensorflow
public Shape GetTensorShape(TF_Output output)
{
var status = tf.Status;
- var ndim = c_api.TF_GraphGetTensorNumDims(_handle, output, status.Handle);
+ var ndim = c_api.TF_GraphGetTensorNumDims(_handle, output, status);
status.Check();
if (ndim == -1)
return Shape.Null;
var dims = new long[ndim];
- c_api.TF_GraphGetTensorShape(_handle, output, dims, dims.Length, status.Handle);
+ c_api.TF_GraphGetTensorShape(_handle, output, dims, dims.Length, status);
status.Check();
return new Shape(dims.Select(x => (int)x).ToArray());
@@ -539,7 +520,7 @@ namespace Tensorflow
string debugString = string.Empty;
public override string ToString()
{
- return $"{graph_key}, 0x{_handle.ToString("x16")}";
+ return $"{graph_key}, 0x{_handle.DangerousGetHandle().ToString("x16")}";
/*if (string.IsNullOrEmpty(debugString))
{
int len = 0;
@@ -558,7 +539,7 @@ namespace Tensorflow
IEnumerator IEnumerable.GetEnumerator()
=> throw new NotImplementedException();
- public static implicit operator IntPtr(Graph graph)
+ public static implicit operator SafeGraphHandle(Graph graph)
{
return graph._handle;
}
diff --git a/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs b/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs
index a04cf55a..859465fc 100644
--- a/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs
+++ b/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs
@@ -14,28 +14,27 @@
limitations under the License.
******************************************************************************/
-using System;
+namespace Tensorflow;
-namespace Tensorflow
+public sealed class ImportGraphDefOptions
{
- public sealed class ImportGraphDefOptions : IDisposable
- {
- public SafeImportGraphDefOptionsHandle Handle { get; }
+ SafeImportGraphDefOptionsHandle _handle { get; }
- public int NumReturnOutputs
- => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(Handle);
+ public int NumReturnOutputs
+ => c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle);
- public ImportGraphDefOptions()
- {
- Handle = c_api.TF_NewImportGraphDefOptions();
- }
+ public ImportGraphDefOptions()
+ {
+ _handle = c_api.TF_NewImportGraphDefOptions();
+ }
- public void AddReturnOutput(string name, int index)
- {
- c_api.TF_ImportGraphDefOptionsAddReturnOutput(Handle, name, index);
- }
+ public void AddReturnOutput(string name, int index)
+ {
+ c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index);
+ }
- public void Dispose()
- => Handle.Dispose();
+ public static implicit operator SafeImportGraphDefOptionsHandle(ImportGraphDefOptions opt)
+ {
+ return opt._handle;
}
}
diff --git a/src/TensorFlowNET.Core/Graphs/SafeFuncGraphHandle.cs b/src/TensorFlowNET.Core/Graphs/SafeFuncGraphHandle.cs
new file mode 100644
index 00000000..f38301b6
--- /dev/null
+++ b/src/TensorFlowNET.Core/Graphs/SafeFuncGraphHandle.cs
@@ -0,0 +1,22 @@
+using Tensorflow.Util;
+
+namespace Tensorflow;
+
+public sealed class SafeFuncGraphHandle : SafeTensorflowHandle
+{
+ private SafeFuncGraphHandle()
+ {
+ }
+
+ public SafeFuncGraphHandle(IntPtr handle)
+ : base(handle)
+ {
+ }
+
+ protected override bool ReleaseHandle()
+ {
+ c_api.TF_DeleteFunction(handle);
+ SetHandle(IntPtr.Zero);
+ return true;
+ }
+}
diff --git a/src/TensorFlowNET.Core/Graphs/SafeGraphHandle.cs b/src/TensorFlowNET.Core/Graphs/SafeGraphHandle.cs
new file mode 100644
index 00000000..a6da0198
--- /dev/null
+++ b/src/TensorFlowNET.Core/Graphs/SafeGraphHandle.cs
@@ -0,0 +1,22 @@
+using Tensorflow.Util;
+
+namespace Tensorflow;
+
+public sealed class SafeGraphHandle : SafeTensorflowHandle
+{
+ private SafeGraphHandle()
+ {
+ }
+
+ public SafeGraphHandle(IntPtr handle)
+ : base(handle)
+ {
+ }
+
+ protected override bool ReleaseHandle()
+ {
+ c_api.TF_DeleteGraph(handle);
+ SetHandle(IntPtr.Zero);
+ return true;
+ }
+}
diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs
index 6eb8f367..dc1827d8 100644
--- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs
+++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs
@@ -60,7 +60,7 @@ namespace Tensorflow
///
///
[DllImport(TensorFlowLibName)]
- public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, SafeStatusHandle status);
+ public static extern void TF_GraphGetTensorShape(SafeGraphHandle graph, TF_Output output, long[] dims, int num_dims, SafeStatusHandle status);
///
/// Import the graph serialized in `graph_def` into `graph`.
@@ -78,7 +78,7 @@ namespace Tensorflow
/// int
/// TF_Status*
[DllImport(TensorFlowLibName)]
- public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, IntPtr return_outputs, int num_return_outputs, SafeStatusHandle status);
+ public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(SafeGraphHandle graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, IntPtr return_outputs, int num_return_outputs, SafeStatusHandle status);
///
/// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and
@@ -92,7 +92,7 @@ namespace Tensorflow
/// TF_Status*
/// TF_ImportGraphDefResults*
[DllImport(TensorFlowLibName)]
- public static extern SafeImportGraphDefResultsHandle TF_GraphImportGraphDefWithResults(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status);
+ public static extern SafeImportGraphDefResultsHandle TF_GraphImportGraphDefWithResults(SafeGraphHandle graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status);
///
/// Import the graph serialized in `graph_def` into `graph`.
@@ -102,7 +102,7 @@ namespace Tensorflow
/// TF_ImportGraphDefOptions*
/// TF_Status*
[DllImport(TensorFlowLibName)]
- public static extern void TF_GraphImportGraphDef(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status);
+ public static extern void TF_GraphImportGraphDef(SafeGraphHandle graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status);
///
/// Iterate through the operations of a graph.
@@ -111,7 +111,7 @@ namespace Tensorflow
///
///
[DllImport(TensorFlowLibName)]
- public static extern IntPtr TF_GraphNextOperation(IntPtr graph, ref uint pos);
+ public static extern IntPtr TF_GraphNextOperation(SafeGraphHandle graph, ref uint pos);
///
/// Returns the operation in the graph with `oper_name`. Returns nullptr if
@@ -121,14 +121,14 @@ namespace Tensorflow
///
///
[DllImport(TensorFlowLibName)]
- public static extern IntPtr TF_GraphOperationByName(IntPtr graph, string oper_name);
+ public static extern IntPtr TF_GraphOperationByName(SafeGraphHandle graph, string oper_name);
///
/// Sets the shape of the Tensor referenced by `output` in `graph` to
/// the shape described by `dims` and `num_dims`.
///
[DllImport(TensorFlowLibName)]
- public static extern void TF_GraphSetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, SafeStatusHandle status);
+ public static extern void TF_GraphSetTensorShape(SafeGraphHandle graph, TF_Output output, long[] dims, int num_dims, SafeStatusHandle status);
///
/// Write out a serialized representation of `graph` (as a GraphDef protocol
@@ -138,7 +138,7 @@ namespace Tensorflow
/// TF_Buffer*
/// TF_Status*
[DllImport(TensorFlowLibName)]
- public static extern void TF_GraphToGraphDef(IntPtr graph, SafeBufferHandle output_graph_def, SafeStatusHandle status);
+ public static extern void TF_GraphToGraphDef(SafeGraphHandle graph, SafeBufferHandle output_graph_def, SafeStatusHandle status);
///
/// Returns the number of dimensions of the Tensor referenced by `output`
@@ -151,7 +151,7 @@ namespace Tensorflow
///
///
[DllImport(TensorFlowLibName)]
- public static extern int TF_GraphGetTensorNumDims(IntPtr graph, TF_Output output, SafeStatusHandle status);
+ public static extern int TF_GraphGetTensorNumDims(SafeGraphHandle graph, TF_Output output, SafeStatusHandle status);
///
/// Cause the imported graph to have a control dependency on `oper`. `oper`
@@ -287,12 +287,12 @@ namespace Tensorflow
/// TF_Status*
///
[DllImport(TensorFlowLibName)]
- public static extern IntPtr TF_LoadSessionFromSavedModel(SafeSessionOptionsHandle session_options, IntPtr run_options,
+ public static extern SafeSessionHandle TF_LoadSessionFromSavedModel(SafeSessionOptionsHandle session_options, IntPtr run_options,
string export_dir, string[] tags, int tags_len,
- IntPtr graph, IntPtr meta_graph_def, SafeStatusHandle status);
+ SafeGraphHandle graph, IntPtr meta_graph_def, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
- public static extern IntPtr TF_NewGraph();
+ public static extern SafeGraphHandle TF_NewGraph();
[DllImport(TensorFlowLibName)]
public static extern SafeImportGraphDefOptionsHandle TF_NewImportGraphDefOptions();
@@ -334,6 +334,6 @@ namespace Tensorflow
///
///
[DllImport(TensorFlowLibName)]
- public static extern bool TF_TryEvaluateConstant(IntPtr graph, TF_Output output, IntPtr[] result, SafeStatusHandle status);
+ public static extern bool TF_TryEvaluateConstant(SafeGraphHandle graph, TF_Output output, IntPtr[] result, SafeStatusHandle status);
}
}
diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs b/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs
index 7417476e..d9743ead 100644
--- a/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs
+++ b/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs
@@ -61,7 +61,7 @@ namespace Tensorflow.NumPy
{
if (_handle is not null)
{
- _eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle);
+ _eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status);
}
}
}
diff --git a/src/TensorFlowNET.Core/Operations/Operation.Input.cs b/src/TensorFlowNET.Core/Operations/Operation.Input.cs
index 44ac52e1..9aa6fde2 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.Input.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.Input.cs
@@ -31,7 +31,7 @@ namespace Tensorflow
public int InputListLength(string name)
{
int num = 0;
- num = c_api.TF_OperationInputListLength(_handle, name, tf.Status.Handle);
+ num = c_api.TF_OperationInputListLength(_handle, name, tf.Status);
tf.Status.Check(true);
return num;
}
diff --git a/src/TensorFlowNET.Core/Operations/Operation.Output.cs b/src/TensorFlowNET.Core/Operations/Operation.Output.cs
index b5d6191d..2955a13f 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.Output.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.Output.cs
@@ -28,7 +28,7 @@ namespace Tensorflow
public int OutputListLength(string name)
{
- int num = c_api.TF_OperationOutputListLength(_handle, name, tf.Status.Handle);
+ int num = c_api.TF_OperationOutputListLength(_handle, name, tf.Status);
tf.Status.Check(true);
return num;
diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs
index fb9a4a27..751ade5d 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.cs
@@ -187,8 +187,8 @@ namespace Tensorflow
if (tf.executing_eagerly())
return (T[])get_attr(name);
- using var buf = new Buffer();
- c_api.TF_OperationGetAttrValueProto(_handle, name, buf.Handle, tf.Status.Handle);
+ var buf = new Buffer();
+ c_api.TF_OperationGetAttrValueProto(_handle, name, buf, tf.Status);
tf.Status.Check(true);
var x = AttrValue.Parser.ParseFrom(buf.ToArray());
@@ -210,8 +210,8 @@ namespace Tensorflow
public virtual object get_attr(string name)
{
- using var buf = new Buffer();
- c_api.TF_OperationGetAttrValueProto(_handle, name, buf.Handle, tf.Status.Handle);
+ var buf = new Buffer();
+ c_api.TF_OperationGetAttrValueProto(_handle, name, buf, tf.Status);
tf.Status.Check(true);
var x = AttrValue.Parser.ParseFrom(buf.ToArray());
@@ -235,13 +235,13 @@ namespace Tensorflow
public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s)
{
- return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s.Handle);
+ return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s);
}
private NodeDef GetNodeDef()
{
- using var buffer = new Buffer();
- c_api.TF_OperationToNodeDef(_handle, buffer.Handle, tf.Status.Handle);
+ var buffer = new Buffer();
+ c_api.TF_OperationToNodeDef(_handle, buffer, tf.Status);
tf.Status.Check(throwException: true);
return NodeDef.Parser.ParseFrom(buffer.ToArray());
}
diff --git a/src/TensorFlowNET.Core/Operations/OperationDescription.cs b/src/TensorFlowNET.Core/Operations/OperationDescription.cs
index 384f5386..28df548d 100644
--- a/src/TensorFlowNET.Core/Operations/OperationDescription.cs
+++ b/src/TensorFlowNET.Core/Operations/OperationDescription.cs
@@ -50,7 +50,7 @@ namespace Tensorflow
public Operation FinishOperation(Status status)
{
- return c_api.TF_FinishOperation(_handle, status.Handle);
+ return c_api.TF_FinishOperation(_handle, status);
}
public static implicit operator OperationDescription(IntPtr handle)
diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs
index 0fc92454..900db8ca 100644
--- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs
+++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs
@@ -96,7 +96,7 @@ namespace Tensorflow
/// const char*
/// TF_OperationDescription*
[DllImport(TensorFlowLibName)]
- public static extern IntPtr TF_NewOperation(IntPtr graph, string opType, string oper_name);
+ public static extern IntPtr TF_NewOperation(SafeGraphHandle graph, string opType, string oper_name);
[DllImport(TensorFlowLibName)]
public static extern IntPtr TF_OperationDevice(IntPtr oper);
diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs
index 4e131b36..0a9cfc2e 100644
--- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs
+++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs
@@ -14,281 +14,272 @@
limitations under the License.
******************************************************************************/
-using Google.Protobuf;
using Tensorflow.NumPy;
-using System;
-using System.Collections;
-using System.Collections.Generic;
-using System.Linq;
-using System.Numerics;
-using System.Text;
-using Tensorflow.Util;
using static Tensorflow.Binding;
-namespace Tensorflow
+namespace Tensorflow;
+
+public class BaseSession : IDisposable
{
- public class BaseSession : DisposableObject
+ protected SafeSessionHandle _handle;
+ protected Graph _graph;
+ protected Status _status;
+ public Graph graph => _graph;
+
+ public BaseSession(SafeSessionHandle handle, Graph g)
{
- protected Graph _graph;
- protected Status _status;
- public Graph graph => _graph;
+ _handle = handle;
+ _graph = g ?? ops.get_default_graph();
+ }
- public BaseSession(IntPtr handle, Graph g)
+ public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null)
+ {
+ _graph = g ?? ops.get_default_graph();
+ if (!_graph.building_function)
{
- _handle = handle;
- _graph = g ?? ops.get_default_graph();
+ if (ops.get_default_graph() != _graph)
+ _graph.as_default();
}
+
+ var opts = new SessionOptions(target, config);
+ _status = status ?? tf.Status;
+ _handle = c_api.TF_NewSession(_graph, opts, _status);
+ _status.Check(true);
+ }
- public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null)
- {
- _graph = g ?? ops.get_default_graph();
- if (!_graph.building_function)
- {
- if (ops.get_default_graph() != _graph)
- _graph.as_default();
- }
-
- using var opts = new SessionOptions(target, config);
- _status = status ?? tf.Status;
- _handle = c_api.TF_NewSession(_graph, opts.Handle, _status.Handle);
- _status.Check(true);
- }
+ public virtual void run(Operation op, params FeedItem[] feed_dict)
+ {
+ _run(op, feed_dict);
+ }
- public virtual void run(Operation op, params FeedItem[] feed_dict)
- {
- _run(op, feed_dict);
- }
+ public virtual NDArray run(Tensor fetche, params FeedItem[] feed_dict)
+ {
+ return _run(fetche, feed_dict)[0];
+ }
- public virtual NDArray run(Tensor fetche, params FeedItem[] feed_dict)
- {
- return _run(fetche, feed_dict)[0];
- }
+ public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict)
+ {
+ var results = _run(fetche, feed_dict);
+ return fetche is Tensor ? results[0] : null;
+ }
- public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict)
- {
- var results = _run(fetche, feed_dict);
- return fetche is Tensor ? results[0] : null;
- }
+ public virtual (NDArray, NDArray, NDArray, NDArray, NDArray) run(
+ (ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches,
+ params FeedItem[] feed_dict)
+ {
+ var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4, fetches.Item5 }, feed_dict);
+ return (results[0], results[1], results[2], results[3], results[4]);
+ }
- public virtual (NDArray, NDArray, NDArray, NDArray, NDArray) run(
- (ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches,
- params FeedItem[] feed_dict)
- {
- var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4, fetches.Item5 }, feed_dict);
- return (results[0], results[1], results[2], results[3], results[4]);
- }
+ public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
+ {
+ var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict);
+ return (results[0], results[1], results[2], results[3]);
+ }
- public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
- {
- var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict);
- return (results[0], results[1], results[2], results[3]);
- }
+ public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
+ {
+ var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict);
+ return (results[0], results[1], results[2]);
+ }
- public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
- {
- var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict);
- return (results[0], results[1], results[2]);
- }
+ public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
+ {
+ var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict);
+ return (results[0], results[1]);
+ }
- public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
- {
- var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict);
- return (results[0], results[1]);
- }
+ public virtual NDArray[] run(object fetches, params FeedItem[] feed_dict)
+ {
+ return _run(fetches, feed_dict);
+ }
- public virtual NDArray[] run(object fetches, params FeedItem[] feed_dict)
- {
- return _run(fetches, feed_dict);
- }
+ public virtual NDArray[] run(object fetches, Hashtable feed_dict = null)
+ {
+ var feed_items = feed_dict == null ? new FeedItem[0] : feed_dict.Keys.OfType