|
|
@@ -60,7 +60,7 @@ if __name__ == '__main__': |
|
|
|
epoch = args.epoch_size |
|
|
|
print('epoch_size is:{}'.format(epoch)) |
|
|
|
for _epoch in range(epoch): |
|
|
|
print('the {} epoch_size begin'.format(_epoch + 1)) |
|
|
|
#print('the {} epoch_size begin'.format(_epoch + 1)) |
|
|
|
model.train() |
|
|
|
for idx, (train_x, train_label) in enumerate(train_loader): |
|
|
|
train_x = train_x.to(device) |
|
|
|