@@ -14,15 +14,17 @@ class Batch(object): | |||||
:param DataSet dataset: a DataSet object | :param DataSet dataset: a DataSet object | ||||
:param int batch_size: the size of the batch | :param int batch_size: the size of the batch | ||||
:param Sampler sampler: a Sampler object | |||||
:param Sampler sampler: a Sampler object. If None, use fastNLP.sampler.RandomSampler | |||||
:param bool as_numpy: If True, return Numpy array. Otherwise, return torch tensors. | :param bool as_numpy: If True, return Numpy array. Otherwise, return torch tensors. | ||||
:param bool prefetch: If True, use multiprocessing to fetch next batch when training. | :param bool prefetch: If True, use multiprocessing to fetch next batch when training. | ||||
:param str or torch.device device: the batch's device, if as_numpy is True, device is ignored. | :param str or torch.device device: the batch's device, if as_numpy is True, device is ignored. | ||||
""" | """ | ||||
def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False, prefetch=False): | |||||
def __init__(self, dataset, batch_size, sampler=None, as_numpy=False, prefetch=False): | |||||
self.dataset = dataset | self.dataset = dataset | ||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
if sampler is None: | |||||
sampler = RandomSampler() | |||||
self.sampler = sampler | self.sampler = sampler | ||||
self.as_numpy = as_numpy | self.as_numpy = as_numpy | ||||
self.idx_list = None | self.idx_list = None | ||||
@@ -17,37 +17,37 @@ class Callback(object): | |||||
super(Callback, self).__init__() | super(Callback, self).__init__() | ||||
self.trainer = None # 在Trainer内部被重新赋值 | self.trainer = None # 在Trainer内部被重新赋值 | ||||
# callback只读属性 | |||||
self._n_epochs = None | |||||
self._n_steps = None | |||||
self._batch_size = None | |||||
self._model = None | |||||
self._pbar = None | |||||
self._optimizer = None | |||||
@property | @property | ||||
def n_epochs(self): | def n_epochs(self): | ||||
return self._n_epochs | |||||
return self.trainer.n_epochs | |||||
@property | |||||
def epoch(self): | |||||
return self.trainer.epoch | |||||
@property | @property | ||||
def n_steps(self): | def n_steps(self): | ||||
return self._n_steps | |||||
return self.trainer.n_steps | |||||
@property | |||||
def step(self): | |||||
return self.trainer.step | |||||
@property | @property | ||||
def batch_size(self): | def batch_size(self): | ||||
return self._batch_size | |||||
return self.trainer.batch_size | |||||
@property | @property | ||||
def model(self): | def model(self): | ||||
return self._model | |||||
return self.trainer.model | |||||
@property | @property | ||||
def pbar(self): | def pbar(self): | ||||
return self._pbar | |||||
return self.trainer.pbar | |||||
@property | @property | ||||
def optimizer(self): | def optimizer(self): | ||||
return self._optimizer | |||||
return self.trainer.optimizer | |||||
def on_train_begin(self): | def on_train_begin(self): | ||||
# before the main training loop | # before the main training loop | ||||
@@ -82,13 +82,14 @@ class Callback(object): | |||||
def on_valid_begin(self): | def on_valid_begin(self): | ||||
pass | pass | ||||
def on_valid_end(self, eval_result, metric_key, optimizer): | |||||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||||
""" | """ | ||||
每次执行验证机的evaluation后会调用。传入eval_result | 每次执行验证机的evaluation后会调用。传入eval_result | ||||
:param eval_result: Dict[str: Dict[str: float]], evaluation的结果 | :param eval_result: Dict[str: Dict[str: float]], evaluation的结果 | ||||
:param metric_key: str | :param metric_key: str | ||||
:param optimizer: | |||||
:param optimizer: optimizer passed to trainer | |||||
:param is_better_eval: bool, 当前dev结果是否比之前的好 | |||||
:return: | :return: | ||||
""" | """ | ||||
pass | pass | ||||
@@ -145,11 +146,10 @@ class CallbackManager(Callback): | |||||
""" | """ | ||||
def __init__(self, env, attr, callbacks=None): | |||||
def __init__(self, env, callbacks=None): | |||||
""" | """ | ||||
:param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. | :param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. | ||||
:param dict attr: read-only attributes for all callbacks | |||||
:param Callback callbacks: | :param Callback callbacks: | ||||
""" | """ | ||||
super(CallbackManager, self).__init__() | super(CallbackManager, self).__init__() | ||||
@@ -170,19 +170,6 @@ class CallbackManager(Callback): | |||||
for callback in self.callbacks: | for callback in self.callbacks: | ||||
setattr(callback, env_name, env_val) # Callback.trainer | setattr(callback, env_name, env_val) # Callback.trainer | ||||
self.set_property(**attr) | |||||
def set_property(self, **kwargs): | |||||
"""设置所有callback的只读属性 | |||||
:param kwargs: | |||||
:return: | |||||
""" | |||||
for callback in self.callbacks: | |||||
for k, v in kwargs.items(): | |||||
setattr(callback, "_" + k, v) | |||||
@transfer | @transfer | ||||
def on_train_begin(self): | def on_train_begin(self): | ||||
pass | pass | ||||
@@ -220,7 +207,7 @@ class CallbackManager(Callback): | |||||
pass | pass | ||||
@transfer | @transfer | ||||
def on_valid_end(self, eval_result, metric_key, optimizer): | |||||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||||
pass | pass | ||||
@transfer | @transfer | ||||
@@ -90,7 +90,7 @@ class DataSet(object): | |||||
data_set = DataSet() | data_set = DataSet() | ||||
for field in self.field_arrays.values(): | for field in self.field_arrays.values(): | ||||
data_set.add_field(name=field.name, fields=field.content[idx], padder=field.padder, | data_set.add_field(name=field.name, fields=field.content[idx], padder=field.padder, | ||||
is_input=field.is_input, is_target=field.is_target) | |||||
is_input=field.is_input, is_target=field.is_target, ignore_type=field.ignore_type) | |||||
return data_set | return data_set | ||||
elif isinstance(idx, str): | elif isinstance(idx, str): | ||||
if idx not in self: | if idx not in self: | ||||
@@ -313,16 +313,23 @@ class DataSet(object): | |||||
else: | else: | ||||
return results | return results | ||||
def drop(self, func): | |||||
def drop(self, func, inplace=True): | |||||
"""Drop instances if a condition holds. | """Drop instances if a condition holds. | ||||
:param func: a function that takes an Instance object as input, and returns bool. | :param func: a function that takes an Instance object as input, and returns bool. | ||||
The instance will be dropped if the function returns True. | The instance will be dropped if the function returns True. | ||||
:param inplace: bool, whether to drop inpalce. Otherwise a new dataset will be returned. | |||||
""" | """ | ||||
results = [ins for ins in self._inner_iter() if not func(ins)] | |||||
for name, old_field in self.field_arrays.items(): | |||||
self.field_arrays[name].content = [ins[name] for ins in results] | |||||
if inplace: | |||||
results = [ins for ins in self._inner_iter() if not func(ins)] | |||||
for name, old_field in self.field_arrays.items(): | |||||
self.field_arrays[name].content = [ins[name] for ins in results] | |||||
else: | |||||
results = [ins for ins in self if not func(ins)] | |||||
data = DataSet(results) | |||||
for field_name, field in self.field_arrays.items(): | |||||
data.field_arrays[field_name].to(field) | |||||
def split(self, dev_ratio): | def split(self, dev_ratio): | ||||
"""Split the dataset into training and development(validation) set. | """Split the dataset into training and development(validation) set. | ||||
@@ -346,19 +353,8 @@ class DataSet(object): | |||||
for idx in train_indices: | for idx in train_indices: | ||||
train_set.append(self[idx]) | train_set.append(self[idx]) | ||||
for field_name in self.field_arrays: | for field_name in self.field_arrays: | ||||
train_set.field_arrays[field_name].is_input = self.field_arrays[field_name].is_input | |||||
train_set.field_arrays[field_name].is_target = self.field_arrays[field_name].is_target | |||||
train_set.field_arrays[field_name].padder = self.field_arrays[field_name].padder | |||||
train_set.field_arrays[field_name].dtype = self.field_arrays[field_name].dtype | |||||
train_set.field_arrays[field_name].pytype = self.field_arrays[field_name].pytype | |||||
train_set.field_arrays[field_name].content_dim = self.field_arrays[field_name].content_dim | |||||
dev_set.field_arrays[field_name].is_input = self.field_arrays[field_name].is_input | |||||
dev_set.field_arrays[field_name].is_target = self.field_arrays[field_name].is_target | |||||
dev_set.field_arrays[field_name].padder = self.field_arrays[field_name].padder | |||||
dev_set.field_arrays[field_name].dtype = self.field_arrays[field_name].dtype | |||||
dev_set.field_arrays[field_name].pytype = self.field_arrays[field_name].pytype | |||||
dev_set.field_arrays[field_name].content_dim = self.field_arrays[field_name].content_dim | |||||
train_set.field_arrays[field_name].to(self.field_arrays[field_name]) | |||||
dev_set.field_arrays[field_name].to(self.field_arrays[field_name]) | |||||
return train_set, dev_set | return train_set, dev_set | ||||
@@ -383,6 +383,23 @@ class FieldArray(object): | |||||
""" | """ | ||||
return len(self.content) | return len(self.content) | ||||
def to(self, other): | |||||
""" | |||||
将other的属性复制给本fieldarray(必须通过fieldarray类型). 包含 is_input, is_target, padder, dtype, pytype, content_dim | |||||
ignore_type | |||||
:param other: FieldArray | |||||
:return: | |||||
""" | |||||
assert isinstance(other, FieldArray), "Only support FieldArray type, not {}.".format(type(other)) | |||||
self.is_input = other.is_input | |||||
self.is_target = other.is_target | |||||
self.padder = other.padder | |||||
self.dtype = other.dtype | |||||
self.pytype = other.pytype | |||||
self.content_dim = other.content_dim | |||||
self.ignore_type = other.ignore_type | |||||
def is_iterable(content): | def is_iterable(content): | ||||
try: | try: | ||||
@@ -91,7 +91,6 @@ class MetricBase(object): | |||||
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | ||||
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | ||||
will be conducted.) | will be conducted.) | ||||
However, in some cases where type check is not necessary, ``_fast_param_map`` will be used. | |||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
@@ -146,21 +145,6 @@ class MetricBase(object): | |||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
raise NotImplemented | raise NotImplemented | ||||
def _fast_param_map(self, pred_dict, target_dict): | |||||
"""Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | |||||
such as pred_dict has one element, target_dict has one element | |||||
:param pred_dict: | |||||
:param target_dict: | |||||
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. | |||||
""" | |||||
fast_param = {} | |||||
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||||
fast_param['pred'] = list(pred_dict.values())[0] | |||||
fast_param['target'] = list(target_dict.values())[0] | |||||
return fast_param | |||||
return fast_param | |||||
def __call__(self, pred_dict, target_dict): | def __call__(self, pred_dict, target_dict): | ||||
""" | """ | ||||
@@ -172,7 +156,6 @@ class MetricBase(object): | |||||
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | ||||
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | ||||
will be conducted.) | will be conducted.) | ||||
This function also support _fast_param_map. | |||||
:param pred_dict: usually the output of forward or prediction function | :param pred_dict: usually the output of forward or prediction function | ||||
:param target_dict: usually features set as target.. | :param target_dict: usually features set as target.. | ||||
:return: | :return: | ||||
@@ -180,11 +163,6 @@ class MetricBase(object): | |||||
if not callable(self.evaluate): | if not callable(self.evaluate): | ||||
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | ||||
fast_param = self._fast_param_map(pred_dict=pred_dict, target_dict=target_dict) | |||||
if fast_param: | |||||
self.evaluate(**fast_param) | |||||
return | |||||
if not self._checked: | if not self._checked: | ||||
# 1. check consistence between signature and param_map | # 1. check consistence between signature and param_map | ||||
func_spect = inspect.getfullargspec(self.evaluate) | func_spect = inspect.getfullargspec(self.evaluate) | ||||
@@ -262,41 +240,6 @@ class AccuracyMetric(MetricBase): | |||||
self.total = 0 | self.total = 0 | ||||
self.acc_count = 0 | self.acc_count = 0 | ||||
def _fast_param_map(self, pred_dict, target_dict): | |||||
"""Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | |||||
such as pred_dict has one element, target_dict has one element | |||||
:param pred_dict: | |||||
:param target_dict: | |||||
:return: dict, if dict is not None, pass it to self.evaluate. Otherwise do mapping. | |||||
""" | |||||
fast_param = {} | |||||
targets = list(target_dict.values()) | |||||
if len(targets) == 1 and isinstance(targets[0], torch.Tensor): | |||||
if len(pred_dict) == 1: | |||||
pred = list(pred_dict.values())[0] | |||||
fast_param['pred'] = pred | |||||
elif len(pred_dict) == 2: | |||||
pred1 = list(pred_dict.values())[0] | |||||
pred2 = list(pred_dict.values())[1] | |||||
if not (isinstance(pred1, torch.Tensor) and isinstance(pred2, torch.Tensor)): | |||||
return fast_param | |||||
if len(pred1.size()) < len(pred2.size()) and len(pred1.size()) == 1: | |||||
seq_lens = pred1 | |||||
pred = pred2 | |||||
elif len(pred1.size()) > len(pred2.size()) and len(pred2.size()) == 1: | |||||
seq_lens = pred2 | |||||
pred = pred1 | |||||
else: | |||||
return fast_param | |||||
fast_param['pred'] = pred | |||||
fast_param['seq_lens'] = seq_lens | |||||
else: | |||||
return fast_param | |||||
fast_param['target'] = targets[0] | |||||
# TODO need to make sure they all have same batch_size | |||||
return fast_param | |||||
def evaluate(self, pred, target, seq_lens=None): | def evaluate(self, pred, target, seq_lens=None): | ||||
""" | """ | ||||
@@ -321,7 +264,7 @@ class AccuracyMetric(MetricBase): | |||||
f"got {type(seq_lens)}.") | f"got {type(seq_lens)}.") | ||||
if seq_lens is not None: | if seq_lens is not None: | ||||
masks = seq_lens_to_masks(seq_lens=seq_lens, float=True) | |||||
masks = seq_lens_to_masks(seq_lens=seq_lens).long() | |||||
else: | else: | ||||
masks = None | masks = None | ||||
@@ -334,14 +277,12 @@ class AccuracyMetric(MetricBase): | |||||
f"size:{pred.size()}, target should have size: {pred.size()} or " | f"size:{pred.size()}, target should have size: {pred.size()} or " | ||||
f"{pred.size()[:-1]}, got {target.size()}.") | f"{pred.size()[:-1]}, got {target.size()}.") | ||||
pred = pred.float() | |||||
target = target.float() | |||||
if masks is not None: | if masks is not None: | ||||
self.acc_count += torch.sum(torch.eq(pred, target).float() * masks.float()).item() | |||||
self.total += torch.sum(masks.float()).item() | |||||
self.acc_count += torch.sum(torch.eq(pred, target) * masks).item() | |||||
self.total += torch.sum(masks).item() | |||||
else: | else: | ||||
self.acc_count += torch.sum(torch.eq(pred, target).float()).item() | |||||
self.acc_count += torch.sum(torch.eq(pred, target)).item() | |||||
self.total += np.prod(list(pred.size())) | self.total += np.prod(list(pred.size())) | ||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
@@ -350,7 +291,7 @@ class AccuracyMetric(MetricBase): | |||||
:param bool reset: whether to recount next time. | :param bool reset: whether to recount next time. | ||||
:return evaluate_result: {"acc": float} | :return evaluate_result: {"acc": float} | ||||
""" | """ | ||||
evaluate_result = {'acc': round(self.acc_count / self.total, 6)} | |||||
evaluate_result = {'acc': round(float(self.acc_count) / (self.total + 1e-12), 6)} | |||||
if reset: | if reset: | ||||
self.acc_count = 0 | self.acc_count = 0 | ||||
self.total = 0 | self.total = 0 | ||||
@@ -441,8 +382,7 @@ def bio_tag_to_spans(tags, ignore_labels=None): | |||||
prev_bio_tag = bio_tag | prev_bio_tag = bio_tag | ||||
return [(span[0], (span[1][0], span[1][1]+1)) | return [(span[0], (span[1][0], span[1][1]+1)) | ||||
for span in spans | for span in spans | ||||
if span[0] not in ignore_labels | |||||
] | |||||
if span[0] not in ignore_labels] | |||||
class SpanFPreRecMetric(MetricBase): | class SpanFPreRecMetric(MetricBase): | ||||
@@ -34,7 +34,7 @@ class Trainer(object): | |||||
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | ||||
validate_every=-1, dev_data=None, save_path=None, optimizer=None, | validate_every=-1, dev_data=None, save_path=None, optimizer=None, | ||||
check_code_level=0, metric_key=None, sampler=None, prefetch=False, use_tqdm=True, | check_code_level=0, metric_key=None, sampler=None, prefetch=False, use_tqdm=True, | ||||
use_cuda=False, callbacks=None): | |||||
use_cuda=False, callbacks=None, update_every=1): | |||||
""" | """ | ||||
:param DataSet train_data: the training data | :param DataSet train_data: the training data | ||||
:param torch.nn.modules.module model: a PyTorch model | :param torch.nn.modules.module model: a PyTorch model | ||||
@@ -62,6 +62,8 @@ class Trainer(object): | |||||
:param bool use_tqdm: whether to use tqdm to show train progress. | :param bool use_tqdm: whether to use tqdm to show train progress. | ||||
:param callbacks: List[Callback]. 用于在train过程中起调节作用的回调函数。比如early stop,negative sampling等可以 | :param callbacks: List[Callback]. 用于在train过程中起调节作用的回调函数。比如early stop,negative sampling等可以 | ||||
通过callback机制实现。 | 通过callback机制实现。 | ||||
:param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128会导致内存 | |||||
不足,通过设置batch_size=32, update_every=4达到目的 | |||||
""" | """ | ||||
super(Trainer, self).__init__() | super(Trainer, self).__init__() | ||||
@@ -76,6 +78,10 @@ class Trainer(object): | |||||
if metrics and (dev_data is None): | if metrics and (dev_data is None): | ||||
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") | raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") | ||||
# check update every | |||||
assert update_every>=1, "update_every must be no less than 1." | |||||
self.update_every = int(update_every) | |||||
# check save_path | # check save_path | ||||
if not (save_path is None or isinstance(save_path, str)): | if not (save_path is None or isinstance(save_path, str)): | ||||
raise ValueError("save_path can only be None or `str`.") | raise ValueError("save_path can only be None or `str`.") | ||||
@@ -144,11 +150,9 @@ class Trainer(object): | |||||
self.start_time = None # start timestamp | self.start_time = None # start timestamp | ||||
self.callback_manager = CallbackManager(env={"trainer": self}, | self.callback_manager = CallbackManager(env={"trainer": self}, | ||||
attr={"n_epochs": self.n_epochs, "n_steps": self.step, | |||||
"batch_size": self.batch_size, "model": self.model, | |||||
"optimizer": self.optimizer}, | |||||
callbacks=callbacks) | callbacks=callbacks) | ||||
def train(self, load_best_model=True): | def train(self, load_best_model=True): | ||||
""" | """ | ||||
@@ -241,7 +245,6 @@ class Trainer(object): | |||||
avg_loss = 0 | avg_loss = 0 | ||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | ||||
prefetch=self.prefetch) | prefetch=self.prefetch) | ||||
self.callback_manager.set_property(pbar=pbar) | |||||
for epoch in range(1, self.n_epochs+1): | for epoch in range(1, self.n_epochs+1): | ||||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | ||||
# early stopping | # early stopping | ||||
@@ -257,6 +260,7 @@ class Trainer(object): | |||||
self.callback_manager.on_loss_begin(batch_y, prediction) | self.callback_manager.on_loss_begin(batch_y, prediction) | ||||
loss = self._compute_loss(prediction, batch_y) | loss = self._compute_loss(prediction, batch_y) | ||||
avg_loss += loss.item() | avg_loss += loss.item() | ||||
loss = loss/self.update_every | |||||
# Is loss NaN or inf? requires_grad = False | # Is loss NaN or inf? requires_grad = False | ||||
self.callback_manager.on_backward_begin(loss, self.model) | self.callback_manager.on_backward_begin(loss, self.model) | ||||
@@ -267,8 +271,9 @@ class Trainer(object): | |||||
self.callback_manager.on_step_end(self.optimizer) | self.callback_manager.on_step_end(self.optimizer) | ||||
if (self.step+1) % self.print_every == 0: | if (self.step+1) % self.print_every == 0: | ||||
avg_loss = avg_loss / self.print_every | |||||
if self.use_tqdm: | if self.use_tqdm: | ||||
print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every) | |||||
print_output = "loss:{0:<6.5f}".format(avg_loss) | |||||
pbar.update(self.print_every) | pbar.update(self.print_every) | ||||
else: | else: | ||||
end = time.time() | end = time.time() | ||||
@@ -286,8 +291,8 @@ class Trainer(object): | |||||
eval_res = self._do_validation(epoch=epoch, step=self.step) | eval_res = self._do_validation(epoch=epoch, step=self.step) | ||||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | ||||
total_steps) + \ | total_steps) + \ | ||||
self.tester._format_eval_results(eval_res) | |||||
pbar.write(eval_str) | |||||
self.tester._format_eval_results(eval_res) | |||||
pbar.write(eval_str + '\n') | |||||
# ================= mini-batch end ==================== # | # ================= mini-batch end ==================== # | ||||
@@ -301,6 +306,7 @@ class Trainer(object): | |||||
self.callback_manager.on_valid_begin() | self.callback_manager.on_valid_begin() | ||||
res = self.tester.test() | res = self.tester.test() | ||||
is_better_eval = False | |||||
if self._better_eval_result(res): | if self._better_eval_result(res): | ||||
if self.save_path is not None: | if self.save_path is not None: | ||||
self._save_model(self.model, | self._save_model(self.model, | ||||
@@ -310,8 +316,9 @@ class Trainer(object): | |||||
self.best_dev_perf = res | self.best_dev_perf = res | ||||
self.best_dev_epoch = epoch | self.best_dev_epoch = epoch | ||||
self.best_dev_step = step | self.best_dev_step = step | ||||
is_better_eval = True | |||||
# get validation results; adjust optimizer | # get validation results; adjust optimizer | ||||
self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer) | |||||
self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer, is_better_eval) | |||||
return res | return res | ||||
def _mode(self, model, is_test=False): | def _mode(self, model, is_test=False): | ||||
@@ -330,7 +337,8 @@ class Trainer(object): | |||||
"""Perform weight update on a model. | """Perform weight update on a model. | ||||
""" | """ | ||||
self.optimizer.step() | |||||
if (self.step+1)%self.update_every==0: | |||||
self.optimizer.step() | |||||
def _data_forward(self, network, x): | def _data_forward(self, network, x): | ||||
x = _build_args(network.forward, **x) | x = _build_args(network.forward, **x) | ||||
@@ -346,7 +354,8 @@ class Trainer(object): | |||||
For PyTorch, just do "loss.backward()" | For PyTorch, just do "loss.backward()" | ||||
""" | """ | ||||
self.model.zero_grad() | |||||
if self.step%self.update_every==0: | |||||
self.model.zero_grad() | |||||
loss.backward() | loss.backward() | ||||
def _compute_loss(self, predict, truth): | def _compute_loss(self, predict, truth): | ||||
@@ -13,12 +13,12 @@ with open('requirements.txt', encoding='utf-8') as f: | |||||
setup( | setup( | ||||
name='FastNLP', | name='FastNLP', | ||||
version='0.1.1', | |||||
version='0.4.0', | |||||
description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', | description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', | ||||
long_description=readme, | long_description=readme, | ||||
license=license, | license=license, | ||||
author='FudanNLP', | author='FudanNLP', | ||||
python_requires='>=3.5', | |||||
python_requires='>=3.6', | |||||
packages=find_packages(), | packages=find_packages(), | ||||
install_requires=reqs.strip().split('\n'), | install_requires=reqs.strip().split('\n'), | ||||
) | ) |
@@ -35,7 +35,7 @@ class TestENAS(unittest.TestCase): | |||||
print(dataset[0]) | print(dataset[0]) | ||||
# DataSet.drop(func)筛除数据 | # DataSet.drop(func)筛除数据 | ||||
dataset.drop(lambda x: x['seq_len'] <= 3) | |||||
dataset.drop(lambda x: x['seq_len'] <= 3, inplace=True) | |||||
print(len(dataset)) | print(len(dataset)) | ||||
# 设置DataSet中,哪些field要转为tensor | # 设置DataSet中,哪些field要转为tensor | ||||
@@ -125,7 +125,7 @@ class TestDataSetMethods(unittest.TestCase): | |||||
def test_drop(self): | def test_drop(self): | ||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20}) | ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20}) | ||||
ds.drop(lambda ins: len(ins["y"]) < 3) | |||||
ds.drop(lambda ins: len(ins["y"]) < 3, inplace=True) | |||||
self.assertEqual(len(ds), 20) | self.assertEqual(len(ds), 20) | ||||
def test_contains(self): | def test_contains(self): | ||||
@@ -169,7 +169,7 @@ class TestDataSetMethods(unittest.TestCase): | |||||
dataset = DataSet.read_csv('test/data_for_tests/tutorial_sample_dataset.csv', headers=('raw_sentence', 'label'), | dataset = DataSet.read_csv('test/data_for_tests/tutorial_sample_dataset.csv', headers=('raw_sentence', 'label'), | ||||
sep='\t') | sep='\t') | ||||
dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0) | |||||
dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0, inplace=True) | |||||
dataset.apply(split_sent, new_field_name='words', is_input=True) | dataset.apply(split_sent, new_field_name='words', is_input=True) | ||||
# print(dataset) | # print(dataset) | ||||
@@ -35,7 +35,7 @@ class TestTutorial(unittest.TestCase): | |||||
print(dataset[0]) | print(dataset[0]) | ||||
# DataSet.drop(func)筛除数据 | # DataSet.drop(func)筛除数据 | ||||
dataset.drop(lambda x: x['seq_len'] <= 3) | |||||
dataset.drop(lambda x: x['seq_len'] <= 3, inplace=True) | |||||
print(len(dataset)) | print(len(dataset)) | ||||
# 设置DataSet中,哪些field要转为tensor | # 设置DataSet中,哪些field要转为tensor | ||||
@@ -296,7 +296,7 @@ class TestTutorial(unittest.TestCase): | |||||
# 筛选数据 | # 筛选数据 | ||||
origin_data_set_len = len(data_set) | origin_data_set_len = len(data_set) | ||||
data_set.drop(lambda x: len(x['premise']) <= 6) | |||||
data_set.drop(lambda x: len(x['premise']) <= 6, inplace=True) | |||||
origin_data_set_len, len(data_set) | origin_data_set_len, len(data_set) | ||||
# In[17]: | # In[17]: | ||||