|
@@ -34,8 +34,8 @@ from fastNLP.core.utils import get_func_signature |
|
|
class Trainer(object): |
|
|
class Trainer(object): |
|
|
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, |
|
|
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, |
|
|
validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), |
|
|
validate_every=-1, dev_data=None, save_path=None, optimizer=Adam(lr=0.01, weight_decay=0), |
|
|
check_code_level=0, metric_key=None, sampler=RandomSampler(), num_workers=0, |
|
|
|
|
|
use_tqdm=True, use_cuda=False, callbacks=None): |
|
|
|
|
|
|
|
|
check_code_level=0, metric_key=None, sampler=RandomSampler(), num_workers=0, pin_memory=False, |
|
|
|
|
|
timeout=0, use_tqdm=True, use_cuda=False, callbacks=None): |
|
|
""" |
|
|
""" |
|
|
:param DataSet train_data: the training data |
|
|
:param DataSet train_data: the training data |
|
|
:param torch.nn.modules.module model: a PyTorch model |
|
|
:param torch.nn.modules.module model: a PyTorch model |
|
@@ -127,6 +127,8 @@ class Trainer(object): |
|
|
self.best_dev_perf = None |
|
|
self.best_dev_perf = None |
|
|
self.sampler = sampler |
|
|
self.sampler = sampler |
|
|
self.num_workers = num_workers |
|
|
self.num_workers = num_workers |
|
|
|
|
|
self.pin_memory = pin_memory |
|
|
|
|
|
self.timeout = timeout |
|
|
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) |
|
|
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) |
|
|
|
|
|
|
|
|
if isinstance(optimizer, torch.optim.Optimizer): |
|
|
if isinstance(optimizer, torch.optim.Optimizer): |
|
@@ -247,7 +249,9 @@ class Trainer(object): |
|
|
len(self.train_data) % self.batch_size != 0)) * self.n_epochs |
|
|
len(self.train_data) % self.batch_size != 0)) * self.n_epochs |
|
|
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: |
|
|
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: |
|
|
avg_loss = 0 |
|
|
avg_loss = 0 |
|
|
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False) |
|
|
|
|
|
|
|
|
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, |
|
|
|
|
|
num_workers=self.num_workers, pin_memory=self.pin_memory, timeout=self.timeout, |
|
|
|
|
|
keep_process=True) |
|
|
for epoch in range(1, self.n_epochs+1): |
|
|
for epoch in range(1, self.n_epochs+1): |
|
|
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) |
|
|
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) |
|
|
# early stopping |
|
|
# early stopping |
|
@@ -257,7 +261,7 @@ class Trainer(object): |
|
|
# negative sampling; replace unknown; re-weight batch_y |
|
|
# negative sampling; replace unknown; re-weight batch_y |
|
|
self.callback_manager.before_batch(batch_x, batch_y, indices) |
|
|
self.callback_manager.before_batch(batch_x, batch_y, indices) |
|
|
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device, |
|
|
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device, |
|
|
non_blocking=self.use_cuda) # pin_memory, use non_blockling. |
|
|
|
|
|
|
|
|
non_blocking=self.pin_memory) # pin_memory, use non_blockling. |
|
|
prediction = self._data_forward(self.model, batch_x) |
|
|
prediction = self._data_forward(self.model, batch_x) |
|
|
|
|
|
|
|
|
# edit prediction |
|
|
# edit prediction |
|
|