Browse Source

solve confilict.

tags/v0.12
Oceania2018 6 years ago
parent
commit
a2dade50fc
4 changed files with 83 additions and 11 deletions
  1. +4
    -1
      src/TensorFlowNET.Core/Graphs/Graph.Export.cs
  2. +28
    -9
      src/TensorFlowNET.Core/Sessions/Session.cs
  3. +45
    -1
      test/TensorFlowNET.UnitTest/MultithreadingTests.cs
  4. +6
    -0
      test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj

+ 4
- 1
src/TensorFlowNET.Core/Graphs/Graph.Export.cs View File

@@ -14,6 +14,7 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using Google.Protobuf;
using System.IO; using System.IO;
using Tensorflow.Util; using Tensorflow.Util;


@@ -37,7 +38,9 @@ namespace Tensorflow
using (var buffer = ToGraphDef(status)) using (var buffer = ToGraphDef(status))
{ {
status.Check(true); status.Check(true);
def = GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream());
// limit size to 250M, recursion to max 100
var inputStream = CodedInputStream.CreateWithLimits(buffer.MemoryBlock.Stream(), 250 * 1024 * 1024, 100);
def = GraphDef.Parser.ParseFrom(inputStream);
} }


// Strip the experimental library field iff it's empty. // Strip the experimental library field iff it's empty.


+ 28
- 9
src/TensorFlowNET.Core/Sessions/Session.cs View File

@@ -15,6 +15,8 @@
******************************************************************************/ ******************************************************************************/


using System; using System;
using System.IO;
using System.Runtime.CompilerServices;
using Tensorflow.Util; using Tensorflow.Util;
using static Tensorflow.Binding; using static Tensorflow.Binding;


@@ -39,6 +41,7 @@ namespace Tensorflow
return this; return this;
} }


[MethodImpl(MethodImplOptions.NoOptimization)]
public static Session LoadFromSavedModel(string path) public static Session LoadFromSavedModel(string path)
{ {
lock (Locks.ProcessWide) lock (Locks.ProcessWide)
@@ -50,20 +53,36 @@ namespace Tensorflow
var tags = new string[] {"serve"}; var tags = new string[] {"serve"};
var buffer = new TF_Buffer(); var buffer = new TF_Buffer();


var sess = c_api.TF_LoadSessionFromSavedModel(opt,
IntPtr.Zero,
path,
tags,
tags.Length,
graph,
ref buffer,
status);
IntPtr sess;
try
{
sess = c_api.TF_LoadSessionFromSavedModel(opt,
IntPtr.Zero,
path,
tags,
tags.Length,
graph,
ref buffer,
status);
status.Check(true);
} catch (TensorflowException ex) when (ex.Message.Contains("Could not find SavedModel"))
{
status = new Status();
sess = c_api.TF_LoadSessionFromSavedModel(opt,
IntPtr.Zero,
Path.GetFullPath(path),
tags,
tags.Length,
graph,
ref buffer,
status);
status.Check(true);
}


// load graph bytes // load graph bytes
// var data = new byte[buffer.length]; // var data = new byte[buffer.length];
// Marshal.Copy(buffer.data, data, 0, (int)buffer.length); // Marshal.Copy(buffer.data, data, 0, (int)buffer.length);
// var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ // var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/
status.Check(true);


return new Session(sess, g: new Graph(graph)).as_default(); return new Session(sess, g: new Graph(graph)).as_default();
} }


+ 45
- 1
test/TensorFlowNET.UnitTest/MultithreadingTests.cs View File

@@ -1,8 +1,11 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using FluentAssertions; using FluentAssertions;
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
using Tensorflow; using Tensorflow;
using Tensorflow.Util; using Tensorflow.Util;
using static Tensorflow.Binding; using static Tensorflow.Binding;
@@ -260,7 +263,7 @@ namespace TensorFlowNET.UnitTest
} }
} }


[TestMethod] [TestMethod]
public void TF_GraphOperationByName() public void TF_GraphOperationByName()
{ {
@@ -280,5 +283,46 @@ namespace TensorFlowNET.UnitTest
} }
} }
} }

private static string modelPath = "./model/";

[TestMethod]
public void TF_GraphOperationByName_FromModel()
{
MultiThreadedUnitTestExecuter.Run(8, Core);

//the core method
void Core(int tid)
{
Console.WriteLine();
for (int j = 0; j < 100; j++)
{
var sess = Session.LoadFromSavedModel(modelPath).as_default();
var inputs = new[] {"sp", "fuel"};

var inp = inputs.Select(name => sess.graph.OperationByName(name).output).ToArray();
var outp = sess.graph.OperationByName("softmax_tensor").output;

for (var i = 0; i < 100; i++)
{
{
var data = new float[96];
FeedItem[] feeds = new FeedItem[2];

for (int f = 0; f < 2; f++)
feeds[f] = new FeedItem(inp[f], new NDArray(data));

try
{
sess.run(outp, feeds);
} catch (Exception ex)
{
Console.WriteLine(ex);
}
}
}
}
}
}
} }
} }

+ 6
- 0
test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj View File

@@ -42,4 +42,10 @@
<ProjectReference Include="..\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj" /> <ProjectReference Include="..\TensorFlowNET.Examples\TensorFlowNET.Examples.csproj" />
</ItemGroup> </ItemGroup>


<ItemGroup>
<None Update="model\saved_model.pb">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
</ItemGroup>

</Project> </Project>

Loading…
Cancel
Save