From 7d2d186fe38aded107aa7e3114f8dcad4e29ed75 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 27 Jun 2020 13:44:34 -0500 Subject: [PATCH] Use shared status object. --- .../Eager/EagerOperation.cs | 3 +- .../Eager/EagerTensor.Creation.cs | 6 +-- src/TensorFlowNET.Core/Eager/EagerTensor.cs | 13 +++---- src/TensorFlowNET.Core/Graphs/Graph.cs | 2 +- .../Operations/Operation.Input.cs | 8 ++-- .../Operations/Operation.Output.cs | 8 +--- .../Operations/Operation.cs | 19 +++++---- .../Operations/array_ops.cs | 26 ++++++------- .../Sessions/BaseSession.cs | 4 +- src/TensorFlowNET.Core/Sessions/Session.cs | 3 +- .../Tensors/EagerTensorV2.cs | 8 ++-- .../Tensors/Tensor.Conversions.cs | 8 ++-- .../Tensors/Tensor.Creation.cs | 21 ++++------ .../Tensors/Tensor.Value.cs | 21 +++++----- src/TensorFlowNET.Core/Tensors/Tensor.cs | 31 +++++---------- src/TensorFlowNET.Core/ops.cs | 39 +++++++++---------- src/TensorFlowNET.Core/tensorflow.cs | 6 +-- test/TensorFlowNET.UnitTest/ConstantTest.cs | 4 +- .../NativeAPI/Eager/GradientEagerTest.cs | 21 +++++++++- 19 files changed, 115 insertions(+), 136 deletions(-) diff --git a/src/TensorFlowNET.Core/Eager/EagerOperation.cs b/src/TensorFlowNET.Core/Eager/EagerOperation.cs index ce742904..fe0054cf 100644 --- a/src/TensorFlowNET.Core/Eager/EagerOperation.cs +++ b/src/TensorFlowNET.Core/Eager/EagerOperation.cs @@ -53,8 +53,7 @@ namespace Tensorflow.Eager { object value = null; byte isList = 0; - using var status = new Status(); - var attrType = c_api.TFE_OpNameGetAttrType(tf.context, Name, attr_name, ref isList, status); + var attrType = c_api.TFE_OpNameGetAttrType(tf.context, Name, attr_name, ref isList, tf.status); switch (attrType) { case TF_AttrType.TF_ATTR_BOOL: diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs index b754c913..ce227185 100644 --- a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs @@ -22,13 +22,13 @@ namespace Tensorflow.Eager public EagerTensor(string value, string device_name) : base(value) { - EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, status); + EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.status); Resolve(); } public EagerTensor(NDArray value, string device_name) : base(value) { - EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, status); + EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.status); Resolve(); } @@ -37,7 +37,7 @@ namespace Tensorflow.Eager _id = get_uid(); if (_handle == IntPtr.Zero) - _handle = c_api.TFE_TensorHandleResolve(EagerTensorHandle, status); + _handle = c_api.TFE_TensorHandleResolve(EagerTensorHandle, tf.status); //print($"new Tensor {Id} {_handle.ToString("x16")}"); //print($"new TensorHandle {Id} {EagerTensorHandle.ToString("x16")}"); diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.cs index eac0aece..d084ffae 100644 --- a/src/TensorFlowNET.Core/Eager/EagerTensor.cs +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.cs @@ -8,26 +8,23 @@ namespace Tensorflow.Eager { public partial class EagerTensor : Tensor { - Status status = new Status(); public IntPtr EagerTensorHandle; - public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(EagerTensorHandle, status)); + public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(EagerTensorHandle, tf.status)); - public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, status); + public override int rank => c_api.TFE_TensorHandleNumDims(EagerTensorHandle, tf.status); public static int GetRank(IntPtr handle) { var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); - using var status = new Status(); - return c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, status); + return c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, tf.status); } public static int[] GetDims(IntPtr handle) { var tfe_tensor_handle = c_api.TFE_EagerTensorHandle(handle); - using var status = new Status(); - var dims = new int[c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, status)]; + var dims = new int[c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, tf.status)]; for (int i = 0; i < dims.Length; i++) - dims[i] = c_api.TFE_TensorHandleDim(tfe_tensor_handle, i, status); + dims[i] = c_api.TFE_TensorHandleDim(tfe_tensor_handle, i, tf.status); return dims; } diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 8ae3a15c..de47d15a 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -512,7 +512,7 @@ namespace Tensorflow public TensorShape GetTensorShape(TF_Output output) { - var status = new Status(); + var status = tf.status; var ndim = c_api.TF_GraphGetTensorNumDims(_handle, output, status); status.Check(); diff --git a/src/TensorFlowNET.Core/Operations/Operation.Input.cs b/src/TensorFlowNET.Core/Operations/Operation.Input.cs index 48f1800b..3941425d 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Input.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Input.cs @@ -17,6 +17,7 @@ using System; using System.Linq; using System.Runtime.InteropServices; +using static Tensorflow.Binding; namespace Tensorflow { @@ -30,11 +31,8 @@ namespace Tensorflow public int InputListLength(string name) { int num = 0; - using(var status = new Status()) - { - num = c_api.TF_OperationInputListLength(_handle, name, status); - status.Check(true); - } + num = c_api.TF_OperationInputListLength(_handle, name, tf.status); + tf.status.Check(true); return num; } public int NumInputs => c_api.TF_OperationNumInputs(_handle); diff --git a/src/TensorFlowNET.Core/Operations/Operation.Output.cs b/src/TensorFlowNET.Core/Operations/Operation.Output.cs index b283d988..72bd3db0 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Output.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Output.cs @@ -28,12 +28,8 @@ namespace Tensorflow public int OutputListLength(string name) { - int num = 0; - using (var status = new Status()) - { - num = c_api.TF_OperationOutputListLength(_handle, name, status); - status.Check(true); - } + int num = c_api.TF_OperationOutputListLength(_handle, name, tf.status); + tf.status.Check(true); return num; } diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 7f2466ee..4d927f76 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -20,6 +20,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; using Tensorflow.Util; +using static Tensorflow.Binding; namespace Tensorflow { @@ -233,14 +234,13 @@ namespace Tensorflow AttrValue x = null; lock (Locks.ProcessWide) - using (var status = new Status()) - using (var buf = new Buffer()) - { - c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); - status.Check(true); + { + using var buf = new Buffer(); + c_api.TF_OperationGetAttrValueProto(_handle, name, buf, tf.status); + tf.status.Check(true); - x = AttrValue.Parser.ParseFrom(buf.MemoryBlock.Stream()); - } + x = AttrValue.Parser.ParseFrom(buf.MemoryBlock.Stream()); + } string oneof_value = x.ValueCase.ToString(); if (string.IsNullOrEmpty(oneof_value)) @@ -295,11 +295,10 @@ namespace Tensorflow // after the c_api call next time _inputs is accessed // the updated inputs are reloaded from the c_api lock (Locks.ProcessWide) - using (var status = new Status()) { - c_api.UpdateEdge(_graph, output, input, status); + c_api.UpdateEdge(_graph, output, input, tf.status); //var updated_inputs = inputs; - status.Check(); + tf.status.Check(); } } diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index 23c6febb..dab9d3ec 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -43,30 +43,26 @@ namespace Tensorflow allow_broadcast: false); public static Tensor zeros(TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) - { - dtype = dtype.as_base_dtype(); - return tf_with(ops.name_scope(name, "zeros", shape), scope => + => tf_with(ops.name_scope(name, "zeros", shape), scope => { + dtype = dtype.as_base_dtype(); name = scope; + var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape); + Tensor zeros = null; switch (dtype) { - case TF_DataType.TF_BOOL: - return _constant_if_small(false, shape, dtype, name); case TF_DataType.TF_DOUBLE: - return _constant_if_small(0.0D, shape, dtype, name); + zeros = constant(0d); + break; case TF_DataType.TF_FLOAT: - return _constant_if_small(0.0F, shape, dtype, name); - case TF_DataType.TF_INT64: - return _constant_if_small(0L, shape, dtype, name); - case TF_DataType.TF_INT32: - return _constant_if_small(0, shape, dtype, name); - case TF_DataType.TF_INT8: - return _constant_if_small(0, shape, dtype, name); + zeros = constant(0f); + break; default: - throw new TypeError("can't find type for zeros"); + zeros = constant(0); + break; } + return fill(shape_tensor, zeros, name: name); }); - } public static Tensor boolean_mask(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0) { diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 5fcdc547..7992db2d 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -22,7 +22,7 @@ using System.Linq; using System.Numerics; using System.Text; using Google.Protobuf; -using NumSharp.Backends; +using static Tensorflow.Binding; using Tensorflow.Util; namespace Tensorflow @@ -236,7 +236,7 @@ namespace Tensorflow // Ensure any changes to the graph are reflected in the runtime. _extend_graph(); - var status = new Status(); + var status = tf.status; var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); diff --git a/src/TensorFlowNET.Core/Sessions/Session.cs b/src/TensorFlowNET.Core/Sessions/Session.cs index c18df439..a38764fc 100644 --- a/src/TensorFlowNET.Core/Sessions/Session.cs +++ b/src/TensorFlowNET.Core/Sessions/Session.cs @@ -46,7 +46,7 @@ namespace Tensorflow lock (Locks.ProcessWide) { var graph = c_api.TF_NewGraph(); - var status = new Status(); + using var status = new Status(); var opt = new SessionOptions(); var tags = new string[] {"serve"}; @@ -66,7 +66,6 @@ namespace Tensorflow status.Check(true); } catch (TensorflowException ex) when (ex.Message.Contains("Could not find SavedModel")) { - status = new Status(); sess = c_api.TF_LoadSessionFromSavedModel(opt, IntPtr.Zero, Path.GetFullPath(path), diff --git a/src/TensorFlowNET.Core/Tensors/EagerTensorV2.cs b/src/TensorFlowNET.Core/Tensors/EagerTensorV2.cs index e4be9811..08e9e964 100644 --- a/src/TensorFlowNET.Core/Tensors/EagerTensorV2.cs +++ b/src/TensorFlowNET.Core/Tensors/EagerTensorV2.cs @@ -13,14 +13,12 @@ namespace Tensorflow public class EagerTensorV2 : DisposableObject, ITensor { IntPtr EagerTensorHandle; - public string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(EagerTensorHandle, status)); - - static Status status = new Status(); + public string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(EagerTensorHandle, tf.status)); public EagerTensorV2(IntPtr handle) { EagerTensorHandle = c_api.TFE_EagerTensorHandle(handle); - _handle = c_api.TFE_TensorHandleResolve(EagerTensorHandle, status); + _handle = c_api.TFE_TensorHandleResolve(EagerTensorHandle, tf.status); } public unsafe EagerTensorV2(NDArray nd, string device_name = "") @@ -40,7 +38,7 @@ namespace Tensorflow }, IntPtr.Zero); - EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, status); + EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.status); } /*public unsafe EagerTensorV2(float[,] value) diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs index eb04814c..e83978fc 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Conversions.cs @@ -21,6 +21,7 @@ using System.Globalization; using System.Runtime.CompilerServices; using System.Text; using NumSharp.Utilities; +using static Tensorflow.Binding; namespace Tensorflow { @@ -69,11 +70,8 @@ namespace Tensorflow IntPtr stringStartAddress = IntPtr.Zero; UIntPtr dstLen = UIntPtr.Zero; - using (var status = new Status()) - { - c_api.TF_StringDecode((byte*) this.buffer + 8, (UIntPtr) (this.bytesize), (byte**) &stringStartAddress, &dstLen, status); - status.Check(true); - } + c_api.TF_StringDecode((byte*) this.buffer + 8, (UIntPtr) (this.bytesize), (byte**) &stringStartAddress, &dstLen, tf.status); + tf.status.Check(true); var dstLenInt = checked((int) dstLen); var value = Encoding.UTF8.GetString((byte*) stringStartAddress, dstLenInt); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index 1f01f709..f87f80bb 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -451,7 +451,6 @@ namespace Tensorflow /// public unsafe Tensor(string str) { - var status = new Status(); var buffer = Encoding.UTF8.GetBytes(str); var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); @@ -460,9 +459,9 @@ namespace Tensorflow IntPtr tensor = c_api.TF_TensorData(handle); Marshal.WriteInt64(tensor, 0); fixed (byte* src = buffer) - c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(long)), size, status); + c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(long)), size, tf.status); _handle = handle; - status.Check(true); + tf.status.Check(true); } public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null) @@ -483,10 +482,8 @@ namespace Tensorflow IntPtr tensor = c_api.TF_TensorData(handle); Marshal.WriteInt64(tensor, 0); - var status = new Status(); - c_api.TF_StringEncode((byte*) nd.Unsafe.Address, bytesLength, (sbyte*) (tensor + sizeof(Int64)), size, status); - - status.Check(true); + c_api.TF_StringEncode((byte*) nd.Unsafe.Address, bytesLength, (sbyte*) (tensor + sizeof(Int64)), size, tf.status); + tf.status.Check(true); _handle = handle; } else { @@ -498,11 +495,10 @@ namespace Tensorflow IntPtr tensor = c_api.TF_TensorData(handle); Marshal.WriteInt64(tensor, 0); - var status = new Status(); fixed (byte* src = buffer) - c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(Int64)), size, status); + c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(Int64)), size, tf.status); - status.Check(true); + tf.status.Check(true); _handle = handle; } @@ -607,11 +603,10 @@ namespace Tensorflow IntPtr tensor = c_api.TF_TensorData(handle); Marshal.WriteInt64(tensor, 0); - var status = new Status(); fixed (byte* src = buffer) - c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(Int64)), size, status); + c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(long)), size, tf.status); - status.Check(true); + tf.status.Check(true); return handle; } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs index d70e9555..ea9e68ee 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Value.cs @@ -3,7 +3,7 @@ using NumSharp.Backends; using NumSharp.Backends.Unmanaged; using NumSharp.Utilities; using System; -using System.Collections.Generic; +using static Tensorflow.Binding; using System.Runtime.InteropServices; using System.Text; @@ -237,18 +237,15 @@ namespace Tensorflow var src = c_api.TF_TensorData(_handle); var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize); src += (int)(size * 8); - using (var status = new Status()) + for (int i = 0; i < buffer.Length; i++) { - for (int i = 0; i < buffer.Length; i++) - { - IntPtr dst = IntPtr.Zero; - UIntPtr dstLen = UIntPtr.Zero; - var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, status); - status.Check(true); - buffer[i] = new byte[(int)dstLen]; - Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); - src += (int)read; - } + IntPtr dst = IntPtr.Zero; + UIntPtr dstLen = UIntPtr.Zero; + var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, tf.status); + tf.status.Check(true); + buffer[i] = new byte[(int)dstLen]; + Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); + src += (int)read; } var _str = new string[buffer.Length]; diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 2b8407a7..506defba 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -22,7 +22,7 @@ using System.Globalization; using System.Linq; using System.Runtime.InteropServices; using System.Text; -using System.Threading.Tasks; +using static Tensorflow.Binding; using Tensorflow.Framework; namespace Tensorflow @@ -109,11 +109,7 @@ namespace Tensorflow if (_handle == IntPtr.Zero) { - using (var status = new Status()) - { - c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); - status.Check(); - } + c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, tf.status); } else { @@ -126,15 +122,12 @@ namespace Tensorflow set { - using (var status = new Status()) - { - if (value == null) - c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, status); - else - c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status); + if (value == null) + c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.status); + else + c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, tf.status); - status.Check(true); - } + tf.status.Check(true); } } @@ -178,13 +171,9 @@ namespace Tensorflow { if (_handle == IntPtr.Zero) { - using (var status = new Status()) - { - var output = _as_tf_output(); - int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, status); - status.Check(); - return ndim; - } + var output = _as_tf_output(); + int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, tf.status); + return ndim; } return c_api.TF_NumDims(_handle); diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 061631ae..e06265fd 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -176,30 +176,29 @@ namespace Tensorflow throw new NotImplementedException("_create_c_op"); } - using (var status = new Status()) + var status = tf.status; + + // Add control inputs + foreach (var control_input in control_inputs) + c_api.TF_AddControlInput(op_desc, control_input); + + // Add attrs + foreach (var attr in node_def.Attr) { - // Add control inputs - foreach (var control_input in control_inputs) - c_api.TF_AddControlInput(op_desc, control_input); - - // Add attrs - foreach (var attr in node_def.Attr) - { - var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. - var protoHandle = Marshal.AllocHGlobal(bytes.Length); - Marshal.Copy(bytes, 0, protoHandle, bytes.Length); - uint len = (uint)bytes.Length; - c_api.TF_SetAttrValueProto(op_desc, attr.Key, protoHandle, proto_len: len, status: status); - status.Check(true); - Marshal.FreeHGlobal(protoHandle); - } + var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. + var protoHandle = Marshal.AllocHGlobal(bytes.Length); + Marshal.Copy(bytes, 0, protoHandle, bytes.Length); + uint len = (uint)bytes.Length; + c_api.TF_SetAttrValueProto(op_desc, attr.Key, protoHandle, proto_len: len, status: status); + status.Check(true); + Marshal.FreeHGlobal(protoHandle); + } - var c_op = c_api.TF_FinishOperation(op_desc, status); + var c_op = c_api.TF_FinishOperation(op_desc, status); - status.Check(true); + status.Check(true); - return c_op; - } + return c_op; } } diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index 422ff1a0..535be9ff 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -22,7 +22,6 @@ using System.Runtime.InteropServices; using System.Threading; using Tensorflow.Eager; using Tensorflow.Gradients; -using static Tensorflow.Binding; namespace Tensorflow { @@ -42,11 +41,12 @@ namespace Tensorflow public delegate Tensor[] BackwardFunction(Tensor[] grads, long[] unneeded_gradients); + public Status status = new Status(); public OpDefLibrary _op_def_lib = new OpDefLibrary(); + public Context context = new Context(new ContextOptions(), new Status()); public Execute _execute = new Execute(); public IEagerRunner Runner = new EagerRunner(); - public Context context = new Context(new ContextOptions(), new Status()); - + public tensorflow() { enable_eager_execution(); diff --git a/test/TensorFlowNET.UnitTest/ConstantTest.cs b/test/TensorFlowNET.UnitTest/ConstantTest.cs index dc3f72f0..45dc64fe 100644 --- a/test/TensorFlowNET.UnitTest/ConstantTest.cs +++ b/test/TensorFlowNET.UnitTest/ConstantTest.cs @@ -96,14 +96,14 @@ namespace TensorFlowNET.UnitTest.Basics public void ZerosConst() { // small size - var tensor = tf.zeros(new Shape(3, 2), tf.int32, "small"); + var tensor = tf.zeros((3, 2), tf.int32, "small"); Assert.AreEqual(tensor.shape[0], 3); Assert.AreEqual(tensor.shape[1], 2); Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 0, 0 }, tensor.numpy().ToArray())); // big size - tensor = tf.zeros(new Shape(200, 100), tf.int32, "big"); + tensor = tf.zeros((200, 100), tf.int32, "big"); Assert.AreEqual(tensor.shape[0], 200); Assert.AreEqual(tensor.shape[1], 100); diff --git a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/GradientEagerTest.cs b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/GradientEagerTest.cs index a39a55ea..c2d61e1f 100644 --- a/test/TensorFlowNET.UnitTest/NativeAPI/Eager/GradientEagerTest.cs +++ b/test/TensorFlowNET.UnitTest/NativeAPI/Eager/GradientEagerTest.cs @@ -35,7 +35,26 @@ namespace TensorFlowNET.UnitTest.Gradient var dz_dx = tape.gradient(z, x); var expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f }; - Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.numpy().ToArray(), expected)); + Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray(), expected)); + } + + [TestMethod] + public void PersistentTape() + { + var x = tf.ones((2, 2)); + using var tape = tf.GradientTape(persistent: true); + tape.watch(x); + var y = tf.reduce_sum(x); + var z = tf.multiply(y, y); + var dz_dx = tape.gradient(z, x); + + var expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f }; + Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray(), expected)); + + var dz_dy = tape.gradient(z, y); + + expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f }; + Assert.IsTrue(Enumerable.SequenceEqual(dz_dx.ToArray(), expected)); } } }