diff --git a/src/TensorFlowHub/MnistModelLoader.cs b/src/TensorFlowHub/MnistModelLoader.cs index 7c4ff109..b22d359e 100644 --- a/src/TensorFlowHub/MnistModelLoader.cs +++ b/src/TensorFlowHub/MnistModelLoader.cs @@ -15,14 +15,27 @@ namespace Tensorflow.Hub private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; - public static async Task> LoadAsync(string trainDir, bool oneHot = false) + public static async Task> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null) { var loader = new MnistModelLoader(); - return await loader.LoadAsync(new ModelLoadSetting + + var setting = new ModelLoadSetting { TrainDir = trainDir, - OneHot = oneHot - }); + OneHot = oneHot, + TrainSize = trainSize + }; + + if (trainSize.HasValue) + setting.TrainSize = trainSize.Value; + + if (validationSize.HasValue) + setting.ValidationSize = validationSize.Value; + + if (testSize.HasValue) + setting.TestSize = testSize.Value; + + return await loader.LoadAsync(setting); } public async Task> LoadAsync(ModelLoadSetting setting)