@@ -289,7 +289,7 @@ namespace Tensorflow | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TF_LoadSessionFromSavedModel(SafeSessionOptionsHandle session_options, IntPtr run_options, | |||
string export_dir, string[] tags, int tags_len, | |||
IntPtr graph, ref TF_Buffer meta_graph_def, SafeStatusHandle status); | |||
IntPtr graph, IntPtr meta_graph_def, SafeStatusHandle status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TF_NewGraph(); | |||
@@ -36,6 +36,12 @@ namespace Tensorflow | |||
protected byte[] _target; | |||
public Graph graph => _graph; | |||
public BaseSession(IntPtr handle, Graph g) | |||
{ | |||
_handle = handle; | |||
_graph = g ?? ops.get_default_graph(); | |||
} | |||
public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null) | |||
{ | |||
_graph = g ?? ops.get_default_graph(); | |||
@@ -291,12 +297,8 @@ namespace Tensorflow | |||
protected override void DisposeUnmanagedResources(IntPtr handle) | |||
{ | |||
lock (Locks.ProcessWide) | |||
using (var status = new Status()) | |||
{ | |||
c_api.TF_DeleteSession(handle, status.Handle); | |||
status.Check(true); | |||
} | |||
// c_api.TF_CloseSession(handle, tf.Status.Handle); | |||
c_api.TF_DeleteSession(handle, tf.Status.Handle); | |||
} | |||
} | |||
} |
@@ -26,10 +26,8 @@ namespace Tensorflow | |||
public Session(string target = "", Graph g = null) : base(target, g, null) | |||
{ } | |||
public Session(IntPtr handle, Graph g = null) : base("", g, null) | |||
{ | |||
_handle = handle; | |||
} | |||
public Session(IntPtr handle, Graph g = null) : base(handle, g) | |||
{ } | |||
public Session(Graph g, ConfigProto config = null, Status s = null) : base("", g, config, s) | |||
{ } | |||
@@ -39,51 +37,29 @@ namespace Tensorflow | |||
return ops.set_default_session(this); | |||
} | |||
[MethodImpl(MethodImplOptions.NoOptimization)] | |||
public static Session LoadFromSavedModel(string path) | |||
{ | |||
lock (Locks.ProcessWide) | |||
{ | |||
var graph = c_api.TF_NewGraph(); | |||
using var status = new Status(); | |||
var opt = new SessionOptions(); | |||
var tags = new string[] { "serve" }; | |||
var buffer = new TF_Buffer(); | |||
IntPtr sess; | |||
try | |||
{ | |||
sess = c_api.TF_LoadSessionFromSavedModel(opt.Handle, | |||
IntPtr.Zero, | |||
path, | |||
tags, | |||
tags.Length, | |||
graph, | |||
ref buffer, | |||
status.Handle); | |||
status.Check(true); | |||
} | |||
catch (TensorflowException ex) when (ex.Message.Contains("Could not find SavedModel")) | |||
{ | |||
sess = c_api.TF_LoadSessionFromSavedModel(opt.Handle, | |||
IntPtr.Zero, | |||
Path.GetFullPath(path), | |||
tags, | |||
tags.Length, | |||
graph, | |||
ref buffer, | |||
status.Handle); | |||
status.Check(true); | |||
} | |||
// load graph bytes | |||
// var data = new byte[buffer.length]; | |||
// Marshal.Copy(buffer.data, data, 0, (int)buffer.length); | |||
// var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ | |||
return new Session(sess, g: new Graph(graph)).as_default(); | |||
} | |||
using var graph = new Graph(); | |||
using var status = new Status(); | |||
using var opt = c_api.TF_NewSessionOptions(); | |||
var tags = new string[] { "serve" }; | |||
var sess = c_api.TF_LoadSessionFromSavedModel(opt, | |||
IntPtr.Zero, | |||
path, | |||
tags, | |||
tags.Length, | |||
graph, | |||
IntPtr.Zero, | |||
status.Handle); | |||
status.Check(true); | |||
// load graph bytes | |||
// var data = new byte[buffer.length]; | |||
// Marshal.Copy(buffer.data, data, 0, (int)buffer.length); | |||
// var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ | |||
return new Session(sess, g: graph); | |||
} | |||
public static implicit operator IntPtr(Session session) => session._handle; | |||
@@ -21,6 +21,18 @@ namespace Tensorflow | |||
{ | |||
public partial class c_api | |||
{ | |||
/// <summary> | |||
/// Close a session. | |||
/// | |||
/// Contacts any other processes associated with the session, if applicable. | |||
/// May not be called after TF_DeleteSession(). | |||
/// </summary> | |||
/// <param name="s"></param> | |||
/// <param name="status"></param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_CloseSession(IntPtr session, SafeStatusHandle status); | |||
/// <summary> | |||
/// Destroy a session object. | |||
/// | |||
@@ -6,6 +6,7 @@ using System.Linq; | |||
using System.Reflection; | |||
using System.Text; | |||
using System.Threading.Tasks; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Benchmark.Leak | |||
{ | |||
@@ -18,13 +19,9 @@ namespace Tensorflow.Benchmark.Leak | |||
var modelDir = Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location); | |||
var ClassifierModelPath = Path.Combine(modelDir, "Leak", "TestModel", "saved_model"); | |||
for (var i = 0; i < 50; i++) | |||
{ | |||
var session = Session.LoadFromSavedModel(ClassifierModelPath); | |||
session.graph.Exit(); | |||
session.graph.Dispose(); | |||
session.Dispose(); | |||
for (var i = 0; i < 1024; i++) | |||
{ | |||
using var sess = Session.LoadFromSavedModel(ClassifierModelPath); | |||
} | |||
} | |||
} | |||
@@ -13,7 +13,9 @@ namespace TensorFlowBenchmark | |||
static void Main(string[] args) | |||
{ | |||
print(tf.VERSION); | |||
/*new RepeatDataSetCrash().Run(); | |||
/*new SavedModelCleanup().Run(); | |||
new RepeatDataSetCrash().Run(); | |||
new GpuLeakByCNN().Run();*/ | |||
if (args?.Length > 0) | |||
@@ -37,7 +37,7 @@ | |||
<ItemGroup> | |||
<PackageReference Include="BenchmarkDotNet" Version="0.13.0" /> | |||
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.6.0-rc0" /> | |||
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.6.0" /> | |||
</ItemGroup> | |||
<ItemGroup> | |||