|
@@ -452,17 +452,15 @@ class Trainer(object): |
|
|
else: |
|
|
else: |
|
|
raise TypeError("train_data type {} not support".format(type(train_data))) |
|
|
raise TypeError("train_data type {} not support".format(type(train_data))) |
|
|
|
|
|
|
|
|
self.model = _move_model_to_device(model, device=device) |
|
|
|
|
|
|
|
|
|
|
|
if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter): |
|
|
if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter): |
|
|
_check_code(dataset=train_data, model=self.model, losser=losser, metrics=metrics, dev_data=dev_data, |
|
|
|
|
|
|
|
|
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, |
|
|
metric_key=metric_key, check_level=check_code_level, |
|
|
metric_key=metric_key, check_level=check_code_level, |
|
|
batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) |
|
|
batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) |
|
|
# _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的代码 |
|
|
# _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的代码 |
|
|
|
|
|
|
|
|
|
|
|
self.model = _move_model_to_device(model, device=device) |
|
|
|
|
|
|
|
|
self.train_data = train_data |
|
|
self.train_data = train_data |
|
|
self.dev_data = dev_data # If None, No validation. |
|
|
self.dev_data = dev_data # If None, No validation. |
|
|
self.model = model |
|
|
|
|
|
self.losser = losser |
|
|
self.losser = losser |
|
|
self.metrics = metrics |
|
|
self.metrics = metrics |
|
|
self.n_epochs = int(n_epochs) |
|
|
self.n_epochs = int(n_epochs) |
|
@@ -480,16 +478,16 @@ class Trainer(object): |
|
|
if isinstance(optimizer, torch.optim.Optimizer): |
|
|
if isinstance(optimizer, torch.optim.Optimizer): |
|
|
self.optimizer = optimizer |
|
|
self.optimizer = optimizer |
|
|
elif isinstance(optimizer, Optimizer): |
|
|
elif isinstance(optimizer, Optimizer): |
|
|
self.optimizer = optimizer.construct_from_pytorch(model.parameters()) |
|
|
|
|
|
|
|
|
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) |
|
|
elif optimizer is None: |
|
|
elif optimizer is None: |
|
|
self.optimizer = torch.optim.Adam(model.parameters(), lr=4e-3) |
|
|
|
|
|
|
|
|
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=4e-3) |
|
|
else: |
|
|
else: |
|
|
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) |
|
|
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) |
|
|
|
|
|
|
|
|
self.use_tqdm = use_tqdm |
|
|
self.use_tqdm = use_tqdm |
|
|
self.pbar = None |
|
|
self.pbar = None |
|
|
self.print_every = abs(self.print_every) |
|
|
self.print_every = abs(self.print_every) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.dev_data is not None: |
|
|
if self.dev_data is not None: |
|
|
self.tester = Tester(model=self.model, |
|
|
self.tester = Tester(model=self.model, |
|
|
data=self.dev_data, |
|
|
data=self.dev_data, |
|
|