|
|
@@ -15,6 +15,16 @@ namespace Tensorflow.Hub |
|
|
|
private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; |
|
|
|
private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; |
|
|
|
|
|
|
|
public static async Task<Datasets<MnistDataSet>> LoadAsync(string trainDir, bool oneHot = false) |
|
|
|
{ |
|
|
|
var loader = new MnistModelLoader(); |
|
|
|
return await loader.LoadAsync(new ModelLoadSetting |
|
|
|
{ |
|
|
|
TrainDir = trainDir, |
|
|
|
OneHot = oneHot |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
public async Task<Datasets<MnistDataSet>> LoadAsync(ModelLoadSetting setting) |
|
|
|
{ |
|
|
|
if (setting.TrainSize.HasValue && setting.ValidationSize >= setting.TrainSize.Value) |
|
|
|