@@ -11,12 +11,20 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "UnitTest", "test\TensorFlow | |||||
EndProject | EndProject | ||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras", "src\TensorFlowNET.Keras\Tensorflow.Keras.csproj", "{6268B461-486A-460B-9B3C-86493CBBAAF7}" | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras", "src\TensorFlowNET.Keras\Tensorflow.Keras.csproj", "{6268B461-486A-460B-9B3C-86493CBBAAF7}" | ||||
EndProject | EndProject | ||||
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Tensorflow.Keras.UnitTest", "test\Tensorflow.Keras.UnitTest\Tensorflow.Keras.UnitTest.csproj", "{EB92DD90-6346-41FB-B967-2B33A860AD98}" | |||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras.UnitTest", "test\Tensorflow.Keras.UnitTest\Tensorflow.Keras.UnitTest.csproj", "{EB92DD90-6346-41FB-B967-2B33A860AD98}" | |||||
EndProject | |||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Hub", "src\TensorFlowNET.Hub\Tensorflow.Hub.csproj", "{95B077C1-E21B-486F-8BDD-1C902FE687AB}" | |||||
EndProject | |||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}" | |||||
EndProject | EndProject | ||||
Global | Global | ||||
GlobalSection(SolutionConfigurationPlatforms) = preSolution | GlobalSection(SolutionConfigurationPlatforms) = preSolution | ||||
Debug|Any CPU = Debug|Any CPU | Debug|Any CPU = Debug|Any CPU | ||||
Debug|x64 = Debug|x64 | Debug|x64 = Debug|x64 | ||||
Debug-Minimal|Any CPU = Debug-Minimal|Any CPU | |||||
Debug-Minimal|x64 = Debug-Minimal|x64 | |||||
Publish|Any CPU = Publish|Any CPU | |||||
Publish|x64 = Publish|x64 | |||||
Release|Any CPU = Release|Any CPU | Release|Any CPU = Release|Any CPU | ||||
Release|x64 = Release|x64 | Release|x64 = Release|x64 | ||||
EndGlobalSection | EndGlobalSection | ||||
@@ -25,6 +33,14 @@ Global | |||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.Build.0 = Debug|Any CPU | {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.Build.0 = Debug|Any CPU | ||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|x64.ActiveCfg = Debug|Any CPU | {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|x64.ActiveCfg = Debug|Any CPU | ||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|x64.Build.0 = Debug|Any CPU | {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|x64.Build.0 = Debug|Any CPU | ||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | |||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU | |||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|x64.Build.0 = Debug|Any CPU | |||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|Any CPU.ActiveCfg = Release|Any CPU | |||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|Any CPU.Build.0 = Release|Any CPU | |||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|x64.ActiveCfg = Release|Any CPU | |||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|x64.Build.0 = Release|Any CPU | |||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.ActiveCfg = Release|Any CPU | {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.ActiveCfg = Release|Any CPU | ||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.Build.0 = Release|Any CPU | {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.Build.0 = Release|Any CPU | ||||
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|x64.ActiveCfg = Release|Any CPU | {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|x64.ActiveCfg = Release|Any CPU | ||||
@@ -33,6 +49,14 @@ Global | |||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|Any CPU.Build.0 = Debug|Any CPU | {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|Any CPU.Build.0 = Debug|Any CPU | ||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|x64.ActiveCfg = Debug|Any CPU | {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|x64.ActiveCfg = Debug|Any CPU | ||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|x64.Build.0 = Debug|Any CPU | {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|x64.Build.0 = Debug|Any CPU | ||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | |||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU | |||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|x64.Build.0 = Debug|Any CPU | |||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|Any CPU.ActiveCfg = Release|Any CPU | |||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|Any CPU.Build.0 = Release|Any CPU | |||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|x64.ActiveCfg = Release|Any CPU | |||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|x64.Build.0 = Release|Any CPU | |||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|Any CPU.ActiveCfg = Release|Any CPU | {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|Any CPU.ActiveCfg = Release|Any CPU | ||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|Any CPU.Build.0 = Release|Any CPU | {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|Any CPU.Build.0 = Release|Any CPU | ||||
{3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|x64.ActiveCfg = Release|Any CPU | {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|x64.ActiveCfg = Release|Any CPU | ||||
@@ -41,6 +65,14 @@ Global | |||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|Any CPU.Build.0 = Debug|Any CPU | {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|Any CPU.Build.0 = Debug|Any CPU | ||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|x64.ActiveCfg = Debug|Any CPU | {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|x64.ActiveCfg = Debug|Any CPU | ||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|x64.Build.0 = Debug|Any CPU | {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|x64.Build.0 = Debug|Any CPU | ||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | |||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU | |||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|x64.Build.0 = Debug|Any CPU | |||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|Any CPU.ActiveCfg = Release|Any CPU | |||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|Any CPU.Build.0 = Release|Any CPU | |||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|x64.ActiveCfg = Release|Any CPU | |||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|x64.Build.0 = Release|Any CPU | |||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|Any CPU.ActiveCfg = Release|Any CPU | {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|Any CPU.ActiveCfg = Release|Any CPU | ||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|Any CPU.Build.0 = Release|Any CPU | {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|Any CPU.Build.0 = Release|Any CPU | ||||
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x64.ActiveCfg = Release|Any CPU | {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x64.ActiveCfg = Release|Any CPU | ||||
@@ -49,6 +81,14 @@ Global | |||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|Any CPU.Build.0 = Debug|Any CPU | {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|Any CPU.Build.0 = Debug|Any CPU | ||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|x64.ActiveCfg = Debug|Any CPU | {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|x64.ActiveCfg = Debug|Any CPU | ||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|x64.Build.0 = Debug|Any CPU | {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|x64.Build.0 = Debug|Any CPU | ||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | |||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU | |||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|x64.Build.0 = Debug|Any CPU | |||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|Any CPU.ActiveCfg = Release|Any CPU | |||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|Any CPU.Build.0 = Release|Any CPU | |||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|x64.ActiveCfg = Release|Any CPU | |||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|x64.Build.0 = Release|Any CPU | |||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|Any CPU.ActiveCfg = Release|Any CPU | {6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|Any CPU.ActiveCfg = Release|Any CPU | ||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|Any CPU.Build.0 = Release|Any CPU | {6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|Any CPU.Build.0 = Release|Any CPU | ||||
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|x64.ActiveCfg = Release|Any CPU | {6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|x64.ActiveCfg = Release|Any CPU | ||||
@@ -57,10 +97,50 @@ Global | |||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|Any CPU.Build.0 = Debug|Any CPU | {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|Any CPU.Build.0 = Debug|Any CPU | ||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|x64.ActiveCfg = Debug|Any CPU | {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|x64.ActiveCfg = Debug|Any CPU | ||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|x64.Build.0 = Debug|Any CPU | {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|x64.Build.0 = Debug|Any CPU | ||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | |||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU | |||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|x64.Build.0 = Debug|Any CPU | |||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|Any CPU.ActiveCfg = Release|Any CPU | |||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|Any CPU.Build.0 = Release|Any CPU | |||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|x64.ActiveCfg = Release|Any CPU | |||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|x64.Build.0 = Release|Any CPU | |||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|Any CPU.ActiveCfg = Release|Any CPU | {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|Any CPU.ActiveCfg = Release|Any CPU | ||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|Any CPU.Build.0 = Release|Any CPU | {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|Any CPU.Build.0 = Release|Any CPU | ||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x64.ActiveCfg = Release|Any CPU | {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x64.ActiveCfg = Release|Any CPU | ||||
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x64.Build.0 = Release|Any CPU | {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x64.Build.0 = Release|Any CPU | ||||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug|x64.ActiveCfg = Debug|Any CPU | |||||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug|x64.Build.0 = Debug|Any CPU | |||||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | |||||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU | |||||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug-Minimal|x64.Build.0 = Debug|Any CPU | |||||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Publish|Any CPU.ActiveCfg = Debug|Any CPU | |||||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Publish|Any CPU.Build.0 = Debug|Any CPU | |||||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Publish|x64.ActiveCfg = Debug|Any CPU | |||||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Publish|x64.Build.0 = Debug|Any CPU | |||||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Release|Any CPU.Build.0 = Release|Any CPU | |||||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Release|x64.ActiveCfg = Release|Any CPU | |||||
{95B077C1-E21B-486F-8BDD-1C902FE687AB}.Release|x64.Build.0 = Release|Any CPU | |||||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug|x64.ActiveCfg = Debug|Any CPU | |||||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug|x64.Build.0 = Debug|Any CPU | |||||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | |||||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU | |||||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug-Minimal|x64.Build.0 = Debug|Any CPU | |||||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Publish|Any CPU.ActiveCfg = Release|Any CPU | |||||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Publish|Any CPU.Build.0 = Release|Any CPU | |||||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Publish|x64.ActiveCfg = Release|Any CPU | |||||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Publish|x64.Build.0 = Release|Any CPU | |||||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Release|Any CPU.Build.0 = Release|Any CPU | |||||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Release|x64.ActiveCfg = Release|Any CPU | |||||
{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Release|x64.Build.0 = Release|Any CPU | |||||
EndGlobalSection | EndGlobalSection | ||||
GlobalSection(SolutionProperties) = preSolution | GlobalSection(SolutionProperties) = preSolution | ||||
HideSolutionNode = FALSE | HideSolutionNode = FALSE | ||||
@@ -88,7 +88,7 @@ namespace Tensorflow | |||||
case ICollection arr: | case ICollection arr: | ||||
return arr.Count; | return arr.Count; | ||||
case NDArray ndArray: | case NDArray ndArray: | ||||
return ndArray.shape[0]; | |||||
return ndArray.ndim == 0 ? 1 : ndArray.shape[0]; | |||||
case IEnumerable enumerable: | case IEnumerable enumerable: | ||||
return enumerable.OfType<object>().Count(); | return enumerable.OfType<object>().Count(); | ||||
} | } | ||||
@@ -60,10 +60,13 @@ https://tensorflownet.readthedocs.io</Description> | |||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="Google.Protobuf" Version="3.11.2" /> | <PackageReference Include="Google.Protobuf" Version="3.11.2" /> | ||||
<PackageReference Include="NumSharp" Version="0.20.5" /> | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
<Folder Include="Keras\Initializers\" /> | <Folder Include="Keras\Initializers\" /> | ||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | |||||
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||||
</ItemGroup> | |||||
</Project> | </Project> |
@@ -0,0 +1,13 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using NumSharp; | |||||
namespace Tensorflow.Hub | |||||
{ | |||||
public abstract class DataSetBase : IDataSet | |||||
{ | |||||
public NDArray Data { get; protected set; } | |||||
public NDArray Labels { get; protected set; } | |||||
} | |||||
} |
@@ -0,0 +1,46 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using NumSharp; | |||||
namespace Tensorflow.Hub | |||||
{ | |||||
public class Datasets<TDataSet> where TDataSet : IDataSet | |||||
{ | |||||
public TDataSet Train { get; private set; } | |||||
public TDataSet Validation { get; private set; } | |||||
public TDataSet Test { get; private set; } | |||||
public Datasets(TDataSet train, TDataSet validation, TDataSet test) | |||||
{ | |||||
Train = train; | |||||
Validation = validation; | |||||
Test = test; | |||||
} | |||||
public (NDArray, NDArray) Randomize(NDArray x, NDArray y) | |||||
{ | |||||
var perm = np.random.permutation(y.shape[0]); | |||||
np.random.shuffle(perm); | |||||
return (x[perm], y[perm]); | |||||
} | |||||
/// <summary> | |||||
/// selects a few number of images determined by the batch_size variable (if you don't know why, read about Stochastic Gradient Method) | |||||
/// </summary> | |||||
/// <param name="x"></param> | |||||
/// <param name="y"></param> | |||||
/// <param name="start"></param> | |||||
/// <param name="end"></param> | |||||
/// <returns></returns> | |||||
public (NDArray, NDArray) GetNextBatch(NDArray x, NDArray y, int start, int end) | |||||
{ | |||||
var slice = new Slice(start, end); | |||||
var x_batch = x[slice]; | |||||
var y_batch = y[slice]; | |||||
return (x_batch, y_batch); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,13 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using NumSharp; | |||||
namespace Tensorflow.Hub | |||||
{ | |||||
public interface IDataSet | |||||
{ | |||||
NDArray Data { get; } | |||||
NDArray Labels { get; } | |||||
} | |||||
} |
@@ -0,0 +1,14 @@ | |||||
using System; | |||||
using System.Threading.Tasks; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using NumSharp; | |||||
namespace Tensorflow.Hub | |||||
{ | |||||
public interface IModelLoader<TDataSet> | |||||
where TDataSet : IDataSet | |||||
{ | |||||
Task<Datasets<TDataSet>> LoadAsync(ModelLoadSetting setting); | |||||
} | |||||
} |
@@ -0,0 +1,88 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Diagnostics; | |||||
using System.Text; | |||||
using NumSharp; | |||||
using Tensorflow; | |||||
namespace Tensorflow.Hub | |||||
{ | |||||
public class MnistDataSet : DataSetBase | |||||
{ | |||||
public int NumOfExamples { get; private set; } | |||||
public int EpochsCompleted { get; private set; } | |||||
public int IndexInEpoch { get; private set; } | |||||
public MnistDataSet(NDArray images, NDArray labels, Type dataType, bool reshape) | |||||
{ | |||||
EpochsCompleted = 0; | |||||
IndexInEpoch = 0; | |||||
NumOfExamples = images.shape[0]; | |||||
images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]); | |||||
images = images.astype(dataType); | |||||
// for debug np.multiply performance | |||||
var sw = new Stopwatch(); | |||||
sw.Start(); | |||||
images = np.multiply(images, 1.0f / 255.0f); | |||||
sw.Stop(); | |||||
Console.WriteLine($"{sw.ElapsedMilliseconds}ms"); | |||||
Data = images; | |||||
labels = labels.astype(dataType); | |||||
Labels = labels; | |||||
} | |||||
public (NDArray, NDArray) GetNextBatch(int batch_size, bool fake_data = false, bool shuffle = true) | |||||
{ | |||||
if (IndexInEpoch >= NumOfExamples) | |||||
IndexInEpoch = 0; | |||||
var start = IndexInEpoch; | |||||
// Shuffle for the first epoch | |||||
if(EpochsCompleted == 0 && start == 0 && shuffle) | |||||
{ | |||||
var perm0 = np.arange(NumOfExamples); | |||||
np.random.shuffle(perm0); | |||||
Data = Data[perm0]; | |||||
Labels = Labels[perm0]; | |||||
} | |||||
// Go to the next epoch | |||||
if (start + batch_size > NumOfExamples) | |||||
{ | |||||
// Finished epoch | |||||
EpochsCompleted += 1; | |||||
// Get the rest examples in this epoch | |||||
var rest_num_examples = NumOfExamples - start; | |||||
var images_rest_part = Data[np.arange(start, NumOfExamples)]; | |||||
var labels_rest_part = Labels[np.arange(start, NumOfExamples)]; | |||||
// Shuffle the data | |||||
if (shuffle) | |||||
{ | |||||
var perm = np.arange(NumOfExamples); | |||||
np.random.shuffle(perm); | |||||
Data = Data[perm]; | |||||
Labels = Labels[perm]; | |||||
} | |||||
start = 0; | |||||
IndexInEpoch = batch_size - rest_num_examples; | |||||
var end = IndexInEpoch; | |||||
var images_new_part = Data[np.arange(start, end)]; | |||||
var labels_new_part = Labels[np.arange(start, end)]; | |||||
return (np.concatenate(new[] { images_rest_part, images_new_part }, axis: 0), | |||||
np.concatenate(new[] { labels_rest_part, labels_new_part }, axis: 0)); | |||||
} | |||||
else | |||||
{ | |||||
IndexInEpoch += batch_size; | |||||
var end = IndexInEpoch; | |||||
return (Data[np.arange(start, end)], Labels[np.arange(start, end)]); | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,184 @@ | |||||
using System; | |||||
using System.Threading.Tasks; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using System.IO; | |||||
using NumSharp; | |||||
namespace Tensorflow.Hub | |||||
{ | |||||
public class MnistModelLoader : IModelLoader<MnistDataSet> | |||||
{ | |||||
private const string DEFAULT_SOURCE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/"; | |||||
private const string TRAIN_IMAGES = "train-images-idx3-ubyte.gz"; | |||||
private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz"; | |||||
private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; | |||||
private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; | |||||
public static async Task<Datasets<MnistDataSet>> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null, bool showProgressInConsole = false) | |||||
{ | |||||
var loader = new MnistModelLoader(); | |||||
var setting = new ModelLoadSetting | |||||
{ | |||||
TrainDir = trainDir, | |||||
OneHot = oneHot, | |||||
ShowProgressInConsole = showProgressInConsole | |||||
}; | |||||
if (trainSize.HasValue) | |||||
setting.TrainSize = trainSize.Value; | |||||
if (validationSize.HasValue) | |||||
setting.ValidationSize = validationSize.Value; | |||||
if (testSize.HasValue) | |||||
setting.TestSize = testSize.Value; | |||||
return await loader.LoadAsync(setting); | |||||
} | |||||
public async Task<Datasets<MnistDataSet>> LoadAsync(ModelLoadSetting setting) | |||||
{ | |||||
if (setting.TrainSize.HasValue && setting.ValidationSize >= setting.TrainSize.Value) | |||||
throw new ArgumentException("Validation set should be smaller than training set"); | |||||
var sourceUrl = setting.SourceUrl; | |||||
if (string.IsNullOrEmpty(sourceUrl)) | |||||
sourceUrl = DEFAULT_SOURCE_URL; | |||||
// load train images | |||||
await this.DownloadAsync(sourceUrl + TRAIN_IMAGES, setting.TrainDir, TRAIN_IMAGES, showProgressInConsole: setting.ShowProgressInConsole) | |||||
.ShowProgressInConsole(setting.ShowProgressInConsole); | |||||
await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_IMAGES), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) | |||||
.ShowProgressInConsole(setting.ShowProgressInConsole); | |||||
var trainImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_IMAGES)), limit: setting.TrainSize); | |||||
// load train labels | |||||
await this.DownloadAsync(sourceUrl + TRAIN_LABELS, setting.TrainDir, TRAIN_LABELS, showProgressInConsole: setting.ShowProgressInConsole) | |||||
.ShowProgressInConsole(setting.ShowProgressInConsole); | |||||
await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_LABELS), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) | |||||
.ShowProgressInConsole(setting.ShowProgressInConsole); | |||||
var trainLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_LABELS)), one_hot: setting.OneHot, limit: setting.TrainSize); | |||||
// load test images | |||||
await this.DownloadAsync(sourceUrl + TEST_IMAGES, setting.TrainDir, TEST_IMAGES, showProgressInConsole: setting.ShowProgressInConsole) | |||||
.ShowProgressInConsole(setting.ShowProgressInConsole); | |||||
await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_IMAGES), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) | |||||
.ShowProgressInConsole(setting.ShowProgressInConsole); | |||||
var testImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_IMAGES)), limit: setting.TestSize); | |||||
// load test labels | |||||
await this.DownloadAsync(sourceUrl + TEST_LABELS, setting.TrainDir, TEST_LABELS, showProgressInConsole: setting.ShowProgressInConsole) | |||||
.ShowProgressInConsole(setting.ShowProgressInConsole); | |||||
await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_LABELS), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) | |||||
.ShowProgressInConsole(setting.ShowProgressInConsole); | |||||
var testLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_LABELS)), one_hot: setting.OneHot, limit: setting.TestSize); | |||||
var end = trainImages.shape[0]; | |||||
var validationSize = setting.ValidationSize; | |||||
var validationImages = trainImages[np.arange(validationSize)]; | |||||
var validationLabels = trainLabels[np.arange(validationSize)]; | |||||
trainImages = trainImages[np.arange(validationSize, end)]; | |||||
trainLabels = trainLabels[np.arange(validationSize, end)]; | |||||
var dtype = setting.DataType; | |||||
var reshape = setting.ReShape; | |||||
var train = new MnistDataSet(trainImages, trainLabels, dtype, reshape); | |||||
var validation = new MnistDataSet(validationImages, validationLabels, dtype, reshape); | |||||
var test = new MnistDataSet(testImages, testLabels, dtype, reshape); | |||||
return new Datasets<MnistDataSet>(train, validation, test); | |||||
} | |||||
private NDArray ExtractImages(string file, int? limit = null) | |||||
{ | |||||
if (!Path.IsPathRooted(file)) | |||||
file = Path.Combine(AppContext.BaseDirectory, file); | |||||
using (var bytestream = new FileStream(file, FileMode.Open)) | |||||
{ | |||||
var magic = Read32(bytestream); | |||||
if (magic != 2051) | |||||
throw new Exception($"Invalid magic number {magic} in MNIST image file: {file}"); | |||||
var num_images = Read32(bytestream); | |||||
num_images = limit == null ? num_images : Math.Min(num_images, (uint)limit); | |||||
var rows = Read32(bytestream); | |||||
var cols = Read32(bytestream); | |||||
var buf = new byte[rows * cols * num_images]; | |||||
bytestream.Read(buf, 0, buf.Length); | |||||
var data = np.frombuffer(buf, np.uint8); | |||||
data = data.reshape((int)num_images, (int)rows, (int)cols, 1); | |||||
return data; | |||||
} | |||||
} | |||||
private NDArray ExtractLabels(string file, bool one_hot = false, int num_classes = 10, int? limit = null) | |||||
{ | |||||
if (!Path.IsPathRooted(file)) | |||||
file = Path.Combine(AppContext.BaseDirectory, file); | |||||
using (var bytestream = new FileStream(file, FileMode.Open)) | |||||
{ | |||||
var magic = Read32(bytestream); | |||||
if (magic != 2049) | |||||
throw new Exception($"Invalid magic number {magic} in MNIST label file: {file}"); | |||||
var num_items = Read32(bytestream); | |||||
num_items = limit == null ? num_items : Math.Min(num_items, (uint)limit); | |||||
var buf = new byte[num_items]; | |||||
bytestream.Read(buf, 0, buf.Length); | |||||
var labels = np.frombuffer(buf, np.uint8); | |||||
if (one_hot) | |||||
return DenseToOneHot(labels, num_classes); | |||||
return labels; | |||||
} | |||||
} | |||||
private NDArray DenseToOneHot(NDArray labels_dense, int num_classes) | |||||
{ | |||||
var num_labels = labels_dense.shape[0]; | |||||
var index_offset = np.arange(num_labels) * num_classes; | |||||
var labels_one_hot = np.zeros(num_labels, num_classes); | |||||
var labels = labels_dense.Data<byte>(); | |||||
for (int row = 0; row < num_labels; row++) | |||||
{ | |||||
var col = labels[row]; | |||||
labels_one_hot.SetData(1.0, row, col); | |||||
} | |||||
return labels_one_hot; | |||||
} | |||||
private uint Read32(FileStream bytestream) | |||||
{ | |||||
var buffer = new byte[sizeof(uint)]; | |||||
var count = bytestream.Read(buffer, 0, 4); | |||||
return np.frombuffer(buffer, ">u4").Data<uint>()[0]; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,20 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using NumSharp; | |||||
namespace Tensorflow.Hub | |||||
{ | |||||
public class ModelLoadSetting | |||||
{ | |||||
public string TrainDir { get; set; } | |||||
public bool OneHot { get; set; } | |||||
public Type DataType { get; set; } = typeof(float); | |||||
public bool ReShape { get; set; } | |||||
public int ValidationSize { get; set; } = 5000; | |||||
public int? TrainSize { get; set; } | |||||
public int? TestSize { get; set; } | |||||
public string SourceUrl { get; set; } | |||||
public bool ShowProgressInConsole { get; set; } | |||||
} | |||||
} |
@@ -0,0 +1,5 @@ | |||||
## TensorFlow Hub | |||||
TensorFlow Hub is a library to foster the publication, discovery, and consumption of reusable parts of machine learning models. In particular, it provides **modules**, which are pre-trained pieces of TensorFlow models that can be reused on new tasks. | |||||
https://github.com/tensorflow/hub |
@@ -0,0 +1,26 @@ | |||||
<Project Sdk="Microsoft.NET.Sdk"> | |||||
<PropertyGroup> | |||||
<RootNamespace>Tensorflow.Hub</RootNamespace> | |||||
<TargetFramework>netstandard2.0</TargetFramework> | |||||
<Version>0.0.5</Version> | |||||
<Authors>Kerry Jiang, Haiping Chen</Authors> | |||||
<Company>SciSharp STACK</Company> | |||||
<Copyright>Apache 2.0</Copyright> | |||||
<RepositoryUrl>https://github.com/SciSharp/TensorFlow.NET</RepositoryUrl> | |||||
<RepositoryType>git</RepositoryType> | |||||
<PackageProjectUrl>http://scisharpstack.org</PackageProjectUrl> | |||||
<PackageTags>TensorFlow, SciSharp, MachineLearning</PackageTags> | |||||
<Description>TensorFlow Hub is a library to foster the publication, discovery, and consumption of reusable parts of machine learning models.</Description> | |||||
<PackageId>SciSharp.TensorFlowHub</PackageId> | |||||
<GeneratePackageOnBuild>true</GeneratePackageOnBuild> | |||||
<PackageReleaseNotes>Fix GetNextBatch() bug.</PackageReleaseNotes> | |||||
<PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&v=4</PackageIconUrl> | |||||
<AssemblyName>TensorFlow.Hub</AssemblyName> | |||||
</PropertyGroup> | |||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | |||||
<DefineConstants>DEBUG;TRACE</DefineConstants> | |||||
</PropertyGroup> | |||||
<ItemGroup> | |||||
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||||
</ItemGroup> | |||||
</Project> |
@@ -0,0 +1,137 @@ | |||||
using System; | |||||
using System.IO; | |||||
using System.IO.Compression; | |||||
using System.Collections.Generic; | |||||
using System.Net; | |||||
using System.Text; | |||||
using System.Threading; | |||||
using System.Threading.Tasks; | |||||
namespace Tensorflow.Hub | |||||
{ | |||||
public static class Utils | |||||
{ | |||||
public static async Task DownloadAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string url, string saveTo) | |||||
where TDataSet : IDataSet | |||||
{ | |||||
var dir = Path.GetDirectoryName(saveTo); | |||||
var fileName = Path.GetFileName(saveTo); | |||||
await modelLoader.DownloadAsync(url, dir, fileName); | |||||
} | |||||
public static async Task DownloadAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string url, string dirSaveTo, string fileName, bool showProgressInConsole = false) | |||||
where TDataSet : IDataSet | |||||
{ | |||||
if (!Path.IsPathRooted(dirSaveTo)) | |||||
dirSaveTo = Path.Combine(AppContext.BaseDirectory, dirSaveTo); | |||||
var fileSaveTo = Path.Combine(dirSaveTo, fileName); | |||||
if (showProgressInConsole) | |||||
{ | |||||
Console.WriteLine($"Downloading {fileName}"); | |||||
} | |||||
if (File.Exists(fileSaveTo)) | |||||
{ | |||||
if (showProgressInConsole) | |||||
{ | |||||
Console.WriteLine($"The file {fileName} already exists"); | |||||
} | |||||
return; | |||||
} | |||||
Directory.CreateDirectory(dirSaveTo); | |||||
using (var wc = new WebClient()) | |||||
{ | |||||
await wc.DownloadFileTaskAsync(url, fileSaveTo).ConfigureAwait(false); | |||||
} | |||||
} | |||||
public static async Task UnzipAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string zipFile, string saveTo, bool showProgressInConsole = false) | |||||
where TDataSet : IDataSet | |||||
{ | |||||
if (!Path.IsPathRooted(saveTo)) | |||||
saveTo = Path.Combine(AppContext.BaseDirectory, saveTo); | |||||
Directory.CreateDirectory(saveTo); | |||||
if (!Path.IsPathRooted(zipFile)) | |||||
zipFile = Path.Combine(AppContext.BaseDirectory, zipFile); | |||||
var destFileName = Path.GetFileNameWithoutExtension(zipFile); | |||||
var destFilePath = Path.Combine(saveTo, destFileName); | |||||
if (showProgressInConsole) | |||||
Console.WriteLine($"Unzippinng {Path.GetFileName(zipFile)}"); | |||||
if (File.Exists(destFilePath)) | |||||
{ | |||||
if (showProgressInConsole) | |||||
Console.WriteLine($"The file {destFileName} already exists"); | |||||
} | |||||
using (GZipStream unzipStream = new GZipStream(File.OpenRead(zipFile), CompressionMode.Decompress)) | |||||
{ | |||||
using (var destStream = File.Create(destFilePath)) | |||||
{ | |||||
await unzipStream.CopyToAsync(destStream).ConfigureAwait(false); | |||||
await destStream.FlushAsync().ConfigureAwait(false); | |||||
destStream.Close(); | |||||
} | |||||
unzipStream.Close(); | |||||
} | |||||
} | |||||
public static async Task ShowProgressInConsole(this Task task, bool enable) | |||||
{ | |||||
if (!enable) | |||||
{ | |||||
await task; | |||||
return; | |||||
} | |||||
var cts = new CancellationTokenSource(); | |||||
var showProgressTask = ShowProgressInConsole(cts); | |||||
try | |||||
{ | |||||
await task; | |||||
} | |||||
finally | |||||
{ | |||||
cts.Cancel(); | |||||
} | |||||
await showProgressTask; | |||||
Console.WriteLine("Done."); | |||||
} | |||||
private static async Task ShowProgressInConsole(CancellationTokenSource cts) | |||||
{ | |||||
var cols = 0; | |||||
await Task.Delay(100); | |||||
while (!cts.IsCancellationRequested) | |||||
{ | |||||
await Task.Delay(100); | |||||
Console.Write("."); | |||||
cols++; | |||||
if (cols % 50 == 0) | |||||
{ | |||||
Console.WriteLine(); | |||||
} | |||||
} | |||||
if (cols > 0) | |||||
Console.WriteLine(); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,24 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using System.Threading.Tasks; | |||||
using Tensorflow.Hub; | |||||
namespace UnitTest | |||||
{ | |||||
[TestClass] | |||||
public class MnistModelLoaderTest | |||||
{ | |||||
[TestMethod] | |||||
public async Task TestLoad() | |||||
{ | |||||
var loader = new MnistModelLoader(); | |||||
var result = await loader.LoadAsync(new ModelLoadSetting | |||||
{ | |||||
TrainDir = "mnist", | |||||
OneHot = true, | |||||
ValidationSize = 5000, | |||||
}); | |||||
Assert.IsNotNull(result); | |||||
} | |||||
} | |||||
} |
@@ -37,6 +37,7 @@ | |||||
<ItemGroup> | <ItemGroup> | ||||
<ProjectReference Include="..\..\src\TensorFlowNET.Core\Tensorflow.Binding.csproj" /> | <ProjectReference Include="..\..\src\TensorFlowNET.Core\Tensorflow.Binding.csproj" /> | ||||
<ProjectReference Include="..\..\src\TensorFlowNET.Hub\Tensorflow.Hub.csproj" /> | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||