@@ -47,6 +47,16 @@ namespace Tensorflow | |||||
_graph_key = $"grap-key-{ops.uid()}/"; | _graph_key = $"grap-key-{ops.uid()}/"; | ||||
} | } | ||||
public Graph(IntPtr handle) | |||||
{ | |||||
_handle = handle; | |||||
Status = new Status(); | |||||
_nodes_by_id = new Dictionary<int, ITensorOrOperation>(); | |||||
_nodes_by_name = new Dictionary<string, ITensorOrOperation>(); | |||||
_names_in_use = new Dictionary<string, int>(); | |||||
_graph_key = $"grap-key-{ops.uid()}/"; | |||||
} | |||||
public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) | public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) | ||||
{ | { | ||||
return _as_graph_element_locked(obj, allow_tensor, allow_operation); | return _as_graph_element_locked(obj, allow_tensor, allow_operation); | ||||
@@ -254,6 +254,25 @@ namespace Tensorflow | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TF_ImportGraphDefResultsReturnOutputs(IntPtr results, ref int num_outputs, ref IntPtr outputs); | public static extern void TF_ImportGraphDefResultsReturnOutputs(IntPtr results, ref int num_outputs, ref IntPtr outputs); | ||||
/// <summary> | |||||
/// This function creates a new TF_Session (which is created on success) using | |||||
/// `session_options`, and then initializes state (restoring tensors and other | |||||
/// assets) using `run_options`. | |||||
/// </summary> | |||||
/// <param name="session_options">const TF_SessionOptions*</param> | |||||
/// <param name="run_options">const TF_Buffer*</param> | |||||
/// <param name="export_dir">const char*</param> | |||||
/// <param name="tags">const char* const*</param> | |||||
/// <param name="tags_len">int</param> | |||||
/// <param name="graph">TF_Graph*</param> | |||||
/// <param name="meta_graph_def">TF_Buffer*</param> | |||||
/// <param name="status">TF_Status*</param> | |||||
/// <returns></returns> | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern IntPtr TF_LoadSessionFromSavedModel(IntPtr session_options, IntPtr run_options, | |||||
string export_dir, string[] tags, int tags_len, | |||||
IntPtr graph, ref TF_Buffer meta_graph_def, IntPtr status); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern IntPtr TF_NewGraph(); | public static extern IntPtr TF_NewGraph(); | ||||
@@ -17,11 +17,23 @@ namespace Tensorflow | |||||
(logits, labels, weights)), | (logits, labels, weights)), | ||||
namescope => | namescope => | ||||
{ | { | ||||
(labels, logits, weights) = _remove_squeezable_dimensions( | |||||
labels, logits, weights, expected_rank_diff: 1); | |||||
}); | }); | ||||
throw new NotImplementedException("sparse_softmax_cross_entropy"); | throw new NotImplementedException("sparse_softmax_cross_entropy"); | ||||
} | } | ||||
public (Tensor, Tensor, float) _remove_squeezable_dimensions(Tensor labels, | |||||
Tensor predictions, | |||||
float weights = 0, | |||||
int expected_rank_diff = 0) | |||||
{ | |||||
(labels, predictions, weights) = confusion_matrix.remove_squeezable_dimensions( | |||||
labels, predictions, expected_rank_diff: expected_rank_diff); | |||||
throw new NotImplementedException("_remove_squeezable_dimensions"); | |||||
} | |||||
} | } | ||||
} | } |
@@ -0,0 +1,17 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class confusion_matrix | |||||
{ | |||||
public static (Tensor, Tensor, float) remove_squeezable_dimensions(Tensor labels, | |||||
Tensor predictions, | |||||
int expected_rank_diff = 0, | |||||
string name = "") | |||||
{ | |||||
throw new NotImplementedException("remove_squeezable_dimensions"); | |||||
} | |||||
} | |||||
} |
@@ -36,6 +36,25 @@ namespace Tensorflow | |||||
Status.Check(true); | Status.Check(true); | ||||
} | } | ||||
public static Session LoadFromSavedModel(string path) | |||||
{ | |||||
var graph = c_api.TF_NewGraph(); | |||||
var status = new Status(); | |||||
var opt = c_api.TF_NewSessionOptions(); | |||||
var buffer = new TF_Buffer(); | |||||
var sess = c_api.TF_LoadSessionFromSavedModel(opt, IntPtr.Zero, path, new string[0], 0, graph, ref buffer, status); | |||||
//var bytes = new Buffer(buffer.data).Data; | |||||
//var meta_graph = MetaGraphDef.Parser.ParseFrom(bytes); | |||||
status.Check(); | |||||
tf.g = new Graph(graph); | |||||
return sess; | |||||
} | |||||
public static implicit operator IntPtr(Session session) => session._handle; | public static implicit operator IntPtr(Session session) => session._handle; | ||||
public static implicit operator Session(IntPtr handle) => new Session(handle); | public static implicit operator Session(IntPtr handle) => new Session(handle); | ||||
@@ -4,7 +4,7 @@ | |||||
<TargetFramework>netstandard2.0</TargetFramework> | <TargetFramework>netstandard2.0</TargetFramework> | ||||
<AssemblyName>TensorFlow.NET</AssemblyName> | <AssemblyName>TensorFlow.NET</AssemblyName> | ||||
<RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
<Version>0.2.0</Version> | |||||
<Version>0.3.0</Version> | |||||
<Authors>Haiping Chen</Authors> | <Authors>Haiping Chen</Authors> | ||||
<Company>SciSharp STACK</Company> | <Company>SciSharp STACK</Company> | ||||
<GeneratePackageOnBuild>true</GeneratePackageOnBuild> | <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | ||||
@@ -16,13 +16,13 @@ | |||||
<PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET</PackageTags> | <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET</PackageTags> | ||||
<Description>Google's TensorFlow binding in .NET Standard. | <Description>Google's TensorFlow binding in .NET Standard. | ||||
Docs: https://tensorflownet.readthedocs.io</Description> | Docs: https://tensorflownet.readthedocs.io</Description> | ||||
<AssemblyVersion>0.2.0.0</AssemblyVersion> | |||||
<AssemblyVersion>0.3.0.0</AssemblyVersion> | |||||
<PackageReleaseNotes>Added a bunch of APIs. | <PackageReleaseNotes>Added a bunch of APIs. | ||||
Fixed String tensor creation bug. | Fixed String tensor creation bug. | ||||
Upgraded to TensorFlow 1.13 RC-1. | Upgraded to TensorFlow 1.13 RC-1. | ||||
</PackageReleaseNotes> | </PackageReleaseNotes> | ||||
<LangVersion>7.2</LangVersion> | <LangVersion>7.2</LangVersion> | ||||
<FileVersion>0.2.0.0</FileVersion> | |||||
<FileVersion>0.3.0.0</FileVersion> | |||||
</PropertyGroup> | </PropertyGroup> | ||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | ||||
@@ -26,6 +26,15 @@ namespace TensorFlowNET.UnitTest | |||||
}); | }); | ||||
} | } | ||||
[TestMethod] | |||||
public void ImportSavedModel() | |||||
{ | |||||
with<Session>(Session.LoadFromSavedModel("mobilenet"), sess => | |||||
{ | |||||
}); | |||||
} | |||||
[TestMethod] | [TestMethod] | ||||
public void Save1() | public void Save1() | ||||
{ | { | ||||