@@ -14,6 +14,7 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using Google.Protobuf; | |||||
using System.IO; | using System.IO; | ||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
@@ -37,7 +38,9 @@ namespace Tensorflow | |||||
using (var buffer = ToGraphDef(status)) | using (var buffer = ToGraphDef(status)) | ||||
{ | { | ||||
status.Check(true); | status.Check(true); | ||||
def = GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
// limit size to 250M, recursion to max 100 | |||||
var inputStream = CodedInputStream.CreateWithLimits(buffer.MemoryBlock.Stream(), 250 * 1024 * 1024, 100); | |||||
def = GraphDef.Parser.ParseFrom(inputStream); | |||||
} | } | ||||
// Strip the experimental library field iff it's empty. | // Strip the experimental library field iff it's empty. | ||||
@@ -15,6 +15,8 @@ | |||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System; | using System; | ||||
using System.IO; | |||||
using System.Runtime.CompilerServices; | |||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -39,6 +41,7 @@ namespace Tensorflow | |||||
return this; | return this; | ||||
} | } | ||||
[MethodImpl(MethodImplOptions.NoOptimization)] | |||||
public static Session LoadFromSavedModel(string path) | public static Session LoadFromSavedModel(string path) | ||||
{ | { | ||||
lock (Locks.ProcessWide) | lock (Locks.ProcessWide) | ||||
@@ -50,20 +53,36 @@ namespace Tensorflow | |||||
var tags = new string[] {"serve"}; | var tags = new string[] {"serve"}; | ||||
var buffer = new TF_Buffer(); | var buffer = new TF_Buffer(); | ||||
var sess = c_api.TF_LoadSessionFromSavedModel(opt, | |||||
IntPtr.Zero, | |||||
path, | |||||
tags, | |||||
tags.Length, | |||||
graph, | |||||
ref buffer, | |||||
status); | |||||
IntPtr sess; | |||||
try | |||||
{ | |||||
sess = c_api.TF_LoadSessionFromSavedModel(opt, | |||||
IntPtr.Zero, | |||||
path, | |||||
tags, | |||||
tags.Length, | |||||
graph, | |||||
ref buffer, | |||||
status); | |||||
status.Check(true); | |||||
} catch (TensorflowException ex) when (ex.Message.Contains("Could not find SavedModel")) | |||||
{ | |||||
status = new Status(); | |||||
sess = c_api.TF_LoadSessionFromSavedModel(opt, | |||||
IntPtr.Zero, | |||||
Path.GetFullPath(path), | |||||
tags, | |||||
tags.Length, | |||||
graph, | |||||
ref buffer, | |||||
status); | |||||
status.Check(true); | |||||
} | |||||
// load graph bytes | // load graph bytes | ||||
// var data = new byte[buffer.length]; | // var data = new byte[buffer.length]; | ||||
// Marshal.Copy(buffer.data, data, 0, (int)buffer.length); | // Marshal.Copy(buffer.data, data, 0, (int)buffer.length); | ||||
// var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ | // var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ | ||||
status.Check(true); | |||||
return new Session(sess, g: new Graph(graph)).as_default(); | return new Session(sess, g: new Graph(graph)).as_default(); | ||||
} | } | ||||
@@ -1,8 +1,11 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.IO; | |||||
using System.Linq; | |||||
using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
using FluentAssertions; | using FluentAssertions; | ||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using NumSharp; | |||||
using Tensorflow; | using Tensorflow; | ||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -260,7 +263,7 @@ namespace TensorFlowNET.UnitTest | |||||
} | } | ||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
public void TF_GraphOperationByName() | public void TF_GraphOperationByName() | ||||
{ | { | ||||
@@ -280,5 +283,46 @@ namespace TensorFlowNET.UnitTest | |||||
} | } | ||||
} | } | ||||
} | } | ||||
private static string modelPath = "./model/"; | |||||
[TestMethod] | |||||
public void TF_GraphOperationByName_FromModel() | |||||
{ | |||||
MultiThreadedUnitTestExecuter.Run(8, Core); | |||||
//the core method | |||||
void Core(int tid) | |||||
{ | |||||
Console.WriteLine(); | |||||
for (int j = 0; j < 100; j++) | |||||
{ | |||||
var sess = Session.LoadFromSavedModel(modelPath).as_default(); | |||||
var inputs = new[] {"sp", "fuel"}; | |||||
var inp = inputs.Select(name => sess.graph.OperationByName(name).output).ToArray(); | |||||
var outp = sess.graph.OperationByName("softmax_tensor").output; | |||||
for (var i = 0; i < 100; i++) | |||||
{ | |||||
{ | |||||
var data = new float[96]; | |||||
FeedItem[] feeds = new FeedItem[2]; | |||||
for (int f = 0; f < 2; f++) | |||||
feeds[f] = new FeedItem(inp[f], new NDArray(data)); | |||||
try | |||||
{ | |||||
sess.run(outp, feeds); | |||||
} catch (Exception ex) | |||||
{ | |||||
Console.WriteLine(ex); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | } | ||||
} | } |
@@ -42,4 +42,10 @@ | |||||
<ProjectReference Include="..\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj" /> | <ProjectReference Include="..\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj" /> | ||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | |||||
<None Update="model\saved_model.pb"> | |||||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||||
</None> | |||||
</ItemGroup> | |||||
</Project> | </Project> |