|
@@ -99,8 +99,7 @@ if __name__ == "__main__": |
|
|
|
|
|
|
|
|
# get training dataset |
|
|
# get training dataset |
|
|
ds_train = generate_mnist_dataset(os.path.join(cfg.data_path, "train"), |
|
|
ds_train = generate_mnist_dataset(os.path.join(cfg.data_path, "train"), |
|
|
cfg.batch_size, |
|
|
|
|
|
cfg.epoch_size) |
|
|
|
|
|
|
|
|
cfg.batch_size) |
|
|
|
|
|
|
|
|
if cfg.micro_batches and cfg.batch_size % cfg.micro_batches != 0: |
|
|
if cfg.micro_batches and cfg.batch_size % cfg.micro_batches != 0: |
|
|
raise ValueError("Number of micro_batches should divide evenly batch_size") |
|
|
raise ValueError("Number of micro_batches should divide evenly batch_size") |
|
|