@@ -23,11 +23,9 @@ namespace Tensorflow | |||||
var x = tf.placeholder(tf.float64, shape: (1024, 1024)); | var x = tf.placeholder(tf.float64, shape: (1024, 1024)); | ||||
var log = tf.log(x); | var log = tf.log(x); | ||||
using (var sess = tf.Session()) | |||||
{ | |||||
var ones = np.ones((1024, 1024), dtype: np.float64); | |||||
var o = sess.run(log, new FeedItem(x, ones)); | |||||
} | |||||
var sess = tf.Session(); | |||||
var ones = np.ones((1024, 1024), dtype: np.float64); | |||||
var o = sess.run(log, new FeedItem(x, ones)); | |||||
// Thread.Sleep(1); | // Thread.Sleep(1); | ||||
} | } | ||||
@@ -25,15 +25,15 @@ namespace Tensorflow | |||||
/// <summary> | /// <summary> | ||||
/// Represents a TF_Buffer that can be passed to Tensorflow. | /// Represents a TF_Buffer that can be passed to Tensorflow. | ||||
/// </summary> | /// </summary> | ||||
public sealed class Buffer : IDisposable | |||||
public sealed class Buffer | |||||
{ | { | ||||
public SafeBufferHandle Handle { get; } | |||||
SafeBufferHandle _handle; | |||||
/// <remarks> | /// <remarks> | ||||
/// <inheritdoc cref="SafeHandleLease" path="/devdoc/usage"/> | /// <inheritdoc cref="SafeHandleLease" path="/devdoc/usage"/> | ||||
/// </remarks> | /// </remarks> | ||||
private unsafe ref readonly TF_Buffer DangerousBuffer | private unsafe ref readonly TF_Buffer DangerousBuffer | ||||
=> ref Unsafe.AsRef<TF_Buffer>(Handle.DangerousGetHandle().ToPointer()); | |||||
=> ref Unsafe.AsRef<TF_Buffer>(_handle.DangerousGetHandle().ToPointer()); | |||||
/// <summary> | /// <summary> | ||||
/// The memory block representing this buffer. | /// The memory block representing this buffer. | ||||
@@ -59,7 +59,7 @@ namespace Tensorflow | |||||
{ | { | ||||
get | get | ||||
{ | { | ||||
using (Handle.Lease()) | |||||
using (_handle.Lease()) | |||||
{ | { | ||||
return DangerousBuffer.length; | return DangerousBuffer.length; | ||||
} | } | ||||
@@ -67,13 +67,13 @@ namespace Tensorflow | |||||
} | } | ||||
public Buffer() | public Buffer() | ||||
=> Handle = TF_NewBuffer(); | |||||
=> _handle = TF_NewBuffer(); | |||||
public Buffer(SafeBufferHandle handle) | public Buffer(SafeBufferHandle handle) | ||||
=> Handle = handle; | |||||
=> _handle = handle; | |||||
public Buffer(byte[] data) | public Buffer(byte[] data) | ||||
=> Handle = _toBuffer(data); | |||||
=> _handle = _toBuffer(data); | |||||
private static SafeBufferHandle _toBuffer(byte[] data) | private static SafeBufferHandle _toBuffer(byte[] data) | ||||
{ | { | ||||
@@ -92,7 +92,7 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
public unsafe byte[] ToArray() | public unsafe byte[] ToArray() | ||||
{ | { | ||||
using (Handle.Lease()) | |||||
using (_handle.Lease()) | |||||
{ | { | ||||
ref readonly TF_Buffer buffer = ref DangerousBuffer; | ref readonly TF_Buffer buffer = ref DangerousBuffer; | ||||
@@ -107,7 +107,12 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
public void Dispose() | |||||
=> Handle.Dispose(); | |||||
public override string ToString() | |||||
=> $"0x{_handle.DangerousGetHandle():x16}"; | |||||
public static implicit operator SafeBufferHandle(Buffer buffer) | |||||
{ | |||||
return buffer._handle; | |||||
} | |||||
} | } | ||||
} | } |
@@ -11,7 +11,7 @@ public class CheckpointReader | |||||
Status status = new Status(); | Status status = new Status(); | ||||
VariableToDataTypeMap = new Dictionary<string, TF_DataType>(); | VariableToDataTypeMap = new Dictionary<string, TF_DataType>(); | ||||
VariableToShapeMap = new Dictionary<string, Shape>(); | VariableToShapeMap = new Dictionary<string, Shape>(); | ||||
_handle = c_api.TF_NewCheckpointReader(filename, status.Handle); | |||||
_handle = c_api.TF_NewCheckpointReader(filename, status); | |||||
status.Check(true); | status.Check(true); | ||||
ReadAllShapeAndType(); | ReadAllShapeAndType(); | ||||
} | } | ||||
@@ -38,7 +38,7 @@ public class CheckpointReader | |||||
int num_dims = GetVariableNumDims(name); | int num_dims = GetVariableNumDims(name); | ||||
long[] dims = new long[num_dims]; | long[] dims = new long[num_dims]; | ||||
Status status = new Status(); | Status status = new Status(); | ||||
c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status.Handle); | |||||
c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status); | |||||
status.Check(true); | status.Check(true); | ||||
return new Shape(dims); | return new Shape(dims); | ||||
} | } | ||||
@@ -49,7 +49,7 @@ public class CheckpointReader | |||||
public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid) | public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid) | ||||
{ | { | ||||
Status status = new Status(); | Status status = new Status(); | ||||
var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status.Handle); | |||||
var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status); | |||||
status.Check(true); | status.Check(true); | ||||
return new Tensor(tensor); | return new Tensor(tensor); | ||||
} | } | ||||
@@ -37,7 +37,7 @@ namespace Tensorflow.Contexts | |||||
public void log_device_placement(bool enable) | public void log_device_placement(bool enable) | ||||
{ | { | ||||
if (_handle != null) | if (_handle != null) | ||||
c_api.TFE_ContextSetLogDevicePlacement(_handle, enable, tf.Status.Handle); | |||||
c_api.TFE_ContextSetLogDevicePlacement(_handle, enable, tf.Status); | |||||
_log_device_placement = enable; | _log_device_placement = enable; | ||||
// _thread_local_data.function_call_options = null; | // _thread_local_data.function_call_options = null; | ||||
} | } | ||||
@@ -60,15 +60,15 @@ namespace Tensorflow.Contexts | |||||
public PhysicalDevice[] list_physical_devices(string device_type = null) | public PhysicalDevice[] list_physical_devices(string device_type = null) | ||||
{ | { | ||||
using var opts = c_api.TFE_NewContextOptions(); | using var opts = c_api.TFE_NewContextOptions(); | ||||
using var ctx = c_api.TFE_NewContext(opts, tf.Status.Handle); | |||||
using var devices = c_api.TFE_ContextListDevices(ctx, tf.Status.Handle); | |||||
using var ctx = c_api.TFE_NewContext(opts, tf.Status); | |||||
using var devices = c_api.TFE_ContextListDevices(ctx, tf.Status); | |||||
tf.Status.Check(true); | tf.Status.Check(true); | ||||
int num_devices = c_api.TF_DeviceListCount(devices); | int num_devices = c_api.TF_DeviceListCount(devices); | ||||
var results = new List<PhysicalDevice>(); | var results = new List<PhysicalDevice>(); | ||||
for (int i = 0; i < num_devices; ++i) | for (int i = 0; i < num_devices; ++i) | ||||
{ | { | ||||
var dev_type = c_api.StringPiece(c_api.TF_DeviceListType(devices, i, tf.Status.Handle)); | |||||
var dev_type = c_api.StringPiece(c_api.TF_DeviceListType(devices, i, tf.Status)); | |||||
tf.Status.Check(true); | tf.Status.Check(true); | ||||
if (dev_type.StartsWith("XLA")) | if (dev_type.StartsWith("XLA")) | ||||
@@ -76,7 +76,7 @@ namespace Tensorflow.Contexts | |||||
if (device_type == null || dev_type == device_type) | if (device_type == null || dev_type == device_type) | ||||
{ | { | ||||
var dev_name = c_api.TF_DeviceListName(devices, i, tf.Status.Handle); | |||||
var dev_name = c_api.TF_DeviceListName(devices, i, tf.Status); | |||||
tf.Status.Check(true); | tf.Status.Check(true); | ||||
results.Add(new PhysicalDevice | results.Add(new PhysicalDevice | ||||
@@ -28,7 +28,7 @@ namespace Tensorflow.Contexts | |||||
/// <summary> | /// <summary> | ||||
/// Environment in which eager operations execute. | /// Environment in which eager operations execute. | ||||
/// </summary> | /// </summary> | ||||
public sealed partial class Context : IDisposable | |||||
public sealed partial class Context | |||||
{ | { | ||||
public const int GRAPH_MODE = 0; | public const int GRAPH_MODE = 0; | ||||
public const int EAGER_MODE = 1; | public const int EAGER_MODE = 1; | ||||
@@ -41,15 +41,7 @@ namespace Tensorflow.Contexts | |||||
public FunctionCallOptions FunctionCallOptions { get; } | public FunctionCallOptions FunctionCallOptions { get; } | ||||
SafeContextHandle _handle; | SafeContextHandle _handle; | ||||
public SafeContextHandle Handle | |||||
{ | |||||
get | |||||
{ | |||||
if (_handle == null) | |||||
ensure_initialized(); | |||||
return _handle; | |||||
} | |||||
} | |||||
int? _seed; | int? _seed; | ||||
Random _rng; | Random _rng; | ||||
@@ -59,6 +51,7 @@ namespace Tensorflow.Contexts | |||||
context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE, false); | context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE, false); | ||||
initialized = false; | initialized = false; | ||||
FunctionCallOptions = new FunctionCallOptions(); | FunctionCallOptions = new FunctionCallOptions(); | ||||
ensure_initialized(); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -72,12 +65,12 @@ namespace Tensorflow.Contexts | |||||
Config = MergeConfig(); | Config = MergeConfig(); | ||||
FunctionCallOptions.Config = Config; | FunctionCallOptions.Config = Config; | ||||
var config_str = Config.ToByteArray(); | var config_str = Config.ToByteArray(); | ||||
using var opts = new ContextOptions(); | |||||
using var status = new Status(); | |||||
c_api.TFE_ContextOptionsSetConfig(opts.Handle, config_str, (ulong)config_str.Length, status.Handle); | |||||
var opts = new ContextOptions(); | |||||
var status = new Status(); | |||||
c_api.TFE_ContextOptionsSetConfig(opts, config_str, (ulong)config_str.Length, status); | |||||
status.Check(true); | status.Check(true); | ||||
c_api.TFE_ContextOptionsSetDevicePlacementPolicy(opts.Handle, _device_policy); | |||||
_handle = c_api.TFE_NewContext(opts.Handle, status.Handle); | |||||
c_api.TFE_ContextOptionsSetDevicePlacementPolicy(opts, _device_policy); | |||||
_handle = c_api.TFE_NewContext(opts, status); | |||||
status.Check(true); | status.Check(true); | ||||
initialized = true; | initialized = true; | ||||
} | } | ||||
@@ -178,10 +171,14 @@ namespace Tensorflow.Contexts | |||||
tf.Context.ensure_initialized(); | tf.Context.ensure_initialized(); | ||||
if (_handle != null) | if (_handle != null) | ||||
{ | |||||
c_api.TFE_ContextClearCaches(_handle); | c_api.TFE_ContextClearCaches(_handle); | ||||
} | |||||
} | } | ||||
public void Dispose() | |||||
=> _handle.Dispose(); | |||||
public static implicit operator SafeContextHandle(Context ctx) | |||||
{ | |||||
return ctx._handle; | |||||
} | |||||
} | } | ||||
} | } |
@@ -14,21 +14,21 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System; | |||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
namespace Tensorflow.Contexts | |||||
namespace Tensorflow.Contexts; | |||||
public sealed class ContextOptions | |||||
{ | { | ||||
public sealed class ContextOptions : IDisposable | |||||
{ | |||||
public SafeContextOptionsHandle Handle { get; } | |||||
SafeContextOptionsHandle _handle { get; } | |||||
public ContextOptions() | |||||
{ | |||||
Handle = c_api.TFE_NewContextOptions(); | |||||
} | |||||
public ContextOptions() | |||||
{ | |||||
_handle = c_api.TFE_NewContextOptions(); | |||||
} | |||||
public void Dispose() | |||||
=> Handle.Dispose(); | |||||
public static implicit operator SafeContextOptionsHandle(ContextOptions opt) | |||||
{ | |||||
return opt._handle; | |||||
} | } | ||||
} | } |
@@ -43,7 +43,7 @@ namespace Tensorflow.Eager | |||||
{ | { | ||||
var status = tf.Status; | var status = tf.Status; | ||||
var op = GetOp(ctx, op_name, status); | var op = GetOp(ctx, op_name, status); | ||||
c_api.TFE_OpSetDevice(op, device_name, status.Handle); | |||||
c_api.TFE_OpSetDevice(op, device_name, status); | |||||
if (status.ok()) | if (status.ok()) | ||||
{ | { | ||||
for (int i = 0; i < inputs.Length; ++i) | for (int i = 0; i < inputs.Length; ++i) | ||||
@@ -54,7 +54,7 @@ namespace Tensorflow.Eager | |||||
Tensor nd => nd.EagerTensorHandle, | Tensor nd => nd.EagerTensorHandle, | ||||
_ => throw new NotImplementedException("Eager tensor handle has not been allocated.") | _ => throw new NotImplementedException("Eager tensor handle has not been allocated.") | ||||
}; | }; | ||||
c_api.TFE_OpAddInput(op, tensor_handle, status.Handle); | |||||
c_api.TFE_OpAddInput(op, tensor_handle, status); | |||||
status.Check(true); | status.Check(true); | ||||
} | } | ||||
} | } | ||||
@@ -64,7 +64,7 @@ namespace Tensorflow.Eager | |||||
var outputs = new SafeEagerTensorHandle[num_outputs]; | var outputs = new SafeEagerTensorHandle[num_outputs]; | ||||
if (status.ok()) | if (status.ok()) | ||||
{ | { | ||||
c_api.TFE_Execute(op, outputs, out num_outputs, status.Handle); | |||||
c_api.TFE_Execute(op, outputs, out num_outputs, status); | |||||
status.Check(true); | status.Check(true); | ||||
} | } | ||||
return outputs.Select(x => new EagerTensor(x)).ToArray(); | return outputs.Select(x => new EagerTensor(x)).ToArray(); | ||||
@@ -104,7 +104,7 @@ namespace Tensorflow.Eager | |||||
var eager_tensor = ops.convert_to_tensor(fast_input_array[j]); | var eager_tensor = ops.convert_to_tensor(fast_input_array[j]); | ||||
attr_values[j] = eager_tensor.dtype; | attr_values[j] = eager_tensor.dtype; | ||||
c_api.TFE_OpAddInput(op, eager_tensor.EagerTensorHandle, status.Handle); | |||||
c_api.TFE_OpAddInput(op, eager_tensor.EagerTensorHandle, status); | |||||
if (op_exec_info.run_callbacks) | if (op_exec_info.run_callbacks) | ||||
{ | { | ||||
@@ -142,7 +142,7 @@ namespace Tensorflow.Eager | |||||
} | } | ||||
var retVals = new SafeEagerTensorHandle[num_retvals]; | var retVals = new SafeEagerTensorHandle[num_retvals]; | ||||
c_api.TFE_Execute(op, retVals, out num_retvals, status.Handle); | |||||
c_api.TFE_Execute(op, retVals, out num_retvals, status); | |||||
status.Check(true); | status.Check(true); | ||||
var flat_result = retVals.Select(x => new EagerTensor(x)).ToArray(); | var flat_result = retVals.Select(x => new EagerTensor(x)).ToArray(); | ||||
@@ -160,10 +160,10 @@ namespace Tensorflow.Eager | |||||
SafeEagerOpHandle GetOp(Context ctx, string op_or_function_name, Status status) | SafeEagerOpHandle GetOp(Context ctx, string op_or_function_name, Status status) | ||||
{ | { | ||||
if (thread_local_eager_operation_map.find(op_or_function_name, out var op)) | if (thread_local_eager_operation_map.find(op_or_function_name, out var op)) | ||||
c_api.TFE_OpReset(op, op_or_function_name, ctx.DeviceName, status.Handle); | |||||
c_api.TFE_OpReset(op, op_or_function_name, ctx.DeviceName, status); | |||||
else | else | ||||
{ | { | ||||
op = c_api.TFE_NewOp(ctx.Handle, op_or_function_name, status.Handle); | |||||
op = c_api.TFE_NewOp(ctx, op_or_function_name, status); | |||||
thread_local_eager_operation_map[op_or_function_name] = op; | thread_local_eager_operation_map[op_or_function_name] = op; | ||||
} | } | ||||
@@ -219,7 +219,7 @@ namespace Tensorflow.Eager | |||||
flattened_attrs.Add(dtype); | flattened_attrs.Add(dtype); | ||||
} | } | ||||
c_api.TFE_OpAddInput(op, tensor.EagerTensorHandle, status.Handle); | |||||
c_api.TFE_OpAddInput(op, tensor.EagerTensorHandle, status); | |||||
status.Check(true); | status.Check(true); | ||||
return true; | return true; | ||||
@@ -235,7 +235,7 @@ namespace Tensorflow.Eager | |||||
var value = attrs[i + 1]; | var value = attrs[i + 1]; | ||||
byte is_list = 0; | byte is_list = 0; | ||||
var type = c_api.TFE_OpGetAttrType(op, key, ref is_list, status.Handle); | |||||
var type = c_api.TFE_OpGetAttrType(op, key, ref is_list, status); | |||||
if (!status.ok()) return; | if (!status.ok()) return; | ||||
if (is_list != 0) | if (is_list != 0) | ||||
SetOpAttrList(tf.Context, op, key, value as object[], type, null, status); | SetOpAttrList(tf.Context, op, key, value as object[], type, null, status); | ||||
@@ -264,7 +264,7 @@ namespace Tensorflow.Eager | |||||
Status status) | Status status) | ||||
{ | { | ||||
byte is_list = 0; | byte is_list = 0; | ||||
var type = c_api.TFE_OpGetAttrType(op, attr_name, ref is_list, status.Handle); | |||||
var type = c_api.TFE_OpGetAttrType(op, attr_name, ref is_list, status); | |||||
if (status.Code != TF_Code.TF_OK) return; | if (status.Code != TF_Code.TF_OK) return; | ||||
if (attr_value == null) | if (attr_value == null) | ||||
@@ -305,7 +305,7 @@ namespace Tensorflow.Eager | |||||
tf.memcpy(dims[i], values1[i].dims, values1[i].ndim * sizeof(long)); | tf.memcpy(dims[i], values1[i].dims, values1[i].ndim * sizeof(long)); | ||||
} | } | ||||
c_api.TFE_OpSetAttrShapeList(op, key, dims, num_dims, num_values, status.Handle); | |||||
c_api.TFE_OpSetAttrShapeList(op, key, dims, num_dims, num_values, status); | |||||
Array.ForEach(dims, x => Marshal.FreeHGlobal(x)); | Array.ForEach(dims, x => Marshal.FreeHGlobal(x)); | ||||
} | } | ||||
else if (type == TF_AttrType.TF_ATTR_TYPE && values is TF_DataType[] values2) | else if (type == TF_AttrType.TF_ATTR_TYPE && values is TF_DataType[] values2) | ||||
@@ -353,7 +353,7 @@ namespace Tensorflow.Eager | |||||
break; | break; | ||||
case TF_AttrType.TF_ATTR_SHAPE: | case TF_AttrType.TF_ATTR_SHAPE: | ||||
var dims = (value as long[]).ToArray(); | var dims = (value as long[]).ToArray(); | ||||
c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status.Handle); | |||||
c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status); | |||||
status.Check(true); | status.Check(true); | ||||
break; | break; | ||||
case TF_AttrType.TF_ATTR_FUNC: | case TF_AttrType.TF_ATTR_FUNC: | ||||
@@ -54,7 +54,7 @@ namespace Tensorflow.Eager | |||||
void NewEagerTensorHandle(SafeTensorHandle h) | void NewEagerTensorHandle(SafeTensorHandle h) | ||||
{ | { | ||||
_id = ops.uid(); | _id = ops.uid(); | ||||
_eagerTensorHandle = c_api.TFE_NewTensorHandle(h, tf.Status.Handle); | |||||
_eagerTensorHandle = c_api.TFE_NewTensorHandle(h, tf.Status); | |||||
#if TRACK_TENSOR_LIFE | #if TRACK_TENSOR_LIFE | ||||
Console.WriteLine($"New EagerTensor {_eagerTensorHandle}"); | Console.WriteLine($"New EagerTensor {_eagerTensorHandle}"); | ||||
#endif | #endif | ||||
@@ -65,7 +65,7 @@ namespace Tensorflow.Eager | |||||
{ | { | ||||
if (_handle != null) | if (_handle != null) | ||||
return; | return; | ||||
_handle = c_api.TFE_TensorHandleResolve(_eagerTensorHandle, tf.Status.Handle); | |||||
_handle = c_api.TFE_TensorHandleResolve(_eagerTensorHandle, tf.Status); | |||||
tf.Status.Check(true); | tf.Status.Check(true); | ||||
} | } | ||||
@@ -24,10 +24,10 @@ namespace Tensorflow.Eager | |||||
} | } | ||||
} | } | ||||
public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(_eagerTensorHandle, tf.Status.Handle)); | |||||
public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(_eagerTensorHandle, tf.Status)); | |||||
public override TF_DataType dtype => c_api.TFE_TensorHandleDataType(_eagerTensorHandle); | public override TF_DataType dtype => c_api.TFE_TensorHandleDataType(_eagerTensorHandle); | ||||
public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.Status.Handle); | |||||
public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.Status); | |||||
public override ulong bytesize | public override ulong bytesize | ||||
{ | { | ||||
@@ -49,9 +49,9 @@ namespace Tensorflow.Eager | |||||
protected override Shape GetShapeInternal() | protected override Shape GetShapeInternal() | ||||
{ | { | ||||
var dims = new int[c_api.TFE_TensorHandleNumDims(_eagerTensorHandle, tf.Status.Handle)]; | |||||
var dims = new int[c_api.TFE_TensorHandleNumDims(_eagerTensorHandle, tf.Status)]; | |||||
for (int i = 0; i < dims.Length; i++) | for (int i = 0; i < dims.Length; i++) | ||||
dims[i] = c_api.TFE_TensorHandleDim(_eagerTensorHandle, i, tf.Status.Handle); | |||||
dims[i] = c_api.TFE_TensorHandleDim(_eagerTensorHandle, i, tf.Status); | |||||
return dims; | return dims; | ||||
} | } | ||||
@@ -64,15 +64,15 @@ namespace Tensorflow.Eager | |||||
public static int GetRank(IntPtr handle) | public static int GetRank(IntPtr handle) | ||||
{ | { | ||||
var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); | var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); | ||||
return c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, tf.Status.Handle); | |||||
return c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, tf.Status); | |||||
} | } | ||||
public static int[] GetDims(IntPtr handle) | public static int[] GetDims(IntPtr handle) | ||||
{ | { | ||||
var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); | var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); | ||||
var dims = new int[c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, tf.Status.Handle)]; | |||||
var dims = new int[c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, tf.Status)]; | |||||
for (int i = 0; i < dims.Length; i++) | for (int i = 0; i < dims.Length; i++) | ||||
dims[i] = c_api.TFE_TensorHandleDim(tfe_tensor_handle, i, tf.Status.Handle); | |||||
dims[i] = c_api.TFE_TensorHandleDim(tfe_tensor_handle, i, tf.Status); | |||||
return dims; | return dims; | ||||
} | } | ||||
@@ -114,7 +114,7 @@ namespace Tensorflow | |||||
/// <param name="function"></param> | /// <param name="function"></param> | ||||
/// <param name="status"></param> | /// <param name="status"></param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_ContextAddFunction(SafeContextHandle ctx, IntPtr function, SafeStatusHandle status); | |||||
public static extern void TFE_ContextAddFunction(SafeContextHandle ctx, SafeFuncGraphHandle function, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// Removes a function from the context. Once removed, you can no longer | /// Removes a function from the context. Once removed, you can no longer | ||||
@@ -56,15 +56,14 @@ namespace Tensorflow | |||||
TF_ImportGraphDefResults results = null; | TF_ImportGraphDefResults results = null; | ||||
var bytes = graph_def.ToByteString().ToArray(); | var bytes = graph_def.ToByteString().ToArray(); | ||||
using (var buffer = c_api_util.tf_buffer(bytes)) | |||||
using (var scoped_options = c_api_util.ScopedTFImportGraphDefOptions()) | |||||
using (var status = new Status()) | |||||
{ | |||||
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); | |||||
// need to create a class ImportGraphDefWithResults with IDisposal | |||||
results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer.Handle, scoped_options.Handle, status.Handle)); | |||||
status.Check(true); | |||||
} | |||||
var buffer = c_api_util.tf_buffer(bytes); | |||||
var scoped_options = c_api_util.ScopedTFImportGraphDefOptions(); | |||||
var status = new Status(); | |||||
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); | |||||
// need to create a class ImportGraphDefWithResults with IDisposal | |||||
results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status)); | |||||
status.Check(true); | |||||
_ProcessNewOps(graph); | _ProcessNewOps(graph); | ||||
@@ -116,13 +115,13 @@ namespace Tensorflow | |||||
Dictionary<string, Tensor> input_map, | Dictionary<string, Tensor> input_map, | ||||
string[] return_elements) | string[] return_elements) | ||||
{ | { | ||||
c_api.TF_ImportGraphDefOptionsSetPrefix(options.Handle, prefix); | |||||
c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options.Handle, (char)1); | |||||
c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix); | |||||
c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, (char)1); | |||||
foreach (var input in input_map) | foreach (var input in input_map) | ||||
{ | { | ||||
var (src_name, src_index) = _ParseTensorName(input.Key); | var (src_name, src_index) = _ParseTensorName(input.Key); | ||||
c_api.TF_ImportGraphDefOptionsAddInputMapping(options.Handle, src_name, src_index, input.Value._as_tf_output()); | |||||
c_api.TF_ImportGraphDefOptionsAddInputMapping(options, src_name, src_index, input.Value._as_tf_output()); | |||||
} | } | ||||
if (return_elements == null) | if (return_elements == null) | ||||
@@ -133,11 +132,11 @@ namespace Tensorflow | |||||
if (name.Contains(":")) | if (name.Contains(":")) | ||||
{ | { | ||||
var (op_name, index) = _ParseTensorName(name); | var (op_name, index) = _ParseTensorName(name); | ||||
c_api.TF_ImportGraphDefOptionsAddReturnOutput(options.Handle, op_name, index); | |||||
c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
c_api.TF_ImportGraphDefOptionsAddReturnOperation(options.Handle, name); | |||||
c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, name); | |||||
} | } | ||||
} | } | ||||
@@ -33,7 +33,7 @@ namespace Tensorflow | |||||
if (_registered_ops.Count > 0) | if (_registered_ops.Count > 0) | ||||
return _registered_ops; | return _registered_ops; | ||||
using var buffer = new Buffer(c_api.TF_GetAllOpList()); | |||||
var buffer = new Buffer(c_api.TF_GetAllOpList()); | |||||
var op_list = OpList.Parser.ParseFrom(buffer.ToArray()); | var op_list = OpList.Parser.ParseFrom(buffer.ToArray()); | ||||
foreach (var op_def in op_list.Op) | foreach (var op_def in op_list.Op) | ||||
_registered_ops[op_def.Name] = op_def; | _registered_ops[op_def.Name] = op_def; | ||||
@@ -56,8 +56,8 @@ namespace Tensorflow.Framework | |||||
if (pred_value is null) | if (pred_value is null) | ||||
{ | { | ||||
var result = range(pred.op.NumOutputs).Select(x => IntPtr.Zero).ToArray(); | var result = range(pred.op.NumOutputs).Select(x => IntPtr.Zero).ToArray(); | ||||
var evaluated = c_api.TF_TryEvaluateConstant(pred.graph, pred._as_tf_output(), result, tf.Status.Handle); | |||||
if (!evaluated || c_api.TF_GetCode(tf.Status.Handle) != TF_Code.TF_OK) | |||||
var evaluated = c_api.TF_TryEvaluateConstant(pred.graph, pred._as_tf_output(), result, tf.Status); | |||||
if (!evaluated || c_api.TF_GetCode(tf.Status) != TF_Code.TF_OK) | |||||
return null; | return null; | ||||
else | else | ||||
throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
@@ -34,10 +34,10 @@ namespace Tensorflow | |||||
/// <param name="output_func_def"></param> | /// <param name="output_func_def"></param> | ||||
/// <param name="status"></param> | /// <param name="status"></param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TF_FunctionToFunctionDef(IntPtr func, SafeBufferHandle output_func_def, SafeStatusHandle status); | |||||
public static extern void TF_FunctionToFunctionDef(SafeFuncGraphHandle func, SafeBufferHandle output_func_def, SafeStatusHandle status); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern IntPtr TF_GraphToFunction(IntPtr fn_body, string fn_name, | |||||
public static extern SafeFuncGraphHandle TF_GraphToFunction(SafeGraphHandle fn_body, string fn_name, | |||||
bool append_hash_to_fn_name, | bool append_hash_to_fn_name, | ||||
int num_opers, IntPtr[] opers, | int num_opers, IntPtr[] opers, | ||||
int ninputs, TF_Output[] inputs, | int ninputs, TF_Output[] inputs, | ||||
@@ -48,12 +48,12 @@ namespace Tensorflow | |||||
SafeStatusHandle status); | SafeStatusHandle status); | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern IntPtr TF_FunctionSetAttrValueProto(IntPtr func, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status); | |||||
public static extern IntPtr TF_FunctionSetAttrValueProto(SafeFuncGraphHandle func, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern IntPtr TF_FunctionName(IntPtr func); | |||||
public static extern IntPtr TF_FunctionName(SafeFuncGraphHandle func); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TF_GraphCopyFunction(IntPtr g, IntPtr func, IntPtr grad, SafeStatusHandle status); | |||||
public static extern void TF_GraphCopyFunction(SafeGraphHandle g, SafeFuncGraphHandle func, IntPtr grad, SafeStatusHandle status); | |||||
} | } | ||||
} | } |
@@ -37,7 +37,7 @@ namespace Tensorflow | |||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
/// <param name="dy">TF_Output*</param> | /// <param name="dy">TF_Output*</param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TF_AddGradientsWithPrefix(IntPtr g, string prefix, TF_Output[] y, int ny, | |||||
public static extern void TF_AddGradientsWithPrefix(SafeGraphHandle g, string prefix, TF_Output[] y, int ny, | |||||
TF_Output[] x, int nx, TF_Output[] dx, SafeStatusHandle status, IntPtr[] dy); | TF_Output[] x, int nx, TF_Output[] dx, SafeStatusHandle status, IntPtr[] dy); | ||||
} | } | ||||
} | } |
@@ -22,21 +22,19 @@ namespace Tensorflow | |||||
var inputs_string = string.Join(",", inputs); | var inputs_string = string.Join(",", inputs); | ||||
var outputs_string = string.Join(",", outputs); | var outputs_string = string.Join(",", outputs); | ||||
var transforms_string = string.Join(" ", transforms); | var transforms_string = string.Join(" ", transforms); | ||||
using (var status = new Status()) | |||||
{ | |||||
var buffer = new Buffer(); | |||||
var len = c_api.TransformGraphWithStringInputs(input_graph_def_string, | |||||
input_graph_def_string.Length, | |||||
inputs_string, | |||||
outputs_string, | |||||
transforms_string, | |||||
buffer.Handle, | |||||
status.Handle); | |||||
var status = new Status(); | |||||
var buffer = new Buffer(); | |||||
var len = c_api.TransformGraphWithStringInputs(input_graph_def_string, | |||||
input_graph_def_string.Length, | |||||
inputs_string, | |||||
outputs_string, | |||||
transforms_string, | |||||
buffer, | |||||
status); | |||||
status.Check(false); | |||||
var bytes = buffer.ToArray(); | |||||
return GraphDef.Parser.ParseFrom(bytes); | |||||
} | |||||
status.Check(false); | |||||
var bytes = buffer.ToArray(); | |||||
return GraphDef.Parser.ParseFrom(bytes); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -37,11 +37,9 @@ namespace Tensorflow.Graphs | |||||
1); | 1); | ||||
return result[0]; | return result[0]; | ||||
} | } | ||||
using (var s = tf.Session(input.graph)) | |||||
{ | |||||
var output = func(input); | |||||
return output; | |||||
} | |||||
var s = tf.Session(input.graph); | |||||
var output = func(input); | |||||
return output; | |||||
}; | }; | ||||
} | } | ||||
@@ -75,12 +73,10 @@ namespace Tensorflow.Graphs | |||||
1); | 1); | ||||
return result[0]; | return result[0]; | ||||
} | } | ||||
using (var s = tf.Session(a.graph)) | |||||
{ | |||||
Debug.Assert(a.graph == b.graph); | |||||
var output = func(a, b); | |||||
return output; | |||||
} | |||||
var s = tf.Session(a.graph); | |||||
Debug.Assert(a.graph == b.graph); | |||||
var output = func(a, b); | |||||
return output; | |||||
}; | }; | ||||
} | } | ||||
} | } | ||||
@@ -1,258 +1,252 @@ | |||||
using Google.Protobuf; | using Google.Protobuf; | ||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
using Tensorflow.Exceptions; | using Tensorflow.Exceptions; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow.Graphs | |||||
namespace Tensorflow.Graphs; | |||||
/// <summary> | |||||
/// Graph representing a function body. | |||||
/// </summary> | |||||
public class FuncGraph : Graph, IDisposable | |||||
{ | { | ||||
SafeFuncGraphHandle _func_graph_handle; | |||||
public string FuncName => _graph_key; | |||||
public Tensors Inputs { get; set; } = new Tensors(); | |||||
public Tensors Outputs { get; set; } = new Tensors(); | |||||
public Dictionary<string, string> Attrs { get; set; } | |||||
Dictionary<long, (Tensor, Tensor)> _captures | |||||
= new Dictionary<long, (Tensor, Tensor)>(); | |||||
public Tensor[] external_captures | |||||
=> _captures.Select(x => x.Value.Item1).ToArray(); | |||||
public (Tensor, Tensor)[] captures | |||||
=> _captures.Values.Select(x => x).ToArray(); | |||||
public Tensor[] internal_captures | |||||
=> _captures.Select(x => x.Value.Item2).ToArray(); | |||||
public Tensor[] captured_inputs | |||||
=> external_captures; | |||||
/// <summary> | /// <summary> | ||||
/// Graph representing a function body. | |||||
/// Construct a new FuncGraph. | |||||
/// </summary> | /// </summary> | ||||
public class FuncGraph : Graph | |||||
public FuncGraph(string name) : base() | |||||
{ | { | ||||
IntPtr _func_graph_handle; | |||||
public string FuncName => _graph_key; | |||||
public Tensors Inputs { get; set; } = new Tensors(); | |||||
public Tensors Outputs { get; set; } = new Tensors(); | |||||
public Dictionary<string, string> Attrs { get; set; } | |||||
outer_graph = ops.get_default_graph(); | |||||
while (outer_graph.building_function) | |||||
outer_graph = outer_graph.OuterGraph; | |||||
_graph_key = name; | |||||
building_function = true; | |||||
} | |||||
Dictionary<long, (Tensor, Tensor)> _captures | |||||
= new Dictionary<long, (Tensor, Tensor)>(); | |||||
public FuncGraph(SafeGraphHandle handle, string name, Dictionary<string, string> attrs) : base() | |||||
{ | |||||
outer_graph = ops.get_default_graph(); | |||||
while (outer_graph.building_function) | |||||
outer_graph = outer_graph.OuterGraph; | |||||
_graph_key = name; | |||||
building_function = true; | |||||
Attrs = attrs; | |||||
// Will to test if FuncGraph has memory leak | |||||
// c_api.TF_DeleteGraph(_handle); | |||||
_handle = handle; | |||||
} | |||||
public Tensor[] external_captures | |||||
=> _captures.Select(x => x.Value.Item1).ToArray(); | |||||
public (Tensor, Tensor)[] captures | |||||
=> _captures.Values.Select(x => x).ToArray(); | |||||
public void ToGraph(Operation[] opers, | |||||
Tensor[] inputs, Tensor[] outputs, | |||||
string[] output_names) | |||||
{ | |||||
var status = new Status(); | |||||
_func_graph_handle = c_api.TF_GraphToFunction(_handle, | |||||
_graph_key, | |||||
false, | |||||
opers.Length, | |||||
opers.Select(x => (IntPtr)x).ToArray(), | |||||
inputs.Length, | |||||
inputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | |||||
outputs.Length, | |||||
outputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | |||||
output_names, | |||||
IntPtr.Zero, | |||||
null, | |||||
status); | |||||
status.Check(true); | |||||
SetAttrs(); | |||||
// c_api.TF_GraphCopyFunction(outer_graph, _func_graph_handle, IntPtr.Zero, status.Handle); | |||||
// status.Check(true); | |||||
c_api.TFE_ContextAddFunction(tf.Context, _func_graph_handle, status); | |||||
status.Check(true); | |||||
_graph_key = c_api.StringPiece(c_api.TF_FunctionName(_func_graph_handle)); | |||||
Inputs = inputs; | |||||
// mark_as_return | |||||
Outputs = outputs;// .Select(x => array_ops.identity(x)).ToArray(); | |||||
} | |||||
public Tensor[] internal_captures | |||||
=> _captures.Select(x => x.Value.Item2).ToArray(); | |||||
public override Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, string name = null, Dictionary<string, AttrValue> attrs = null, OpDef op_def = null, bool compute_device = true) | |||||
{ | |||||
foreach(var (i, inp) in enumerate(inputs)) | |||||
inputs[i] = capture(inp); | |||||
public Tensor[] captured_inputs | |||||
=> external_captures; | |||||
return base.create_op(op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device); | |||||
} | |||||
/// <summary> | |||||
/// Construct a new FuncGraph. | |||||
/// </summary> | |||||
public FuncGraph(string name) : base() | |||||
const int _EAGER_CONST_THRESHOLD = 128; | |||||
public Tensor capture(Tensor tensor, string name = null, Shape shape = null) | |||||
{ | |||||
if(tensor is EagerTensor) | |||||
{ | { | ||||
outer_graph = ops.get_default_graph(); | |||||
while (outer_graph.building_function) | |||||
outer_graph = outer_graph.OuterGraph; | |||||
_graph_key = name; | |||||
building_function = true; | |||||
if (name == null) | |||||
name = ops.uid().ToString(); | |||||
// Small EagerTensors are captured with Const ops | |||||
if (dtypes.is_value_dtype(tensor.dtype) | |||||
&& (tensor.rank == 0 || tensor.size < _EAGER_CONST_THRESHOLD)) | |||||
return capture_eager_tensor(tensor, name); | |||||
// Large EagerTensors and resources are captured with Placeholder ops | |||||
return _capture_helper(tensor, name, shape: shape); | |||||
} | } | ||||
public FuncGraph(IntPtr handle, string name, Dictionary<string, string> attrs) : base() | |||||
if(tensor.graph != this) | |||||
{ | { | ||||
outer_graph = ops.get_default_graph(); | |||||
while (outer_graph.building_function) | |||||
outer_graph = outer_graph.OuterGraph; | |||||
_graph_key = name; | |||||
building_function = true; | |||||
Attrs = attrs; | |||||
// Will to test if FuncGraph has memory leak | |||||
// c_api.TF_DeleteGraph(_handle); | |||||
_handle = handle; | |||||
if (name == null) | |||||
name = tensor.op.name; | |||||
var inner_graph = tensor.graph; | |||||
while(inner_graph != null && inner_graph is FuncGraph inner_func_graph) | |||||
{ | |||||
if (inner_graph == this) | |||||
throw new InaccessibleTensorError($"The tensor '{tensor.name}' cannot be accessed here: it is defined" + | |||||
" in another function or code block. Use return values," + | |||||
" explicit Python locals or TensorFlow collections to access" + | |||||
$" it. Defined in: {tensor.graph.graph_key}; accessed from: {graph_key}."); | |||||
inner_graph = inner_func_graph.outer_graph; | |||||
} | |||||
return _capture_helper(tensor, name); | |||||
} | } | ||||
public void ToGraph(Operation[] opers, | |||||
Tensor[] inputs, Tensor[] outputs, | |||||
string[] output_names) | |||||
return tensor; | |||||
} | |||||
Tensor capture_eager_tensor(Tensor tensor, string name) | |||||
{ | |||||
Tensor graph_const = null; | |||||
if (!_captures.ContainsKey(tensor.Id)) | |||||
{ | { | ||||
var status = new Status(); | |||||
_func_graph_handle = c_api.TF_GraphToFunction(_handle, | |||||
_graph_key, | |||||
false, | |||||
opers.Length, | |||||
opers.Select(x => (IntPtr)x).ToArray(), | |||||
inputs.Length, | |||||
inputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | |||||
outputs.Length, | |||||
outputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | |||||
output_names == null || output_names.Length == 0 ? null : output_names, | |||||
IntPtr.Zero, | |||||
null, | |||||
status.Handle); | |||||
status.Check(true); | |||||
SetAttrs(); | |||||
// c_api.TF_GraphCopyFunction(outer_graph, _func_graph_handle, IntPtr.Zero, status.Handle); | |||||
// status.Check(true); | |||||
c_api.TFE_ContextAddFunction(tf.Context.Handle, _func_graph_handle, status.Handle); | |||||
status.Check(true); | |||||
_graph_key = c_api.StringPiece(c_api.TF_FunctionName(_func_graph_handle)); | |||||
Inputs = inputs; | |||||
// mark_as_return | |||||
Outputs = outputs;// .Select(x => array_ops.identity(x)).ToArray(); | |||||
graph_const = tf_with(ops.control_dependencies(null), ctl | |||||
=> constant_op.constant(tensor.numpy(), dtype: tensor.dtype, shape: tensor.shape, name: name)); | |||||
add_capture(tensor, graph_const); | |||||
} | } | ||||
public override Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, string name = null, Dictionary<string, AttrValue> attrs = null, OpDef op_def = null, bool compute_device = true) | |||||
else | |||||
{ | { | ||||
foreach(var (i, inp) in enumerate(inputs)) | |||||
inputs[i] = capture(inp); | |||||
return base.create_op(op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device); | |||||
graph_const = _captures[tensor.Id].Item2; | |||||
} | } | ||||
const int _EAGER_CONST_THRESHOLD = 128; | |||||
public Tensor capture(Tensor tensor, string name = null, Shape shape = null) | |||||
BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | |||||
{ | { | ||||
if(tensor is EagerTensor) | |||||
{ | |||||
if (name == null) | |||||
name = ops.uid().ToString(); | |||||
// Small EagerTensors are captured with Const ops | |||||
if (dtypes.is_value_dtype(tensor.dtype) | |||||
&& (tensor.rank == 0 || tensor.size < _EAGER_CONST_THRESHOLD)) | |||||
return capture_eager_tensor(tensor, name); | |||||
return output_grads; | |||||
}; | |||||
// Large EagerTensors and resources are captured with Placeholder ops | |||||
return _capture_helper(tensor, name, shape: shape); | |||||
} | |||||
tf.Runner.RecordGradient("captured_value", | |||||
new[] { graph_const }, null, | |||||
new[] { tensor }, | |||||
getBackwardFunction: _backward_function_wrapper | |||||
/*getForwardFunction: forward_function*/); | |||||
if(tensor.graph != this) | |||||
{ | |||||
if (name == null) | |||||
name = tensor.op.name; | |||||
var inner_graph = tensor.graph; | |||||
while(inner_graph != null && inner_graph is FuncGraph inner_func_graph) | |||||
{ | |||||
if (inner_graph == this) | |||||
throw new InaccessibleTensorError($"The tensor '{tensor.name}' cannot be accessed here: it is defined" + | |||||
" in another function or code block. Use return values," + | |||||
" explicit Python locals or TensorFlow collections to access" + | |||||
$" it. Defined in: {tensor.graph.graph_key}; accessed from: {graph_key}."); | |||||
inner_graph = inner_func_graph.outer_graph; | |||||
} | |||||
return _capture_helper(tensor, name); | |||||
} | |||||
return graph_const; | |||||
} | |||||
return tensor; | |||||
Tensor _capture_helper(Tensor tensor, string name, Shape shape = null) | |||||
{ | |||||
Tensor placeholder = null; | |||||
if (!_captures.ContainsKey(tensor.Id)) | |||||
{ | |||||
placeholder = _create_substitute_placeholder(tensor, | |||||
name: name, | |||||
dtype: tensor.dtype, | |||||
shape: shape); | |||||
add_capture(tensor, placeholder); | |||||
} | } | ||||
Tensor capture_eager_tensor(Tensor tensor, string name) | |||||
else | |||||
{ | { | ||||
Tensor graph_const = null; | |||||
if (!_captures.ContainsKey(tensor.Id)) | |||||
{ | |||||
graph_const = tf_with(ops.control_dependencies(null), ctl | |||||
=> constant_op.constant(tensor.numpy(), dtype: tensor.dtype, shape: tensor.shape, name: name)); | |||||
add_capture(tensor, graph_const); | |||||
} | |||||
else | |||||
{ | |||||
graph_const = _captures[tensor.Id].Item2; | |||||
} | |||||
BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | |||||
{ | |||||
return output_grads; | |||||
}; | |||||
tf.Runner.RecordGradient("captured_value", | |||||
new[] { graph_const }, null, | |||||
new[] { tensor }, | |||||
getBackwardFunction: _backward_function_wrapper | |||||
/*getForwardFunction: forward_function*/); | |||||
return graph_const; | |||||
placeholder = _captures[tensor.Id].Item2; | |||||
} | } | ||||
Tensor _capture_helper(Tensor tensor, string name, Shape shape = null) | |||||
BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | |||||
{ | { | ||||
Tensor placeholder = null; | |||||
if (!_captures.ContainsKey(tensor.Id)) | |||||
{ | |||||
placeholder = _create_substitute_placeholder(tensor, | |||||
name: name, | |||||
dtype: tensor.dtype, | |||||
shape: shape); | |||||
add_capture(tensor, placeholder); | |||||
} | |||||
else | |||||
{ | |||||
placeholder = _captures[tensor.Id].Item2; | |||||
} | |||||
return output_grads; | |||||
}; | |||||
BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | |||||
{ | |||||
return output_grads; | |||||
}; | |||||
tf.Runner.RecordGradient("captured_value", | |||||
new[] { placeholder }, null, | |||||
new[] { tensor }, | |||||
getBackwardFunction: _backward_function_wrapper | |||||
/*getForwardFunction: forward_function*/); | |||||
tf.Runner.RecordGradient("captured_value", | |||||
new[] { placeholder }, null, | |||||
new[] { tensor }, | |||||
getBackwardFunction: _backward_function_wrapper | |||||
/*getForwardFunction: forward_function*/); | |||||
return placeholder; | |||||
} | |||||
return placeholder; | |||||
} | |||||
void add_capture(Tensor tensor, Tensor placeholder) | |||||
{ | |||||
_captures.Add(tensor.Id, (tensor, placeholder)); | |||||
Inputs.Add(placeholder); | |||||
} | |||||
void add_capture(Tensor tensor, Tensor placeholder) | |||||
{ | |||||
_captures.Add(tensor.Id, (tensor, placeholder)); | |||||
Inputs.Add(placeholder); | |||||
} | |||||
Tensor _create_substitute_placeholder(Tensor value, | |||||
string name = null, | |||||
TF_DataType dtype = TF_DataType.DtInvalid, | |||||
Shape shape = null) | |||||
{ | |||||
if (shape is null) | |||||
shape = value.shape; | |||||
if (dtype == TF_DataType.DtInvalid) | |||||
dtype = value.dtype; | |||||
var placeholder = tf_with(ops.control_dependencies(null), ctl | |||||
=> array_ops.placeholder(dtype, shape: shape, name: name)); | |||||
// custom_gradient.copy_handle_data(value, placeholder) | |||||
return placeholder; | |||||
} | |||||
Tensor _create_substitute_placeholder(Tensor value, | |||||
string name = null, | |||||
TF_DataType dtype = TF_DataType.DtInvalid, | |||||
Shape shape = null) | |||||
{ | |||||
if (shape is null) | |||||
shape = value.shape; | |||||
if (dtype == TF_DataType.DtInvalid) | |||||
dtype = value.dtype; | |||||
var placeholder = tf_with(ops.control_dependencies(null), ctl | |||||
=> array_ops.placeholder(dtype, shape: shape, name: name)); | |||||
// custom_gradient.copy_handle_data(value, placeholder) | |||||
return placeholder; | |||||
} | |||||
void SetAttrs() | |||||
{ | |||||
if (Attrs == null) | |||||
return; | |||||
void SetAttrs() | |||||
foreach (var (_name, attr_value) in enumerate(Attrs)) | |||||
{ | { | ||||
if (Attrs == null) | |||||
return; | |||||
foreach (var (_name, attr_value) in enumerate(Attrs)) | |||||
var serialized = new AttrValue | |||||
{ | { | ||||
var serialized = new AttrValue | |||||
{ | |||||
S = ByteString.CopyFromUtf8(attr_value) | |||||
}.ToByteArray(); | |||||
c_api.TF_FunctionSetAttrValueProto(_func_graph_handle, _name, serialized, serialized.Length, tf.Status.Handle); | |||||
tf.Status.Check(true); | |||||
} | |||||
S = ByteString.CopyFromUtf8(attr_value) | |||||
}.ToByteArray(); | |||||
c_api.TF_FunctionSetAttrValueProto(_func_graph_handle, _name, serialized, serialized.Length, tf.Status); | |||||
tf.Status.Check(true); | |||||
} | } | ||||
} | |||||
public override Graph as_default() | |||||
{ | |||||
tf.Context.graph_mode(isFunc: true); | |||||
ops.set_default_graph(this); | |||||
return this; | |||||
} | |||||
public override Graph as_default() | |||||
{ | |||||
tf.Context.graph_mode(isFunc: true); | |||||
ops.set_default_graph(this); | |||||
return this; | |||||
} | |||||
public override void Exit() | |||||
{ | |||||
tf.Context.restore_mode(); | |||||
ops.pop_graph(); | |||||
} | |||||
public override void Exit() | |||||
{ | |||||
tf.Context.restore_mode(); | |||||
ops.pop_graph(); | |||||
} | |||||
protected override void DisposeUnmanagedResources(IntPtr handle) | |||||
{ | |||||
c_api.TFE_ContextRemoveFunction(tf.Context.Handle, _graph_key, tf.Status.Handle); | |||||
c_api.TF_DeleteFunction(_func_graph_handle); | |||||
base.DisposeUnmanagedResources(handle); | |||||
} | |||||
public void Dispose() | |||||
{ | |||||
c_api.TFE_ContextRemoveFunction(tf.Context, _graph_key, tf.Status); | |||||
} | } | ||||
} | } |
@@ -24,7 +24,7 @@ namespace Tensorflow | |||||
public Buffer ToGraphDef(Status s) | public Buffer ToGraphDef(Status s) | ||||
{ | { | ||||
var buffer = new Buffer(); | var buffer = new Buffer(); | ||||
c_api.TF_GraphToGraphDef(_handle, buffer.Handle, s.Handle); | |||||
c_api.TF_GraphToGraphDef(_handle, buffer, s); | |||||
s.Check(true); | s.Check(true); | ||||
return buffer; | return buffer; | ||||
@@ -33,14 +33,12 @@ namespace Tensorflow | |||||
private GraphDef _as_graph_def(bool add_shapes = false) | private GraphDef _as_graph_def(bool add_shapes = false) | ||||
{ | { | ||||
GraphDef def; | GraphDef def; | ||||
using (var status = new Status()) | |||||
using (var buffer = ToGraphDef(status)) | |||||
{ | |||||
status.Check(true); | |||||
// limit size to 250M, recursion to max 100 | |||||
var inputStream = CodedInputStream.CreateWithLimits(buffer.DangerousMemoryBlock, 250 * 1024 * 1024, 100); | |||||
def = GraphDef.Parser.ParseFrom(inputStream); | |||||
} | |||||
var status = new Status(); | |||||
var buffer = ToGraphDef(status); | |||||
status.Check(true); | |||||
// limit size to 250M, recursion to max 100 | |||||
var inputStream = CodedInputStream.CreateWithLimits(buffer.DangerousMemoryBlock, 250 * 1024 * 1024, 100); | |||||
def = GraphDef.Parser.ParseFrom(inputStream); | |||||
// Strip the experimental library field iff it's empty. | // Strip the experimental library field iff it's empty. | ||||
// if(def.Library.Function.Count == 0) | // if(def.Library.Function.Count == 0) | ||||
@@ -29,7 +29,7 @@ namespace Tensorflow | |||||
int size = Marshal.SizeOf<TF_Output>(); | int size = Marshal.SizeOf<TF_Output>(); | ||||
var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs); | var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs); | ||||
c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def.Handle, opts.Handle, return_output_handle, num_return_outputs, s.Handle); | |||||
c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s); | |||||
var tf_output_ptr = (TF_Output*)return_output_handle; | var tf_output_ptr = (TF_Output*)return_output_handle; | ||||
for (int i = 0; i < num_return_outputs; i++) | for (int i = 0; i < num_return_outputs; i++) | ||||
@@ -48,15 +48,14 @@ namespace Tensorflow | |||||
public bool Import(byte[] bytes, string prefix = "") | public bool Import(byte[] bytes, string prefix = "") | ||||
{ | { | ||||
using (var opts = new ImportGraphDefOptions()) | |||||
using (var status = new Status()) | |||||
using (var graph_def = new Buffer(bytes)) | |||||
{ | |||||
c_api.TF_ImportGraphDefOptionsSetPrefix(opts.Handle, prefix); | |||||
c_api.TF_GraphImportGraphDef(_handle, graph_def.Handle, opts.Handle, status.Handle); | |||||
status.Check(true); | |||||
return status.Code == TF_Code.TF_OK; | |||||
} | |||||
var opts = new ImportGraphDefOptions(); | |||||
var status = new Status(); | |||||
var graph_def = new Buffer(bytes); | |||||
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, prefix); | |||||
c_api.TF_GraphImportGraphDef(_handle, graph_def, opts, status); | |||||
status.Check(true); | |||||
return status.Code == TF_Code.TF_OK; | |||||
} | } | ||||
public Graph ImportGraphDef(string file_path, string name = null) | public Graph ImportGraphDef(string file_path, string name = null) | ||||
@@ -75,9 +75,9 @@ namespace Tensorflow | |||||
/// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | ||||
/// </summary> | /// </summary> | ||||
/// <remarks>https://www.tensorflow.org/guide/graphs <br></br>https://www.tensorflow.org/api_docs/python/tf/Graph</remarks> | /// <remarks>https://www.tensorflow.org/guide/graphs <br></br>https://www.tensorflow.org/api_docs/python/tf/Graph</remarks> | ||||
public partial class Graph : DisposableObject | |||||
, IEnumerable<Operation> | |||||
public partial class Graph : IEnumerable<Operation> | |||||
{ | { | ||||
protected new SafeGraphHandle _handle; | |||||
private Dictionary<int, ITensorOrOperation> _nodes_by_id; | private Dictionary<int, ITensorOrOperation> _nodes_by_id; | ||||
public Dictionary<string, ITensorOrOperation> _nodes_by_name; | public Dictionary<string, ITensorOrOperation> _nodes_by_name; | ||||
private Dictionary<string, int> _names_in_use; | private Dictionary<string, int> _names_in_use; | ||||
@@ -130,15 +130,6 @@ namespace Tensorflow | |||||
_graph_key = $"graph-{ops.GraphUniqueId()}/"; | _graph_key = $"graph-{ops.GraphUniqueId()}/"; | ||||
} | } | ||||
public Graph(IntPtr handle) | |||||
{ | |||||
_handle = handle; | |||||
_nodes_by_id = new Dictionary<int, ITensorOrOperation>(); | |||||
_nodes_by_name = new Dictionary<string, ITensorOrOperation>(); | |||||
_names_in_use = new Dictionary<string, int>(); | |||||
_graph_key = $"grap-{ops.GraphUniqueId()}/"; | |||||
} | |||||
public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) | public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) | ||||
{ | { | ||||
return _as_graph_element_locked(obj, allow_tensor, allow_operation); | return _as_graph_element_locked(obj, allow_tensor, allow_operation); | ||||
@@ -486,16 +477,6 @@ namespace Tensorflow | |||||
_unfetchable_ops.Add(op); | _unfetchable_ops.Add(op); | ||||
} | } | ||||
protected override void DisposeManagedResources() | |||||
{ | |||||
} | |||||
protected override void DisposeUnmanagedResources(IntPtr handle) | |||||
{ | |||||
c_api.TF_DeleteGraph(handle); | |||||
} | |||||
public Tensor get_tensor_by_tf_output(TF_Output tf_output) | public Tensor get_tensor_by_tf_output(TF_Output tf_output) | ||||
{ | { | ||||
var op = _get_operation_by_tf_operation(tf_output.oper); | var op = _get_operation_by_tf_operation(tf_output.oper); | ||||
@@ -517,14 +498,14 @@ namespace Tensorflow | |||||
public Shape GetTensorShape(TF_Output output) | public Shape GetTensorShape(TF_Output output) | ||||
{ | { | ||||
var status = tf.Status; | var status = tf.Status; | ||||
var ndim = c_api.TF_GraphGetTensorNumDims(_handle, output, status.Handle); | |||||
var ndim = c_api.TF_GraphGetTensorNumDims(_handle, output, status); | |||||
status.Check(); | status.Check(); | ||||
if (ndim == -1) | if (ndim == -1) | ||||
return Shape.Null; | return Shape.Null; | ||||
var dims = new long[ndim]; | var dims = new long[ndim]; | ||||
c_api.TF_GraphGetTensorShape(_handle, output, dims, dims.Length, status.Handle); | |||||
c_api.TF_GraphGetTensorShape(_handle, output, dims, dims.Length, status); | |||||
status.Check(); | status.Check(); | ||||
return new Shape(dims.Select(x => (int)x).ToArray()); | return new Shape(dims.Select(x => (int)x).ToArray()); | ||||
@@ -539,7 +520,7 @@ namespace Tensorflow | |||||
string debugString = string.Empty; | string debugString = string.Empty; | ||||
public override string ToString() | public override string ToString() | ||||
{ | { | ||||
return $"{graph_key}, 0x{_handle.ToString("x16")}"; | |||||
return $"{graph_key}, 0x{_handle.DangerousGetHandle().ToString("x16")}"; | |||||
/*if (string.IsNullOrEmpty(debugString)) | /*if (string.IsNullOrEmpty(debugString)) | ||||
{ | { | ||||
int len = 0; | int len = 0; | ||||
@@ -558,7 +539,7 @@ namespace Tensorflow | |||||
IEnumerator IEnumerable.GetEnumerator() | IEnumerator IEnumerable.GetEnumerator() | ||||
=> throw new NotImplementedException(); | => throw new NotImplementedException(); | ||||
public static implicit operator IntPtr(Graph graph) | |||||
public static implicit operator SafeGraphHandle(Graph graph) | |||||
{ | { | ||||
return graph._handle; | return graph._handle; | ||||
} | } | ||||
@@ -14,28 +14,27 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System; | |||||
namespace Tensorflow; | |||||
namespace Tensorflow | |||||
public sealed class ImportGraphDefOptions | |||||
{ | { | ||||
public sealed class ImportGraphDefOptions : IDisposable | |||||
{ | |||||
public SafeImportGraphDefOptionsHandle Handle { get; } | |||||
SafeImportGraphDefOptionsHandle _handle { get; } | |||||
public int NumReturnOutputs | |||||
=> c_api.TF_ImportGraphDefOptionsNumReturnOutputs(Handle); | |||||
public int NumReturnOutputs | |||||
=> c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle); | |||||
public ImportGraphDefOptions() | |||||
{ | |||||
Handle = c_api.TF_NewImportGraphDefOptions(); | |||||
} | |||||
public ImportGraphDefOptions() | |||||
{ | |||||
_handle = c_api.TF_NewImportGraphDefOptions(); | |||||
} | |||||
public void AddReturnOutput(string name, int index) | |||||
{ | |||||
c_api.TF_ImportGraphDefOptionsAddReturnOutput(Handle, name, index); | |||||
} | |||||
public void AddReturnOutput(string name, int index) | |||||
{ | |||||
c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); | |||||
} | |||||
public void Dispose() | |||||
=> Handle.Dispose(); | |||||
public static implicit operator SafeImportGraphDefOptionsHandle(ImportGraphDefOptions opt) | |||||
{ | |||||
return opt._handle; | |||||
} | } | ||||
} | } |
@@ -0,0 +1,22 @@ | |||||
using Tensorflow.Util; | |||||
namespace Tensorflow; | |||||
public sealed class SafeFuncGraphHandle : SafeTensorflowHandle | |||||
{ | |||||
private SafeFuncGraphHandle() | |||||
{ | |||||
} | |||||
public SafeFuncGraphHandle(IntPtr handle) | |||||
: base(handle) | |||||
{ | |||||
} | |||||
protected override bool ReleaseHandle() | |||||
{ | |||||
c_api.TF_DeleteFunction(handle); | |||||
SetHandle(IntPtr.Zero); | |||||
return true; | |||||
} | |||||
} |
@@ -0,0 +1,22 @@ | |||||
using Tensorflow.Util; | |||||
namespace Tensorflow; | |||||
public sealed class SafeGraphHandle : SafeTensorflowHandle | |||||
{ | |||||
private SafeGraphHandle() | |||||
{ | |||||
} | |||||
public SafeGraphHandle(IntPtr handle) | |||||
: base(handle) | |||||
{ | |||||
} | |||||
protected override bool ReleaseHandle() | |||||
{ | |||||
c_api.TF_DeleteGraph(handle); | |||||
SetHandle(IntPtr.Zero); | |||||
return true; | |||||
} | |||||
} |
@@ -60,7 +60,7 @@ namespace Tensorflow | |||||
/// <param name="num_dims"></param> | /// <param name="num_dims"></param> | ||||
/// <param name="status"></param> | /// <param name="status"></param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, SafeStatusHandle status); | |||||
public static extern void TF_GraphGetTensorShape(SafeGraphHandle graph, TF_Output output, long[] dims, int num_dims, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// Import the graph serialized in `graph_def` into `graph`. | /// Import the graph serialized in `graph_def` into `graph`. | ||||
@@ -78,7 +78,7 @@ namespace Tensorflow | |||||
/// <param name="num_return_outputs">int</param> | /// <param name="num_return_outputs">int</param> | ||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, IntPtr return_outputs, int num_return_outputs, SafeStatusHandle status); | |||||
public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(SafeGraphHandle graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, IntPtr return_outputs, int num_return_outputs, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and | /// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and | ||||
@@ -92,7 +92,7 @@ namespace Tensorflow | |||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
/// <returns>TF_ImportGraphDefResults*</returns> | /// <returns>TF_ImportGraphDefResults*</returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern SafeImportGraphDefResultsHandle TF_GraphImportGraphDefWithResults(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status); | |||||
public static extern SafeImportGraphDefResultsHandle TF_GraphImportGraphDefWithResults(SafeGraphHandle graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// Import the graph serialized in `graph_def` into `graph`. | /// Import the graph serialized in `graph_def` into `graph`. | ||||
@@ -102,7 +102,7 @@ namespace Tensorflow | |||||
/// <param name="options">TF_ImportGraphDefOptions*</param> | /// <param name="options">TF_ImportGraphDefOptions*</param> | ||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TF_GraphImportGraphDef(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status); | |||||
public static extern void TF_GraphImportGraphDef(SafeGraphHandle graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// Iterate through the operations of a graph. | /// Iterate through the operations of a graph. | ||||
@@ -111,7 +111,7 @@ namespace Tensorflow | |||||
/// <param name="pos"></param> | /// <param name="pos"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern IntPtr TF_GraphNextOperation(IntPtr graph, ref uint pos); | |||||
public static extern IntPtr TF_GraphNextOperation(SafeGraphHandle graph, ref uint pos); | |||||
/// <summary> | /// <summary> | ||||
/// Returns the operation in the graph with `oper_name`. Returns nullptr if | /// Returns the operation in the graph with `oper_name`. Returns nullptr if | ||||
@@ -121,14 +121,14 @@ namespace Tensorflow | |||||
/// <param name="oper_name"></param> | /// <param name="oper_name"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern IntPtr TF_GraphOperationByName(IntPtr graph, string oper_name); | |||||
public static extern IntPtr TF_GraphOperationByName(SafeGraphHandle graph, string oper_name); | |||||
/// <summary> | /// <summary> | ||||
/// Sets the shape of the Tensor referenced by `output` in `graph` to | /// Sets the shape of the Tensor referenced by `output` in `graph` to | ||||
/// the shape described by `dims` and `num_dims`. | /// the shape described by `dims` and `num_dims`. | ||||
/// </summary> | /// </summary> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TF_GraphSetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, SafeStatusHandle status); | |||||
public static extern void TF_GraphSetTensorShape(SafeGraphHandle graph, TF_Output output, long[] dims, int num_dims, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// Write out a serialized representation of `graph` (as a GraphDef protocol | /// Write out a serialized representation of `graph` (as a GraphDef protocol | ||||
@@ -138,7 +138,7 @@ namespace Tensorflow | |||||
/// <param name="output_graph_def">TF_Buffer*</param> | /// <param name="output_graph_def">TF_Buffer*</param> | ||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TF_GraphToGraphDef(IntPtr graph, SafeBufferHandle output_graph_def, SafeStatusHandle status); | |||||
public static extern void TF_GraphToGraphDef(SafeGraphHandle graph, SafeBufferHandle output_graph_def, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// Returns the number of dimensions of the Tensor referenced by `output` | /// Returns the number of dimensions of the Tensor referenced by `output` | ||||
@@ -151,7 +151,7 @@ namespace Tensorflow | |||||
/// <param name="status"></param> | /// <param name="status"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern int TF_GraphGetTensorNumDims(IntPtr graph, TF_Output output, SafeStatusHandle status); | |||||
public static extern int TF_GraphGetTensorNumDims(SafeGraphHandle graph, TF_Output output, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// Cause the imported graph to have a control dependency on `oper`. `oper` | /// Cause the imported graph to have a control dependency on `oper`. `oper` | ||||
@@ -287,12 +287,12 @@ namespace Tensorflow | |||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern IntPtr TF_LoadSessionFromSavedModel(SafeSessionOptionsHandle session_options, IntPtr run_options, | |||||
public static extern SafeSessionHandle TF_LoadSessionFromSavedModel(SafeSessionOptionsHandle session_options, IntPtr run_options, | |||||
string export_dir, string[] tags, int tags_len, | string export_dir, string[] tags, int tags_len, | ||||
IntPtr graph, IntPtr meta_graph_def, SafeStatusHandle status); | |||||
SafeGraphHandle graph, IntPtr meta_graph_def, SafeStatusHandle status); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern IntPtr TF_NewGraph(); | |||||
public static extern SafeGraphHandle TF_NewGraph(); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern SafeImportGraphDefOptionsHandle TF_NewImportGraphDefOptions(); | public static extern SafeImportGraphDefOptionsHandle TF_NewImportGraphDefOptions(); | ||||
@@ -334,6 +334,6 @@ namespace Tensorflow | |||||
/// <param name="status"></param> | /// <param name="status"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern bool TF_TryEvaluateConstant(IntPtr graph, TF_Output output, IntPtr[] result, SafeStatusHandle status); | |||||
public static extern bool TF_TryEvaluateConstant(SafeGraphHandle graph, TF_Output output, IntPtr[] result, SafeStatusHandle status); | |||||
} | } | ||||
} | } |
@@ -61,7 +61,7 @@ namespace Tensorflow.NumPy | |||||
{ | { | ||||
if (_handle is not null) | if (_handle is not null) | ||||
{ | { | ||||
_eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle); | |||||
_eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -31,7 +31,7 @@ namespace Tensorflow | |||||
public int InputListLength(string name) | public int InputListLength(string name) | ||||
{ | { | ||||
int num = 0; | int num = 0; | ||||
num = c_api.TF_OperationInputListLength(_handle, name, tf.Status.Handle); | |||||
num = c_api.TF_OperationInputListLength(_handle, name, tf.Status); | |||||
tf.Status.Check(true); | tf.Status.Check(true); | ||||
return num; | return num; | ||||
} | } | ||||
@@ -28,7 +28,7 @@ namespace Tensorflow | |||||
public int OutputListLength(string name) | public int OutputListLength(string name) | ||||
{ | { | ||||
int num = c_api.TF_OperationOutputListLength(_handle, name, tf.Status.Handle); | |||||
int num = c_api.TF_OperationOutputListLength(_handle, name, tf.Status); | |||||
tf.Status.Check(true); | tf.Status.Check(true); | ||||
return num; | return num; | ||||
@@ -187,8 +187,8 @@ namespace Tensorflow | |||||
if (tf.executing_eagerly()) | if (tf.executing_eagerly()) | ||||
return (T[])get_attr(name); | return (T[])get_attr(name); | ||||
using var buf = new Buffer(); | |||||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf.Handle, tf.Status.Handle); | |||||
var buf = new Buffer(); | |||||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, tf.Status); | |||||
tf.Status.Check(true); | tf.Status.Check(true); | ||||
var x = AttrValue.Parser.ParseFrom(buf.ToArray()); | var x = AttrValue.Parser.ParseFrom(buf.ToArray()); | ||||
@@ -210,8 +210,8 @@ namespace Tensorflow | |||||
public virtual object get_attr(string name) | public virtual object get_attr(string name) | ||||
{ | { | ||||
using var buf = new Buffer(); | |||||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf.Handle, tf.Status.Handle); | |||||
var buf = new Buffer(); | |||||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, tf.Status); | |||||
tf.Status.Check(true); | tf.Status.Check(true); | ||||
var x = AttrValue.Parser.ParseFrom(buf.ToArray()); | var x = AttrValue.Parser.ParseFrom(buf.ToArray()); | ||||
@@ -235,13 +235,13 @@ namespace Tensorflow | |||||
public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s) | public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s) | ||||
{ | { | ||||
return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s.Handle); | |||||
return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s); | |||||
} | } | ||||
private NodeDef GetNodeDef() | private NodeDef GetNodeDef() | ||||
{ | { | ||||
using var buffer = new Buffer(); | |||||
c_api.TF_OperationToNodeDef(_handle, buffer.Handle, tf.Status.Handle); | |||||
var buffer = new Buffer(); | |||||
c_api.TF_OperationToNodeDef(_handle, buffer, tf.Status); | |||||
tf.Status.Check(throwException: true); | tf.Status.Check(throwException: true); | ||||
return NodeDef.Parser.ParseFrom(buffer.ToArray()); | return NodeDef.Parser.ParseFrom(buffer.ToArray()); | ||||
} | } | ||||
@@ -50,7 +50,7 @@ namespace Tensorflow | |||||
public Operation FinishOperation(Status status) | public Operation FinishOperation(Status status) | ||||
{ | { | ||||
return c_api.TF_FinishOperation(_handle, status.Handle); | |||||
return c_api.TF_FinishOperation(_handle, status); | |||||
} | } | ||||
public static implicit operator OperationDescription(IntPtr handle) | public static implicit operator OperationDescription(IntPtr handle) | ||||
@@ -96,7 +96,7 @@ namespace Tensorflow | |||||
/// <param name="oper_name">const char*</param> | /// <param name="oper_name">const char*</param> | ||||
/// <returns>TF_OperationDescription*</returns> | /// <returns>TF_OperationDescription*</returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern IntPtr TF_NewOperation(IntPtr graph, string opType, string oper_name); | |||||
public static extern IntPtr TF_NewOperation(SafeGraphHandle graph, string opType, string oper_name); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern IntPtr TF_OperationDevice(IntPtr oper); | public static extern IntPtr TF_OperationDevice(IntPtr oper); | ||||
@@ -14,281 +14,272 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using Google.Protobuf; | |||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
using System; | |||||
using System.Collections; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Numerics; | |||||
using System.Text; | |||||
using Tensorflow.Util; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow | |||||
namespace Tensorflow; | |||||
public class BaseSession : IDisposable | |||||
{ | { | ||||
public class BaseSession : DisposableObject | |||||
protected SafeSessionHandle _handle; | |||||
protected Graph _graph; | |||||
protected Status _status; | |||||
public Graph graph => _graph; | |||||
public BaseSession(SafeSessionHandle handle, Graph g) | |||||
{ | { | ||||
protected Graph _graph; | |||||
protected Status _status; | |||||
public Graph graph => _graph; | |||||
_handle = handle; | |||||
_graph = g ?? ops.get_default_graph(); | |||||
} | |||||
public BaseSession(IntPtr handle, Graph g) | |||||
public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null) | |||||
{ | |||||
_graph = g ?? ops.get_default_graph(); | |||||
if (!_graph.building_function) | |||||
{ | { | ||||
_handle = handle; | |||||
_graph = g ?? ops.get_default_graph(); | |||||
if (ops.get_default_graph() != _graph) | |||||
_graph.as_default(); | |||||
} | } | ||||
var opts = new SessionOptions(target, config); | |||||
_status = status ?? tf.Status; | |||||
_handle = c_api.TF_NewSession(_graph, opts, _status); | |||||
_status.Check(true); | |||||
} | |||||
public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null) | |||||
{ | |||||
_graph = g ?? ops.get_default_graph(); | |||||
if (!_graph.building_function) | |||||
{ | |||||
if (ops.get_default_graph() != _graph) | |||||
_graph.as_default(); | |||||
} | |||||
using var opts = new SessionOptions(target, config); | |||||
_status = status ?? tf.Status; | |||||
_handle = c_api.TF_NewSession(_graph, opts.Handle, _status.Handle); | |||||
_status.Check(true); | |||||
} | |||||
public virtual void run(Operation op, params FeedItem[] feed_dict) | |||||
{ | |||||
_run(op, feed_dict); | |||||
} | |||||
public virtual void run(Operation op, params FeedItem[] feed_dict) | |||||
{ | |||||
_run(op, feed_dict); | |||||
} | |||||
public virtual NDArray run(Tensor fetche, params FeedItem[] feed_dict) | |||||
{ | |||||
return _run(fetche, feed_dict)[0]; | |||||
} | |||||
public virtual NDArray run(Tensor fetche, params FeedItem[] feed_dict) | |||||
{ | |||||
return _run(fetche, feed_dict)[0]; | |||||
} | |||||
public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict) | |||||
{ | |||||
var results = _run(fetche, feed_dict); | |||||
return fetche is Tensor ? results[0] : null; | |||||
} | |||||
public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict) | |||||
{ | |||||
var results = _run(fetche, feed_dict); | |||||
return fetche is Tensor ? results[0] : null; | |||||
} | |||||
public virtual (NDArray, NDArray, NDArray, NDArray, NDArray) run( | |||||
(ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, | |||||
params FeedItem[] feed_dict) | |||||
{ | |||||
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4, fetches.Item5 }, feed_dict); | |||||
return (results[0], results[1], results[2], results[3], results[4]); | |||||
} | |||||
public virtual (NDArray, NDArray, NDArray, NDArray, NDArray) run( | |||||
(ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, | |||||
params FeedItem[] feed_dict) | |||||
{ | |||||
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4, fetches.Item5 }, feed_dict); | |||||
return (results[0], results[1], results[2], results[3], results[4]); | |||||
} | |||||
public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||||
{ | |||||
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); | |||||
return (results[0], results[1], results[2], results[3]); | |||||
} | |||||
public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||||
{ | |||||
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); | |||||
return (results[0], results[1], results[2], results[3]); | |||||
} | |||||
public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||||
{ | |||||
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict); | |||||
return (results[0], results[1], results[2]); | |||||
} | |||||
public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||||
{ | |||||
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict); | |||||
return (results[0], results[1], results[2]); | |||||
} | |||||
public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||||
{ | |||||
var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict); | |||||
return (results[0], results[1]); | |||||
} | |||||
public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||||
{ | |||||
var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict); | |||||
return (results[0], results[1]); | |||||
} | |||||
public virtual NDArray[] run(object fetches, params FeedItem[] feed_dict) | |||||
{ | |||||
return _run(fetches, feed_dict); | |||||
} | |||||
public virtual NDArray[] run(object fetches, params FeedItem[] feed_dict) | |||||
{ | |||||
return _run(fetches, feed_dict); | |||||
} | |||||
public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) | |||||
{ | |||||
var feed_items = feed_dict == null ? new FeedItem[0] : feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | |||||
return _run(fetches, feed_items); | |||||
} | |||||
public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) | |||||
{ | |||||
var feed_items = feed_dict == null ? new FeedItem[0] : feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | |||||
return _run(fetches, feed_items); | |||||
} | |||||
private NDArray[] _run(object fetches, FeedItem[] feed_dict = null) | |||||
{ | |||||
var feed_dict_tensor = new Dictionary<object, object>(); | |||||
//var feed_map = new Dictionary<object, object>(); | |||||
private NDArray[] _run(object fetches, FeedItem[] feed_dict = null) | |||||
// Validate and process feed_dict. | |||||
if (feed_dict != null) | |||||
{ | { | ||||
var feed_dict_tensor = new Dictionary<object, object>(); | |||||
//var feed_map = new Dictionary<object, object>(); | |||||
// Validate and process feed_dict. | |||||
if (feed_dict != null) | |||||
foreach (var subfeed in feed_dict) | |||||
{ | { | ||||
foreach (var subfeed in feed_dict) | |||||
{ | |||||
var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false); | |||||
//var target_dtype = subfeed_t.dtype.as_numpy_typecode(); // subfeed_dtype was never used | |||||
feed_dict_tensor[subfeed_t] = subfeed.Value; | |||||
//feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value); | |||||
} | |||||
var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false); | |||||
//var target_dtype = subfeed_t.dtype.as_numpy_typecode(); // subfeed_dtype was never used | |||||
feed_dict_tensor[subfeed_t] = subfeed.Value; | |||||
//feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value); | |||||
} | } | ||||
} | |||||
// Create a fetch handler to take care of the structure of fetches. | |||||
var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); | |||||
// Create a fetch handler to take care of the structure of fetches. | |||||
var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); | |||||
// Run request and get response. | |||||
// We need to keep the returned movers alive for the following _do_run(). | |||||
// These movers are no longer needed when _do_run() completes, and | |||||
// are deleted when `movers` goes out of scope when this _run() ends. | |||||
var _ = _update_with_movers(); | |||||
var final_fetches = fetch_handler.fetches(); | |||||
var final_targets = fetch_handler.targets(); | |||||
// Run request and get response. | |||||
// We need to keep the returned movers alive for the following _do_run(). | |||||
// These movers are no longer needed when _do_run() completes, and | |||||
// are deleted when `movers` goes out of scope when this _run() ends. | |||||
var _ = _update_with_movers(); | |||||
var final_fetches = fetch_handler.fetches(); | |||||
var final_targets = fetch_handler.targets(); | |||||
// We only want to really perform the run if fetches or targets are provided, | |||||
// or if the call is a partial run that specifies feeds. | |||||
var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor); | |||||
// We only want to really perform the run if fetches or targets are provided, | |||||
// or if the call is a partial run that specifies feeds. | |||||
var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor); | |||||
return fetch_handler.build_results(this, results); | |||||
} | |||||
return fetch_handler.build_results(this, results); | |||||
} | |||||
/// <summary> | |||||
/// Runs a step based on the given fetches and feeds. | |||||
/// </summary> | |||||
/// <param name="target_list">A list of operations to be run, but not fetched.</param> | |||||
/// <param name="fetch_list"></param> | |||||
/// <param name="feed_dict"></param> | |||||
/// <returns> | |||||
/// A list of numpy ndarrays, corresponding to the elements of | |||||
/// `fetch_list`. If the ith element of `fetch_list` contains the | |||||
/// name of an operation, the first Tensor output of that operation | |||||
/// will be returned for that element. | |||||
/// </returns> | |||||
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | |||||
/// <summary> | |||||
/// Runs a step based on the given fetches and feeds. | |||||
/// </summary> | |||||
/// <param name="target_list">A list of operations to be run, but not fetched.</param> | |||||
/// <param name="fetch_list"></param> | |||||
/// <param name="feed_dict"></param> | |||||
/// <returns> | |||||
/// A list of numpy ndarrays, corresponding to the elements of | |||||
/// `fetch_list`. If the ith element of `fetch_list` contains the | |||||
/// name of an operation, the first Tensor output of that operation | |||||
/// will be returned for that element. | |||||
/// </returns> | |||||
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | |||||
{ | |||||
var feeds = new KeyValuePair<TF_Output, Tensor>[feed_dict.Count]; | |||||
int i = 0; | |||||
foreach (var x in feed_dict) | |||||
{ | { | ||||
var feeds = new KeyValuePair<TF_Output, Tensor>[feed_dict.Count]; | |||||
int i = 0; | |||||
foreach (var x in feed_dict) | |||||
if (x.Key is Tensor key) | |||||
{ | { | ||||
if (x.Key is Tensor key) | |||||
switch (x.Value) | |||||
{ | { | ||||
switch (x.Value) | |||||
{ | |||||
case Tensor v: | |||||
if (v.dtype != key.dtype) | |||||
throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {v.dtype}"); | |||||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), v); | |||||
break; | |||||
case SafeTensorHandle v: | |||||
var tensor = new Tensor(v); | |||||
if (tensor.dtype != key.dtype) | |||||
throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {tensor.dtype}"); | |||||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), tensor); | |||||
break; | |||||
case bool v: | |||||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||||
break; | |||||
case byte v: | |||||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||||
break; | |||||
case int v: | |||||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||||
break; | |||||
case long v: | |||||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||||
break; | |||||
case float v: | |||||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||||
break; | |||||
case double v: | |||||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||||
break; | |||||
case string v: | |||||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||||
break; | |||||
case Array v: | |||||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v, v.GetShape())); | |||||
break; | |||||
default: | |||||
throw new NotImplementedException(""); | |||||
} | |||||
case Tensor v: | |||||
if (v.dtype != key.dtype) | |||||
throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {v.dtype}"); | |||||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), v); | |||||
break; | |||||
case SafeTensorHandle v: | |||||
var tensor = new Tensor(v); | |||||
if (tensor.dtype != key.dtype) | |||||
throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {tensor.dtype}"); | |||||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), tensor); | |||||
break; | |||||
case bool v: | |||||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||||
break; | |||||
case byte v: | |||||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||||
break; | |||||
case int v: | |||||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||||
break; | |||||
case long v: | |||||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||||
break; | |||||
case float v: | |||||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||||
break; | |||||
case double v: | |||||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||||
break; | |||||
case string v: | |||||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||||
break; | |||||
case Array v: | |||||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v, v.GetShape())); | |||||
break; | |||||
default: | |||||
throw new NotImplementedException(""); | |||||
} | } | ||||
else | |||||
throw new NotImplementedException(""); | |||||
} | } | ||||
var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | |||||
//var targets = target_list; | |||||
return _call_tf_sessionrun(feeds, fetches, target_list); | |||||
else | |||||
throw new NotImplementedException(""); | |||||
} | } | ||||
var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | |||||
//var targets = target_list; | |||||
return _call_tf_sessionrun(feeds, fetches, target_list); | |||||
} | |||||
private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list) | |||||
{ | |||||
// Ensure any changes to the graph are reflected in the runtime. | |||||
_extend_graph(); | |||||
var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); | |||||
private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list) | |||||
{ | |||||
// Ensure any changes to the graph are reflected in the runtime. | |||||
_extend_graph(); | |||||
c_api.TF_SessionRun(_handle, | |||||
run_options: null, | |||||
inputs: feed_dict.Select(f => f.Key).ToArray(), | |||||
input_values: feed_dict.Select(f => f.Value.Handle.DangerousGetHandle()).ToArray(), | |||||
ninputs: feed_dict.Length, | |||||
outputs: fetch_list, | |||||
output_values: output_values, | |||||
noutputs: fetch_list.Length, | |||||
target_opers: target_list.Select(f => (IntPtr)f).ToArray(), | |||||
ntargets: target_list.Count, | |||||
run_metadata: IntPtr.Zero, | |||||
status: _status.Handle); | |||||
var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); | |||||
_status.Check(true); | |||||
c_api.TF_SessionRun(_handle, | |||||
run_options: null, | |||||
inputs: feed_dict.Select(f => f.Key).ToArray(), | |||||
input_values: feed_dict.Select(f => f.Value.Handle.DangerousGetHandle()).ToArray(), | |||||
ninputs: feed_dict.Length, | |||||
outputs: fetch_list, | |||||
output_values: output_values, | |||||
noutputs: fetch_list.Length, | |||||
target_opers: target_list.Select(f => (IntPtr)f).ToArray(), | |||||
ntargets: target_list.Count, | |||||
run_metadata: IntPtr.Zero, | |||||
status: _status); | |||||
var result = new NDArray[fetch_list.Length]; | |||||
_status.Check(true); | |||||
for (int i = 0; i < fetch_list.Length; i++) | |||||
result[i] = fetchValue(new SafeTensorHandle(output_values[i])); | |||||
var result = new NDArray[fetch_list.Length]; | |||||
return result; | |||||
} | |||||
for (int i = 0; i < fetch_list.Length; i++) | |||||
result[i] = fetchValue(new SafeTensorHandle(output_values[i])); | |||||
public unsafe Tensor eval(Tensor tensor) | |||||
{ | |||||
var output_values = new IntPtr[1]; | |||||
var fetch_list = new[] { tensor._as_tf_output() }; | |||||
c_api.TF_SessionRun(_handle, | |||||
run_options: null, | |||||
inputs: new TF_Output[0], | |||||
input_values: new IntPtr[0], | |||||
ninputs: 0, | |||||
outputs: fetch_list, | |||||
output_values: output_values, | |||||
noutputs: 1, | |||||
target_opers: new IntPtr[0], | |||||
ntargets: 0, | |||||
run_metadata: IntPtr.Zero, | |||||
status: _status.Handle); | |||||
_status.Check(true); | |||||
return new Tensor(new SafeTensorHandle(output_values[0])); | |||||
} | |||||
return result; | |||||
} | |||||
private static unsafe NDArray fetchValue(SafeTensorHandle output) | |||||
{ | |||||
var tensor = new Tensor(output); | |||||
return tensor.numpy(); | |||||
} | |||||
public unsafe Tensor eval(Tensor tensor) | |||||
{ | |||||
var output_values = new IntPtr[1]; | |||||
var fetch_list = new[] { tensor._as_tf_output() }; | |||||
c_api.TF_SessionRun(_handle, | |||||
run_options: null, | |||||
inputs: new TF_Output[0], | |||||
input_values: new IntPtr[0], | |||||
ninputs: 0, | |||||
outputs: fetch_list, | |||||
output_values: output_values, | |||||
noutputs: 1, | |||||
target_opers: new IntPtr[0], | |||||
ntargets: 0, | |||||
run_metadata: IntPtr.Zero, | |||||
status: _status); | |||||
_status.Check(true); | |||||
return new Tensor(new SafeTensorHandle(output_values[0])); | |||||
} | |||||
/// <summary> | |||||
/// If a tensor handle that is fed to a device incompatible placeholder, | |||||
/// we move the tensor to the right device, generate a new tensor handle, | |||||
/// and update feed_dict to use the new handle. | |||||
/// </summary> | |||||
private List<object> _update_with_movers() | |||||
{ | |||||
return new List<object> { }; | |||||
} | |||||
private static unsafe NDArray fetchValue(SafeTensorHandle output) | |||||
{ | |||||
var tensor = new Tensor(output); | |||||
return tensor.numpy(); | |||||
} | |||||
private void _extend_graph() | |||||
{ } | |||||
/// <summary> | |||||
/// If a tensor handle that is fed to a device incompatible placeholder, | |||||
/// we move the tensor to the right device, generate a new tensor handle, | |||||
/// and update feed_dict to use the new handle. | |||||
/// </summary> | |||||
private List<object> _update_with_movers() | |||||
{ | |||||
return new List<object> { }; | |||||
} | |||||
protected override void DisposeUnmanagedResources(IntPtr handle) | |||||
{ | |||||
// c_api.TF_CloseSession(handle, tf.Status.Handle); | |||||
c_api.TF_DeleteSession(handle, _status.Handle); | |||||
} | |||||
private void _extend_graph() | |||||
{ } | |||||
public void Dispose() | |||||
{ | |||||
} | } | ||||
} | } |
@@ -0,0 +1,46 @@ | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using System; | |||||
using System.Net.NetworkInformation; | |||||
using Tensorflow.Util; | |||||
namespace Tensorflow | |||||
{ | |||||
public sealed class SafeSessionHandle : SafeTensorflowHandle | |||||
{ | |||||
private SafeSessionHandle() | |||||
{ | |||||
} | |||||
public SafeSessionHandle(IntPtr handle) | |||||
: base(handle) | |||||
{ | |||||
} | |||||
public override string ToString() | |||||
=> $"0x{handle:x16}"; | |||||
protected override bool ReleaseHandle() | |||||
{ | |||||
var status = new Status(); | |||||
// c_api.TF_CloseSession(handle, tf.Status.Handle); | |||||
c_api.TF_DeleteSession(handle, status); | |||||
SetHandle(IntPtr.Zero); | |||||
return true; | |||||
} | |||||
} | |||||
} |
@@ -14,75 +14,49 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System; | |||||
using System.IO; | |||||
using System.Runtime.CompilerServices; | |||||
using Tensorflow.Util; | |||||
namespace Tensorflow; | |||||
namespace Tensorflow | |||||
public class Session : BaseSession | |||||
{ | { | ||||
public class Session : BaseSession | |||||
{ | |||||
public Session(string target = "", Graph g = null) : base(target, g, null) | |||||
{ } | |||||
public Session(IntPtr handle, Graph g = null) : base(handle, g) | |||||
{ } | |||||
public Session(Graph g, ConfigProto config = null, Status s = null) : base("", g, config, s) | |||||
{ } | |||||
public Session as_default() | |||||
{ | |||||
return ops.set_default_session(this); | |||||
} | |||||
public static Session LoadFromSavedModel(string path) | |||||
{ | |||||
var graph = new Graph(); | |||||
using var status = new Status(); | |||||
using var opt = c_api.TF_NewSessionOptions(); | |||||
var tags = new string[] { "serve" }; | |||||
var sess = c_api.TF_LoadSessionFromSavedModel(opt, | |||||
IntPtr.Zero, | |||||
path, | |||||
tags, | |||||
tags.Length, | |||||
graph, | |||||
IntPtr.Zero, | |||||
status.Handle); | |||||
status.Check(true); | |||||
// load graph bytes | |||||
// var data = new byte[buffer.length]; | |||||
// Marshal.Copy(buffer.data, data, 0, (int)buffer.length); | |||||
// var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ | |||||
return new Session(sess, g: graph); | |||||
} | |||||
public static implicit operator IntPtr(Session session) => session._handle; | |||||
public static implicit operator Session(IntPtr handle) => new Session(handle); | |||||
public Session(string target = "", Graph g = null) : base(target, g, null) | |||||
{ } | |||||
public void __enter__() | |||||
{ | |||||
public Session(SafeSessionHandle handle, Graph g = null) : base(handle, g) | |||||
{ } | |||||
} | |||||
public Session(Graph g, ConfigProto config = null, Status s = null) : base("", g, config, s) | |||||
{ } | |||||
public void __exit__() | |||||
{ | |||||
} | |||||
public void __init__() | |||||
{ | |||||
} | |||||
public void __del__() | |||||
{ | |||||
public Session as_default() | |||||
{ | |||||
return ops.set_default_session(this); | |||||
} | |||||
} | |||||
public static Session LoadFromSavedModel(string path) | |||||
{ | |||||
var graph = new Graph(); | |||||
var status = new Status(); | |||||
using var opt = c_api.TF_NewSessionOptions(); | |||||
var tags = new string[] { "serve" }; | |||||
var sess = c_api.TF_LoadSessionFromSavedModel(opt, | |||||
IntPtr.Zero, | |||||
path, | |||||
tags, | |||||
tags.Length, | |||||
graph, | |||||
IntPtr.Zero, | |||||
status); | |||||
status.Check(true); | |||||
// load graph bytes | |||||
// var data = new byte[buffer.length]; | |||||
// Marshal.Copy(buffer.data, data, 0, (int)buffer.length); | |||||
// var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ | |||||
return new Session(sess, g: graph); | |||||
} | } | ||||
public static implicit operator SafeSessionHandle(Session session) => session._handle; | |||||
public static implicit operator Session(SafeSessionHandle handle) => new Session(handle); | |||||
} | } |
@@ -19,33 +19,33 @@ using System; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
internal sealed class SessionOptions : IDisposable | |||||
internal sealed class SessionOptions | |||||
{ | { | ||||
public SafeSessionOptionsHandle Handle { get; } | |||||
SafeSessionOptionsHandle _handle { get; } | |||||
public SessionOptions(string target = "", ConfigProto config = null) | public SessionOptions(string target = "", ConfigProto config = null) | ||||
{ | { | ||||
Handle = c_api.TF_NewSessionOptions(); | |||||
c_api.TF_SetTarget(Handle, target); | |||||
_handle = c_api.TF_NewSessionOptions(); | |||||
c_api.TF_SetTarget(_handle, target); | |||||
if (config != null) | if (config != null) | ||||
SetConfig(config); | SetConfig(config); | ||||
} | } | ||||
public void Dispose() | |||||
=> Handle.Dispose(); | |||||
private unsafe void SetConfig(ConfigProto config) | private unsafe void SetConfig(ConfigProto config) | ||||
{ | { | ||||
var bytes = config.ToByteArray(); | var bytes = config.ToByteArray(); | ||||
fixed (byte* proto2 = bytes) | fixed (byte* proto2 = bytes) | ||||
{ | { | ||||
using (var status = new Status()) | |||||
{ | |||||
c_api.TF_SetConfig(Handle, (IntPtr)proto2, (ulong)bytes.Length, status.Handle); | |||||
status.Check(false); | |||||
} | |||||
var status = new Status(); | |||||
c_api.TF_SetConfig(_handle, (IntPtr)proto2, (ulong)bytes.Length, status); | |||||
status.Check(false); | |||||
} | } | ||||
} | } | ||||
public static implicit operator SafeSessionOptionsHandle(SessionOptions opt) | |||||
{ | |||||
return opt._handle; | |||||
} | |||||
} | } | ||||
} | } |
@@ -62,7 +62,7 @@ namespace Tensorflow | |||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
/// <returns>TF_Session*</returns> | /// <returns>TF_Session*</returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern IntPtr TF_NewSession(IntPtr graph, SafeSessionOptionsHandle opts, SafeStatusHandle status); | |||||
public static extern SafeSessionHandle TF_NewSession(SafeGraphHandle graph, SafeSessionOptionsHandle opts, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// Return a new options object. | /// Return a new options object. | ||||
@@ -110,7 +110,7 @@ namespace Tensorflow | |||||
/// <param name="run_metadata">TF_Buffer*</param> | /// <param name="run_metadata">TF_Buffer*</param> | ||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern unsafe void TF_SessionRun(IntPtr session, TF_Buffer* run_options, | |||||
public static extern unsafe void TF_SessionRun(SafeSessionHandle session, TF_Buffer* run_options, | |||||
TF_Output[] inputs, IntPtr[] input_values, int ninputs, | TF_Output[] inputs, IntPtr[] input_values, int ninputs, | ||||
TF_Output[] outputs, IntPtr[] output_values, int noutputs, | TF_Output[] outputs, IntPtr[] output_values, int noutputs, | ||||
IntPtr[] target_opers, int ntargets, | IntPtr[] target_opers, int ntargets, | ||||
@@ -26,7 +26,7 @@ namespace Tensorflow | |||||
/// TF_Status holds error information. It either has an OK code, or | /// TF_Status holds error information. It either has an OK code, or | ||||
/// else an error code with an associated error message. | /// else an error code with an associated error message. | ||||
/// </summary> | /// </summary> | ||||
public sealed class Status : IDisposable | |||||
public sealed class Status | |||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// Error message | /// Error message | ||||
@@ -35,9 +35,9 @@ namespace Tensorflow | |||||
{ | { | ||||
get | get | ||||
{ | { | ||||
using (Handle.Lease()) | |||||
using (_handle.Lease()) | |||||
{ | { | ||||
return StringPiece(TF_Message(Handle)); | |||||
return StringPiece(TF_Message(_handle)); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -45,23 +45,23 @@ namespace Tensorflow | |||||
/// <summary> | /// <summary> | ||||
/// Error code | /// Error code | ||||
/// </summary> | /// </summary> | ||||
public TF_Code Code => TF_GetCode(Handle); | |||||
public TF_Code Code => TF_GetCode(_handle); | |||||
public SafeStatusHandle Handle { get; } | |||||
SafeStatusHandle _handle { get; } | |||||
public Status() | public Status() | ||||
{ | { | ||||
Handle = TF_NewStatus(); | |||||
_handle = TF_NewStatus(); | |||||
} | } | ||||
public Status(SafeStatusHandle handle) | public Status(SafeStatusHandle handle) | ||||
{ | { | ||||
Handle = handle ?? throw new ArgumentNullException(nameof(handle)); | |||||
_handle = handle ?? throw new ArgumentNullException(nameof(handle)); | |||||
} | } | ||||
public void SetStatus(TF_Code code, string msg) | public void SetStatus(TF_Code code, string msg) | ||||
{ | { | ||||
TF_SetStatus(Handle, code, msg); | |||||
TF_SetStatus(_handle, code, msg); | |||||
} | } | ||||
public bool ok() => Code == TF_Code.TF_OK; | public bool ok() => Code == TF_Code.TF_OK; | ||||
@@ -94,10 +94,12 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
public void Dispose() | |||||
=> Handle.Dispose(); | |||||
public override string ToString() | public override string ToString() | ||||
=> $"{Code} 0x{Handle.DangerousGetHandle():x16}"; | |||||
=> $"{Code} 0x{_handle.DangerousGetHandle():x16}"; | |||||
public static implicit operator SafeStatusHandle(Status status) | |||||
{ | |||||
return status._handle; | |||||
} | |||||
} | } | ||||
} | } |
@@ -121,7 +121,7 @@ namespace Tensorflow | |||||
if (_handle == null) | if (_handle == null) | ||||
{ | { | ||||
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, tf.Status.Handle); | |||||
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, tf.Status); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -135,9 +135,9 @@ namespace Tensorflow | |||||
protected virtual void SetShapeInternal(Shape value) | protected virtual void SetShapeInternal(Shape value) | ||||
{ | { | ||||
if (value == null) | if (value == null) | ||||
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.Status.Handle); | |||||
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.Status); | |||||
else | else | ||||
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.dims, value.ndim, tf.Status.Handle); | |||||
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.dims, value.ndim, tf.Status); | |||||
} | } | ||||
public int[] _shape_tuple() | public int[] _shape_tuple() | ||||
@@ -176,7 +176,7 @@ namespace Tensorflow | |||||
if (_handle == null) | if (_handle == null) | ||||
{ | { | ||||
var output = _as_tf_output(); | var output = _as_tf_output(); | ||||
int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, tf.Status.Handle); | |||||
int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, tf.Status); | |||||
return ndim; | return ndim; | ||||
} | } | ||||
@@ -94,18 +94,16 @@ namespace Tensorflow | |||||
string output_pb = Path.GetFullPath(Path.Combine(checkpoint_dir, "../", $"{output_pb_name}.pb")); | string output_pb = Path.GetFullPath(Path.Combine(checkpoint_dir, "../", $"{output_pb_name}.pb")); | ||||
using (var graph = tf.Graph()) | |||||
using (var sess = tf.Session(graph)) | |||||
{ | |||||
var saver = tf.train.import_meta_graph($"{checkpoint}.meta", clear_devices: true); | |||||
saver.restore(sess, checkpoint); | |||||
var output_graph_def = tf.graph_util.convert_variables_to_constants(sess, | |||||
graph.as_graph_def(), | |||||
output_node_names); | |||||
Binding.tf_output_redirect.WriteLine($"Froze {output_graph_def.Node.Count} nodes."); | |||||
File.WriteAllBytes(output_pb, output_graph_def.ToByteArray()); | |||||
return output_pb; | |||||
} | |||||
var graph = tf.Graph(); | |||||
var sess = tf.Session(graph); | |||||
var saver = tf.train.import_meta_graph($"{checkpoint}.meta", clear_devices: true); | |||||
saver.restore(sess, checkpoint); | |||||
var output_graph_def = tf.graph_util.convert_variables_to_constants(sess, | |||||
graph.as_graph_def(), | |||||
output_node_names); | |||||
Binding.tf_output_redirect.WriteLine($"Froze {output_graph_def.Node.Count} nodes."); | |||||
File.WriteAllBytes(output_pb, output_graph_def.ToByteArray()); | |||||
return output_pb; | |||||
} | } | ||||
public static Graph load_graph(string freeze_graph_pb, string name = "") | public static Graph load_graph(string freeze_graph_pb, string name = "") | ||||
@@ -164,7 +164,7 @@ namespace Tensorflow | |||||
result._as_tf_output(), | result._as_tf_output(), | ||||
shape.dims, | shape.dims, | ||||
shape.ndim, | shape.ndim, | ||||
tf.Status.Handle); | |||||
tf.Status); | |||||
tf.Status.Check(true); | tf.Status.Check(true); | ||||
} | } | ||||
@@ -247,7 +247,7 @@ namespace Tensorflow | |||||
foreach (var attr in node_def.Attr) | foreach (var attr in node_def.Attr) | ||||
{ | { | ||||
var bytes = attr.Value.ToByteArray(); | var bytes = attr.Value.ToByteArray(); | ||||
c_api.TF_SetAttrValueProto(op_desc, attr.Key, bytes, proto_len: bytes.Length, status: status.Handle); | |||||
c_api.TF_SetAttrValueProto(op_desc, attr.Key, bytes, proto_len: bytes.Length, status: status); | |||||
status.Check(true); | status.Check(true); | ||||
} | } | ||||
@@ -23,16 +23,14 @@ namespace Tensorflow.Benchmark.Leak | |||||
var ClassifierModelPath = Path.Combine(modelDir, "Leak", "TestModel", "saved_model"); | var ClassifierModelPath = Path.Combine(modelDir, "Leak", "TestModel", "saved_model"); | ||||
for (var i = 0; i < 1024; i++) | for (var i = 0; i < 1024; i++) | ||||
{ | |||||
using (var sess = Session.LoadFromSavedModel(ClassifierModelPath)) { | |||||
using (var g = sess.graph.as_default()) { | |||||
var inputOp = g.OperationByName("inference_input"); | |||||
var outputOp = g.OperationByName("StatefulPartitionedCall"); | |||||
{ | |||||
var sess = Session.LoadFromSavedModel(ClassifierModelPath); | |||||
var g = sess.graph.as_default(); | |||||
var inputOp = g.OperationByName("inference_input"); | |||||
var outputOp = g.OperationByName("StatefulPartitionedCall"); | |||||
var inp = np.zeros(new Shape(new int[] { 1, 2, 96 }), TF_DataType.TF_FLOAT); | |||||
sess.run(outputOp.outputs[0], new FeedItem(inputOp.outputs[0], inp)); | |||||
} | |||||
} | |||||
var inp = np.zeros(new Shape(new int[] { 1, 2, 96 }), TF_DataType.TF_FLOAT); | |||||
sess.run(outputOp.outputs[0], new FeedItem(inputOp.outputs[0], inp)); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -16,18 +16,16 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
var enqueue = queue.enqueue(numbers); | var enqueue = queue.enqueue(numbers); | ||||
var dequeue_many = queue.dequeue_many(n: 3); | var dequeue_many = queue.dequeue_many(n: 3); | ||||
using (var sess = tf.Session()) | |||||
{ | |||||
sess.run(enqueue, (numbers, new[] { 1 })); | |||||
sess.run(enqueue, (numbers, new[] { 2, 3 })); | |||||
sess.run(enqueue, (numbers, new[] { 3, 4, 5 })); | |||||
var result = sess.run(dequeue_many[0]); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0 }, result[0].ToArray<int>())); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3, 0 }, result[1].ToArray<int>())); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 4, 5 }, result[2].ToArray<int>())); | |||||
} | |||||
var sess = tf.Session(); | |||||
sess.run(enqueue, (numbers, new[] { 1 })); | |||||
sess.run(enqueue, (numbers, new[] { 2, 3 })); | |||||
sess.run(enqueue, (numbers, new[] { 3, 4, 5 })); | |||||
var result = sess.run(dequeue_many[0]); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0 }, result[0].ToArray<int>())); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3, 0 }, result[1].ToArray<int>())); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 4, 5 }, result[2].ToArray<int>())); | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
@@ -45,27 +43,25 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
// push back into queue | // push back into queue | ||||
var inc = queue.enqueue(y); | var inc = queue.enqueue(y); | ||||
using (var sess = tf.Session()) | |||||
{ | |||||
// init queue | |||||
init.run(); | |||||
var sess = tf.Session(); | |||||
// init queue | |||||
init.run(); | |||||
// pop out first element and push back calculated y | |||||
(int dequeued, _) = sess.run((x, inc)); | |||||
Assert.AreEqual(10, dequeued); | |||||
// pop out first element and push back calculated y | |||||
(int dequeued, _) = sess.run((x, inc)); | |||||
Assert.AreEqual(10, dequeued); | |||||
(dequeued, _) = sess.run((x, inc)); | |||||
Assert.AreEqual(20, dequeued); | |||||
(dequeued, _) = sess.run((x, inc)); | |||||
Assert.AreEqual(20, dequeued); | |||||
(dequeued, _) = sess.run((x, inc)); | |||||
Assert.AreEqual(11, dequeued); | |||||
(dequeued, _) = sess.run((x, inc)); | |||||
Assert.AreEqual(11, dequeued); | |||||
(dequeued, _) = sess.run((x, inc)); | |||||
Assert.AreEqual(21, dequeued); | |||||
(dequeued, _) = sess.run((x, inc)); | |||||
Assert.AreEqual(21, dequeued); | |||||
// thread will hang or block if you run sess.run(x) again | |||||
// until queue has more element. | |||||
} | |||||
// thread will hang or block if you run sess.run(x) again | |||||
// until queue has more element. | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
@@ -75,19 +71,17 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
var init = queue.enqueue_many(new[] { 2L, 4L, 3L }, new[] { "p1", "p2", "p3" }); | var init = queue.enqueue_many(new[] { 2L, 4L, 3L }, new[] { "p1", "p2", "p3" }); | ||||
var x = queue.dequeue(); | var x = queue.dequeue(); | ||||
using (var sess = tf.Session()) | |||||
{ | |||||
init.run(); | |||||
var sess = tf.Session(); | |||||
init.run(); | |||||
var result = sess.run(x); | |||||
Assert.AreEqual(result[0], 2L); | |||||
var result = sess.run(x); | |||||
Assert.AreEqual(result[0], 2L); | |||||
result = sess.run(x); | |||||
Assert.AreEqual(result[0], 3L); | |||||
result = sess.run(x); | |||||
Assert.AreEqual(result[0], 3L); | |||||
result = sess.run(x); | |||||
Assert.AreEqual(result[0], 4L); | |||||
} | |||||
result = sess.run(x); | |||||
Assert.AreEqual(result[0], 4L); | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
@@ -98,16 +92,14 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
var x = queue.dequeue(); | var x = queue.dequeue(); | ||||
string results = ""; | string results = ""; | ||||
using (var sess = tf.Session()) | |||||
{ | |||||
init.run(); | |||||
var sess = tf.Session(); | |||||
init.run(); | |||||
foreach (var i in range(9)) | |||||
results += (int)sess.run(x) + "."; | |||||
foreach (var i in range(9)) | |||||
results += (int)sess.run(x) + "."; | |||||
// output in random order | |||||
Assert.IsFalse(results == "1.2.3.4.5.6.7.8.9."); | |||||
} | |||||
// output in random order | |||||
Assert.IsFalse(results == "1.2.3.4.5.6.7.8.9."); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -19,11 +19,9 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
var a = constant_op.constant(np.array(3.0).reshape((1, 1))); | var a = constant_op.constant(np.array(3.0).reshape((1, 1))); | ||||
var b = constant_op.constant(np.array(2.0).reshape((1, 1))); | var b = constant_op.constant(np.array(2.0).reshape((1, 1))); | ||||
var c = math_ops.matmul(a, b, name: "matmul"); | var c = math_ops.matmul(a, b, name: "matmul"); | ||||
using (var sess = tf.Session()) | |||||
{ | |||||
var result = c.eval(sess); | |||||
Assert.AreEqual(result[0], 6.0); | |||||
} | |||||
var sess = tf.Session(); | |||||
var result = c.eval(sess); | |||||
Assert.AreEqual(result[0], 6.0); | |||||
} | } | ||||
} | } | ||||
@@ -32,11 +30,9 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
{ | { | ||||
var a = constant_op.constant("123 heythere 123 ", TF_DataType.TF_STRING); | var a = constant_op.constant("123 heythere 123 ", TF_DataType.TF_STRING); | ||||
var c = tf.strings.substr(a, 4, 8); | var c = tf.strings.substr(a, 4, 8); | ||||
using (var sess = tf.Session()) | |||||
{ | |||||
var result = c.eval(sess).StringData(); | |||||
Assert.AreEqual(result[0], "heythere"); | |||||
} | |||||
var sess = tf.Session(); | |||||
var result = c.eval(sess).StringData(); | |||||
Assert.AreEqual(result[0], "heythere"); | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
@@ -47,11 +43,9 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
const int size = 30_000; | const int size = 30_000; | ||||
var a = constant_op.constant(new string('a', size), TF_DataType.TF_STRING); | var a = constant_op.constant(new string('a', size), TF_DataType.TF_STRING); | ||||
var c = tf.strings.substr(a, 0, size - 5000); | var c = tf.strings.substr(a, 0, size - 5000); | ||||
using (var sess = tf.Session()) | |||||
{ | |||||
var result = UTF8Encoding.UTF8.GetString(c.eval(sess).ToByteArray()); | |||||
Console.WriteLine(result); | |||||
} | |||||
var sess = tf.Session(); | |||||
var result = UTF8Encoding.UTF8.GetString(c.eval(sess).ToByteArray()); | |||||
Console.WriteLine(result); | |||||
} | } | ||||
} | } | ||||
@@ -16,15 +16,13 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
var labels = tf.expand_dims(tf.constant(new[] { 0, 1, 2, 3, 4 }), 1); | var labels = tf.expand_dims(tf.constant(new[] { 0, 1, 2, 3, 4 }), 1); | ||||
var st = tf.concat(values: new[] { indices, labels }, axis: 1); | var st = tf.concat(values: new[] { indices, labels }, axis: 1); | ||||
var onehot = tf.sparse_to_dense(st, (5, 5), 1); | var onehot = tf.sparse_to_dense(st, (5, 5), 1); | ||||
using (var sess = tf.Session()) | |||||
{ | |||||
var result = sess.run(onehot); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0, 0 }, result[0].ToArray<int>())); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 1, 0, 0, 0 }, result[1].ToArray<int>())); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 1, 0, 0 }, result[2].ToArray<int>())); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 1, 0 }, result[3].ToArray<int>())); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 1 }, result[4].ToArray<int>())); | |||||
}; | |||||
var sess = tf.Session(); | |||||
var result = sess.run(onehot); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0, 0 }, result[0].ToArray<int>())); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 1, 0, 0, 0 }, result[1].ToArray<int>())); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 1, 0, 0 }, result[2].ToArray<int>())); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 1, 0 }, result[3].ToArray<int>())); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 1 }, result[4].ToArray<int>())); | |||||
} | } | ||||
[TestMethod, Ignore] | [TestMethod, Ignore] | ||||
@@ -39,13 +37,11 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
new[] { 3L, 4L }); | new[] { 3L, 4L }); | ||||
var onehot = tf.sparse_tensor_to_dense(decoded_list); | var onehot = tf.sparse_tensor_to_dense(decoded_list); | ||||
using (var sess = tf.Session()) | |||||
{ | |||||
var result = sess.run(onehot); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0 }, result[0].ToArray<int>())); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 2, 0 }, result[1].ToArray<int>())); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0 }, result[2].ToArray<int>())); | |||||
} | |||||
var sess = tf.Session(); | |||||
var result = sess.run(onehot); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0 }, result[0].ToArray<int>())); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 2, 0 }, result[1].ToArray<int>())); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0 }, result[2].ToArray<int>())); | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
@@ -56,14 +52,12 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
int[,] crops = { { 0, 0 }, { 0, 0 } }; | int[,] crops = { { 0, 0 }, { 0, 0 } }; | ||||
var tensor = tf.batch_to_space_nd(inputs, block_shape, crops); | var tensor = tf.batch_to_space_nd(inputs, block_shape, crops); | ||||
using (var sess = tf.Session()) | |||||
{ | |||||
var result = sess.run(tensor); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 6, 1, 7, 2, 8 }, result[0, 0].ToArray<int>())); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 12, 18, 13, 19, 14, 20 }, result[0, 1].ToArray<int>())); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 9, 4, 10, 5, 11 }, result[0, 2].ToArray<int>())); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray<int>())); | |||||
} | |||||
var sess = tf.Session(); | |||||
var result = sess.run(tensor); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 6, 1, 7, 2, 8 }, result[0, 0].ToArray<int>())); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 12, 18, 13, 19, 14, 20 }, result[0, 1].ToArray<int>())); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 9, 4, 10, 5, 11 }, result[0, 2].ToArray<int>())); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray<int>())); | |||||
} | } | ||||
[TestMethod, Ignore] | [TestMethod, Ignore] | ||||
@@ -72,11 +66,9 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
var tensor = new[] { 0, 1, 2, 3 }; | var tensor = new[] { 0, 1, 2, 3 }; | ||||
var mask = np.array(new[] { true, false, true, false }); | var mask = np.array(new[] { true, false, true, false }); | ||||
var masked = tf.boolean_mask(tensor, mask); | var masked = tf.boolean_mask(tensor, mask); | ||||
using (var sess = tf.Session()) | |||||
{ | |||||
var result = sess.run(masked); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, masked.ToArray<int>())); | |||||
} | |||||
var sess = tf.Session(); | |||||
var result = sess.run(masked); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, masked.ToArray<int>())); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -14,7 +14,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
var v = tf.Variable(new[] { 1, 2 }); | var v = tf.Variable(new[] { 1, 2 }); | ||||
var init = tf.compat.v1.global_variables_initializer(); | var init = tf.compat.v1.global_variables_initializer(); | ||||
using var sess = tf.compat.v1.Session(); | |||||
var sess = tf.compat.v1.Session(); | |||||
sess.run(init); | sess.run(init); | ||||
// Usage passing the session explicitly. | // Usage passing the session explicitly. | ||||
print(v.eval(sess)); | print(v.eval(sess)); | ||||
@@ -16,18 +16,16 @@ namespace TensorFlowNET.UnitTest.ControlFlowTest | |||||
{ | { | ||||
var graph = tf.Graph().as_default(); | var graph = tf.Graph().as_default(); | ||||
using (var sess = tf.Session(graph)) | |||||
{ | |||||
var x = tf.constant(2, name: "x"); | |||||
var y = tf.constant(5, name: "y"); | |||||
var z = control_flow_ops.cond(tf.less(x, y), | |||||
() => tf.constant(22, name: "t22"), | |||||
() => tf.constant(55, name: "f55")); | |||||
int result = z.eval(sess); | |||||
assertEquals(result, 22); | |||||
} | |||||
var sess = tf.Session(graph); | |||||
var x = tf.constant(2, name: "x"); | |||||
var y = tf.constant(5, name: "y"); | |||||
var z = control_flow_ops.cond(tf.less(x, y), | |||||
() => tf.constant(22, name: "t22"), | |||||
() => tf.constant(55, name: "f55")); | |||||
int result = z.eval(sess); | |||||
assertEquals(result, 22); | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
@@ -35,18 +33,16 @@ namespace TensorFlowNET.UnitTest.ControlFlowTest | |||||
{ | { | ||||
var graph = tf.Graph().as_default(); | var graph = tf.Graph().as_default(); | ||||
using (var sess = tf.Session(graph)) | |||||
{ | |||||
var x = tf.constant(2, name: "x"); | |||||
var y = tf.constant(1, name: "y"); | |||||
var sess = tf.Session(graph); | |||||
var x = tf.constant(2, name: "x"); | |||||
var y = tf.constant(1, name: "y"); | |||||
var z = control_flow_ops.cond(tf.less(x, y), | |||||
() => tf.constant(22, name: "t22"), | |||||
() => tf.constant(11, name: "f11")); | |||||
var z = control_flow_ops.cond(tf.less(x, y), | |||||
() => tf.constant(22, name: "t22"), | |||||
() => tf.constant(11, name: "f11")); | |||||
int result = z.eval(sess); | |||||
assertEquals(result, 11); | |||||
} | |||||
int result = z.eval(sess); | |||||
assertEquals(result, 11); | |||||
} | } | ||||
[Ignore("Dependent on UpdateEdge")] | [Ignore("Dependent on UpdateEdge")] | ||||
@@ -23,21 +23,19 @@ namespace TensorFlowNET.UnitTest.ControlFlowTest | |||||
private void _testWhileContextHelper(int maximum_iterations) | private void _testWhileContextHelper(int maximum_iterations) | ||||
{ | { | ||||
// TODO: implement missing code dependencies | // TODO: implement missing code dependencies | ||||
using (var sess = this.cached_session()) | |||||
var sess = this.cached_session(); | |||||
var i = constant_op.constant(0, name: "i"); | |||||
var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, 10, name: "c")); | |||||
var b = new Func<Tensor, Tensor>(x => gen_math_ops.add(x, 1, name: "c")); | |||||
//control_flow_ops.while_loop( | |||||
// c, b, i , maximum_iterations: tf.constant(maximum_iterations)); | |||||
foreach (Operation op in sess.graph.get_operations()) | |||||
{ | { | ||||
var i = constant_op.constant(0, name: "i"); | |||||
var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, 10, name: "c")); | |||||
var b = new Func<Tensor, Tensor>(x => gen_math_ops.add(x, 1, name: "c")); | |||||
//control_flow_ops.while_loop( | |||||
// c, b, i , maximum_iterations: tf.constant(maximum_iterations)); | |||||
foreach (Operation op in sess.graph.get_operations()) | |||||
{ | |||||
var control_flow_context = op._get_control_flow_context(); | |||||
/*if (control_flow_context != null) | |||||
self.assertProtoEquals(control_flow_context.to_proto(), | |||||
WhileContext.from_proto( | |||||
control_flow_context.to_proto()).to_proto(), "");*/ | |||||
} | |||||
var control_flow_context = op._get_control_flow_context(); | |||||
/*if (control_flow_context != null) | |||||
self.assertProtoEquals(control_flow_context.to_proto(), | |||||
WhileContext.from_proto( | |||||
control_flow_context.to_proto()).to_proto(), "");*/ | |||||
} | } | ||||
} | } | ||||
@@ -18,11 +18,9 @@ namespace TensorFlowNET.UnitTest.Gradient | |||||
var y = tf.broadcast_to(x, (2, 4, 3)); | var y = tf.broadcast_to(x, (2, 4, 3)); | ||||
var grad = tf.gradients(y, x); | var grad = tf.gradients(y, x); | ||||
using (var sess = tf.Session(graph)) | |||||
{ | |||||
float result = sess.run(grad[0]); | |||||
Assert.AreEqual(result, 24.0f); | |||||
} | |||||
var sess = tf.Session(graph); | |||||
float result = sess.run(grad[0]); | |||||
Assert.AreEqual(result, 24.0f); | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
@@ -33,11 +31,9 @@ namespace TensorFlowNET.UnitTest.Gradient | |||||
var z = tf.cumsum(y, axis: 1); | var z = tf.cumsum(y, axis: 1); | ||||
var grad = tf.gradients(z, x); | var grad = tf.gradients(z, x); | ||||
using (var sess = tf.Session(graph)) | |||||
{ | |||||
float result = sess.run(grad[0]); | |||||
Assert.AreEqual(result, 60.0f); | |||||
} | |||||
var sess = tf.Session(graph); | |||||
float result = sess.run(grad[0]); | |||||
Assert.AreEqual(result, 60.0f); | |||||
} | } | ||||
[TestMethod, Ignore] | [TestMethod, Ignore] | ||||
@@ -78,14 +74,12 @@ namespace TensorFlowNET.UnitTest.Gradient | |||||
42.0f, 42.0f, 42.0f, | 42.0f, 42.0f, 42.0f, | ||||
45.0f, 45.0f, 45.0f | 45.0f, 45.0f, 45.0f | ||||
}; | }; | ||||
using (var sess = tf.Session()) | |||||
{ | |||||
var result = sess.run(g); | |||||
var resultList = result[0].ToArray<float>().ToList(); | |||||
resultList.AddRange(result[1].ToArray<float>()); | |||||
Console.WriteLine(result.ToString()); | |||||
CollectionAssert.AreEqual(resultList.ToArray(), checkG); | |||||
} | |||||
var sess = tf.Session(); | |||||
var result = sess.run(g); | |||||
var resultList = result[0].ToArray<float>().ToList(); | |||||
resultList.AddRange(result[1].ToArray<float>()); | |||||
Console.WriteLine(result.ToString()); | |||||
CollectionAssert.AreEqual(resultList.ToArray(), checkG); | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
@@ -97,11 +91,9 @@ namespace TensorFlowNET.UnitTest.Gradient | |||||
var y = f(x); | var y = f(x); | ||||
var g = tf.gradients(y, x); | var g = tf.gradients(y, x); | ||||
using (var session = tf.Session()) | |||||
{ | |||||
var result = session.run(new[] { y, g[0] }); | |||||
return (result[0].ToArray<T>()[0], result[1].ToArray<T>()[0]); | |||||
} | |||||
var session = tf.Session(); | |||||
var result = session.run(new[] { y, g[0] }); | |||||
return (result[0].ToArray<T>()[0], result[1].ToArray<T>()[0]); | |||||
} | } | ||||
void test(string name, Func<Tensor, Tensor> tfF, Func<double, (double, double)> targetF, double[] values) | void test(string name, Func<Tensor, Tensor> tfF, Func<double, (double, double)> targetF, double[] values) | ||||
@@ -197,13 +189,11 @@ namespace TensorFlowNET.UnitTest.Gradient | |||||
var g1 = tf.gradients(tf.reduce_sum(m, axis: 0)[0], x)[0]; | var g1 = tf.gradients(tf.reduce_sum(m, axis: 0)[0], x)[0]; | ||||
var g2 = tf.gradients(tf.reduce_sum(m, axis: 1)[0], x)[0]; | var g2 = tf.gradients(tf.reduce_sum(m, axis: 1)[0], x)[0]; | ||||
using (var session = tf.Session()) | |||||
{ | |||||
var (r0, r1, r2) = session.run((g0, g1, g2), new FeedItem(x, new[,] { { 1.0 } })); | |||||
self.assertFloat64Equal(6.0, r0[0], $"tf.reduce_sum(...)"); | |||||
self.assertFloat64Equal(2.0, r1[0], $"tf.reduce_sum(..., axis = 0)"); | |||||
self.assertFloat64Equal(3.0, r2[0], $"tf.reduce_sum(..., axis = 1)"); | |||||
} | |||||
var session = tf.Session(); | |||||
var (r0, r1, r2) = session.run((g0, g1, g2), new FeedItem(x, new[,] { { 1.0 } })); | |||||
self.assertFloat64Equal(6.0, r0[0], $"tf.reduce_sum(...)"); | |||||
self.assertFloat64Equal(2.0, r1[0], $"tf.reduce_sum(..., axis = 0)"); | |||||
self.assertFloat64Equal(3.0, r2[0], $"tf.reduce_sum(..., axis = 1)"); | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
@@ -212,12 +202,10 @@ namespace TensorFlowNET.UnitTest.Gradient | |||||
var a = tf.constant(1f); | var a = tf.constant(1f); | ||||
var b = tf.tanh(a); | var b = tf.tanh(a); | ||||
var g = tf.gradients(b, a); | var g = tf.gradients(b, a); | ||||
using (var sess = tf.Session()) | |||||
{ | |||||
var result = sess.run(g); | |||||
var actual = result[0]; | |||||
Assert.AreEqual(actual, 0.41997434127f); | |||||
} | |||||
var sess = tf.Session(); | |||||
var result = sess.run(g); | |||||
var actual = result[0]; | |||||
Assert.AreEqual(actual, 0.41997434127f); | |||||
} | } | ||||
@@ -227,14 +215,12 @@ namespace TensorFlowNET.UnitTest.Gradient | |||||
var a = tf.constant(5f); | var a = tf.constant(5f); | ||||
var b = tf.lgamma(a); | var b = tf.lgamma(a); | ||||
var g = tf.gradients(b, a); | var g = tf.gradients(b, a); | ||||
using (var sess = tf.Session()) | |||||
{ | |||||
var result = sess.run(new object[] { g, b }); | |||||
var actualDeriv = result[0]; | |||||
var actual = result[1]; | |||||
Assert.AreEqual(actualDeriv, 1.5061177f); | |||||
Assert.AreEqual(actual, 3.17805386f); | |||||
} | |||||
var sess = tf.Session(); | |||||
var result = sess.run(new object[] { g, b }); | |||||
var actualDeriv = result[0]; | |||||
var actual = result[1]; | |||||
Assert.AreEqual(actualDeriv, 1.5061177f); | |||||
Assert.AreEqual(actual, 3.17805386f); | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
@@ -247,14 +233,12 @@ namespace TensorFlowNET.UnitTest.Gradient | |||||
tf.constant(new[] { 1 }, tf.int32, new[] { 1 }) | tf.constant(new[] { 1 }, tf.int32, new[] { 1 }) | ||||
); | ); | ||||
var g = tf.gradients(b, a); | var g = tf.gradients(b, a); | ||||
using (var sess = tf.Session()) | |||||
{ | |||||
var result = sess.run(new object[] { g, b }); | |||||
var actualDeriv = np.squeeze(result[0]); | |||||
var actual = np.squeeze(result[1]); | |||||
Assert.AreEqual(actualDeriv, new float[] { 1, 0 }); | |||||
Assert.AreEqual(actual, 0.9640276f); | |||||
} | |||||
var sess = tf.Session(); | |||||
var result = sess.run(new object[] { g, b }); | |||||
var actualDeriv = np.squeeze(result[0]); | |||||
var actual = np.squeeze(result[1]); | |||||
Assert.AreEqual(actualDeriv, new float[] { 1, 0 }); | |||||
Assert.AreEqual(actual, 0.9640276f); | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
@@ -264,14 +248,12 @@ namespace TensorFlowNET.UnitTest.Gradient | |||||
var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); | var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); | ||||
var a = tf.concat(new List<Tensor>(new[] { a1, a2 }), 0); | var a = tf.concat(new List<Tensor>(new[] { a1, a2 }), 0); | ||||
var g = tf.gradients(a, a1); | var g = tf.gradients(a, a1); | ||||
using (var sess = tf.Session()) | |||||
{ | |||||
var result = sess.run(new object[] { g, a }); | |||||
var actualDeriv = result[0][0]; | |||||
var actual = result[1][0]; | |||||
Assert.AreEqual(actualDeriv, 1f); | |||||
Assert.AreEqual(actual, 2f); | |||||
} | |||||
var sess = tf.Session(); | |||||
var result = sess.run(new object[] { g, a }); | |||||
var actualDeriv = result[0][0]; | |||||
var actual = result[1][0]; | |||||
Assert.AreEqual(actualDeriv, 1f); | |||||
Assert.AreEqual(actual, 2f); | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
@@ -280,13 +262,12 @@ namespace TensorFlowNET.UnitTest.Gradient | |||||
var ap = tf.constant(1f); | var ap = tf.constant(1f); | ||||
var b = tf.tanh(ap) + gen_array_ops.stop_gradient(ap); | var b = tf.tanh(ap) + gen_array_ops.stop_gradient(ap); | ||||
var g = tf.gradients(b, ap); | var g = tf.gradients(b, ap); | ||||
using (var sess = tf.Session()) | |||||
{ | |||||
var result = sess.run(g); | |||||
var actual = result[0]; | |||||
Assert.AreEqual(actual, 0.41997434127f); | |||||
} | |||||
var sess = tf.Session(); | |||||
var result = sess.run(g); | |||||
var actual = result[0]; | |||||
Assert.AreEqual(actual, 0.41997434127f); | |||||
} | } | ||||
[Ignore("TODO")] | [Ignore("TODO")] | ||||
[TestMethod] | [TestMethod] | ||||
public void testUnusedOutput() | public void testUnusedOutput() | ||||
@@ -74,23 +74,21 @@ namespace TensorFlowNET.UnitTest | |||||
var cropSize2_2 = tf.Variable(np.array(4, 4)); | var cropSize2_2 = tf.Variable(np.array(4, 4)); | ||||
var init = tf.global_variables_initializer(); | var init = tf.global_variables_initializer(); | ||||
using (Session sess = tf.Session()) | |||||
{ | |||||
sess.run(init); | |||||
var sess = tf.Session(); | |||||
sess.run(init); | |||||
var cropped = tf.image.crop_and_resize(image, box, boxInd, cropSize1_1); | |||||
var cropped = tf.image.crop_and_resize(image, box, boxInd, cropSize1_1); | |||||
var result = sess.run(cropped); | |||||
// check if cropped to 1x1 center was succesfull | |||||
Assert.AreEqual(result.size, 1ul); | |||||
Assert.AreEqual(result[0, 0, 0, 0], 4f); | |||||
var result = sess.run(cropped); | |||||
// check if cropped to 1x1 center was succesfull | |||||
Assert.AreEqual(result.size, 1ul); | |||||
Assert.AreEqual(result[0, 0, 0, 0], 4f); | |||||
cropped = tf.image.crop_and_resize(image2, box, boxInd, cropSize2_2); | |||||
result = sess.run(cropped); | |||||
// check if flipped and no cropping occured | |||||
Assert.AreEqual(result.size, 16ul); | |||||
Assert.AreEqual(result[0, 0, 0, 0], 12f); | |||||
} | |||||
cropped = tf.image.crop_and_resize(image2, box, boxInd, cropSize2_2); | |||||
result = sess.run(cropped); | |||||
// check if flipped and no cropping occured | |||||
Assert.AreEqual(result.size, 16ul); | |||||
Assert.AreEqual(result[0, 0, 0, 0], 12f); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -24,7 +24,7 @@ namespace TensorFlowNET.UnitTest | |||||
{ | { | ||||
Assert.IsNull(tf.peak_default_graph()); | Assert.IsNull(tf.peak_default_graph()); | ||||
using var sess = tf.Session(); | |||||
var sess = tf.Session(); | |||||
var default_graph = tf.get_default_graph(); | var default_graph = tf.get_default_graph(); | ||||
var sess_graph = sess.graph; | var sess_graph = sess.graph; | ||||
Assert.IsNotNull(default_graph); | Assert.IsNotNull(default_graph); | ||||
@@ -45,7 +45,7 @@ namespace TensorFlowNET.UnitTest | |||||
{ | { | ||||
Assert.IsNull(tf.peak_default_graph()); | Assert.IsNull(tf.peak_default_graph()); | ||||
//tf.Session created an other graph | //tf.Session created an other graph | ||||
using var sess = tf.Session(); | |||||
var sess = tf.Session(); | |||||
var default_graph = tf.get_default_graph(); | var default_graph = tf.get_default_graph(); | ||||
var sess_graph = sess.graph; | var sess_graph = sess.graph; | ||||
Assert.IsNotNull(default_graph); | Assert.IsNotNull(default_graph); | ||||
@@ -69,7 +69,7 @@ namespace TensorFlowNET.UnitTest | |||||
beforehand.as_default(); | beforehand.as_default(); | ||||
Assert.IsNotNull(tf.peak_default_graph()); | Assert.IsNotNull(tf.peak_default_graph()); | ||||
using var sess = tf.Session(); | |||||
var sess = tf.Session(); | |||||
var default_graph = tf.peak_default_graph(); | var default_graph = tf.peak_default_graph(); | ||||
var sess_graph = sess.graph; | var sess_graph = sess.graph; | ||||
Assert.IsNotNull(default_graph); | Assert.IsNotNull(default_graph); | ||||
@@ -102,7 +102,7 @@ namespace TensorFlowNET.UnitTest | |||||
//the core method | //the core method | ||||
void Core(int tid) | void Core(int tid) | ||||
{ | { | ||||
using var sess = tf.Session(); | |||||
var sess = tf.Session(); | |||||
for (int i = 0; i < 100; i++) | for (int i = 0; i < 100; i++) | ||||
{ | { | ||||
var t = new Tensor(1); | var t = new Tensor(1); | ||||
@@ -119,7 +119,7 @@ namespace TensorFlowNET.UnitTest | |||||
void Core(int tid) | void Core(int tid) | ||||
{ | { | ||||
//tf.Session created an other graph | //tf.Session created an other graph | ||||
using var sess = tf.Session(); | |||||
var sess = tf.Session(); | |||||
for (int i = 0; i < 100; i++) | for (int i = 0; i < 100; i++) | ||||
{ | { | ||||
var t = new Tensor(new int[] { 1, 2, 3 }); | var t = new Tensor(new int[] { 1, 2, 3 }); | ||||
@@ -142,7 +142,7 @@ namespace TensorFlowNET.UnitTest | |||||
var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | ||||
var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); | var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); | ||||
var math = a1 + a2; | var math = a1 + a2; | ||||
using var sess = tf.Session(graph); | |||||
var sess = tf.Session(graph); | |||||
for (int i = 0; i < 100; i++) | for (int i = 0; i < 100; i++) | ||||
{ | { | ||||
var result = sess.run(math); | var result = sess.run(math); | ||||
@@ -162,7 +162,7 @@ namespace TensorFlowNET.UnitTest | |||||
tf.compat.v1.disable_eager_execution(); | tf.compat.v1.disable_eager_execution(); | ||||
var graph = tf.Graph().as_default(); | var graph = tf.Graph().as_default(); | ||||
using var sess = tf.Session(graph); | |||||
var sess = tf.Session(graph); | |||||
Assert.IsNotNull(tf.get_default_graph()); | Assert.IsNotNull(tf.get_default_graph()); | ||||
//graph is created automatically to perform create these operations | //graph is created automatically to perform create these operations | ||||
var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | ||||
@@ -182,7 +182,7 @@ namespace TensorFlowNET.UnitTest | |||||
//the core method | //the core method | ||||
void Core(int tid) | void Core(int tid) | ||||
{ | { | ||||
using var sess = tf.Session(); | |||||
var sess = tf.Session(); | |||||
Assert.IsNotNull(tf.get_default_graph()); | Assert.IsNotNull(tf.get_default_graph()); | ||||
//graph is created automatically to perform create these operations | //graph is created automatically to perform create these operations | ||||
var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | ||||
@@ -182,23 +182,21 @@ namespace TensorFlowNET.UnitTest | |||||
// return self._eval_helper(tensors) | // return self._eval_helper(tensors) | ||||
// else: | // else: | ||||
{ | { | ||||
using (var sess = tf.Session()) | |||||
var sess = tf.Session(); | |||||
var ndarray = tensor.eval(sess); | |||||
if (typeof(T) == typeof(double)) | |||||
{ | { | ||||
var ndarray = tensor.eval(sess); | |||||
if (typeof(T) == typeof(double)) | |||||
{ | |||||
double x = ndarray; | |||||
result = x; | |||||
} | |||||
else if (typeof(T) == typeof(int)) | |||||
{ | |||||
int x = ndarray; | |||||
result = x; | |||||
} | |||||
else | |||||
{ | |||||
result = ndarray; | |||||
} | |||||
double x = ndarray; | |||||
result = x; | |||||
} | |||||
else if (typeof(T) == typeof(int)) | |||||
{ | |||||
int x = ndarray; | |||||
result = x; | |||||
} | |||||
else | |||||
{ | |||||
result = ndarray; | |||||
} | } | ||||
return (T)result; | return (T)result; | ||||
@@ -48,7 +48,7 @@ namespace Tensorflow.Native.UnitTest | |||||
private void EXPECT_TF_META(Operation oper, string attr_name, int expected_list_size, TF_AttrType expected_type, uint expected_total_size) | private void EXPECT_TF_META(Operation oper, string attr_name, int expected_list_size, TF_AttrType expected_type, uint expected_total_size) | ||||
{ | { | ||||
var m = c_api.TF_OperationGetAttrMetadata(oper, attr_name, s_.Handle); | |||||
var m = c_api.TF_OperationGetAttrMetadata(oper, attr_name, s_); | |||||
EXPECT_EQ(TF_Code.TF_OK, s_.Code); | EXPECT_EQ(TF_Code.TF_OK, s_.Code); | ||||
char e = expected_list_size >= 0 ? (char)1 : (char)0; | char e = expected_list_size >= 0 ? (char)1 : (char)0; | ||||
/*EXPECT_EQ(e, m.is_list); | /*EXPECT_EQ(e, m.is_list); | ||||
@@ -63,7 +63,7 @@ namespace Tensorflow.Native.UnitTest | |||||
var desc = init("string"); | var desc = init("string"); | ||||
c_api.TF_SetAttrString(desc, "v", "bunny", 5); | c_api.TF_SetAttrString(desc, "v", "bunny", 5); | ||||
var oper = c_api.TF_FinishOperation(desc, s_.Handle); | |||||
var oper = c_api.TF_FinishOperation(desc, s_); | |||||
//ASSERT_EQ(TF_Code.TF_OK, s_.Code); | //ASSERT_EQ(TF_Code.TF_OK, s_.Code); | ||||
//EXPECT_TF_META(oper, "v", -1, TF_AttrType.TF_ATTR_STRING, 5); | //EXPECT_TF_META(oper, "v", -1, TF_AttrType.TF_ATTR_STRING, 5); | ||||
//var value = new char[5]; | //var value = new char[5]; | ||||
@@ -86,8 +86,6 @@ namespace Tensorflow.Native.UnitTest | |||||
public void Dispose() | public void Dispose() | ||||
{ | { | ||||
graph_.Dispose(); | |||||
s_.Dispose(); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -59,7 +59,7 @@ namespace Tensorflow.Native.UnitTest | |||||
private void VerifyCollocation(Operation op, string[] expected) | private void VerifyCollocation(Operation op, string[] expected) | ||||
{ | { | ||||
var handle = c_api.TF_OperationGetAttrMetadata(op, "_class", s_.Handle); | |||||
var handle = c_api.TF_OperationGetAttrMetadata(op, "_class", s_); | |||||
TF_AttrMetadata m = new TF_AttrMetadata(); | TF_AttrMetadata m = new TF_AttrMetadata(); | ||||
if (expected.Length == 0) | if (expected.Length == 0) | ||||
{ | { | ||||
@@ -98,8 +98,6 @@ namespace Tensorflow.Native.UnitTest | |||||
public void Dispose() | public void Dispose() | ||||
{ | { | ||||
graph_.Dispose(); | |||||
s_.Dispose(); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -45,10 +45,10 @@ namespace Tensorflow.Native.UnitTest | |||||
=> c_api.TF_AddInput(desc, input); | => c_api.TF_AddInput(desc, input); | ||||
protected Operation TF_FinishOperation(OperationDescription desc, Status s) | protected Operation TF_FinishOperation(OperationDescription desc, Status s) | ||||
=> c_api.TF_FinishOperation(desc, s.Handle); | |||||
=> c_api.TF_FinishOperation(desc, s); | |||||
protected void TF_SetAttrTensor(OperationDescription desc, string attrName, Tensor value, Status s) | protected void TF_SetAttrTensor(OperationDescription desc, string attrName, Tensor value, Status s) | ||||
=> c_api.TF_SetAttrTensor(desc, attrName, value, s.Handle); | |||||
=> c_api.TF_SetAttrTensor(desc, attrName, value, s); | |||||
protected void TF_SetAttrType(OperationDescription desc, string attrName, TF_DataType dtype) | protected void TF_SetAttrType(OperationDescription desc, string attrName, TF_DataType dtype) | ||||
=> c_api.TF_SetAttrType(desc, attrName, dtype); | => c_api.TF_SetAttrType(desc, attrName, dtype); | ||||
@@ -18,7 +18,7 @@ namespace Tensorflow.Native.UnitTest | |||||
string func_name_ = "MyFunc"; | string func_name_ = "MyFunc"; | ||||
string func_node_name_ = "MyFunc_0"; | string func_node_name_ = "MyFunc_0"; | ||||
Status s_; | Status s_; | ||||
IntPtr func_; | |||||
SafeFuncGraphHandle func_; | |||||
[TestInitialize] | [TestInitialize] | ||||
public void Initialize() | public void Initialize() | ||||
@@ -402,7 +402,7 @@ namespace Tensorflow.Native.UnitTest | |||||
inputs.Length, inputs.ToArray(), | inputs.Length, inputs.ToArray(), | ||||
outputs.Length, outputs.ToArray(), | outputs.Length, outputs.ToArray(), | ||||
output_names == null || output_names.Length == 0 ? null : output_names, | output_names == null || output_names.Length == 0 ? null : output_names, | ||||
IntPtr.Zero, null, s_.Handle); | |||||
IntPtr.Zero, null, s_); | |||||
if (expect_failure) | if (expect_failure) | ||||
{ | { | ||||
@@ -413,7 +413,7 @@ namespace Tensorflow.Native.UnitTest | |||||
ASSERT_EQ(TF_OK, s_.Code, s_.Message); | ASSERT_EQ(TF_OK, s_.Code, s_.Message); | ||||
ASSERT_NE(func_, IntPtr.Zero); | ASSERT_NE(func_, IntPtr.Zero); | ||||
ASSERT_EQ(func_name_, c_api.StringPiece(c_api.TF_FunctionName(func_))); | ASSERT_EQ(func_name_, c_api.StringPiece(c_api.TF_FunctionName(func_))); | ||||
c_api.TF_GraphCopyFunction(host_graph_, func_, IntPtr.Zero, s_.Handle); | |||||
c_api.TF_GraphCopyFunction(host_graph_, func_, IntPtr.Zero, s_); | |||||
ASSERT_EQ(TF_OK, s_.Code, s_.Message); | ASSERT_EQ(TF_OK, s_.Code, s_.Message); | ||||
} | } | ||||
@@ -44,18 +44,14 @@ namespace Tensorflow.Native.UnitTest | |||||
private bool GetGraphDef(Graph graph, out GraphDef graph_def) | private bool GetGraphDef(Graph graph, out GraphDef graph_def) | ||||
{ | { | ||||
graph_def = null; | graph_def = null; | ||||
using (var s = new Status()) | |||||
{ | |||||
using (var buffer = new Buffer()) | |||||
{ | |||||
c_api.TF_GraphToGraphDef(graph, buffer.Handle, s.Handle); | |||||
bool ret = TF_GetCode(s) == TF_OK; | |||||
EXPECT_EQ(TF_OK, TF_GetCode(s)); | |||||
if (ret) | |||||
graph_def = GraphDef.Parser.ParseFrom(buffer.ToArray()); | |||||
return ret; | |||||
} | |||||
} | |||||
var s = new Status(); | |||||
var buffer = new Buffer(); | |||||
c_api.TF_GraphToGraphDef(graph, buffer, s); | |||||
bool ret = TF_GetCode(s) == TF_OK; | |||||
EXPECT_EQ(TF_OK, TF_GetCode(s)); | |||||
if (ret) | |||||
graph_def = GraphDef.Parser.ParseFrom(buffer.ToArray()); | |||||
return ret; | |||||
} | } | ||||
private void RunGraphsAndCompareOutputs(TF_Output[] grad_outputs, TF_Output[] expected_grad_outputs) | private void RunGraphsAndCompareOutputs(TF_Output[] grad_outputs, TF_Output[] expected_grad_outputs) | ||||
@@ -111,9 +107,9 @@ namespace Tensorflow.Native.UnitTest | |||||
IntPtr[] handles = new IntPtr[2] { IntPtr.Zero, IntPtr.Zero }; | IntPtr[] handles = new IntPtr[2] { IntPtr.Zero, IntPtr.Zero }; | ||||
c_api.TF_AddGradientsWithPrefix(graph_, prefix, outputs, noutputs, inputs, | c_api.TF_AddGradientsWithPrefix(graph_, prefix, outputs, noutputs, inputs, | ||||
ninputs, grad_inputs, s_.Handle, handles); | |||||
ninputs, grad_inputs, s_, handles); | |||||
var op = new Operation(handles[0]); | |||||
// var op = new Operation(handles[0]); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -275,9 +271,6 @@ namespace Tensorflow.Native.UnitTest | |||||
public void Dispose() | public void Dispose() | ||||
{ | { | ||||
graph_.Dispose(); | |||||
expected_graph_.Dispose(); | |||||
s_.Dispose(); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -9,7 +9,7 @@ namespace Tensorflow.Native.UnitTest | |||||
[TestMethod, Ignore("Waiting to merge https://github.com/tensorflow/tensorflow/pull/43383")] | [TestMethod, Ignore("Waiting to merge https://github.com/tensorflow/tensorflow/pull/43383")] | ||||
public void UpdateEdge() | public void UpdateEdge() | ||||
{ | { | ||||
using var graph = new Graph().as_default(); | |||||
var graph = new Graph().as_default(); | |||||
var one = tf.constant(1, name: "one"); | var one = tf.constant(1, name: "one"); | ||||
var two = tf.constant(2, name: "two"); | var two = tf.constant(2, name: "two"); | ||||
@@ -35,7 +35,7 @@ namespace Tensorflow.Native.UnitTest | |||||
EXPECT_EQ(attr_value.Type, DataType.DtInt32); | EXPECT_EQ(attr_value.Type, DataType.DtInt32); | ||||
// Test not found errors in TF_Operation*() query functions. | // Test not found errors in TF_Operation*() query functions. | ||||
EXPECT_EQ(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s.Handle)); | |||||
EXPECT_EQ(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s)); | |||||
EXPECT_EQ(TF_Code.TF_INVALID_ARGUMENT, s.Code); | EXPECT_EQ(TF_Code.TF_INVALID_ARGUMENT, s.Code); | ||||
Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s)); | Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s)); | ||||
EXPECT_EQ("Operation 'feed' has no attr named 'missing'.", s.Message); | EXPECT_EQ("Operation 'feed' has no attr named 'missing'.", s.Message); | ||||
@@ -191,9 +191,6 @@ namespace Tensorflow.Native.UnitTest | |||||
ASSERT_TRUE(found_scalar_const); | ASSERT_TRUE(found_scalar_const); | ||||
ASSERT_TRUE(found_add); | ASSERT_TRUE(found_add); | ||||
ASSERT_TRUE(found_neg); | ASSERT_TRUE(found_neg); | ||||
graph.Dispose(); | |||||
s.Dispose(); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -213,16 +210,15 @@ namespace Tensorflow.Native.UnitTest | |||||
// Export to a GraphDef. | // Export to a GraphDef. | ||||
var graph_def = new Buffer(); | var graph_def = new Buffer(); | ||||
c_api.TF_GraphToGraphDef(graph, graph_def.Handle, s.Handle); | |||||
c_api.TF_GraphToGraphDef(graph, graph_def, s); | |||||
EXPECT_EQ(TF_Code.TF_OK, s.Code); | EXPECT_EQ(TF_Code.TF_OK, s.Code); | ||||
// Import it, with a prefix, in a fresh graph. | // Import it, with a prefix, in a fresh graph. | ||||
graph.Dispose(); | |||||
graph = new Graph().as_default(); | graph = new Graph().as_default(); | ||||
using (var opts = c_api.TF_NewImportGraphDefOptions()) | using (var opts = c_api.TF_NewImportGraphDefOptions()) | ||||
{ | { | ||||
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); | c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); | ||||
c_api.TF_GraphImportGraphDef(graph, graph_def.Handle, opts, s.Handle); | |||||
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); | |||||
EXPECT_EQ(TF_Code.TF_OK, s.Code); | EXPECT_EQ(TF_Code.TF_OK, s.Code); | ||||
} | } | ||||
@@ -265,7 +261,7 @@ namespace Tensorflow.Native.UnitTest | |||||
EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts)); | EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts)); | ||||
c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar"); | c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar"); | ||||
EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts)); | EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts)); | ||||
var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def.Handle, opts, s.Handle); | |||||
var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s); | |||||
EXPECT_EQ(TF_Code.TF_OK, s.Code); | EXPECT_EQ(TF_Code.TF_OK, s.Code); | ||||
return results; | return results; | ||||
@@ -305,7 +301,7 @@ namespace Tensorflow.Native.UnitTest | |||||
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported3"); | c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported3"); | ||||
c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed); | c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed); | ||||
c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2); | c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2); | ||||
c_api.TF_GraphImportGraphDef(graph, graph_def.Handle, opts, s.Handle); | |||||
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); | |||||
EXPECT_EQ(TF_Code.TF_OK, s.Code); | EXPECT_EQ(TF_Code.TF_OK, s.Code); | ||||
} | } | ||||
@@ -330,7 +326,7 @@ namespace Tensorflow.Native.UnitTest | |||||
// Export to a graph def so we can import a graph with control dependencies | // Export to a graph def so we can import a graph with control dependencies | ||||
graph_def = new Buffer(); | graph_def = new Buffer(); | ||||
c_api.TF_GraphToGraphDef(graph, graph_def.Handle, s.Handle); | |||||
c_api.TF_GraphToGraphDef(graph, graph_def, s); | |||||
EXPECT_EQ(TF_Code.TF_OK, s.Code); | EXPECT_EQ(TF_Code.TF_OK, s.Code); | ||||
// Import again, with remapped control dependency, into the same graph | // Import again, with remapped control dependency, into the same graph | ||||
@@ -338,7 +334,7 @@ namespace Tensorflow.Native.UnitTest | |||||
{ | { | ||||
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported4"); | c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported4"); | ||||
c_api.TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed); | c_api.TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed); | ||||
c_api.TF_GraphImportGraphDef(graph, graph_def.Handle, opts, s.Handle); | |||||
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); | |||||
ASSERT_EQ(TF_Code.TF_OK, s.Code); | ASSERT_EQ(TF_Code.TF_OK, s.Code); | ||||
} | } | ||||
@@ -380,7 +376,6 @@ namespace Tensorflow.Native.UnitTest | |||||
ASSERT_EQ(TF_Code.TF_OK, s.Code); | ASSERT_EQ(TF_Code.TF_OK, s.Code); | ||||
// Import it in a fresh graph with return outputs. | // Import it in a fresh graph with return outputs. | ||||
graph.Dispose(); | |||||
graph = new Graph().as_default(); | graph = new Graph().as_default(); | ||||
var opts = new ImportGraphDefOptions(); | var opts = new ImportGraphDefOptions(); | ||||
opts.AddReturnOutput("feed", 0); | opts.AddReturnOutput("feed", 0); | ||||
@@ -401,11 +396,6 @@ namespace Tensorflow.Native.UnitTest | |||||
EXPECT_EQ(0, return_outputs[0].index); | EXPECT_EQ(0, return_outputs[0].index); | ||||
EXPECT_EQ(scalar, return_outputs[1].oper); | EXPECT_EQ(scalar, return_outputs[1].oper); | ||||
EXPECT_EQ(0, return_outputs[1].index); | EXPECT_EQ(0, return_outputs[1].index); | ||||
opts.Dispose(); | |||||
graph_def.Dispose(); | |||||
graph.Dispose(); | |||||
s.Dispose(); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -422,16 +412,14 @@ namespace Tensorflow.Native.UnitTest | |||||
public void ImportGraphMeta() | public void ImportGraphMeta() | ||||
{ | { | ||||
var dir = "my-save-dir/"; | var dir = "my-save-dir/"; | ||||
using (var sess = tf.Session()) | |||||
{ | |||||
var new_saver = tf.train.import_meta_graph(dir + "my-model-10000.meta"); | |||||
new_saver.restore(sess, dir + "my-model-10000"); | |||||
var labels = tf.constant(0, dtype: tf.int32, shape: new int[] { 100 }, name: "labels"); | |||||
var batch_size = tf.size(labels); | |||||
var logits = tf.get_collection<ITensorOrOperation>("logits")[0] as Tensor; | |||||
var loss = tf.losses.sparse_softmax_cross_entropy(labels: labels, | |||||
logits: logits); | |||||
} | |||||
var sess = tf.Session(); | |||||
var new_saver = tf.train.import_meta_graph(dir + "my-model-10000.meta"); | |||||
new_saver.restore(sess, dir + "my-model-10000"); | |||||
var labels = tf.constant(0, dtype: tf.int32, shape: new int[] { 100 }, name: "labels"); | |||||
var batch_size = tf.size(labels); | |||||
var logits = tf.get_collection<ITensorOrOperation>("logits")[0] as Tensor; | |||||
var loss = tf.losses.sparse_softmax_cross_entropy(labels: labels, | |||||
logits: logits); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -11,7 +11,7 @@ namespace Tensorflow.Native.UnitTest | |||||
/// </summary> | /// </summary> | ||||
public class CSession | public class CSession | ||||
{ | { | ||||
private IntPtr session_; | |||||
private SafeSessionHandle session_; | |||||
private List<TF_Output> inputs_ = new List<TF_Output>(); | private List<TF_Output> inputs_ = new List<TF_Output>(); | ||||
private List<Tensor> input_values_ = new List<Tensor>(); | private List<Tensor> input_values_ = new List<Tensor>(); | ||||
@@ -22,11 +22,8 @@ namespace Tensorflow.Native.UnitTest | |||||
public CSession(Graph graph, Status s, bool user_XLA = false) | public CSession(Graph graph, Status s, bool user_XLA = false) | ||||
{ | { | ||||
lock (Locks.ProcessWide) | |||||
{ | |||||
var config = new ConfigProto { InterOpParallelismThreads = 4 }; | |||||
session_ = new Session(graph, config, s); | |||||
} | |||||
var config = new ConfigProto { InterOpParallelismThreads = 4 }; | |||||
session_ = new Session(graph, config, s); | |||||
} | } | ||||
public void SetInputs(Dictionary<Operation, Tensor> inputs) | public void SetInputs(Dictionary<Operation, Tensor> inputs) | ||||
@@ -85,7 +82,7 @@ namespace Tensorflow.Native.UnitTest | |||||
c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length, | c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length, | ||||
outputs_ptr, output_values_ptr, outputs_.Count, | outputs_ptr, output_values_ptr, outputs_.Count, | ||||
targets_ptr, targets_.Count, | targets_ptr, targets_.Count, | ||||
IntPtr.Zero, s.Handle); | |||||
IntPtr.Zero, s); | |||||
s.Check(); | s.Check(); | ||||
@@ -14,8 +14,8 @@ namespace Tensorflow.Native.UnitTest.Sessions | |||||
[TestMethod] | [TestMethod] | ||||
public void Session() | public void Session() | ||||
{ | { | ||||
using var s = new Status(); | |||||
using var graph = new Graph(); | |||||
var s = new Status(); | |||||
var graph = new Graph(); | |||||
// Make a placeholder operation. | // Make a placeholder operation. | ||||
var feed = c_test_util.Placeholder(graph, s); | var feed = c_test_util.Placeholder(graph, s); | ||||
@@ -139,45 +139,45 @@ namespace Tensorflow.Native.UnitTest.Tensors | |||||
var feed_out_0 = new TF_Output(feed, 0); | var feed_out_0 = new TF_Output(feed, 0); | ||||
// Fetch the shape, it should be completely unknown. | // Fetch the shape, it should be completely unknown. | ||||
int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle); | |||||
int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | |||||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
EXPECT_EQ(-1, num_dims); | EXPECT_EQ(-1, num_dims); | ||||
// Set the shape to be unknown, expect no change. | // Set the shape to be unknown, expect no change. | ||||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s.Handle); | |||||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s); | |||||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle); | |||||
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | |||||
EXPECT_EQ(-1, num_dims); | EXPECT_EQ(-1, num_dims); | ||||
// Set the shape to be 2 x Unknown | // Set the shape to be 2 x Unknown | ||||
long[] dims = { 2, -1 }; | long[] dims = { 2, -1 }; | ||||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s.Handle); | |||||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); | |||||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle); | |||||
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | |||||
EXPECT_EQ(2, num_dims); | EXPECT_EQ(2, num_dims); | ||||
// Get the dimension vector appropriately. | // Get the dimension vector appropriately. | ||||
var returned_dims = new long[dims.Length]; | var returned_dims = new long[dims.Length]; | ||||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); | |||||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | ||||
// Set to a new valid shape: [2, 3] | // Set to a new valid shape: [2, 3] | ||||
dims[1] = 3; | dims[1] = 3; | ||||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s.Handle); | |||||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); | |||||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
// Fetch and see that the new value is returned. | // Fetch and see that the new value is returned. | ||||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); | |||||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | ||||
// Try to set 'unknown' with unknown rank on the shape and see that | // Try to set 'unknown' with unknown rank on the shape and see that | ||||
// it doesn't change. | // it doesn't change. | ||||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s.Handle); | |||||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s); | |||||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); | |||||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
EXPECT_EQ(2, num_dims); | EXPECT_EQ(2, num_dims); | ||||
EXPECT_EQ(2, (int)returned_dims[0]); | EXPECT_EQ(2, (int)returned_dims[0]); | ||||
@@ -187,21 +187,21 @@ namespace Tensorflow.Native.UnitTest.Tensors | |||||
// it doesn't change. | // it doesn't change. | ||||
dims[0] = -1; | dims[0] = -1; | ||||
dims[1] = -1; | dims[1] = -1; | ||||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s.Handle); | |||||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s); | |||||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); | |||||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
EXPECT_EQ(2, num_dims); | EXPECT_EQ(2, num_dims); | ||||
EXPECT_EQ(2, (int)returned_dims[0]); | EXPECT_EQ(2, (int)returned_dims[0]); | ||||
EXPECT_EQ(3, (int)returned_dims[1]); | EXPECT_EQ(3, (int)returned_dims[1]); | ||||
// Try to fetch a shape with the wrong num_dims | // Try to fetch a shape with the wrong num_dims | ||||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s.Handle); | |||||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s); | |||||
Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); | Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); | ||||
// Try to set an invalid shape (cannot change 2x3 to a 2x5). | // Try to set an invalid shape (cannot change 2x3 to a 2x5). | ||||
dims[1] = 5; | dims[1] = 5; | ||||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s.Handle); | |||||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s); | |||||
Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); | Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); | ||||
// Test for a scalar. | // Test for a scalar. | ||||
@@ -209,14 +209,13 @@ namespace Tensorflow.Native.UnitTest.Tensors | |||||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
var three_out_0 = new TF_Output(three, 0); | var three_out_0 = new TF_Output(three, 0); | ||||
num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s.Handle); | |||||
num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s); | |||||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
EXPECT_EQ(0, num_dims); | EXPECT_EQ(0, num_dims); | ||||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, dims, num_dims, s.Handle); | |||||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, dims, num_dims, s); | |||||
Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); | Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); | ||||
graph.Exit(); | graph.Exit(); | ||||
s.Dispose(); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -23,7 +23,7 @@ namespace Tensorflow.Native.UnitTest | |||||
c_api.TF_AddInputList(desc, inputs, inputs.Length); | c_api.TF_AddInputList(desc, inputs, inputs.Length); | ||||
var op = c_api.TF_FinishOperation(desc, s.Handle); | |||||
var op = c_api.TF_FinishOperation(desc, s); | |||||
s.Check(); | s.Check(); | ||||
return op; | return op; | ||||
@@ -33,37 +33,29 @@ namespace Tensorflow.Native.UnitTest | |||||
[SuppressMessage("ReSharper", "RedundantAssignment")] | [SuppressMessage("ReSharper", "RedundantAssignment")] | ||||
public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) | public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) | ||||
{ | { | ||||
lock (Locks.ProcessWide) | |||||
{ | |||||
using (var buffer = new Buffer()) | |||||
{ | |||||
c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer.Handle, s.Handle); | |||||
attr_value = AttrValue.Parser.ParseFrom(buffer.ToArray()); | |||||
} | |||||
var buffer = new Buffer(); | |||||
return s.Code == TF_Code.TF_OK; | |||||
} | |||||
c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); | |||||
attr_value = AttrValue.Parser.ParseFrom(buffer.ToArray()); | |||||
return s.Code == TF_Code.TF_OK; | |||||
} | } | ||||
public static GraphDef GetGraphDef(Graph graph) | public static GraphDef GetGraphDef(Graph graph) | ||||
{ | { | ||||
lock (Locks.ProcessWide) | |||||
{ | |||||
using (var s = new Status()) | |||||
using (var buffer = new Buffer()) | |||||
{ | |||||
c_api.TF_GraphToGraphDef(graph, buffer.Handle, s.Handle); | |||||
s.Check(); | |||||
return GraphDef.Parser.ParseFrom(buffer.ToArray()); | |||||
} | |||||
} | |||||
var s = new Status(); | |||||
var buffer = new Buffer(); | |||||
c_api.TF_GraphToGraphDef(graph, buffer, s); | |||||
s.Check(); | |||||
return GraphDef.Parser.ParseFrom(buffer.ToArray()); | |||||
} | } | ||||
public static FunctionDef GetFunctionDef(IntPtr func) | |||||
public static FunctionDef GetFunctionDef(SafeFuncGraphHandle func) | |||||
{ | { | ||||
using var s = new Status(); | |||||
using var buffer = new Buffer(); | |||||
c_api.TF_FunctionToFunctionDef(func, buffer.Handle, s.Handle); | |||||
var s = new Status(); | |||||
var buffer = new Buffer(); | |||||
c_api.TF_FunctionToFunctionDef(func, buffer, s); | |||||
s.Check(true); | s.Check(true); | ||||
var func_def = FunctionDef.Parser.ParseFrom(buffer.ToArray()); | var func_def = FunctionDef.Parser.ParseFrom(buffer.ToArray()); | ||||
return func_def; | return func_def; | ||||
@@ -192,7 +184,7 @@ namespace Tensorflow.Native.UnitTest | |||||
OperationDescription desc = c_api.TF_NewOperation(graph, "Neg", name); | OperationDescription desc = c_api.TF_NewOperation(graph, "Neg", name); | ||||
var neg_input = new TF_Output(n, 0); | var neg_input = new TF_Output(n, 0); | ||||
c_api.TF_AddInput(desc, neg_input); | c_api.TF_AddInput(desc, neg_input); | ||||
var op = c_api.TF_FinishOperation(desc, s.Handle); | |||||
var op = c_api.TF_FinishOperation(desc, s); | |||||
s.Check(); | s.Check(); | ||||
return op; | return op; | ||||
@@ -210,7 +202,7 @@ namespace Tensorflow.Native.UnitTest | |||||
c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length); | c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length); | ||||
} | } | ||||
var op = c_api.TF_FinishOperation(desc, s.Handle); | |||||
var op = c_api.TF_FinishOperation(desc, s); | |||||
s.Check(); | s.Check(); | ||||
return op; | return op; | ||||
@@ -222,10 +214,10 @@ namespace Tensorflow.Native.UnitTest | |||||
lock (Locks.ProcessWide) | lock (Locks.ProcessWide) | ||||
{ | { | ||||
var desc = c_api.TF_NewOperation(graph, "Const", name); | var desc = c_api.TF_NewOperation(graph, "Const", name); | ||||
c_api.TF_SetAttrTensor(desc, "value", t, s.Handle); | |||||
c_api.TF_SetAttrTensor(desc, "value", t, s); | |||||
s.Check(); | s.Check(); | ||||
c_api.TF_SetAttrType(desc, "dtype", t.dtype); | c_api.TF_SetAttrType(desc, "dtype", t.dtype); | ||||
var op = c_api.TF_FinishOperation(desc, s.Handle); | |||||
var op = c_api.TF_FinishOperation(desc, s); | |||||
s.Check(); | s.Check(); | ||||
return op; | return op; | ||||
@@ -17,10 +17,8 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
public void ImportGraph() | public void ImportGraph() | ||||
{ | { | ||||
using (var sess = tf.Session()) | |||||
{ | |||||
var new_saver = tf.train.import_meta_graph("C:/tmp/my-model.meta"); | |||||
} | |||||
var sess = tf.Session(); | |||||
var new_saver = tf.train.import_meta_graph("C:/tmp/my-model.meta"); | |||||
//tf.train.export_meta_graph(filename: "linear_regression.meta.bin"); | //tf.train.export_meta_graph(filename: "linear_regression.meta.bin"); | ||||
// import meta | // import meta | ||||
@@ -60,14 +58,12 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
// Add ops to save and restore all the variables. | // Add ops to save and restore all the variables. | ||||
var saver = tf.train.Saver(); | var saver = tf.train.Saver(); | ||||
using (var sess = tf.Session()) | |||||
{ | |||||
sess.run(init_op); | |||||
var sess = tf.Session(); | |||||
sess.run(init_op); | |||||
// Save the variables to disk. | |||||
var save_path = saver.save(sess, "/tmp/model1.ckpt"); | |||||
Console.WriteLine($"Model saved in path: {save_path}"); | |||||
} | |||||
// Save the variables to disk. | |||||
var save_path = saver.save(sess, "/tmp/model1.ckpt"); | |||||
Console.WriteLine($"Model saved in path: {save_path}"); | |||||
} | } | ||||
public void Save2() | public void Save2() | ||||
@@ -84,17 +80,15 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
// Add ops to save and restore all the variables. | // Add ops to save and restore all the variables. | ||||
var saver = tf.train.Saver(); | var saver = tf.train.Saver(); | ||||
using (var sess = tf.Session()) | |||||
{ | |||||
sess.run(init_op); | |||||
// o some work with the model. | |||||
inc_v1.op.run(); | |||||
dec_v2.op.run(); | |||||
// Save the variables to disk. | |||||
var save_path = saver.save(sess, "/tmp/model2.ckpt"); | |||||
Console.WriteLine($"Model saved in path: {save_path}"); | |||||
} | |||||
var sess = tf.Session(); | |||||
sess.run(init_op); | |||||
// o some work with the model. | |||||
inc_v1.op.run(); | |||||
dec_v2.op.run(); | |||||
// Save the variables to disk. | |||||
var save_path = saver.save(sess, "/tmp/model2.ckpt"); | |||||
Console.WriteLine($"Model saved in path: {save_path}"); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -57,12 +57,10 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||||
var input = tf.placeholder(TF_DataType.TF_FLOAT, new Shape(6)); | var input = tf.placeholder(TF_DataType.TF_FLOAT, new Shape(6)); | ||||
var scan = tf.scan(fn, input); | var scan = tf.scan(fn, input); | ||||
using (var sess = tf.Session()) | |||||
{ | |||||
sess.run(tf.global_variables_initializer()); | |||||
var result = sess.run(scan, new FeedItem(input, np.array(1, 2, 3, 4, 5, 6))); | |||||
Assert.AreEqual(new float[] { 1, 3, 6, 10, 15, 21 }, result.ToArray<float>()); | |||||
} | |||||
var sess = tf.Session(); | |||||
sess.run(tf.global_variables_initializer()); | |||||
var result = sess.run(scan, new FeedItem(input, np.array(1, 2, 3, 4, 5, 6))); | |||||
Assert.AreEqual(new float[] { 1, 3, 6, 10, 15, 21 }, result.ToArray<float>()); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -196,23 +196,21 @@ namespace TensorFlowNET.UnitTest | |||||
// return self._eval_helper(tensors) | // return self._eval_helper(tensors) | ||||
// else: | // else: | ||||
{ | { | ||||
using (var sess = tf.Session()) | |||||
var sess = tf.Session(); | |||||
var ndarray = tensor.eval(sess); | |||||
if (typeof(T) == typeof(double)) | |||||
{ | { | ||||
var ndarray = tensor.eval(sess); | |||||
if (typeof(T) == typeof(double)) | |||||
{ | |||||
double x = ndarray; | |||||
result = x; | |||||
} | |||||
else if (typeof(T) == typeof(int)) | |||||
{ | |||||
int x = ndarray; | |||||
result = x; | |||||
} | |||||
else | |||||
{ | |||||
result = ndarray; | |||||
} | |||||
double x = ndarray; | |||||
result = x; | |||||
} | |||||
else if (typeof(T) == typeof(int)) | |||||
{ | |||||
int x = ndarray; | |||||
result = x; | |||||
} | |||||
else | |||||
{ | |||||
result = ndarray; | |||||
} | } | ||||
return (T)result; | return (T)result; | ||||
@@ -28,7 +28,6 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
public void DeleteStatus() | public void DeleteStatus() | ||||
{ | { | ||||
var s = new Status(); | var s = new Status(); | ||||
s.Dispose(); | |||||
} | } | ||||
} | } | ||||
} | } |