@@ -1,8 +0,0 @@ | |||||
using System; | |||||
namespace TensorFlowHub | |||||
{ | |||||
public class Class1 | |||||
{ | |||||
} | |||||
} |
@@ -0,0 +1,13 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using NumSharp; | |||||
namespace Tensorflow.Hub | |||||
{ | |||||
public abstract class DataSetBase : IDataSet | |||||
{ | |||||
public NDArray Data { get; protected set; } | |||||
public NDArray Labels { get; protected set; } | |||||
} | |||||
} |
@@ -0,0 +1,46 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using NumSharp; | |||||
namespace Tensorflow.Hub | |||||
{ | |||||
public class Datasets<TDataSet> where TDataSet : IDataSet | |||||
{ | |||||
public TDataSet Train { get; private set; } | |||||
public TDataSet Validation { get; private set; } | |||||
public TDataSet Test { get; private set; } | |||||
public Datasets(TDataSet train, TDataSet validation, TDataSet test) | |||||
{ | |||||
Train = train; | |||||
Validation = validation; | |||||
Test = test; | |||||
} | |||||
public (NDArray, NDArray) Randomize(NDArray x, NDArray y) | |||||
{ | |||||
var perm = np.random.permutation(y.shape[0]); | |||||
np.random.shuffle(perm); | |||||
return (x[perm], y[perm]); | |||||
} | |||||
/// <summary> | |||||
/// selects a few number of images determined by the batch_size variable (if you don't know why, read about Stochastic Gradient Method) | |||||
/// </summary> | |||||
/// <param name="x"></param> | |||||
/// <param name="y"></param> | |||||
/// <param name="start"></param> | |||||
/// <param name="end"></param> | |||||
/// <returns></returns> | |||||
public (NDArray, NDArray) GetNextBatch(NDArray x, NDArray y, int start, int end) | |||||
{ | |||||
var slice = new Slice(start, end); | |||||
var x_batch = x[slice]; | |||||
var y_batch = y[slice]; | |||||
return (x_batch, y_batch); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,13 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using NumSharp; | |||||
namespace Tensorflow.Hub | |||||
{ | |||||
public interface IDataSet | |||||
{ | |||||
NDArray Data { get; } | |||||
NDArray Labels { get; } | |||||
} | |||||
} |
@@ -0,0 +1,14 @@ | |||||
using System; | |||||
using System.Threading.Tasks; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using NumSharp; | |||||
namespace Tensorflow.Hub | |||||
{ | |||||
public interface IModelLoader<TDataSet> | |||||
where TDataSet : IDataSet | |||||
{ | |||||
Task<Datasets<TDataSet>> LoadAsync(ModelLoadSetting setting); | |||||
} | |||||
} |
@@ -0,0 +1,31 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using NumSharp; | |||||
using Tensorflow; | |||||
namespace Tensorflow.Hub | |||||
{ | |||||
public class MnistDataSet : DataSetBase | |||||
{ | |||||
public int NumOfExamples { get; private set; } | |||||
public int EpochsCompleted { get; private set; } | |||||
public int IndexInEpoch { get; private set; } | |||||
public MnistDataSet(NDArray images, NDArray labels, TF_DataType dtype, bool reshape) | |||||
{ | |||||
EpochsCompleted = 0; | |||||
IndexInEpoch = 0; | |||||
NumOfExamples = images.shape[0]; | |||||
images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]); | |||||
images.astype(dtype.as_numpy_datatype()); | |||||
images = np.multiply(images, 1.0f / 255.0f); | |||||
Data = images; | |||||
labels.astype(dtype.as_numpy_datatype()); | |||||
Labels = labels; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,16 @@ | |||||
using System; | |||||
using System.Threading.Tasks; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using NumSharp; | |||||
namespace Tensorflow.Hub | |||||
{ | |||||
public class MnistModelLoader : IModelLoader<MnistDataSet> | |||||
{ | |||||
public Task<Datasets<MnistDataSet>> LoadAsync(ModelLoadSetting setting) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,19 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using NumSharp; | |||||
namespace Tensorflow.Hub | |||||
{ | |||||
public class ModelLoadSetting | |||||
{ | |||||
public string TrainDir { get; set; } | |||||
public bool OneHot { get; set; } | |||||
public TF_DataType DtType { get; set; } = TF_DataType.TF_FLOAT; | |||||
public bool ReShape { get; set; } | |||||
public int ValidationSize { get; set; } = 5000; | |||||
public int? TrainSize { get; set; } | |||||
public int? TestSize { get; set; } | |||||
public string SourceUrl { get; set; } | |||||
} | |||||
} |
@@ -1,7 +1,13 @@ | |||||
<Project Sdk="Microsoft.NET.Sdk"> | <Project Sdk="Microsoft.NET.Sdk"> | ||||
<PropertyGroup> | <PropertyGroup> | ||||
<AssemblyName>TensorFlow.Net.Hub</AssemblyName> | |||||
<RootNamespace>Tensorflow.Hub</RootNamespace> | |||||
<TargetFramework>netstandard2.0</TargetFramework> | <TargetFramework>netstandard2.0</TargetFramework> | ||||
</PropertyGroup> | </PropertyGroup> | ||||
<ItemGroup> | |||||
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | |||||
</ItemGroup> | |||||
<ItemGroup> | |||||
<PackageReference Include="NumSharp" Version="0.10.4" /> | |||||
</ItemGroup> | |||||
</Project> | </Project> |