@@ -23,11 +23,9 @@ namespace Tensorflow | |||
var x = tf.placeholder(tf.float64, shape: (1024, 1024)); | |||
var log = tf.log(x); | |||
using (var sess = tf.Session()) | |||
{ | |||
var ones = np.ones((1024, 1024), dtype: np.float64); | |||
var o = sess.run(log, new FeedItem(x, ones)); | |||
} | |||
var sess = tf.Session(); | |||
var ones = np.ones((1024, 1024), dtype: np.float64); | |||
var o = sess.run(log, new FeedItem(x, ones)); | |||
// Thread.Sleep(1); | |||
} | |||
@@ -25,15 +25,15 @@ namespace Tensorflow | |||
/// <summary> | |||
/// Represents a TF_Buffer that can be passed to Tensorflow. | |||
/// </summary> | |||
public sealed class Buffer : IDisposable | |||
public sealed class Buffer | |||
{ | |||
public SafeBufferHandle Handle { get; } | |||
SafeBufferHandle _handle; | |||
/// <remarks> | |||
/// <inheritdoc cref="SafeHandleLease" path="/devdoc/usage"/> | |||
/// </remarks> | |||
private unsafe ref readonly TF_Buffer DangerousBuffer | |||
=> ref Unsafe.AsRef<TF_Buffer>(Handle.DangerousGetHandle().ToPointer()); | |||
=> ref Unsafe.AsRef<TF_Buffer>(_handle.DangerousGetHandle().ToPointer()); | |||
/// <summary> | |||
/// The memory block representing this buffer. | |||
@@ -59,7 +59,7 @@ namespace Tensorflow | |||
{ | |||
get | |||
{ | |||
using (Handle.Lease()) | |||
using (_handle.Lease()) | |||
{ | |||
return DangerousBuffer.length; | |||
} | |||
@@ -67,13 +67,13 @@ namespace Tensorflow | |||
} | |||
public Buffer() | |||
=> Handle = TF_NewBuffer(); | |||
=> _handle = TF_NewBuffer(); | |||
public Buffer(SafeBufferHandle handle) | |||
=> Handle = handle; | |||
=> _handle = handle; | |||
public Buffer(byte[] data) | |||
=> Handle = _toBuffer(data); | |||
=> _handle = _toBuffer(data); | |||
private static SafeBufferHandle _toBuffer(byte[] data) | |||
{ | |||
@@ -92,7 +92,7 @@ namespace Tensorflow | |||
/// </summary> | |||
public unsafe byte[] ToArray() | |||
{ | |||
using (Handle.Lease()) | |||
using (_handle.Lease()) | |||
{ | |||
ref readonly TF_Buffer buffer = ref DangerousBuffer; | |||
@@ -107,7 +107,12 @@ namespace Tensorflow | |||
} | |||
} | |||
public void Dispose() | |||
=> Handle.Dispose(); | |||
public override string ToString() | |||
=> $"0x{_handle.DangerousGetHandle():x16}"; | |||
public static implicit operator SafeBufferHandle(Buffer buffer) | |||
{ | |||
return buffer._handle; | |||
} | |||
} | |||
} |
@@ -11,7 +11,7 @@ public class CheckpointReader | |||
Status status = new Status(); | |||
VariableToDataTypeMap = new Dictionary<string, TF_DataType>(); | |||
VariableToShapeMap = new Dictionary<string, Shape>(); | |||
_handle = c_api.TF_NewCheckpointReader(filename, status.Handle); | |||
_handle = c_api.TF_NewCheckpointReader(filename, status); | |||
status.Check(true); | |||
ReadAllShapeAndType(); | |||
} | |||
@@ -38,7 +38,7 @@ public class CheckpointReader | |||
int num_dims = GetVariableNumDims(name); | |||
long[] dims = new long[num_dims]; | |||
Status status = new Status(); | |||
c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status.Handle); | |||
c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status); | |||
status.Check(true); | |||
return new Shape(dims); | |||
} | |||
@@ -49,7 +49,7 @@ public class CheckpointReader | |||
public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid) | |||
{ | |||
Status status = new Status(); | |||
var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status.Handle); | |||
var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status); | |||
status.Check(true); | |||
return new Tensor(tensor); | |||
} | |||
@@ -37,7 +37,7 @@ namespace Tensorflow.Contexts | |||
public void log_device_placement(bool enable) | |||
{ | |||
if (_handle != null) | |||
c_api.TFE_ContextSetLogDevicePlacement(_handle, enable, tf.Status.Handle); | |||
c_api.TFE_ContextSetLogDevicePlacement(_handle, enable, tf.Status); | |||
_log_device_placement = enable; | |||
// _thread_local_data.function_call_options = null; | |||
} | |||
@@ -60,15 +60,15 @@ namespace Tensorflow.Contexts | |||
public PhysicalDevice[] list_physical_devices(string device_type = null) | |||
{ | |||
using var opts = c_api.TFE_NewContextOptions(); | |||
using var ctx = c_api.TFE_NewContext(opts, tf.Status.Handle); | |||
using var devices = c_api.TFE_ContextListDevices(ctx, tf.Status.Handle); | |||
using var ctx = c_api.TFE_NewContext(opts, tf.Status); | |||
using var devices = c_api.TFE_ContextListDevices(ctx, tf.Status); | |||
tf.Status.Check(true); | |||
int num_devices = c_api.TF_DeviceListCount(devices); | |||
var results = new List<PhysicalDevice>(); | |||
for (int i = 0; i < num_devices; ++i) | |||
{ | |||
var dev_type = c_api.StringPiece(c_api.TF_DeviceListType(devices, i, tf.Status.Handle)); | |||
var dev_type = c_api.StringPiece(c_api.TF_DeviceListType(devices, i, tf.Status)); | |||
tf.Status.Check(true); | |||
if (dev_type.StartsWith("XLA")) | |||
@@ -76,7 +76,7 @@ namespace Tensorflow.Contexts | |||
if (device_type == null || dev_type == device_type) | |||
{ | |||
var dev_name = c_api.TF_DeviceListName(devices, i, tf.Status.Handle); | |||
var dev_name = c_api.TF_DeviceListName(devices, i, tf.Status); | |||
tf.Status.Check(true); | |||
results.Add(new PhysicalDevice | |||
@@ -28,7 +28,7 @@ namespace Tensorflow.Contexts | |||
/// <summary> | |||
/// Environment in which eager operations execute. | |||
/// </summary> | |||
public sealed partial class Context : IDisposable | |||
public sealed partial class Context | |||
{ | |||
public const int GRAPH_MODE = 0; | |||
public const int EAGER_MODE = 1; | |||
@@ -41,15 +41,7 @@ namespace Tensorflow.Contexts | |||
public FunctionCallOptions FunctionCallOptions { get; } | |||
SafeContextHandle _handle; | |||
public SafeContextHandle Handle | |||
{ | |||
get | |||
{ | |||
if (_handle == null) | |||
ensure_initialized(); | |||
return _handle; | |||
} | |||
} | |||
int? _seed; | |||
Random _rng; | |||
@@ -59,6 +51,7 @@ namespace Tensorflow.Contexts | |||
context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE, false); | |||
initialized = false; | |||
FunctionCallOptions = new FunctionCallOptions(); | |||
ensure_initialized(); | |||
} | |||
/// <summary> | |||
@@ -72,12 +65,12 @@ namespace Tensorflow.Contexts | |||
Config = MergeConfig(); | |||
FunctionCallOptions.Config = Config; | |||
var config_str = Config.ToByteArray(); | |||
using var opts = new ContextOptions(); | |||
using var status = new Status(); | |||
c_api.TFE_ContextOptionsSetConfig(opts.Handle, config_str, (ulong)config_str.Length, status.Handle); | |||
var opts = new ContextOptions(); | |||
var status = new Status(); | |||
c_api.TFE_ContextOptionsSetConfig(opts, config_str, (ulong)config_str.Length, status); | |||
status.Check(true); | |||
c_api.TFE_ContextOptionsSetDevicePlacementPolicy(opts.Handle, _device_policy); | |||
_handle = c_api.TFE_NewContext(opts.Handle, status.Handle); | |||
c_api.TFE_ContextOptionsSetDevicePlacementPolicy(opts, _device_policy); | |||
_handle = c_api.TFE_NewContext(opts, status); | |||
status.Check(true); | |||
initialized = true; | |||
} | |||
@@ -178,10 +171,14 @@ namespace Tensorflow.Contexts | |||
tf.Context.ensure_initialized(); | |||
if (_handle != null) | |||
{ | |||
c_api.TFE_ContextClearCaches(_handle); | |||
} | |||
} | |||
public void Dispose() | |||
=> _handle.Dispose(); | |||
public static implicit operator SafeContextHandle(Context ctx) | |||
{ | |||
return ctx._handle; | |||
} | |||
} | |||
} |
@@ -14,21 +14,21 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using Tensorflow.Eager; | |||
namespace Tensorflow.Contexts | |||
namespace Tensorflow.Contexts; | |||
public sealed class ContextOptions | |||
{ | |||
public sealed class ContextOptions : IDisposable | |||
{ | |||
public SafeContextOptionsHandle Handle { get; } | |||
SafeContextOptionsHandle _handle { get; } | |||
public ContextOptions() | |||
{ | |||
Handle = c_api.TFE_NewContextOptions(); | |||
} | |||
public ContextOptions() | |||
{ | |||
_handle = c_api.TFE_NewContextOptions(); | |||
} | |||
public void Dispose() | |||
=> Handle.Dispose(); | |||
public static implicit operator SafeContextOptionsHandle(ContextOptions opt) | |||
{ | |||
return opt._handle; | |||
} | |||
} |
@@ -43,7 +43,7 @@ namespace Tensorflow.Eager | |||
{ | |||
var status = tf.Status; | |||
var op = GetOp(ctx, op_name, status); | |||
c_api.TFE_OpSetDevice(op, device_name, status.Handle); | |||
c_api.TFE_OpSetDevice(op, device_name, status); | |||
if (status.ok()) | |||
{ | |||
for (int i = 0; i < inputs.Length; ++i) | |||
@@ -54,7 +54,7 @@ namespace Tensorflow.Eager | |||
Tensor nd => nd.EagerTensorHandle, | |||
_ => throw new NotImplementedException("Eager tensor handle has not been allocated.") | |||
}; | |||
c_api.TFE_OpAddInput(op, tensor_handle, status.Handle); | |||
c_api.TFE_OpAddInput(op, tensor_handle, status); | |||
status.Check(true); | |||
} | |||
} | |||
@@ -64,7 +64,7 @@ namespace Tensorflow.Eager | |||
var outputs = new SafeEagerTensorHandle[num_outputs]; | |||
if (status.ok()) | |||
{ | |||
c_api.TFE_Execute(op, outputs, out num_outputs, status.Handle); | |||
c_api.TFE_Execute(op, outputs, out num_outputs, status); | |||
status.Check(true); | |||
} | |||
return outputs.Select(x => new EagerTensor(x)).ToArray(); | |||
@@ -104,7 +104,7 @@ namespace Tensorflow.Eager | |||
var eager_tensor = ops.convert_to_tensor(fast_input_array[j]); | |||
attr_values[j] = eager_tensor.dtype; | |||
c_api.TFE_OpAddInput(op, eager_tensor.EagerTensorHandle, status.Handle); | |||
c_api.TFE_OpAddInput(op, eager_tensor.EagerTensorHandle, status); | |||
if (op_exec_info.run_callbacks) | |||
{ | |||
@@ -142,7 +142,7 @@ namespace Tensorflow.Eager | |||
} | |||
var retVals = new SafeEagerTensorHandle[num_retvals]; | |||
c_api.TFE_Execute(op, retVals, out num_retvals, status.Handle); | |||
c_api.TFE_Execute(op, retVals, out num_retvals, status); | |||
status.Check(true); | |||
var flat_result = retVals.Select(x => new EagerTensor(x)).ToArray(); | |||
@@ -160,10 +160,10 @@ namespace Tensorflow.Eager | |||
SafeEagerOpHandle GetOp(Context ctx, string op_or_function_name, Status status) | |||
{ | |||
if (thread_local_eager_operation_map.find(op_or_function_name, out var op)) | |||
c_api.TFE_OpReset(op, op_or_function_name, ctx.DeviceName, status.Handle); | |||
c_api.TFE_OpReset(op, op_or_function_name, ctx.DeviceName, status); | |||
else | |||
{ | |||
op = c_api.TFE_NewOp(ctx.Handle, op_or_function_name, status.Handle); | |||
op = c_api.TFE_NewOp(ctx, op_or_function_name, status); | |||
thread_local_eager_operation_map[op_or_function_name] = op; | |||
} | |||
@@ -219,7 +219,7 @@ namespace Tensorflow.Eager | |||
flattened_attrs.Add(dtype); | |||
} | |||
c_api.TFE_OpAddInput(op, tensor.EagerTensorHandle, status.Handle); | |||
c_api.TFE_OpAddInput(op, tensor.EagerTensorHandle, status); | |||
status.Check(true); | |||
return true; | |||
@@ -235,7 +235,7 @@ namespace Tensorflow.Eager | |||
var value = attrs[i + 1]; | |||
byte is_list = 0; | |||
var type = c_api.TFE_OpGetAttrType(op, key, ref is_list, status.Handle); | |||
var type = c_api.TFE_OpGetAttrType(op, key, ref is_list, status); | |||
if (!status.ok()) return; | |||
if (is_list != 0) | |||
SetOpAttrList(tf.Context, op, key, value as object[], type, null, status); | |||
@@ -264,7 +264,7 @@ namespace Tensorflow.Eager | |||
Status status) | |||
{ | |||
byte is_list = 0; | |||
var type = c_api.TFE_OpGetAttrType(op, attr_name, ref is_list, status.Handle); | |||
var type = c_api.TFE_OpGetAttrType(op, attr_name, ref is_list, status); | |||
if (status.Code != TF_Code.TF_OK) return; | |||
if (attr_value == null) | |||
@@ -305,7 +305,7 @@ namespace Tensorflow.Eager | |||
tf.memcpy(dims[i], values1[i].dims, values1[i].ndim * sizeof(long)); | |||
} | |||
c_api.TFE_OpSetAttrShapeList(op, key, dims, num_dims, num_values, status.Handle); | |||
c_api.TFE_OpSetAttrShapeList(op, key, dims, num_dims, num_values, status); | |||
Array.ForEach(dims, x => Marshal.FreeHGlobal(x)); | |||
} | |||
else if (type == TF_AttrType.TF_ATTR_TYPE && values is TF_DataType[] values2) | |||
@@ -353,7 +353,7 @@ namespace Tensorflow.Eager | |||
break; | |||
case TF_AttrType.TF_ATTR_SHAPE: | |||
var dims = (value as long[]).ToArray(); | |||
c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status.Handle); | |||
c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status); | |||
status.Check(true); | |||
break; | |||
case TF_AttrType.TF_ATTR_FUNC: | |||
@@ -54,7 +54,7 @@ namespace Tensorflow.Eager | |||
void NewEagerTensorHandle(SafeTensorHandle h) | |||
{ | |||
_id = ops.uid(); | |||
_eagerTensorHandle = c_api.TFE_NewTensorHandle(h, tf.Status.Handle); | |||
_eagerTensorHandle = c_api.TFE_NewTensorHandle(h, tf.Status); | |||
#if TRACK_TENSOR_LIFE | |||
Console.WriteLine($"New EagerTensor {_eagerTensorHandle}"); | |||
#endif | |||
@@ -65,7 +65,7 @@ namespace Tensorflow.Eager | |||
{ | |||
if (_handle != null) | |||
return; | |||
_handle = c_api.TFE_TensorHandleResolve(_eagerTensorHandle, tf.Status.Handle); | |||
_handle = c_api.TFE_TensorHandleResolve(_eagerTensorHandle, tf.Status); | |||
tf.Status.Check(true); | |||
} | |||
@@ -24,10 +24,10 @@ namespace Tensorflow.Eager | |||
} | |||
} | |||
public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(_eagerTensorHandle, tf.Status.Handle)); | |||
public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(_eagerTensorHandle, tf.Status)); | |||
public override TF_DataType dtype => c_api.TFE_TensorHandleDataType(_eagerTensorHandle); | |||
public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.Status.Handle); | |||
public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.Status); | |||
public override ulong bytesize | |||
{ | |||
@@ -49,9 +49,9 @@ namespace Tensorflow.Eager | |||
protected override Shape GetShapeInternal() | |||
{ | |||
var dims = new int[c_api.TFE_TensorHandleNumDims(_eagerTensorHandle, tf.Status.Handle)]; | |||
var dims = new int[c_api.TFE_TensorHandleNumDims(_eagerTensorHandle, tf.Status)]; | |||
for (int i = 0; i < dims.Length; i++) | |||
dims[i] = c_api.TFE_TensorHandleDim(_eagerTensorHandle, i, tf.Status.Handle); | |||
dims[i] = c_api.TFE_TensorHandleDim(_eagerTensorHandle, i, tf.Status); | |||
return dims; | |||
} | |||
@@ -64,15 +64,15 @@ namespace Tensorflow.Eager | |||
public static int GetRank(IntPtr handle) | |||
{ | |||
var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); | |||
return c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, tf.Status.Handle); | |||
return c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, tf.Status); | |||
} | |||
public static int[] GetDims(IntPtr handle) | |||
{ | |||
var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); | |||
var dims = new int[c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, tf.Status.Handle)]; | |||
var dims = new int[c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, tf.Status)]; | |||
for (int i = 0; i < dims.Length; i++) | |||
dims[i] = c_api.TFE_TensorHandleDim(tfe_tensor_handle, i, tf.Status.Handle); | |||
dims[i] = c_api.TFE_TensorHandleDim(tfe_tensor_handle, i, tf.Status); | |||
return dims; | |||
} | |||
@@ -114,7 +114,7 @@ namespace Tensorflow | |||
/// <param name="function"></param> | |||
/// <param name="status"></param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_ContextAddFunction(SafeContextHandle ctx, IntPtr function, SafeStatusHandle status); | |||
public static extern void TFE_ContextAddFunction(SafeContextHandle ctx, SafeFuncGraphHandle function, SafeStatusHandle status); | |||
/// <summary> | |||
/// Removes a function from the context. Once removed, you can no longer | |||
@@ -56,15 +56,14 @@ namespace Tensorflow | |||
TF_ImportGraphDefResults results = null; | |||
var bytes = graph_def.ToByteString().ToArray(); | |||
using (var buffer = c_api_util.tf_buffer(bytes)) | |||
using (var scoped_options = c_api_util.ScopedTFImportGraphDefOptions()) | |||
using (var status = new Status()) | |||
{ | |||
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); | |||
// need to create a class ImportGraphDefWithResults with IDisposal | |||
results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer.Handle, scoped_options.Handle, status.Handle)); | |||
status.Check(true); | |||
} | |||
var buffer = c_api_util.tf_buffer(bytes); | |||
var scoped_options = c_api_util.ScopedTFImportGraphDefOptions(); | |||
var status = new Status(); | |||
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); | |||
// need to create a class ImportGraphDefWithResults with IDisposal | |||
results = new TF_ImportGraphDefResults(c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status)); | |||
status.Check(true); | |||
_ProcessNewOps(graph); | |||
@@ -116,13 +115,13 @@ namespace Tensorflow | |||
Dictionary<string, Tensor> input_map, | |||
string[] return_elements) | |||
{ | |||
c_api.TF_ImportGraphDefOptionsSetPrefix(options.Handle, prefix); | |||
c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options.Handle, (char)1); | |||
c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix); | |||
c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, (char)1); | |||
foreach (var input in input_map) | |||
{ | |||
var (src_name, src_index) = _ParseTensorName(input.Key); | |||
c_api.TF_ImportGraphDefOptionsAddInputMapping(options.Handle, src_name, src_index, input.Value._as_tf_output()); | |||
c_api.TF_ImportGraphDefOptionsAddInputMapping(options, src_name, src_index, input.Value._as_tf_output()); | |||
} | |||
if (return_elements == null) | |||
@@ -133,11 +132,11 @@ namespace Tensorflow | |||
if (name.Contains(":")) | |||
{ | |||
var (op_name, index) = _ParseTensorName(name); | |||
c_api.TF_ImportGraphDefOptionsAddReturnOutput(options.Handle, op_name, index); | |||
c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index); | |||
} | |||
else | |||
{ | |||
c_api.TF_ImportGraphDefOptionsAddReturnOperation(options.Handle, name); | |||
c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, name); | |||
} | |||
} | |||
@@ -33,7 +33,7 @@ namespace Tensorflow | |||
if (_registered_ops.Count > 0) | |||
return _registered_ops; | |||
using var buffer = new Buffer(c_api.TF_GetAllOpList()); | |||
var buffer = new Buffer(c_api.TF_GetAllOpList()); | |||
var op_list = OpList.Parser.ParseFrom(buffer.ToArray()); | |||
foreach (var op_def in op_list.Op) | |||
_registered_ops[op_def.Name] = op_def; | |||
@@ -56,8 +56,8 @@ namespace Tensorflow.Framework | |||
if (pred_value is null) | |||
{ | |||
var result = range(pred.op.NumOutputs).Select(x => IntPtr.Zero).ToArray(); | |||
var evaluated = c_api.TF_TryEvaluateConstant(pred.graph, pred._as_tf_output(), result, tf.Status.Handle); | |||
if (!evaluated || c_api.TF_GetCode(tf.Status.Handle) != TF_Code.TF_OK) | |||
var evaluated = c_api.TF_TryEvaluateConstant(pred.graph, pred._as_tf_output(), result, tf.Status); | |||
if (!evaluated || c_api.TF_GetCode(tf.Status) != TF_Code.TF_OK) | |||
return null; | |||
else | |||
throw new NotImplementedException(""); | |||
@@ -34,10 +34,10 @@ namespace Tensorflow | |||
/// <param name="output_func_def"></param> | |||
/// <param name="status"></param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_FunctionToFunctionDef(IntPtr func, SafeBufferHandle output_func_def, SafeStatusHandle status); | |||
public static extern void TF_FunctionToFunctionDef(SafeFuncGraphHandle func, SafeBufferHandle output_func_def, SafeStatusHandle status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TF_GraphToFunction(IntPtr fn_body, string fn_name, | |||
public static extern SafeFuncGraphHandle TF_GraphToFunction(SafeGraphHandle fn_body, string fn_name, | |||
bool append_hash_to_fn_name, | |||
int num_opers, IntPtr[] opers, | |||
int ninputs, TF_Output[] inputs, | |||
@@ -48,12 +48,12 @@ namespace Tensorflow | |||
SafeStatusHandle status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TF_FunctionSetAttrValueProto(IntPtr func, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status); | |||
public static extern IntPtr TF_FunctionSetAttrValueProto(SafeFuncGraphHandle func, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TF_FunctionName(IntPtr func); | |||
public static extern IntPtr TF_FunctionName(SafeFuncGraphHandle func); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_GraphCopyFunction(IntPtr g, IntPtr func, IntPtr grad, SafeStatusHandle status); | |||
public static extern void TF_GraphCopyFunction(SafeGraphHandle g, SafeFuncGraphHandle func, IntPtr grad, SafeStatusHandle status); | |||
} | |||
} |
@@ -37,7 +37,7 @@ namespace Tensorflow | |||
/// <param name="status">TF_Status*</param> | |||
/// <param name="dy">TF_Output*</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_AddGradientsWithPrefix(IntPtr g, string prefix, TF_Output[] y, int ny, | |||
public static extern void TF_AddGradientsWithPrefix(SafeGraphHandle g, string prefix, TF_Output[] y, int ny, | |||
TF_Output[] x, int nx, TF_Output[] dx, SafeStatusHandle status, IntPtr[] dy); | |||
} | |||
} |
@@ -22,21 +22,19 @@ namespace Tensorflow | |||
var inputs_string = string.Join(",", inputs); | |||
var outputs_string = string.Join(",", outputs); | |||
var transforms_string = string.Join(" ", transforms); | |||
using (var status = new Status()) | |||
{ | |||
var buffer = new Buffer(); | |||
var len = c_api.TransformGraphWithStringInputs(input_graph_def_string, | |||
input_graph_def_string.Length, | |||
inputs_string, | |||
outputs_string, | |||
transforms_string, | |||
buffer.Handle, | |||
status.Handle); | |||
var status = new Status(); | |||
var buffer = new Buffer(); | |||
var len = c_api.TransformGraphWithStringInputs(input_graph_def_string, | |||
input_graph_def_string.Length, | |||
inputs_string, | |||
outputs_string, | |||
transforms_string, | |||
buffer, | |||
status); | |||
status.Check(false); | |||
var bytes = buffer.ToArray(); | |||
return GraphDef.Parser.ParseFrom(bytes); | |||
} | |||
status.Check(false); | |||
var bytes = buffer.ToArray(); | |||
return GraphDef.Parser.ParseFrom(bytes); | |||
} | |||
} | |||
} |
@@ -37,11 +37,9 @@ namespace Tensorflow.Graphs | |||
1); | |||
return result[0]; | |||
} | |||
using (var s = tf.Session(input.graph)) | |||
{ | |||
var output = func(input); | |||
return output; | |||
} | |||
var s = tf.Session(input.graph); | |||
var output = func(input); | |||
return output; | |||
}; | |||
} | |||
@@ -75,12 +73,10 @@ namespace Tensorflow.Graphs | |||
1); | |||
return result[0]; | |||
} | |||
using (var s = tf.Session(a.graph)) | |||
{ | |||
Debug.Assert(a.graph == b.graph); | |||
var output = func(a, b); | |||
return output; | |||
} | |||
var s = tf.Session(a.graph); | |||
Debug.Assert(a.graph == b.graph); | |||
var output = func(a, b); | |||
return output; | |||
}; | |||
} | |||
} | |||
@@ -1,258 +1,252 @@ | |||
using Google.Protobuf; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Exceptions; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Graphs | |||
namespace Tensorflow.Graphs; | |||
/// <summary> | |||
/// Graph representing a function body. | |||
/// </summary> | |||
public class FuncGraph : Graph, IDisposable | |||
{ | |||
SafeFuncGraphHandle _func_graph_handle; | |||
public string FuncName => _graph_key; | |||
public Tensors Inputs { get; set; } = new Tensors(); | |||
public Tensors Outputs { get; set; } = new Tensors(); | |||
public Dictionary<string, string> Attrs { get; set; } | |||
Dictionary<long, (Tensor, Tensor)> _captures | |||
= new Dictionary<long, (Tensor, Tensor)>(); | |||
public Tensor[] external_captures | |||
=> _captures.Select(x => x.Value.Item1).ToArray(); | |||
public (Tensor, Tensor)[] captures | |||
=> _captures.Values.Select(x => x).ToArray(); | |||
public Tensor[] internal_captures | |||
=> _captures.Select(x => x.Value.Item2).ToArray(); | |||
public Tensor[] captured_inputs | |||
=> external_captures; | |||
/// <summary> | |||
/// Graph representing a function body. | |||
/// Construct a new FuncGraph. | |||
/// </summary> | |||
public class FuncGraph : Graph | |||
public FuncGraph(string name) : base() | |||
{ | |||
IntPtr _func_graph_handle; | |||
public string FuncName => _graph_key; | |||
public Tensors Inputs { get; set; } = new Tensors(); | |||
public Tensors Outputs { get; set; } = new Tensors(); | |||
public Dictionary<string, string> Attrs { get; set; } | |||
outer_graph = ops.get_default_graph(); | |||
while (outer_graph.building_function) | |||
outer_graph = outer_graph.OuterGraph; | |||
_graph_key = name; | |||
building_function = true; | |||
} | |||
Dictionary<long, (Tensor, Tensor)> _captures | |||
= new Dictionary<long, (Tensor, Tensor)>(); | |||
public FuncGraph(SafeGraphHandle handle, string name, Dictionary<string, string> attrs) : base() | |||
{ | |||
outer_graph = ops.get_default_graph(); | |||
while (outer_graph.building_function) | |||
outer_graph = outer_graph.OuterGraph; | |||
_graph_key = name; | |||
building_function = true; | |||
Attrs = attrs; | |||
// Will to test if FuncGraph has memory leak | |||
// c_api.TF_DeleteGraph(_handle); | |||
_handle = handle; | |||
} | |||
public Tensor[] external_captures | |||
=> _captures.Select(x => x.Value.Item1).ToArray(); | |||
public (Tensor, Tensor)[] captures | |||
=> _captures.Values.Select(x => x).ToArray(); | |||
public void ToGraph(Operation[] opers, | |||
Tensor[] inputs, Tensor[] outputs, | |||
string[] output_names) | |||
{ | |||
var status = new Status(); | |||
_func_graph_handle = c_api.TF_GraphToFunction(_handle, | |||
_graph_key, | |||
false, | |||
opers.Length, | |||
opers.Select(x => (IntPtr)x).ToArray(), | |||
inputs.Length, | |||
inputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | |||
outputs.Length, | |||
outputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | |||
output_names, | |||
IntPtr.Zero, | |||
null, | |||
status); | |||
status.Check(true); | |||
SetAttrs(); | |||
// c_api.TF_GraphCopyFunction(outer_graph, _func_graph_handle, IntPtr.Zero, status.Handle); | |||
// status.Check(true); | |||
c_api.TFE_ContextAddFunction(tf.Context, _func_graph_handle, status); | |||
status.Check(true); | |||
_graph_key = c_api.StringPiece(c_api.TF_FunctionName(_func_graph_handle)); | |||
Inputs = inputs; | |||
// mark_as_return | |||
Outputs = outputs;// .Select(x => array_ops.identity(x)).ToArray(); | |||
} | |||
public Tensor[] internal_captures | |||
=> _captures.Select(x => x.Value.Item2).ToArray(); | |||
public override Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, string name = null, Dictionary<string, AttrValue> attrs = null, OpDef op_def = null, bool compute_device = true) | |||
{ | |||
foreach(var (i, inp) in enumerate(inputs)) | |||
inputs[i] = capture(inp); | |||
public Tensor[] captured_inputs | |||
=> external_captures; | |||
return base.create_op(op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device); | |||
} | |||
/// <summary> | |||
/// Construct a new FuncGraph. | |||
/// </summary> | |||
public FuncGraph(string name) : base() | |||
const int _EAGER_CONST_THRESHOLD = 128; | |||
public Tensor capture(Tensor tensor, string name = null, Shape shape = null) | |||
{ | |||
if(tensor is EagerTensor) | |||
{ | |||
outer_graph = ops.get_default_graph(); | |||
while (outer_graph.building_function) | |||
outer_graph = outer_graph.OuterGraph; | |||
_graph_key = name; | |||
building_function = true; | |||
if (name == null) | |||
name = ops.uid().ToString(); | |||
// Small EagerTensors are captured with Const ops | |||
if (dtypes.is_value_dtype(tensor.dtype) | |||
&& (tensor.rank == 0 || tensor.size < _EAGER_CONST_THRESHOLD)) | |||
return capture_eager_tensor(tensor, name); | |||
// Large EagerTensors and resources are captured with Placeholder ops | |||
return _capture_helper(tensor, name, shape: shape); | |||
} | |||
public FuncGraph(IntPtr handle, string name, Dictionary<string, string> attrs) : base() | |||
if(tensor.graph != this) | |||
{ | |||
outer_graph = ops.get_default_graph(); | |||
while (outer_graph.building_function) | |||
outer_graph = outer_graph.OuterGraph; | |||
_graph_key = name; | |||
building_function = true; | |||
Attrs = attrs; | |||
// Will to test if FuncGraph has memory leak | |||
// c_api.TF_DeleteGraph(_handle); | |||
_handle = handle; | |||
if (name == null) | |||
name = tensor.op.name; | |||
var inner_graph = tensor.graph; | |||
while(inner_graph != null && inner_graph is FuncGraph inner_func_graph) | |||
{ | |||
if (inner_graph == this) | |||
throw new InaccessibleTensorError($"The tensor '{tensor.name}' cannot be accessed here: it is defined" + | |||
" in another function or code block. Use return values," + | |||
" explicit Python locals or TensorFlow collections to access" + | |||
$" it. Defined in: {tensor.graph.graph_key}; accessed from: {graph_key}."); | |||
inner_graph = inner_func_graph.outer_graph; | |||
} | |||
return _capture_helper(tensor, name); | |||
} | |||
public void ToGraph(Operation[] opers, | |||
Tensor[] inputs, Tensor[] outputs, | |||
string[] output_names) | |||
return tensor; | |||
} | |||
Tensor capture_eager_tensor(Tensor tensor, string name) | |||
{ | |||
Tensor graph_const = null; | |||
if (!_captures.ContainsKey(tensor.Id)) | |||
{ | |||
var status = new Status(); | |||
_func_graph_handle = c_api.TF_GraphToFunction(_handle, | |||
_graph_key, | |||
false, | |||
opers.Length, | |||
opers.Select(x => (IntPtr)x).ToArray(), | |||
inputs.Length, | |||
inputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | |||
outputs.Length, | |||
outputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | |||
output_names == null || output_names.Length == 0 ? null : output_names, | |||
IntPtr.Zero, | |||
null, | |||
status.Handle); | |||
status.Check(true); | |||
SetAttrs(); | |||
// c_api.TF_GraphCopyFunction(outer_graph, _func_graph_handle, IntPtr.Zero, status.Handle); | |||
// status.Check(true); | |||
c_api.TFE_ContextAddFunction(tf.Context.Handle, _func_graph_handle, status.Handle); | |||
status.Check(true); | |||
_graph_key = c_api.StringPiece(c_api.TF_FunctionName(_func_graph_handle)); | |||
Inputs = inputs; | |||
// mark_as_return | |||
Outputs = outputs;// .Select(x => array_ops.identity(x)).ToArray(); | |||
graph_const = tf_with(ops.control_dependencies(null), ctl | |||
=> constant_op.constant(tensor.numpy(), dtype: tensor.dtype, shape: tensor.shape, name: name)); | |||
add_capture(tensor, graph_const); | |||
} | |||
public override Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, string name = null, Dictionary<string, AttrValue> attrs = null, OpDef op_def = null, bool compute_device = true) | |||
else | |||
{ | |||
foreach(var (i, inp) in enumerate(inputs)) | |||
inputs[i] = capture(inp); | |||
return base.create_op(op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device); | |||
graph_const = _captures[tensor.Id].Item2; | |||
} | |||
const int _EAGER_CONST_THRESHOLD = 128; | |||
public Tensor capture(Tensor tensor, string name = null, Shape shape = null) | |||
BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | |||
{ | |||
if(tensor is EagerTensor) | |||
{ | |||
if (name == null) | |||
name = ops.uid().ToString(); | |||
// Small EagerTensors are captured with Const ops | |||
if (dtypes.is_value_dtype(tensor.dtype) | |||
&& (tensor.rank == 0 || tensor.size < _EAGER_CONST_THRESHOLD)) | |||
return capture_eager_tensor(tensor, name); | |||
return output_grads; | |||
}; | |||
// Large EagerTensors and resources are captured with Placeholder ops | |||
return _capture_helper(tensor, name, shape: shape); | |||
} | |||
tf.Runner.RecordGradient("captured_value", | |||
new[] { graph_const }, null, | |||
new[] { tensor }, | |||
getBackwardFunction: _backward_function_wrapper | |||
/*getForwardFunction: forward_function*/); | |||
if(tensor.graph != this) | |||
{ | |||
if (name == null) | |||
name = tensor.op.name; | |||
var inner_graph = tensor.graph; | |||
while(inner_graph != null && inner_graph is FuncGraph inner_func_graph) | |||
{ | |||
if (inner_graph == this) | |||
throw new InaccessibleTensorError($"The tensor '{tensor.name}' cannot be accessed here: it is defined" + | |||
" in another function or code block. Use return values," + | |||
" explicit Python locals or TensorFlow collections to access" + | |||
$" it. Defined in: {tensor.graph.graph_key}; accessed from: {graph_key}."); | |||
inner_graph = inner_func_graph.outer_graph; | |||
} | |||
return _capture_helper(tensor, name); | |||
} | |||
return graph_const; | |||
} | |||
return tensor; | |||
Tensor _capture_helper(Tensor tensor, string name, Shape shape = null) | |||
{ | |||
Tensor placeholder = null; | |||
if (!_captures.ContainsKey(tensor.Id)) | |||
{ | |||
placeholder = _create_substitute_placeholder(tensor, | |||
name: name, | |||
dtype: tensor.dtype, | |||
shape: shape); | |||
add_capture(tensor, placeholder); | |||
} | |||
Tensor capture_eager_tensor(Tensor tensor, string name) | |||
else | |||
{ | |||
Tensor graph_const = null; | |||
if (!_captures.ContainsKey(tensor.Id)) | |||
{ | |||
graph_const = tf_with(ops.control_dependencies(null), ctl | |||
=> constant_op.constant(tensor.numpy(), dtype: tensor.dtype, shape: tensor.shape, name: name)); | |||
add_capture(tensor, graph_const); | |||
} | |||
else | |||
{ | |||
graph_const = _captures[tensor.Id].Item2; | |||
} | |||
BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | |||
{ | |||
return output_grads; | |||
}; | |||
tf.Runner.RecordGradient("captured_value", | |||
new[] { graph_const }, null, | |||
new[] { tensor }, | |||
getBackwardFunction: _backward_function_wrapper | |||
/*getForwardFunction: forward_function*/); | |||
return graph_const; | |||
placeholder = _captures[tensor.Id].Item2; | |||
} | |||
Tensor _capture_helper(Tensor tensor, string name, Shape shape = null) | |||
BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | |||
{ | |||
Tensor placeholder = null; | |||
if (!_captures.ContainsKey(tensor.Id)) | |||
{ | |||
placeholder = _create_substitute_placeholder(tensor, | |||
name: name, | |||
dtype: tensor.dtype, | |||
shape: shape); | |||
add_capture(tensor, placeholder); | |||
} | |||
else | |||
{ | |||
placeholder = _captures[tensor.Id].Item2; | |||
} | |||
return output_grads; | |||
}; | |||
BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | |||
{ | |||
return output_grads; | |||
}; | |||
tf.Runner.RecordGradient("captured_value", | |||
new[] { placeholder }, null, | |||
new[] { tensor }, | |||
getBackwardFunction: _backward_function_wrapper | |||
/*getForwardFunction: forward_function*/); | |||
tf.Runner.RecordGradient("captured_value", | |||
new[] { placeholder }, null, | |||
new[] { tensor }, | |||
getBackwardFunction: _backward_function_wrapper | |||
/*getForwardFunction: forward_function*/); | |||
return placeholder; | |||
} | |||
return placeholder; | |||
} | |||
void add_capture(Tensor tensor, Tensor placeholder) | |||
{ | |||
_captures.Add(tensor.Id, (tensor, placeholder)); | |||
Inputs.Add(placeholder); | |||
} | |||
void add_capture(Tensor tensor, Tensor placeholder) | |||
{ | |||
_captures.Add(tensor.Id, (tensor, placeholder)); | |||
Inputs.Add(placeholder); | |||
} | |||
Tensor _create_substitute_placeholder(Tensor value, | |||
string name = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
Shape shape = null) | |||
{ | |||
if (shape is null) | |||
shape = value.shape; | |||
if (dtype == TF_DataType.DtInvalid) | |||
dtype = value.dtype; | |||
var placeholder = tf_with(ops.control_dependencies(null), ctl | |||
=> array_ops.placeholder(dtype, shape: shape, name: name)); | |||
// custom_gradient.copy_handle_data(value, placeholder) | |||
return placeholder; | |||
} | |||
Tensor _create_substitute_placeholder(Tensor value, | |||
string name = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
Shape shape = null) | |||
{ | |||
if (shape is null) | |||
shape = value.shape; | |||
if (dtype == TF_DataType.DtInvalid) | |||
dtype = value.dtype; | |||
var placeholder = tf_with(ops.control_dependencies(null), ctl | |||
=> array_ops.placeholder(dtype, shape: shape, name: name)); | |||
// custom_gradient.copy_handle_data(value, placeholder) | |||
return placeholder; | |||
} | |||
void SetAttrs() | |||
{ | |||
if (Attrs == null) | |||
return; | |||
void SetAttrs() | |||
foreach (var (_name, attr_value) in enumerate(Attrs)) | |||
{ | |||
if (Attrs == null) | |||
return; | |||
foreach (var (_name, attr_value) in enumerate(Attrs)) | |||
var serialized = new AttrValue | |||
{ | |||
var serialized = new AttrValue | |||
{ | |||
S = ByteString.CopyFromUtf8(attr_value) | |||
}.ToByteArray(); | |||
c_api.TF_FunctionSetAttrValueProto(_func_graph_handle, _name, serialized, serialized.Length, tf.Status.Handle); | |||
tf.Status.Check(true); | |||
} | |||
S = ByteString.CopyFromUtf8(attr_value) | |||
}.ToByteArray(); | |||
c_api.TF_FunctionSetAttrValueProto(_func_graph_handle, _name, serialized, serialized.Length, tf.Status); | |||
tf.Status.Check(true); | |||
} | |||
} | |||
public override Graph as_default() | |||
{ | |||
tf.Context.graph_mode(isFunc: true); | |||
ops.set_default_graph(this); | |||
return this; | |||
} | |||
public override Graph as_default() | |||
{ | |||
tf.Context.graph_mode(isFunc: true); | |||
ops.set_default_graph(this); | |||
return this; | |||
} | |||
public override void Exit() | |||
{ | |||
tf.Context.restore_mode(); | |||
ops.pop_graph(); | |||
} | |||
public override void Exit() | |||
{ | |||
tf.Context.restore_mode(); | |||
ops.pop_graph(); | |||
} | |||
protected override void DisposeUnmanagedResources(IntPtr handle) | |||
{ | |||
c_api.TFE_ContextRemoveFunction(tf.Context.Handle, _graph_key, tf.Status.Handle); | |||
c_api.TF_DeleteFunction(_func_graph_handle); | |||
base.DisposeUnmanagedResources(handle); | |||
} | |||
public void Dispose() | |||
{ | |||
c_api.TFE_ContextRemoveFunction(tf.Context, _graph_key, tf.Status); | |||
} | |||
} |
@@ -24,7 +24,7 @@ namespace Tensorflow | |||
public Buffer ToGraphDef(Status s) | |||
{ | |||
var buffer = new Buffer(); | |||
c_api.TF_GraphToGraphDef(_handle, buffer.Handle, s.Handle); | |||
c_api.TF_GraphToGraphDef(_handle, buffer, s); | |||
s.Check(true); | |||
return buffer; | |||
@@ -33,14 +33,12 @@ namespace Tensorflow | |||
private GraphDef _as_graph_def(bool add_shapes = false) | |||
{ | |||
GraphDef def; | |||
using (var status = new Status()) | |||
using (var buffer = ToGraphDef(status)) | |||
{ | |||
status.Check(true); | |||
// limit size to 250M, recursion to max 100 | |||
var inputStream = CodedInputStream.CreateWithLimits(buffer.DangerousMemoryBlock, 250 * 1024 * 1024, 100); | |||
def = GraphDef.Parser.ParseFrom(inputStream); | |||
} | |||
var status = new Status(); | |||
var buffer = ToGraphDef(status); | |||
status.Check(true); | |||
// limit size to 250M, recursion to max 100 | |||
var inputStream = CodedInputStream.CreateWithLimits(buffer.DangerousMemoryBlock, 250 * 1024 * 1024, 100); | |||
def = GraphDef.Parser.ParseFrom(inputStream); | |||
// Strip the experimental library field iff it's empty. | |||
// if(def.Library.Function.Count == 0) | |||
@@ -29,7 +29,7 @@ namespace Tensorflow | |||
int size = Marshal.SizeOf<TF_Output>(); | |||
var return_output_handle = Marshal.AllocHGlobal(size * num_return_outputs); | |||
c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def.Handle, opts.Handle, return_output_handle, num_return_outputs, s.Handle); | |||
c_api.TF_GraphImportGraphDefWithReturnOutputs(_handle, graph_def, opts, return_output_handle, num_return_outputs, s); | |||
var tf_output_ptr = (TF_Output*)return_output_handle; | |||
for (int i = 0; i < num_return_outputs; i++) | |||
@@ -48,15 +48,14 @@ namespace Tensorflow | |||
public bool Import(byte[] bytes, string prefix = "") | |||
{ | |||
using (var opts = new ImportGraphDefOptions()) | |||
using (var status = new Status()) | |||
using (var graph_def = new Buffer(bytes)) | |||
{ | |||
c_api.TF_ImportGraphDefOptionsSetPrefix(opts.Handle, prefix); | |||
c_api.TF_GraphImportGraphDef(_handle, graph_def.Handle, opts.Handle, status.Handle); | |||
status.Check(true); | |||
return status.Code == TF_Code.TF_OK; | |||
} | |||
var opts = new ImportGraphDefOptions(); | |||
var status = new Status(); | |||
var graph_def = new Buffer(bytes); | |||
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, prefix); | |||
c_api.TF_GraphImportGraphDef(_handle, graph_def, opts, status); | |||
status.Check(true); | |||
return status.Code == TF_Code.TF_OK; | |||
} | |||
public Graph ImportGraphDef(string file_path, string name = null) | |||
@@ -75,9 +75,9 @@ namespace Tensorflow | |||
/// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | |||
/// </summary> | |||
/// <remarks>https://www.tensorflow.org/guide/graphs <br></br>https://www.tensorflow.org/api_docs/python/tf/Graph</remarks> | |||
public partial class Graph : DisposableObject | |||
, IEnumerable<Operation> | |||
public partial class Graph : IEnumerable<Operation> | |||
{ | |||
protected new SafeGraphHandle _handle; | |||
private Dictionary<int, ITensorOrOperation> _nodes_by_id; | |||
public Dictionary<string, ITensorOrOperation> _nodes_by_name; | |||
private Dictionary<string, int> _names_in_use; | |||
@@ -130,15 +130,6 @@ namespace Tensorflow | |||
_graph_key = $"graph-{ops.GraphUniqueId()}/"; | |||
} | |||
public Graph(IntPtr handle) | |||
{ | |||
_handle = handle; | |||
_nodes_by_id = new Dictionary<int, ITensorOrOperation>(); | |||
_nodes_by_name = new Dictionary<string, ITensorOrOperation>(); | |||
_names_in_use = new Dictionary<string, int>(); | |||
_graph_key = $"grap-{ops.GraphUniqueId()}/"; | |||
} | |||
public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) | |||
{ | |||
return _as_graph_element_locked(obj, allow_tensor, allow_operation); | |||
@@ -486,16 +477,6 @@ namespace Tensorflow | |||
_unfetchable_ops.Add(op); | |||
} | |||
protected override void DisposeManagedResources() | |||
{ | |||
} | |||
protected override void DisposeUnmanagedResources(IntPtr handle) | |||
{ | |||
c_api.TF_DeleteGraph(handle); | |||
} | |||
public Tensor get_tensor_by_tf_output(TF_Output tf_output) | |||
{ | |||
var op = _get_operation_by_tf_operation(tf_output.oper); | |||
@@ -517,14 +498,14 @@ namespace Tensorflow | |||
public Shape GetTensorShape(TF_Output output) | |||
{ | |||
var status = tf.Status; | |||
var ndim = c_api.TF_GraphGetTensorNumDims(_handle, output, status.Handle); | |||
var ndim = c_api.TF_GraphGetTensorNumDims(_handle, output, status); | |||
status.Check(); | |||
if (ndim == -1) | |||
return Shape.Null; | |||
var dims = new long[ndim]; | |||
c_api.TF_GraphGetTensorShape(_handle, output, dims, dims.Length, status.Handle); | |||
c_api.TF_GraphGetTensorShape(_handle, output, dims, dims.Length, status); | |||
status.Check(); | |||
return new Shape(dims.Select(x => (int)x).ToArray()); | |||
@@ -539,7 +520,7 @@ namespace Tensorflow | |||
string debugString = string.Empty; | |||
public override string ToString() | |||
{ | |||
return $"{graph_key}, 0x{_handle.ToString("x16")}"; | |||
return $"{graph_key}, 0x{_handle.DangerousGetHandle().ToString("x16")}"; | |||
/*if (string.IsNullOrEmpty(debugString)) | |||
{ | |||
int len = 0; | |||
@@ -558,7 +539,7 @@ namespace Tensorflow | |||
IEnumerator IEnumerable.GetEnumerator() | |||
=> throw new NotImplementedException(); | |||
public static implicit operator IntPtr(Graph graph) | |||
public static implicit operator SafeGraphHandle(Graph graph) | |||
{ | |||
return graph._handle; | |||
} | |||
@@ -14,28 +14,27 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
namespace Tensorflow; | |||
namespace Tensorflow | |||
public sealed class ImportGraphDefOptions | |||
{ | |||
public sealed class ImportGraphDefOptions : IDisposable | |||
{ | |||
public SafeImportGraphDefOptionsHandle Handle { get; } | |||
SafeImportGraphDefOptionsHandle _handle { get; } | |||
public int NumReturnOutputs | |||
=> c_api.TF_ImportGraphDefOptionsNumReturnOutputs(Handle); | |||
public int NumReturnOutputs | |||
=> c_api.TF_ImportGraphDefOptionsNumReturnOutputs(_handle); | |||
public ImportGraphDefOptions() | |||
{ | |||
Handle = c_api.TF_NewImportGraphDefOptions(); | |||
} | |||
public ImportGraphDefOptions() | |||
{ | |||
_handle = c_api.TF_NewImportGraphDefOptions(); | |||
} | |||
public void AddReturnOutput(string name, int index) | |||
{ | |||
c_api.TF_ImportGraphDefOptionsAddReturnOutput(Handle, name, index); | |||
} | |||
public void AddReturnOutput(string name, int index) | |||
{ | |||
c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); | |||
} | |||
public void Dispose() | |||
=> Handle.Dispose(); | |||
public static implicit operator SafeImportGraphDefOptionsHandle(ImportGraphDefOptions opt) | |||
{ | |||
return opt._handle; | |||
} | |||
} |
@@ -0,0 +1,22 @@ | |||
using Tensorflow.Util; | |||
namespace Tensorflow; | |||
public sealed class SafeFuncGraphHandle : SafeTensorflowHandle | |||
{ | |||
private SafeFuncGraphHandle() | |||
{ | |||
} | |||
public SafeFuncGraphHandle(IntPtr handle) | |||
: base(handle) | |||
{ | |||
} | |||
protected override bool ReleaseHandle() | |||
{ | |||
c_api.TF_DeleteFunction(handle); | |||
SetHandle(IntPtr.Zero); | |||
return true; | |||
} | |||
} |
@@ -0,0 +1,22 @@ | |||
using Tensorflow.Util; | |||
namespace Tensorflow; | |||
public sealed class SafeGraphHandle : SafeTensorflowHandle | |||
{ | |||
private SafeGraphHandle() | |||
{ | |||
} | |||
public SafeGraphHandle(IntPtr handle) | |||
: base(handle) | |||
{ | |||
} | |||
protected override bool ReleaseHandle() | |||
{ | |||
c_api.TF_DeleteGraph(handle); | |||
SetHandle(IntPtr.Zero); | |||
return true; | |||
} | |||
} |
@@ -60,7 +60,7 @@ namespace Tensorflow | |||
/// <param name="num_dims"></param> | |||
/// <param name="status"></param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, SafeStatusHandle status); | |||
public static extern void TF_GraphGetTensorShape(SafeGraphHandle graph, TF_Output output, long[] dims, int num_dims, SafeStatusHandle status); | |||
/// <summary> | |||
/// Import the graph serialized in `graph_def` into `graph`. | |||
@@ -78,7 +78,7 @@ namespace Tensorflow | |||
/// <param name="num_return_outputs">int</param> | |||
/// <param name="status">TF_Status*</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, IntPtr return_outputs, int num_return_outputs, SafeStatusHandle status); | |||
public static extern unsafe void TF_GraphImportGraphDefWithReturnOutputs(SafeGraphHandle graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, IntPtr return_outputs, int num_return_outputs, SafeStatusHandle status); | |||
/// <summary> | |||
/// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and | |||
@@ -92,7 +92,7 @@ namespace Tensorflow | |||
/// <param name="status">TF_Status*</param> | |||
/// <returns>TF_ImportGraphDefResults*</returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern SafeImportGraphDefResultsHandle TF_GraphImportGraphDefWithResults(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status); | |||
public static extern SafeImportGraphDefResultsHandle TF_GraphImportGraphDefWithResults(SafeGraphHandle graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status); | |||
/// <summary> | |||
/// Import the graph serialized in `graph_def` into `graph`. | |||
@@ -102,7 +102,7 @@ namespace Tensorflow | |||
/// <param name="options">TF_ImportGraphDefOptions*</param> | |||
/// <param name="status">TF_Status*</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_GraphImportGraphDef(IntPtr graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status); | |||
public static extern void TF_GraphImportGraphDef(SafeGraphHandle graph, SafeBufferHandle graph_def, SafeImportGraphDefOptionsHandle options, SafeStatusHandle status); | |||
/// <summary> | |||
/// Iterate through the operations of a graph. | |||
@@ -111,7 +111,7 @@ namespace Tensorflow | |||
/// <param name="pos"></param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TF_GraphNextOperation(IntPtr graph, ref uint pos); | |||
public static extern IntPtr TF_GraphNextOperation(SafeGraphHandle graph, ref uint pos); | |||
/// <summary> | |||
/// Returns the operation in the graph with `oper_name`. Returns nullptr if | |||
@@ -121,14 +121,14 @@ namespace Tensorflow | |||
/// <param name="oper_name"></param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TF_GraphOperationByName(IntPtr graph, string oper_name); | |||
public static extern IntPtr TF_GraphOperationByName(SafeGraphHandle graph, string oper_name); | |||
/// <summary> | |||
/// Sets the shape of the Tensor referenced by `output` in `graph` to | |||
/// the shape described by `dims` and `num_dims`. | |||
/// </summary> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_GraphSetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, SafeStatusHandle status); | |||
public static extern void TF_GraphSetTensorShape(SafeGraphHandle graph, TF_Output output, long[] dims, int num_dims, SafeStatusHandle status); | |||
/// <summary> | |||
/// Write out a serialized representation of `graph` (as a GraphDef protocol | |||
@@ -138,7 +138,7 @@ namespace Tensorflow | |||
/// <param name="output_graph_def">TF_Buffer*</param> | |||
/// <param name="status">TF_Status*</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_GraphToGraphDef(IntPtr graph, SafeBufferHandle output_graph_def, SafeStatusHandle status); | |||
public static extern void TF_GraphToGraphDef(SafeGraphHandle graph, SafeBufferHandle output_graph_def, SafeStatusHandle status); | |||
/// <summary> | |||
/// Returns the number of dimensions of the Tensor referenced by `output` | |||
@@ -151,7 +151,7 @@ namespace Tensorflow | |||
/// <param name="status"></param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern int TF_GraphGetTensorNumDims(IntPtr graph, TF_Output output, SafeStatusHandle status); | |||
public static extern int TF_GraphGetTensorNumDims(SafeGraphHandle graph, TF_Output output, SafeStatusHandle status); | |||
/// <summary> | |||
/// Cause the imported graph to have a control dependency on `oper`. `oper` | |||
@@ -287,12 +287,12 @@ namespace Tensorflow | |||
/// <param name="status">TF_Status*</param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TF_LoadSessionFromSavedModel(SafeSessionOptionsHandle session_options, IntPtr run_options, | |||
public static extern SafeSessionHandle TF_LoadSessionFromSavedModel(SafeSessionOptionsHandle session_options, IntPtr run_options, | |||
string export_dir, string[] tags, int tags_len, | |||
IntPtr graph, IntPtr meta_graph_def, SafeStatusHandle status); | |||
SafeGraphHandle graph, IntPtr meta_graph_def, SafeStatusHandle status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TF_NewGraph(); | |||
public static extern SafeGraphHandle TF_NewGraph(); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern SafeImportGraphDefOptionsHandle TF_NewImportGraphDefOptions(); | |||
@@ -334,6 +334,6 @@ namespace Tensorflow | |||
/// <param name="status"></param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern bool TF_TryEvaluateConstant(IntPtr graph, TF_Output output, IntPtr[] result, SafeStatusHandle status); | |||
public static extern bool TF_TryEvaluateConstant(SafeGraphHandle graph, TF_Output output, IntPtr[] result, SafeStatusHandle status); | |||
} | |||
} |
@@ -61,7 +61,7 @@ namespace Tensorflow.NumPy | |||
{ | |||
if (_handle is not null) | |||
{ | |||
_eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle); | |||
_eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status); | |||
} | |||
} | |||
} | |||
@@ -31,7 +31,7 @@ namespace Tensorflow | |||
public int InputListLength(string name) | |||
{ | |||
int num = 0; | |||
num = c_api.TF_OperationInputListLength(_handle, name, tf.Status.Handle); | |||
num = c_api.TF_OperationInputListLength(_handle, name, tf.Status); | |||
tf.Status.Check(true); | |||
return num; | |||
} | |||
@@ -28,7 +28,7 @@ namespace Tensorflow | |||
public int OutputListLength(string name) | |||
{ | |||
int num = c_api.TF_OperationOutputListLength(_handle, name, tf.Status.Handle); | |||
int num = c_api.TF_OperationOutputListLength(_handle, name, tf.Status); | |||
tf.Status.Check(true); | |||
return num; | |||
@@ -187,8 +187,8 @@ namespace Tensorflow | |||
if (tf.executing_eagerly()) | |||
return (T[])get_attr(name); | |||
using var buf = new Buffer(); | |||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf.Handle, tf.Status.Handle); | |||
var buf = new Buffer(); | |||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, tf.Status); | |||
tf.Status.Check(true); | |||
var x = AttrValue.Parser.ParseFrom(buf.ToArray()); | |||
@@ -210,8 +210,8 @@ namespace Tensorflow | |||
public virtual object get_attr(string name) | |||
{ | |||
using var buf = new Buffer(); | |||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf.Handle, tf.Status.Handle); | |||
var buf = new Buffer(); | |||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, tf.Status); | |||
tf.Status.Check(true); | |||
var x = AttrValue.Parser.ParseFrom(buf.ToArray()); | |||
@@ -235,13 +235,13 @@ namespace Tensorflow | |||
public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s) | |||
{ | |||
return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s.Handle); | |||
return c_api.TF_OperationGetAttrMetadata(_handle, attr_name, s); | |||
} | |||
private NodeDef GetNodeDef() | |||
{ | |||
using var buffer = new Buffer(); | |||
c_api.TF_OperationToNodeDef(_handle, buffer.Handle, tf.Status.Handle); | |||
var buffer = new Buffer(); | |||
c_api.TF_OperationToNodeDef(_handle, buffer, tf.Status); | |||
tf.Status.Check(throwException: true); | |||
return NodeDef.Parser.ParseFrom(buffer.ToArray()); | |||
} | |||
@@ -50,7 +50,7 @@ namespace Tensorflow | |||
public Operation FinishOperation(Status status) | |||
{ | |||
return c_api.TF_FinishOperation(_handle, status.Handle); | |||
return c_api.TF_FinishOperation(_handle, status); | |||
} | |||
public static implicit operator OperationDescription(IntPtr handle) | |||
@@ -96,7 +96,7 @@ namespace Tensorflow | |||
/// <param name="oper_name">const char*</param> | |||
/// <returns>TF_OperationDescription*</returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TF_NewOperation(IntPtr graph, string opType, string oper_name); | |||
public static extern IntPtr TF_NewOperation(SafeGraphHandle graph, string opType, string oper_name); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TF_OperationDevice(IntPtr oper); | |||
@@ -14,281 +14,272 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using Google.Protobuf; | |||
using Tensorflow.NumPy; | |||
using System; | |||
using System.Collections; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Numerics; | |||
using System.Text; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
namespace Tensorflow; | |||
public class BaseSession : IDisposable | |||
{ | |||
public class BaseSession : DisposableObject | |||
protected SafeSessionHandle _handle; | |||
protected Graph _graph; | |||
protected Status _status; | |||
public Graph graph => _graph; | |||
public BaseSession(SafeSessionHandle handle, Graph g) | |||
{ | |||
protected Graph _graph; | |||
protected Status _status; | |||
public Graph graph => _graph; | |||
_handle = handle; | |||
_graph = g ?? ops.get_default_graph(); | |||
} | |||
public BaseSession(IntPtr handle, Graph g) | |||
public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null) | |||
{ | |||
_graph = g ?? ops.get_default_graph(); | |||
if (!_graph.building_function) | |||
{ | |||
_handle = handle; | |||
_graph = g ?? ops.get_default_graph(); | |||
if (ops.get_default_graph() != _graph) | |||
_graph.as_default(); | |||
} | |||
var opts = new SessionOptions(target, config); | |||
_status = status ?? tf.Status; | |||
_handle = c_api.TF_NewSession(_graph, opts, _status); | |||
_status.Check(true); | |||
} | |||
public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null) | |||
{ | |||
_graph = g ?? ops.get_default_graph(); | |||
if (!_graph.building_function) | |||
{ | |||
if (ops.get_default_graph() != _graph) | |||
_graph.as_default(); | |||
} | |||
using var opts = new SessionOptions(target, config); | |||
_status = status ?? tf.Status; | |||
_handle = c_api.TF_NewSession(_graph, opts.Handle, _status.Handle); | |||
_status.Check(true); | |||
} | |||
public virtual void run(Operation op, params FeedItem[] feed_dict) | |||
{ | |||
_run(op, feed_dict); | |||
} | |||
public virtual void run(Operation op, params FeedItem[] feed_dict) | |||
{ | |||
_run(op, feed_dict); | |||
} | |||
public virtual NDArray run(Tensor fetche, params FeedItem[] feed_dict) | |||
{ | |||
return _run(fetche, feed_dict)[0]; | |||
} | |||
public virtual NDArray run(Tensor fetche, params FeedItem[] feed_dict) | |||
{ | |||
return _run(fetche, feed_dict)[0]; | |||
} | |||
public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict) | |||
{ | |||
var results = _run(fetche, feed_dict); | |||
return fetche is Tensor ? results[0] : null; | |||
} | |||
public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict) | |||
{ | |||
var results = _run(fetche, feed_dict); | |||
return fetche is Tensor ? results[0] : null; | |||
} | |||
public virtual (NDArray, NDArray, NDArray, NDArray, NDArray) run( | |||
(ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, | |||
params FeedItem[] feed_dict) | |||
{ | |||
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4, fetches.Item5 }, feed_dict); | |||
return (results[0], results[1], results[2], results[3], results[4]); | |||
} | |||
public virtual (NDArray, NDArray, NDArray, NDArray, NDArray) run( | |||
(ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, | |||
params FeedItem[] feed_dict) | |||
{ | |||
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4, fetches.Item5 }, feed_dict); | |||
return (results[0], results[1], results[2], results[3], results[4]); | |||
} | |||
public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
{ | |||
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); | |||
return (results[0], results[1], results[2], results[3]); | |||
} | |||
public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
{ | |||
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); | |||
return (results[0], results[1], results[2], results[3]); | |||
} | |||
public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
{ | |||
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict); | |||
return (results[0], results[1], results[2]); | |||
} | |||
public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
{ | |||
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict); | |||
return (results[0], results[1], results[2]); | |||
} | |||
public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
{ | |||
var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict); | |||
return (results[0], results[1]); | |||
} | |||
public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
{ | |||
var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict); | |||
return (results[0], results[1]); | |||
} | |||
public virtual NDArray[] run(object fetches, params FeedItem[] feed_dict) | |||
{ | |||
return _run(fetches, feed_dict); | |||
} | |||
public virtual NDArray[] run(object fetches, params FeedItem[] feed_dict) | |||
{ | |||
return _run(fetches, feed_dict); | |||
} | |||
public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) | |||
{ | |||
var feed_items = feed_dict == null ? new FeedItem[0] : feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | |||
return _run(fetches, feed_items); | |||
} | |||
public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) | |||
{ | |||
var feed_items = feed_dict == null ? new FeedItem[0] : feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | |||
return _run(fetches, feed_items); | |||
} | |||
private NDArray[] _run(object fetches, FeedItem[] feed_dict = null) | |||
{ | |||
var feed_dict_tensor = new Dictionary<object, object>(); | |||
//var feed_map = new Dictionary<object, object>(); | |||
private NDArray[] _run(object fetches, FeedItem[] feed_dict = null) | |||
// Validate and process feed_dict. | |||
if (feed_dict != null) | |||
{ | |||
var feed_dict_tensor = new Dictionary<object, object>(); | |||
//var feed_map = new Dictionary<object, object>(); | |||
// Validate and process feed_dict. | |||
if (feed_dict != null) | |||
foreach (var subfeed in feed_dict) | |||
{ | |||
foreach (var subfeed in feed_dict) | |||
{ | |||
var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false); | |||
//var target_dtype = subfeed_t.dtype.as_numpy_typecode(); // subfeed_dtype was never used | |||
feed_dict_tensor[subfeed_t] = subfeed.Value; | |||
//feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value); | |||
} | |||
var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false); | |||
//var target_dtype = subfeed_t.dtype.as_numpy_typecode(); // subfeed_dtype was never used | |||
feed_dict_tensor[subfeed_t] = subfeed.Value; | |||
//feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value); | |||
} | |||
} | |||
// Create a fetch handler to take care of the structure of fetches. | |||
var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); | |||
// Create a fetch handler to take care of the structure of fetches. | |||
var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor); | |||
// Run request and get response. | |||
// We need to keep the returned movers alive for the following _do_run(). | |||
// These movers are no longer needed when _do_run() completes, and | |||
// are deleted when `movers` goes out of scope when this _run() ends. | |||
var _ = _update_with_movers(); | |||
var final_fetches = fetch_handler.fetches(); | |||
var final_targets = fetch_handler.targets(); | |||
// Run request and get response. | |||
// We need to keep the returned movers alive for the following _do_run(). | |||
// These movers are no longer needed when _do_run() completes, and | |||
// are deleted when `movers` goes out of scope when this _run() ends. | |||
var _ = _update_with_movers(); | |||
var final_fetches = fetch_handler.fetches(); | |||
var final_targets = fetch_handler.targets(); | |||
// We only want to really perform the run if fetches or targets are provided, | |||
// or if the call is a partial run that specifies feeds. | |||
var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor); | |||
// We only want to really perform the run if fetches or targets are provided, | |||
// or if the call is a partial run that specifies feeds. | |||
var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor); | |||
return fetch_handler.build_results(this, results); | |||
} | |||
return fetch_handler.build_results(this, results); | |||
} | |||
/// <summary> | |||
/// Runs a step based on the given fetches and feeds. | |||
/// </summary> | |||
/// <param name="target_list">A list of operations to be run, but not fetched.</param> | |||
/// <param name="fetch_list"></param> | |||
/// <param name="feed_dict"></param> | |||
/// <returns> | |||
/// A list of numpy ndarrays, corresponding to the elements of | |||
/// `fetch_list`. If the ith element of `fetch_list` contains the | |||
/// name of an operation, the first Tensor output of that operation | |||
/// will be returned for that element. | |||
/// </returns> | |||
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | |||
/// <summary> | |||
/// Runs a step based on the given fetches and feeds. | |||
/// </summary> | |||
/// <param name="target_list">A list of operations to be run, but not fetched.</param> | |||
/// <param name="fetch_list"></param> | |||
/// <param name="feed_dict"></param> | |||
/// <returns> | |||
/// A list of numpy ndarrays, corresponding to the elements of | |||
/// `fetch_list`. If the ith element of `fetch_list` contains the | |||
/// name of an operation, the first Tensor output of that operation | |||
/// will be returned for that element. | |||
/// </returns> | |||
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | |||
{ | |||
var feeds = new KeyValuePair<TF_Output, Tensor>[feed_dict.Count]; | |||
int i = 0; | |||
foreach (var x in feed_dict) | |||
{ | |||
var feeds = new KeyValuePair<TF_Output, Tensor>[feed_dict.Count]; | |||
int i = 0; | |||
foreach (var x in feed_dict) | |||
if (x.Key is Tensor key) | |||
{ | |||
if (x.Key is Tensor key) | |||
switch (x.Value) | |||
{ | |||
switch (x.Value) | |||
{ | |||
case Tensor v: | |||
if (v.dtype != key.dtype) | |||
throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {v.dtype}"); | |||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), v); | |||
break; | |||
case SafeTensorHandle v: | |||
var tensor = new Tensor(v); | |||
if (tensor.dtype != key.dtype) | |||
throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {tensor.dtype}"); | |||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), tensor); | |||
break; | |||
case bool v: | |||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||
break; | |||
case byte v: | |||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||
break; | |||
case int v: | |||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||
break; | |||
case long v: | |||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||
break; | |||
case float v: | |||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||
break; | |||
case double v: | |||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||
break; | |||
case string v: | |||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||
break; | |||
case Array v: | |||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v, v.GetShape())); | |||
break; | |||
default: | |||
throw new NotImplementedException(""); | |||
} | |||
case Tensor v: | |||
if (v.dtype != key.dtype) | |||
throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {v.dtype}"); | |||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), v); | |||
break; | |||
case SafeTensorHandle v: | |||
var tensor = new Tensor(v); | |||
if (tensor.dtype != key.dtype) | |||
throw new ValueError($"Tensor {v} does not match the expected dtype {key.dtype}, actual dtype: {tensor.dtype}"); | |||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), tensor); | |||
break; | |||
case bool v: | |||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||
break; | |||
case byte v: | |||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||
break; | |||
case int v: | |||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||
break; | |||
case long v: | |||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||
break; | |||
case float v: | |||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||
break; | |||
case double v: | |||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||
break; | |||
case string v: | |||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v)); | |||
break; | |||
case Array v: | |||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(key._as_tf_output(), new Tensor(v, v.GetShape())); | |||
break; | |||
default: | |||
throw new NotImplementedException(""); | |||
} | |||
else | |||
throw new NotImplementedException(""); | |||
} | |||
var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | |||
//var targets = target_list; | |||
return _call_tf_sessionrun(feeds, fetches, target_list); | |||
else | |||
throw new NotImplementedException(""); | |||
} | |||
var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); | |||
//var targets = target_list; | |||
return _call_tf_sessionrun(feeds, fetches, target_list); | |||
} | |||
private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list) | |||
{ | |||
// Ensure any changes to the graph are reflected in the runtime. | |||
_extend_graph(); | |||
var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); | |||
private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list) | |||
{ | |||
// Ensure any changes to the graph are reflected in the runtime. | |||
_extend_graph(); | |||
c_api.TF_SessionRun(_handle, | |||
run_options: null, | |||
inputs: feed_dict.Select(f => f.Key).ToArray(), | |||
input_values: feed_dict.Select(f => f.Value.Handle.DangerousGetHandle()).ToArray(), | |||
ninputs: feed_dict.Length, | |||
outputs: fetch_list, | |||
output_values: output_values, | |||
noutputs: fetch_list.Length, | |||
target_opers: target_list.Select(f => (IntPtr)f).ToArray(), | |||
ntargets: target_list.Count, | |||
run_metadata: IntPtr.Zero, | |||
status: _status.Handle); | |||
var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); | |||
_status.Check(true); | |||
c_api.TF_SessionRun(_handle, | |||
run_options: null, | |||
inputs: feed_dict.Select(f => f.Key).ToArray(), | |||
input_values: feed_dict.Select(f => f.Value.Handle.DangerousGetHandle()).ToArray(), | |||
ninputs: feed_dict.Length, | |||
outputs: fetch_list, | |||
output_values: output_values, | |||
noutputs: fetch_list.Length, | |||
target_opers: target_list.Select(f => (IntPtr)f).ToArray(), | |||
ntargets: target_list.Count, | |||
run_metadata: IntPtr.Zero, | |||
status: _status); | |||
var result = new NDArray[fetch_list.Length]; | |||
_status.Check(true); | |||
for (int i = 0; i < fetch_list.Length; i++) | |||
result[i] = fetchValue(new SafeTensorHandle(output_values[i])); | |||
var result = new NDArray[fetch_list.Length]; | |||
return result; | |||
} | |||
for (int i = 0; i < fetch_list.Length; i++) | |||
result[i] = fetchValue(new SafeTensorHandle(output_values[i])); | |||
public unsafe Tensor eval(Tensor tensor) | |||
{ | |||
var output_values = new IntPtr[1]; | |||
var fetch_list = new[] { tensor._as_tf_output() }; | |||
c_api.TF_SessionRun(_handle, | |||
run_options: null, | |||
inputs: new TF_Output[0], | |||
input_values: new IntPtr[0], | |||
ninputs: 0, | |||
outputs: fetch_list, | |||
output_values: output_values, | |||
noutputs: 1, | |||
target_opers: new IntPtr[0], | |||
ntargets: 0, | |||
run_metadata: IntPtr.Zero, | |||
status: _status.Handle); | |||
_status.Check(true); | |||
return new Tensor(new SafeTensorHandle(output_values[0])); | |||
} | |||
return result; | |||
} | |||
private static unsafe NDArray fetchValue(SafeTensorHandle output) | |||
{ | |||
var tensor = new Tensor(output); | |||
return tensor.numpy(); | |||
} | |||
public unsafe Tensor eval(Tensor tensor) | |||
{ | |||
var output_values = new IntPtr[1]; | |||
var fetch_list = new[] { tensor._as_tf_output() }; | |||
c_api.TF_SessionRun(_handle, | |||
run_options: null, | |||
inputs: new TF_Output[0], | |||
input_values: new IntPtr[0], | |||
ninputs: 0, | |||
outputs: fetch_list, | |||
output_values: output_values, | |||
noutputs: 1, | |||
target_opers: new IntPtr[0], | |||
ntargets: 0, | |||
run_metadata: IntPtr.Zero, | |||
status: _status); | |||
_status.Check(true); | |||
return new Tensor(new SafeTensorHandle(output_values[0])); | |||
} | |||
/// <summary> | |||
/// If a tensor handle that is fed to a device incompatible placeholder, | |||
/// we move the tensor to the right device, generate a new tensor handle, | |||
/// and update feed_dict to use the new handle. | |||
/// </summary> | |||
private List<object> _update_with_movers() | |||
{ | |||
return new List<object> { }; | |||
} | |||
private static unsafe NDArray fetchValue(SafeTensorHandle output) | |||
{ | |||
var tensor = new Tensor(output); | |||
return tensor.numpy(); | |||
} | |||
private void _extend_graph() | |||
{ } | |||
/// <summary> | |||
/// If a tensor handle that is fed to a device incompatible placeholder, | |||
/// we move the tensor to the right device, generate a new tensor handle, | |||
/// and update feed_dict to use the new handle. | |||
/// </summary> | |||
private List<object> _update_with_movers() | |||
{ | |||
return new List<object> { }; | |||
} | |||
protected override void DisposeUnmanagedResources(IntPtr handle) | |||
{ | |||
// c_api.TF_CloseSession(handle, tf.Status.Handle); | |||
c_api.TF_DeleteSession(handle, _status.Handle); | |||
} | |||
private void _extend_graph() | |||
{ } | |||
public void Dispose() | |||
{ | |||
} | |||
} |
@@ -0,0 +1,46 @@ | |||
/***************************************************************************** | |||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using System.Net.NetworkInformation; | |||
using Tensorflow.Util; | |||
namespace Tensorflow | |||
{ | |||
public sealed class SafeSessionHandle : SafeTensorflowHandle | |||
{ | |||
private SafeSessionHandle() | |||
{ | |||
} | |||
public SafeSessionHandle(IntPtr handle) | |||
: base(handle) | |||
{ | |||
} | |||
public override string ToString() | |||
=> $"0x{handle:x16}"; | |||
protected override bool ReleaseHandle() | |||
{ | |||
var status = new Status(); | |||
// c_api.TF_CloseSession(handle, tf.Status.Handle); | |||
c_api.TF_DeleteSession(handle, status); | |||
SetHandle(IntPtr.Zero); | |||
return true; | |||
} | |||
} | |||
} |
@@ -14,75 +14,49 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using System.IO; | |||
using System.Runtime.CompilerServices; | |||
using Tensorflow.Util; | |||
namespace Tensorflow; | |||
namespace Tensorflow | |||
public class Session : BaseSession | |||
{ | |||
public class Session : BaseSession | |||
{ | |||
public Session(string target = "", Graph g = null) : base(target, g, null) | |||
{ } | |||
public Session(IntPtr handle, Graph g = null) : base(handle, g) | |||
{ } | |||
public Session(Graph g, ConfigProto config = null, Status s = null) : base("", g, config, s) | |||
{ } | |||
public Session as_default() | |||
{ | |||
return ops.set_default_session(this); | |||
} | |||
public static Session LoadFromSavedModel(string path) | |||
{ | |||
var graph = new Graph(); | |||
using var status = new Status(); | |||
using var opt = c_api.TF_NewSessionOptions(); | |||
var tags = new string[] { "serve" }; | |||
var sess = c_api.TF_LoadSessionFromSavedModel(opt, | |||
IntPtr.Zero, | |||
path, | |||
tags, | |||
tags.Length, | |||
graph, | |||
IntPtr.Zero, | |||
status.Handle); | |||
status.Check(true); | |||
// load graph bytes | |||
// var data = new byte[buffer.length]; | |||
// Marshal.Copy(buffer.data, data, 0, (int)buffer.length); | |||
// var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ | |||
return new Session(sess, g: graph); | |||
} | |||
public static implicit operator IntPtr(Session session) => session._handle; | |||
public static implicit operator Session(IntPtr handle) => new Session(handle); | |||
public Session(string target = "", Graph g = null) : base(target, g, null) | |||
{ } | |||
public void __enter__() | |||
{ | |||
public Session(SafeSessionHandle handle, Graph g = null) : base(handle, g) | |||
{ } | |||
} | |||
public Session(Graph g, ConfigProto config = null, Status s = null) : base("", g, config, s) | |||
{ } | |||
public void __exit__() | |||
{ | |||
} | |||
public void __init__() | |||
{ | |||
} | |||
public void __del__() | |||
{ | |||
public Session as_default() | |||
{ | |||
return ops.set_default_session(this); | |||
} | |||
} | |||
public static Session LoadFromSavedModel(string path) | |||
{ | |||
var graph = new Graph(); | |||
var status = new Status(); | |||
using var opt = c_api.TF_NewSessionOptions(); | |||
var tags = new string[] { "serve" }; | |||
var sess = c_api.TF_LoadSessionFromSavedModel(opt, | |||
IntPtr.Zero, | |||
path, | |||
tags, | |||
tags.Length, | |||
graph, | |||
IntPtr.Zero, | |||
status); | |||
status.Check(true); | |||
// load graph bytes | |||
// var data = new byte[buffer.length]; | |||
// Marshal.Copy(buffer.data, data, 0, (int)buffer.length); | |||
// var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ | |||
return new Session(sess, g: graph); | |||
} | |||
public static implicit operator SafeSessionHandle(Session session) => session._handle; | |||
public static implicit operator Session(SafeSessionHandle handle) => new Session(handle); | |||
} |
@@ -19,33 +19,33 @@ using System; | |||
namespace Tensorflow | |||
{ | |||
internal sealed class SessionOptions : IDisposable | |||
internal sealed class SessionOptions | |||
{ | |||
public SafeSessionOptionsHandle Handle { get; } | |||
SafeSessionOptionsHandle _handle { get; } | |||
public SessionOptions(string target = "", ConfigProto config = null) | |||
{ | |||
Handle = c_api.TF_NewSessionOptions(); | |||
c_api.TF_SetTarget(Handle, target); | |||
_handle = c_api.TF_NewSessionOptions(); | |||
c_api.TF_SetTarget(_handle, target); | |||
if (config != null) | |||
SetConfig(config); | |||
} | |||
public void Dispose() | |||
=> Handle.Dispose(); | |||
private unsafe void SetConfig(ConfigProto config) | |||
{ | |||
var bytes = config.ToByteArray(); | |||
fixed (byte* proto2 = bytes) | |||
{ | |||
using (var status = new Status()) | |||
{ | |||
c_api.TF_SetConfig(Handle, (IntPtr)proto2, (ulong)bytes.Length, status.Handle); | |||
status.Check(false); | |||
} | |||
var status = new Status(); | |||
c_api.TF_SetConfig(_handle, (IntPtr)proto2, (ulong)bytes.Length, status); | |||
status.Check(false); | |||
} | |||
} | |||
public static implicit operator SafeSessionOptionsHandle(SessionOptions opt) | |||
{ | |||
return opt._handle; | |||
} | |||
} | |||
} |
@@ -62,7 +62,7 @@ namespace Tensorflow | |||
/// <param name="status">TF_Status*</param> | |||
/// <returns>TF_Session*</returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TF_NewSession(IntPtr graph, SafeSessionOptionsHandle opts, SafeStatusHandle status); | |||
public static extern SafeSessionHandle TF_NewSession(SafeGraphHandle graph, SafeSessionOptionsHandle opts, SafeStatusHandle status); | |||
/// <summary> | |||
/// Return a new options object. | |||
@@ -110,7 +110,7 @@ namespace Tensorflow | |||
/// <param name="run_metadata">TF_Buffer*</param> | |||
/// <param name="status">TF_Status*</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern unsafe void TF_SessionRun(IntPtr session, TF_Buffer* run_options, | |||
public static extern unsafe void TF_SessionRun(SafeSessionHandle session, TF_Buffer* run_options, | |||
TF_Output[] inputs, IntPtr[] input_values, int ninputs, | |||
TF_Output[] outputs, IntPtr[] output_values, int noutputs, | |||
IntPtr[] target_opers, int ntargets, | |||
@@ -26,7 +26,7 @@ namespace Tensorflow | |||
/// TF_Status holds error information. It either has an OK code, or | |||
/// else an error code with an associated error message. | |||
/// </summary> | |||
public sealed class Status : IDisposable | |||
public sealed class Status | |||
{ | |||
/// <summary> | |||
/// Error message | |||
@@ -35,9 +35,9 @@ namespace Tensorflow | |||
{ | |||
get | |||
{ | |||
using (Handle.Lease()) | |||
using (_handle.Lease()) | |||
{ | |||
return StringPiece(TF_Message(Handle)); | |||
return StringPiece(TF_Message(_handle)); | |||
} | |||
} | |||
} | |||
@@ -45,23 +45,23 @@ namespace Tensorflow | |||
/// <summary> | |||
/// Error code | |||
/// </summary> | |||
public TF_Code Code => TF_GetCode(Handle); | |||
public TF_Code Code => TF_GetCode(_handle); | |||
public SafeStatusHandle Handle { get; } | |||
SafeStatusHandle _handle { get; } | |||
public Status() | |||
{ | |||
Handle = TF_NewStatus(); | |||
_handle = TF_NewStatus(); | |||
} | |||
public Status(SafeStatusHandle handle) | |||
{ | |||
Handle = handle ?? throw new ArgumentNullException(nameof(handle)); | |||
_handle = handle ?? throw new ArgumentNullException(nameof(handle)); | |||
} | |||
public void SetStatus(TF_Code code, string msg) | |||
{ | |||
TF_SetStatus(Handle, code, msg); | |||
TF_SetStatus(_handle, code, msg); | |||
} | |||
public bool ok() => Code == TF_Code.TF_OK; | |||
@@ -94,10 +94,12 @@ namespace Tensorflow | |||
} | |||
} | |||
public void Dispose() | |||
=> Handle.Dispose(); | |||
public override string ToString() | |||
=> $"{Code} 0x{Handle.DangerousGetHandle():x16}"; | |||
=> $"{Code} 0x{_handle.DangerousGetHandle():x16}"; | |||
public static implicit operator SafeStatusHandle(Status status) | |||
{ | |||
return status._handle; | |||
} | |||
} | |||
} |
@@ -121,7 +121,7 @@ namespace Tensorflow | |||
if (_handle == null) | |||
{ | |||
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, tf.Status.Handle); | |||
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, tf.Status); | |||
} | |||
else | |||
{ | |||
@@ -135,9 +135,9 @@ namespace Tensorflow | |||
protected virtual void SetShapeInternal(Shape value) | |||
{ | |||
if (value == null) | |||
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.Status.Handle); | |||
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.Status); | |||
else | |||
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.dims, value.ndim, tf.Status.Handle); | |||
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.dims, value.ndim, tf.Status); | |||
} | |||
public int[] _shape_tuple() | |||
@@ -176,7 +176,7 @@ namespace Tensorflow | |||
if (_handle == null) | |||
{ | |||
var output = _as_tf_output(); | |||
int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, tf.Status.Handle); | |||
int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, tf.Status); | |||
return ndim; | |||
} | |||
@@ -94,18 +94,16 @@ namespace Tensorflow | |||
string output_pb = Path.GetFullPath(Path.Combine(checkpoint_dir, "../", $"{output_pb_name}.pb")); | |||
using (var graph = tf.Graph()) | |||
using (var sess = tf.Session(graph)) | |||
{ | |||
var saver = tf.train.import_meta_graph($"{checkpoint}.meta", clear_devices: true); | |||
saver.restore(sess, checkpoint); | |||
var output_graph_def = tf.graph_util.convert_variables_to_constants(sess, | |||
graph.as_graph_def(), | |||
output_node_names); | |||
Binding.tf_output_redirect.WriteLine($"Froze {output_graph_def.Node.Count} nodes."); | |||
File.WriteAllBytes(output_pb, output_graph_def.ToByteArray()); | |||
return output_pb; | |||
} | |||
var graph = tf.Graph(); | |||
var sess = tf.Session(graph); | |||
var saver = tf.train.import_meta_graph($"{checkpoint}.meta", clear_devices: true); | |||
saver.restore(sess, checkpoint); | |||
var output_graph_def = tf.graph_util.convert_variables_to_constants(sess, | |||
graph.as_graph_def(), | |||
output_node_names); | |||
Binding.tf_output_redirect.WriteLine($"Froze {output_graph_def.Node.Count} nodes."); | |||
File.WriteAllBytes(output_pb, output_graph_def.ToByteArray()); | |||
return output_pb; | |||
} | |||
public static Graph load_graph(string freeze_graph_pb, string name = "") | |||
@@ -164,7 +164,7 @@ namespace Tensorflow | |||
result._as_tf_output(), | |||
shape.dims, | |||
shape.ndim, | |||
tf.Status.Handle); | |||
tf.Status); | |||
tf.Status.Check(true); | |||
} | |||
@@ -247,7 +247,7 @@ namespace Tensorflow | |||
foreach (var attr in node_def.Attr) | |||
{ | |||
var bytes = attr.Value.ToByteArray(); | |||
c_api.TF_SetAttrValueProto(op_desc, attr.Key, bytes, proto_len: bytes.Length, status: status.Handle); | |||
c_api.TF_SetAttrValueProto(op_desc, attr.Key, bytes, proto_len: bytes.Length, status: status); | |||
status.Check(true); | |||
} | |||
@@ -23,16 +23,14 @@ namespace Tensorflow.Benchmark.Leak | |||
var ClassifierModelPath = Path.Combine(modelDir, "Leak", "TestModel", "saved_model"); | |||
for (var i = 0; i < 1024; i++) | |||
{ | |||
using (var sess = Session.LoadFromSavedModel(ClassifierModelPath)) { | |||
using (var g = sess.graph.as_default()) { | |||
var inputOp = g.OperationByName("inference_input"); | |||
var outputOp = g.OperationByName("StatefulPartitionedCall"); | |||
{ | |||
var sess = Session.LoadFromSavedModel(ClassifierModelPath); | |||
var g = sess.graph.as_default(); | |||
var inputOp = g.OperationByName("inference_input"); | |||
var outputOp = g.OperationByName("StatefulPartitionedCall"); | |||
var inp = np.zeros(new Shape(new int[] { 1, 2, 96 }), TF_DataType.TF_FLOAT); | |||
sess.run(outputOp.outputs[0], new FeedItem(inputOp.outputs[0], inp)); | |||
} | |||
} | |||
var inp = np.zeros(new Shape(new int[] { 1, 2, 96 }), TF_DataType.TF_FLOAT); | |||
sess.run(outputOp.outputs[0], new FeedItem(inputOp.outputs[0], inp)); | |||
} | |||
} | |||
} | |||
@@ -16,18 +16,16 @@ namespace TensorFlowNET.UnitTest.Basics | |||
var enqueue = queue.enqueue(numbers); | |||
var dequeue_many = queue.dequeue_many(n: 3); | |||
using (var sess = tf.Session()) | |||
{ | |||
sess.run(enqueue, (numbers, new[] { 1 })); | |||
sess.run(enqueue, (numbers, new[] { 2, 3 })); | |||
sess.run(enqueue, (numbers, new[] { 3, 4, 5 })); | |||
var result = sess.run(dequeue_many[0]); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0 }, result[0].ToArray<int>())); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3, 0 }, result[1].ToArray<int>())); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 4, 5 }, result[2].ToArray<int>())); | |||
} | |||
var sess = tf.Session(); | |||
sess.run(enqueue, (numbers, new[] { 1 })); | |||
sess.run(enqueue, (numbers, new[] { 2, 3 })); | |||
sess.run(enqueue, (numbers, new[] { 3, 4, 5 })); | |||
var result = sess.run(dequeue_many[0]); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0 }, result[0].ToArray<int>())); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3, 0 }, result[1].ToArray<int>())); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 4, 5 }, result[2].ToArray<int>())); | |||
} | |||
[TestMethod] | |||
@@ -45,27 +43,25 @@ namespace TensorFlowNET.UnitTest.Basics | |||
// push back into queue | |||
var inc = queue.enqueue(y); | |||
using (var sess = tf.Session()) | |||
{ | |||
// init queue | |||
init.run(); | |||
var sess = tf.Session(); | |||
// init queue | |||
init.run(); | |||
// pop out first element and push back calculated y | |||
(int dequeued, _) = sess.run((x, inc)); | |||
Assert.AreEqual(10, dequeued); | |||
// pop out first element and push back calculated y | |||
(int dequeued, _) = sess.run((x, inc)); | |||
Assert.AreEqual(10, dequeued); | |||
(dequeued, _) = sess.run((x, inc)); | |||
Assert.AreEqual(20, dequeued); | |||
(dequeued, _) = sess.run((x, inc)); | |||
Assert.AreEqual(20, dequeued); | |||
(dequeued, _) = sess.run((x, inc)); | |||
Assert.AreEqual(11, dequeued); | |||
(dequeued, _) = sess.run((x, inc)); | |||
Assert.AreEqual(11, dequeued); | |||
(dequeued, _) = sess.run((x, inc)); | |||
Assert.AreEqual(21, dequeued); | |||
(dequeued, _) = sess.run((x, inc)); | |||
Assert.AreEqual(21, dequeued); | |||
// thread will hang or block if you run sess.run(x) again | |||
// until queue has more element. | |||
} | |||
// thread will hang or block if you run sess.run(x) again | |||
// until queue has more element. | |||
} | |||
[TestMethod] | |||
@@ -75,19 +71,17 @@ namespace TensorFlowNET.UnitTest.Basics | |||
var init = queue.enqueue_many(new[] { 2L, 4L, 3L }, new[] { "p1", "p2", "p3" }); | |||
var x = queue.dequeue(); | |||
using (var sess = tf.Session()) | |||
{ | |||
init.run(); | |||
var sess = tf.Session(); | |||
init.run(); | |||
var result = sess.run(x); | |||
Assert.AreEqual(result[0], 2L); | |||
var result = sess.run(x); | |||
Assert.AreEqual(result[0], 2L); | |||
result = sess.run(x); | |||
Assert.AreEqual(result[0], 3L); | |||
result = sess.run(x); | |||
Assert.AreEqual(result[0], 3L); | |||
result = sess.run(x); | |||
Assert.AreEqual(result[0], 4L); | |||
} | |||
result = sess.run(x); | |||
Assert.AreEqual(result[0], 4L); | |||
} | |||
[TestMethod] | |||
@@ -98,16 +92,14 @@ namespace TensorFlowNET.UnitTest.Basics | |||
var x = queue.dequeue(); | |||
string results = ""; | |||
using (var sess = tf.Session()) | |||
{ | |||
init.run(); | |||
var sess = tf.Session(); | |||
init.run(); | |||
foreach (var i in range(9)) | |||
results += (int)sess.run(x) + "."; | |||
foreach (var i in range(9)) | |||
results += (int)sess.run(x) + "."; | |||
// output in random order | |||
Assert.IsFalse(results == "1.2.3.4.5.6.7.8.9."); | |||
} | |||
// output in random order | |||
Assert.IsFalse(results == "1.2.3.4.5.6.7.8.9."); | |||
} | |||
} | |||
} |
@@ -19,11 +19,9 @@ namespace TensorFlowNET.UnitTest.Basics | |||
var a = constant_op.constant(np.array(3.0).reshape((1, 1))); | |||
var b = constant_op.constant(np.array(2.0).reshape((1, 1))); | |||
var c = math_ops.matmul(a, b, name: "matmul"); | |||
using (var sess = tf.Session()) | |||
{ | |||
var result = c.eval(sess); | |||
Assert.AreEqual(result[0], 6.0); | |||
} | |||
var sess = tf.Session(); | |||
var result = c.eval(sess); | |||
Assert.AreEqual(result[0], 6.0); | |||
} | |||
} | |||
@@ -32,11 +30,9 @@ namespace TensorFlowNET.UnitTest.Basics | |||
{ | |||
var a = constant_op.constant("123 heythere 123 ", TF_DataType.TF_STRING); | |||
var c = tf.strings.substr(a, 4, 8); | |||
using (var sess = tf.Session()) | |||
{ | |||
var result = c.eval(sess).StringData(); | |||
Assert.AreEqual(result[0], "heythere"); | |||
} | |||
var sess = tf.Session(); | |||
var result = c.eval(sess).StringData(); | |||
Assert.AreEqual(result[0], "heythere"); | |||
} | |||
[TestMethod] | |||
@@ -47,11 +43,9 @@ namespace TensorFlowNET.UnitTest.Basics | |||
const int size = 30_000; | |||
var a = constant_op.constant(new string('a', size), TF_DataType.TF_STRING); | |||
var c = tf.strings.substr(a, 0, size - 5000); | |||
using (var sess = tf.Session()) | |||
{ | |||
var result = UTF8Encoding.UTF8.GetString(c.eval(sess).ToByteArray()); | |||
Console.WriteLine(result); | |||
} | |||
var sess = tf.Session(); | |||
var result = UTF8Encoding.UTF8.GetString(c.eval(sess).ToByteArray()); | |||
Console.WriteLine(result); | |||
} | |||
} | |||
@@ -16,15 +16,13 @@ namespace TensorFlowNET.UnitTest.Basics | |||
var labels = tf.expand_dims(tf.constant(new[] { 0, 1, 2, 3, 4 }), 1); | |||
var st = tf.concat(values: new[] { indices, labels }, axis: 1); | |||
var onehot = tf.sparse_to_dense(st, (5, 5), 1); | |||
using (var sess = tf.Session()) | |||
{ | |||
var result = sess.run(onehot); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0, 0 }, result[0].ToArray<int>())); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 1, 0, 0, 0 }, result[1].ToArray<int>())); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 1, 0, 0 }, result[2].ToArray<int>())); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 1, 0 }, result[3].ToArray<int>())); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 1 }, result[4].ToArray<int>())); | |||
}; | |||
var sess = tf.Session(); | |||
var result = sess.run(onehot); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0, 0 }, result[0].ToArray<int>())); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 1, 0, 0, 0 }, result[1].ToArray<int>())); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 1, 0, 0 }, result[2].ToArray<int>())); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 1, 0 }, result[3].ToArray<int>())); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 1 }, result[4].ToArray<int>())); | |||
} | |||
[TestMethod, Ignore] | |||
@@ -39,13 +37,11 @@ namespace TensorFlowNET.UnitTest.Basics | |||
new[] { 3L, 4L }); | |||
var onehot = tf.sparse_tensor_to_dense(decoded_list); | |||
using (var sess = tf.Session()) | |||
{ | |||
var result = sess.run(onehot); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0 }, result[0].ToArray<int>())); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 2, 0 }, result[1].ToArray<int>())); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0 }, result[2].ToArray<int>())); | |||
} | |||
var sess = tf.Session(); | |||
var result = sess.run(onehot); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 1, 0, 0, 0 }, result[0].ToArray<int>())); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 2, 0 }, result[1].ToArray<int>())); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0 }, result[2].ToArray<int>())); | |||
} | |||
[TestMethod] | |||
@@ -56,14 +52,12 @@ namespace TensorFlowNET.UnitTest.Basics | |||
int[,] crops = { { 0, 0 }, { 0, 0 } }; | |||
var tensor = tf.batch_to_space_nd(inputs, block_shape, crops); | |||
using (var sess = tf.Session()) | |||
{ | |||
var result = sess.run(tensor); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 6, 1, 7, 2, 8 }, result[0, 0].ToArray<int>())); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 12, 18, 13, 19, 14, 20 }, result[0, 1].ToArray<int>())); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 9, 4, 10, 5, 11 }, result[0, 2].ToArray<int>())); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray<int>())); | |||
} | |||
var sess = tf.Session(); | |||
var result = sess.run(tensor); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 6, 1, 7, 2, 8 }, result[0, 0].ToArray<int>())); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 12, 18, 13, 19, 14, 20 }, result[0, 1].ToArray<int>())); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 9, 4, 10, 5, 11 }, result[0, 2].ToArray<int>())); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray<int>())); | |||
} | |||
[TestMethod, Ignore] | |||
@@ -72,11 +66,9 @@ namespace TensorFlowNET.UnitTest.Basics | |||
var tensor = new[] { 0, 1, 2, 3 }; | |||
var mask = np.array(new[] { true, false, true, false }); | |||
var masked = tf.boolean_mask(tensor, mask); | |||
using (var sess = tf.Session()) | |||
{ | |||
var result = sess.run(masked); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, masked.ToArray<int>())); | |||
} | |||
var sess = tf.Session(); | |||
var result = sess.run(masked); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, masked.ToArray<int>())); | |||
} | |||
} | |||
} |
@@ -14,7 +14,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||
var v = tf.Variable(new[] { 1, 2 }); | |||
var init = tf.compat.v1.global_variables_initializer(); | |||
using var sess = tf.compat.v1.Session(); | |||
var sess = tf.compat.v1.Session(); | |||
sess.run(init); | |||
// Usage passing the session explicitly. | |||
print(v.eval(sess)); | |||
@@ -16,18 +16,16 @@ namespace TensorFlowNET.UnitTest.ControlFlowTest | |||
{ | |||
var graph = tf.Graph().as_default(); | |||
using (var sess = tf.Session(graph)) | |||
{ | |||
var x = tf.constant(2, name: "x"); | |||
var y = tf.constant(5, name: "y"); | |||
var z = control_flow_ops.cond(tf.less(x, y), | |||
() => tf.constant(22, name: "t22"), | |||
() => tf.constant(55, name: "f55")); | |||
int result = z.eval(sess); | |||
assertEquals(result, 22); | |||
} | |||
var sess = tf.Session(graph); | |||
var x = tf.constant(2, name: "x"); | |||
var y = tf.constant(5, name: "y"); | |||
var z = control_flow_ops.cond(tf.less(x, y), | |||
() => tf.constant(22, name: "t22"), | |||
() => tf.constant(55, name: "f55")); | |||
int result = z.eval(sess); | |||
assertEquals(result, 22); | |||
} | |||
[TestMethod] | |||
@@ -35,18 +33,16 @@ namespace TensorFlowNET.UnitTest.ControlFlowTest | |||
{ | |||
var graph = tf.Graph().as_default(); | |||
using (var sess = tf.Session(graph)) | |||
{ | |||
var x = tf.constant(2, name: "x"); | |||
var y = tf.constant(1, name: "y"); | |||
var sess = tf.Session(graph); | |||
var x = tf.constant(2, name: "x"); | |||
var y = tf.constant(1, name: "y"); | |||
var z = control_flow_ops.cond(tf.less(x, y), | |||
() => tf.constant(22, name: "t22"), | |||
() => tf.constant(11, name: "f11")); | |||
var z = control_flow_ops.cond(tf.less(x, y), | |||
() => tf.constant(22, name: "t22"), | |||
() => tf.constant(11, name: "f11")); | |||
int result = z.eval(sess); | |||
assertEquals(result, 11); | |||
} | |||
int result = z.eval(sess); | |||
assertEquals(result, 11); | |||
} | |||
[Ignore("Dependent on UpdateEdge")] | |||
@@ -23,21 +23,19 @@ namespace TensorFlowNET.UnitTest.ControlFlowTest | |||
private void _testWhileContextHelper(int maximum_iterations) | |||
{ | |||
// TODO: implement missing code dependencies | |||
using (var sess = this.cached_session()) | |||
var sess = this.cached_session(); | |||
var i = constant_op.constant(0, name: "i"); | |||
var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, 10, name: "c")); | |||
var b = new Func<Tensor, Tensor>(x => gen_math_ops.add(x, 1, name: "c")); | |||
//control_flow_ops.while_loop( | |||
// c, b, i , maximum_iterations: tf.constant(maximum_iterations)); | |||
foreach (Operation op in sess.graph.get_operations()) | |||
{ | |||
var i = constant_op.constant(0, name: "i"); | |||
var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, 10, name: "c")); | |||
var b = new Func<Tensor, Tensor>(x => gen_math_ops.add(x, 1, name: "c")); | |||
//control_flow_ops.while_loop( | |||
// c, b, i , maximum_iterations: tf.constant(maximum_iterations)); | |||
foreach (Operation op in sess.graph.get_operations()) | |||
{ | |||
var control_flow_context = op._get_control_flow_context(); | |||
/*if (control_flow_context != null) | |||
self.assertProtoEquals(control_flow_context.to_proto(), | |||
WhileContext.from_proto( | |||
control_flow_context.to_proto()).to_proto(), "");*/ | |||
} | |||
var control_flow_context = op._get_control_flow_context(); | |||
/*if (control_flow_context != null) | |||
self.assertProtoEquals(control_flow_context.to_proto(), | |||
WhileContext.from_proto( | |||
control_flow_context.to_proto()).to_proto(), "");*/ | |||
} | |||
} | |||
@@ -18,11 +18,9 @@ namespace TensorFlowNET.UnitTest.Gradient | |||
var y = tf.broadcast_to(x, (2, 4, 3)); | |||
var grad = tf.gradients(y, x); | |||
using (var sess = tf.Session(graph)) | |||
{ | |||
float result = sess.run(grad[0]); | |||
Assert.AreEqual(result, 24.0f); | |||
} | |||
var sess = tf.Session(graph); | |||
float result = sess.run(grad[0]); | |||
Assert.AreEqual(result, 24.0f); | |||
} | |||
[TestMethod] | |||
@@ -33,11 +31,9 @@ namespace TensorFlowNET.UnitTest.Gradient | |||
var z = tf.cumsum(y, axis: 1); | |||
var grad = tf.gradients(z, x); | |||
using (var sess = tf.Session(graph)) | |||
{ | |||
float result = sess.run(grad[0]); | |||
Assert.AreEqual(result, 60.0f); | |||
} | |||
var sess = tf.Session(graph); | |||
float result = sess.run(grad[0]); | |||
Assert.AreEqual(result, 60.0f); | |||
} | |||
[TestMethod, Ignore] | |||
@@ -78,14 +74,12 @@ namespace TensorFlowNET.UnitTest.Gradient | |||
42.0f, 42.0f, 42.0f, | |||
45.0f, 45.0f, 45.0f | |||
}; | |||
using (var sess = tf.Session()) | |||
{ | |||
var result = sess.run(g); | |||
var resultList = result[0].ToArray<float>().ToList(); | |||
resultList.AddRange(result[1].ToArray<float>()); | |||
Console.WriteLine(result.ToString()); | |||
CollectionAssert.AreEqual(resultList.ToArray(), checkG); | |||
} | |||
var sess = tf.Session(); | |||
var result = sess.run(g); | |||
var resultList = result[0].ToArray<float>().ToList(); | |||
resultList.AddRange(result[1].ToArray<float>()); | |||
Console.WriteLine(result.ToString()); | |||
CollectionAssert.AreEqual(resultList.ToArray(), checkG); | |||
} | |||
[TestMethod] | |||
@@ -97,11 +91,9 @@ namespace TensorFlowNET.UnitTest.Gradient | |||
var y = f(x); | |||
var g = tf.gradients(y, x); | |||
using (var session = tf.Session()) | |||
{ | |||
var result = session.run(new[] { y, g[0] }); | |||
return (result[0].ToArray<T>()[0], result[1].ToArray<T>()[0]); | |||
} | |||
var session = tf.Session(); | |||
var result = session.run(new[] { y, g[0] }); | |||
return (result[0].ToArray<T>()[0], result[1].ToArray<T>()[0]); | |||
} | |||
void test(string name, Func<Tensor, Tensor> tfF, Func<double, (double, double)> targetF, double[] values) | |||
@@ -197,13 +189,11 @@ namespace TensorFlowNET.UnitTest.Gradient | |||
var g1 = tf.gradients(tf.reduce_sum(m, axis: 0)[0], x)[0]; | |||
var g2 = tf.gradients(tf.reduce_sum(m, axis: 1)[0], x)[0]; | |||
using (var session = tf.Session()) | |||
{ | |||
var (r0, r1, r2) = session.run((g0, g1, g2), new FeedItem(x, new[,] { { 1.0 } })); | |||
self.assertFloat64Equal(6.0, r0[0], $"tf.reduce_sum(...)"); | |||
self.assertFloat64Equal(2.0, r1[0], $"tf.reduce_sum(..., axis = 0)"); | |||
self.assertFloat64Equal(3.0, r2[0], $"tf.reduce_sum(..., axis = 1)"); | |||
} | |||
var session = tf.Session(); | |||
var (r0, r1, r2) = session.run((g0, g1, g2), new FeedItem(x, new[,] { { 1.0 } })); | |||
self.assertFloat64Equal(6.0, r0[0], $"tf.reduce_sum(...)"); | |||
self.assertFloat64Equal(2.0, r1[0], $"tf.reduce_sum(..., axis = 0)"); | |||
self.assertFloat64Equal(3.0, r2[0], $"tf.reduce_sum(..., axis = 1)"); | |||
} | |||
[TestMethod] | |||
@@ -212,12 +202,10 @@ namespace TensorFlowNET.UnitTest.Gradient | |||
var a = tf.constant(1f); | |||
var b = tf.tanh(a); | |||
var g = tf.gradients(b, a); | |||
using (var sess = tf.Session()) | |||
{ | |||
var result = sess.run(g); | |||
var actual = result[0]; | |||
Assert.AreEqual(actual, 0.41997434127f); | |||
} | |||
var sess = tf.Session(); | |||
var result = sess.run(g); | |||
var actual = result[0]; | |||
Assert.AreEqual(actual, 0.41997434127f); | |||
} | |||
@@ -227,14 +215,12 @@ namespace TensorFlowNET.UnitTest.Gradient | |||
var a = tf.constant(5f); | |||
var b = tf.lgamma(a); | |||
var g = tf.gradients(b, a); | |||
using (var sess = tf.Session()) | |||
{ | |||
var result = sess.run(new object[] { g, b }); | |||
var actualDeriv = result[0]; | |||
var actual = result[1]; | |||
Assert.AreEqual(actualDeriv, 1.5061177f); | |||
Assert.AreEqual(actual, 3.17805386f); | |||
} | |||
var sess = tf.Session(); | |||
var result = sess.run(new object[] { g, b }); | |||
var actualDeriv = result[0]; | |||
var actual = result[1]; | |||
Assert.AreEqual(actualDeriv, 1.5061177f); | |||
Assert.AreEqual(actual, 3.17805386f); | |||
} | |||
[TestMethod] | |||
@@ -247,14 +233,12 @@ namespace TensorFlowNET.UnitTest.Gradient | |||
tf.constant(new[] { 1 }, tf.int32, new[] { 1 }) | |||
); | |||
var g = tf.gradients(b, a); | |||
using (var sess = tf.Session()) | |||
{ | |||
var result = sess.run(new object[] { g, b }); | |||
var actualDeriv = np.squeeze(result[0]); | |||
var actual = np.squeeze(result[1]); | |||
Assert.AreEqual(actualDeriv, new float[] { 1, 0 }); | |||
Assert.AreEqual(actual, 0.9640276f); | |||
} | |||
var sess = tf.Session(); | |||
var result = sess.run(new object[] { g, b }); | |||
var actualDeriv = np.squeeze(result[0]); | |||
var actual = np.squeeze(result[1]); | |||
Assert.AreEqual(actualDeriv, new float[] { 1, 0 }); | |||
Assert.AreEqual(actual, 0.9640276f); | |||
} | |||
[TestMethod] | |||
@@ -264,14 +248,12 @@ namespace TensorFlowNET.UnitTest.Gradient | |||
var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); | |||
var a = tf.concat(new List<Tensor>(new[] { a1, a2 }), 0); | |||
var g = tf.gradients(a, a1); | |||
using (var sess = tf.Session()) | |||
{ | |||
var result = sess.run(new object[] { g, a }); | |||
var actualDeriv = result[0][0]; | |||
var actual = result[1][0]; | |||
Assert.AreEqual(actualDeriv, 1f); | |||
Assert.AreEqual(actual, 2f); | |||
} | |||
var sess = tf.Session(); | |||
var result = sess.run(new object[] { g, a }); | |||
var actualDeriv = result[0][0]; | |||
var actual = result[1][0]; | |||
Assert.AreEqual(actualDeriv, 1f); | |||
Assert.AreEqual(actual, 2f); | |||
} | |||
[TestMethod] | |||
@@ -280,13 +262,12 @@ namespace TensorFlowNET.UnitTest.Gradient | |||
var ap = tf.constant(1f); | |||
var b = tf.tanh(ap) + gen_array_ops.stop_gradient(ap); | |||
var g = tf.gradients(b, ap); | |||
using (var sess = tf.Session()) | |||
{ | |||
var result = sess.run(g); | |||
var actual = result[0]; | |||
Assert.AreEqual(actual, 0.41997434127f); | |||
} | |||
var sess = tf.Session(); | |||
var result = sess.run(g); | |||
var actual = result[0]; | |||
Assert.AreEqual(actual, 0.41997434127f); | |||
} | |||
[Ignore("TODO")] | |||
[TestMethod] | |||
public void testUnusedOutput() | |||
@@ -74,23 +74,21 @@ namespace TensorFlowNET.UnitTest | |||
var cropSize2_2 = tf.Variable(np.array(4, 4)); | |||
var init = tf.global_variables_initializer(); | |||
using (Session sess = tf.Session()) | |||
{ | |||
sess.run(init); | |||
var sess = tf.Session(); | |||
sess.run(init); | |||
var cropped = tf.image.crop_and_resize(image, box, boxInd, cropSize1_1); | |||
var cropped = tf.image.crop_and_resize(image, box, boxInd, cropSize1_1); | |||
var result = sess.run(cropped); | |||
// check if cropped to 1x1 center was succesfull | |||
Assert.AreEqual(result.size, 1ul); | |||
Assert.AreEqual(result[0, 0, 0, 0], 4f); | |||
var result = sess.run(cropped); | |||
// check if cropped to 1x1 center was succesfull | |||
Assert.AreEqual(result.size, 1ul); | |||
Assert.AreEqual(result[0, 0, 0, 0], 4f); | |||
cropped = tf.image.crop_and_resize(image2, box, boxInd, cropSize2_2); | |||
result = sess.run(cropped); | |||
// check if flipped and no cropping occured | |||
Assert.AreEqual(result.size, 16ul); | |||
Assert.AreEqual(result[0, 0, 0, 0], 12f); | |||
} | |||
cropped = tf.image.crop_and_resize(image2, box, boxInd, cropSize2_2); | |||
result = sess.run(cropped); | |||
// check if flipped and no cropping occured | |||
Assert.AreEqual(result.size, 16ul); | |||
Assert.AreEqual(result[0, 0, 0, 0], 12f); | |||
} | |||
} | |||
} |
@@ -24,7 +24,7 @@ namespace TensorFlowNET.UnitTest | |||
{ | |||
Assert.IsNull(tf.peak_default_graph()); | |||
using var sess = tf.Session(); | |||
var sess = tf.Session(); | |||
var default_graph = tf.get_default_graph(); | |||
var sess_graph = sess.graph; | |||
Assert.IsNotNull(default_graph); | |||
@@ -45,7 +45,7 @@ namespace TensorFlowNET.UnitTest | |||
{ | |||
Assert.IsNull(tf.peak_default_graph()); | |||
//tf.Session created an other graph | |||
using var sess = tf.Session(); | |||
var sess = tf.Session(); | |||
var default_graph = tf.get_default_graph(); | |||
var sess_graph = sess.graph; | |||
Assert.IsNotNull(default_graph); | |||
@@ -69,7 +69,7 @@ namespace TensorFlowNET.UnitTest | |||
beforehand.as_default(); | |||
Assert.IsNotNull(tf.peak_default_graph()); | |||
using var sess = tf.Session(); | |||
var sess = tf.Session(); | |||
var default_graph = tf.peak_default_graph(); | |||
var sess_graph = sess.graph; | |||
Assert.IsNotNull(default_graph); | |||
@@ -102,7 +102,7 @@ namespace TensorFlowNET.UnitTest | |||
//the core method | |||
void Core(int tid) | |||
{ | |||
using var sess = tf.Session(); | |||
var sess = tf.Session(); | |||
for (int i = 0; i < 100; i++) | |||
{ | |||
var t = new Tensor(1); | |||
@@ -119,7 +119,7 @@ namespace TensorFlowNET.UnitTest | |||
void Core(int tid) | |||
{ | |||
//tf.Session created an other graph | |||
using var sess = tf.Session(); | |||
var sess = tf.Session(); | |||
for (int i = 0; i < 100; i++) | |||
{ | |||
var t = new Tensor(new int[] { 1, 2, 3 }); | |||
@@ -142,7 +142,7 @@ namespace TensorFlowNET.UnitTest | |||
var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | |||
var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); | |||
var math = a1 + a2; | |||
using var sess = tf.Session(graph); | |||
var sess = tf.Session(graph); | |||
for (int i = 0; i < 100; i++) | |||
{ | |||
var result = sess.run(math); | |||
@@ -162,7 +162,7 @@ namespace TensorFlowNET.UnitTest | |||
tf.compat.v1.disable_eager_execution(); | |||
var graph = tf.Graph().as_default(); | |||
using var sess = tf.Session(graph); | |||
var sess = tf.Session(graph); | |||
Assert.IsNotNull(tf.get_default_graph()); | |||
//graph is created automatically to perform create these operations | |||
var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | |||
@@ -182,7 +182,7 @@ namespace TensorFlowNET.UnitTest | |||
//the core method | |||
void Core(int tid) | |||
{ | |||
using var sess = tf.Session(); | |||
var sess = tf.Session(); | |||
Assert.IsNotNull(tf.get_default_graph()); | |||
//graph is created automatically to perform create these operations | |||
var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | |||
@@ -182,23 +182,21 @@ namespace TensorFlowNET.UnitTest | |||
// return self._eval_helper(tensors) | |||
// else: | |||
{ | |||
using (var sess = tf.Session()) | |||
var sess = tf.Session(); | |||
var ndarray = tensor.eval(sess); | |||
if (typeof(T) == typeof(double)) | |||
{ | |||
var ndarray = tensor.eval(sess); | |||
if (typeof(T) == typeof(double)) | |||
{ | |||
double x = ndarray; | |||
result = x; | |||
} | |||
else if (typeof(T) == typeof(int)) | |||
{ | |||
int x = ndarray; | |||
result = x; | |||
} | |||
else | |||
{ | |||
result = ndarray; | |||
} | |||
double x = ndarray; | |||
result = x; | |||
} | |||
else if (typeof(T) == typeof(int)) | |||
{ | |||
int x = ndarray; | |||
result = x; | |||
} | |||
else | |||
{ | |||
result = ndarray; | |||
} | |||
return (T)result; | |||
@@ -48,7 +48,7 @@ namespace Tensorflow.Native.UnitTest | |||
private void EXPECT_TF_META(Operation oper, string attr_name, int expected_list_size, TF_AttrType expected_type, uint expected_total_size) | |||
{ | |||
var m = c_api.TF_OperationGetAttrMetadata(oper, attr_name, s_.Handle); | |||
var m = c_api.TF_OperationGetAttrMetadata(oper, attr_name, s_); | |||
EXPECT_EQ(TF_Code.TF_OK, s_.Code); | |||
char e = expected_list_size >= 0 ? (char)1 : (char)0; | |||
/*EXPECT_EQ(e, m.is_list); | |||
@@ -63,7 +63,7 @@ namespace Tensorflow.Native.UnitTest | |||
var desc = init("string"); | |||
c_api.TF_SetAttrString(desc, "v", "bunny", 5); | |||
var oper = c_api.TF_FinishOperation(desc, s_.Handle); | |||
var oper = c_api.TF_FinishOperation(desc, s_); | |||
//ASSERT_EQ(TF_Code.TF_OK, s_.Code); | |||
//EXPECT_TF_META(oper, "v", -1, TF_AttrType.TF_ATTR_STRING, 5); | |||
//var value = new char[5]; | |||
@@ -86,8 +86,6 @@ namespace Tensorflow.Native.UnitTest | |||
public void Dispose() | |||
{ | |||
graph_.Dispose(); | |||
s_.Dispose(); | |||
} | |||
} | |||
} |
@@ -59,7 +59,7 @@ namespace Tensorflow.Native.UnitTest | |||
private void VerifyCollocation(Operation op, string[] expected) | |||
{ | |||
var handle = c_api.TF_OperationGetAttrMetadata(op, "_class", s_.Handle); | |||
var handle = c_api.TF_OperationGetAttrMetadata(op, "_class", s_); | |||
TF_AttrMetadata m = new TF_AttrMetadata(); | |||
if (expected.Length == 0) | |||
{ | |||
@@ -98,8 +98,6 @@ namespace Tensorflow.Native.UnitTest | |||
public void Dispose() | |||
{ | |||
graph_.Dispose(); | |||
s_.Dispose(); | |||
} | |||
} | |||
} |
@@ -45,10 +45,10 @@ namespace Tensorflow.Native.UnitTest | |||
=> c_api.TF_AddInput(desc, input); | |||
protected Operation TF_FinishOperation(OperationDescription desc, Status s) | |||
=> c_api.TF_FinishOperation(desc, s.Handle); | |||
=> c_api.TF_FinishOperation(desc, s); | |||
protected void TF_SetAttrTensor(OperationDescription desc, string attrName, Tensor value, Status s) | |||
=> c_api.TF_SetAttrTensor(desc, attrName, value, s.Handle); | |||
=> c_api.TF_SetAttrTensor(desc, attrName, value, s); | |||
protected void TF_SetAttrType(OperationDescription desc, string attrName, TF_DataType dtype) | |||
=> c_api.TF_SetAttrType(desc, attrName, dtype); | |||
@@ -18,7 +18,7 @@ namespace Tensorflow.Native.UnitTest | |||
string func_name_ = "MyFunc"; | |||
string func_node_name_ = "MyFunc_0"; | |||
Status s_; | |||
IntPtr func_; | |||
SafeFuncGraphHandle func_; | |||
[TestInitialize] | |||
public void Initialize() | |||
@@ -402,7 +402,7 @@ namespace Tensorflow.Native.UnitTest | |||
inputs.Length, inputs.ToArray(), | |||
outputs.Length, outputs.ToArray(), | |||
output_names == null || output_names.Length == 0 ? null : output_names, | |||
IntPtr.Zero, null, s_.Handle); | |||
IntPtr.Zero, null, s_); | |||
if (expect_failure) | |||
{ | |||
@@ -413,7 +413,7 @@ namespace Tensorflow.Native.UnitTest | |||
ASSERT_EQ(TF_OK, s_.Code, s_.Message); | |||
ASSERT_NE(func_, IntPtr.Zero); | |||
ASSERT_EQ(func_name_, c_api.StringPiece(c_api.TF_FunctionName(func_))); | |||
c_api.TF_GraphCopyFunction(host_graph_, func_, IntPtr.Zero, s_.Handle); | |||
c_api.TF_GraphCopyFunction(host_graph_, func_, IntPtr.Zero, s_); | |||
ASSERT_EQ(TF_OK, s_.Code, s_.Message); | |||
} | |||
@@ -44,18 +44,14 @@ namespace Tensorflow.Native.UnitTest | |||
private bool GetGraphDef(Graph graph, out GraphDef graph_def) | |||
{ | |||
graph_def = null; | |||
using (var s = new Status()) | |||
{ | |||
using (var buffer = new Buffer()) | |||
{ | |||
c_api.TF_GraphToGraphDef(graph, buffer.Handle, s.Handle); | |||
bool ret = TF_GetCode(s) == TF_OK; | |||
EXPECT_EQ(TF_OK, TF_GetCode(s)); | |||
if (ret) | |||
graph_def = GraphDef.Parser.ParseFrom(buffer.ToArray()); | |||
return ret; | |||
} | |||
} | |||
var s = new Status(); | |||
var buffer = new Buffer(); | |||
c_api.TF_GraphToGraphDef(graph, buffer, s); | |||
bool ret = TF_GetCode(s) == TF_OK; | |||
EXPECT_EQ(TF_OK, TF_GetCode(s)); | |||
if (ret) | |||
graph_def = GraphDef.Parser.ParseFrom(buffer.ToArray()); | |||
return ret; | |||
} | |||
private void RunGraphsAndCompareOutputs(TF_Output[] grad_outputs, TF_Output[] expected_grad_outputs) | |||
@@ -111,9 +107,9 @@ namespace Tensorflow.Native.UnitTest | |||
IntPtr[] handles = new IntPtr[2] { IntPtr.Zero, IntPtr.Zero }; | |||
c_api.TF_AddGradientsWithPrefix(graph_, prefix, outputs, noutputs, inputs, | |||
ninputs, grad_inputs, s_.Handle, handles); | |||
ninputs, grad_inputs, s_, handles); | |||
var op = new Operation(handles[0]); | |||
// var op = new Operation(handles[0]); | |||
} | |||
else | |||
{ | |||
@@ -275,9 +271,6 @@ namespace Tensorflow.Native.UnitTest | |||
public void Dispose() | |||
{ | |||
graph_.Dispose(); | |||
expected_graph_.Dispose(); | |||
s_.Dispose(); | |||
} | |||
} | |||
} |
@@ -9,7 +9,7 @@ namespace Tensorflow.Native.UnitTest | |||
[TestMethod, Ignore("Waiting to merge https://github.com/tensorflow/tensorflow/pull/43383")] | |||
public void UpdateEdge() | |||
{ | |||
using var graph = new Graph().as_default(); | |||
var graph = new Graph().as_default(); | |||
var one = tf.constant(1, name: "one"); | |||
var two = tf.constant(2, name: "two"); | |||
@@ -35,7 +35,7 @@ namespace Tensorflow.Native.UnitTest | |||
EXPECT_EQ(attr_value.Type, DataType.DtInt32); | |||
// Test not found errors in TF_Operation*() query functions. | |||
EXPECT_EQ(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s.Handle)); | |||
EXPECT_EQ(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s)); | |||
EXPECT_EQ(TF_Code.TF_INVALID_ARGUMENT, s.Code); | |||
Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s)); | |||
EXPECT_EQ("Operation 'feed' has no attr named 'missing'.", s.Message); | |||
@@ -191,9 +191,6 @@ namespace Tensorflow.Native.UnitTest | |||
ASSERT_TRUE(found_scalar_const); | |||
ASSERT_TRUE(found_add); | |||
ASSERT_TRUE(found_neg); | |||
graph.Dispose(); | |||
s.Dispose(); | |||
} | |||
/// <summary> | |||
@@ -213,16 +210,15 @@ namespace Tensorflow.Native.UnitTest | |||
// Export to a GraphDef. | |||
var graph_def = new Buffer(); | |||
c_api.TF_GraphToGraphDef(graph, graph_def.Handle, s.Handle); | |||
c_api.TF_GraphToGraphDef(graph, graph_def, s); | |||
EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
// Import it, with a prefix, in a fresh graph. | |||
graph.Dispose(); | |||
graph = new Graph().as_default(); | |||
using (var opts = c_api.TF_NewImportGraphDefOptions()) | |||
{ | |||
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); | |||
c_api.TF_GraphImportGraphDef(graph, graph_def.Handle, opts, s.Handle); | |||
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); | |||
EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
} | |||
@@ -265,7 +261,7 @@ namespace Tensorflow.Native.UnitTest | |||
EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts)); | |||
c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar"); | |||
EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts)); | |||
var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def.Handle, opts, s.Handle); | |||
var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s); | |||
EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
return results; | |||
@@ -305,7 +301,7 @@ namespace Tensorflow.Native.UnitTest | |||
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported3"); | |||
c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed); | |||
c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2); | |||
c_api.TF_GraphImportGraphDef(graph, graph_def.Handle, opts, s.Handle); | |||
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); | |||
EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
} | |||
@@ -330,7 +326,7 @@ namespace Tensorflow.Native.UnitTest | |||
// Export to a graph def so we can import a graph with control dependencies | |||
graph_def = new Buffer(); | |||
c_api.TF_GraphToGraphDef(graph, graph_def.Handle, s.Handle); | |||
c_api.TF_GraphToGraphDef(graph, graph_def, s); | |||
EXPECT_EQ(TF_Code.TF_OK, s.Code); | |||
// Import again, with remapped control dependency, into the same graph | |||
@@ -338,7 +334,7 @@ namespace Tensorflow.Native.UnitTest | |||
{ | |||
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported4"); | |||
c_api.TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed); | |||
c_api.TF_GraphImportGraphDef(graph, graph_def.Handle, opts, s.Handle); | |||
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); | |||
ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||
} | |||
@@ -380,7 +376,6 @@ namespace Tensorflow.Native.UnitTest | |||
ASSERT_EQ(TF_Code.TF_OK, s.Code); | |||
// Import it in a fresh graph with return outputs. | |||
graph.Dispose(); | |||
graph = new Graph().as_default(); | |||
var opts = new ImportGraphDefOptions(); | |||
opts.AddReturnOutput("feed", 0); | |||
@@ -401,11 +396,6 @@ namespace Tensorflow.Native.UnitTest | |||
EXPECT_EQ(0, return_outputs[0].index); | |||
EXPECT_EQ(scalar, return_outputs[1].oper); | |||
EXPECT_EQ(0, return_outputs[1].index); | |||
opts.Dispose(); | |||
graph_def.Dispose(); | |||
graph.Dispose(); | |||
s.Dispose(); | |||
} | |||
/// <summary> | |||
@@ -422,16 +412,14 @@ namespace Tensorflow.Native.UnitTest | |||
public void ImportGraphMeta() | |||
{ | |||
var dir = "my-save-dir/"; | |||
using (var sess = tf.Session()) | |||
{ | |||
var new_saver = tf.train.import_meta_graph(dir + "my-model-10000.meta"); | |||
new_saver.restore(sess, dir + "my-model-10000"); | |||
var labels = tf.constant(0, dtype: tf.int32, shape: new int[] { 100 }, name: "labels"); | |||
var batch_size = tf.size(labels); | |||
var logits = tf.get_collection<ITensorOrOperation>("logits")[0] as Tensor; | |||
var loss = tf.losses.sparse_softmax_cross_entropy(labels: labels, | |||
logits: logits); | |||
} | |||
var sess = tf.Session(); | |||
var new_saver = tf.train.import_meta_graph(dir + "my-model-10000.meta"); | |||
new_saver.restore(sess, dir + "my-model-10000"); | |||
var labels = tf.constant(0, dtype: tf.int32, shape: new int[] { 100 }, name: "labels"); | |||
var batch_size = tf.size(labels); | |||
var logits = tf.get_collection<ITensorOrOperation>("logits")[0] as Tensor; | |||
var loss = tf.losses.sparse_softmax_cross_entropy(labels: labels, | |||
logits: logits); | |||
} | |||
} | |||
} |
@@ -11,7 +11,7 @@ namespace Tensorflow.Native.UnitTest | |||
/// </summary> | |||
public class CSession | |||
{ | |||
private IntPtr session_; | |||
private SafeSessionHandle session_; | |||
private List<TF_Output> inputs_ = new List<TF_Output>(); | |||
private List<Tensor> input_values_ = new List<Tensor>(); | |||
@@ -22,11 +22,8 @@ namespace Tensorflow.Native.UnitTest | |||
public CSession(Graph graph, Status s, bool user_XLA = false) | |||
{ | |||
lock (Locks.ProcessWide) | |||
{ | |||
var config = new ConfigProto { InterOpParallelismThreads = 4 }; | |||
session_ = new Session(graph, config, s); | |||
} | |||
var config = new ConfigProto { InterOpParallelismThreads = 4 }; | |||
session_ = new Session(graph, config, s); | |||
} | |||
public void SetInputs(Dictionary<Operation, Tensor> inputs) | |||
@@ -85,7 +82,7 @@ namespace Tensorflow.Native.UnitTest | |||
c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length, | |||
outputs_ptr, output_values_ptr, outputs_.Count, | |||
targets_ptr, targets_.Count, | |||
IntPtr.Zero, s.Handle); | |||
IntPtr.Zero, s); | |||
s.Check(); | |||
@@ -14,8 +14,8 @@ namespace Tensorflow.Native.UnitTest.Sessions | |||
[TestMethod] | |||
public void Session() | |||
{ | |||
using var s = new Status(); | |||
using var graph = new Graph(); | |||
var s = new Status(); | |||
var graph = new Graph(); | |||
// Make a placeholder operation. | |||
var feed = c_test_util.Placeholder(graph, s); | |||
@@ -139,45 +139,45 @@ namespace Tensorflow.Native.UnitTest.Tensors | |||
var feed_out_0 = new TF_Output(feed, 0); | |||
// Fetch the shape, it should be completely unknown. | |||
int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle); | |||
int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
EXPECT_EQ(-1, num_dims); | |||
// Set the shape to be unknown, expect no change. | |||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s.Handle); | |||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle); | |||
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | |||
EXPECT_EQ(-1, num_dims); | |||
// Set the shape to be 2 x Unknown | |||
long[] dims = { 2, -1 }; | |||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s.Handle); | |||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s.Handle); | |||
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | |||
EXPECT_EQ(2, num_dims); | |||
// Get the dimension vector appropriately. | |||
var returned_dims = new long[dims.Length]; | |||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); | |||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | |||
// Set to a new valid shape: [2, 3] | |||
dims[1] = 3; | |||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s.Handle); | |||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
// Fetch and see that the new value is returned. | |||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); | |||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | |||
// Try to set 'unknown' with unknown rank on the shape and see that | |||
// it doesn't change. | |||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s.Handle); | |||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); | |||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
EXPECT_EQ(2, num_dims); | |||
EXPECT_EQ(2, (int)returned_dims[0]); | |||
@@ -187,21 +187,21 @@ namespace Tensorflow.Native.UnitTest.Tensors | |||
// it doesn't change. | |||
dims[0] = -1; | |||
dims[1] = -1; | |||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s.Handle); | |||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s.Handle); | |||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
EXPECT_EQ(2, num_dims); | |||
EXPECT_EQ(2, (int)returned_dims[0]); | |||
EXPECT_EQ(3, (int)returned_dims[1]); | |||
// Try to fetch a shape with the wrong num_dims | |||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s.Handle); | |||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); | |||
// Try to set an invalid shape (cannot change 2x3 to a 2x5). | |||
dims[1] = 5; | |||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s.Handle); | |||
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); | |||
// Test for a scalar. | |||
@@ -209,14 +209,13 @@ namespace Tensorflow.Native.UnitTest.Tensors | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
var three_out_0 = new TF_Output(three, 0); | |||
num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s.Handle); | |||
num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
EXPECT_EQ(0, num_dims); | |||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, dims, num_dims, s.Handle); | |||
c_api.TF_GraphGetTensorShape(graph, feed_out_0, dims, num_dims, s); | |||
Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); | |||
graph.Exit(); | |||
s.Dispose(); | |||
} | |||
} | |||
} |
@@ -23,7 +23,7 @@ namespace Tensorflow.Native.UnitTest | |||
c_api.TF_AddInputList(desc, inputs, inputs.Length); | |||
var op = c_api.TF_FinishOperation(desc, s.Handle); | |||
var op = c_api.TF_FinishOperation(desc, s); | |||
s.Check(); | |||
return op; | |||
@@ -33,37 +33,29 @@ namespace Tensorflow.Native.UnitTest | |||
[SuppressMessage("ReSharper", "RedundantAssignment")] | |||
public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) | |||
{ | |||
lock (Locks.ProcessWide) | |||
{ | |||
using (var buffer = new Buffer()) | |||
{ | |||
c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer.Handle, s.Handle); | |||
attr_value = AttrValue.Parser.ParseFrom(buffer.ToArray()); | |||
} | |||
var buffer = new Buffer(); | |||
return s.Code == TF_Code.TF_OK; | |||
} | |||
c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); | |||
attr_value = AttrValue.Parser.ParseFrom(buffer.ToArray()); | |||
return s.Code == TF_Code.TF_OK; | |||
} | |||
public static GraphDef GetGraphDef(Graph graph) | |||
{ | |||
lock (Locks.ProcessWide) | |||
{ | |||
using (var s = new Status()) | |||
using (var buffer = new Buffer()) | |||
{ | |||
c_api.TF_GraphToGraphDef(graph, buffer.Handle, s.Handle); | |||
s.Check(); | |||
return GraphDef.Parser.ParseFrom(buffer.ToArray()); | |||
} | |||
} | |||
var s = new Status(); | |||
var buffer = new Buffer(); | |||
c_api.TF_GraphToGraphDef(graph, buffer, s); | |||
s.Check(); | |||
return GraphDef.Parser.ParseFrom(buffer.ToArray()); | |||
} | |||
public static FunctionDef GetFunctionDef(IntPtr func) | |||
public static FunctionDef GetFunctionDef(SafeFuncGraphHandle func) | |||
{ | |||
using var s = new Status(); | |||
using var buffer = new Buffer(); | |||
c_api.TF_FunctionToFunctionDef(func, buffer.Handle, s.Handle); | |||
var s = new Status(); | |||
var buffer = new Buffer(); | |||
c_api.TF_FunctionToFunctionDef(func, buffer, s); | |||
s.Check(true); | |||
var func_def = FunctionDef.Parser.ParseFrom(buffer.ToArray()); | |||
return func_def; | |||
@@ -192,7 +184,7 @@ namespace Tensorflow.Native.UnitTest | |||
OperationDescription desc = c_api.TF_NewOperation(graph, "Neg", name); | |||
var neg_input = new TF_Output(n, 0); | |||
c_api.TF_AddInput(desc, neg_input); | |||
var op = c_api.TF_FinishOperation(desc, s.Handle); | |||
var op = c_api.TF_FinishOperation(desc, s); | |||
s.Check(); | |||
return op; | |||
@@ -210,7 +202,7 @@ namespace Tensorflow.Native.UnitTest | |||
c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length); | |||
} | |||
var op = c_api.TF_FinishOperation(desc, s.Handle); | |||
var op = c_api.TF_FinishOperation(desc, s); | |||
s.Check(); | |||
return op; | |||
@@ -222,10 +214,10 @@ namespace Tensorflow.Native.UnitTest | |||
lock (Locks.ProcessWide) | |||
{ | |||
var desc = c_api.TF_NewOperation(graph, "Const", name); | |||
c_api.TF_SetAttrTensor(desc, "value", t, s.Handle); | |||
c_api.TF_SetAttrTensor(desc, "value", t, s); | |||
s.Check(); | |||
c_api.TF_SetAttrType(desc, "dtype", t.dtype); | |||
var op = c_api.TF_FinishOperation(desc, s.Handle); | |||
var op = c_api.TF_FinishOperation(desc, s); | |||
s.Check(); | |||
return op; | |||
@@ -17,10 +17,8 @@ namespace TensorFlowNET.UnitTest.Basics | |||
public void ImportGraph() | |||
{ | |||
using (var sess = tf.Session()) | |||
{ | |||
var new_saver = tf.train.import_meta_graph("C:/tmp/my-model.meta"); | |||
} | |||
var sess = tf.Session(); | |||
var new_saver = tf.train.import_meta_graph("C:/tmp/my-model.meta"); | |||
//tf.train.export_meta_graph(filename: "linear_regression.meta.bin"); | |||
// import meta | |||
@@ -60,14 +58,12 @@ namespace TensorFlowNET.UnitTest.Basics | |||
// Add ops to save and restore all the variables. | |||
var saver = tf.train.Saver(); | |||
using (var sess = tf.Session()) | |||
{ | |||
sess.run(init_op); | |||
var sess = tf.Session(); | |||
sess.run(init_op); | |||
// Save the variables to disk. | |||
var save_path = saver.save(sess, "/tmp/model1.ckpt"); | |||
Console.WriteLine($"Model saved in path: {save_path}"); | |||
} | |||
// Save the variables to disk. | |||
var save_path = saver.save(sess, "/tmp/model1.ckpt"); | |||
Console.WriteLine($"Model saved in path: {save_path}"); | |||
} | |||
public void Save2() | |||
@@ -84,17 +80,15 @@ namespace TensorFlowNET.UnitTest.Basics | |||
// Add ops to save and restore all the variables. | |||
var saver = tf.train.Saver(); | |||
using (var sess = tf.Session()) | |||
{ | |||
sess.run(init_op); | |||
// o some work with the model. | |||
inc_v1.op.run(); | |||
dec_v2.op.run(); | |||
// Save the variables to disk. | |||
var save_path = saver.save(sess, "/tmp/model2.ckpt"); | |||
Console.WriteLine($"Model saved in path: {save_path}"); | |||
} | |||
var sess = tf.Session(); | |||
sess.run(init_op); | |||
// o some work with the model. | |||
inc_v1.op.run(); | |||
dec_v2.op.run(); | |||
// Save the variables to disk. | |||
var save_path = saver.save(sess, "/tmp/model2.ckpt"); | |||
Console.WriteLine($"Model saved in path: {save_path}"); | |||
} | |||
} | |||
} |
@@ -57,12 +57,10 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||
var input = tf.placeholder(TF_DataType.TF_FLOAT, new Shape(6)); | |||
var scan = tf.scan(fn, input); | |||
using (var sess = tf.Session()) | |||
{ | |||
sess.run(tf.global_variables_initializer()); | |||
var result = sess.run(scan, new FeedItem(input, np.array(1, 2, 3, 4, 5, 6))); | |||
Assert.AreEqual(new float[] { 1, 3, 6, 10, 15, 21 }, result.ToArray<float>()); | |||
} | |||
var sess = tf.Session(); | |||
sess.run(tf.global_variables_initializer()); | |||
var result = sess.run(scan, new FeedItem(input, np.array(1, 2, 3, 4, 5, 6))); | |||
Assert.AreEqual(new float[] { 1, 3, 6, 10, 15, 21 }, result.ToArray<float>()); | |||
} | |||
} | |||
} |
@@ -196,23 +196,21 @@ namespace TensorFlowNET.UnitTest | |||
// return self._eval_helper(tensors) | |||
// else: | |||
{ | |||
using (var sess = tf.Session()) | |||
var sess = tf.Session(); | |||
var ndarray = tensor.eval(sess); | |||
if (typeof(T) == typeof(double)) | |||
{ | |||
var ndarray = tensor.eval(sess); | |||
if (typeof(T) == typeof(double)) | |||
{ | |||
double x = ndarray; | |||
result = x; | |||
} | |||
else if (typeof(T) == typeof(int)) | |||
{ | |||
int x = ndarray; | |||
result = x; | |||
} | |||
else | |||
{ | |||
result = ndarray; | |||
} | |||
double x = ndarray; | |||
result = x; | |||
} | |||
else if (typeof(T) == typeof(int)) | |||
{ | |||
int x = ndarray; | |||
result = x; | |||
} | |||
else | |||
{ | |||
result = ndarray; | |||
} | |||
return (T)result; | |||
@@ -28,7 +28,6 @@ namespace TensorFlowNET.UnitTest.Basics | |||
public void DeleteStatus() | |||
{ | |||
var s = new Status(); | |||
s.Dispose(); | |||
} | |||
} | |||
} |