|
@@ -64,8 +64,9 @@ if __name__ == '__main__': |
|
|
sgd.zero_grad() |
|
|
sgd.zero_grad() |
|
|
predict_y = model(train_x.float()) |
|
|
predict_y = model(train_x.float()) |
|
|
loss = cost(predict_y, train_label.long()) |
|
|
loss = cost(predict_y, train_label.long()) |
|
|
if idx % 10 == 0: |
|
|
|
|
|
print('idx: {}, loss: {}'.format(idx, loss.sum().item())) |
|
|
|
|
|
|
|
|
#if idx % 10 == 0: |
|
|
|
|
|
#print('idx: {}, loss: {}'.format(idx, loss.sum().item())) |
|
|
|
|
|
print('idx: {}, loss: {}'.format(idx, loss.sum().item())) |
|
|
loss.backward() |
|
|
loss.backward() |
|
|
sgd.step() |
|
|
sgd.step() |
|
|
|
|
|
|
|
|