Browse Source

更新 'gpu/train_fail3.py'

test_v20221116
wjtest1215 1 year ago
parent
commit
3b00379b6c
1 changed files with 3 additions and 2 deletions
  1. +3
    -2
      gpu/train_fail3.py

+ 3
- 2
gpu/train_fail3.py View File

@@ -69,8 +69,9 @@ if __name__ == '__main__':
sgd.zero_grad()
predict_y = model(train_x.float())
loss = cost(predict_y, train_label.long())
if idx % 10 == 0:
print(gettime(), 'idx: {}, loss: {}'.format(idx, loss.sum().item()))
#if idx % 10 == 0:
#print(gettime(), 'idx: {}, loss: {}'.format(idx, loss.sum().item()))
print(gettime(), 'idx: {}, loss: {}'.format(idx, loss.sum().item()))
loss.backward()
sgd.step()



Loading…
Cancel
Save