@@ -15,14 +15,15 @@ namespace Tensorflow.Hub | |||
private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; | |||
private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; | |||
public static async Task<Datasets<MnistDataSet>> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null) | |||
public static async Task<Datasets<MnistDataSet>> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null, bool showProgressInConsole = false) | |||
{ | |||
var loader = new MnistModelLoader(); | |||
var setting = new ModelLoadSetting | |||
{ | |||
TrainDir = trainDir, | |||
OneHot = oneHot | |||
OneHot = oneHot, | |||
ShowProgressInConsole = showProgressInConsole | |||
}; | |||
if (trainSize.HasValue) | |||
@@ -48,37 +49,37 @@ namespace Tensorflow.Hub | |||
sourceUrl = DEFAULT_SOURCE_URL; | |||
// load train images | |||
await this.DownloadAsync(sourceUrl + TRAIN_IMAGES, setting.TrainDir, TRAIN_IMAGES) | |||
await this.DownloadAsync(sourceUrl + TRAIN_IMAGES, setting.TrainDir, TRAIN_IMAGES, showProgressInConsole: setting.ShowProgressInConsole) | |||
.ShowProgressInConsole(setting.ShowProgressInConsole); | |||
await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_IMAGES), setting.TrainDir) | |||
await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_IMAGES), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) | |||
.ShowProgressInConsole(setting.ShowProgressInConsole); | |||
var trainImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_IMAGES)), limit: setting.TrainSize); | |||
// load train labels | |||
await this.DownloadAsync(sourceUrl + TRAIN_LABELS, setting.TrainDir, TRAIN_LABELS) | |||
await this.DownloadAsync(sourceUrl + TRAIN_LABELS, setting.TrainDir, TRAIN_LABELS, showProgressInConsole: setting.ShowProgressInConsole) | |||
.ShowProgressInConsole(setting.ShowProgressInConsole); | |||
await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_LABELS), setting.TrainDir) | |||
await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_LABELS), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) | |||
.ShowProgressInConsole(setting.ShowProgressInConsole); | |||
var trainLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_LABELS)), one_hot: setting.OneHot, limit: setting.TrainSize); | |||
// load test images | |||
await this.DownloadAsync(sourceUrl + TEST_IMAGES, setting.TrainDir, TEST_IMAGES) | |||
await this.DownloadAsync(sourceUrl + TEST_IMAGES, setting.TrainDir, TEST_IMAGES, showProgressInConsole: setting.ShowProgressInConsole) | |||
.ShowProgressInConsole(setting.ShowProgressInConsole); | |||
await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_IMAGES), setting.TrainDir) | |||
await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_IMAGES), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) | |||
.ShowProgressInConsole(setting.ShowProgressInConsole); | |||
var testImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_IMAGES)), limit: setting.TestSize); | |||
// load test labels | |||
await this.DownloadAsync(sourceUrl + TEST_LABELS, setting.TrainDir, TEST_LABELS) | |||
await this.DownloadAsync(sourceUrl + TEST_LABELS, setting.TrainDir, TEST_LABELS, showProgressInConsole: setting.ShowProgressInConsole) | |||
.ShowProgressInConsole(setting.ShowProgressInConsole); | |||
await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_LABELS), setting.TrainDir) | |||
await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_LABELS), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) | |||
.ShowProgressInConsole(setting.ShowProgressInConsole); | |||
var testLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_LABELS)), one_hot: setting.OneHot, limit: setting.TestSize); | |||
@@ -2,7 +2,7 @@ | |||
<PropertyGroup> | |||
<RootNamespace>Tensorflow.Hub</RootNamespace> | |||
<TargetFramework>netstandard2.0</TargetFramework> | |||
<Version>0.0.1</Version> | |||
<Version>0.0.2</Version> | |||
<Authors>Kerry Jiang</Authors> | |||
<Company>SciSharp STACK</Company> | |||
<Copyright>Apache 2.0</Copyright> | |||
@@ -13,7 +13,7 @@ | |||
<Description>TensorFlow Hub is a library to foster the publication, discovery, and consumption of reusable parts of machine learning models.</Description> | |||
<PackageId>SciSharp.TensorFlowHub</PackageId> | |||
<GeneratePackageOnBuild>true</GeneratePackageOnBuild> | |||
<PackageReleaseNotes>1. Add MNIST loader.</PackageReleaseNotes> | |||
<PackageReleaseNotes></PackageReleaseNotes> | |||
<PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&v=4</PackageIconUrl> | |||
</PropertyGroup> | |||
<ItemGroup> | |||
@@ -19,7 +19,7 @@ namespace Tensorflow.Hub | |||
await modelLoader.DownloadAsync(url, dir, fileName); | |||
} | |||
public static async Task DownloadAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string url, string dirSaveTo, string fileName) | |||
public static async Task DownloadAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string url, string dirSaveTo, string fileName, bool showProgressInConsole = false) | |||
where TDataSet : IDataSet | |||
{ | |||
if (!Path.IsPathRooted(dirSaveTo)) | |||
@@ -27,18 +27,30 @@ namespace Tensorflow.Hub | |||
var fileSaveTo = Path.Combine(dirSaveTo, fileName); | |||
if (showProgressInConsole) | |||
{ | |||
Console.WriteLine($"Downloading {fileName}"); | |||
} | |||
if (File.Exists(fileSaveTo)) | |||
{ | |||
if (showProgressInConsole) | |||
{ | |||
Console.WriteLine($"The file {fileName} already exists"); | |||
} | |||
return; | |||
} | |||
Directory.CreateDirectory(dirSaveTo); | |||
using (var wc = new WebClient()) | |||
{ | |||
await wc.DownloadFileTaskAsync(url, fileSaveTo); | |||
await wc.DownloadFileTaskAsync(url, fileSaveTo).ConfigureAwait(false); | |||
} | |||
} | |||
public static async Task UnzipAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string zipFile, string saveTo) | |||
public static async Task UnzipAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string zipFile, string saveTo, bool showProgressInConsole = false) | |||
where TDataSet : IDataSet | |||
{ | |||
if (!Path.IsPathRooted(saveTo)) | |||
@@ -49,67 +61,76 @@ namespace Tensorflow.Hub | |||
if (!Path.IsPathRooted(zipFile)) | |||
zipFile = Path.Combine(AppContext.BaseDirectory, zipFile); | |||
var destFilePath = Path.Combine(saveTo, Path.GetFileNameWithoutExtension(zipFile)); | |||
var destFileName = Path.GetFileNameWithoutExtension(zipFile); | |||
var destFilePath = Path.Combine(saveTo, destFileName); | |||
if (showProgressInConsole) | |||
Console.WriteLine($"Unzippinng {Path.GetFileName(zipFile)}"); | |||
if (File.Exists(destFilePath)) | |||
File.Delete(destFilePath); | |||
{ | |||
if (showProgressInConsole) | |||
Console.WriteLine($"The file {destFileName} already exists"); | |||
} | |||
using (GZipStream unzipStream = new GZipStream(File.OpenRead(zipFile), CompressionMode.Decompress)) | |||
{ | |||
using (var destStream = File.Create(destFilePath)) | |||
{ | |||
await unzipStream.CopyToAsync(destStream); | |||
await destStream.FlushAsync(); | |||
await unzipStream.CopyToAsync(destStream).ConfigureAwait(false); | |||
await destStream.FlushAsync().ConfigureAwait(false); | |||
destStream.Close(); | |||
} | |||
unzipStream.Close(); | |||
} | |||
} | |||
public static async Task ShowProgressInConsole(this Task task) | |||
{ | |||
await ShowProgressInConsole(task, true); | |||
} | |||
} | |||
public static async Task ShowProgressInConsole(this Task task, bool enable) | |||
{ | |||
if (!enable) | |||
{ | |||
await task; | |||
return; | |||
} | |||
var cts = new CancellationTokenSource(); | |||
var showProgressTask = ShowProgressInConsole(cts); | |||
try | |||
{ | |||
{ | |||
await task; | |||
} | |||
finally | |||
{ | |||
cts.Cancel(); | |||
cts.Cancel(); | |||
} | |||
await showProgressTask; | |||
Console.WriteLine("Done."); | |||
} | |||
private static async Task ShowProgressInConsole(CancellationTokenSource cts) | |||
{ | |||
var cols = 0; | |||
await Task.Delay(1000); | |||
while (!cts.IsCancellationRequested) | |||
{ | |||
await Task.Delay(1000); | |||
Console.Write("."); | |||
cols++; | |||
if (cols >= 50) | |||
if (cols % 50 == 0) | |||
{ | |||
cols = 0; | |||
Console.WriteLine(); | |||
} | |||
} | |||
Console.WriteLine(); | |||
if (cols > 0) | |||
Console.WriteLine(); | |||
} | |||
} | |||
} |
@@ -192,6 +192,12 @@ namespace Tensorflow | |||
public static Tensor logical_and(Tensor x, Tensor y, string name = null) | |||
=> gen_math_ops.logical_and(x, y, name); | |||
public static Tensor logical_not(Tensor x, string name = null) | |||
=> gen_math_ops.logical_not(x, name); | |||
public static Tensor logical_or(Tensor x, Tensor y, string name = null) | |||
=> gen_math_ops.logical_or(x, y, name); | |||
/// <summary> | |||
/// Clips tensor values to a specified min and max. | |||
/// </summary> | |||
@@ -34,10 +34,17 @@ namespace Tensorflow | |||
public Graph get_controller() | |||
{ | |||
if (stack.Count == 0) | |||
if (stack.Count(x => x.IsDefault) == 0) | |||
stack.Add(new StackModel { Graph = tf.Graph(), IsDefault = true }); | |||
return stack.First(x => x.IsDefault).Graph; | |||
return stack.Last(x => x.IsDefault).Graph; | |||
} | |||
public bool remove(Graph g) | |||
{ | |||
var sm = stack.FirstOrDefault(x => x.Graph == g); | |||
if (sm == null) return false; | |||
return stack.Remove(sm); | |||
} | |||
public void reset() | |||
@@ -73,9 +73,8 @@ namespace Tensorflow | |||
all variables that are created during the construction of a graph. The caller | |||
may define additional collections by specifying a new name. | |||
*/ | |||
public partial class Graph : IPython, IDisposable, IEnumerable<Operation> | |||
public partial class Graph : DisposableObject, IEnumerable<Operation> | |||
{ | |||
private IntPtr _handle; | |||
private Dictionary<int, ITensorOrOperation> _nodes_by_id; | |||
public Dictionary<string, ITensorOrOperation> _nodes_by_name; | |||
private Dictionary<string, int> _names_in_use; | |||
@@ -121,10 +120,6 @@ namespace Tensorflow | |||
_graph_key = $"grap-key-{ops.uid()}/"; | |||
} | |||
public void __enter__() | |||
{ | |||
} | |||
public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) | |||
{ | |||
return _as_graph_element_locked(obj, allow_tensor, allow_operation); | |||
@@ -443,14 +438,15 @@ namespace Tensorflow | |||
_unfetchable_ops.Add(op); | |||
} | |||
public void Dispose() | |||
{ | |||
/*if (_handle != IntPtr.Zero) | |||
c_api.TF_DeleteGraph(_handle); | |||
_handle = IntPtr.Zero; | |||
GC.SuppressFinalize(this);*/ | |||
protected override void DisposeManagedState() | |||
{ | |||
ops.default_graph_stack.remove(this); | |||
} | |||
protected override void DisposeUnManagedState(IntPtr handle) | |||
{ | |||
Console.WriteLine($"Destroy graph {handle}"); | |||
c_api.TF_DeleteGraph(handle); | |||
} | |||
/// <summary> | |||
@@ -481,17 +477,19 @@ namespace Tensorflow | |||
return new TensorShape(dims.Select(x => (int)x).ToArray()); | |||
} | |||
string debugString = string.Empty; | |||
public override string ToString() | |||
{ | |||
int len = 0; | |||
return c_api.TF_GraphDebugString(_handle, out len); | |||
return $"{graph_key}, ({_handle})"; | |||
/*if (string.IsNullOrEmpty(debugString)) | |||
{ | |||
int len = 0; | |||
debugString = c_api.TF_GraphDebugString(_handle, out len); | |||
} | |||
return debugString;*/ | |||
} | |||
public void __exit__() | |||
{ | |||
} | |||
private IEnumerable<Operation> GetEnumerable() | |||
=> c_api_util.tf_operations(this); | |||
@@ -84,7 +84,7 @@ namespace Tensorflow | |||
// Dict mapping op name to file and line information for op colocation | |||
// context managers. | |||
_control_flow_context = graph._get_control_flow_context(); | |||
_control_flow_context = _graph._get_control_flow_context(); | |||
// Note: _control_flow_post_processing() must not be called here, the caller is responsible for calling it when using this constructor. | |||
} | |||
@@ -357,6 +357,20 @@ namespace Tensorflow | |||
return _op.outputs[0]; | |||
} | |||
public static Tensor logical_not(Tensor x, string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("LogicalNot", name, args: new { x }); | |||
return _op.outputs[0]; | |||
} | |||
public static Tensor logical_or(Tensor x, Tensor y, string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("LogicalOr", name, args: new { x, y }); | |||
return _op.outputs[0]; | |||
} | |||
public static Tensor squared_difference(Tensor x, Tensor y, string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("SquaredDifference", name, args: new { x, y, name }); | |||
@@ -31,7 +31,6 @@ namespace Tensorflow | |||
protected bool _closed; | |||
protected int _current_version; | |||
protected byte[] _target; | |||
protected IntPtr _session; | |||
public Graph graph => _graph; | |||
public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) | |||
@@ -46,7 +45,7 @@ namespace Tensorflow | |||
var status = new Status(); | |||
_session = c_api.TF_NewSession(_graph, opts ?? newOpts, status); | |||
_handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status); | |||
status.Check(true); | |||
} | |||
@@ -212,7 +211,7 @@ namespace Tensorflow | |||
var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray(); | |||
c_api.TF_SessionRun(_session, | |||
c_api.TF_SessionRun(_handle, | |||
run_options: null, | |||
inputs: feed_dict.Select(f => f.Key).ToArray(), | |||
input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(), | |||
@@ -30,7 +30,7 @@ namespace Tensorflow | |||
public Session(IntPtr handle, Graph g = null) | |||
: base("", g, null) | |||
{ | |||
_session = handle; | |||
_handle = handle; | |||
} | |||
public Session(Graph g, SessionOptions opts = null, Status s = null) | |||
@@ -73,7 +73,7 @@ namespace Tensorflow | |||
return new Session(sess, g: new Graph(graph).as_default()); | |||
} | |||
public static implicit operator IntPtr(Session session) => session._session; | |||
public static implicit operator IntPtr(Session session) => session._handle; | |||
public static implicit operator Session(IntPtr handle) => new Session(handle); | |||
public void __enter__() | |||
@@ -506,7 +506,7 @@ namespace Tensorflow | |||
IsMemoryOwner = true; | |||
} | |||
private unsafe IntPtr Allocate(NDArray nd, TF_DataType? tensorDType = null) | |||
private unsafe IntPtr AllocateWithMemoryCopy(NDArray nd, TF_DataType? tensorDType = null) | |||
{ | |||
IntPtr dotHandle = IntPtr.Zero; | |||
int buffersize = 0; | |||
@@ -520,30 +520,30 @@ namespace Tensorflow | |||
var dataType = ToTFDataType(nd.dtype); | |||
// shape | |||
var dims = nd.shape.Select(x => (long)x).ToArray(); | |||
var nd1 = nd.ravel(); | |||
// var nd1 = nd.ravel(); | |||
switch (nd.dtype.Name) | |||
{ | |||
case "Boolean": | |||
var boolVals = Array.ConvertAll(nd1.Data<bool>(), x => Convert.ToByte(x)); | |||
var boolVals = Array.ConvertAll(nd.Data<bool>(), x => Convert.ToByte(x)); | |||
Marshal.Copy(boolVals, 0, dotHandle, nd.size); | |||
break; | |||
case "Int16": | |||
Marshal.Copy(nd1.Data<short>(), 0, dotHandle, nd.size); | |||
Marshal.Copy(nd.Data<short>(), 0, dotHandle, nd.size); | |||
break; | |||
case "Int32": | |||
Marshal.Copy(nd1.Data<int>(), 0, dotHandle, nd.size); | |||
Marshal.Copy(nd.Data<int>(), 0, dotHandle, nd.size); | |||
break; | |||
case "Int64": | |||
Marshal.Copy(nd1.Data<long>(), 0, dotHandle, nd.size); | |||
Marshal.Copy(nd.Data<long>(), 0, dotHandle, nd.size); | |||
break; | |||
case "Single": | |||
Marshal.Copy(nd1.Data<float>(), 0, dotHandle, nd.size); | |||
Marshal.Copy(nd.Data<float>(), 0, dotHandle, nd.size); | |||
break; | |||
case "Double": | |||
Marshal.Copy(nd1.Data<double>(), 0, dotHandle, nd.size); | |||
Marshal.Copy(nd.Data<double>(), 0, dotHandle, nd.size); | |||
break; | |||
case "Byte": | |||
Marshal.Copy(nd1.Data<byte>(), 0, dotHandle, nd.size); | |||
Marshal.Copy(nd.Data<byte>(), 0, dotHandle, nd.size); | |||
break; | |||
case "String": | |||
return new Tensor(UTF8Encoding.UTF8.GetBytes(nd.Data<string>(0)), TF_DataType.TF_STRING); | |||
@@ -559,6 +559,132 @@ namespace Tensorflow | |||
ref _deallocatorArgs); | |||
return tfHandle; | |||
} | |||
private unsafe IntPtr Allocate(NDArray nd, TF_DataType? tensorDType = null) | |||
{ | |||
IntPtr dotHandle = IntPtr.Zero; | |||
IntPtr tfHandle = IntPtr.Zero; | |||
int buffersize = nd.size * nd.dtypesize; | |||
var dataType = ToTFDataType(nd.dtype); | |||
// shape | |||
var dims = nd.shape.Select(x => (long)x).ToArray(); | |||
switch (nd.dtype.Name) | |||
{ | |||
case "Boolean": | |||
{ | |||
var boolVals = Array.ConvertAll(nd.Data<bool>(), x => Convert.ToByte(x)); | |||
var array = nd.Data<byte>(); | |||
fixed (byte* h = &array[0]) | |||
{ | |||
tfHandle = c_api.TF_NewTensor(dataType, | |||
dims, | |||
dims.Length, | |||
new IntPtr(h), | |||
(UIntPtr)buffersize, | |||
_nothingDeallocator, | |||
ref _deallocatorArgs); | |||
} | |||
} | |||
break; | |||
case "Int16": | |||
{ | |||
var array = nd.Data<short>(); | |||
fixed (short* h = &array[0]) | |||
{ | |||
tfHandle = c_api.TF_NewTensor(dataType, | |||
dims, | |||
dims.Length, | |||
new IntPtr(h), | |||
(UIntPtr)buffersize, | |||
_nothingDeallocator, | |||
ref _deallocatorArgs); | |||
} | |||
} | |||
break; | |||
case "Int32": | |||
{ | |||
var array = nd.Data<int>(); | |||
fixed (int* h = &array[0]) | |||
{ | |||
tfHandle = c_api.TF_NewTensor(dataType, | |||
dims, | |||
dims.Length, | |||
new IntPtr(h), | |||
(UIntPtr)buffersize, | |||
_nothingDeallocator, | |||
ref _deallocatorArgs); | |||
} | |||
} | |||
break; | |||
case "Int64": | |||
{ | |||
var array = nd.Data<long>(); | |||
fixed (long* h = &array[0]) | |||
{ | |||
tfHandle = c_api.TF_NewTensor(dataType, | |||
dims, | |||
dims.Length, | |||
new IntPtr(h), | |||
(UIntPtr)buffersize, | |||
_nothingDeallocator, | |||
ref _deallocatorArgs); | |||
} | |||
} | |||
break; | |||
case "Single": | |||
{ | |||
var array = nd.Data<float>(); | |||
fixed (float* h = &array[0]) | |||
{ | |||
tfHandle = c_api.TF_NewTensor(dataType, | |||
dims, | |||
dims.Length, | |||
new IntPtr(h), | |||
(UIntPtr)buffersize, | |||
_nothingDeallocator, | |||
ref _deallocatorArgs); | |||
} | |||
} | |||
break; | |||
case "Double": | |||
{ | |||
var array = nd.Data<double>(); | |||
fixed (double* h = &array[0]) | |||
{ | |||
tfHandle = c_api.TF_NewTensor(dataType, | |||
dims, | |||
dims.Length, | |||
new IntPtr(h), | |||
(UIntPtr)buffersize, | |||
_nothingDeallocator, | |||
ref _deallocatorArgs); | |||
} | |||
} | |||
break; | |||
case "Byte": | |||
{ | |||
var array = nd.Data<byte>(); | |||
fixed (byte* h = &array[0]) | |||
{ | |||
tfHandle = c_api.TF_NewTensor(dataType, | |||
dims, | |||
dims.Length, | |||
new IntPtr(h), | |||
(UIntPtr)buffersize, | |||
_nothingDeallocator, | |||
ref _deallocatorArgs); | |||
} | |||
} | |||
break; | |||
case "String": | |||
return new Tensor(UTF8Encoding.UTF8.GetBytes(nd.Data<string>(0)), TF_DataType.TF_STRING); | |||
default: | |||
throw new NotImplementedException($"Marshal.Copy failed for {nd.dtype.Name}."); | |||
} | |||
return tfHandle; | |||
} | |||
public unsafe Tensor(byte[][] buffer, long[] shape) | |||
@@ -70,7 +70,8 @@ namespace TensorFlowNET.Examples | |||
OneHot = true, | |||
TrainSize = train_size, | |||
ValidationSize = validation_size, | |||
TestSize = test_size | |||
TestSize = test_size, | |||
ShowProgressInConsole = true | |||
}; | |||
mnist = loader.LoadAsync(setting).Result; | |||
@@ -124,7 +124,7 @@ namespace TensorFlowNET.Examples | |||
public void PrepareData() | |||
{ | |||
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: train_size, validationSize: validation_size, testSize: test_size).Result; | |||
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: train_size, validationSize: validation_size, testSize: test_size, showProgressInConsole: true).Result; | |||
} | |||
public void SaveModel(Session sess) | |||
@@ -84,7 +84,7 @@ namespace TensorFlowNET.Examples | |||
public void PrepareData() | |||
{ | |||
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: TrainSize, validationSize: ValidationSize, testSize: TestSize).Result; | |||
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: TrainSize, validationSize: ValidationSize, testSize: TestSize, showProgressInConsole: true).Result; | |||
// In this example, we limit mnist data | |||
(Xtr, Ytr) = mnist.Train.GetNextBatch(TrainSize == null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates) | |||
(Xte, Yte) = mnist.Test.GetNextBatch(TestSize == null ? 200 : TestSize.Value / 100); // 200 for testing | |||
@@ -310,7 +310,7 @@ namespace TensorFlowNET.Examples | |||
public void PrepareData() | |||
{ | |||
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result; | |||
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, showProgressInConsole: true).Result; | |||
(x_train, y_train) = Reformat(mnist.Train.Data, mnist.Train.Labels); | |||
(x_valid, y_valid) = Reformat(mnist.Validation.Data, mnist.Validation.Labels); | |||
(x_test, y_test) = Reformat(mnist.Test.Data, mnist.Test.Labels); | |||
@@ -121,7 +121,7 @@ namespace TensorFlowNET.Examples | |||
public void PrepareData() | |||
{ | |||
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result; | |||
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, showProgressInConsole: true).Result; | |||
} | |||
public void Train(Session sess) | |||
@@ -143,7 +143,7 @@ namespace TensorFlowNET.Examples | |||
public void PrepareData() | |||
{ | |||
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result; | |||
mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, showProgressInConsole: true).Result; | |||
(x_train, y_train) = (mnist.Train.Data, mnist.Train.Labels); | |||
(x_valid, y_valid) = (mnist.Validation.Data, mnist.Validation.Labels); | |||
(x_test, y_test) = (mnist.Test.Data, mnist.Test.Labels); | |||
@@ -52,7 +52,8 @@ namespace TensorFlowNET.Examples | |||
// The location where variable checkpoints will be stored. | |||
string CHECKPOINT_NAME = Path.Join(data_dir, "_retrain_checkpoint"); | |||
string tfhub_module = "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/3"; | |||
string final_tensor_name = "final_result"; | |||
string input_tensor_name = "Placeholder"; | |||
string final_tensor_name = "Score"; | |||
float testing_percentage = 0.1f; | |||
float validation_percentage = 0.1f; | |||
float learning_rate = 0.01f; | |||
@@ -81,13 +82,13 @@ namespace TensorFlowNET.Examples | |||
PrepareData(); | |||
#region For debug purpose | |||
// predict images | |||
// Predict(null); | |||
// load saved pb and test new images. | |||
// Test(null); | |||
#endregion | |||
var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); | |||
@@ -276,16 +277,13 @@ namespace TensorFlowNET.Examples | |||
private (Graph, Tensor, Tensor, bool) create_module_graph() | |||
{ | |||
var (height, width) = (299, 299); | |||
return tf_with(tf.Graph().as_default(), graph => | |||
{ | |||
tf.train.import_meta_graph("graph/InceptionV3.meta"); | |||
Tensor resized_input_tensor = graph.OperationByName("Placeholder"); //tf.placeholder(tf.float32, new TensorShape(-1, height, width, 3)); | |||
// var m = hub.Module(module_spec); | |||
Tensor bottleneck_tensor = graph.OperationByName("module_apply_default/hub_output/feature_vector/SpatialSqueeze");// m(resized_input_tensor); | |||
var wants_quantization = false; | |||
return (graph, bottleneck_tensor, resized_input_tensor, wants_quantization); | |||
}); | |||
var graph = tf.Graph().as_default(); | |||
tf.train.import_meta_graph("graph/InceptionV3.meta"); | |||
Tensor resized_input_tensor = graph.OperationByName(input_tensor_name); //tf.placeholder(tf.float32, new TensorShape(-1, height, width, 3)); | |||
// var m = hub.Module(module_spec); | |||
Tensor bottleneck_tensor = graph.OperationByName("module_apply_default/hub_output/feature_vector/SpatialSqueeze");// m(resized_input_tensor); | |||
var wants_quantization = false; | |||
return (graph, bottleneck_tensor, resized_input_tensor, wants_quantization); | |||
} | |||
private (NDArray, long[], string[]) get_random_cached_bottlenecks(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists, | |||
@@ -594,13 +592,10 @@ namespace TensorFlowNET.Examples | |||
create_module_graph(); | |||
// Add the new layer that we'll be training. | |||
tf_with(graph.as_default(), delegate | |||
{ | |||
(train_step, cross_entropy, bottleneck_input, | |||
ground_truth_input, final_tensor) = add_final_retrain_ops( | |||
class_count, final_tensor_name, bottleneck_tensor, | |||
wants_quantization, is_training: true); | |||
}); | |||
(train_step, cross_entropy, bottleneck_input, | |||
ground_truth_input, final_tensor) = add_final_retrain_ops( | |||
class_count, final_tensor_name, bottleneck_tensor, | |||
wants_quantization, is_training: true); | |||
return graph; | |||
} | |||
@@ -734,15 +729,15 @@ namespace TensorFlowNET.Examples | |||
var labels = File.ReadAllLines(output_labels); | |||
// predict image | |||
var img_path = Path.Join(image_dir, "roses", "12240303_80d87f77a3_n.jpg"); | |||
var img_path = Path.Join(image_dir, "daisy", "5547758_eea9edfd54_n.jpg"); | |||
var fileBytes = ReadTensorFromImageFile(img_path); | |||
// import graph and variables | |||
var graph = new Graph(); | |||
graph.Import(output_graph, ""); | |||
Tensor input = graph.OperationByName("Placeholder"); | |||
Tensor output = graph.OperationByName("final_result"); | |||
Tensor input = graph.OperationByName(input_tensor_name); | |||
Tensor output = graph.OperationByName(final_tensor_name); | |||
using (var sess = tf.Session(graph)) | |||
{ | |||
@@ -7,12 +7,13 @@ namespace TensorFlowNET.UnitTest | |||
[TestClass] | |||
public class NameScopeTest | |||
{ | |||
Graph g = ops.get_default_graph(); | |||
string name = ""; | |||
[TestMethod] | |||
public void NestedNameScope() | |||
{ | |||
Graph g = tf.Graph().as_default(); | |||
tf_with(new ops.NameScope("scope1"), scope1 => | |||
{ | |||
name = scope1; | |||
@@ -37,6 +38,8 @@ namespace TensorFlowNET.UnitTest | |||
Assert.AreEqual("scope1/Const_1:0", const3.name); | |||
}); | |||
g.Dispose(); | |||
Assert.AreEqual("", g._name_stack); | |||
} | |||
} | |||
@@ -131,7 +131,7 @@ namespace TensorFlowNET.UnitTest | |||
} | |||
[TestMethod] | |||
public void logicalAndTest() | |||
public void logicalOpsTest() | |||
{ | |||
var a = tf.constant(new[] {1f, 2f, 3f, 4f, -4f, -3f, -2f, -1f}); | |||
var b = tf.less(a, 0f); | |||
@@ -144,6 +144,24 @@ namespace TensorFlowNET.UnitTest | |||
var o = sess.run(d); | |||
Assert.IsTrue(o.array_equal(check)); | |||
} | |||
d = tf.cast(tf.logical_not(b), tf.int32); | |||
check = np.array(new[] { 1, 1, 1, 1, 0, 0, 0, 0 }); | |||
using (var sess = tf.Session()) | |||
{ | |||
var o = sess.run(d); | |||
Assert.IsTrue(o.array_equal(check)); | |||
} | |||
d = tf.cast(tf.logical_or(b, c), tf.int32); | |||
check = np.array(new[] { 1, 1, 1, 1, 1, 1, 1, 1 }); | |||
using (var sess = tf.Session()) | |||
{ | |||
var o = sess.run(d); | |||
Assert.IsTrue(o.array_equal(check)); | |||
} | |||
} | |||
[TestMethod] | |||