diff --git a/example/mnist_demo/lenet5_dp_pynative_model.py b/example/mnist_demo/lenet5_dp_pynative_model.py index e198d6b..13374c1 100644 --- a/example/mnist_demo/lenet5_dp_pynative_model.py +++ b/example/mnist_demo/lenet5_dp_pynative_model.py @@ -99,8 +99,7 @@ if __name__ == "__main__": # get training dataset ds_train = generate_mnist_dataset(os.path.join(cfg.data_path, "train"), - cfg.batch_size, - cfg.epoch_size) + cfg.batch_size) if cfg.micro_batches and cfg.batch_size % cfg.micro_batches != 0: raise ValueError("Number of micro_batches should divide evenly batch_size")