Browse Source

More SafeHandles.

tags/v0.100.5-BERT-load
Haiping Chen 2 years ago
parent
commit
e5dfe90e9b
69 changed files with 1524 additions and 1826 deletions
  1. +3
    -5
      src/TensorFlowNET.Console/MemoryMonitor.cs
  2. +15
    -10
      src/TensorFlowNET.Core/Buffers/Buffer.cs
  3. +3
    -3
      src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs
  4. +5
    -5
      src/TensorFlowNET.Core/Contexts/Context.Device.cs
  5. +14
    -17
      src/TensorFlowNET.Core/Contexts/Context.cs
  6. +11
    -11
      src/TensorFlowNET.Core/Contexts/ContextOptions.cs
  7. +3
    -3
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs
  8. +9
    -9
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
  9. +2
    -2
      src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
  10. +7
    -7
      src/TensorFlowNET.Core/Eager/EagerTensor.cs
  11. +1
    -1
      src/TensorFlowNET.Core/Eager/c_api.eager.cs
  12. +13
    -14
      src/TensorFlowNET.Core/Framework/importer.cs
  13. +1
    -1
      src/TensorFlowNET.Core/Framework/op_def_registry.py.cs
  14. +2
    -2
      src/TensorFlowNET.Core/Framework/smart_module.cs
  15. +5
    -5
      src/TensorFlowNET.Core/Functions/c_api.function.cs
  16. +1
    -1
      src/TensorFlowNET.Core/Gradients/c_api.gradient.cs
  17. +12
    -14
      src/TensorFlowNET.Core/GraphTransformation/GraphTransformer.cs
  18. +7
    -11
      src/TensorFlowNET.Core/Graphs/AutoGraph.cs
  19. +204
    -210
      src/TensorFlowNET.Core/Graphs/FuncGraph.cs
  20. +7
    -9
      src/TensorFlowNET.Core/Graphs/Graph.Export.cs
  21. +9
    -10
      src/TensorFlowNET.Core/Graphs/Graph.Import.cs
  22. +6
    -25
      src/TensorFlowNET.Core/Graphs/Graph.cs
  23. +16
    -17
      src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs
  24. +22
    -0
      src/TensorFlowNET.Core/Graphs/SafeFuncGraphHandle.cs
  25. +22
    -0
      src/TensorFlowNET.Core/Graphs/SafeGraphHandle.cs
  26. +13
    -13
      src/TensorFlowNET.Core/Graphs/c_api.graph.cs
  27. +1
    -1
      src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs
  28. +1
    -1
      src/TensorFlowNET.Core/Operations/Operation.Input.cs
  29. +1
    -1
      src/TensorFlowNET.Core/Operations/Operation.Output.cs
  30. +7
    -7
      src/TensorFlowNET.Core/Operations/Operation.cs
  31. +1
    -1
      src/TensorFlowNET.Core/Operations/OperationDescription.cs
  32. +1
    -1
      src/TensorFlowNET.Core/Operations/c_api.ops.cs
  33. +221
    -230
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  34. +46
    -0
      src/TensorFlowNET.Core/Sessions/SafeSessionHandle.cs
  35. +38
    -64
      src/TensorFlowNET.Core/Sessions/Session.cs
  36. +12
    -12
      src/TensorFlowNET.Core/Sessions/SessionOptions.cs
  37. +2
    -2
      src/TensorFlowNET.Core/Sessions/c_api.session.cs
  38. +14
    -12
      src/TensorFlowNET.Core/Status/Status.cs
  39. +4
    -4
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  40. +10
    -12
      src/TensorFlowNET.Core/Training/Saving/saver.py.cs
  41. +1
    -1
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  42. +1
    -1
      src/TensorFlowNET.Core/ops.cs
  43. +7
    -9
      src/TensorFlowNet.Benchmarks/Leak/SavedModelCleanup.cs
  44. +38
    -46
      test/TensorFlowNET.Graph.UnitTest/Basics/QueueTest.cs
  45. +9
    -15
      test/TensorFlowNET.Graph.UnitTest/Basics/SessionTest.cs
  46. +21
    -29
      test/TensorFlowNET.Graph.UnitTest/Basics/TensorTest.cs
  47. +1
    -1
      test/TensorFlowNET.Graph.UnitTest/Basics/VariableTest.cs
  48. +18
    -22
      test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/CondTestCases.cs
  49. +12
    -14
      test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/WhileContextTestCase.cs
  50. +47
    -66
      test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs
  51. +12
    -14
      test/TensorFlowNET.Graph.UnitTest/ImageTest.cs
  52. +8
    -8
      test/TensorFlowNET.Graph.UnitTest/MultithreadingTests.cs
  53. +453
    -675
      test/TensorFlowNET.Graph.UnitTest/OperationsTest.cs
  54. +14
    -16
      test/TensorFlowNET.Graph.UnitTest/PythonTest.cs
  55. +2
    -4
      test/TensorFlowNET.Native.UnitTest/Attributes/AttributesTestcs.cs
  56. +1
    -3
      test/TensorFlowNET.Native.UnitTest/CApiColocationTest.cs
  57. +2
    -2
      test/TensorFlowNET.Native.UnitTest/CApiTest.cs
  58. +3
    -3
      test/TensorFlowNET.Native.UnitTest/Functions/FunctionTest.cs
  59. +10
    -17
      test/TensorFlowNET.Native.UnitTest/Gradients/GradientsTest.cs
  60. +1
    -1
      test/TensorFlowNET.Native.UnitTest/Graphs/GraphBuildTest.cs
  61. +15
    -27
      test/TensorFlowNET.Native.UnitTest/Graphs/GraphTest.cs
  62. +4
    -7
      test/TensorFlowNET.Native.UnitTest/Sessions/CSession.cs
  63. +2
    -2
      test/TensorFlowNET.Native.UnitTest/Sessions/SessionTest.cs
  64. +16
    -17
      test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs
  65. +20
    -28
      test/TensorFlowNET.Native.UnitTest/c_test_util.cs
  66. +16
    -22
      test/TensorFlowNET.UnitTest/Basics/TrainSaverTest.cs
  67. +4
    -6
      test/TensorFlowNET.UnitTest/ManagedAPI/ControlFlowApiTest.cs
  68. +14
    -16
      test/TensorFlowNET.UnitTest/PythonTest.cs
  69. +0
    -1
      test/TensorFlowNET.UnitTest/StatusTest.cs

+ 3
- 5
src/TensorFlowNET.Console/MemoryMonitor.cs View File

@@ -23,11 +23,9 @@ namespace Tensorflow
var x = tf.placeholder(tf.float64, shape: (1024, 1024)); var x = tf.placeholder(tf.float64, shape: (1024, 1024));
var log = tf.log(x); var log = tf.log(x);


using (var sess = tf.Session())
{
var ones = np.ones((1024, 1024), dtype: np.float64);
var o = sess.run(log, new FeedItem(x, ones));
}
var sess = tf.Session();
var ones = np.ones((1024, 1024), dtype: np.float64);
var o = sess.run(log, new FeedItem(x, ones));
// Thread.Sleep(1); // Thread.Sleep(1);
} }




+ 15
- 10
src/TensorFlowNET.Core/Buffers/Buffer.cs View File

@@ -25,15 +25,15 @@ namespace Tensorflow
/// <summary> /// <summary>
/// Represents a TF_Buffer that can be passed to Tensorflow. /// Represents a TF_Buffer that can be passed to Tensorflow.
/// </summary> /// </summary>
public sealed class Buffer : IDisposable
public sealed class Buffer
{ {
public SafeBufferHandle Handle { get; }
SafeBufferHandle _handle;


/// <remarks> /// <remarks>
/// <inheritdoc cref="SafeHandleLease" path="/devdoc/usage"/> /// <inheritdoc cref="SafeHandleLease" path="/devdoc/usage"/>
/// </remarks> /// </remarks>
private unsafe ref readonly TF_Buffer DangerousBuffer private unsafe ref readonly TF_Buffer DangerousBuffer
=> ref Unsafe.AsRef<TF_Buffer>(Handle.DangerousGetHandle().ToPointer());
=> ref Unsafe.AsRef<TF_Buffer>(_handle.DangerousGetHandle().ToPointer());


/// <summary> /// <summary>
/// The memory block representing this buffer. /// The memory block representing this buffer.
@@ -59,7 +59,7 @@ namespace Tensorflow
{ {
get get
{ {
using (Handle.Lease())
using (_handle.Lease())
{ {
return DangerousBuffer.length; return DangerousBuffer.length;
} }
@@ -67,13 +67,13 @@ namespace Tensorflow
} }


public Buffer() public Buffer()
=> Handle = TF_NewBuffer();
=> _handle = TF_NewBuffer();


public Buffer(SafeBufferHandle handle) public Buffer(SafeBufferHandle handle)
=> Handle = handle;
=> _handle = handle;


public Buffer(byte[] data) public Buffer(byte[] data)
=> Handle = _toBuffer(data);
=> _handle = _toBuffer(data);


private static SafeBufferHandle _toBuffer(byte[] data) private static SafeBufferHandle _toBuffer(byte[] data)
{ {
@@ -92,7 +92,7 @@ namespace Tensorflow
/// </summary> /// </summary>
public unsafe byte[] ToArray() public unsafe byte[] ToArray()
{ {
using (Handle.Lease())
using (_handle.Lease())
{ {
ref readonly TF_Buffer buffer = ref DangerousBuffer; ref readonly TF_Buffer buffer = ref DangerousBuffer;


@@ -107,7 +107,12 @@ namespace Tensorflow
} }
} }


public void Dispose()
=> Handle.Dispose();
public override string ToString()
=> $"0x{_handle.DangerousGetHandle():x16}";

public static implicit operator SafeBufferHandle(Buffer buffer)
{
return buffer._handle;
}
} }
} }

+ 3
- 3
src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs View File

@@ -11,7 +11,7 @@ public class CheckpointReader
Status status = new Status(); Status status = new Status();
VariableToDataTypeMap = new Dictionary<string, TF_DataType>(); VariableToDataTypeMap = new Dictionary<string, TF_DataType>();
VariableToShapeMap = new Dictionary<string, Shape>(); VariableToShapeMap = new Dictionary<string, Shape>();
_handle = c_api.TF_NewCheckpointReader(filename, status.Handle);
_handle = c_api.TF_NewCheckpointReader(filename, status);
status.Check(true); status.Check(true);
ReadAllShapeAndType(); ReadAllShapeAndType();
} }
@@ -38,7 +38,7 @@ public class CheckpointReader
int num_dims = GetVariableNumDims(name); int num_dims = GetVariableNumDims(name);
long[] dims = new long[num_dims]; long[] dims = new long[num_dims];
Status status = new Status(); Status status = new Status();
c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status.Handle);
c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status);
status.Check(true); status.Check(true);
return new Shape(dims); return new Shape(dims);
} }
@@ -49,7 +49,7 @@ public class CheckpointReader
public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid) public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid)
{ {
Status status = new Status(); Status status = new Status();
var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status.Handle);
var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status);
status.Check(true); status.Check(true);
return new Tensor(tensor); return new Tensor(tensor);
} }


+ 5
- 5
src/TensorFlowNET.Core/Contexts/Context.Device.cs View File

@@ -37,7 +37,7 @@ namespace Tensorflow.Contexts
public void log_device_placement(bool enable) public void log_device_placement(bool enable)
{ {
if (_handle != null) if (_handle != null)
c_api.TFE_ContextSetLogDevicePlacement(_handle, enable, tf.Status.Handle);
c_api.TFE_ContextSetLogDevicePlacement(_handle, enable, tf.Status);
_log_device_placement = enable; _log_device_placement = enable;
// _thread_local_data.function_call_options = null; // _thread_local_data.function_call_options = null;
} }
@@ -60,15 +60,15 @@ namespace Tensorflow.Contexts
public PhysicalDevice[] list_physical_devices(string device_type = null) public PhysicalDevice[] list_physical_devices(string device_type = null)
{ {
using var opts = c_api.TFE_NewContextOptions(); using var opts = c_api.TFE_NewContextOptions();
using var ctx = c_api.TFE_NewContext(opts, tf.Status.Handle);
using var devices = c_api.TFE_ContextListDevices(ctx, tf.Status.Handle);
using var ctx = c_api.TFE_NewContext(opts, tf.Status);
using var devices = c_api.TFE_ContextListDevices(ctx, tf.Status);
tf.Status.Check(true); tf.Status.Check(true);


int num_devices = c_api.TF_DeviceListCount(devices); int num_devices = c_api.TF_DeviceListCount(devices);
var results = new List<PhysicalDevice>(); var results = new List<PhysicalDevice>();
for (int i = 0; i < num_devices; ++i) for (int i = 0; i < num_devices; ++i)
{ {
var dev_type = c_api.StringPiece(c_api.TF_DeviceListType(devices, i, tf.Status.Handle));
var dev_type = c_api.StringPiece(c_api.TF_DeviceListType(devices, i, tf.Status));
tf.Status.Check(true); tf.Status.Check(true);


if (dev_type.StartsWith("XLA")) if (dev_type.StartsWith("XLA"))
@@ -76,7 +76,7 @@ namespace Tensorflow.Contexts


if (device_type == null || dev_type == device_type) if (device_type == null || dev_type == device_type)
{ {
var dev_name = c_api.TF_DeviceListName(devices, i, tf.Status.Handle);
var dev_name = c_api.TF_DeviceListName(devices, i, tf.Status);
tf.Status.Check(true); tf.Status.Check(true);


results.Add(new PhysicalDevice results.Add(new PhysicalDevice


+ 14
- 17
src/TensorFlowNET.Core/Contexts/Context.cs View File

@@ -28,7 +28,7 @@ namespace Tensorflow.Contexts
/// <summary> /// <summary>
/// Environment in which eager operations execute. /// Environment in which eager operations execute.
/// </summary> /// </summary>
public sealed partial class Context : IDisposable
public sealed partial class Context
{ {
public const int GRAPH_MODE = 0; public const int GRAPH_MODE = 0;
public const int EAGER_MODE = 1; public const int EAGER_MODE = 1;
@@ -41,15 +41,7 @@ namespace Tensorflow.Contexts
public FunctionCallOptions FunctionCallOptions { get; } public FunctionCallOptions FunctionCallOptions { get; }


SafeContextHandle _handle; SafeContextHandle _handle;
public SafeContextHandle Handle
{
get
{
if (_handle == null)
ensure_initialized();
return _handle;
}
}

int? _seed; int? _seed;
Random _rng; Random _rng;


@@ -59,6 +51,7 @@ namespace Tensorflow.Contexts
context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE, false); context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE, false);
initialized = false; initialized = false;
FunctionCallOptions = new FunctionCallOptions(); FunctionCallOptions = new FunctionCallOptions();
ensure_initialized();
} }


/// <summary> /// <summary>
@@ -72,12 +65,12 @@ namespace Tensorflow.Contexts
Config = MergeConfig(); Config = MergeConfig();
FunctionCallOptions.Config = Config; FunctionCallOptions.Config = Config;
var config_str = Config.ToByteArray(); var config_str = Config.ToByteArray();
using var opts = new ContextOptions();
using var status = new Status();
c_api.TFE_ContextOptionsSetConfig(opts.Handle, config_str, (ulong)config_str.Length, status.Handle);
var opts = new ContextOptions();
var status = new Status();
c_api.TFE_ContextOptionsSetConfig(opts, config_str, (ulong)config_str.Length, status);
status.Check(true); status.Check(true);
c_api.TFE_ContextOptionsSetDevicePlacementPolicy(opts.Handle, _device_policy);
_handle = c_api.TFE_NewContext(opts.Handle, status.Handle);
c_api.TFE_ContextOptionsSetDevicePlacementPolicy(opts, _device_policy);
_handle = c_api.TFE_NewContext(opts, status);
status.Check(true); status.Check(true);
initialized = true; initialized = true;
} }
@@ -178,10 +171,14 @@ namespace Tensorflow.Contexts
tf.Context.ensure_initialized(); tf.Context.ensure_initialized();


if (_handle != null) if (_handle != null)
{
c_api.TFE_ContextClearCaches(_handle); c_api.TFE_ContextClearCaches(_handle);
}
} }


public void Dispose()
=> _handle.Dispose();
public static implicit operator SafeContextHandle(Context ctx)
{
return ctx._handle;
}
} }
} }

+ 11
- 11
src/TensorFlowNET.Core/Contexts/ContextOptions.cs View File

@@ -14,21 +14,21 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System;
using Tensorflow.Eager; using Tensorflow.Eager;


namespace Tensorflow.Contexts
namespace Tensorflow.Contexts;

public sealed class ContextOptions
{ {
public sealed class ContextOptions : IDisposable
{
public SafeContextOptionsHandle Handle { get; }
SafeContextOptionsHandle _handle { get; }


public ContextOptions()
{
Handle = c_api.TFE_NewContextOptions();
}
public ContextOptions()
{
_handle = c_api.TFE_NewContextOptions();
}


public void Dispose()
=> Handle.Dispose();
public static implicit operator SafeContextOptionsHandle(ContextOptions opt)
{
return opt._handle;
} }
} }

+ 3
- 3
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs View File

@@ -43,7 +43,7 @@ namespace Tensorflow.Eager
{ {
var status = tf.Status; var status = tf.Status;
var op = GetOp(ctx, op_name, status); var op = GetOp(ctx, op_name, status);
c_api.TFE_OpSetDevice(op, device_name, status.Handle);
c_api.TFE_OpSetDevice(op, device_name, status);
if (status.ok()) if (status.ok())
{ {
for (int i = 0; i < inputs.Length; ++i) for (int i = 0; i < inputs.Length; ++i)
@@ -54,7 +54,7 @@ namespace Tensorflow.Eager
Tensor nd => nd.EagerTensorHandle, Tensor nd => nd.EagerTensorHandle,
_ => throw new NotImplementedException("Eager tensor handle has not been allocated.") _ => throw new NotImplementedException("Eager tensor handle has not been allocated.")
}; };
c_api.TFE_OpAddInput(op, tensor_handle, status.Handle);
c_api.TFE_OpAddInput(op, tensor_handle, status);
status.Check(true); status.Check(true);
} }
} }
@@ -64,7 +64,7 @@ namespace Tensorflow.Eager
var outputs = new SafeEagerTensorHandle[num_outputs]; var outputs = new SafeEagerTensorHandle[num_outputs];
if (status.ok()) if (status.ok())
{ {
c_api.TFE_Execute(op, outputs, out num_outputs, status.Handle);
c_api.TFE_Execute(op, outputs, out num_outputs, status);
status.Check(true); status.Check(true);
} }
return outputs.Select(x => new EagerTensor(x)).ToArray(); return outputs.Select(x => new EagerTensor(x)).ToArray();


+ 9
- 9
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs View File

@@ -104,7 +104,7 @@ namespace Tensorflow.Eager
var eager_tensor = ops.convert_to_tensor(fast_input_array[j]); var eager_tensor = ops.convert_to_tensor(fast_input_array[j]);
attr_values[j] = eager_tensor.dtype; attr_values[j] = eager_tensor.dtype;


c_api.TFE_OpAddInput(op, eager_tensor.EagerTensorHandle, status.Handle);
c_api.TFE_OpAddInput(op, eager_tensor.EagerTensorHandle, status);


if (op_exec_info.run_callbacks) if (op_exec_info.run_callbacks)
{ {
@@ -142,7 +142,7 @@ namespace Tensorflow.Eager
} }


var retVals = new SafeEagerTensorHandle[num_retvals]; var retVals = new SafeEagerTensorHandle[num_retvals];
c_api.TFE_Execute(op, retVals, out num_retvals, status.Handle);
c_api.TFE_Execute(op, retVals, out num_retvals, status);
status.Check(true); status.Check(true);


var flat_result = retVals.Select(x => new EagerTensor(x)).ToArray(); var flat_result = retVals.Select(x => new EagerTensor(x)).ToArray();
@@ -160,10 +160,10 @@ namespace Tensorflow.Eager
SafeEagerOpHandle GetOp(Context ctx, string op_or_function_name, Status status) SafeEagerOpHandle GetOp(Context ctx, string op_or_function_name, Status status)
{ {
if (thread_local_eager_operation_map.find(op_or_function_name, out var op)) if (thread_local_eager_operation_map.find(op_or_function_name, out var op))
c_api.TFE_OpReset(op, op_or_function_name, ctx.DeviceName, status.Handle);
c_api.TFE_OpReset(op, op_or_function_name, ctx.DeviceName, status);
else else
{ {
op = c_api.TFE_NewOp(ctx.Handle, op_or_function_name, status.Handle);
op = c_api.TFE_NewOp(ctx, op_or_function_name, status);
thread_local_eager_operation_map[op_or_function_name] = op; thread_local_eager_operation_map[op_or_function_name] = op;
} }


@@ -219,7 +219,7 @@ namespace Tensorflow.Eager
flattened_attrs.Add(dtype); flattened_attrs.Add(dtype);
} }


c_api.TFE_OpAddInput(op, tensor.EagerTensorHandle, status.Handle);
c_api.TFE_OpAddInput(op, tensor.EagerTensorHandle, status);
status.Check(true); status.Check(true);


return true; return true;
@@ -235,7 +235,7 @@ namespace Tensorflow.Eager
var value = attrs[i + 1]; var value = attrs[i + 1];


byte is_list = 0; byte is_list = 0;
var type = c_api.TFE_OpGetAttrType(op, key, ref is_list, status.Handle);
var type = c_api.TFE_OpGetAttrType(op, key, ref is_list, status);
if (!status.ok()) return; if (!status.ok()) return;
if (is_list != 0) if (is_list != 0)
SetOpAttrList(tf.Context, op, key, value as object[], type, null, status); SetOpAttrList(tf.Context, op, key, value as object[], type, null, status);
@@ -264,7 +264,7 @@ namespace Tensorflow.Eager
Status status) Status status)
{ {
byte is_list = 0; byte is_list = 0;
var type = c_api.TFE_OpGetAttrType(op, attr_name, ref is_list, status.Handle);
var type = c_api.TFE_OpGetAttrType(op, attr_name, ref is_list, status);
if (status.Code != TF_Code.TF_OK) return; if (status.Code != TF_Code.TF_OK) return;


if (attr_value == null) if (attr_value == null)
@@ -305,7 +305,7 @@ namespace Tensorflow.Eager
tf.memcpy(dims[i], values1[i].dims, values1[i].ndim * sizeof(long)); tf.memcpy(dims[i], values1[i].dims, values1[i].ndim * sizeof(long));
} }


c_api.TFE_OpSetAttrShapeList(op, key, dims, num_dims, num_values, status.Handle);
c_api.TFE_OpSetAttrShapeList(op, key, dims, num_dims, num_values, status);
Array.ForEach(dims, x => Marshal.FreeHGlobal(x)); Array.ForEach(dims, x => Marshal.FreeHGlobal(x));
} }
else if (type == TF_AttrType.TF_ATTR_TYPE && values is TF_DataType[] values2) else if (type == TF_AttrType.TF_ATTR_TYPE && values is TF_DataType[] values2)
@@ -353,7 +353,7 @@ namespace Tensorflow.Eager
break; break;
case TF_AttrType.TF_ATTR_SHAPE: case TF_AttrType.TF_ATTR_SHAPE:
var dims = (value as long[]).ToArray(); var dims = (value as long[]).ToArray();
c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status.Handle);
c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status);
status.Check(true); status.Check(true);
break; break;
case TF_AttrType.TF_ATTR_FUNC: case TF_AttrType.TF_ATTR_FUNC:


+ 2
- 2
src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs View File

@@ -54,7 +54,7 @@ namespace Tensorflow.Eager
void NewEagerTensorHandle(SafeTensorHandle h) void NewEagerTensorHandle(SafeTensorHandle h)
{ {
_id = ops.uid(); _id = ops.uid();
_eagerTensorHandle = c_api.TFE_NewTensorHandle(h, tf.Status.Handle);
_eagerTensorHandle = c_api.TFE_NewTensorHandle(h, tf.Status);
#if TRACK_TENSOR_LIFE #if TRACK_TENSOR_LIFE
Console.WriteLine($"New EagerTensor {_eagerTensorHandle}"); Console.WriteLine($"New EagerTensor {_eagerTensorHandle}");
#endif #endif
@@ -65,7 +65,7 @@ namespace Tensorflow.Eager
{ {
if (_handle != null) if (_handle != null)
return; return;
_handle = c_api.TFE_TensorHandleResolve(_eagerTensorHandle, tf.Status.Handle);
_handle = c_api.TFE_TensorHandleResolve(_eagerTensorHandle, tf.Status);
tf.Status.Check(true); tf.Status.Check(true);
} }




+ 7
- 7
src/TensorFlowNET.Core/Eager/EagerTensor.cs View File

@@ -24,10 +24,10 @@ namespace Tensorflow.Eager
} }
} }


public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(_eagerTensorHandle, tf.Status.Handle));
public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(_eagerTensorHandle, tf.Status));
public override TF_DataType dtype => c_api.TFE_TensorHandleDataType(_eagerTensorHandle); public override TF_DataType dtype => c_api.TFE_TensorHandleDataType(_eagerTensorHandle);


public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.Status.Handle);
public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.Status);


public override ulong bytesize public override ulong bytesize
{ {
@@ -49,9 +49,9 @@ namespace Tensorflow.Eager


protected override Shape GetShapeInternal() protected override Shape GetShapeInternal()
{ {
var dims = new int[c_api.TFE_TensorHandleNumDims(_eagerTensorHandle, tf.Status.Handle)];
var dims = new int[c_api.TFE_TensorHandleNumDims(_eagerTensorHandle, tf.Status)];
for (int i = 0; i < dims.Length; i++) for (int i = 0; i < dims.Length; i++)
dims[i] = c_api.TFE_TensorHandleDim(_eagerTensorHandle, i, tf.Status.Handle);
dims[i] = c_api.TFE_TensorHandleDim(_eagerTensorHandle, i, tf.Status);
return dims; return dims;
} }


@@ -64,15 +64,15 @@ namespace Tensorflow.Eager
public static int GetRank(IntPtr handle) public static int GetRank(IntPtr handle)
{ {
var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle);
return c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, tf.Status.Handle);
return c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, tf.Status);
} }


public static int[] GetDims(IntPtr handle) public static int[] GetDims(IntPtr handle)
{ {
var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle);
var dims = new int[c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, tf.Status.Handle)];
var dims = new int[c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, tf.Status)];
for (int i = 0; i < dims.Length; i++) for (int i = 0; i < dims.Length; i++)
dims[i] = c_api.TFE_TensorHandleDim(tfe_tensor_handle, i, tf.Status.Handle);
dims[i] = c_api.TFE_TensorHandleDim(tfe_tensor_handle, i, tf.Status);
return dims; return dims;
} }




+ 1
- 1
src/TensorFlowNET.Core/Eager/c_api.eager.cs View File

@@ -114,7 +114,7 @@ namespace Tensorflow
/// <param name="function"></param> /// <param name="function"></param>
/// <param name="status"></param> /// <param name="status"></param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TFE_ContextAddFunction(SafeContextHandle ctx, IntPtr function, SafeStatusHandle status);
public static extern void TFE_ContextAddFunction(SafeContextHandle ctx, SafeFuncGraphHandle function, SafeStatusHandle status);


/// <summary> /// <summary>
/// Removes a function from the context. Once removed, you can no longer /// Removes a function from the context. Once removed, you can no longer


+ 13
- 14
src/TensorFlowNET.Core/Framework/importer.cs View File

@@ -56,15 +56,14 @@ namespace Tensorflow


TF_ImportGraphDefResults results = null; TF_ImportGraphDefResults results = null;
var bytes = graph_def.ToByteString().ToArray(); var bytes = graph_def.ToByteString().ToArray();
using (var buffer = c_api_util.tf_buffer(bytes))
using (var scoped_options = c_api_util.ScopedTFImportGraphDefOptions())
using (var status = new Status())
{
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements);
// need to create a class ImportGraphDefWithResults with IDisposal
results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer.Handle, scoped_options.Handle, status.Handle));
status.Check(true);
}
var buffer = c_api_util.tf_buffer(bytes);
var scoped_options = c_api_util.ScopedTFImportGraphDefOptions();
var status = new Status();
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements);
// need to create a class ImportGraphDefWithResults with IDisposal
results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status));
status.Check(true);


_ProcessNewOps(graph); _ProcessNewOps(graph);


@@ -116,13 +115,13 @@ namespace Tensorflow
Dictionary<string, Tensor> input_map, Dictionary<string, Tensor> input_map,
string[] return_elements) string[] return_elements)
{ {
c_api.TF_ImportGraphDefOptionsSetPrefix(options.Handle, prefix);
c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options.Handle, (char)1);
c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix);
c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, (char)1);


foreach (var input in input_map) foreach (var input in input_map)
{ {
var (src_name, src_index) = _ParseTensorName(input.Key); var (src_name, src_index) = _ParseTensorName(input.Key);
c_api.TF_ImportGraphDefOptionsAddInputMapping(options.Handle, src_name, src_index, input.Value._as_tf_output());
c_api.TF_ImportGraphDefOptionsAddInputMapping(options, src_name, src_index, input.Value._as_tf_output());
} }


if (return_elements == null) if (return_elements == null)
@@ -133,11 +132,11 @@ namespace Tensorflow
if (name.Contains(":")) if (name.Contains(":"))
{ {
var (op_name, index) = _ParseTensorName(name); var (op_name, index) = _ParseTensorName(name);
c_api.TF_ImportGraphDefOptionsAddReturnOutput(options.Handle, op_name, index);
c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index);
} }
else else
{ {
c_api.TF_ImportGraphDefOptionsAddReturnOperation(options.Handle, name);
c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, name);
} }
} }




+ 1
- 1
src/TensorFlowNET.Core/Framework/op_def_registry.py.cs View File

@@ -33,7 +33,7 @@ namespace Tensorflow
if (_registered_ops.Count > 0) if (_registered_ops.Count > 0)
return _registered_ops; return _registered_ops;


using var buffer = new Buffer(c_api.TF_GetAllOpList());
var buffer = new Buffer(c_api.TF_GetAllOpList());
var op_list = OpList.Parser.ParseFrom(buffer.ToArray()); var op_list = OpList.Parser.ParseFrom(buffer.ToArray());
foreach (var op_def in op_list.Op) foreach (var op_def in op_list.Op)
_registered_ops[op_def.Name] = op_def; _registered_ops[op_def.Name] = op_def;


+ 2
- 2
src/TensorFlowNET.Core/Framework/smart_module.cs View File

@@ -56,8 +56,8 @@ namespace Tensorflow.Framework
if (pred_value is null) if (pred_value is null)
{ {
var result = range(pred.op.NumOutputs).Select(x => IntPtr.Zero).ToArray(); var result = range(pred.op.NumOutputs).Select(x => IntPtr.Zero).ToArray();
var evaluated = c_api.TF_TryEvaluateConstant(pred.graph, pred._as_tf_output(), result, tf.Status.Handle);
if (!evaluated || c_api.TF_GetCode(tf.Status.Handle) != TF_Code.TF_OK)
var evaluated = c_api.TF_TryEvaluateConstant(pred.graph, pred._as_tf_output(), result, tf.Status);
if (!evaluated || c_api.TF_GetCode(tf.Status) != TF_Code.TF_OK)
return null; return null;
else else
throw new NotImplementedException(""); throw new NotImplementedException("");


+ 5
- 5
src/TensorFlowNET.Core/Functions/c_api.function.cs View File

@@ -34,10 +34,10 @@ namespace Tensorflow
/// <param name="output_func_def"></param> /// <param name="output_func_def"></param>
/// <param name="status"></param> /// <param name="status"></param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TF_FunctionToFunctionDef(IntPtr func, SafeBufferHandle output_func_def, SafeStatusHandle status);
public static extern void TF_FunctionToFunctionDef(SafeFuncGraphHandle func, SafeBufferHandle output_func_def, SafeStatusHandle status);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TF_GraphToFunction(IntPtr fn_body, string fn_name,
public static extern SafeFuncGraphHandle TF_GraphToFunction(SafeGraphHandle fn_body, string fn_name,
bool append_hash_to_fn_name, bool append_hash_to_fn_name,
int num_opers, IntPtr[] opers, int num_opers, IntPtr[] opers,
int ninputs, TF_Output[] inputs, int ninputs, TF_Output[] inputs,
@@ -48,12 +48,12 @@ namespace Tensorflow
SafeStatusHandle status); SafeStatusHandle status);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TF_FunctionSetAttrValueProto(IntPtr func, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status);
public static extern IntPtr TF_FunctionSetAttrValueProto(SafeFuncGraphHandle func, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TF_FunctionName(IntPtr func);
public static extern IntPtr TF_FunctionName(SafeFuncGraphHandle func);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TF_GraphCopyFunction(IntPtr g, IntPtr func, IntPtr grad, SafeStatusHandle status);
public static extern void TF_GraphCopyFunction(SafeGraphHandle g, SafeFuncGraphHandle func, IntPtr grad, SafeStatusHandle status);
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Gradients/c_api.gradient.cs View File

@@ -37,7 +37,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
/// <param name="dy">TF_Output*</param> /// <param name="dy">TF_Output*</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TF_AddGradientsWithPrefix(IntPtr g, string prefix, TF_Output[] y, int ny,
public static extern void TF_AddGradientsWithPrefix(SafeGraphHandle g, string prefix, TF_Output[] y, int ny,
TF_Output[] x, int nx, TF_Output[] dx, SafeStatusHandle status, IntPtr[] dy); TF_Output[] x, int nx, TF_Output[] dx, SafeStatusHandle status, IntPtr[] dy);
} }
} }

+ 12
- 14
src/TensorFlowNET.Core/GraphTransformation/GraphTransformer.cs View File

@@ -22,21 +22,19 @@ namespace Tensorflow
var inputs_string = string.Join(",", inputs); var inputs_string = string.Join(",", inputs);
var outputs_string = string.Join(",", outputs); var outputs_string = string.Join(",", outputs);
var transforms_string = string.Join(" ", transforms); var transforms_string = string.Join(" ", transforms);
using (var status = new Status())
{
var buffer = new Buffer();
var len = c_api.TransformGraphWithStringInputs(input_graph_def_string,
input_graph_def_string.Length,
inputs_string,
outputs_string,
transforms_string,
buffer.Handle,
status.Handle);
var status = new Status();
var buffer = new Buffer();
var len = c_api.TransformGraphWithStringInputs(input_graph_def_string,
input_graph_def_string.Length,
inputs_string,
outputs_string,
transforms_string,
buffer,
status);


status.Check(false);
var bytes = buffer.ToArray();
return GraphDef.Parser.ParseFrom(bytes);
}
status.Check(false);
var bytes = buffer.ToArray();
return GraphDef.Parser.ParseFrom(bytes);
} }
} }
} }

+ 7
- 11
src/TensorFlowNET.Core/Graphs/AutoGraph.cs View File

@@ -37,11 +37,9 @@ namespace Tensorflow.Graphs
1); 1);
return result[0]; return result[0];
} }
using (var s = tf.Session(input.graph))
{
var output = func(input);
return output;
}
var s = tf.Session(input.graph);
var output = func(input);
return output;
}; };
} }


@@ -75,12 +73,10 @@ namespace Tensorflow.Graphs
1); 1);
return result[0]; return result[0];
} }
using (var s = tf.Session(a.graph))
{
Debug.Assert(a.graph == b.graph);
var output = func(a, b);
return output;
}
var s = tf.Session(a.graph);
Debug.Assert(a.graph == b.graph);
var output = func(a, b);
return output;
}; };
} }
} }


+ 204
- 210
src/TensorFlowNET.Core/Graphs/FuncGraph.cs View File

@@ -1,258 +1,252 @@
using Google.Protobuf; using Google.Protobuf;
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Eager; using Tensorflow.Eager;
using Tensorflow.Exceptions; using Tensorflow.Exceptions;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow.Graphs
namespace Tensorflow.Graphs;

/// <summary>
/// Graph representing a function body.
/// </summary>
public class FuncGraph : Graph, IDisposable
{ {
SafeFuncGraphHandle _func_graph_handle;
public string FuncName => _graph_key;

public Tensors Inputs { get; set; } = new Tensors();
public Tensors Outputs { get; set; } = new Tensors();
public Dictionary<string, string> Attrs { get; set; }

Dictionary<long, (Tensor, Tensor)> _captures
= new Dictionary<long, (Tensor, Tensor)>();

public Tensor[] external_captures
=> _captures.Select(x => x.Value.Item1).ToArray();
public (Tensor, Tensor)[] captures
=> _captures.Values.Select(x => x).ToArray();

public Tensor[] internal_captures
=> _captures.Select(x => x.Value.Item2).ToArray();

public Tensor[] captured_inputs
=> external_captures;

/// <summary> /// <summary>
/// Graph representing a function body.
/// Construct a new FuncGraph.
/// </summary> /// </summary>
public class FuncGraph : Graph
public FuncGraph(string name) : base()
{ {
IntPtr _func_graph_handle;
public string FuncName => _graph_key;

public Tensors Inputs { get; set; } = new Tensors();
public Tensors Outputs { get; set; } = new Tensors();
public Dictionary<string, string> Attrs { get; set; }
outer_graph = ops.get_default_graph();
while (outer_graph.building_function)
outer_graph = outer_graph.OuterGraph;
_graph_key = name;
building_function = true;
}


Dictionary<long, (Tensor, Tensor)> _captures
= new Dictionary<long, (Tensor, Tensor)>();
public FuncGraph(SafeGraphHandle handle, string name, Dictionary<string, string> attrs) : base()
{
outer_graph = ops.get_default_graph();
while (outer_graph.building_function)
outer_graph = outer_graph.OuterGraph;
_graph_key = name;
building_function = true;
Attrs = attrs;
// Will to test if FuncGraph has memory leak
// c_api.TF_DeleteGraph(_handle);
_handle = handle;
}


public Tensor[] external_captures
=> _captures.Select(x => x.Value.Item1).ToArray();
public (Tensor, Tensor)[] captures
=> _captures.Values.Select(x => x).ToArray();
public void ToGraph(Operation[] opers,
Tensor[] inputs, Tensor[] outputs,
string[] output_names)
{
var status = new Status();
_func_graph_handle = c_api.TF_GraphToFunction(_handle,
_graph_key,
false,
opers.Length,
opers.Select(x => (IntPtr)x).ToArray(),
inputs.Length,
inputs.Select(x => new TF_Output(x.op, 0)).ToArray(),
outputs.Length,
outputs.Select(x => new TF_Output(x.op, 0)).ToArray(),
output_names,
IntPtr.Zero,
null,
status);
status.Check(true);

SetAttrs();

// c_api.TF_GraphCopyFunction(outer_graph, _func_graph_handle, IntPtr.Zero, status.Handle);
// status.Check(true);

c_api.TFE_ContextAddFunction(tf.Context, _func_graph_handle, status);
status.Check(true);

_graph_key = c_api.StringPiece(c_api.TF_FunctionName(_func_graph_handle));

Inputs = inputs;
// mark_as_return
Outputs = outputs;// .Select(x => array_ops.identity(x)).ToArray();
}


public Tensor[] internal_captures
=> _captures.Select(x => x.Value.Item2).ToArray();
public override Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, string name = null, Dictionary<string, AttrValue> attrs = null, OpDef op_def = null, bool compute_device = true)
{
foreach(var (i, inp) in enumerate(inputs))
inputs[i] = capture(inp);


public Tensor[] captured_inputs
=> external_captures;
return base.create_op(op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device);
}


/// <summary>
/// Construct a new FuncGraph.
/// </summary>
public FuncGraph(string name) : base()
const int _EAGER_CONST_THRESHOLD = 128;
public Tensor capture(Tensor tensor, string name = null, Shape shape = null)
{
if(tensor is EagerTensor)
{ {
outer_graph = ops.get_default_graph();
while (outer_graph.building_function)
outer_graph = outer_graph.OuterGraph;
_graph_key = name;
building_function = true;
if (name == null)
name = ops.uid().ToString();

// Small EagerTensors are captured with Const ops
if (dtypes.is_value_dtype(tensor.dtype)
&& (tensor.rank == 0 || tensor.size < _EAGER_CONST_THRESHOLD))
return capture_eager_tensor(tensor, name);

// Large EagerTensors and resources are captured with Placeholder ops
return _capture_helper(tensor, name, shape: shape);
} }


public FuncGraph(IntPtr handle, string name, Dictionary<string, string> attrs) : base()
if(tensor.graph != this)
{ {
outer_graph = ops.get_default_graph();
while (outer_graph.building_function)
outer_graph = outer_graph.OuterGraph;
_graph_key = name;
building_function = true;
Attrs = attrs;
// Will to test if FuncGraph has memory leak
// c_api.TF_DeleteGraph(_handle);
_handle = handle;
if (name == null)
name = tensor.op.name;
var inner_graph = tensor.graph;
while(inner_graph != null && inner_graph is FuncGraph inner_func_graph)
{
if (inner_graph == this)
throw new InaccessibleTensorError($"The tensor '{tensor.name}' cannot be accessed here: it is defined" +
" in another function or code block. Use return values," +
" explicit Python locals or TensorFlow collections to access" +
$" it. Defined in: {tensor.graph.graph_key}; accessed from: {graph_key}.");
inner_graph = inner_func_graph.outer_graph;
}
return _capture_helper(tensor, name);
} }


public void ToGraph(Operation[] opers,
Tensor[] inputs, Tensor[] outputs,
string[] output_names)
return tensor;
}

Tensor capture_eager_tensor(Tensor tensor, string name)
{
Tensor graph_const = null;
if (!_captures.ContainsKey(tensor.Id))
{ {
var status = new Status();
_func_graph_handle = c_api.TF_GraphToFunction(_handle,
_graph_key,
false,
opers.Length,
opers.Select(x => (IntPtr)x).ToArray(),
inputs.Length,
inputs.Select(x => new TF_Output(x.op, 0)).ToArray(),
outputs.Length,
outputs.Select(x => new TF_Output(x.op, 0)).ToArray(),
output_names == null || output_names.Length == 0 ? null : output_names,
IntPtr.Zero,
null,
status.Handle);
status.Check(true);

SetAttrs();

// c_api.TF_GraphCopyFunction(outer_graph, _func_graph_handle, IntPtr.Zero, status.Handle);
// status.Check(true);

c_api.TFE_ContextAddFunction(tf.Context.Handle, _func_graph_handle, status.Handle);
status.Check(true);

_graph_key = c_api.StringPiece(c_api.TF_FunctionName(_func_graph_handle));

Inputs = inputs;
// mark_as_return
Outputs = outputs;// .Select(x => array_ops.identity(x)).ToArray();
graph_const = tf_with(ops.control_dependencies(null), ctl
=> constant_op.constant(tensor.numpy(), dtype: tensor.dtype, shape: tensor.shape, name: name));
add_capture(tensor, graph_const);
} }

public override Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, string name = null, Dictionary<string, AttrValue> attrs = null, OpDef op_def = null, bool compute_device = true)
else
{ {
foreach(var (i, inp) in enumerate(inputs))
inputs[i] = capture(inp);

return base.create_op(op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device);
graph_const = _captures[tensor.Id].Item2;
} }


const int _EAGER_CONST_THRESHOLD = 128;
public Tensor capture(Tensor tensor, string name = null, Shape shape = null)
BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) =>
{ {
if(tensor is EagerTensor)
{
if (name == null)
name = ops.uid().ToString();

// Small EagerTensors are captured with Const ops
if (dtypes.is_value_dtype(tensor.dtype)
&& (tensor.rank == 0 || tensor.size < _EAGER_CONST_THRESHOLD))
return capture_eager_tensor(tensor, name);
return output_grads;
};


// Large EagerTensors and resources are captured with Placeholder ops
return _capture_helper(tensor, name, shape: shape);
}
tf.Runner.RecordGradient("captured_value",
new[] { graph_const }, null,
new[] { tensor },
getBackwardFunction: _backward_function_wrapper
/*getForwardFunction: forward_function*/);


if(tensor.graph != this)
{
if (name == null)
name = tensor.op.name;
var inner_graph = tensor.graph;
while(inner_graph != null && inner_graph is FuncGraph inner_func_graph)
{
if (inner_graph == this)
throw new InaccessibleTensorError($"The tensor '{tensor.name}' cannot be accessed here: it is defined" +
" in another function or code block. Use return values," +
" explicit Python locals or TensorFlow collections to access" +
$" it. Defined in: {tensor.graph.graph_key}; accessed from: {graph_key}.");
inner_graph = inner_func_graph.outer_graph;
}
return _capture_helper(tensor, name);
}
return graph_const;
}


return tensor;
Tensor _capture_helper(Tensor tensor, string name, Shape shape = null)
{
Tensor placeholder = null;
if (!_captures.ContainsKey(tensor.Id))
{
placeholder = _create_substitute_placeholder(tensor,
name: name,
dtype: tensor.dtype,
shape: shape);
add_capture(tensor, placeholder);
} }

Tensor capture_eager_tensor(Tensor tensor, string name)
else
{ {
Tensor graph_const = null;
if (!_captures.ContainsKey(tensor.Id))
{
graph_const = tf_with(ops.control_dependencies(null), ctl
=> constant_op.constant(tensor.numpy(), dtype: tensor.dtype, shape: tensor.shape, name: name));
add_capture(tensor, graph_const);
}
else
{
graph_const = _captures[tensor.Id].Item2;
}

BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) =>
{
return output_grads;
};

tf.Runner.RecordGradient("captured_value",
new[] { graph_const }, null,
new[] { tensor },
getBackwardFunction: _backward_function_wrapper
/*getForwardFunction: forward_function*/);

return graph_const;
placeholder = _captures[tensor.Id].Item2;
} }


Tensor _capture_helper(Tensor tensor, string name, Shape shape = null)
BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) =>
{ {
Tensor placeholder = null;
if (!_captures.ContainsKey(tensor.Id))
{
placeholder = _create_substitute_placeholder(tensor,
name: name,
dtype: tensor.dtype,
shape: shape);
add_capture(tensor, placeholder);
}
else
{
placeholder = _captures[tensor.Id].Item2;
}
return output_grads;
};


BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) =>
{
return output_grads;
};
tf.Runner.RecordGradient("captured_value",
new[] { placeholder }, null,
new[] { tensor },
getBackwardFunction: _backward_function_wrapper
/*getForwardFunction: forward_function*/);


tf.Runner.RecordGradient("captured_value",
new[] { placeholder }, null,
new[] { tensor },
getBackwardFunction: _backward_function_wrapper
/*getForwardFunction: forward_function*/);
return placeholder;
}


return placeholder;
}
void add_capture(Tensor tensor, Tensor placeholder)
{
_captures.Add(tensor.Id, (tensor, placeholder));
Inputs.Add(placeholder);
}


void add_capture(Tensor tensor, Tensor placeholder)
{
_captures.Add(tensor.Id, (tensor, placeholder));
Inputs.Add(placeholder);
}
Tensor _create_substitute_placeholder(Tensor value,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,
Shape shape = null)
{
if (shape is null)
shape = value.shape;
if (dtype == TF_DataType.DtInvalid)
dtype = value.dtype;

var placeholder = tf_with(ops.control_dependencies(null), ctl
=> array_ops.placeholder(dtype, shape: shape, name: name));
// custom_gradient.copy_handle_data(value, placeholder)
return placeholder;
}


Tensor _create_substitute_placeholder(Tensor value,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,
Shape shape = null)
{
if (shape is null)
shape = value.shape;
if (dtype == TF_DataType.DtInvalid)
dtype = value.dtype;

var placeholder = tf_with(ops.control_dependencies(null), ctl
=> array_ops.placeholder(dtype, shape: shape, name: name));
// custom_gradient.copy_handle_data(value, placeholder)
return placeholder;
}
void SetAttrs()
{
if (Attrs == null)
return;


void SetAttrs()
foreach (var (_name, attr_value) in enumerate(Attrs))
{ {
if (Attrs == null)
return;

foreach (var (_name, attr_value) in enumerate(Attrs))
var serialized = new AttrValue
{ {
var serialized = new AttrValue
{
S = ByteString.CopyFromUtf8(attr_value)
}.ToByteArray();
c_api.TF_FunctionSetAttrValueProto(_func_graph_handle, _name, serialized, serialized.Length, tf.Status.Handle);
tf.Status.Check(true);
}
S = ByteString.CopyFromUtf8(attr_value)
}.ToByteArray();
c_api.TF_FunctionSetAttrValueProto(_func_graph_handle, _name, serialized, serialized.Length, tf.Status);
tf.Status.Check(true);
} }
}


public override Graph as_default()
{
tf.Context.graph_mode(isFunc: true);
ops.set_default_graph(this);
return this;
}
public override Graph as_default()
{
tf.Context.graph_mode(isFunc: true);
ops.set_default_graph(this);
return this;
}


public override void Exit()
{
tf.Context.restore_mode();
ops.pop_graph();
}
public override void Exit()
{
tf.Context.restore_mode();
ops.pop_graph();
}


protected override void DisposeUnmanagedResources(IntPtr handle)
{
c_api.TFE_ContextRemoveFunction(tf.Context.Handle, _graph_key, tf.Status.Handle);
c_api.TF_DeleteFunction(_func_graph_handle);
base.DisposeUnmanagedResources(handle);
}
public void Dispose()
{
c_api.TFE_ContextRemoveFunction(tf.Context, _graph_key, tf.Status);
} }
} }

+ 7
- 9
src/TensorFlowNET.Core/Graphs/Graph.Export.cs View File

@@ -24,7 +24,7 @@ namespace Tensorflow
public Buffer ToGraphDef(Status s) public Buffer ToGraphDef(Status s)
{ {
var buffer = new Buffer(); var buffer = new Buffer();
c_api.TF_GraphToGraphDef(_handle, buffer.Handle, s.Handle);
c_api.TF_GraphToGraphDef(_handle, buffer, s);
s.Check(true); s.Check(true);


return buffer; return buffer;
@@ -33,14 +33,12 @@ namespace Tensorflow
private GraphDef _as_graph_def(bool add_shapes = false) private GraphDef _as_graph_def(bool add_shapes = false)
{ {
GraphDef def; GraphDef def;
using (var status = new Status())
using (var buffer = ToGraphDef(status))
{
status.Check(true);
// limit size to 250M, recursion to max 100
var inputStream = CodedInputStream.CreateWithLimits(buffer.DangerousMemoryBlock, 250 * 1024 * 1024, 100);
def = GraphDef.Parser.ParseFrom(inputStream);
}
var status = new Status();
var buffer = ToGraphDef(status);
status.Check(true);
// limit size to 250M, recursion to max 100
var inputStream = CodedInputStream.CreateWithLimits(buffer.DangerousMemoryBlock, 250 * 1024 * 1024, 100);
def = GraphDef.Parser.ParseFrom(inputStream);


// Strip the experimental library field iff it's empty. // Strip the experimental library field iff it's empty.
// if(def.Library.Function.Count == 0) // if(def.Library.Function.Count == 0)


+ 9
- 10
src/TensorFlowNET.Core/Graphs/Graph.Import.cs View File

@@ -29,7 +29,7 @@ namespace Tensorflow
int size = Marshal.SizeOf<TF_Output>(); int size = Marshal.SizeOf<TF_Output>();
var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs); var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs);


c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def.Handle, opts.Handle, return_output_handle, num_return_outputs, s.Handle);
c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s);


var tf_output_ptr = (TF_Output*)return_output_handle; var tf_output_ptr = (TF_Output*)return_output_handle;
for (int i = 0; i < num_return_outputs; i++) for (int i = 0; i < num_return_outputs; i++)
@@ -48,15 +48,14 @@ namespace Tensorflow


public bool Import(byte[] bytes, string prefix = "") public bool Import(byte[] bytes, string prefix = "")
{ {
using (var opts = new ImportGraphDefOptions())
using (var status = new Status())
using (var graph_def = new Buffer(bytes))
{
c_api.TF_ImportGraphDefOptionsSetPrefix(opts.Handle, prefix);
c_api.TF_GraphImportGraphDef(_handle, graph_def.Handle, opts.Handle, status.Handle);
status.Check(true);
return status.Code == TF_Code.TF_OK;
}
var opts = new ImportGraphDefOptions();
var status = new Status();
var graph_def = new Buffer(bytes);

c_api.TF_ImportGraphDefOptionsSetPrefix(opts, prefix);
c_api.TF_GraphImportGraphDef(_handle, graph_def, opts, status);
status.Check(true);
return status.Code == TF_Code.TF_OK;
} }


public Graph ImportGraphDef(string file_path, string name = null) public Graph ImportGraphDef(string file_path, string name = null)


+ 6
- 25
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -75,9 +75,9 @@ namespace Tensorflow
/// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices.
/// </summary> /// </summary>
/// <remarks>https://www.tensorflow.org/guide/graphs <br></br>https://www.tensorflow.org/api_docs/python/tf/Graph</remarks> /// <remarks>https://www.tensorflow.org/guide/graphs <br></br>https://www.tensorflow.org/api_docs/python/tf/Graph</remarks>
public partial class Graph : DisposableObject
, IEnumerable<Operation>
public partial class Graph : IEnumerable<Operation>
{ {
protected new SafeGraphHandle _handle;
private Dictionary<int, ITensorOrOperation> _nodes_by_id; private Dictionary<int, ITensorOrOperation> _nodes_by_id;
public Dictionary<string, ITensorOrOperation> _nodes_by_name; public Dictionary<string, ITensorOrOperation> _nodes_by_name;
private Dictionary<string, int> _names_in_use; private Dictionary<string, int> _names_in_use;
@@ -130,15 +130,6 @@ namespace Tensorflow
_graph_key = $"graph-{ops.GraphUniqueId()}/"; _graph_key = $"graph-{ops.GraphUniqueId()}/";
} }


public Graph(IntPtr handle)
{
_handle = handle;
_nodes_by_id = new Dictionary<int, ITensorOrOperation>();
_nodes_by_name = new Dictionary<string, ITensorOrOperation>();
_names_in_use = new Dictionary<string, int>();
_graph_key = $"grap-{ops.GraphUniqueId()}/";
}

public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true)
{ {
return _as_graph_element_locked(obj, allow_tensor, allow_operation); return _as_graph_element_locked(obj, allow_tensor, allow_operation);
@@ -486,16 +477,6 @@ namespace Tensorflow
_unfetchable_ops.Add(op); _unfetchable_ops.Add(op);
} }


protected override void DisposeManagedResources()
{
}

protected override void DisposeUnmanagedResources(IntPtr handle)
{
c_api.TF_DeleteGraph(handle);
}

public Tensor get_tensor_by_tf_output(TF_Output tf_output) public Tensor get_tensor_by_tf_output(TF_Output tf_output)
{ {
var op = _get_operation_by_tf_operation(tf_output.oper); var op = _get_operation_by_tf_operation(tf_output.oper);
@@ -517,14 +498,14 @@ namespace Tensorflow
public Shape GetTensorShape(TF_Output output) public Shape GetTensorShape(TF_Output output)
{ {
var status = tf.Status; var status = tf.Status;
var ndim = c_api.TF_GraphGetTensorNumDims(_handle, output, status.Handle);
var ndim = c_api.TF_GraphGetTensorNumDims(_handle, output, status);
status.Check(); status.Check();


if (ndim == -1) if (ndim == -1)
return Shape.Null; return Shape.Null;


var dims = new long[ndim]; var dims = new long[ndim];
c_api.TF_GraphGetTensorShape(_handle, output, dims, dims.Length, status.Handle);
c_api.TF_GraphGetTensorShape(_handle, output, dims, dims.Length, status);
status.Check(); status.Check();


return new Shape(dims.Select(x => (int)x).ToArray()); return new Shape(dims.Select(x => (int)x).ToArray());
@@ -539,7 +520,7 @@ namespace Tensorflow
string debugString = string.Empty; string debugString = string.Empty;
public override string ToString() public override string ToString()
{ {
return $"{graph_key}, 0x{_handle.ToString("x16")}";
return $"{graph_key}, 0x{_handle.DangerousGetHandle().ToString("x16")}";
/*if (string.IsNullOrEmpty(debugString)) /*if (string.IsNullOrEmpty(debugString))
{ {
int len = 0; int len = 0;
@@ -558,7 +539,7 @@ namespace Tensorflow
IEnumerator IEnumerable.GetEnumerator() IEnumerator IEnumerable.GetEnumerator()
=> throw new NotImplementedException(); => throw new NotImplementedException();


public static implicit operator IntPtr(Graph graph)
public static implicit operator SafeGraphHandle(Graph graph)
{ {
return graph._handle; return graph._handle;
} }


+ 16
- 17
src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs View File

@@ -14,28 +14,27 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System;
namespace Tensorflow;


namespace Tensorflow
public sealed class ImportGraphDefOptions
{ {
public sealed class ImportGraphDefOptions : IDisposable
{
public SafeImportGraphDefOptionsHandle Handle { get; }
SafeImportGraphDefOptionsHandle _handle { get; }


public int NumReturnOutputs
=> c_api.TF_ImportGraphDefOptionsNumReturnOutputs(Handle);
public int NumReturnOutputs
=> c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle);


public ImportGraphDefOptions()
{
Handle = c_api.TF_NewImportGraphDefOptions();
}
public ImportGraphDefOptions()
{
_handle = c_api.TF_NewImportGraphDefOptions();
}


public void AddReturnOutput(string name, int index)
{
c_api.TF_ImportGraphDefOptionsAddReturnOutput(Handle, name, index);
}
public void AddReturnOutput(string name, int index)
{
c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index);
}


public void Dispose()
=> Handle.Dispose();
public static implicit operator SafeImportGraphDefOptionsHandle(ImportGraphDefOptions opt)
{
return opt._handle;
} }
} }

+ 22
- 0
src/TensorFlowNET.Core/Graphs/SafeFuncGraphHandle.cs View File

@@ -0,0 +1,22 @@
using Tensorflow.Util;

namespace Tensorflow;

public sealed class SafeFuncGraphHandle : SafeTensorflowHandle
{
private SafeFuncGraphHandle()
{
}

public SafeFuncGraphHandle(IntPtr handle)
: base(handle)
{
}

protected override bool ReleaseHandle()
{
c_api.TF_DeleteFunction(handle);
SetHandle(IntPtr.Zero);
return true;
}
}

+ 22
- 0
src/TensorFlowNET.Core/Graphs/SafeGraphHandle.cs View File

@@ -0,0 +1,22 @@
using Tensorflow.Util;

namespace Tensorflow;

public sealed class SafeGraphHandle : SafeTensorflowHandle
{
private SafeGraphHandle()
{
}

public SafeGraphHandle(IntPtr handle)
: base(handle)
{
}

protected override bool ReleaseHandle()
{
c_api.TF_DeleteGraph(handle);
SetHandle(IntPtr.Zero);
return true;
}
}

+ 13
- 13
src/TensorFlowNET.Core/Graphs/c_api.graph.cs View File

@@ -60,7 +60,7 @@ namespace Tensorflow
/// <param name="num_dims"></param> /// <param name="num_dims"></param>
/// <param name="status"></param> /// <param name="status"></param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, SafeStatusHandle status);
public static extern void TF_GraphGetTensorShape(SafeGraphHandle graph, TF_Output output, long[] dims, int num_dims, SafeStatusHandle status);


/// <summary> /// <summary>
/// Import the graph serialized in `graph_def` into `graph`. /// Import the graph serialized in `graph_def` into `graph`.
@@ -78,7 +78,7 @@ namespace Tensorflow
/// <param name="num_return_outputs">int</param> /// <param name="num_return_outputs">int</param>
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, IntPtr return_outputs, int num_return_outputs, SafeStatusHandle status);
public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(SafeGraphHandle graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, IntPtr return_outputs, int num_return_outputs, SafeStatusHandle status);


/// <summary> /// <summary>
/// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and /// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and
@@ -92,7 +92,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
/// <returns>TF_ImportGraphDefResults*</returns> /// <returns>TF_ImportGraphDefResults*</returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern SafeImportGraphDefResultsHandle TF_GraphImportGraphDefWithResults(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status);
public static extern SafeImportGraphDefResultsHandle TF_GraphImportGraphDefWithResults(SafeGraphHandle graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status);


/// <summary> /// <summary>
/// Import the graph serialized in `graph_def` into `graph`. /// Import the graph serialized in `graph_def` into `graph`.
@@ -102,7 +102,7 @@ namespace Tensorflow
/// <param name="options">TF_ImportGraphDefOptions*</param> /// <param name="options">TF_ImportGraphDefOptions*</param>
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TF_GraphImportGraphDef(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status);
public static extern void TF_GraphImportGraphDef(SafeGraphHandle graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status);


/// <summary> /// <summary>
/// Iterate through the operations of a graph. /// Iterate through the operations of a graph.
@@ -111,7 +111,7 @@ namespace Tensorflow
/// <param name="pos"></param> /// <param name="pos"></param>
/// <returns></returns> /// <returns></returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TF_GraphNextOperation(IntPtr graph, ref uint pos);
public static extern IntPtr TF_GraphNextOperation(SafeGraphHandle graph, ref uint pos);


/// <summary> /// <summary>
/// Returns the operation in the graph with `oper_name`. Returns nullptr if /// Returns the operation in the graph with `oper_name`. Returns nullptr if
@@ -121,14 +121,14 @@ namespace Tensorflow
/// <param name="oper_name"></param> /// <param name="oper_name"></param>
/// <returns></returns> /// <returns></returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TF_GraphOperationByName(IntPtr graph, string oper_name);
public static extern IntPtr TF_GraphOperationByName(SafeGraphHandle graph, string oper_name);


/// <summary> /// <summary>
/// Sets the shape of the Tensor referenced by `output` in `graph` to /// Sets the shape of the Tensor referenced by `output` in `graph` to
/// the shape described by `dims` and `num_dims`. /// the shape described by `dims` and `num_dims`.
/// </summary> /// </summary>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TF_GraphSetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, SafeStatusHandle status);
public static extern void TF_GraphSetTensorShape(SafeGraphHandle graph, TF_Output output, long[] dims, int num_dims, SafeStatusHandle status);


/// <summary> /// <summary>
/// Write out a serialized representation of `graph` (as a GraphDef protocol /// Write out a serialized representation of `graph` (as a GraphDef protocol
@@ -138,7 +138,7 @@ namespace Tensorflow
/// <param name="output_graph_def">TF_Buffer*</param> /// <param name="output_graph_def">TF_Buffer*</param>
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TF_GraphToGraphDef(IntPtr graph, SafeBufferHandle output_graph_def, SafeStatusHandle status);
public static extern void TF_GraphToGraphDef(SafeGraphHandle graph, SafeBufferHandle output_graph_def, SafeStatusHandle status);


/// <summary> /// <summary>
/// Returns the number of dimensions of the Tensor referenced by `output` /// Returns the number of dimensions of the Tensor referenced by `output`
@@ -151,7 +151,7 @@ namespace Tensorflow
/// <param name="status"></param> /// <param name="status"></param>
/// <returns></returns> /// <returns></returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern int TF_GraphGetTensorNumDims(IntPtr graph, TF_Output output, SafeStatusHandle status);
public static extern int TF_GraphGetTensorNumDims(SafeGraphHandle graph, TF_Output output, SafeStatusHandle status);


/// <summary> /// <summary>
/// Cause the imported graph to have a control dependency on `oper`. `oper` /// Cause the imported graph to have a control dependency on `oper`. `oper`
@@ -287,12 +287,12 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
/// <returns></returns> /// <returns></returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TF_LoadSessionFromSavedModel(SafeSessionOptionsHandle session_options, IntPtr run_options,
public static extern SafeSessionHandle TF_LoadSessionFromSavedModel(SafeSessionOptionsHandle session_options, IntPtr run_options,
string export_dir, string[] tags, int tags_len, string export_dir, string[] tags, int tags_len,
IntPtr graph, IntPtr meta_graph_def, SafeStatusHandle status);
SafeGraphHandle graph, IntPtr meta_graph_def, SafeStatusHandle status);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TF_NewGraph();
public static extern SafeGraphHandle TF_NewGraph();


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern SafeImportGraphDefOptionsHandle TF_NewImportGraphDefOptions(); public static extern SafeImportGraphDefOptionsHandle TF_NewImportGraphDefOptions();
@@ -334,6 +334,6 @@ namespace Tensorflow
/// <param name="status"></param> /// <param name="status"></param>
/// <returns></returns> /// <returns></returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern bool TF_TryEvaluateConstant(IntPtr graph, TF_Output output, IntPtr[] result, SafeStatusHandle status);
public static extern bool TF_TryEvaluateConstant(SafeGraphHandle graph, TF_Output output, IntPtr[] result, SafeStatusHandle status);
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs View File

@@ -61,7 +61,7 @@ namespace Tensorflow.NumPy
{ {
if (_handle is not null) if (_handle is not null)
{ {
_eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle);
_eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status);
} }
} }
} }


+ 1
- 1
src/TensorFlowNET.Core/Operations/Operation.Input.cs View File

@@ -31,7 +31,7 @@ namespace Tensorflow
public int InputListLength(string name) public int InputListLength(string name)
{ {
int num = 0; int num = 0;
num = c_api.TF_OperationInputListLength(_handle, name, tf.Status.Handle);
num = c_api.TF_OperationInputListLength(_handle, name, tf.Status);
tf.Status.Check(true); tf.Status.Check(true);
return num; return num;
} }


+ 1
- 1
src/TensorFlowNET.Core/Operations/Operation.Output.cs View File

@@ -28,7 +28,7 @@ namespace Tensorflow


public int OutputListLength(string name) public int OutputListLength(string name)
{ {
int num = c_api.TF_OperationOutputListLength(_handle, name, tf.Status.Handle);
int num = c_api.TF_OperationOutputListLength(_handle, name, tf.Status);
tf.Status.Check(true); tf.Status.Check(true);


return num; return num;


+ 7
- 7
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -187,8 +187,8 @@ namespace Tensorflow
if (tf.executing_eagerly()) if (tf.executing_eagerly())
return (T[])get_attr(name); return (T[])get_attr(name);


using var buf = new Buffer();
c_api.TF_OperationGetAttrValueProto(_handle, name, buf.Handle, tf.Status.Handle);
var buf = new Buffer();
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, tf.Status);
tf.Status.Check(true); tf.Status.Check(true);


var x = AttrValue.Parser.ParseFrom(buf.ToArray()); var x = AttrValue.Parser.ParseFrom(buf.ToArray());
@@ -210,8 +210,8 @@ namespace Tensorflow


public virtual object get_attr(string name) public virtual object get_attr(string name)
{ {
using var buf = new Buffer();
c_api.TF_OperationGetAttrValueProto(_handle, name, buf.Handle, tf.Status.Handle);
var buf = new Buffer();
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, tf.Status);
tf.Status.Check(true); tf.Status.Check(true);


var x = AttrValue.Parser.ParseFrom(buf.ToArray()); var x = AttrValue.Parser.ParseFrom(buf.ToArray());
@@ -235,13 +235,13 @@ namespace Tensorflow


public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s) public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s)
{ {
return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s.Handle);
return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s);
} }


private NodeDef GetNodeDef() private NodeDef GetNodeDef()
{ {
using var buffer = new Buffer();
c_api.TF_OperationToNodeDef(_handle, buffer.Handle, tf.Status.Handle);
var buffer = new Buffer();
c_api.TF_OperationToNodeDef(_handle, buffer, tf.Status);
tf.Status.Check(throwException: true); tf.Status.Check(throwException: true);
return NodeDef.Parser.ParseFrom(buffer.ToArray()); return NodeDef.Parser.ParseFrom(buffer.ToArray());
} }


+ 1
- 1
src/TensorFlowNET.Core/Operations/OperationDescription.cs View File

@@ -50,7 +50,7 @@ namespace Tensorflow


public Operation FinishOperation(Status status) public Operation FinishOperation(Status status)
{ {
return c_api.TF_FinishOperation(_handle, status.Handle);
return c_api.TF_FinishOperation(_handle, status);
} }


public static implicit operator OperationDescription(IntPtr handle) public static implicit operator OperationDescription(IntPtr handle)


+ 1
- 1
src/TensorFlowNET.Core/Operations/c_api.ops.cs View File

@@ -96,7 +96,7 @@ namespace Tensorflow
/// <param name="oper_name">const char*</param> /// <param name="oper_name">const char*</param>
/// <returns>TF_OperationDescription*</returns> /// <returns>TF_OperationDescription*</returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TF_NewOperation(IntPtr graph, string opType, string oper_name);
public static extern IntPtr TF_NewOperation(SafeGraphHandle graph, string opType, string oper_name);


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TF_OperationDevice(IntPtr oper); public static extern IntPtr TF_OperationDevice(IntPtr oper);


+ 221
- 230
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -14,281 +14,272 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using Google.Protobuf;
using Tensorflow.NumPy; using Tensorflow.NumPy;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Numerics;
using System.Text;
using Tensorflow.Util;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow
namespace Tensorflow;

public class BaseSession : IDisposable
{ {
public class BaseSession : DisposableObject
protected SafeSessionHandle _handle;
protected Graph _graph;
protected Status _status;
public Graph graph => _graph;

public BaseSession(SafeSessionHandle handle, Graph g)
{ {
protected Graph _graph;
protected Status _status;
public Graph graph => _graph;
_handle = handle;
_graph = g ?? ops.get_default_graph();
}


public BaseSession(IntPtr handle, Graph g)
public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null)
{
_graph = g ?? ops.get_default_graph();
if (!_graph.building_function)
{ {
_handle = handle;
_graph = g ?? ops.get_default_graph();
if (ops.get_default_graph() != _graph)
_graph.as_default();
} }
var opts = new SessionOptions(target, config);
_status = status ?? tf.Status;
_handle = c_api.TF_NewSession(_graph, opts, _status);
_status.Check(true);
}


public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null)
{
_graph = g ?? ops.get_default_graph();
if (!_graph.building_function)
{
if (ops.get_default_graph() != _graph)
_graph.as_default();
}
using var opts = new SessionOptions(target, config);
_status = status ?? tf.Status;
_handle = c_api.TF_NewSession(_graph, opts.Handle, _status.Handle);
_status.Check(true);
}
public virtual void run(Operation op, params FeedItem[] feed_dict)
{
_run(op, feed_dict);
}


public virtual void run(Operation op, params FeedItem[] feed_dict)
{
_run(op, feed_dict);
}
public virtual NDArray run(Tensor fetche, params FeedItem[] feed_dict)
{
return _run(fetche, feed_dict)[0];
}


public virtual NDArray run(Tensor fetche, params FeedItem[] feed_dict)
{
return _run(fetche, feed_dict)[0];
}
public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict)
{
var results = _run(fetche, feed_dict);
return fetche is Tensor ? results[0] : null;
}


public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict)
{
var results = _run(fetche, feed_dict);
return fetche is Tensor ? results[0] : null;
}
public virtual (NDArray, NDArray, NDArray, NDArray, NDArray) run(
(ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches,
params FeedItem[] feed_dict)
{
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4, fetches.Item5 }, feed_dict);
return (results[0], results[1], results[2], results[3], results[4]);
}


public virtual (NDArray, NDArray, NDArray, NDArray, NDArray) run(
(ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches,
params FeedItem[] feed_dict)
{
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4, fetches.Item5 }, feed_dict);
return (results[0], results[1], results[2], results[3], results[4]);
}
public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
{
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict);
return (results[0], results[1], results[2], results[3]);
}


public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
{
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict);
return (results[0], results[1], results[2], results[3]);
}
public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
{
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict);
return (results[0], results[1], results[2]);
}


public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
{
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict);
return (results[0], results[1], results[2]);
}
public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
{
var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict);
return (results[0], results[1]);
}


public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
{
var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict);
return (results[0], results[1]);
}
public virtual NDArray[] run(object fetches, params FeedItem[] feed_dict)
{
return _run(fetches, feed_dict);
}


public virtual NDArray[] run(object fetches, params FeedItem[] feed_dict)
{
return _run(fetches, feed_dict);
}
public virtual NDArray[] run(object fetches, Hashtable feed_dict = null)
{
var feed_items = feed_dict == null ? new FeedItem[0] : feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray();
return _run(fetches, feed_items);
}


public virtual NDArray[] run(object fetches, Hashtable feed_dict = null)
{
var feed_items = feed_dict == null ? new FeedItem[0] : feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray();
return _run(fetches, feed_items);
}
private NDArray[] _run(object fetches, FeedItem[] feed_dict = null)
{
var feed_dict_tensor = new Dictionary<object, object>();
//var feed_map = new Dictionary<object, object>();


private NDArray[] _run(object fetches, FeedItem[] feed_dict = null)
// Validate and process feed_dict.
if (feed_dict != null)
{ {
var feed_dict_tensor = new Dictionary<object, object>();
//var feed_map = new Dictionary<object, object>();

// Validate and process feed_dict.
if (feed_dict != null)
foreach (var subfeed in feed_dict)
{ {
foreach (var subfeed in feed_dict)
{
var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false);
//var target_dtype = subfeed_t.dtype.as_numpy_typecode(); // subfeed_dtype was never used
feed_dict_tensor[subfeed_t] = subfeed.Value;
//feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value);
}
var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false);
//var target_dtype = subfeed_t.dtype.as_numpy_typecode(); // subfeed_dtype was never used
feed_dict_tensor[subfeed_t] = subfeed.Value;
//feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value);
} }
}


// Create a fetch handler to take care of the structure of fetches.
var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor);
// Create a fetch handler to take care of the structure of fetches.
var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor);


// Run request and get response.
// We need to keep the returned movers alive for the following _do_run().
// These movers are no longer needed when _do_run() completes, and
// are deleted when `movers` goes out of scope when this _run() ends.
var _ = _update_with_movers();
var final_fetches = fetch_handler.fetches();
var final_targets = fetch_handler.targets();
// Run request and get response.
// We need to keep the returned movers alive for the following _do_run().
// These movers are no longer needed when _do_run() completes, and
// are deleted when `movers` goes out of scope when this _run() ends.
var _ = _update_with_movers();
var final_fetches = fetch_handler.fetches();
var final_targets = fetch_handler.targets();


// We only want to really perform the run if fetches or targets are provided,
// or if the call is a partial run that specifies feeds.
var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor);
// We only want to really perform the run if fetches or targets are provided,
// or if the call is a partial run that specifies feeds.
var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor);


return fetch_handler.build_results(this, results);
}
return fetch_handler.build_results(this, results);
}


/// <summary>
/// Runs a step based on the given fetches and feeds.
/// </summary>
/// <param name="target_list">A list of operations to be run, but not fetched.</param>
/// <param name="fetch_list"></param>
/// <param name="feed_dict"></param>
/// <returns>
/// A list of numpy ndarrays, corresponding to the elements of
/// `fetch_list`. If the ith element of `fetch_list` contains the
/// name of an operation, the first Tensor output of that operation
/// will be returned for that element.
/// </returns>
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict)
/// <summary>
/// Runs a step based on the given fetches and feeds.
/// </summary>
/// <param name="target_list">A list of operations to be run, but not fetched.</param>
/// <param name="fetch_list"></param>
/// <param name="feed_dict"></param>
/// <returns>
/// A list of numpy ndarrays, corresponding to the elements of
/// `fetch_list`. If the ith element of `fetch_list` contains the
/// name of an operation, the first Tensor output of that operation
/// will be returned for that element.
/// </returns>
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict)
{
var feeds = new KeyValuePair<TF_Output, Tensor>[feed_dict.Count];
int i = 0;
foreach (var x in feed_dict)
{ {
var feeds = new KeyValuePair<TF_Output, Tensor>[feed_dict.Count];
int i = 0;
foreach (var x in feed_dict)
if (x.Key is Tensor key)
{ {
if (x.Key is Tensor key)
switch (x.Value)
{ {
switch (x.Value)
{
case Tensor v:
if (v.dtype != key.dtype)
throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {v.dtype}");
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), v);
break;
case SafeTensorHandle v:
var tensor = new Tensor(v);
if (tensor.dtype != key.dtype)
throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {tensor.dtype}");
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), tensor);
break;
case bool v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v));
break;
case byte v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v));
break;
case int v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v));
break;
case long v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v));
break;
case float v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v));
break;
case double v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v));
break;
case string v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v));
break;
case Array v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v, v.GetShape()));
break;
default:
throw new NotImplementedException("");
}
case Tensor v:
if (v.dtype != key.dtype)
throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {v.dtype}");
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), v);
break;
case SafeTensorHandle v:
var tensor = new Tensor(v);
if (tensor.dtype != key.dtype)
throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {tensor.dtype}");
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), tensor);
break;
case bool v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v));
break;
case byte v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v));
break;
case int v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v));
break;
case long v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v));
break;
case float v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v));
break;
case double v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v));
break;
case string v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v));
break;
case Array v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v, v.GetShape()));
break;
default:
throw new NotImplementedException("");
} }
else
throw new NotImplementedException("");
} }

var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();
//var targets = target_list;
return _call_tf_sessionrun(feeds, fetches, target_list);
else
throw new NotImplementedException("");
} }


var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();
//var targets = target_list;
return _call_tf_sessionrun(feeds, fetches, target_list);
}


private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list)
{
// Ensure any changes to the graph are reflected in the runtime.
_extend_graph();


var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray();
private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list)
{
// Ensure any changes to the graph are reflected in the runtime.
_extend_graph();


c_api.TF_SessionRun(_handle,
run_options: null,
inputs: feed_dict.Select(f => f.Key).ToArray(),
input_values: feed_dict.Select(f => f.Value.Handle.DangerousGetHandle()).ToArray(),
ninputs: feed_dict.Length,
outputs: fetch_list,
output_values: output_values,
noutputs: fetch_list.Length,
target_opers: target_list.Select(f => (IntPtr)f).ToArray(),
ntargets: target_list.Count,
run_metadata: IntPtr.Zero,
status: _status.Handle);
var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray();


_status.Check(true);
c_api.TF_SessionRun(_handle,
run_options: null,
inputs: feed_dict.Select(f => f.Key).ToArray(),
input_values: feed_dict.Select(f => f.Value.Handle.DangerousGetHandle()).ToArray(),
ninputs: feed_dict.Length,
outputs: fetch_list,
output_values: output_values,
noutputs: fetch_list.Length,
target_opers: target_list.Select(f => (IntPtr)f).ToArray(),
ntargets: target_list.Count,
run_metadata: IntPtr.Zero,
status: _status);


var result = new NDArray[fetch_list.Length];
_status.Check(true);


for (int i = 0; i < fetch_list.Length; i++)
result[i] = fetchValue(new SafeTensorHandle(output_values[i]));
var result = new NDArray[fetch_list.Length];


return result;
}
for (int i = 0; i < fetch_list.Length; i++)
result[i] = fetchValue(new SafeTensorHandle(output_values[i]));


public unsafe Tensor eval(Tensor tensor)
{
var output_values = new IntPtr[1];
var fetch_list = new[] { tensor._as_tf_output() };

c_api.TF_SessionRun(_handle,
run_options: null,
inputs: new TF_Output[0],
input_values: new IntPtr[0],
ninputs: 0,
outputs: fetch_list,
output_values: output_values,
noutputs: 1,
target_opers: new IntPtr[0],
ntargets: 0,
run_metadata: IntPtr.Zero,
status: _status.Handle);

_status.Check(true);

return new Tensor(new SafeTensorHandle(output_values[0]));
}
return result;
}


private static unsafe NDArray fetchValue(SafeTensorHandle output)
{
var tensor = new Tensor(output);
return tensor.numpy();
}
public unsafe Tensor eval(Tensor tensor)
{
var output_values = new IntPtr[1];
var fetch_list = new[] { tensor._as_tf_output() };

c_api.TF_SessionRun(_handle,
run_options: null,
inputs: new TF_Output[0],
input_values: new IntPtr[0],
ninputs: 0,
outputs: fetch_list,
output_values: output_values,
noutputs: 1,
target_opers: new IntPtr[0],
ntargets: 0,
run_metadata: IntPtr.Zero,
status: _status);

_status.Check(true);

return new Tensor(new SafeTensorHandle(output_values[0]));
}


/// <summary>
/// If a tensor handle that is fed to a device incompatible placeholder,
/// we move the tensor to the right device, generate a new tensor handle,
/// and update feed_dict to use the new handle.
/// </summary>
private List<object> _update_with_movers()
{
return new List<object> { };
}
private static unsafe NDArray fetchValue(SafeTensorHandle output)
{
var tensor = new Tensor(output);
return tensor.numpy();
}


private void _extend_graph()
{ }
/// <summary>
/// If a tensor handle that is fed to a device incompatible placeholder,
/// we move the tensor to the right device, generate a new tensor handle,
/// and update feed_dict to use the new handle.
/// </summary>
private List<object> _update_with_movers()
{
return new List<object> { };
}


protected override void DisposeUnmanagedResources(IntPtr handle)
{
// c_api.TF_CloseSession(handle, tf.Status.Handle);
c_api.TF_DeleteSession(handle, _status.Handle);
}
private void _extend_graph()
{ }

public void Dispose()
{
} }
} }

+ 46
- 0
src/TensorFlowNET.Core/Sessions/SafeSessionHandle.cs View File

@@ -0,0 +1,46 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Net.NetworkInformation;
using Tensorflow.Util;

namespace Tensorflow
{
public sealed class SafeSessionHandle : SafeTensorflowHandle
{
private SafeSessionHandle()
{
}

public SafeSessionHandle(IntPtr handle)
: base(handle)
{
}

public override string ToString()
=> $"0x{handle:x16}";

protected override bool ReleaseHandle()
{
var status = new Status();
// c_api.TF_CloseSession(handle, tf.Status.Handle);
c_api.TF_DeleteSession(handle, status);
SetHandle(IntPtr.Zero);
return true;
}
}
}

+ 38
- 64
src/TensorFlowNET.Core/Sessions/Session.cs View File

@@ -14,75 +14,49 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System;
using System.IO;
using System.Runtime.CompilerServices;
using Tensorflow.Util;
namespace Tensorflow;


namespace Tensorflow
public class Session : BaseSession
{ {
public class Session : BaseSession
{
public Session(string target = "", Graph g = null) : base(target, g, null)
{ }

public Session(IntPtr handle, Graph g = null) : base(handle, g)
{ }

public Session(Graph g, ConfigProto config = null, Status s = null) : base("", g, config, s)
{ }

public Session as_default()
{
return ops.set_default_session(this);
}

public static Session LoadFromSavedModel(string path)
{
var graph = new Graph();
using var status = new Status();
using var opt = c_api.TF_NewSessionOptions();

var tags = new string[] { "serve" };

var sess = c_api.TF_LoadSessionFromSavedModel(opt,
IntPtr.Zero,
path,
tags,
tags.Length,
graph,
IntPtr.Zero,
status.Handle);
status.Check(true);

// load graph bytes
// var data = new byte[buffer.length];
// Marshal.Copy(buffer.data, data, 0, (int)buffer.length);
// var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/
return new Session(sess, g: graph);
}

public static implicit operator IntPtr(Session session) => session._handle;
public static implicit operator Session(IntPtr handle) => new Session(handle);
public Session(string target = "", Graph g = null) : base(target, g, null)
{ }


public void __enter__()
{
public Session(SafeSessionHandle handle, Graph g = null) : base(handle, g)
{ }


}
public Session(Graph g, ConfigProto config = null, Status s = null) : base("", g, config, s)
{ }


public void __exit__()
{

}

public void __init__()
{

}

public void __del__()
{
public Session as_default()
{
return ops.set_default_session(this);
}


}
public static Session LoadFromSavedModel(string path)
{
var graph = new Graph();
var status = new Status();
using var opt = c_api.TF_NewSessionOptions();

var tags = new string[] { "serve" };

var sess = c_api.TF_LoadSessionFromSavedModel(opt,
IntPtr.Zero,
path,
tags,
tags.Length,
graph,
IntPtr.Zero,
status);
status.Check(true);

// load graph bytes
// var data = new byte[buffer.length];
// Marshal.Copy(buffer.data, data, 0, (int)buffer.length);
// var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/
return new Session(sess, g: graph);
} }

public static implicit operator SafeSessionHandle(Session session) => session._handle;
public static implicit operator Session(SafeSessionHandle handle) => new Session(handle);
} }

+ 12
- 12
src/TensorFlowNET.Core/Sessions/SessionOptions.cs View File

@@ -19,33 +19,33 @@ using System;


namespace Tensorflow namespace Tensorflow
{ {
internal sealed class SessionOptions : IDisposable
internal sealed class SessionOptions
{ {
public SafeSessionOptionsHandle Handle { get; }
SafeSessionOptionsHandle _handle { get; }


public SessionOptions(string target = "", ConfigProto config = null) public SessionOptions(string target = "", ConfigProto config = null)
{ {
Handle = c_api.TF_NewSessionOptions();
c_api.TF_SetTarget(Handle, target);
_handle = c_api.TF_NewSessionOptions();
c_api.TF_SetTarget(_handle, target);
if (config != null) if (config != null)
SetConfig(config); SetConfig(config);
} }


public void Dispose()
=> Handle.Dispose();

private unsafe void SetConfig(ConfigProto config) private unsafe void SetConfig(ConfigProto config)
{ {
var bytes = config.ToByteArray(); var bytes = config.ToByteArray();


fixed (byte* proto2 = bytes) fixed (byte* proto2 = bytes)
{ {
using (var status = new Status())
{
c_api.TF_SetConfig(Handle, (IntPtr)proto2, (ulong)bytes.Length, status.Handle);
status.Check(false);
}
var status = new Status();
c_api.TF_SetConfig(_handle, (IntPtr)proto2, (ulong)bytes.Length, status);
status.Check(false);
} }
} }

public static implicit operator SafeSessionOptionsHandle(SessionOptions opt)
{
return opt._handle;
}
} }
} }

+ 2
- 2
src/TensorFlowNET.Core/Sessions/c_api.session.cs View File

@@ -62,7 +62,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
/// <returns>TF_Session*</returns> /// <returns>TF_Session*</returns>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TF_NewSession(IntPtr graph, SafeSessionOptionsHandle opts, SafeStatusHandle status);
public static extern SafeSessionHandle TF_NewSession(SafeGraphHandle graph, SafeSessionOptionsHandle opts, SafeStatusHandle status);


/// <summary> /// <summary>
/// Return a new options object. /// Return a new options object.
@@ -110,7 +110,7 @@ namespace Tensorflow
/// <param name="run_metadata">TF_Buffer*</param> /// <param name="run_metadata">TF_Buffer*</param>
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern unsafe void TF_SessionRun(IntPtr session, TF_Buffer* run_options,
public static extern unsafe void TF_SessionRun(SafeSessionHandle session, TF_Buffer* run_options,
TF_Output[] inputs, IntPtr[] input_values, int ninputs, TF_Output[] inputs, IntPtr[] input_values, int ninputs,
TF_Output[] outputs, IntPtr[] output_values, int noutputs, TF_Output[] outputs, IntPtr[] output_values, int noutputs,
IntPtr[] target_opers, int ntargets, IntPtr[] target_opers, int ntargets,


+ 14
- 12
src/TensorFlowNET.Core/Status/Status.cs View File

@@ -26,7 +26,7 @@ namespace Tensorflow
/// TF_Status holds error information. It either has an OK code, or /// TF_Status holds error information. It either has an OK code, or
/// else an error code with an associated error message. /// else an error code with an associated error message.
/// </summary> /// </summary>
public sealed class Status : IDisposable
public sealed class Status
{ {
/// <summary> /// <summary>
/// Error message /// Error message
@@ -35,9 +35,9 @@ namespace Tensorflow
{ {
get get
{ {
using (Handle.Lease())
using (_handle.Lease())
{ {
return StringPiece(TF_Message(Handle));
return StringPiece(TF_Message(_handle));
} }
} }
} }
@@ -45,23 +45,23 @@ namespace Tensorflow
/// <summary> /// <summary>
/// Error code /// Error code
/// </summary> /// </summary>
public TF_Code Code => TF_GetCode(Handle);
public TF_Code Code => TF_GetCode(_handle);


public SafeStatusHandle Handle { get; }
SafeStatusHandle _handle { get; }


public Status() public Status()
{ {
Handle = TF_NewStatus();
_handle = TF_NewStatus();
} }


public Status(SafeStatusHandle handle) public Status(SafeStatusHandle handle)
{ {
Handle = handle ?? throw new ArgumentNullException(nameof(handle));
_handle = handle ?? throw new ArgumentNullException(nameof(handle));
} }


public void SetStatus(TF_Code code, string msg) public void SetStatus(TF_Code code, string msg)
{ {
TF_SetStatus(Handle, code, msg);
TF_SetStatus(_handle, code, msg);
} }


public bool ok() => Code == TF_Code.TF_OK; public bool ok() => Code == TF_Code.TF_OK;
@@ -94,10 +94,12 @@ namespace Tensorflow
} }
} }


public void Dispose()
=> Handle.Dispose();

public override string ToString() public override string ToString()
=> $"{Code} 0x{Handle.DangerousGetHandle():x16}";
=> $"{Code} 0x{_handle.DangerousGetHandle():x16}";

public static implicit operator SafeStatusHandle(Status status)
{
return status._handle;
}
} }
} }

+ 4
- 4
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -121,7 +121,7 @@ namespace Tensorflow


if (_handle == null) if (_handle == null)
{ {
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, tf.Status.Handle);
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, tf.Status);
} }
else else
{ {
@@ -135,9 +135,9 @@ namespace Tensorflow
protected virtual void SetShapeInternal(Shape value) protected virtual void SetShapeInternal(Shape value)
{ {
if (value == null) if (value == null)
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.Status.Handle);
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.Status);
else else
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.dims, value.ndim, tf.Status.Handle);
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.dims, value.ndim, tf.Status);
} }


public int[] _shape_tuple() public int[] _shape_tuple()
@@ -176,7 +176,7 @@ namespace Tensorflow
if (_handle == null) if (_handle == null)
{ {
var output = _as_tf_output(); var output = _as_tf_output();
int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, tf.Status.Handle);
int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, tf.Status);
return ndim; return ndim;
} }




+ 10
- 12
src/TensorFlowNET.Core/Training/Saving/saver.py.cs View File

@@ -94,18 +94,16 @@ namespace Tensorflow


string output_pb = Path.GetFullPath(Path.Combine(checkpoint_dir, "../", $"{output_pb_name}.pb")); string output_pb = Path.GetFullPath(Path.Combine(checkpoint_dir, "../", $"{output_pb_name}.pb"));


using (var graph = tf.Graph())
using (var sess = tf.Session(graph))
{
var saver = tf.train.import_meta_graph($"{checkpoint}.meta", clear_devices: true);
saver.restore(sess, checkpoint);
var output_graph_def = tf.graph_util.convert_variables_to_constants(sess,
graph.as_graph_def(),
output_node_names);
Binding.tf_output_redirect.WriteLine($"Froze {output_graph_def.Node.Count} nodes.");
File.WriteAllBytes(output_pb, output_graph_def.ToByteArray());
return output_pb;
}
var graph = tf.Graph();
var sess = tf.Session(graph);
var saver = tf.train.import_meta_graph($"{checkpoint}.meta", clear_devices: true);
saver.restore(sess, checkpoint);
var output_graph_def = tf.graph_util.convert_variables_to_constants(sess,
graph.as_graph_def(),
output_node_names);
Binding.tf_output_redirect.WriteLine($"Froze {output_graph_def.Node.Count} nodes.");
File.WriteAllBytes(output_pb, output_graph_def.ToByteArray());
return output_pb;
} }


public static Graph load_graph(string freeze_graph_pb, string name = "") public static Graph load_graph(string freeze_graph_pb, string name = "")


+ 1
- 1
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -164,7 +164,7 @@ namespace Tensorflow
result._as_tf_output(), result._as_tf_output(),
shape.dims, shape.dims,
shape.ndim, shape.ndim,
tf.Status.Handle);
tf.Status);
tf.Status.Check(true); tf.Status.Check(true);
} }




+ 1
- 1
src/TensorFlowNET.Core/ops.cs View File

@@ -247,7 +247,7 @@ namespace Tensorflow
foreach (var attr in node_def.Attr) foreach (var attr in node_def.Attr)
{ {
var bytes = attr.Value.ToByteArray(); var bytes = attr.Value.ToByteArray();
c_api.TF_SetAttrValueProto(op_desc, attr.Key, bytes, proto_len: bytes.Length, status: status.Handle);
c_api.TF_SetAttrValueProto(op_desc, attr.Key, bytes, proto_len: bytes.Length, status: status);
status.Check(true); status.Check(true);
} }




+ 7
- 9
src/TensorFlowNet.Benchmarks/Leak/SavedModelCleanup.cs View File

@@ -23,16 +23,14 @@ namespace Tensorflow.Benchmark.Leak
var ClassifierModelPath = Path.Combine(modelDir, "Leak", "TestModel", "saved_model"); var ClassifierModelPath = Path.Combine(modelDir, "Leak", "TestModel", "saved_model");


for (var i = 0; i < 1024; i++) for (var i = 0; i < 1024; i++)
{
using (var sess = Session.LoadFromSavedModel(ClassifierModelPath)) {
using (var g = sess.graph.as_default()) {
var inputOp = g.OperationByName("inference_input");
var outputOp = g.OperationByName("StatefulPartitionedCall");
{
var sess = Session.LoadFromSavedModel(ClassifierModelPath);
var g = sess.graph.as_default();
var inputOp = g.OperationByName("inference_input");
var outputOp = g.OperationByName("StatefulPartitionedCall");


var inp = np.zeros(new Shape(new int[] { 1, 2, 96 }), TF_DataType.TF_FLOAT);
sess.run(outputOp.outputs[0], new FeedItem(inputOp.outputs[0], inp));
}
}
var inp = np.zeros(new Shape(new int[] { 1, 2, 96 }), TF_DataType.TF_FLOAT);
sess.run(outputOp.outputs[0], new FeedItem(inputOp.outputs[0], inp));
} }
} }
} }


+ 38
- 46
test/TensorFlowNET.Graph.UnitTest/Basics/QueueTest.cs View File

@@ -16,18 +16,16 @@ namespace TensorFlowNET.UnitTest.Basics
var enqueue = queue.enqueue(numbers); var enqueue = queue.enqueue(numbers);
var dequeue_many = queue.dequeue_many(n: 3); var dequeue_many = queue.dequeue_many(n: 3);


using (var sess = tf.Session())
{
sess.run(enqueue, (numbers, new[] { 1 }));
sess.run(enqueue, (numbers, new[] { 2, 3 }));
sess.run(enqueue, (numbers, new[] { 3, 4, 5 }));

var result = sess.run(dequeue_many[0]);

Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0 }, result[0].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3, 0 }, result[1].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 4, 5 }, result[2].ToArray<int>()));
}
var sess = tf.Session();
sess.run(enqueue, (numbers, new[] { 1 }));
sess.run(enqueue, (numbers, new[] { 2, 3 }));
sess.run(enqueue, (numbers, new[] { 3, 4, 5 }));

var result = sess.run(dequeue_many[0]);

Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0 }, result[0].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3, 0 }, result[1].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 4, 5 }, result[2].ToArray<int>()));
} }


[TestMethod] [TestMethod]
@@ -45,27 +43,25 @@ namespace TensorFlowNET.UnitTest.Basics
// push back into queue // push back into queue
var inc = queue.enqueue(y); var inc = queue.enqueue(y);


using (var sess = tf.Session())
{
// init queue
init.run();
var sess = tf.Session();
// init queue
init.run();


// pop out first element and push back calculated y
(int dequeued, _) = sess.run((x, inc));
Assert.AreEqual(10, dequeued);
// pop out first element and push back calculated y
(int dequeued, _) = sess.run((x, inc));
Assert.AreEqual(10, dequeued);


(dequeued, _) = sess.run((x, inc));
Assert.AreEqual(20, dequeued);
(dequeued, _) = sess.run((x, inc));
Assert.AreEqual(20, dequeued);


(dequeued, _) = sess.run((x, inc));
Assert.AreEqual(11, dequeued);
(dequeued, _) = sess.run((x, inc));
Assert.AreEqual(11, dequeued);


(dequeued, _) = sess.run((x, inc));
Assert.AreEqual(21, dequeued);
(dequeued, _) = sess.run((x, inc));
Assert.AreEqual(21, dequeued);


// thread will hang or block if you run sess.run(x) again
// until queue has more element.
}
// thread will hang or block if you run sess.run(x) again
// until queue has more element.
} }


[TestMethod] [TestMethod]
@@ -75,19 +71,17 @@ namespace TensorFlowNET.UnitTest.Basics
var init = queue.enqueue_many(new[] { 2L, 4L, 3L }, new[] { "p1", "p2", "p3" }); var init = queue.enqueue_many(new[] { 2L, 4L, 3L }, new[] { "p1", "p2", "p3" });
var x = queue.dequeue(); var x = queue.dequeue();


using (var sess = tf.Session())
{
init.run();
var sess = tf.Session();
init.run();


var result = sess.run(x);
Assert.AreEqual(result[0], 2L);
var result = sess.run(x);
Assert.AreEqual(result[0], 2L);


result = sess.run(x);
Assert.AreEqual(result[0], 3L);
result = sess.run(x);
Assert.AreEqual(result[0], 3L);


result = sess.run(x);
Assert.AreEqual(result[0], 4L);
}
result = sess.run(x);
Assert.AreEqual(result[0], 4L);
} }


[TestMethod] [TestMethod]
@@ -98,16 +92,14 @@ namespace TensorFlowNET.UnitTest.Basics
var x = queue.dequeue(); var x = queue.dequeue();


string results = ""; string results = "";
using (var sess = tf.Session())
{
init.run();
var sess = tf.Session();
init.run();


foreach (var i in range(9))
results += (int)sess.run(x) + ".";
foreach (var i in range(9))
results += (int)sess.run(x) + ".";


// output in random order
Assert.IsFalse(results == "1.2.3.4.5.6.7.8.9.");
}
// output in random order
Assert.IsFalse(results == "1.2.3.4.5.6.7.8.9.");
} }
} }
} }

+ 9
- 15
test/TensorFlowNET.Graph.UnitTest/Basics/SessionTest.cs View File

@@ -19,11 +19,9 @@ namespace TensorFlowNET.UnitTest.Basics
var a = constant_op.constant(np.array(3.0).reshape((1, 1))); var a = constant_op.constant(np.array(3.0).reshape((1, 1)));
var b = constant_op.constant(np.array(2.0).reshape((1, 1))); var b = constant_op.constant(np.array(2.0).reshape((1, 1)));
var c = math_ops.matmul(a, b, name: "matmul"); var c = math_ops.matmul(a, b, name: "matmul");
using (var sess = tf.Session())
{
var result = c.eval(sess);
Assert.AreEqual(result[0], 6.0);
}
var sess = tf.Session();
var result = c.eval(sess);
Assert.AreEqual(result[0], 6.0);
} }
} }


@@ -32,11 +30,9 @@ namespace TensorFlowNET.UnitTest.Basics
{ {
var a = constant_op.constant("123 heythere 123 ", TF_DataType.TF_STRING); var a = constant_op.constant("123 heythere 123 ", TF_DataType.TF_STRING);
var c = tf.strings.substr(a, 4, 8); var c = tf.strings.substr(a, 4, 8);
using (var sess = tf.Session())
{
var result = c.eval(sess).StringData();
Assert.AreEqual(result[0], "heythere");
}
var sess = tf.Session();
var result = c.eval(sess).StringData();
Assert.AreEqual(result[0], "heythere");
} }


[TestMethod] [TestMethod]
@@ -47,11 +43,9 @@ namespace TensorFlowNET.UnitTest.Basics
const int size = 30_000; const int size = 30_000;
var a = constant_op.constant(new string('a', size), TF_DataType.TF_STRING); var a = constant_op.constant(new string('a', size), TF_DataType.TF_STRING);
var c = tf.strings.substr(a, 0, size - 5000); var c = tf.strings.substr(a, 0, size - 5000);
using (var sess = tf.Session())
{
var result = UTF8Encoding.UTF8.GetString(c.eval(sess).ToByteArray());
Console.WriteLine(result);
}
var sess = tf.Session();
var result = UTF8Encoding.UTF8.GetString(c.eval(sess).ToByteArray());
Console.WriteLine(result);
} }
} }




+ 21
- 29
test/TensorFlowNET.Graph.UnitTest/Basics/TensorTest.cs View File

@@ -16,15 +16,13 @@ namespace TensorFlowNET.UnitTest.Basics
var labels = tf.expand_dims(tf.constant(new[] { 0, 1, 2, 3, 4 }), 1); var labels = tf.expand_dims(tf.constant(new[] { 0, 1, 2, 3, 4 }), 1);
var st = tf.concat(values: new[] { indices, labels }, axis: 1); var st = tf.concat(values: new[] { indices, labels }, axis: 1);
var onehot = tf.sparse_to_dense(st, (5, 5), 1); var onehot = tf.sparse_to_dense(st, (5, 5), 1);
using (var sess = tf.Session())
{
var result = sess.run(onehot);
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0, 0 }, result[0].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 1, 0, 0, 0 }, result[1].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 1, 0, 0 }, result[2].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 1, 0 }, result[3].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 1 }, result[4].ToArray<int>()));
};
var sess = tf.Session();
var result = sess.run(onehot);
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0, 0 }, result[0].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 1, 0, 0, 0 }, result[1].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 1, 0, 0 }, result[2].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 1, 0 }, result[3].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 1 }, result[4].ToArray<int>()));
} }


[TestMethod, Ignore] [TestMethod, Ignore]
@@ -39,13 +37,11 @@ namespace TensorFlowNET.UnitTest.Basics
new[] { 3L, 4L }); new[] { 3L, 4L });


var onehot = tf.sparse_tensor_to_dense(decoded_list); var onehot = tf.sparse_tensor_to_dense(decoded_list);
using (var sess = tf.Session())
{
var result = sess.run(onehot);
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0 }, result[0].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 2, 0 }, result[1].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0 }, result[2].ToArray<int>()));
}
var sess = tf.Session();
var result = sess.run(onehot);
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0 }, result[0].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 2, 0 }, result[1].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0 }, result[2].ToArray<int>()));
} }


[TestMethod] [TestMethod]
@@ -56,14 +52,12 @@ namespace TensorFlowNET.UnitTest.Basics
int[,] crops = { { 0, 0 }, { 0, 0 } }; int[,] crops = { { 0, 0 }, { 0, 0 } };
var tensor = tf.batch_to_space_nd(inputs, block_shape, crops); var tensor = tf.batch_to_space_nd(inputs, block_shape, crops);


using (var sess = tf.Session())
{
var result = sess.run(tensor);
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 6, 1, 7, 2, 8 }, result[0, 0].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 12, 18, 13, 19, 14, 20 }, result[0, 1].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 9, 4, 10, 5, 11 }, result[0, 2].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray<int>()));
}
var sess = tf.Session();
var result = sess.run(tensor);
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 6, 1, 7, 2, 8 }, result[0, 0].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 12, 18, 13, 19, 14, 20 }, result[0, 1].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 9, 4, 10, 5, 11 }, result[0, 2].ToArray<int>()));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray<int>()));
} }


[TestMethod, Ignore] [TestMethod, Ignore]
@@ -72,11 +66,9 @@ namespace TensorFlowNET.UnitTest.Basics
var tensor = new[] { 0, 1, 2, 3 }; var tensor = new[] { 0, 1, 2, 3 };
var mask = np.array(new[] { true, false, true, false }); var mask = np.array(new[] { true, false, true, false });
var masked = tf.boolean_mask(tensor, mask); var masked = tf.boolean_mask(tensor, mask);
using (var sess = tf.Session())
{
var result = sess.run(masked);
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, masked.ToArray<int>()));
}
var sess = tf.Session();
var result = sess.run(masked);
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, masked.ToArray<int>()));
} }
} }
} }

+ 1
- 1
test/TensorFlowNET.Graph.UnitTest/Basics/VariableTest.cs View File

@@ -14,7 +14,7 @@ namespace TensorFlowNET.UnitTest.Basics
var v = tf.Variable(new[] { 1, 2 }); var v = tf.Variable(new[] { 1, 2 });
var init = tf.compat.v1.global_variables_initializer(); var init = tf.compat.v1.global_variables_initializer();


using var sess = tf.compat.v1.Session();
var sess = tf.compat.v1.Session();
sess.run(init); sess.run(init);
// Usage passing the session explicitly. // Usage passing the session explicitly.
print(v.eval(sess)); print(v.eval(sess));


+ 18
- 22
test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/CondTestCases.cs View File

@@ -16,18 +16,16 @@ namespace TensorFlowNET.UnitTest.ControlFlowTest
{ {
var graph = tf.Graph().as_default(); var graph = tf.Graph().as_default();


using (var sess = tf.Session(graph))
{
var x = tf.constant(2, name: "x");
var y = tf.constant(5, name: "y");

var z = control_flow_ops.cond(tf.less(x, y),
() => tf.constant(22, name: "t22"),
() => tf.constant(55, name: "f55"));

int result = z.eval(sess);
assertEquals(result, 22);
}
var sess = tf.Session(graph);
var x = tf.constant(2, name: "x");
var y = tf.constant(5, name: "y");

var z = control_flow_ops.cond(tf.less(x, y),
() => tf.constant(22, name: "t22"),
() => tf.constant(55, name: "f55"));

int result = z.eval(sess);
assertEquals(result, 22);
} }


[TestMethod] [TestMethod]
@@ -35,18 +33,16 @@ namespace TensorFlowNET.UnitTest.ControlFlowTest
{ {
var graph = tf.Graph().as_default(); var graph = tf.Graph().as_default();


using (var sess = tf.Session(graph))
{
var x = tf.constant(2, name: "x");
var y = tf.constant(1, name: "y");
var sess = tf.Session(graph);
var x = tf.constant(2, name: "x");
var y = tf.constant(1, name: "y");


var z = control_flow_ops.cond(tf.less(x, y),
() => tf.constant(22, name: "t22"),
() => tf.constant(11, name: "f11"));
var z = control_flow_ops.cond(tf.less(x, y),
() => tf.constant(22, name: "t22"),
() => tf.constant(11, name: "f11"));


int result = z.eval(sess);
assertEquals(result, 11);
}
int result = z.eval(sess);
assertEquals(result, 11);
} }


[Ignore("Dependent on UpdateEdge")] [Ignore("Dependent on UpdateEdge")]


+ 12
- 14
test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/WhileContextTestCase.cs View File

@@ -23,21 +23,19 @@ namespace TensorFlowNET.UnitTest.ControlFlowTest
private void _testWhileContextHelper(int maximum_iterations) private void _testWhileContextHelper(int maximum_iterations)
{ {
// TODO: implement missing code dependencies // TODO: implement missing code dependencies
using (var sess = this.cached_session())
var sess = this.cached_session();
var i = constant_op.constant(0, name: "i");
var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, 10, name: "c"));
var b = new Func<Tensor, Tensor>(x => gen_math_ops.add(x, 1, name: "c"));
//control_flow_ops.while_loop(
// c, b, i , maximum_iterations: tf.constant(maximum_iterations));
foreach (Operation op in sess.graph.get_operations())
{ {
var i = constant_op.constant(0, name: "i");
var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, 10, name: "c"));
var b = new Func<Tensor, Tensor>(x => gen_math_ops.add(x, 1, name: "c"));
//control_flow_ops.while_loop(
// c, b, i , maximum_iterations: tf.constant(maximum_iterations));
foreach (Operation op in sess.graph.get_operations())
{
var control_flow_context = op._get_control_flow_context();
/*if (control_flow_context != null)
self.assertProtoEquals(control_flow_context.to_proto(),
WhileContext.from_proto(
control_flow_context.to_proto()).to_proto(), "");*/
}
var control_flow_context = op._get_control_flow_context();
/*if (control_flow_context != null)
self.assertProtoEquals(control_flow_context.to_proto(),
WhileContext.from_proto(
control_flow_context.to_proto()).to_proto(), "");*/
} }
} }




+ 47
- 66
test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs View File

@@ -18,11 +18,9 @@ namespace TensorFlowNET.UnitTest.Gradient
var y = tf.broadcast_to(x, (2, 4, 3)); var y = tf.broadcast_to(x, (2, 4, 3));
var grad = tf.gradients(y, x); var grad = tf.gradients(y, x);


using (var sess = tf.Session(graph))
{
float result = sess.run(grad[0]);
Assert.AreEqual(result, 24.0f);
}
var sess = tf.Session(graph);
float result = sess.run(grad[0]);
Assert.AreEqual(result, 24.0f);
} }


[TestMethod] [TestMethod]
@@ -33,11 +31,9 @@ namespace TensorFlowNET.UnitTest.Gradient
var z = tf.cumsum(y, axis: 1); var z = tf.cumsum(y, axis: 1);
var grad = tf.gradients(z, x); var grad = tf.gradients(z, x);


using (var sess = tf.Session(graph))
{
float result = sess.run(grad[0]);
Assert.AreEqual(result, 60.0f);
}
var sess = tf.Session(graph);
float result = sess.run(grad[0]);
Assert.AreEqual(result, 60.0f);
} }


[TestMethod, Ignore] [TestMethod, Ignore]
@@ -78,14 +74,12 @@ namespace TensorFlowNET.UnitTest.Gradient
42.0f, 42.0f, 42.0f, 42.0f, 42.0f, 42.0f,
45.0f, 45.0f, 45.0f 45.0f, 45.0f, 45.0f
}; };
using (var sess = tf.Session())
{
var result = sess.run(g);
var resultList = result[0].ToArray<float>().ToList();
resultList.AddRange(result[1].ToArray<float>());
Console.WriteLine(result.ToString());
CollectionAssert.AreEqual(resultList.ToArray(), checkG);
}
var sess = tf.Session();
var result = sess.run(g);
var resultList = result[0].ToArray<float>().ToList();
resultList.AddRange(result[1].ToArray<float>());
Console.WriteLine(result.ToString());
CollectionAssert.AreEqual(resultList.ToArray(), checkG);
} }


[TestMethod] [TestMethod]
@@ -97,11 +91,9 @@ namespace TensorFlowNET.UnitTest.Gradient
var y = f(x); var y = f(x);
var g = tf.gradients(y, x); var g = tf.gradients(y, x);


using (var session = tf.Session())
{
var result = session.run(new[] { y, g[0] });
return (result[0].ToArray<T>()[0], result[1].ToArray<T>()[0]);
}
var session = tf.Session();
var result = session.run(new[] { y, g[0] });
return (result[0].ToArray<T>()[0], result[1].ToArray<T>()[0]);
} }


void test(string name, Func<Tensor, Tensor> tfF, Func<double, (double, double)> targetF, double[] values) void test(string name, Func<Tensor, Tensor> tfF, Func<double, (double, double)> targetF, double[] values)
@@ -197,13 +189,11 @@ namespace TensorFlowNET.UnitTest.Gradient
var g1 = tf.gradients(tf.reduce_sum(m, axis: 0)[0], x)[0]; var g1 = tf.gradients(tf.reduce_sum(m, axis: 0)[0], x)[0];
var g2 = tf.gradients(tf.reduce_sum(m, axis: 1)[0], x)[0]; var g2 = tf.gradients(tf.reduce_sum(m, axis: 1)[0], x)[0];


using (var session = tf.Session())
{
var (r0, r1, r2) = session.run((g0, g1, g2), new FeedItem(x, new[,] { { 1.0 } }));
self.assertFloat64Equal(6.0, r0[0], $"tf.reduce_sum(...)");
self.assertFloat64Equal(2.0, r1[0], $"tf.reduce_sum(..., axis = 0)");
self.assertFloat64Equal(3.0, r2[0], $"tf.reduce_sum(..., axis = 1)");
}
var session = tf.Session();
var (r0, r1, r2) = session.run((g0, g1, g2), new FeedItem(x, new[,] { { 1.0 } }));
self.assertFloat64Equal(6.0, r0[0], $"tf.reduce_sum(...)");
self.assertFloat64Equal(2.0, r1[0], $"tf.reduce_sum(..., axis = 0)");
self.assertFloat64Equal(3.0, r2[0], $"tf.reduce_sum(..., axis = 1)");
} }


[TestMethod] [TestMethod]
@@ -212,12 +202,10 @@ namespace TensorFlowNET.UnitTest.Gradient
var a = tf.constant(1f); var a = tf.constant(1f);
var b = tf.tanh(a); var b = tf.tanh(a);
var g = tf.gradients(b, a); var g = tf.gradients(b, a);
using (var sess = tf.Session())
{
var result = sess.run(g);
var actual = result[0];
Assert.AreEqual(actual, 0.41997434127f);
}
var sess = tf.Session();
var result = sess.run(g);
var actual = result[0];
Assert.AreEqual(actual, 0.41997434127f);
} }




@@ -227,14 +215,12 @@ namespace TensorFlowNET.UnitTest.Gradient
var a = tf.constant(5f); var a = tf.constant(5f);
var b = tf.lgamma(a); var b = tf.lgamma(a);
var g = tf.gradients(b, a); var g = tf.gradients(b, a);
using (var sess = tf.Session())
{
var result = sess.run(new object[] { g, b });
var actualDeriv = result[0];
var actual = result[1];
Assert.AreEqual(actualDeriv, 1.5061177f);
Assert.AreEqual(actual, 3.17805386f);
}
var sess = tf.Session();
var result = sess.run(new object[] { g, b });
var actualDeriv = result[0];
var actual = result[1];
Assert.AreEqual(actualDeriv, 1.5061177f);
Assert.AreEqual(actual, 3.17805386f);
} }


[TestMethod] [TestMethod]
@@ -247,14 +233,12 @@ namespace TensorFlowNET.UnitTest.Gradient
tf.constant(new[] { 1 }, tf.int32, new[] { 1 }) tf.constant(new[] { 1 }, tf.int32, new[] { 1 })
); );
var g = tf.gradients(b, a); var g = tf.gradients(b, a);
using (var sess = tf.Session())
{
var result = sess.run(new object[] { g, b });
var actualDeriv = np.squeeze(result[0]);
var actual = np.squeeze(result[1]);
Assert.AreEqual(actualDeriv, new float[] { 1, 0 });
Assert.AreEqual(actual, 0.9640276f);
}
var sess = tf.Session();
var result = sess.run(new object[] { g, b });
var actualDeriv = np.squeeze(result[0]);
var actual = np.squeeze(result[1]);
Assert.AreEqual(actualDeriv, new float[] { 1, 0 });
Assert.AreEqual(actual, 0.9640276f);
} }


[TestMethod] [TestMethod]
@@ -264,14 +248,12 @@ namespace TensorFlowNET.UnitTest.Gradient
var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 });
var a = tf.concat(new List<Tensor>(new[] { a1, a2 }), 0); var a = tf.concat(new List<Tensor>(new[] { a1, a2 }), 0);
var g = tf.gradients(a, a1); var g = tf.gradients(a, a1);
using (var sess = tf.Session())
{
var result = sess.run(new object[] { g, a });
var actualDeriv = result[0][0];
var actual = result[1][0];
Assert.AreEqual(actualDeriv, 1f);
Assert.AreEqual(actual, 2f);
}
var sess = tf.Session();
var result = sess.run(new object[] { g, a });
var actualDeriv = result[0][0];
var actual = result[1][0];
Assert.AreEqual(actualDeriv, 1f);
Assert.AreEqual(actual, 2f);
} }


[TestMethod] [TestMethod]
@@ -280,13 +262,12 @@ namespace TensorFlowNET.UnitTest.Gradient
var ap = tf.constant(1f); var ap = tf.constant(1f);
var b = tf.tanh(ap) + gen_array_ops.stop_gradient(ap); var b = tf.tanh(ap) + gen_array_ops.stop_gradient(ap);
var g = tf.gradients(b, ap); var g = tf.gradients(b, ap);
using (var sess = tf.Session())
{
var result = sess.run(g);
var actual = result[0];
Assert.AreEqual(actual, 0.41997434127f);
}
var sess = tf.Session();
var result = sess.run(g);
var actual = result[0];
Assert.AreEqual(actual, 0.41997434127f);
} }

[Ignore("TODO")] [Ignore("TODO")]
[TestMethod] [TestMethod]
public void testUnusedOutput() public void testUnusedOutput()


+ 12
- 14
test/TensorFlowNET.Graph.UnitTest/ImageTest.cs View File

@@ -74,23 +74,21 @@ namespace TensorFlowNET.UnitTest
var cropSize2_2 = tf.Variable(np.array(4, 4)); var cropSize2_2 = tf.Variable(np.array(4, 4));


var init = tf.global_variables_initializer(); var init = tf.global_variables_initializer();
using (Session sess = tf.Session())
{
sess.run(init);
var sess = tf.Session();
sess.run(init);


var cropped = tf.image.crop_and_resize(image, box, boxInd, cropSize1_1);
var cropped = tf.image.crop_and_resize(image, box, boxInd, cropSize1_1);


var result = sess.run(cropped);
// check if cropped to 1x1 center was succesfull
Assert.AreEqual(result.size, 1ul);
Assert.AreEqual(result[0, 0, 0, 0], 4f);
var result = sess.run(cropped);
// check if cropped to 1x1 center was succesfull
Assert.AreEqual(result.size, 1ul);
Assert.AreEqual(result[0, 0, 0, 0], 4f);


cropped = tf.image.crop_and_resize(image2, box, boxInd, cropSize2_2);
result = sess.run(cropped);
// check if flipped and no cropping occured
Assert.AreEqual(result.size, 16ul);
Assert.AreEqual(result[0, 0, 0, 0], 12f);
}
cropped = tf.image.crop_and_resize(image2, box, boxInd, cropSize2_2);
result = sess.run(cropped);
// check if flipped and no cropping occured
Assert.AreEqual(result.size, 16ul);
Assert.AreEqual(result[0, 0, 0, 0], 12f);
} }
} }
} }

+ 8
- 8
test/TensorFlowNET.Graph.UnitTest/MultithreadingTests.cs View File

@@ -24,7 +24,7 @@ namespace TensorFlowNET.UnitTest
{ {
Assert.IsNull(tf.peak_default_graph()); Assert.IsNull(tf.peak_default_graph());


using var sess = tf.Session();
var sess = tf.Session();
var default_graph = tf.get_default_graph(); var default_graph = tf.get_default_graph();
var sess_graph = sess.graph; var sess_graph = sess.graph;
Assert.IsNotNull(default_graph); Assert.IsNotNull(default_graph);
@@ -45,7 +45,7 @@ namespace TensorFlowNET.UnitTest
{ {
Assert.IsNull(tf.peak_default_graph()); Assert.IsNull(tf.peak_default_graph());
//tf.Session created an other graph //tf.Session created an other graph
using var sess = tf.Session();
var sess = tf.Session();
var default_graph = tf.get_default_graph(); var default_graph = tf.get_default_graph();
var sess_graph = sess.graph; var sess_graph = sess.graph;
Assert.IsNotNull(default_graph); Assert.IsNotNull(default_graph);
@@ -69,7 +69,7 @@ namespace TensorFlowNET.UnitTest
beforehand.as_default(); beforehand.as_default();
Assert.IsNotNull(tf.peak_default_graph()); Assert.IsNotNull(tf.peak_default_graph());


using var sess = tf.Session();
var sess = tf.Session();
var default_graph = tf.peak_default_graph(); var default_graph = tf.peak_default_graph();
var sess_graph = sess.graph; var sess_graph = sess.graph;
Assert.IsNotNull(default_graph); Assert.IsNotNull(default_graph);
@@ -102,7 +102,7 @@ namespace TensorFlowNET.UnitTest
//the core method //the core method
void Core(int tid) void Core(int tid)
{ {
using var sess = tf.Session();
var sess = tf.Session();
for (int i = 0; i < 100; i++) for (int i = 0; i < 100; i++)
{ {
var t = new Tensor(1); var t = new Tensor(1);
@@ -119,7 +119,7 @@ namespace TensorFlowNET.UnitTest
void Core(int tid) void Core(int tid)
{ {
//tf.Session created an other graph //tf.Session created an other graph
using var sess = tf.Session();
var sess = tf.Session();
for (int i = 0; i < 100; i++) for (int i = 0; i < 100; i++)
{ {
var t = new Tensor(new int[] { 1, 2, 3 }); var t = new Tensor(new int[] { 1, 2, 3 });
@@ -142,7 +142,7 @@ namespace TensorFlowNET.UnitTest
var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 });
var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 });
var math = a1 + a2; var math = a1 + a2;
using var sess = tf.Session(graph);
var sess = tf.Session(graph);
for (int i = 0; i < 100; i++) for (int i = 0; i < 100; i++)
{ {
var result = sess.run(math); var result = sess.run(math);
@@ -162,7 +162,7 @@ namespace TensorFlowNET.UnitTest
tf.compat.v1.disable_eager_execution(); tf.compat.v1.disable_eager_execution();
var graph = tf.Graph().as_default(); var graph = tf.Graph().as_default();


using var sess = tf.Session(graph);
var sess = tf.Session(graph);
Assert.IsNotNull(tf.get_default_graph()); Assert.IsNotNull(tf.get_default_graph());
//graph is created automatically to perform create these operations //graph is created automatically to perform create these operations
var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 });
@@ -182,7 +182,7 @@ namespace TensorFlowNET.UnitTest
//the core method //the core method
void Core(int tid) void Core(int tid)
{ {
using var sess = tf.Session();
var sess = tf.Session();
Assert.IsNotNull(tf.get_default_graph()); Assert.IsNotNull(tf.get_default_graph());
//graph is created automatically to perform create these operations //graph is created automatically to perform create these operations
var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 });


+ 453
- 675
test/TensorFlowNET.Graph.UnitTest/OperationsTest.cs
File diff suppressed because it is too large
View File


+ 14
- 16
test/TensorFlowNET.Graph.UnitTest/PythonTest.cs View File

@@ -182,23 +182,21 @@ namespace TensorFlowNET.UnitTest
// return self._eval_helper(tensors) // return self._eval_helper(tensors)
// else: // else:
{ {
using (var sess = tf.Session())
var sess = tf.Session();
var ndarray = tensor.eval(sess);
if (typeof(T) == typeof(double))
{ {
var ndarray = tensor.eval(sess);
if (typeof(T) == typeof(double))
{
double x = ndarray;
result = x;
}
else if (typeof(T) == typeof(int))
{
int x = ndarray;
result = x;
}
else
{
result = ndarray;
}
double x = ndarray;
result = x;
}
else if (typeof(T) == typeof(int))
{
int x = ndarray;
result = x;
}
else
{
result = ndarray;
} }


return (T)result; return (T)result;


+ 2
- 4
test/TensorFlowNET.Native.UnitTest/Attributes/AttributesTestcs.cs View File

@@ -48,7 +48,7 @@ namespace Tensorflow.Native.UnitTest


private void EXPECT_TF_META(Operation oper, string attr_name, int expected_list_size, TF_AttrType expected_type, uint expected_total_size) private void EXPECT_TF_META(Operation oper, string attr_name, int expected_list_size, TF_AttrType expected_type, uint expected_total_size)
{ {
var m = c_api.TF_OperationGetAttrMetadata(oper, attr_name, s_.Handle);
var m = c_api.TF_OperationGetAttrMetadata(oper, attr_name, s_);
EXPECT_EQ(TF_Code.TF_OK, s_.Code); EXPECT_EQ(TF_Code.TF_OK, s_.Code);
char e = expected_list_size >= 0 ? (char)1 : (char)0; char e = expected_list_size >= 0 ? (char)1 : (char)0;
/*EXPECT_EQ(e, m.is_list); /*EXPECT_EQ(e, m.is_list);
@@ -63,7 +63,7 @@ namespace Tensorflow.Native.UnitTest
var desc = init("string"); var desc = init("string");
c_api.TF_SetAttrString(desc, "v", "bunny", 5); c_api.TF_SetAttrString(desc, "v", "bunny", 5);


var oper = c_api.TF_FinishOperation(desc, s_.Handle);
var oper = c_api.TF_FinishOperation(desc, s_);
//ASSERT_EQ(TF_Code.TF_OK, s_.Code); //ASSERT_EQ(TF_Code.TF_OK, s_.Code);
//EXPECT_TF_META(oper, "v", -1, TF_AttrType.TF_ATTR_STRING, 5); //EXPECT_TF_META(oper, "v", -1, TF_AttrType.TF_ATTR_STRING, 5);
//var value = new char[5]; //var value = new char[5];
@@ -86,8 +86,6 @@ namespace Tensorflow.Native.UnitTest


public void Dispose() public void Dispose()
{ {
graph_.Dispose();
s_.Dispose();
} }
} }
} }

+ 1
- 3
test/TensorFlowNET.Native.UnitTest/CApiColocationTest.cs View File

@@ -59,7 +59,7 @@ namespace Tensorflow.Native.UnitTest


private void VerifyCollocation(Operation op, string[] expected) private void VerifyCollocation(Operation op, string[] expected)
{ {
var handle = c_api.TF_OperationGetAttrMetadata(op, "_class", s_.Handle);
var handle = c_api.TF_OperationGetAttrMetadata(op, "_class", s_);
TF_AttrMetadata m = new TF_AttrMetadata(); TF_AttrMetadata m = new TF_AttrMetadata();
if (expected.Length == 0) if (expected.Length == 0)
{ {
@@ -98,8 +98,6 @@ namespace Tensorflow.Native.UnitTest


public void Dispose() public void Dispose()
{ {
graph_.Dispose();
s_.Dispose();
} }
} }
} }

+ 2
- 2
test/TensorFlowNET.Native.UnitTest/CApiTest.cs View File

@@ -45,10 +45,10 @@ namespace Tensorflow.Native.UnitTest
=> c_api.TF_AddInput(desc, input); => c_api.TF_AddInput(desc, input);


protected Operation TF_FinishOperation(OperationDescription desc, Status s) protected Operation TF_FinishOperation(OperationDescription desc, Status s)
=> c_api.TF_FinishOperation(desc, s.Handle);
=> c_api.TF_FinishOperation(desc, s);


protected void TF_SetAttrTensor(OperationDescription desc, string attrName, Tensor value, Status s) protected void TF_SetAttrTensor(OperationDescription desc, string attrName, Tensor value, Status s)
=> c_api.TF_SetAttrTensor(desc, attrName, value, s.Handle);
=> c_api.TF_SetAttrTensor(desc, attrName, value, s);


protected void TF_SetAttrType(OperationDescription desc, string attrName, TF_DataType dtype) protected void TF_SetAttrType(OperationDescription desc, string attrName, TF_DataType dtype)
=> c_api.TF_SetAttrType(desc, attrName, dtype); => c_api.TF_SetAttrType(desc, attrName, dtype);


+ 3
- 3
test/TensorFlowNET.Native.UnitTest/Functions/FunctionTest.cs View File

@@ -18,7 +18,7 @@ namespace Tensorflow.Native.UnitTest
string func_name_ = "MyFunc"; string func_name_ = "MyFunc";
string func_node_name_ = "MyFunc_0"; string func_node_name_ = "MyFunc_0";
Status s_; Status s_;
IntPtr func_;
SafeFuncGraphHandle func_;


[TestInitialize] [TestInitialize]
public void Initialize() public void Initialize()
@@ -402,7 +402,7 @@ namespace Tensorflow.Native.UnitTest
inputs.Length, inputs.ToArray(), inputs.Length, inputs.ToArray(),
outputs.Length, outputs.ToArray(), outputs.Length, outputs.ToArray(),
output_names == null || output_names.Length == 0 ? null : output_names, output_names == null || output_names.Length == 0 ? null : output_names,
IntPtr.Zero, null, s_.Handle);
IntPtr.Zero, null, s_);


if (expect_failure) if (expect_failure)
{ {
@@ -413,7 +413,7 @@ namespace Tensorflow.Native.UnitTest
ASSERT_EQ(TF_OK, s_.Code, s_.Message); ASSERT_EQ(TF_OK, s_.Code, s_.Message);
ASSERT_NE(func_, IntPtr.Zero); ASSERT_NE(func_, IntPtr.Zero);
ASSERT_EQ(func_name_, c_api.StringPiece(c_api.TF_FunctionName(func_))); ASSERT_EQ(func_name_, c_api.StringPiece(c_api.TF_FunctionName(func_)));
c_api.TF_GraphCopyFunction(host_graph_, func_, IntPtr.Zero, s_.Handle);
c_api.TF_GraphCopyFunction(host_graph_, func_, IntPtr.Zero, s_);
ASSERT_EQ(TF_OK, s_.Code, s_.Message); ASSERT_EQ(TF_OK, s_.Code, s_.Message);
} }




+ 10
- 17
test/TensorFlowNET.Native.UnitTest/Gradients/GradientsTest.cs View File

@@ -44,18 +44,14 @@ namespace Tensorflow.Native.UnitTest
private bool GetGraphDef(Graph graph, out GraphDef graph_def) private bool GetGraphDef(Graph graph, out GraphDef graph_def)
{ {
graph_def = null; graph_def = null;
using (var s = new Status())
{
using (var buffer = new Buffer())
{
c_api.TF_GraphToGraphDef(graph, buffer.Handle, s.Handle);
bool ret = TF_GetCode(s) == TF_OK;
EXPECT_EQ(TF_OK, TF_GetCode(s));
if (ret)
graph_def = GraphDef.Parser.ParseFrom(buffer.ToArray());
return ret;
}
}
var s = new Status();
var buffer = new Buffer();
c_api.TF_GraphToGraphDef(graph, buffer, s);
bool ret = TF_GetCode(s) == TF_OK;
EXPECT_EQ(TF_OK, TF_GetCode(s));
if (ret)
graph_def = GraphDef.Parser.ParseFrom(buffer.ToArray());
return ret;
} }


private void RunGraphsAndCompareOutputs(TF_Output[] grad_outputs, TF_Output[] expected_grad_outputs) private void RunGraphsAndCompareOutputs(TF_Output[] grad_outputs, TF_Output[] expected_grad_outputs)
@@ -111,9 +107,9 @@ namespace Tensorflow.Native.UnitTest


IntPtr[] handles = new IntPtr[2] { IntPtr.Zero, IntPtr.Zero }; IntPtr[] handles = new IntPtr[2] { IntPtr.Zero, IntPtr.Zero };
c_api.TF_AddGradientsWithPrefix(graph_, prefix, outputs, noutputs, inputs, c_api.TF_AddGradientsWithPrefix(graph_, prefix, outputs, noutputs, inputs,
ninputs, grad_inputs, s_.Handle, handles);
ninputs, grad_inputs, s_, handles);


var op = new Operation(handles[0]);
// var op = new Operation(handles[0]);
} }
else else
{ {
@@ -275,9 +271,6 @@ namespace Tensorflow.Native.UnitTest


public void Dispose() public void Dispose()
{ {
graph_.Dispose();
expected_graph_.Dispose();
s_.Dispose();
} }
} }
} }

+ 1
- 1
test/TensorFlowNET.Native.UnitTest/Graphs/GraphBuildTest.cs View File

@@ -9,7 +9,7 @@ namespace Tensorflow.Native.UnitTest
[TestMethod, Ignore("Waiting to merge https://github.com/tensorflow/tensorflow/pull/43383")] [TestMethod, Ignore("Waiting to merge https://github.com/tensorflow/tensorflow/pull/43383")]
public void UpdateEdge() public void UpdateEdge()
{ {
using var graph = new Graph().as_default();
var graph = new Graph().as_default();


var one = tf.constant(1, name: "one"); var one = tf.constant(1, name: "one");
var two = tf.constant(2, name: "two"); var two = tf.constant(2, name: "two");


+ 15
- 27
test/TensorFlowNET.Native.UnitTest/Graphs/GraphTest.cs View File

@@ -35,7 +35,7 @@ namespace Tensorflow.Native.UnitTest
EXPECT_EQ(attr_value.Type, DataType.DtInt32); EXPECT_EQ(attr_value.Type, DataType.DtInt32);


// Test not found errors in TF_Operation*() query functions. // Test not found errors in TF_Operation*() query functions.
EXPECT_EQ(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s.Handle));
EXPECT_EQ(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s));
EXPECT_EQ(TF_Code.TF_INVALID_ARGUMENT, s.Code); EXPECT_EQ(TF_Code.TF_INVALID_ARGUMENT, s.Code);
Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s)); Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s));
EXPECT_EQ("Operation 'feed' has no attr named 'missing'.", s.Message); EXPECT_EQ("Operation 'feed' has no attr named 'missing'.", s.Message);
@@ -191,9 +191,6 @@ namespace Tensorflow.Native.UnitTest
ASSERT_TRUE(found_scalar_const); ASSERT_TRUE(found_scalar_const);
ASSERT_TRUE(found_add); ASSERT_TRUE(found_add);
ASSERT_TRUE(found_neg); ASSERT_TRUE(found_neg);

graph.Dispose();
s.Dispose();
} }


/// <summary> /// <summary>
@@ -213,16 +210,15 @@ namespace Tensorflow.Native.UnitTest


// Export to a GraphDef. // Export to a GraphDef.
var graph_def = new Buffer(); var graph_def = new Buffer();
c_api.TF_GraphToGraphDef(graph, graph_def.Handle, s.Handle);
c_api.TF_GraphToGraphDef(graph, graph_def, s);
EXPECT_EQ(TF_Code.TF_OK, s.Code); EXPECT_EQ(TF_Code.TF_OK, s.Code);


// Import it, with a prefix, in a fresh graph. // Import it, with a prefix, in a fresh graph.
graph.Dispose();
graph = new Graph().as_default(); graph = new Graph().as_default();
using (var opts = c_api.TF_NewImportGraphDefOptions()) using (var opts = c_api.TF_NewImportGraphDefOptions())
{ {
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported");
c_api.TF_GraphImportGraphDef(graph, graph_def.Handle, opts, s.Handle);
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s);
EXPECT_EQ(TF_Code.TF_OK, s.Code); EXPECT_EQ(TF_Code.TF_OK, s.Code);
} }


@@ -265,7 +261,7 @@ namespace Tensorflow.Native.UnitTest
EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts)); EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts));
c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar"); c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar");
EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts)); EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts));
var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def.Handle, opts, s.Handle);
var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s);
EXPECT_EQ(TF_Code.TF_OK, s.Code); EXPECT_EQ(TF_Code.TF_OK, s.Code);


return results; return results;
@@ -305,7 +301,7 @@ namespace Tensorflow.Native.UnitTest
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported3"); c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported3");
c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed); c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed);
c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2); c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2);
c_api.TF_GraphImportGraphDef(graph, graph_def.Handle, opts, s.Handle);
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s);
EXPECT_EQ(TF_Code.TF_OK, s.Code); EXPECT_EQ(TF_Code.TF_OK, s.Code);
} }


@@ -330,7 +326,7 @@ namespace Tensorflow.Native.UnitTest


// Export to a graph def so we can import a graph with control dependencies // Export to a graph def so we can import a graph with control dependencies
graph_def = new Buffer(); graph_def = new Buffer();
c_api.TF_GraphToGraphDef(graph, graph_def.Handle, s.Handle);
c_api.TF_GraphToGraphDef(graph, graph_def, s);
EXPECT_EQ(TF_Code.TF_OK, s.Code); EXPECT_EQ(TF_Code.TF_OK, s.Code);


// Import again, with remapped control dependency, into the same graph // Import again, with remapped control dependency, into the same graph
@@ -338,7 +334,7 @@ namespace Tensorflow.Native.UnitTest
{ {
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported4"); c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported4");
c_api.TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed); c_api.TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed);
c_api.TF_GraphImportGraphDef(graph, graph_def.Handle, opts, s.Handle);
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s);
ASSERT_EQ(TF_Code.TF_OK, s.Code); ASSERT_EQ(TF_Code.TF_OK, s.Code);
} }


@@ -380,7 +376,6 @@ namespace Tensorflow.Native.UnitTest
ASSERT_EQ(TF_Code.TF_OK, s.Code); ASSERT_EQ(TF_Code.TF_OK, s.Code);


// Import it in a fresh graph with return outputs. // Import it in a fresh graph with return outputs.
graph.Dispose();
graph = new Graph().as_default(); graph = new Graph().as_default();
var opts = new ImportGraphDefOptions(); var opts = new ImportGraphDefOptions();
opts.AddReturnOutput("feed", 0); opts.AddReturnOutput("feed", 0);
@@ -401,11 +396,6 @@ namespace Tensorflow.Native.UnitTest
EXPECT_EQ(0, return_outputs[0].index); EXPECT_EQ(0, return_outputs[0].index);
EXPECT_EQ(scalar, return_outputs[1].oper); EXPECT_EQ(scalar, return_outputs[1].oper);
EXPECT_EQ(0, return_outputs[1].index); EXPECT_EQ(0, return_outputs[1].index);

opts.Dispose();
graph_def.Dispose();
graph.Dispose();
s.Dispose();
} }


/// <summary> /// <summary>
@@ -422,16 +412,14 @@ namespace Tensorflow.Native.UnitTest
public void ImportGraphMeta() public void ImportGraphMeta()
{ {
var dir = "my-save-dir/"; var dir = "my-save-dir/";
using (var sess = tf.Session())
{
var new_saver = tf.train.import_meta_graph(dir + "my-model-10000.meta");
new_saver.restore(sess, dir + "my-model-10000");
var labels = tf.constant(0, dtype: tf.int32, shape: new int[] { 100 }, name: "labels");
var batch_size = tf.size(labels);
var logits = tf.get_collection<ITensorOrOperation>("logits")[0] as Tensor;
var loss = tf.losses.sparse_softmax_cross_entropy(labels: labels,
logits: logits);
}
var sess = tf.Session();
var new_saver = tf.train.import_meta_graph(dir + "my-model-10000.meta");
new_saver.restore(sess, dir + "my-model-10000");
var labels = tf.constant(0, dtype: tf.int32, shape: new int[] { 100 }, name: "labels");
var batch_size = tf.size(labels);
var logits = tf.get_collection<ITensorOrOperation>("logits")[0] as Tensor;
var loss = tf.losses.sparse_softmax_cross_entropy(labels: labels,
logits: logits);
} }
} }
} }

+ 4
- 7
test/TensorFlowNET.Native.UnitTest/Sessions/CSession.cs View File

@@ -11,7 +11,7 @@ namespace Tensorflow.Native.UnitTest
/// </summary> /// </summary>
public class CSession public class CSession
{ {
private IntPtr session_;
private SafeSessionHandle session_;


private List<TF_Output> inputs_ = new List<TF_Output>(); private List<TF_Output> inputs_ = new List<TF_Output>();
private List<Tensor> input_values_ = new List<Tensor>(); private List<Tensor> input_values_ = new List<Tensor>();
@@ -22,11 +22,8 @@ namespace Tensorflow.Native.UnitTest


public CSession(Graph graph, Status s, bool user_XLA = false) public CSession(Graph graph, Status s, bool user_XLA = false)
{ {
lock (Locks.ProcessWide)
{
var config = new ConfigProto { InterOpParallelismThreads = 4 };
session_ = new Session(graph, config, s);
}
var config = new ConfigProto { InterOpParallelismThreads = 4 };
session_ = new Session(graph, config, s);
} }


public void SetInputs(Dictionary<Operation, Tensor> inputs) public void SetInputs(Dictionary<Operation, Tensor> inputs)
@@ -85,7 +82,7 @@ namespace Tensorflow.Native.UnitTest
c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length, c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length,
outputs_ptr, output_values_ptr, outputs_.Count, outputs_ptr, output_values_ptr, outputs_.Count,
targets_ptr, targets_.Count, targets_ptr, targets_.Count,
IntPtr.Zero, s.Handle);
IntPtr.Zero, s);


s.Check(); s.Check();




+ 2
- 2
test/TensorFlowNET.Native.UnitTest/Sessions/SessionTest.cs View File

@@ -14,8 +14,8 @@ namespace Tensorflow.Native.UnitTest.Sessions
[TestMethod] [TestMethod]
public void Session() public void Session()
{ {
using var s = new Status();
using var graph = new Graph();
var s = new Status();
var graph = new Graph();


// Make a placeholder operation. // Make a placeholder operation.
var feed = c_test_util.Placeholder(graph, s); var feed = c_test_util.Placeholder(graph, s);


+ 16
- 17
test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs View File

@@ -139,45 +139,45 @@ namespace Tensorflow.Native.UnitTest.Tensors
var feed_out_0 = new TF_Output(feed, 0); var feed_out_0 = new TF_Output(feed, 0);


// Fetch the shape, it should be completely unknown. // Fetch the shape, it should be completely unknown.
int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle);
int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);


Assert.IsTrue(s.Code == TF_Code.TF_OK); Assert.IsTrue(s.Code == TF_Code.TF_OK);
EXPECT_EQ(-1, num_dims); EXPECT_EQ(-1, num_dims);


// Set the shape to be unknown, expect no change. // Set the shape to be unknown, expect no change.
c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s.Handle);
c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK); Assert.IsTrue(s.Code == TF_Code.TF_OK);
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle);
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);
EXPECT_EQ(-1, num_dims); EXPECT_EQ(-1, num_dims);


// Set the shape to be 2 x Unknown // Set the shape to be 2 x Unknown
long[] dims = { 2, -1 }; long[] dims = { 2, -1 };
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s.Handle);
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK); Assert.IsTrue(s.Code == TF_Code.TF_OK);
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle);
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);
EXPECT_EQ(2, num_dims); EXPECT_EQ(2, num_dims);


// Get the dimension vector appropriately. // Get the dimension vector appropriately.
var returned_dims = new long[dims.Length]; var returned_dims = new long[dims.Length];
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle);
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK); Assert.IsTrue(s.Code == TF_Code.TF_OK);
Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims));


// Set to a new valid shape: [2, 3] // Set to a new valid shape: [2, 3]
dims[1] = 3; dims[1] = 3;
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s.Handle);
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK); Assert.IsTrue(s.Code == TF_Code.TF_OK);


// Fetch and see that the new value is returned. // Fetch and see that the new value is returned.
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle);
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK); Assert.IsTrue(s.Code == TF_Code.TF_OK);
Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims));


// Try to set 'unknown' with unknown rank on the shape and see that // Try to set 'unknown' with unknown rank on the shape and see that
// it doesn't change. // it doesn't change.
c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s.Handle);
c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK); Assert.IsTrue(s.Code == TF_Code.TF_OK);
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle);
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK); Assert.IsTrue(s.Code == TF_Code.TF_OK);
EXPECT_EQ(2, num_dims); EXPECT_EQ(2, num_dims);
EXPECT_EQ(2, (int)returned_dims[0]); EXPECT_EQ(2, (int)returned_dims[0]);
@@ -187,21 +187,21 @@ namespace Tensorflow.Native.UnitTest.Tensors
// it doesn't change. // it doesn't change.
dims[0] = -1; dims[0] = -1;
dims[1] = -1; dims[1] = -1;
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s.Handle);
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK); Assert.IsTrue(s.Code == TF_Code.TF_OK);
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle);
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK); Assert.IsTrue(s.Code == TF_Code.TF_OK);
EXPECT_EQ(2, num_dims); EXPECT_EQ(2, num_dims);
EXPECT_EQ(2, (int)returned_dims[0]); EXPECT_EQ(2, (int)returned_dims[0]);
EXPECT_EQ(3, (int)returned_dims[1]); EXPECT_EQ(3, (int)returned_dims[1]);


// Try to fetch a shape with the wrong num_dims // Try to fetch a shape with the wrong num_dims
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s.Handle);
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s);
Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT);


// Try to set an invalid shape (cannot change 2x3 to a 2x5). // Try to set an invalid shape (cannot change 2x3 to a 2x5).
dims[1] = 5; dims[1] = 5;
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s.Handle);
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s);
Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT);


// Test for a scalar. // Test for a scalar.
@@ -209,14 +209,13 @@ namespace Tensorflow.Native.UnitTest.Tensors
Assert.IsTrue(s.Code == TF_Code.TF_OK); Assert.IsTrue(s.Code == TF_Code.TF_OK);
var three_out_0 = new TF_Output(three, 0); var three_out_0 = new TF_Output(three, 0);


num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s.Handle);
num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK); Assert.IsTrue(s.Code == TF_Code.TF_OK);
EXPECT_EQ(0, num_dims); EXPECT_EQ(0, num_dims);
c_api.TF_GraphGetTensorShape(graph, feed_out_0, dims, num_dims, s.Handle);
c_api.TF_GraphGetTensorShape(graph, feed_out_0, dims, num_dims, s);
Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT);


graph.Exit(); graph.Exit();
s.Dispose();
} }
} }
} }

+ 20
- 28
test/TensorFlowNET.Native.UnitTest/c_test_util.cs View File

@@ -23,7 +23,7 @@ namespace Tensorflow.Native.UnitTest


c_api.TF_AddInputList(desc, inputs, inputs.Length); c_api.TF_AddInputList(desc, inputs, inputs.Length);


var op = c_api.TF_FinishOperation(desc, s.Handle);
var op = c_api.TF_FinishOperation(desc, s);
s.Check(); s.Check();


return op; return op;
@@ -33,37 +33,29 @@ namespace Tensorflow.Native.UnitTest
[SuppressMessage("ReSharper", "RedundantAssignment")] [SuppressMessage("ReSharper", "RedundantAssignment")]
public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s)
{ {
lock (Locks.ProcessWide)
{
using (var buffer = new Buffer())
{
c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer.Handle, s.Handle);
attr_value = AttrValue.Parser.ParseFrom(buffer.ToArray());
}
var buffer = new Buffer();


return s.Code == TF_Code.TF_OK;
}
c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s);
attr_value = AttrValue.Parser.ParseFrom(buffer.ToArray());

return s.Code == TF_Code.TF_OK;
} }


public static GraphDef GetGraphDef(Graph graph) public static GraphDef GetGraphDef(Graph graph)
{ {
lock (Locks.ProcessWide)
{
using (var s = new Status())
using (var buffer = new Buffer())
{
c_api.TF_GraphToGraphDef(graph, buffer.Handle, s.Handle);
s.Check();
return GraphDef.Parser.ParseFrom(buffer.ToArray());
}
}
var s = new Status();
var buffer = new Buffer();

c_api.TF_GraphToGraphDef(graph, buffer, s);
s.Check();
return GraphDef.Parser.ParseFrom(buffer.ToArray());
} }


public static FunctionDef GetFunctionDef(IntPtr func)
public static FunctionDef GetFunctionDef(SafeFuncGraphHandle func)
{ {
using var s = new Status();
using var buffer = new Buffer();
c_api.TF_FunctionToFunctionDef(func, buffer.Handle, s.Handle);
var s = new Status();
var buffer = new Buffer();
c_api.TF_FunctionToFunctionDef(func, buffer, s);
s.Check(true); s.Check(true);
var func_def = FunctionDef.Parser.ParseFrom(buffer.ToArray()); var func_def = FunctionDef.Parser.ParseFrom(buffer.ToArray());
return func_def; return func_def;
@@ -192,7 +184,7 @@ namespace Tensorflow.Native.UnitTest
OperationDescription desc = c_api.TF_NewOperation(graph, "Neg", name); OperationDescription desc = c_api.TF_NewOperation(graph, "Neg", name);
var neg_input = new TF_Output(n, 0); var neg_input = new TF_Output(n, 0);
c_api.TF_AddInput(desc, neg_input); c_api.TF_AddInput(desc, neg_input);
var op = c_api.TF_FinishOperation(desc, s.Handle);
var op = c_api.TF_FinishOperation(desc, s);
s.Check(); s.Check();


return op; return op;
@@ -210,7 +202,7 @@ namespace Tensorflow.Native.UnitTest
c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length); c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length);
} }


var op = c_api.TF_FinishOperation(desc, s.Handle);
var op = c_api.TF_FinishOperation(desc, s);
s.Check(); s.Check();


return op; return op;
@@ -222,10 +214,10 @@ namespace Tensorflow.Native.UnitTest
lock (Locks.ProcessWide) lock (Locks.ProcessWide)
{ {
var desc = c_api.TF_NewOperation(graph, "Const", name); var desc = c_api.TF_NewOperation(graph, "Const", name);
c_api.TF_SetAttrTensor(desc, "value", t, s.Handle);
c_api.TF_SetAttrTensor(desc, "value", t, s);
s.Check(); s.Check();
c_api.TF_SetAttrType(desc, "dtype", t.dtype); c_api.TF_SetAttrType(desc, "dtype", t.dtype);
var op = c_api.TF_FinishOperation(desc, s.Handle);
var op = c_api.TF_FinishOperation(desc, s);
s.Check(); s.Check();


return op; return op;


+ 16
- 22
test/TensorFlowNET.UnitTest/Basics/TrainSaverTest.cs View File

@@ -17,10 +17,8 @@ namespace TensorFlowNET.UnitTest.Basics


public void ImportGraph() public void ImportGraph()
{ {
using (var sess = tf.Session())
{
var new_saver = tf.train.import_meta_graph("C:/tmp/my-model.meta");
}
var sess = tf.Session();
var new_saver = tf.train.import_meta_graph("C:/tmp/my-model.meta");


//tf.train.export_meta_graph(filename: "linear_regression.meta.bin"); //tf.train.export_meta_graph(filename: "linear_regression.meta.bin");
// import meta // import meta
@@ -60,14 +58,12 @@ namespace TensorFlowNET.UnitTest.Basics
// Add ops to save and restore all the variables. // Add ops to save and restore all the variables.
var saver = tf.train.Saver(); var saver = tf.train.Saver();


using (var sess = tf.Session())
{
sess.run(init_op);
var sess = tf.Session();
sess.run(init_op);


// Save the variables to disk.
var save_path = saver.save(sess, "/tmp/model1.ckpt");
Console.WriteLine($"Model saved in path: {save_path}");
}
// Save the variables to disk.
var save_path = saver.save(sess, "/tmp/model1.ckpt");
Console.WriteLine($"Model saved in path: {save_path}");
} }


public void Save2() public void Save2()
@@ -84,17 +80,15 @@ namespace TensorFlowNET.UnitTest.Basics
// Add ops to save and restore all the variables. // Add ops to save and restore all the variables.
var saver = tf.train.Saver(); var saver = tf.train.Saver();


using (var sess = tf.Session())
{
sess.run(init_op);
// o some work with the model.
inc_v1.op.run();
dec_v2.op.run();

// Save the variables to disk.
var save_path = saver.save(sess, "/tmp/model2.ckpt");
Console.WriteLine($"Model saved in path: {save_path}");
}
var sess = tf.Session();
sess.run(init_op);
// o some work with the model.
inc_v1.op.run();
dec_v2.op.run();

// Save the variables to disk.
var save_path = saver.save(sess, "/tmp/model2.ckpt");
Console.WriteLine($"Model saved in path: {save_path}");
} }
} }
} }

+ 4
- 6
test/TensorFlowNET.UnitTest/ManagedAPI/ControlFlowApiTest.cs View File

@@ -57,12 +57,10 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
var input = tf.placeholder(TF_DataType.TF_FLOAT, new Shape(6)); var input = tf.placeholder(TF_DataType.TF_FLOAT, new Shape(6));
var scan = tf.scan(fn, input); var scan = tf.scan(fn, input);


using (var sess = tf.Session())
{
sess.run(tf.global_variables_initializer());
var result = sess.run(scan, new FeedItem(input, np.array(1, 2, 3, 4, 5, 6)));
Assert.AreEqual(new float[] { 1, 3, 6, 10, 15, 21 }, result.ToArray<float>());
}
var sess = tf.Session();
sess.run(tf.global_variables_initializer());
var result = sess.run(scan, new FeedItem(input, np.array(1, 2, 3, 4, 5, 6)));
Assert.AreEqual(new float[] { 1, 3, 6, 10, 15, 21 }, result.ToArray<float>());
} }
} }
} }

+ 14
- 16
test/TensorFlowNET.UnitTest/PythonTest.cs View File

@@ -196,23 +196,21 @@ namespace TensorFlowNET.UnitTest
// return self._eval_helper(tensors) // return self._eval_helper(tensors)
// else: // else:
{ {
using (var sess = tf.Session())
var sess = tf.Session();
var ndarray = tensor.eval(sess);
if (typeof(T) == typeof(double))
{ {
var ndarray = tensor.eval(sess);
if (typeof(T) == typeof(double))
{
double x = ndarray;
result = x;
}
else if (typeof(T) == typeof(int))
{
int x = ndarray;
result = x;
}
else
{
result = ndarray;
}
double x = ndarray;
result = x;
}
else if (typeof(T) == typeof(int))
{
int x = ndarray;
result = x;
}
else
{
result = ndarray;
} }


return (T)result; return (T)result;


+ 0
- 1
test/TensorFlowNET.UnitTest/StatusTest.cs View File

@@ -28,7 +28,6 @@ namespace TensorFlowNET.UnitTest.Basics
public void DeleteStatus() public void DeleteStatus()
{ {
var s = new Status(); var s = new Status();
s.Dispose();
} }
} }
} }

Loading…
Cancel
Save