Browse Source

更新 'gpu/train.py'

test_v20221116
wjtest1201 2 years ago
parent
commit
d430e901ca
1 changed files with 9 additions and 4 deletions
  1. +9
    -4
      gpu/train.py

+ 9
- 4
gpu/train.py View File

@@ -39,10 +39,14 @@ parser.add_argument('--testdata', default="/dataset/test" ,help='path to test da
parser.add_argument('--epoch_size', type=int, default=1, help='how much epoch to train')
parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch')

def gettime():
timestr = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
return timestr

if __name__ == '__main__':
args, unknown = parser.parse_known_args()
#log output
print('cuda is available:{}'.format(torch.cuda.is_available()))
print(gettime(), 'cuda is available:{}'.format(torch.cuda.is_available()))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = args.batch_size
train_dataset = mnist.MNIST(root=args.traindata, train=True, transform=ToTensor(),download=False)
@@ -53,9 +57,9 @@ if __name__ == '__main__':
sgd = SGD(model.parameters(), lr=1e-1)
cost = CrossEntropyLoss()
epoch = args.epoch_size
print('epoch_size is:{}'.format(epoch))
print(gettime(), 'epoch_size is:{}'.format(epoch))
for _epoch in range(epoch):
print('the {} epoch_size begin'.format(_epoch + 1))
print(gettime(), 'the {} epoch_size begin'.format(_epoch + 1))
model.train()
for idx, (train_x, train_label) in enumerate(train_loader):
train_x = train_x.to(device)
@@ -66,6 +70,7 @@ if __name__ == '__main__':
loss = cost(predict_y, train_label.long())
#if idx % 10 == 0:
#print('idx: {}, loss: {}'.format(idx, loss.sum().item()))
print(gettime())
print('idx: {}, loss: {}'.format(idx, loss.sum().item()))
loss.backward()
sgd.step()
@@ -82,6 +87,6 @@ if __name__ == '__main__':
_ = predict_ys == test_label
correct += np.sum(_.numpy(), axis=-1)
_sum += _.shape[0]
print('accuracy: {:.2f}'.format(correct / _sum))
print(gettime(), 'accuracy: {:.2f}'.format(correct / _sum))
#The model output location is placed under /model
torch.save(model, '/model/mnist_epoch{}_{:.2f}.pkl'.format(_epoch+1, correct / _sum))

Loading…
Cancel
Save