Browse Source

Use a local Status variable

Using a local reference ensure that the Status object cannot be disposed before the Dispose. This way it's also possible to use an external Status instance instead of the static one, if needed.
tags/v0.100.4-load-saved-model
Superpiffer Haiping 2 years ago
parent
commit
a7c9a75954
1 changed files with 9 additions and 20 deletions
  1. +9
    -20
      src/TensorFlowNET.Core/Sessions/BaseSession.cs

+ 9
- 20
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -30,6 +30,7 @@ namespace Tensorflow
public class BaseSession : DisposableObject
{
protected Graph _graph;
protected Status _status;
public Graph graph => _graph;

public BaseSession(IntPtr handle, Graph g)
@@ -48,9 +49,9 @@ namespace Tensorflow
}
using var opts = new SessionOptions(target, config);
status = status ?? tf.Status;
_handle = c_api.TF_NewSession(_graph, opts.Handle, status.Handle);
status.Check(true);
_status = status ?? tf.Status;
_handle = c_api.TF_NewSession(_graph, opts.Handle, _status.Handle);
_status.Check(true);
}

public virtual void run(Operation op, params FeedItem[] feed_dict)
@@ -217,8 +218,6 @@ namespace Tensorflow
// Ensure any changes to the graph are reflected in the runtime.
_extend_graph();

var status = tf.Status;

var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray();

c_api.TF_SessionRun(_handle,
@@ -232,9 +231,9 @@ namespace Tensorflow
target_opers: target_list.Select(f => (IntPtr)f).ToArray(),
ntargets: target_list.Count,
run_metadata: IntPtr.Zero,
status: status.Handle);
status: _status.Handle);

status.Check(true);
_status.Check(true);

var result = new NDArray[fetch_list.Length];

@@ -246,8 +245,6 @@ namespace Tensorflow

public unsafe Tensor eval(Tensor tensor)
{
var status = tf.Status;

var output_values = new IntPtr[1];
var fetch_list = new[] { tensor._as_tf_output() };

@@ -262,9 +259,9 @@ namespace Tensorflow
target_opers: new IntPtr[0],
ntargets: 0,
run_metadata: IntPtr.Zero,
status: status.Handle);
status: _status.Handle);

status.Check(true);
_status.Check(true);

return new Tensor(new SafeTensorHandle(output_values[0]));
}
@@ -291,15 +288,7 @@ namespace Tensorflow
protected override void DisposeUnmanagedResources(IntPtr handle)
{
// c_api.TF_CloseSession(handle, tf.Status.Handle);
if (tf.Status == null || tf.Status.Handle.IsInvalid)
{
using var status = new Status();
c_api.TF_DeleteSession(handle, status.Handle);
}
else
{
c_api.TF_DeleteSession(handle, tf.Status.Handle);
}
c_api.TF_DeleteSession(handle, _status.Handle);
}
}
}

Loading…
Cancel
Save