|
|
@@ -15,14 +15,27 @@ namespace Tensorflow.Hub |
|
|
|
private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; |
|
|
|
private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; |
|
|
|
|
|
|
|
public static async Task<Datasets<MnistDataSet>> LoadAsync(string trainDir, bool oneHot = false) |
|
|
|
public static async Task<Datasets<MnistDataSet>> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null) |
|
|
|
{ |
|
|
|
var loader = new MnistModelLoader(); |
|
|
|
return await loader.LoadAsync(new ModelLoadSetting |
|
|
|
|
|
|
|
var setting = new ModelLoadSetting |
|
|
|
{ |
|
|
|
TrainDir = trainDir, |
|
|
|
OneHot = oneHot |
|
|
|
}); |
|
|
|
OneHot = oneHot, |
|
|
|
TrainSize = trainSize |
|
|
|
}; |
|
|
|
|
|
|
|
if (trainSize.HasValue) |
|
|
|
setting.TrainSize = trainSize.Value; |
|
|
|
|
|
|
|
if (validationSize.HasValue) |
|
|
|
setting.ValidationSize = validationSize.Value; |
|
|
|
|
|
|
|
if (testSize.HasValue) |
|
|
|
setting.TestSize = testSize.Value; |
|
|
|
|
|
|
|
return await loader.LoadAsync(setting); |
|
|
|
} |
|
|
|
|
|
|
|
public async Task<Datasets<MnistDataSet>> LoadAsync(ModelLoadSetting setting) |
|
|
|