Browse Source

base classes for Tensorflow.Hub

tags/v0.10
Kerry Jiang 6 years ago
parent
commit
ce94a5c56c
9 changed files with 160 additions and 10 deletions
  1. +0
    -8
      src/TensorFlowHub/Class1.cs
  2. +13
    -0
      src/TensorFlowHub/DataSetBase.cs
  3. +46
    -0
      src/TensorFlowHub/Datasets.cs
  4. +13
    -0
      src/TensorFlowHub/IDataSet.cs
  5. +14
    -0
      src/TensorFlowHub/IModelLoader.cs
  6. +31
    -0
      src/TensorFlowHub/MnistDataSet.cs
  7. +16
    -0
      src/TensorFlowHub/MnistModelLoader.cs
  8. +19
    -0
      src/TensorFlowHub/ModelLoadSetting.cs
  9. +8
    -2
      src/TensorFlowHub/TensorFlowHub.csproj

+ 0
- 8
src/TensorFlowHub/Class1.cs View File

@@ -1,8 +0,0 @@
using System;

namespace TensorFlowHub
{
public class Class1
{
}
}

+ 13
- 0
src/TensorFlowHub/DataSetBase.cs View File

@@ -0,0 +1,13 @@
using System;
using System.Collections.Generic;
using System.Text;
using NumSharp;

namespace Tensorflow.Hub
{
public abstract class DataSetBase : IDataSet
{
public NDArray Data { get; protected set; }
public NDArray Labels { get; protected set; }
}
}

+ 46
- 0
src/TensorFlowHub/Datasets.cs View File

@@ -0,0 +1,46 @@
using System;
using System.Collections.Generic;
using System.Text;
using NumSharp;

namespace Tensorflow.Hub
{
public class Datasets<TDataSet> where TDataSet : IDataSet
{
public TDataSet Train { get; private set; }

public TDataSet Validation { get; private set; }

public TDataSet Test { get; private set; }

public Datasets(TDataSet train, TDataSet validation, TDataSet test)
{
Train = train;
Validation = validation;
Test = test;
}

public (NDArray, NDArray) Randomize(NDArray x, NDArray y)
{
var perm = np.random.permutation(y.shape[0]);
np.random.shuffle(perm);
return (x[perm], y[perm]);
}

/// <summary>
/// selects a few number of images determined by the batch_size variable (if you don't know why, read about Stochastic Gradient Method)
/// </summary>
/// <param name="x"></param>
/// <param name="y"></param>
/// <param name="start"></param>
/// <param name="end"></param>
/// <returns></returns>
public (NDArray, NDArray) GetNextBatch(NDArray x, NDArray y, int start, int end)
{
var slice = new Slice(start, end);
var x_batch = x[slice];
var y_batch = y[slice];
return (x_batch, y_batch);
}
}
}

+ 13
- 0
src/TensorFlowHub/IDataSet.cs View File

@@ -0,0 +1,13 @@
using System;
using System.Collections.Generic;
using System.Text;
using NumSharp;

namespace Tensorflow.Hub
{
public interface IDataSet
{
NDArray Data { get; }
NDArray Labels { get; }
}
}

+ 14
- 0
src/TensorFlowHub/IModelLoader.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Threading.Tasks;
using System.Collections.Generic;
using System.Text;
using NumSharp;

namespace Tensorflow.Hub
{
public interface IModelLoader<TDataSet>
where TDataSet : IDataSet
{
Task<Datasets<TDataSet>> LoadAsync(ModelLoadSetting setting);
}
}

+ 31
- 0
src/TensorFlowHub/MnistDataSet.cs View File

@@ -0,0 +1,31 @@
using System;
using System.Collections.Generic;
using System.Text;
using NumSharp;
using Tensorflow;

namespace Tensorflow.Hub
{
public class MnistDataSet : DataSetBase
{
public int NumOfExamples { get; private set; }
public int EpochsCompleted { get; private set; }
public int IndexInEpoch { get; private set; }

public MnistDataSet(NDArray images, NDArray labels, TF_DataType dtype, bool reshape)
{
EpochsCompleted = 0;
IndexInEpoch = 0;

NumOfExamples = images.shape[0];

images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]);
images.astype(dtype.as_numpy_datatype());
images = np.multiply(images, 1.0f / 255.0f);
Data = images;

labels.astype(dtype.as_numpy_datatype());
Labels = labels;
}
}
}

+ 16
- 0
src/TensorFlowHub/MnistModelLoader.cs View File

@@ -0,0 +1,16 @@
using System;
using System.Threading.Tasks;
using System.Collections.Generic;
using System.Text;
using NumSharp;

namespace Tensorflow.Hub
{
public class MnistModelLoader : IModelLoader<MnistDataSet>
{
public Task<Datasets<MnistDataSet>> LoadAsync(ModelLoadSetting setting)
{
throw new NotImplementedException();
}
}
}

+ 19
- 0
src/TensorFlowHub/ModelLoadSetting.cs View File

@@ -0,0 +1,19 @@
using System;
using System.Collections.Generic;
using System.Text;
using NumSharp;

namespace Tensorflow.Hub
{
public class ModelLoadSetting
{
public string TrainDir { get; set; }
public bool OneHot { get; set; }
public TF_DataType DtType { get; set; } = TF_DataType.TF_FLOAT;
public bool ReShape { get; set; }
public int ValidationSize { get; set; } = 5000;
public int? TrainSize { get; set; }
public int? TestSize { get; set; }
public string SourceUrl { get; set; }
}
}

+ 8
- 2
src/TensorFlowHub/TensorFlowHub.csproj View File

@@ -1,7 +1,13 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<AssemblyName>TensorFlow.Net.Hub</AssemblyName>
<RootNamespace>Tensorflow.Hub</RootNamespace>
<TargetFramework>netstandard2.0</TargetFramework>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="NumSharp" Version="0.10.4" />
</ItemGroup>
</Project>

Loading…
Cancel
Save