diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs
index 2f5af971..6eb8f367 100644
--- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs
+++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs
@@ -289,7 +289,7 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern IntPtr TF_LoadSessionFromSavedModel(SafeSessionOptionsHandle session_options, IntPtr run_options,
string export_dir, string[] tags, int tags_len,
- IntPtr graph, ref TF_Buffer meta_graph_def, SafeStatusHandle status);
+ IntPtr graph, IntPtr meta_graph_def, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
public static extern IntPtr TF_NewGraph();
diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs
index 3c994a6e..a740226f 100644
--- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs
+++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs
@@ -36,6 +36,12 @@ namespace Tensorflow
protected byte[] _target;
public Graph graph => _graph;
+ public BaseSession(IntPtr handle, Graph g)
+ {
+ _handle = handle;
+ _graph = g ?? ops.get_default_graph();
+ }
+
public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null)
{
_graph = g ?? ops.get_default_graph();
@@ -291,12 +297,8 @@ namespace Tensorflow
protected override void DisposeUnmanagedResources(IntPtr handle)
{
- lock (Locks.ProcessWide)
- using (var status = new Status())
- {
- c_api.TF_DeleteSession(handle, status.Handle);
- status.Check(true);
- }
+ // c_api.TF_CloseSession(handle, tf.Status.Handle);
+ c_api.TF_DeleteSession(handle, tf.Status.Handle);
}
}
}
diff --git a/src/TensorFlowNET.Core/Sessions/Session.cs b/src/TensorFlowNET.Core/Sessions/Session.cs
index c48715a2..1e94b882 100644
--- a/src/TensorFlowNET.Core/Sessions/Session.cs
+++ b/src/TensorFlowNET.Core/Sessions/Session.cs
@@ -26,10 +26,8 @@ namespace Tensorflow
public Session(string target = "", Graph g = null) : base(target, g, null)
{ }
- public Session(IntPtr handle, Graph g = null) : base("", g, null)
- {
- _handle = handle;
- }
+ public Session(IntPtr handle, Graph g = null) : base(handle, g)
+ { }
public Session(Graph g, ConfigProto config = null, Status s = null) : base("", g, config, s)
{ }
@@ -39,51 +37,29 @@ namespace Tensorflow
return ops.set_default_session(this);
}
- [MethodImpl(MethodImplOptions.NoOptimization)]
public static Session LoadFromSavedModel(string path)
{
- lock (Locks.ProcessWide)
- {
- var graph = c_api.TF_NewGraph();
- using var status = new Status();
- var opt = new SessionOptions();
-
- var tags = new string[] { "serve" };
- var buffer = new TF_Buffer();
-
- IntPtr sess;
- try
- {
- sess = c_api.TF_LoadSessionFromSavedModel(opt.Handle,
- IntPtr.Zero,
- path,
- tags,
- tags.Length,
- graph,
- ref buffer,
- status.Handle);
- status.Check(true);
- }
- catch (TensorflowException ex) when (ex.Message.Contains("Could not find SavedModel"))
- {
- sess = c_api.TF_LoadSessionFromSavedModel(opt.Handle,
- IntPtr.Zero,
- Path.GetFullPath(path),
- tags,
- tags.Length,
- graph,
- ref buffer,
- status.Handle);
- status.Check(true);
- }
-
- // load graph bytes
- // var data = new byte[buffer.length];
- // Marshal.Copy(buffer.data, data, 0, (int)buffer.length);
- // var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/
-
- return new Session(sess, g: new Graph(graph)).as_default();
- }
+ using var graph = new Graph();
+ using var status = new Status();
+ using var opt = c_api.TF_NewSessionOptions();
+
+ var tags = new string[] { "serve" };
+
+ var sess = c_api.TF_LoadSessionFromSavedModel(opt,
+ IntPtr.Zero,
+ path,
+ tags,
+ tags.Length,
+ graph,
+ IntPtr.Zero,
+ status.Handle);
+ status.Check(true);
+
+ // load graph bytes
+ // var data = new byte[buffer.length];
+ // Marshal.Copy(buffer.data, data, 0, (int)buffer.length);
+ // var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/
+ return new Session(sess, g: graph);
}
public static implicit operator IntPtr(Session session) => session._handle;
diff --git a/src/TensorFlowNET.Core/Sessions/c_api.session.cs b/src/TensorFlowNET.Core/Sessions/c_api.session.cs
index 8ac4d53e..548d79e7 100644
--- a/src/TensorFlowNET.Core/Sessions/c_api.session.cs
+++ b/src/TensorFlowNET.Core/Sessions/c_api.session.cs
@@ -21,6 +21,18 @@ namespace Tensorflow
{
public partial class c_api
{
+ ///
+ /// Close a session.
+ ///
+ /// Contacts any other processes associated with the session, if applicable.
+ /// May not be called after TF_DeleteSession().
+ ///
+ ///
+ ///
+
+ [DllImport(TensorFlowLibName)]
+ public static extern void TF_CloseSession(IntPtr session, SafeStatusHandle status);
+
///
/// Destroy a session object.
///
diff --git a/src/TensorFlowNet.Benchmarks/Leak/SavedModelCleanup.cs b/src/TensorFlowNet.Benchmarks/Leak/SavedModelCleanup.cs
index 36b2c0ba..5cdb28f7 100644
--- a/src/TensorFlowNet.Benchmarks/Leak/SavedModelCleanup.cs
+++ b/src/TensorFlowNet.Benchmarks/Leak/SavedModelCleanup.cs
@@ -6,6 +6,7 @@ using System.Linq;
using System.Reflection;
using System.Text;
using System.Threading.Tasks;
+using static Tensorflow.Binding;
namespace Tensorflow.Benchmark.Leak
{
@@ -18,13 +19,9 @@ namespace Tensorflow.Benchmark.Leak
var modelDir = Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location);
var ClassifierModelPath = Path.Combine(modelDir, "Leak", "TestModel", "saved_model");
- for (var i = 0; i < 50; i++)
- {
- var session = Session.LoadFromSavedModel(ClassifierModelPath);
-
- session.graph.Exit();
- session.graph.Dispose();
- session.Dispose();
+ for (var i = 0; i < 1024; i++)
+ {
+ using var sess = Session.LoadFromSavedModel(ClassifierModelPath);
}
}
}
diff --git a/src/TensorFlowNet.Benchmarks/Program.cs b/src/TensorFlowNet.Benchmarks/Program.cs
index 598d7a03..22abf730 100644
--- a/src/TensorFlowNet.Benchmarks/Program.cs
+++ b/src/TensorFlowNet.Benchmarks/Program.cs
@@ -13,7 +13,9 @@ namespace TensorFlowBenchmark
static void Main(string[] args)
{
print(tf.VERSION);
- /*new RepeatDataSetCrash().Run();
+
+ /*new SavedModelCleanup().Run();
+ new RepeatDataSetCrash().Run();
new GpuLeakByCNN().Run();*/
if (args?.Length > 0)
diff --git a/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj b/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj
index ea799b02..ceba6cbb 100644
--- a/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj
+++ b/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj
@@ -37,7 +37,7 @@
-
+