diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs index 2f5af971..6eb8f367 100644 --- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs +++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs @@ -289,7 +289,7 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern IntPtr TF_LoadSessionFromSavedModel(SafeSessionOptionsHandle session_options, IntPtr run_options, string export_dir, string[] tags, int tags_len, - IntPtr graph, ref TF_Buffer meta_graph_def, SafeStatusHandle status); + IntPtr graph, IntPtr meta_graph_def, SafeStatusHandle status); [DllImport(TensorFlowLibName)] public static extern IntPtr TF_NewGraph(); diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 3c994a6e..a740226f 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -36,6 +36,12 @@ namespace Tensorflow protected byte[] _target; public Graph graph => _graph; + public BaseSession(IntPtr handle, Graph g) + { + _handle = handle; + _graph = g ?? ops.get_default_graph(); + } + public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null) { _graph = g ?? ops.get_default_graph(); @@ -291,12 +297,8 @@ namespace Tensorflow protected override void DisposeUnmanagedResources(IntPtr handle) { - lock (Locks.ProcessWide) - using (var status = new Status()) - { - c_api.TF_DeleteSession(handle, status.Handle); - status.Check(true); - } + // c_api.TF_CloseSession(handle, tf.Status.Handle); + c_api.TF_DeleteSession(handle, tf.Status.Handle); } } } diff --git a/src/TensorFlowNET.Core/Sessions/Session.cs b/src/TensorFlowNET.Core/Sessions/Session.cs index c48715a2..1e94b882 100644 --- a/src/TensorFlowNET.Core/Sessions/Session.cs +++ b/src/TensorFlowNET.Core/Sessions/Session.cs @@ -26,10 +26,8 @@ namespace Tensorflow public Session(string target = "", Graph g = null) : base(target, g, null) { } - public Session(IntPtr handle, Graph g = null) : base("", g, null) - { - _handle = handle; - } + public Session(IntPtr handle, Graph g = null) : base(handle, g) + { } public Session(Graph g, ConfigProto config = null, Status s = null) : base("", g, config, s) { } @@ -39,51 +37,29 @@ namespace Tensorflow return ops.set_default_session(this); } - [MethodImpl(MethodImplOptions.NoOptimization)] public static Session LoadFromSavedModel(string path) { - lock (Locks.ProcessWide) - { - var graph = c_api.TF_NewGraph(); - using var status = new Status(); - var opt = new SessionOptions(); - - var tags = new string[] { "serve" }; - var buffer = new TF_Buffer(); - - IntPtr sess; - try - { - sess = c_api.TF_LoadSessionFromSavedModel(opt.Handle, - IntPtr.Zero, - path, - tags, - tags.Length, - graph, - ref buffer, - status.Handle); - status.Check(true); - } - catch (TensorflowException ex) when (ex.Message.Contains("Could not find SavedModel")) - { - sess = c_api.TF_LoadSessionFromSavedModel(opt.Handle, - IntPtr.Zero, - Path.GetFullPath(path), - tags, - tags.Length, - graph, - ref buffer, - status.Handle); - status.Check(true); - } - - // load graph bytes - // var data = new byte[buffer.length]; - // Marshal.Copy(buffer.data, data, 0, (int)buffer.length); - // var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ - - return new Session(sess, g: new Graph(graph)).as_default(); - } + using var graph = new Graph(); + using var status = new Status(); + using var opt = c_api.TF_NewSessionOptions(); + + var tags = new string[] { "serve" }; + + var sess = c_api.TF_LoadSessionFromSavedModel(opt, + IntPtr.Zero, + path, + tags, + tags.Length, + graph, + IntPtr.Zero, + status.Handle); + status.Check(true); + + // load graph bytes + // var data = new byte[buffer.length]; + // Marshal.Copy(buffer.data, data, 0, (int)buffer.length); + // var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ + return new Session(sess, g: graph); } public static implicit operator IntPtr(Session session) => session._handle; diff --git a/src/TensorFlowNET.Core/Sessions/c_api.session.cs b/src/TensorFlowNET.Core/Sessions/c_api.session.cs index 8ac4d53e..548d79e7 100644 --- a/src/TensorFlowNET.Core/Sessions/c_api.session.cs +++ b/src/TensorFlowNET.Core/Sessions/c_api.session.cs @@ -21,6 +21,18 @@ namespace Tensorflow { public partial class c_api { + /// + /// Close a session. + /// + /// Contacts any other processes associated with the session, if applicable. + /// May not be called after TF_DeleteSession(). + /// + /// + /// + + [DllImport(TensorFlowLibName)] + public static extern void TF_CloseSession(IntPtr session, SafeStatusHandle status); + /// /// Destroy a session object. /// diff --git a/src/TensorFlowNet.Benchmarks/Leak/SavedModelCleanup.cs b/src/TensorFlowNet.Benchmarks/Leak/SavedModelCleanup.cs index 36b2c0ba..5cdb28f7 100644 --- a/src/TensorFlowNet.Benchmarks/Leak/SavedModelCleanup.cs +++ b/src/TensorFlowNet.Benchmarks/Leak/SavedModelCleanup.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Reflection; using System.Text; using System.Threading.Tasks; +using static Tensorflow.Binding; namespace Tensorflow.Benchmark.Leak { @@ -18,13 +19,9 @@ namespace Tensorflow.Benchmark.Leak var modelDir = Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location); var ClassifierModelPath = Path.Combine(modelDir, "Leak", "TestModel", "saved_model"); - for (var i = 0; i < 50; i++) - { - var session = Session.LoadFromSavedModel(ClassifierModelPath); - - session.graph.Exit(); - session.graph.Dispose(); - session.Dispose(); + for (var i = 0; i < 1024; i++) + { + using var sess = Session.LoadFromSavedModel(ClassifierModelPath); } } } diff --git a/src/TensorFlowNet.Benchmarks/Program.cs b/src/TensorFlowNet.Benchmarks/Program.cs index 598d7a03..22abf730 100644 --- a/src/TensorFlowNet.Benchmarks/Program.cs +++ b/src/TensorFlowNet.Benchmarks/Program.cs @@ -13,7 +13,9 @@ namespace TensorFlowBenchmark static void Main(string[] args) { print(tf.VERSION); - /*new RepeatDataSetCrash().Run(); + + /*new SavedModelCleanup().Run(); + new RepeatDataSetCrash().Run(); new GpuLeakByCNN().Run();*/ if (args?.Length > 0) diff --git a/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj b/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj index ea799b02..ceba6cbb 100644 --- a/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj +++ b/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj @@ -37,7 +37,7 @@ - +