@@ -31,8 +31,9 @@ namespace Tensorflow | |||
private GraphDef _as_graph_def(bool add_shapes = false) | |||
{ | |||
var buffer = ToGraphDef(Status); | |||
Status.Check(); | |||
var status = new Status(); | |||
var buffer = ToGraphDef(status); | |||
status.Check(); | |||
var def = GraphDef.Parser.ParseFrom(buffer); | |||
buffer.Dispose(); | |||
@@ -43,16 +43,20 @@ namespace Tensorflow | |||
var bytes = File.ReadAllBytes(file_path); | |||
var graph_def = new Tensorflow.Buffer(bytes); | |||
var opts = c_api.TF_NewImportGraphDefOptions(); | |||
c_api.TF_GraphImportGraphDef(_handle, graph_def, opts, Status); | |||
return Status; | |||
var status = new Status(); | |||
c_api.TF_GraphImportGraphDef(_handle, graph_def, opts, status); | |||
return status; | |||
} | |||
public Status Import(byte[] bytes) | |||
public Status Import(byte[] bytes, string prefix = "") | |||
{ | |||
var graph_def = new Tensorflow.Buffer(bytes); | |||
var opts = c_api.TF_NewImportGraphDefOptions(); | |||
c_api.TF_GraphImportGraphDef(_handle, graph_def, opts, Status); | |||
return Status; | |||
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, prefix); | |||
var status = new Status(); | |||
c_api.TF_GraphImportGraphDef(_handle, graph_def, opts, status); | |||
c_api.TF_DeleteImportGraphDefOptions(opts); | |||
return status; | |||
} | |||
public static Graph ImportFromPB(string file_path, string name = null) | |||
@@ -88,8 +88,7 @@ namespace Tensorflow | |||
private string _graph_key; | |||
public string graph_key => _graph_key; | |||
public string _last_loss_reduction; | |||
public bool _is_loss_scaled_by_optimizer { get; set; } | |||
public Status Status { get; } | |||
public bool _is_loss_scaled_by_optimizer { get; set; } | |||
/// <summary> | |||
/// True if the graph is considered "finalized". In that case no | |||
@@ -107,7 +106,6 @@ namespace Tensorflow | |||
public Graph() | |||
{ | |||
_handle = c_api.TF_NewGraph(); | |||
Status = new Status(); | |||
_nodes_by_id = new Dictionary<int, ITensorOrOperation>(); | |||
_nodes_by_name = new Dictionary<string, ITensorOrOperation>(); | |||
_names_in_use = new Dictionary<string, int>(); | |||
@@ -117,7 +115,6 @@ namespace Tensorflow | |||
public Graph(IntPtr handle) | |||
{ | |||
_handle = handle; | |||
Status = new Status(); | |||
_nodes_by_id = new Dictionary<int, ITensorOrOperation>(); | |||
_nodes_by_name = new Dictionary<string, ITensorOrOperation>(); | |||
_names_in_use = new Dictionary<string, int>(); | |||
@@ -448,7 +445,12 @@ namespace Tensorflow | |||
public void Dispose() | |||
{ | |||
// c_api.TF_DeleteGraph(_handle); | |||
if (_handle != IntPtr.Zero) | |||
c_api.TF_DeleteGraph(_handle); | |||
_handle = IntPtr.Zero; | |||
GC.SuppressFinalize(this); | |||
} | |||
/// <summary> | |||
@@ -32,20 +32,19 @@ namespace Tensorflow | |||
protected int _current_version; | |||
protected byte[] _target; | |||
protected IntPtr _session; | |||
public Status Status; | |||
public Graph graph => _graph; | |||
public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) | |||
{ | |||
_graph = g is null ? ops.get_default_graph() : g; | |||
_graph.as_default(); | |||
_target = UTF8Encoding.UTF8.GetBytes(target); | |||
SessionOptions newOpts = null; | |||
if (opts == null) | |||
newOpts = c_api.TF_NewSessionOptions(); | |||
Status = new Status(); | |||
var Status = new Status(); | |||
_session = c_api.TF_NewSession(_graph, opts ?? newOpts, Status); | |||
@@ -37,7 +37,7 @@ namespace Tensorflow | |||
: base("", g, opts) | |||
{ | |||
if (s == null) | |||
s = Status; | |||
s = new Status(); | |||
} | |||
public Session as_default() | |||
@@ -83,8 +83,19 @@ namespace Tensorflow | |||
public void Dispose() | |||
{ | |||
c_api.TF_DeleteSession(_session, Status); | |||
Status.Dispose(); | |||
if (_session != IntPtr.Zero) | |||
{ | |||
var status = new Status(); | |||
c_api.TF_DeleteSession(_session, status); | |||
} | |||
_session = IntPtr.Zero; | |||
GC.SuppressFinalize(this); | |||
} | |||
~Session() | |||
{ | |||
Dispose(); | |||
} | |||
public void __enter__() | |||
@@ -5,7 +5,7 @@ | |||
<AssemblyName>TensorFlow.NET</AssemblyName> | |||
<RootNamespace>Tensorflow</RootNamespace> | |||
<TargetTensorFlow>1.14.0</TargetTensorFlow> | |||
<Version>0.10.4</Version> | |||
<Version>0.10.7</Version> | |||
<Authors>Haiping Chen, Meinrad Recheis</Authors> | |||
<Company>SciSharp STACK</Company> | |||
<GeneratePackageOnBuild>true</GeneratePackageOnBuild> | |||
@@ -17,7 +17,7 @@ | |||
<PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags> | |||
<Description>Google's TensorFlow full binding in .NET Standard. | |||
Docs: https://tensorflownet.readthedocs.io</Description> | |||
<AssemblyVersion>0.10.4.0</AssemblyVersion> | |||
<AssemblyVersion>0.10.7.0</AssemblyVersion> | |||
<PackageReleaseNotes>Changes since v0.9.0: | |||
1. Added full connected Convolution Neural Network example. | |||
@@ -31,9 +31,12 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||
9. Fix strided_slice_grad type convention error. | |||
10. Add AbsGrad. | |||
11. Fix Session.LoadFromSavedModel(string). | |||
12. Add Tensor operator overloads.</PackageReleaseNotes> | |||
12. Add Tensor operator overloads. | |||
13. Fix default graph and operation issue when import model. | |||
14. Fix TF_String endcode and decode. | |||
15. Fix Tensor memory leak.</PackageReleaseNotes> | |||
<LangVersion>7.2</LangVersion> | |||
<FileVersion>0.10.4.0</FileVersion> | |||
<FileVersion>0.10.7.0</FileVersion> | |||
<PackageLicenseFile>LICENSE</PackageLicenseFile> | |||
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | |||
<SignAssembly>true</SignAssembly> | |||
@@ -19,6 +19,7 @@ using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Runtime.InteropServices; | |||
using System.Text; | |||
using Tensorflow.Framework; | |||
using static Tensorflow.Python; | |||
@@ -48,8 +49,6 @@ namespace Tensorflow | |||
private int _value_index; | |||
public int value_index => _value_index; | |||
private Status status = new Status(); | |||
private TF_DataType _dtype = TF_DataType.DtInvalid; | |||
public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle); | |||
@@ -76,6 +75,7 @@ namespace Tensorflow | |||
if (_handle == IntPtr.Zero) | |||
{ | |||
var status = new Status(); | |||
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); | |||
status.Check(); | |||
} | |||
@@ -90,6 +90,8 @@ namespace Tensorflow | |||
set | |||
{ | |||
var status = new Status(); | |||
if (value == null) | |||
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status); | |||
else | |||
@@ -131,6 +133,7 @@ namespace Tensorflow | |||
{ | |||
if (_handle == IntPtr.Zero) | |||
{ | |||
var status = new Status(); | |||
var output = _as_tf_output(); | |||
return c_api.TF_GraphGetTensorNumDims(op.graph, output, status); | |||
} | |||
@@ -184,6 +187,41 @@ namespace Tensorflow | |||
return data; | |||
} | |||
public unsafe string[] StringData() | |||
{ | |||
// | |||
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes. | |||
// [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes] | |||
// | |||
long size = 1; | |||
foreach (var s in TensorShape.Dimensions) | |||
size *= s; | |||
var buffer = new byte[size][]; | |||
var src = c_api.TF_TensorData(_handle); | |||
var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize); | |||
src += (int)(size * 8); | |||
for (int i = 0; i < buffer.Length; i++) | |||
{ | |||
using (var status = new Status()) | |||
{ | |||
IntPtr dst = IntPtr.Zero; | |||
UIntPtr dstLen = UIntPtr.Zero; | |||
var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, status); | |||
status.Check(true); | |||
buffer[i] = new byte[(int)dstLen]; | |||
Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); | |||
src += (int)read; | |||
} | |||
} | |||
var _str = new string[buffer.Length]; | |||
for (int i = 0; i < _str.Length; i++) | |||
_str[i] = Encoding.UTF8.GetString(buffer[i]); | |||
return _str; | |||
} | |||
public Tensor MaybeMove() | |||
{ | |||
var tensor = c_api.TF_TensorMaybeMove(_handle); | |||
@@ -364,7 +402,7 @@ namespace Tensorflow | |||
} | |||
if (h != IntPtr.Zero) | |||
c_api.TF_DeleteTensor(h); | |||
status.Dispose(); | |||
GC.SuppressFinalize(this); | |||
} | |||
@@ -32,6 +32,9 @@ namespace Tensorflow | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TF_AllocateTensor(TF_DataType dtype, IntPtr dims, int num_dims, UIntPtr len); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TF_AllocateTensor(TF_DataType dtype, long[] dims, int num_dims, UIntPtr len); | |||
/// <summary> | |||
/// returns the sizeof() for the underlying type corresponding to the given TF_DataType enum value. | |||
/// </summary> | |||
@@ -150,5 +153,8 @@ namespace Tensorflow | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern ulong TF_StringDecode(IntPtr src, ulong src_len, IntPtr dst, ref ulong dst_len, IntPtr status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern unsafe UIntPtr TF_StringDecode(byte* src, UIntPtr src_len, byte** dst, UIntPtr* dst_len, IntPtr status); | |||
} | |||
} |