|
|
@@ -696,8 +696,9 @@ class Trainer(TrainerEventTrigger): |
|
|
|
self.on_before_backward(outputs) |
|
|
|
loss = self.extract_loss_from_outputs(outputs) |
|
|
|
loss = loss / self.accumulation_steps |
|
|
|
with self.get_no_sync_context(): |
|
|
|
self.driver.backward(loss) |
|
|
|
# with self.get_no_sync_context(): |
|
|
|
# self.driver.backward(loss) |
|
|
|
self.driver.backward(loss) |
|
|
|
self.on_after_backward() |
|
|
|
|
|
|
|
def zero_grad(self): |
|
|
|