@@ -14,6 +14,7 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using Google.Protobuf; | |||
using System.IO; | |||
using Tensorflow.Util; | |||
@@ -37,7 +38,9 @@ namespace Tensorflow | |||
using (var buffer = ToGraphDef(status)) | |||
{ | |||
status.Check(true); | |||
def = GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||
// limit size to 250M, recursion to max 100 | |||
var inputStream = CodedInputStream.CreateWithLimits(buffer.MemoryBlock.Stream(), 250 * 1024 * 1024, 100); | |||
def = GraphDef.Parser.ParseFrom(inputStream); | |||
} | |||
// Strip the experimental library field iff it's empty. | |||
@@ -15,6 +15,8 @@ | |||
******************************************************************************/ | |||
using System; | |||
using System.IO; | |||
using System.Runtime.CompilerServices; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Binding; | |||
@@ -39,6 +41,7 @@ namespace Tensorflow | |||
return this; | |||
} | |||
[MethodImpl(MethodImplOptions.NoOptimization)] | |||
public static Session LoadFromSavedModel(string path) | |||
{ | |||
lock (Locks.ProcessWide) | |||
@@ -50,20 +53,36 @@ namespace Tensorflow | |||
var tags = new string[] {"serve"}; | |||
var buffer = new TF_Buffer(); | |||
var sess = c_api.TF_LoadSessionFromSavedModel(opt, | |||
IntPtr.Zero, | |||
path, | |||
tags, | |||
tags.Length, | |||
graph, | |||
ref buffer, | |||
status); | |||
IntPtr sess; | |||
try | |||
{ | |||
sess = c_api.TF_LoadSessionFromSavedModel(opt, | |||
IntPtr.Zero, | |||
path, | |||
tags, | |||
tags.Length, | |||
graph, | |||
ref buffer, | |||
status); | |||
status.Check(true); | |||
} catch (TensorflowException ex) when (ex.Message.Contains("Could not find SavedModel")) | |||
{ | |||
status = new Status(); | |||
sess = c_api.TF_LoadSessionFromSavedModel(opt, | |||
IntPtr.Zero, | |||
Path.GetFullPath(path), | |||
tags, | |||
tags.Length, | |||
graph, | |||
ref buffer, | |||
status); | |||
status.Check(true); | |||
} | |||
// load graph bytes | |||
// var data = new byte[buffer.length]; | |||
// Marshal.Copy(buffer.data, data, 0, (int)buffer.length); | |||
// var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ | |||
status.Check(true); | |||
return new Session(sess, g: new Graph(graph)).as_default(); | |||
} | |||
@@ -1,8 +1,11 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.IO; | |||
using System.Linq; | |||
using System.Runtime.InteropServices; | |||
using FluentAssertions; | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using NumSharp; | |||
using Tensorflow; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Binding; | |||
@@ -260,7 +263,7 @@ namespace TensorFlowNET.UnitTest | |||
} | |||
} | |||
[TestMethod] | |||
public void TF_GraphOperationByName() | |||
{ | |||
@@ -280,5 +283,46 @@ namespace TensorFlowNET.UnitTest | |||
} | |||
} | |||
} | |||
private static string modelPath = "./model/"; | |||
[TestMethod] | |||
public void TF_GraphOperationByName_FromModel() | |||
{ | |||
MultiThreadedUnitTestExecuter.Run(8, Core); | |||
//the core method | |||
void Core(int tid) | |||
{ | |||
Console.WriteLine(); | |||
for (int j = 0; j < 100; j++) | |||
{ | |||
var sess = Session.LoadFromSavedModel(modelPath).as_default(); | |||
var inputs = new[] {"sp", "fuel"}; | |||
var inp = inputs.Select(name => sess.graph.OperationByName(name).output).ToArray(); | |||
var outp = sess.graph.OperationByName("softmax_tensor").output; | |||
for (var i = 0; i < 100; i++) | |||
{ | |||
{ | |||
var data = new float[96]; | |||
FeedItem[] feeds = new FeedItem[2]; | |||
for (int f = 0; f < 2; f++) | |||
feeds[f] = new FeedItem(inp[f], new NDArray(data)); | |||
try | |||
{ | |||
sess.run(outp, feeds); | |||
} catch (Exception ex) | |||
{ | |||
Console.WriteLine(ex); | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} | |||
} |
@@ -42,4 +42,10 @@ | |||
<ProjectReference Include="..\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj" /> | |||
</ItemGroup> | |||
<ItemGroup> | |||
<None Update="model\saved_model.pb"> | |||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||
</None> | |||
</ItemGroup> | |||
</Project> |