Browse Source

更新 'gpu/train.py'

test_v20221116
wjtest1201 2 years ago
parent
commit
b26241c965
1 changed files with 3 additions and 2 deletions
  1. +3
    -2
      gpu/train.py

+ 3
- 2
gpu/train.py View File

@@ -64,8 +64,9 @@ if __name__ == '__main__':
sgd.zero_grad() sgd.zero_grad()
predict_y = model(train_x.float()) predict_y = model(train_x.float())
loss = cost(predict_y, train_label.long()) loss = cost(predict_y, train_label.long())
if idx % 10 == 0:
print('idx: {}, loss: {}'.format(idx, loss.sum().item()))
#if idx % 10 == 0:
#print('idx: {}, loss: {}'.format(idx, loss.sum().item()))
print('idx: {}, loss: {}'.format(idx, loss.sum().item()))
loss.backward() loss.backward()
sgd.step() sgd.step()




Loading…
Cancel
Save