@@ -31,7 +31,7 @@ class Saver: | |||||
folder = Path.cwd() | folder = Path.cwd() | ||||
folder = Path(folder) | folder = Path(folder) | ||||
if not folder.exists(): | if not folder.exists(): | ||||
raise NotADirectoryError(f"Path '{folder.absolute()}' is not existed!") | |||||
folder.mkdir(parents=True, exist_ok=True) | |||||
elif folder.is_file(): | elif folder.is_file(): | ||||
raise ValueError("Parameter `folder` should be a directory instead of a file.") | raise ValueError("Parameter `folder` should be a directory instead of a file.") | ||||
@@ -36,7 +36,8 @@ class TrainBatchLoop(Loop): | |||||
raise e | raise e | ||||
trainer.on_train_batch_begin(batch, indices) | trainer.on_train_batch_begin(batch, indices) | ||||
self.batch_step_fn(trainer, batch) | |||||
with trainer.get_no_sync_context(): # 在多卡的时候可能需要关闭 sync | |||||
self.batch_step_fn(trainer, batch) | |||||
trainer.global_forward_batches += 1 | trainer.global_forward_batches += 1 | ||||
trainer.batch_idx_in_epoch += 1 | trainer.batch_idx_in_epoch += 1 | ||||
@@ -696,8 +696,9 @@ class Trainer(TrainerEventTrigger): | |||||
self.on_before_backward(outputs) | self.on_before_backward(outputs) | ||||
loss = self.extract_loss_from_outputs(outputs) | loss = self.extract_loss_from_outputs(outputs) | ||||
loss = loss / self.accumulation_steps | loss = loss / self.accumulation_steps | ||||
with self.get_no_sync_context(): | |||||
self.driver.backward(loss) | |||||
# with self.get_no_sync_context(): | |||||
# self.driver.backward(loss) | |||||
self.driver.backward(loss) | |||||
self.on_after_backward() | self.on_after_backward() | ||||
def zero_grad(self): | def zero_grad(self): | ||||