Browse Source

Tensorflow Hub

tags/v0.20
Oceania2018 5 years ago
parent
commit
3516079dbb
15 changed files with 657 additions and 3 deletions
  1. +81
    -1
      TensorFlow.NET.sln
  2. +1
    -1
      src/TensorFlowNET.Core/Binding.Util.cs
  3. +4
    -1
      src/TensorFlowNET.Core/TensorFlow.Binding.csproj
  4. +13
    -0
      src/TensorFlowNET.Hub/DataSetBase.cs
  5. +46
    -0
      src/TensorFlowNET.Hub/Datasets.cs
  6. +13
    -0
      src/TensorFlowNET.Hub/IDataSet.cs
  7. +14
    -0
      src/TensorFlowNET.Hub/IModelLoader.cs
  8. +88
    -0
      src/TensorFlowNET.Hub/MnistDataSet.cs
  9. +184
    -0
      src/TensorFlowNET.Hub/MnistModelLoader.cs
  10. +20
    -0
      src/TensorFlowNET.Hub/ModelLoadSetting.cs
  11. +5
    -0
      src/TensorFlowNET.Hub/README.md
  12. +26
    -0
      src/TensorFlowNET.Hub/Tensorflow.Hub.csproj
  13. +137
    -0
      src/TensorFlowNET.Hub/Utils.cs
  14. +24
    -0
      test/TensorFlowNET.UnitTest/Hub/MnistModelLoaderTest.cs
  15. +1
    -0
      test/TensorFlowNET.UnitTest/UnitTest.csproj

+ 81
- 1
TensorFlow.NET.sln View File

@@ -11,12 +11,20 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "UnitTest", "test\TensorFlow
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras", "src\TensorFlowNET.Keras\Tensorflow.Keras.csproj", "{6268B461-486A-460B-9B3C-86493CBBAAF7}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Tensorflow.Keras.UnitTest", "test\Tensorflow.Keras.UnitTest\Tensorflow.Keras.UnitTest.csproj", "{EB92DD90-6346-41FB-B967-2B33A860AD98}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras.UnitTest", "test\Tensorflow.Keras.UnitTest\Tensorflow.Keras.UnitTest.csproj", "{EB92DD90-6346-41FB-B967-2B33A860AD98}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Hub", "src\TensorFlowNET.Hub\Tensorflow.Hub.csproj", "{95B077C1-E21B-486F-8BDD-1C902FE687AB}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Debug|x64 = Debug|x64
Debug-Minimal|Any CPU = Debug-Minimal|Any CPU
Debug-Minimal|x64 = Debug-Minimal|x64
Publish|Any CPU = Publish|Any CPU
Publish|x64 = Publish|x64
Release|Any CPU = Release|Any CPU
Release|x64 = Release|x64
EndGlobalSection
@@ -25,6 +33,14 @@ Global
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.Build.0 = Debug|Any CPU
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|x64.ActiveCfg = Debug|Any CPU
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|x64.Build.0 = Debug|Any CPU
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|x64.Build.0 = Debug|Any CPU
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|Any CPU.ActiveCfg = Release|Any CPU
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|Any CPU.Build.0 = Release|Any CPU
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|x64.ActiveCfg = Release|Any CPU
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|x64.Build.0 = Release|Any CPU
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.ActiveCfg = Release|Any CPU
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.Build.0 = Release|Any CPU
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|x64.ActiveCfg = Release|Any CPU
@@ -33,6 +49,14 @@ Global
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|Any CPU.Build.0 = Debug|Any CPU
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|x64.ActiveCfg = Debug|Any CPU
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|x64.Build.0 = Debug|Any CPU
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|x64.Build.0 = Debug|Any CPU
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|Any CPU.ActiveCfg = Release|Any CPU
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|Any CPU.Build.0 = Release|Any CPU
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|x64.ActiveCfg = Release|Any CPU
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|x64.Build.0 = Release|Any CPU
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|Any CPU.ActiveCfg = Release|Any CPU
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|Any CPU.Build.0 = Release|Any CPU
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|x64.ActiveCfg = Release|Any CPU
@@ -41,6 +65,14 @@ Global
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|Any CPU.Build.0 = Debug|Any CPU
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|x64.ActiveCfg = Debug|Any CPU
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|x64.Build.0 = Debug|Any CPU
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|x64.Build.0 = Debug|Any CPU
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|Any CPU.ActiveCfg = Release|Any CPU
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|Any CPU.Build.0 = Release|Any CPU
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|x64.ActiveCfg = Release|Any CPU
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|x64.Build.0 = Release|Any CPU
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|Any CPU.ActiveCfg = Release|Any CPU
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|Any CPU.Build.0 = Release|Any CPU
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x64.ActiveCfg = Release|Any CPU
@@ -49,6 +81,14 @@ Global
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|Any CPU.Build.0 = Debug|Any CPU
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|x64.ActiveCfg = Debug|Any CPU
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|x64.Build.0 = Debug|Any CPU
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|x64.Build.0 = Debug|Any CPU
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|Any CPU.ActiveCfg = Release|Any CPU
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|Any CPU.Build.0 = Release|Any CPU
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|x64.ActiveCfg = Release|Any CPU
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|x64.Build.0 = Release|Any CPU
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|Any CPU.ActiveCfg = Release|Any CPU
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|Any CPU.Build.0 = Release|Any CPU
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|x64.ActiveCfg = Release|Any CPU
@@ -57,10 +97,50 @@ Global
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|Any CPU.Build.0 = Debug|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|x64.ActiveCfg = Debug|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|x64.Build.0 = Debug|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|x64.Build.0 = Debug|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|Any CPU.ActiveCfg = Release|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|Any CPU.Build.0 = Release|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|x64.ActiveCfg = Release|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|x64.Build.0 = Release|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|Any CPU.ActiveCfg = Release|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|Any CPU.Build.0 = Release|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x64.ActiveCfg = Release|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x64.Build.0 = Release|Any CPU
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug|Any CPU.Build.0 = Debug|Any CPU
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug|x64.ActiveCfg = Debug|Any CPU
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug|x64.Build.0 = Debug|Any CPU
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug-Minimal|x64.Build.0 = Debug|Any CPU
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Publish|Any CPU.ActiveCfg = Debug|Any CPU
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Publish|Any CPU.Build.0 = Debug|Any CPU
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Publish|x64.ActiveCfg = Debug|Any CPU
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Publish|x64.Build.0 = Debug|Any CPU
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Release|Any CPU.ActiveCfg = Release|Any CPU
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Release|Any CPU.Build.0 = Release|Any CPU
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Release|x64.ActiveCfg = Release|Any CPU
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Release|x64.Build.0 = Release|Any CPU
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug|Any CPU.Build.0 = Debug|Any CPU
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug|x64.ActiveCfg = Debug|Any CPU
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug|x64.Build.0 = Debug|Any CPU
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug-Minimal|x64.Build.0 = Debug|Any CPU
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Publish|Any CPU.ActiveCfg = Release|Any CPU
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Publish|Any CPU.Build.0 = Release|Any CPU
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Publish|x64.ActiveCfg = Release|Any CPU
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Publish|x64.Build.0 = Release|Any CPU
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Release|Any CPU.ActiveCfg = Release|Any CPU
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Release|Any CPU.Build.0 = Release|Any CPU
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Release|x64.ActiveCfg = Release|Any CPU
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Release|x64.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE


+ 1
- 1
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -88,7 +88,7 @@ namespace Tensorflow
case ICollection arr:
return arr.Count;
case NDArray ndArray:
return ndArray.shape[0];
return ndArray.ndim == 0 ? 1 : ndArray.shape[0];
case IEnumerable enumerable:
return enumerable.OfType<object>().Count();
}


+ 4
- 1
src/TensorFlowNET.Core/TensorFlow.Binding.csproj View File

@@ -60,10 +60,13 @@ https://tensorflownet.readthedocs.io</Description>

<ItemGroup>
<PackageReference Include="Google.Protobuf" Version="3.11.2" />
<PackageReference Include="NumSharp" Version="0.20.5" />
</ItemGroup>

<ItemGroup>
<Folder Include="Keras\Initializers\" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" />
</ItemGroup>
</Project>

+ 13
- 0
src/TensorFlowNET.Hub/DataSetBase.cs View File

@@ -0,0 +1,13 @@
using System;
using System.Collections.Generic;
using System.Text;
using NumSharp;

namespace Tensorflow.Hub
{
public abstract class DataSetBase : IDataSet
{
public NDArray Data { get; protected set; }
public NDArray Labels { get; protected set; }
}
}

+ 46
- 0
src/TensorFlowNET.Hub/Datasets.cs View File

@@ -0,0 +1,46 @@
using System;
using System.Collections.Generic;
using System.Text;
using NumSharp;

namespace Tensorflow.Hub
{
public class Datasets<TDataSet> where TDataSet : IDataSet
{
public TDataSet Train { get; private set; }

public TDataSet Validation { get; private set; }

public TDataSet Test { get; private set; }

public Datasets(TDataSet train, TDataSet validation, TDataSet test)
{
Train = train;
Validation = validation;
Test = test;
}

public (NDArray, NDArray) Randomize(NDArray x, NDArray y)
{
var perm = np.random.permutation(y.shape[0]);
np.random.shuffle(perm);
return (x[perm], y[perm]);
}

/// <summary>
/// selects a few number of images determined by the batch_size variable (if you don't know why, read about Stochastic Gradient Method)
/// </summary>
/// <param name="x"></param>
/// <param name="y"></param>
/// <param name="start"></param>
/// <param name="end"></param>
/// <returns></returns>
public (NDArray, NDArray) GetNextBatch(NDArray x, NDArray y, int start, int end)
{
var slice = new Slice(start, end);
var x_batch = x[slice];
var y_batch = y[slice];
return (x_batch, y_batch);
}
}
}

+ 13
- 0
src/TensorFlowNET.Hub/IDataSet.cs View File

@@ -0,0 +1,13 @@
using System;
using System.Collections.Generic;
using System.Text;
using NumSharp;

namespace Tensorflow.Hub
{
public interface IDataSet
{
NDArray Data { get; }
NDArray Labels { get; }
}
}

+ 14
- 0
src/TensorFlowNET.Hub/IModelLoader.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Threading.Tasks;
using System.Collections.Generic;
using System.Text;
using NumSharp;

namespace Tensorflow.Hub
{
public interface IModelLoader<TDataSet>
where TDataSet : IDataSet
{
Task<Datasets<TDataSet>> LoadAsync(ModelLoadSetting setting);
}
}

+ 88
- 0
src/TensorFlowNET.Hub/MnistDataSet.cs View File

@@ -0,0 +1,88 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
using NumSharp;
using Tensorflow;

namespace Tensorflow.Hub
{
public class MnistDataSet : DataSetBase
{
public int NumOfExamples { get; private set; }
public int EpochsCompleted { get; private set; }
public int IndexInEpoch { get; private set; }

public MnistDataSet(NDArray images, NDArray labels, Type dataType, bool reshape)
{
EpochsCompleted = 0;
IndexInEpoch = 0;

NumOfExamples = images.shape[0];

images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]);
images = images.astype(dataType);
// for debug np.multiply performance
var sw = new Stopwatch();
sw.Start();
images = np.multiply(images, 1.0f / 255.0f);
sw.Stop();
Console.WriteLine($"{sw.ElapsedMilliseconds}ms");
Data = images;

labels = labels.astype(dataType);
Labels = labels;
}

public (NDArray, NDArray) GetNextBatch(int batch_size, bool fake_data = false, bool shuffle = true)
{
if (IndexInEpoch >= NumOfExamples)
IndexInEpoch = 0;

var start = IndexInEpoch;
// Shuffle for the first epoch
if(EpochsCompleted == 0 && start == 0 && shuffle)
{
var perm0 = np.arange(NumOfExamples);
np.random.shuffle(perm0);
Data = Data[perm0];
Labels = Labels[perm0];
}

// Go to the next epoch
if (start + batch_size > NumOfExamples)
{
// Finished epoch
EpochsCompleted += 1;

// Get the rest examples in this epoch
var rest_num_examples = NumOfExamples - start;
var images_rest_part = Data[np.arange(start, NumOfExamples)];
var labels_rest_part = Labels[np.arange(start, NumOfExamples)];
// Shuffle the data
if (shuffle)
{
var perm = np.arange(NumOfExamples);
np.random.shuffle(perm);
Data = Data[perm];
Labels = Labels[perm];
}

start = 0;
IndexInEpoch = batch_size - rest_num_examples;
var end = IndexInEpoch;
var images_new_part = Data[np.arange(start, end)];
var labels_new_part = Labels[np.arange(start, end)];

return (np.concatenate(new[] { images_rest_part, images_new_part }, axis: 0),
np.concatenate(new[] { labels_rest_part, labels_new_part }, axis: 0));
}
else
{
IndexInEpoch += batch_size;
var end = IndexInEpoch;
return (Data[np.arange(start, end)], Labels[np.arange(start, end)]);
}
}
}
}

+ 184
- 0
src/TensorFlowNET.Hub/MnistModelLoader.cs View File

@@ -0,0 +1,184 @@
using System;
using System.Threading.Tasks;
using System.Collections.Generic;
using System.Text;
using System.IO;
using NumSharp;

namespace Tensorflow.Hub
{
public class MnistModelLoader : IModelLoader<MnistDataSet>
{
private const string DEFAULT_SOURCE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/";
private const string TRAIN_IMAGES = "train-images-idx3-ubyte.gz";
private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz";
private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz";
private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz";

public static async Task<Datasets<MnistDataSet>> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null, bool showProgressInConsole = false)
{
var loader = new MnistModelLoader();

var setting = new ModelLoadSetting
{
TrainDir = trainDir,
OneHot = oneHot,
ShowProgressInConsole = showProgressInConsole
};

if (trainSize.HasValue)
setting.TrainSize = trainSize.Value;

if (validationSize.HasValue)
setting.ValidationSize = validationSize.Value;

if (testSize.HasValue)
setting.TestSize = testSize.Value;

return await loader.LoadAsync(setting);
}

public async Task<Datasets<MnistDataSet>> LoadAsync(ModelLoadSetting setting)
{
if (setting.TrainSize.HasValue && setting.ValidationSize >= setting.TrainSize.Value)
throw new ArgumentException("Validation set should be smaller than training set");

var sourceUrl = setting.SourceUrl;

if (string.IsNullOrEmpty(sourceUrl))
sourceUrl = DEFAULT_SOURCE_URL;

// load train images
await this.DownloadAsync(sourceUrl + TRAIN_IMAGES, setting.TrainDir, TRAIN_IMAGES, showProgressInConsole: setting.ShowProgressInConsole)
.ShowProgressInConsole(setting.ShowProgressInConsole);

await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_IMAGES), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole)
.ShowProgressInConsole(setting.ShowProgressInConsole);

var trainImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_IMAGES)), limit: setting.TrainSize);

// load train labels
await this.DownloadAsync(sourceUrl + TRAIN_LABELS, setting.TrainDir, TRAIN_LABELS, showProgressInConsole: setting.ShowProgressInConsole)
.ShowProgressInConsole(setting.ShowProgressInConsole);

await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_LABELS), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole)
.ShowProgressInConsole(setting.ShowProgressInConsole);

var trainLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_LABELS)), one_hot: setting.OneHot, limit: setting.TrainSize);

// load test images
await this.DownloadAsync(sourceUrl + TEST_IMAGES, setting.TrainDir, TEST_IMAGES, showProgressInConsole: setting.ShowProgressInConsole)
.ShowProgressInConsole(setting.ShowProgressInConsole);

await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_IMAGES), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole)
.ShowProgressInConsole(setting.ShowProgressInConsole);

var testImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_IMAGES)), limit: setting.TestSize);

// load test labels
await this.DownloadAsync(sourceUrl + TEST_LABELS, setting.TrainDir, TEST_LABELS, showProgressInConsole: setting.ShowProgressInConsole)
.ShowProgressInConsole(setting.ShowProgressInConsole);

await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_LABELS), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole)
.ShowProgressInConsole(setting.ShowProgressInConsole);

var testLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_LABELS)), one_hot: setting.OneHot, limit: setting.TestSize);

var end = trainImages.shape[0];

var validationSize = setting.ValidationSize;

var validationImages = trainImages[np.arange(validationSize)];
var validationLabels = trainLabels[np.arange(validationSize)];
trainImages = trainImages[np.arange(validationSize, end)];
trainLabels = trainLabels[np.arange(validationSize, end)];

var dtype = setting.DataType;
var reshape = setting.ReShape;

var train = new MnistDataSet(trainImages, trainLabels, dtype, reshape);
var validation = new MnistDataSet(validationImages, validationLabels, dtype, reshape);
var test = new MnistDataSet(testImages, testLabels, dtype, reshape);

return new Datasets<MnistDataSet>(train, validation, test);
}

private NDArray ExtractImages(string file, int? limit = null)
{
if (!Path.IsPathRooted(file))
file = Path.Combine(AppContext.BaseDirectory, file);

using (var bytestream = new FileStream(file, FileMode.Open))
{
var magic = Read32(bytestream);
if (magic != 2051)
throw new Exception($"Invalid magic number {magic} in MNIST image file: {file}");
var num_images = Read32(bytestream);
num_images = limit == null ? num_images : Math.Min(num_images, (uint)limit);

var rows = Read32(bytestream);
var cols = Read32(bytestream);

var buf = new byte[rows * cols * num_images];

bytestream.Read(buf, 0, buf.Length);

var data = np.frombuffer(buf, np.uint8);
data = data.reshape((int)num_images, (int)rows, (int)cols, 1);

return data;
}
}

private NDArray ExtractLabels(string file, bool one_hot = false, int num_classes = 10, int? limit = null)
{
if (!Path.IsPathRooted(file))
file = Path.Combine(AppContext.BaseDirectory, file);
using (var bytestream = new FileStream(file, FileMode.Open))
{
var magic = Read32(bytestream);
if (magic != 2049)
throw new Exception($"Invalid magic number {magic} in MNIST label file: {file}");
var num_items = Read32(bytestream);
num_items = limit == null ? num_items : Math.Min(num_items, (uint)limit);
var buf = new byte[num_items];

bytestream.Read(buf, 0, buf.Length);
var labels = np.frombuffer(buf, np.uint8);

if (one_hot)
return DenseToOneHot(labels, num_classes);
return labels;
}
}

private NDArray DenseToOneHot(NDArray labels_dense, int num_classes)
{
var num_labels = labels_dense.shape[0];
var index_offset = np.arange(num_labels) * num_classes;
var labels_one_hot = np.zeros(num_labels, num_classes);
var labels = labels_dense.Data<byte>();
for (int row = 0; row < num_labels; row++)
{
var col = labels[row];
labels_one_hot.SetData(1.0, row, col);
}

return labels_one_hot;
}

private uint Read32(FileStream bytestream)
{
var buffer = new byte[sizeof(uint)];
var count = bytestream.Read(buffer, 0, 4);
return np.frombuffer(buffer, ">u4").Data<uint>()[0];
}
}
}

+ 20
- 0
src/TensorFlowNET.Hub/ModelLoadSetting.cs View File

@@ -0,0 +1,20 @@
using System;
using System.Collections.Generic;
using System.Text;
using NumSharp;

namespace Tensorflow.Hub
{
public class ModelLoadSetting
{
public string TrainDir { get; set; }
public bool OneHot { get; set; }
public Type DataType { get; set; } = typeof(float);
public bool ReShape { get; set; }
public int ValidationSize { get; set; } = 5000;
public int? TrainSize { get; set; }
public int? TestSize { get; set; }
public string SourceUrl { get; set; }
public bool ShowProgressInConsole { get; set; }
}
}

+ 5
- 0
src/TensorFlowNET.Hub/README.md View File

@@ -0,0 +1,5 @@
## TensorFlow Hub

TensorFlow Hub is a library to foster the publication, discovery, and consumption of reusable parts of machine learning models. In particular, it provides **modules**, which are pre-trained pieces of TensorFlow models that can be reused on new tasks.

https://github.com/tensorflow/hub

+ 26
- 0
src/TensorFlowNET.Hub/Tensorflow.Hub.csproj View File

@@ -0,0 +1,26 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<RootNamespace>Tensorflow.Hub</RootNamespace>
<TargetFramework>netstandard2.0</TargetFramework>
<Version>0.0.5</Version>
<Authors>Kerry Jiang, Haiping Chen</Authors>
<Company>SciSharp STACK</Company>
<Copyright>Apache 2.0</Copyright>
<RepositoryUrl>https://github.com/SciSharp/TensorFlow.NET</RepositoryUrl>
<RepositoryType>git</RepositoryType>
<PackageProjectUrl>http://scisharpstack.org</PackageProjectUrl>
<PackageTags>TensorFlow, SciSharp, MachineLearning</PackageTags>
<Description>TensorFlow Hub is a library to foster the publication, discovery, and consumption of reusable parts of machine learning models.</Description>
<PackageId>SciSharp.TensorFlowHub</PackageId>
<GeneratePackageOnBuild>true</GeneratePackageOnBuild>
<PackageReleaseNotes>Fix GetNextBatch() bug.</PackageReleaseNotes>
<PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&amp;v=4</PackageIconUrl>
<AssemblyName>TensorFlow.Hub</AssemblyName>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
<DefineConstants>DEBUG;TRACE</DefineConstants>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" />
</ItemGroup>
</Project>

+ 137
- 0
src/TensorFlowNET.Hub/Utils.cs View File

@@ -0,0 +1,137 @@
using System;
using System.IO;
using System.IO.Compression;
using System.Collections.Generic;
using System.Net;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace Tensorflow.Hub
{
public static class Utils
{
public static async Task DownloadAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string url, string saveTo)
where TDataSet : IDataSet
{
var dir = Path.GetDirectoryName(saveTo);
var fileName = Path.GetFileName(saveTo);
await modelLoader.DownloadAsync(url, dir, fileName);
}

public static async Task DownloadAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string url, string dirSaveTo, string fileName, bool showProgressInConsole = false)
where TDataSet : IDataSet
{
if (!Path.IsPathRooted(dirSaveTo))
dirSaveTo = Path.Combine(AppContext.BaseDirectory, dirSaveTo);

var fileSaveTo = Path.Combine(dirSaveTo, fileName);

if (showProgressInConsole)
{
Console.WriteLine($"Downloading {fileName}");
}

if (File.Exists(fileSaveTo))
{
if (showProgressInConsole)
{
Console.WriteLine($"The file {fileName} already exists");
}

return;
}
Directory.CreateDirectory(dirSaveTo);
using (var wc = new WebClient())
{
await wc.DownloadFileTaskAsync(url, fileSaveTo).ConfigureAwait(false);
}

}

public static async Task UnzipAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string zipFile, string saveTo, bool showProgressInConsole = false)
where TDataSet : IDataSet
{
if (!Path.IsPathRooted(saveTo))
saveTo = Path.Combine(AppContext.BaseDirectory, saveTo);

Directory.CreateDirectory(saveTo);

if (!Path.IsPathRooted(zipFile))
zipFile = Path.Combine(AppContext.BaseDirectory, zipFile);

var destFileName = Path.GetFileNameWithoutExtension(zipFile);
var destFilePath = Path.Combine(saveTo, destFileName);

if (showProgressInConsole)
Console.WriteLine($"Unzippinng {Path.GetFileName(zipFile)}");

if (File.Exists(destFilePath))
{
if (showProgressInConsole)
Console.WriteLine($"The file {destFileName} already exists");
}

using (GZipStream unzipStream = new GZipStream(File.OpenRead(zipFile), CompressionMode.Decompress))
{
using (var destStream = File.Create(destFilePath))
{
await unzipStream.CopyToAsync(destStream).ConfigureAwait(false);
await destStream.FlushAsync().ConfigureAwait(false);
destStream.Close();
}

unzipStream.Close();
}
}

public static async Task ShowProgressInConsole(this Task task, bool enable)
{
if (!enable)
{
await task;
return;
}

var cts = new CancellationTokenSource();

var showProgressTask = ShowProgressInConsole(cts);

try
{
await task;
}
finally
{
cts.Cancel();
}

await showProgressTask;
Console.WriteLine("Done.");
}

private static async Task ShowProgressInConsole(CancellationTokenSource cts)
{
var cols = 0;

await Task.Delay(100);

while (!cts.IsCancellationRequested)
{
await Task.Delay(100);
Console.Write(".");
cols++;

if (cols % 50 == 0)
{
Console.WriteLine();
}
}

if (cols > 0)
Console.WriteLine();
}
}
}

+ 24
- 0
test/TensorFlowNET.UnitTest/Hub/MnistModelLoaderTest.cs View File

@@ -0,0 +1,24 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Threading.Tasks;
using Tensorflow.Hub;

namespace UnitTest
{
[TestClass]
public class MnistModelLoaderTest
{
[TestMethod]
public async Task TestLoad()
{
var loader = new MnistModelLoader();
var result = await loader.LoadAsync(new ModelLoadSetting
{
TrainDir = "mnist",
OneHot = true,
ValidationSize = 5000,
});

Assert.IsNotNull(result);
}
}
}

+ 1
- 0
test/TensorFlowNET.UnitTest/UnitTest.csproj View File

@@ -37,6 +37,7 @@

<ItemGroup>
<ProjectReference Include="..\..\src\TensorFlowNET.Core\Tensorflow.Binding.csproj" />
<ProjectReference Include="..\..\src\TensorFlowNET.Hub\Tensorflow.Hub.csproj" />
</ItemGroup>

<ItemGroup>


Loading…
Cancel
Save