|
|
@@ -70,7 +70,7 @@ def main(): |
|
|
parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink") |
|
|
parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink") |
|
|
parser.add_argument("--epoch_size", type=int, default=10, help="Epoch size, default is 10") |
|
|
parser.add_argument("--epoch_size", type=int, default=10, help="Epoch size, default is 10") |
|
|
parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.") |
|
|
parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.") |
|
|
parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path") |
|
|
|
|
|
|
|
|
parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained checkpoint file path") |
|
|
parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.") |
|
|
parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.") |
|
|
parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") |
|
|
parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") |
|
|
parser.add_argument("--mindrecord_dir", type=str, default="./Mindrecord_train", |
|
|
parser.add_argument("--mindrecord_dir", type=str, default="./Mindrecord_train", |
|
|
@@ -138,8 +138,8 @@ def main(): |
|
|
opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), lr, loss_scale=loss_scale) |
|
|
opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), lr, loss_scale=loss_scale) |
|
|
net = TrainingWrapper(net, opt, loss_scale) |
|
|
net = TrainingWrapper(net, opt, loss_scale) |
|
|
|
|
|
|
|
|
if args_opt.checkpoint_path != "": |
|
|
|
|
|
param_dict = load_checkpoint(args_opt.checkpoint_path) |
|
|
|
|
|
|
|
|
if args_opt.pre_trained: |
|
|
|
|
|
param_dict = load_checkpoint(args_opt.pre_trained) |
|
|
load_param_into_net(net, param_dict) |
|
|
load_param_into_net(net, param_dict) |
|
|
|
|
|
|
|
|
callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] |
|
|
callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] |
|
|
|