Browse Source

imporved MnistModelLoader's interface

tags/v0.12
Kerry Jiang 6 years ago
parent
commit
ef5ea664df
1 changed files with 17 additions and 4 deletions
  1. +17
    -4
      src/TensorFlowHub/MnistModelLoader.cs

+ 17
- 4
src/TensorFlowHub/MnistModelLoader.cs View File

@@ -15,14 +15,27 @@ namespace Tensorflow.Hub
private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz";
private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz";

public static async Task<Datasets<MnistDataSet>> LoadAsync(string trainDir, bool oneHot = false)
public static async Task<Datasets<MnistDataSet>> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null)
{
var loader = new MnistModelLoader();
return await loader.LoadAsync(new ModelLoadSetting

var setting = new ModelLoadSetting
{
TrainDir = trainDir,
OneHot = oneHot
});
OneHot = oneHot,
TrainSize = trainSize
};

if (trainSize.HasValue)
setting.TrainSize = trainSize.Value;

if (validationSize.HasValue)
setting.ValidationSize = validationSize.Value;

if (testSize.HasValue)
setting.TestSize = testSize.Value;

return await loader.LoadAsync(setting);
}

public async Task<Datasets<MnistDataSet>> LoadAsync(ModelLoadSetting setting)


Loading…
Cancel
Save