@@ -11,12 +11,20 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "UnitTest", "test\TensorFlow | |||
EndProject | |||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras", "src\TensorFlowNET.Keras\Tensorflow.Keras.csproj", "{6268B461-486A-460B-9B3C-86493CBBAAF7}" | |||
EndProject | |||
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Tensorflow.Keras.UnitTest", "test\Tensorflow.Keras.UnitTest\Tensorflow.Keras.UnitTest.csproj", "{EB92DD90-6346-41FB-B967-2B33A860AD98}" | |||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras.UnitTest", "test\Tensorflow.Keras.UnitTest\Tensorflow.Keras.UnitTest.csproj", "{EB92DD90-6346-41FB-B967-2B33A860AD98}" | |||
EndProject | |||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Hub", "src\TensorFlowNET.Hub\Tensorflow.Hub.csproj", "{95B077C1-E21B-486F-8BDD-1C902FE687AB}" | |||
EndProject | |||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}" | |||
EndProject | |||
Global | |||
GlobalSection(SolutionConfigurationPlatforms) = preSolution | |||
Debug|Any CPU = Debug|Any CPU | |||
Debug|x64 = Debug|x64 | |||
Debug-Minimal|Any CPU = Debug-Minimal|Any CPU | |||
Debug-Minimal|x64 = Debug-Minimal|x64 | |||
Publish|Any CPU = Publish|Any CPU | |||
Publish|x64 = Publish|x64 | |||
Release|Any CPU = Release|Any CPU | |||
Release|x64 = Release|x64 | |||
EndGlobalSection | |||
@@ -25,6 +33,14 @@ Global | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|x64.ActiveCfg = Debug|Any CPU | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|x64.Build.0 = Debug|Any CPU | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|x64.Build.0 = Debug|Any CPU | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|Any CPU.ActiveCfg = Release|Any CPU | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|Any CPU.Build.0 = Release|Any CPU | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|x64.ActiveCfg = Release|Any CPU | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|x64.Build.0 = Release|Any CPU | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.Build.0 = Release|Any CPU | |||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|x64.ActiveCfg = Release|Any CPU | |||
@@ -33,6 +49,14 @@ Global | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|x64.ActiveCfg = Debug|Any CPU | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|x64.Build.0 = Debug|Any CPU | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|x64.Build.0 = Debug|Any CPU | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|Any CPU.ActiveCfg = Release|Any CPU | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|Any CPU.Build.0 = Release|Any CPU | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|x64.ActiveCfg = Release|Any CPU | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|x64.Build.0 = Release|Any CPU | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|Any CPU.Build.0 = Release|Any CPU | |||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|x64.ActiveCfg = Release|Any CPU | |||
@@ -41,6 +65,14 @@ Global | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|x64.ActiveCfg = Debug|Any CPU | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|x64.Build.0 = Debug|Any CPU | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|x64.Build.0 = Debug|Any CPU | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|Any CPU.ActiveCfg = Release|Any CPU | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|Any CPU.Build.0 = Release|Any CPU | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|x64.ActiveCfg = Release|Any CPU | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|x64.Build.0 = Release|Any CPU | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|Any CPU.Build.0 = Release|Any CPU | |||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x64.ActiveCfg = Release|Any CPU | |||
@@ -49,6 +81,14 @@ Global | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|x64.ActiveCfg = Debug|Any CPU | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|x64.Build.0 = Debug|Any CPU | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|x64.Build.0 = Debug|Any CPU | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|Any CPU.ActiveCfg = Release|Any CPU | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|Any CPU.Build.0 = Release|Any CPU | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|x64.ActiveCfg = Release|Any CPU | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|x64.Build.0 = Release|Any CPU | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|Any CPU.Build.0 = Release|Any CPU | |||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|x64.ActiveCfg = Release|Any CPU | |||
@@ -57,10 +97,50 @@ Global | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|x64.ActiveCfg = Debug|Any CPU | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|x64.Build.0 = Debug|Any CPU | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|x64.Build.0 = Debug|Any CPU | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|Any CPU.ActiveCfg = Release|Any CPU | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|Any CPU.Build.0 = Release|Any CPU | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|x64.ActiveCfg = Release|Any CPU | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|x64.Build.0 = Release|Any CPU | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|Any CPU.Build.0 = Release|Any CPU | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x64.ActiveCfg = Release|Any CPU | |||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x64.Build.0 = Release|Any CPU | |||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug|x64.ActiveCfg = Debug|Any CPU | |||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug|x64.Build.0 = Debug|Any CPU | |||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | |||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU | |||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug-Minimal|x64.Build.0 = Debug|Any CPU | |||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Publish|Any CPU.ActiveCfg = Debug|Any CPU | |||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Publish|Any CPU.Build.0 = Debug|Any CPU | |||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Publish|x64.ActiveCfg = Debug|Any CPU | |||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Publish|x64.Build.0 = Debug|Any CPU | |||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Release|Any CPU.Build.0 = Release|Any CPU | |||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Release|x64.ActiveCfg = Release|Any CPU | |||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Release|x64.Build.0 = Release|Any CPU | |||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug|x64.ActiveCfg = Debug|Any CPU | |||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug|x64.Build.0 = Debug|Any CPU | |||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | |||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU | |||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug-Minimal|x64.Build.0 = Debug|Any CPU | |||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Publish|Any CPU.ActiveCfg = Release|Any CPU | |||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Publish|Any CPU.Build.0 = Release|Any CPU | |||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Publish|x64.ActiveCfg = Release|Any CPU | |||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Publish|x64.Build.0 = Release|Any CPU | |||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Release|Any CPU.Build.0 = Release|Any CPU | |||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Release|x64.ActiveCfg = Release|Any CPU | |||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Release|x64.Build.0 = Release|Any CPU | |||
EndGlobalSection | |||
GlobalSection(SolutionProperties) = preSolution | |||
HideSolutionNode = FALSE | |||
@@ -88,7 +88,7 @@ namespace Tensorflow | |||
case ICollection arr: | |||
return arr.Count; | |||
case NDArray ndArray: | |||
return ndArray.shape[0]; | |||
return ndArray.ndim == 0 ? 1 : ndArray.shape[0]; | |||
case IEnumerable enumerable: | |||
return enumerable.OfType<object>().Count(); | |||
} | |||
@@ -60,10 +60,13 @@ https://tensorflownet.readthedocs.io</Description> | |||
<ItemGroup> | |||
<PackageReference Include="Google.Protobuf" Version="3.11.2" /> | |||
<PackageReference Include="NumSharp" Version="0.20.5" /> | |||
</ItemGroup> | |||
<ItemGroup> | |||
<Folder Include="Keras\Initializers\" /> | |||
</ItemGroup> | |||
<ItemGroup> | |||
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||
</ItemGroup> | |||
</Project> |
@@ -0,0 +1,13 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using NumSharp; | |||
namespace Tensorflow.Hub | |||
{ | |||
public abstract class DataSetBase : IDataSet | |||
{ | |||
public NDArray Data { get; protected set; } | |||
public NDArray Labels { get; protected set; } | |||
} | |||
} |
@@ -0,0 +1,46 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using NumSharp; | |||
namespace Tensorflow.Hub | |||
{ | |||
public class Datasets<TDataSet> where TDataSet : IDataSet | |||
{ | |||
public TDataSet Train { get; private set; } | |||
public TDataSet Validation { get; private set; } | |||
public TDataSet Test { get; private set; } | |||
public Datasets(TDataSet train, TDataSet validation, TDataSet test) | |||
{ | |||
Train = train; | |||
Validation = validation; | |||
Test = test; | |||
} | |||
public (NDArray, NDArray) Randomize(NDArray x, NDArray y) | |||
{ | |||
var perm = np.random.permutation(y.shape[0]); | |||
np.random.shuffle(perm); | |||
return (x[perm], y[perm]); | |||
} | |||
/// <summary> | |||
/// selects a few number of images determined by the batch_size variable (if you don't know why, read about Stochastic Gradient Method) | |||
/// </summary> | |||
/// <param name="x"></param> | |||
/// <param name="y"></param> | |||
/// <param name="start"></param> | |||
/// <param name="end"></param> | |||
/// <returns></returns> | |||
public (NDArray, NDArray) GetNextBatch(NDArray x, NDArray y, int start, int end) | |||
{ | |||
var slice = new Slice(start, end); | |||
var x_batch = x[slice]; | |||
var y_batch = y[slice]; | |||
return (x_batch, y_batch); | |||
} | |||
} | |||
} |
@@ -0,0 +1,13 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using NumSharp; | |||
namespace Tensorflow.Hub | |||
{ | |||
public interface IDataSet | |||
{ | |||
NDArray Data { get; } | |||
NDArray Labels { get; } | |||
} | |||
} |
@@ -0,0 +1,14 @@ | |||
using System; | |||
using System.Threading.Tasks; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using NumSharp; | |||
namespace Tensorflow.Hub | |||
{ | |||
public interface IModelLoader<TDataSet> | |||
where TDataSet : IDataSet | |||
{ | |||
Task<Datasets<TDataSet>> LoadAsync(ModelLoadSetting setting); | |||
} | |||
} |
@@ -0,0 +1,88 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Text; | |||
using NumSharp; | |||
using Tensorflow; | |||
namespace Tensorflow.Hub | |||
{ | |||
public class MnistDataSet : DataSetBase | |||
{ | |||
public int NumOfExamples { get; private set; } | |||
public int EpochsCompleted { get; private set; } | |||
public int IndexInEpoch { get; private set; } | |||
public MnistDataSet(NDArray images, NDArray labels, Type dataType, bool reshape) | |||
{ | |||
EpochsCompleted = 0; | |||
IndexInEpoch = 0; | |||
NumOfExamples = images.shape[0]; | |||
images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]); | |||
images = images.astype(dataType); | |||
// for debug np.multiply performance | |||
var sw = new Stopwatch(); | |||
sw.Start(); | |||
images = np.multiply(images, 1.0f / 255.0f); | |||
sw.Stop(); | |||
Console.WriteLine($"{sw.ElapsedMilliseconds}ms"); | |||
Data = images; | |||
labels = labels.astype(dataType); | |||
Labels = labels; | |||
} | |||
public (NDArray, NDArray) GetNextBatch(int batch_size, bool fake_data = false, bool shuffle = true) | |||
{ | |||
if (IndexInEpoch >= NumOfExamples) | |||
IndexInEpoch = 0; | |||
var start = IndexInEpoch; | |||
// Shuffle for the first epoch | |||
if(EpochsCompleted == 0 && start == 0 && shuffle) | |||
{ | |||
var perm0 = np.arange(NumOfExamples); | |||
np.random.shuffle(perm0); | |||
Data = Data[perm0]; | |||
Labels = Labels[perm0]; | |||
} | |||
// Go to the next epoch | |||
if (start + batch_size > NumOfExamples) | |||
{ | |||
// Finished epoch | |||
EpochsCompleted += 1; | |||
// Get the rest examples in this epoch | |||
var rest_num_examples = NumOfExamples - start; | |||
var images_rest_part = Data[np.arange(start, NumOfExamples)]; | |||
var labels_rest_part = Labels[np.arange(start, NumOfExamples)]; | |||
// Shuffle the data | |||
if (shuffle) | |||
{ | |||
var perm = np.arange(NumOfExamples); | |||
np.random.shuffle(perm); | |||
Data = Data[perm]; | |||
Labels = Labels[perm]; | |||
} | |||
start = 0; | |||
IndexInEpoch = batch_size - rest_num_examples; | |||
var end = IndexInEpoch; | |||
var images_new_part = Data[np.arange(start, end)]; | |||
var labels_new_part = Labels[np.arange(start, end)]; | |||
return (np.concatenate(new[] { images_rest_part, images_new_part }, axis: 0), | |||
np.concatenate(new[] { labels_rest_part, labels_new_part }, axis: 0)); | |||
} | |||
else | |||
{ | |||
IndexInEpoch += batch_size; | |||
var end = IndexInEpoch; | |||
return (Data[np.arange(start, end)], Labels[np.arange(start, end)]); | |||
} | |||
} | |||
} | |||
} |
@@ -0,0 +1,184 @@ | |||
using System; | |||
using System.Threading.Tasks; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using System.IO; | |||
using NumSharp; | |||
namespace Tensorflow.Hub | |||
{ | |||
public class MnistModelLoader : IModelLoader<MnistDataSet> | |||
{ | |||
private const string DEFAULT_SOURCE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/"; | |||
private const string TRAIN_IMAGES = "train-images-idx3-ubyte.gz"; | |||
private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz"; | |||
private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; | |||
private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; | |||
public static async Task<Datasets<MnistDataSet>> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null, bool showProgressInConsole = false) | |||
{ | |||
var loader = new MnistModelLoader(); | |||
var setting = new ModelLoadSetting | |||
{ | |||
TrainDir = trainDir, | |||
OneHot = oneHot, | |||
ShowProgressInConsole = showProgressInConsole | |||
}; | |||
if (trainSize.HasValue) | |||
setting.TrainSize = trainSize.Value; | |||
if (validationSize.HasValue) | |||
setting.ValidationSize = validationSize.Value; | |||
if (testSize.HasValue) | |||
setting.TestSize = testSize.Value; | |||
return await loader.LoadAsync(setting); | |||
} | |||
public async Task<Datasets<MnistDataSet>> LoadAsync(ModelLoadSetting setting) | |||
{ | |||
if (setting.TrainSize.HasValue && setting.ValidationSize >= setting.TrainSize.Value) | |||
throw new ArgumentException("Validation set should be smaller than training set"); | |||
var sourceUrl = setting.SourceUrl; | |||
if (string.IsNullOrEmpty(sourceUrl)) | |||
sourceUrl = DEFAULT_SOURCE_URL; | |||
// load train images | |||
await this.DownloadAsync(sourceUrl + TRAIN_IMAGES, setting.TrainDir, TRAIN_IMAGES, showProgressInConsole: setting.ShowProgressInConsole) | |||
.ShowProgressInConsole(setting.ShowProgressInConsole); | |||
await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_IMAGES), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) | |||
.ShowProgressInConsole(setting.ShowProgressInConsole); | |||
var trainImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_IMAGES)), limit: setting.TrainSize); | |||
// load train labels | |||
await this.DownloadAsync(sourceUrl + TRAIN_LABELS, setting.TrainDir, TRAIN_LABELS, showProgressInConsole: setting.ShowProgressInConsole) | |||
.ShowProgressInConsole(setting.ShowProgressInConsole); | |||
await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_LABELS), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) | |||
.ShowProgressInConsole(setting.ShowProgressInConsole); | |||
var trainLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_LABELS)), one_hot: setting.OneHot, limit: setting.TrainSize); | |||
// load test images | |||
await this.DownloadAsync(sourceUrl + TEST_IMAGES, setting.TrainDir, TEST_IMAGES, showProgressInConsole: setting.ShowProgressInConsole) | |||
.ShowProgressInConsole(setting.ShowProgressInConsole); | |||
await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_IMAGES), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) | |||
.ShowProgressInConsole(setting.ShowProgressInConsole); | |||
var testImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_IMAGES)), limit: setting.TestSize); | |||
// load test labels | |||
await this.DownloadAsync(sourceUrl + TEST_LABELS, setting.TrainDir, TEST_LABELS, showProgressInConsole: setting.ShowProgressInConsole) | |||
.ShowProgressInConsole(setting.ShowProgressInConsole); | |||
await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_LABELS), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) | |||
.ShowProgressInConsole(setting.ShowProgressInConsole); | |||
var testLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_LABELS)), one_hot: setting.OneHot, limit: setting.TestSize); | |||
var end = trainImages.shape[0]; | |||
var validationSize = setting.ValidationSize; | |||
var validationImages = trainImages[np.arange(validationSize)]; | |||
var validationLabels = trainLabels[np.arange(validationSize)]; | |||
trainImages = trainImages[np.arange(validationSize, end)]; | |||
trainLabels = trainLabels[np.arange(validationSize, end)]; | |||
var dtype = setting.DataType; | |||
var reshape = setting.ReShape; | |||
var train = new MnistDataSet(trainImages, trainLabels, dtype, reshape); | |||
var validation = new MnistDataSet(validationImages, validationLabels, dtype, reshape); | |||
var test = new MnistDataSet(testImages, testLabels, dtype, reshape); | |||
return new Datasets<MnistDataSet>(train, validation, test); | |||
} | |||
private NDArray ExtractImages(string file, int? limit = null) | |||
{ | |||
if (!Path.IsPathRooted(file)) | |||
file = Path.Combine(AppContext.BaseDirectory, file); | |||
using (var bytestream = new FileStream(file, FileMode.Open)) | |||
{ | |||
var magic = Read32(bytestream); | |||
if (magic != 2051) | |||
throw new Exception($"Invalid magic number {magic} in MNIST image file: {file}"); | |||
var num_images = Read32(bytestream); | |||
num_images = limit == null ? num_images : Math.Min(num_images, (uint)limit); | |||
var rows = Read32(bytestream); | |||
var cols = Read32(bytestream); | |||
var buf = new byte[rows * cols * num_images]; | |||
bytestream.Read(buf, 0, buf.Length); | |||
var data = np.frombuffer(buf, np.uint8); | |||
data = data.reshape((int)num_images, (int)rows, (int)cols, 1); | |||
return data; | |||
} | |||
} | |||
private NDArray ExtractLabels(string file, bool one_hot = false, int num_classes = 10, int? limit = null) | |||
{ | |||
if (!Path.IsPathRooted(file)) | |||
file = Path.Combine(AppContext.BaseDirectory, file); | |||
using (var bytestream = new FileStream(file, FileMode.Open)) | |||
{ | |||
var magic = Read32(bytestream); | |||
if (magic != 2049) | |||
throw new Exception($"Invalid magic number {magic} in MNIST label file: {file}"); | |||
var num_items = Read32(bytestream); | |||
num_items = limit == null ? num_items : Math.Min(num_items, (uint)limit); | |||
var buf = new byte[num_items]; | |||
bytestream.Read(buf, 0, buf.Length); | |||
var labels = np.frombuffer(buf, np.uint8); | |||
if (one_hot) | |||
return DenseToOneHot(labels, num_classes); | |||
return labels; | |||
} | |||
} | |||
private NDArray DenseToOneHot(NDArray labels_dense, int num_classes) | |||
{ | |||
var num_labels = labels_dense.shape[0]; | |||
var index_offset = np.arange(num_labels) * num_classes; | |||
var labels_one_hot = np.zeros(num_labels, num_classes); | |||
var labels = labels_dense.Data<byte>(); | |||
for (int row = 0; row < num_labels; row++) | |||
{ | |||
var col = labels[row]; | |||
labels_one_hot.SetData(1.0, row, col); | |||
} | |||
return labels_one_hot; | |||
} | |||
private uint Read32(FileStream bytestream) | |||
{ | |||
var buffer = new byte[sizeof(uint)]; | |||
var count = bytestream.Read(buffer, 0, 4); | |||
return np.frombuffer(buffer, ">u4").Data<uint>()[0]; | |||
} | |||
} | |||
} |
@@ -0,0 +1,20 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using NumSharp; | |||
namespace Tensorflow.Hub | |||
{ | |||
public class ModelLoadSetting | |||
{ | |||
public string TrainDir { get; set; } | |||
public bool OneHot { get; set; } | |||
public Type DataType { get; set; } = typeof(float); | |||
public bool ReShape { get; set; } | |||
public int ValidationSize { get; set; } = 5000; | |||
public int? TrainSize { get; set; } | |||
public int? TestSize { get; set; } | |||
public string SourceUrl { get; set; } | |||
public bool ShowProgressInConsole { get; set; } | |||
} | |||
} |
@@ -0,0 +1,5 @@ | |||
## TensorFlow Hub | |||
TensorFlow Hub is a library to foster the publication, discovery, and consumption of reusable parts of machine learning models. In particular, it provides **modules**, which are pre-trained pieces of TensorFlow models that can be reused on new tasks. | |||
https://github.com/tensorflow/hub |
@@ -0,0 +1,26 @@ | |||
<Project Sdk="Microsoft.NET.Sdk"> | |||
<PropertyGroup> | |||
<RootNamespace>Tensorflow.Hub</RootNamespace> | |||
<TargetFramework>netstandard2.0</TargetFramework> | |||
<Version>0.0.5</Version> | |||
<Authors>Kerry Jiang, Haiping Chen</Authors> | |||
<Company>SciSharp STACK</Company> | |||
<Copyright>Apache 2.0</Copyright> | |||
<RepositoryUrl>https://github.com/SciSharp/TensorFlow.NET</RepositoryUrl> | |||
<RepositoryType>git</RepositoryType> | |||
<PackageProjectUrl>http://scisharpstack.org</PackageProjectUrl> | |||
<PackageTags>TensorFlow, SciSharp, MachineLearning</PackageTags> | |||
<Description>TensorFlow Hub is a library to foster the publication, discovery, and consumption of reusable parts of machine learning models.</Description> | |||
<PackageId>SciSharp.TensorFlowHub</PackageId> | |||
<GeneratePackageOnBuild>true</GeneratePackageOnBuild> | |||
<PackageReleaseNotes>Fix GetNextBatch() bug.</PackageReleaseNotes> | |||
<PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&v=4</PackageIconUrl> | |||
<AssemblyName>TensorFlow.Hub</AssemblyName> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | |||
<DefineConstants>DEBUG;TRACE</DefineConstants> | |||
</PropertyGroup> | |||
<ItemGroup> | |||
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||
</ItemGroup> | |||
</Project> |
@@ -0,0 +1,137 @@ | |||
using System; | |||
using System.IO; | |||
using System.IO.Compression; | |||
using System.Collections.Generic; | |||
using System.Net; | |||
using System.Text; | |||
using System.Threading; | |||
using System.Threading.Tasks; | |||
namespace Tensorflow.Hub | |||
{ | |||
public static class Utils | |||
{ | |||
public static async Task DownloadAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string url, string saveTo) | |||
where TDataSet : IDataSet | |||
{ | |||
var dir = Path.GetDirectoryName(saveTo); | |||
var fileName = Path.GetFileName(saveTo); | |||
await modelLoader.DownloadAsync(url, dir, fileName); | |||
} | |||
public static async Task DownloadAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string url, string dirSaveTo, string fileName, bool showProgressInConsole = false) | |||
where TDataSet : IDataSet | |||
{ | |||
if (!Path.IsPathRooted(dirSaveTo)) | |||
dirSaveTo = Path.Combine(AppContext.BaseDirectory, dirSaveTo); | |||
var fileSaveTo = Path.Combine(dirSaveTo, fileName); | |||
if (showProgressInConsole) | |||
{ | |||
Console.WriteLine($"Downloading {fileName}"); | |||
} | |||
if (File.Exists(fileSaveTo)) | |||
{ | |||
if (showProgressInConsole) | |||
{ | |||
Console.WriteLine($"The file {fileName} already exists"); | |||
} | |||
return; | |||
} | |||
Directory.CreateDirectory(dirSaveTo); | |||
using (var wc = new WebClient()) | |||
{ | |||
await wc.DownloadFileTaskAsync(url, fileSaveTo).ConfigureAwait(false); | |||
} | |||
} | |||
public static async Task UnzipAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string zipFile, string saveTo, bool showProgressInConsole = false) | |||
where TDataSet : IDataSet | |||
{ | |||
if (!Path.IsPathRooted(saveTo)) | |||
saveTo = Path.Combine(AppContext.BaseDirectory, saveTo); | |||
Directory.CreateDirectory(saveTo); | |||
if (!Path.IsPathRooted(zipFile)) | |||
zipFile = Path.Combine(AppContext.BaseDirectory, zipFile); | |||
var destFileName = Path.GetFileNameWithoutExtension(zipFile); | |||
var destFilePath = Path.Combine(saveTo, destFileName); | |||
if (showProgressInConsole) | |||
Console.WriteLine($"Unzippinng {Path.GetFileName(zipFile)}"); | |||
if (File.Exists(destFilePath)) | |||
{ | |||
if (showProgressInConsole) | |||
Console.WriteLine($"The file {destFileName} already exists"); | |||
} | |||
using (GZipStream unzipStream = new GZipStream(File.OpenRead(zipFile), CompressionMode.Decompress)) | |||
{ | |||
using (var destStream = File.Create(destFilePath)) | |||
{ | |||
await unzipStream.CopyToAsync(destStream).ConfigureAwait(false); | |||
await destStream.FlushAsync().ConfigureAwait(false); | |||
destStream.Close(); | |||
} | |||
unzipStream.Close(); | |||
} | |||
} | |||
public static async Task ShowProgressInConsole(this Task task, bool enable) | |||
{ | |||
if (!enable) | |||
{ | |||
await task; | |||
return; | |||
} | |||
var cts = new CancellationTokenSource(); | |||
var showProgressTask = ShowProgressInConsole(cts); | |||
try | |||
{ | |||
await task; | |||
} | |||
finally | |||
{ | |||
cts.Cancel(); | |||
} | |||
await showProgressTask; | |||
Console.WriteLine("Done."); | |||
} | |||
private static async Task ShowProgressInConsole(CancellationTokenSource cts) | |||
{ | |||
var cols = 0; | |||
await Task.Delay(100); | |||
while (!cts.IsCancellationRequested) | |||
{ | |||
await Task.Delay(100); | |||
Console.Write("."); | |||
cols++; | |||
if (cols % 50 == 0) | |||
{ | |||
Console.WriteLine(); | |||
} | |||
} | |||
if (cols > 0) | |||
Console.WriteLine(); | |||
} | |||
} | |||
} |
@@ -0,0 +1,24 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System.Threading.Tasks; | |||
using Tensorflow.Hub; | |||
namespace UnitTest | |||
{ | |||
[TestClass] | |||
public class MnistModelLoaderTest | |||
{ | |||
[TestMethod] | |||
public async Task TestLoad() | |||
{ | |||
var loader = new MnistModelLoader(); | |||
var result = await loader.LoadAsync(new ModelLoadSetting | |||
{ | |||
TrainDir = "mnist", | |||
OneHot = true, | |||
ValidationSize = 5000, | |||
}); | |||
Assert.IsNotNull(result); | |||
} | |||
} | |||
} |
@@ -37,6 +37,7 @@ | |||
<ItemGroup> | |||
<ProjectReference Include="..\..\src\TensorFlowNET.Core\Tensorflow.Binding.csproj" /> | |||
<ProjectReference Include="..\..\src\TensorFlowNET.Hub\Tensorflow.Hub.csproj" /> | |||
</ItemGroup> | |||
<ItemGroup> | |||