Browse Source

add LoadFromSavedModel

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
4eccba4cf5
7 changed files with 90 additions and 4 deletions
  1. +10
    -0
      src/TensorFlowNET.Core/Graphs/Graph.cs
  2. +19
    -0
      src/TensorFlowNET.Core/Graphs/c_api.graph.cs
  3. +13
    -1
      src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs
  4. +17
    -0
      src/TensorFlowNET.Core/Operations/confusion_matrix.py.cs
  5. +19
    -0
      src/TensorFlowNET.Core/Sessions/Session.cs
  6. +3
    -3
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  7. +9
    -0
      test/TensorFlowNET.UnitTest/TrainSaverTest.cs

+ 10
- 0
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -47,6 +47,16 @@ namespace Tensorflow
_graph_key = $"grap-key-{ops.uid()}/"; _graph_key = $"grap-key-{ops.uid()}/";
} }


public Graph(IntPtr handle)
{
_handle = handle;
Status = new Status();
_nodes_by_id = new Dictionary<int, ITensorOrOperation>();
_nodes_by_name = new Dictionary<string, ITensorOrOperation>();
_names_in_use = new Dictionary<string, int>();
_graph_key = $"grap-key-{ops.uid()}/";
}

public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true)
{ {
return _as_graph_element_locked(obj, allow_tensor, allow_operation); return _as_graph_element_locked(obj, allow_tensor, allow_operation);


+ 19
- 0
src/TensorFlowNET.Core/Graphs/c_api.graph.cs View File

@@ -254,6 +254,25 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern void TF_ImportGraphDefResultsReturnOutputs(IntPtr results, ref int num_outputs, ref IntPtr outputs); public static extern void TF_ImportGraphDefResultsReturnOutputs(IntPtr results, ref int num_outputs, ref IntPtr outputs);


/// <summary>
/// This function creates a new TF_Session (which is created on success) using
/// `session_options`, and then initializes state (restoring tensors and other
/// assets) using `run_options`.
/// </summary>
/// <param name="session_options">const TF_SessionOptions*</param>
/// <param name="run_options">const TF_Buffer*</param>
/// <param name="export_dir">const char*</param>
/// <param name="tags">const char* const*</param>
/// <param name="tags_len">int</param>
/// <param name="graph">TF_Graph*</param>
/// <param name="meta_graph_def">TF_Buffer*</param>
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TF_LoadSessionFromSavedModel(IntPtr session_options, IntPtr run_options,
string export_dir, string[] tags, int tags_len,
IntPtr graph, ref TF_Buffer meta_graph_def, IntPtr status);

[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TF_NewGraph(); public static extern IntPtr TF_NewGraph();




+ 13
- 1
src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs View File

@@ -17,11 +17,23 @@ namespace Tensorflow
(logits, labels, weights)), (logits, labels, weights)),
namescope => namescope =>
{ {

(labels, logits, weights) = _remove_squeezable_dimensions(
labels, logits, weights, expected_rank_diff: 1);


}); });


throw new NotImplementedException("sparse_softmax_cross_entropy"); throw new NotImplementedException("sparse_softmax_cross_entropy");
} }

public (Tensor, Tensor, float) _remove_squeezable_dimensions(Tensor labels,
Tensor predictions,
float weights = 0,
int expected_rank_diff = 0)
{
(labels, predictions, weights) = confusion_matrix.remove_squeezable_dimensions(
labels, predictions, expected_rank_diff: expected_rank_diff);

throw new NotImplementedException("_remove_squeezable_dimensions");
}
} }
} }

+ 17
- 0
src/TensorFlowNET.Core/Operations/confusion_matrix.py.cs View File

@@ -0,0 +1,17 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class confusion_matrix
{
public static (Tensor, Tensor, float) remove_squeezable_dimensions(Tensor labels,
Tensor predictions,
int expected_rank_diff = 0,
string name = "")
{
throw new NotImplementedException("remove_squeezable_dimensions");
}
}
}

+ 19
- 0
src/TensorFlowNET.Core/Sessions/Session.cs View File

@@ -36,6 +36,25 @@ namespace Tensorflow
Status.Check(true); Status.Check(true);
} }


public static Session LoadFromSavedModel(string path)
{
var graph = c_api.TF_NewGraph();
var status = new Status();
var opt = c_api.TF_NewSessionOptions();

var buffer = new TF_Buffer();
var sess = c_api.TF_LoadSessionFromSavedModel(opt, IntPtr.Zero, path, new string[0], 0, graph, ref buffer, status);

//var bytes = new Buffer(buffer.data).Data;
//var meta_graph = MetaGraphDef.Parser.ParseFrom(bytes);

status.Check();

tf.g = new Graph(graph);

return sess;
}

public static implicit operator IntPtr(Session session) => session._handle; public static implicit operator IntPtr(Session session) => session._handle;
public static implicit operator Session(IntPtr handle) => new Session(handle); public static implicit operator Session(IntPtr handle) => new Session(handle);




+ 3
- 3
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -4,7 +4,7 @@
<TargetFramework>netstandard2.0</TargetFramework> <TargetFramework>netstandard2.0</TargetFramework>
<AssemblyName>TensorFlow.NET</AssemblyName> <AssemblyName>TensorFlow.NET</AssemblyName>
<RootNamespace>Tensorflow</RootNamespace> <RootNamespace>Tensorflow</RootNamespace>
<Version>0.2.0</Version>
<Version>0.3.0</Version>
<Authors>Haiping Chen</Authors> <Authors>Haiping Chen</Authors>
<Company>SciSharp STACK</Company> <Company>SciSharp STACK</Company>
<GeneratePackageOnBuild>true</GeneratePackageOnBuild> <GeneratePackageOnBuild>true</GeneratePackageOnBuild>
@@ -16,13 +16,13 @@
<PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET</PackageTags> <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET</PackageTags>
<Description>Google's TensorFlow binding in .NET Standard. <Description>Google's TensorFlow binding in .NET Standard.
Docs: https://tensorflownet.readthedocs.io</Description> Docs: https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.2.0.0</AssemblyVersion>
<AssemblyVersion>0.3.0.0</AssemblyVersion>
<PackageReleaseNotes>Added a bunch of APIs. <PackageReleaseNotes>Added a bunch of APIs.
Fixed String tensor creation bug. Fixed String tensor creation bug.
Upgraded to TensorFlow 1.13 RC-1. Upgraded to TensorFlow 1.13 RC-1.
</PackageReleaseNotes> </PackageReleaseNotes>
<LangVersion>7.2</LangVersion> <LangVersion>7.2</LangVersion>
<FileVersion>0.2.0.0</FileVersion>
<FileVersion>0.3.0.0</FileVersion>
</PropertyGroup> </PropertyGroup>


<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">


+ 9
- 0
test/TensorFlowNET.UnitTest/TrainSaverTest.cs View File

@@ -26,6 +26,15 @@ namespace TensorFlowNET.UnitTest
}); });
} }


[TestMethod]
public void ImportSavedModel()
{
with<Session>(Session.LoadFromSavedModel("mobilenet"), sess =>
{
});
}

[TestMethod] [TestMethod]
public void Save1() public void Save1()
{ {


Loading…
Cancel
Save