From a7c9a75954d219cb606042fcbfbeb1b176781d7e Mon Sep 17 00:00:00 2001 From: Superpiffer Date: Mon, 6 Feb 2023 12:42:56 +0100 Subject: [PATCH] Use a local Status variable Using a local reference ensure that the Status object cannot be disposed before the Dispose. This way it's also possible to use an external Status instance instead of the static one, if needed. --- .../Sessions/BaseSession.cs | 29 ++++++------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 095187b9..4e131b36 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -30,6 +30,7 @@ namespace Tensorflow public class BaseSession : DisposableObject { protected Graph _graph; + protected Status _status; public Graph graph => _graph; public BaseSession(IntPtr handle, Graph g) @@ -48,9 +49,9 @@ namespace Tensorflow } using var opts = new SessionOptions(target, config); - status = status ?? tf.Status; - _handle = c_api.TF_NewSession(_graph, opts.Handle, status.Handle); - status.Check(true); + _status = status ?? tf.Status; + _handle = c_api.TF_NewSession(_graph, opts.Handle, _status.Handle); + _status.Check(true); } public virtual void run(Operation op, params FeedItem[] feed_dict) @@ -217,8 +218,6 @@ namespace Tensorflow // Ensure any changes to the graph are reflected in the runtime. _extend_graph(); - var status = tf.Status; - var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); c_api.TF_SessionRun(_handle, @@ -232,9 +231,9 @@ namespace Tensorflow target_opers: target_list.Select(f => (IntPtr)f).ToArray(), ntargets: target_list.Count, run_metadata: IntPtr.Zero, - status: status.Handle); + status: _status.Handle); - status.Check(true); + _status.Check(true); var result = new NDArray[fetch_list.Length]; @@ -246,8 +245,6 @@ namespace Tensorflow public unsafe Tensor eval(Tensor tensor) { - var status = tf.Status; - var output_values = new IntPtr[1]; var fetch_list = new[] { tensor._as_tf_output() }; @@ -262,9 +259,9 @@ namespace Tensorflow target_opers: new IntPtr[0], ntargets: 0, run_metadata: IntPtr.Zero, - status: status.Handle); + status: _status.Handle); - status.Check(true); + _status.Check(true); return new Tensor(new SafeTensorHandle(output_values[0])); } @@ -291,15 +288,7 @@ namespace Tensorflow protected override void DisposeUnmanagedResources(IntPtr handle) { // c_api.TF_CloseSession(handle, tf.Status.Handle); - if (tf.Status == null || tf.Status.Handle.IsInvalid) - { - using var status = new Status(); - c_api.TF_DeleteSession(handle, status.Handle); - } - else - { - c_api.TF_DeleteSession(handle, tf.Status.Handle); - } + c_api.TF_DeleteSession(handle, _status.Handle); } } }