# Conflicts: # fastNLP/core/dataset.py # fastNLP/core/trainer.py # test/core/test_trainer.py Trainer support print_train and tqdm train.tags/v0.2.0^2
@@ -1,7 +1,9 @@ | |||||
import _pickle as pickle | |||||
import numpy as np | import numpy as np | ||||
from fastNLP.core.fieldarray import FieldArray | from fastNLP.core.fieldarray import FieldArray | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.utils import get_func_signature | |||||
_READERS = {} | _READERS = {} | ||||
@@ -26,24 +28,6 @@ class DataSet(object): | |||||
However, it stores data in a different way: Field-first, Instance-second. | However, it stores data in a different way: Field-first, Instance-second. | ||||
""" | """ | ||||
class DataSetIter(object): | |||||
def __init__(self, data_set, idx=-1, **fields): | |||||
self.data_set = data_set | |||||
self.idx = idx | |||||
self.fields = fields | |||||
def __next__(self): | |||||
self.idx += 1 | |||||
if self.idx >= len(self.data_set): | |||||
raise StopIteration | |||||
# this returns a copy | |||||
return self.data_set[self.idx] | |||||
def __repr__(self): | |||||
return "\n".join(['{}: {}'.format(name, repr(self.data_set[name][self.idx])) for name | |||||
in self.data_set.get_fields().keys()]) | |||||
def __init__(self, data=None): | def __init__(self, data=None): | ||||
""" | """ | ||||
@@ -72,7 +56,27 @@ class DataSet(object): | |||||
return item in self.field_arrays | return item in self.field_arrays | ||||
def __iter__(self): | def __iter__(self): | ||||
return self.DataSetIter(self) | |||||
def iter_func(): | |||||
for idx in range(len(self)): | |||||
yield self[idx] | |||||
return iter_func() | |||||
def _inner_iter(self): | |||||
class Iter_ptr: | |||||
def __init__(self, dataset, idx): | |||||
self.dataset = dataset | |||||
self.idx = idx | |||||
def __getitem__(self, item): | |||||
assert self.idx < len(self.dataset), "index:{} out of range".format(self.idx) | |||||
assert item in self.dataset.field_arrays, "no such field:{} in instance {}".format(item, self.dataset[self.idx]) | |||||
return self.dataset.field_arrays[item][self.idx] | |||||
def __repr__(self): | |||||
return self.dataset[self.idx].__repr__() | |||||
def inner_iter_func(): | |||||
for idx in range(len(self)): | |||||
yield Iter_ptr(self, idx) | |||||
return inner_iter_func() | |||||
def __getitem__(self, idx): | def __getitem__(self, idx): | ||||
"""Fetch Instance(s) at the `idx` position(s) in the dataset. | """Fetch Instance(s) at the `idx` position(s) in the dataset. | ||||
@@ -110,6 +114,15 @@ class DataSet(object): | |||||
field = iter(self.field_arrays.values()).__next__() | field = iter(self.field_arrays.values()).__next__() | ||||
return len(field) | return len(field) | ||||
def __inner_repr__(self): | |||||
if len(self) < 20: | |||||
return ",\n".join([ins.__repr__() for ins in self]) | |||||
else: | |||||
return self[:5].__inner_repr__() + "\n...\n" + self[-5:].__inner_repr__() | |||||
def __repr__(self): | |||||
return "DataSet(" + self.__inner_repr__() + ")" | |||||
def append(self, ins): | def append(self, ins): | ||||
"""Add an instance to the DataSet. | """Add an instance to the DataSet. | ||||
If the DataSet is not empty, the instance must have the same field names as the rest instances in the DataSet. | If the DataSet is not empty, the instance must have the same field names as the rest instances in the DataSet. | ||||
@@ -226,7 +239,10 @@ class DataSet(object): | |||||
(2) is_target: boolean, will be ignored if new_field is None. If True, the new field will be as target. | (2) is_target: boolean, will be ignored if new_field is None. If True, the new field will be as target. | ||||
:return results: if new_field_name is not passed, returned values of the function over all instances. | :return results: if new_field_name is not passed, returned values of the function over all instances. | ||||
""" | """ | ||||
results = [func(ins) for ins in self] | |||||
results = [func(ins) for ins in self._inner_iter()] | |||||
if len(list(filter(lambda x: x is not None, results)))==0: # all None | |||||
raise ValueError("{} always return None.".format(get_func_signature(func=func))) | |||||
extra_param = {} | extra_param = {} | ||||
if 'is_input' in kwargs: | if 'is_input' in kwargs: | ||||
extra_param['is_input'] = kwargs['is_input'] | extra_param['is_input'] = kwargs['is_input'] | ||||
@@ -250,7 +266,7 @@ class DataSet(object): | |||||
return results | return results | ||||
def drop(self, func): | def drop(self, func): | ||||
results = [ins for ins in self if not func(ins)] | |||||
results = [ins for ins in self._inner_iter() if not func(ins)] | |||||
for name, old_field in self.field_arrays.items(): | for name, old_field in self.field_arrays.items(): | ||||
self.field_arrays[name].content = [ins[name] for ins in results] | self.field_arrays[name].content = [ins[name] for ins in results] | ||||
@@ -317,3 +333,12 @@ class DataSet(object): | |||||
for header, content in zip(headers, contents): | for header, content in zip(headers, contents): | ||||
_dict[header].append(content) | _dict[header].append(content) | ||||
return cls(_dict) | return cls(_dict) | ||||
def save(self, path): | |||||
with open(path, 'wb') as f: | |||||
pickle.dump(self, f) | |||||
@staticmethod | |||||
def load(self, path): | |||||
with open(path, 'rb') as f: | |||||
return pickle.load(f) |
@@ -1,5 +1,3 @@ | |||||
class Instance(object): | class Instance(object): | ||||
"""An Instance is an example of data. It is the collection of Fields. | """An Instance is an example of data. It is the collection of Fields. | ||||
@@ -33,4 +31,5 @@ class Instance(object): | |||||
return self.add_field(name, field) | return self.add_field(name, field) | ||||
def __repr__(self): | def __repr__(self): | ||||
return self.fields.__repr__() | |||||
return "{" + ",\n".join( | |||||
"\'" + field_name + "\': " + str(self.fields[field_name]) for field_name in self.fields) + "}" |
@@ -70,13 +70,23 @@ class LossBase(object): | |||||
raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use " | raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use " | ||||
f"positional argument.).") | f"positional argument.).") | ||||
def __call__(self, output_dict, target_dict, force_check=False): | |||||
def _fast_param_map(self, pred_dict, target_dict): | |||||
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||||
return tuple(pred_dict.values())[0], tuple(target_dict.values())[0] | |||||
return None | |||||
def __call__(self, pred_dict, target_dict, check=False): | |||||
""" | """ | ||||
:param output_dict: A dict from forward function of the network. | |||||
:param pred_dict: A dict from forward function of the network. | |||||
:param target_dict: A dict from DataSet.batch_y. | :param target_dict: A dict from DataSet.batch_y. | ||||
:param force_check: Boolean. Force to check the mapping functions when it is running. | |||||
:param check: Boolean. Force to check the mapping functions when it is running. | |||||
:return: | :return: | ||||
""" | """ | ||||
fast_param = self._fast_param_map(pred_dict, target_dict) | |||||
if fast_param is not None: | |||||
loss = self.get_loss(*fast_param) | |||||
return loss | |||||
args, defaults, defaults_val, varargs, kwargs = _get_arg_list(self.get_loss) | args, defaults, defaults_val, varargs, kwargs = _get_arg_list(self.get_loss) | ||||
if varargs is not None: | if varargs is not None: | ||||
raise RuntimeError( | raise RuntimeError( | ||||
@@ -88,7 +98,8 @@ class LossBase(object): | |||||
raise RuntimeError( | raise RuntimeError( | ||||
f"There is not any param in function{get_func_signature(self.get_loss)}" | f"There is not any param in function{get_func_signature(self.get_loss)}" | ||||
) | ) | ||||
self._checked = self._checked and not force_check | |||||
self._checked = self._checked and not check | |||||
if not self._checked: | if not self._checked: | ||||
for keys in args: | for keys in args: | ||||
if keys not in param_map: | if keys not in param_map: | ||||
@@ -105,12 +116,12 @@ class LossBase(object): | |||||
duplicated = [] | duplicated = [] | ||||
missing = [] | missing = [] | ||||
if not self._checked: | if not self._checked: | ||||
for keys, val in output_dict.items(): | |||||
for keys, val in pred_dict.items(): | |||||
if keys in target_dict.keys(): | if keys in target_dict.keys(): | ||||
duplicated.append(keys) | duplicated.append(keys) | ||||
param_val_dict = {} | param_val_dict = {} | ||||
for keys, val in output_dict.items(): | |||||
for keys, val in pred_dict.items(): | |||||
param_val_dict.update({keys: val}) | param_val_dict.update({keys: val}) | ||||
for keys, val in target_dict.items(): | for keys, val in target_dict.items(): | ||||
param_val_dict.update({keys: val}) | param_val_dict.update({keys: val}) | ||||
@@ -131,7 +142,6 @@ class LossBase(object): | |||||
param_map_val = _map_args(reversed_param_map, **param_val_dict) | param_map_val = _map_args(reversed_param_map, **param_val_dict) | ||||
param_value = _build_args(self.get_loss, **param_map_val) | param_value = _build_args(self.get_loss, **param_map_val) | ||||
loss = self.get_loss(**param_value) | loss = self.get_loss(**param_value) | ||||
if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | ||||
@@ -158,29 +168,31 @@ class LossFunc(LossBase): | |||||
class CrossEntropyLoss(LossBase): | class CrossEntropyLoss(LossBase): | ||||
def __init__(self, input=None, target=None): | |||||
def __init__(self, pred=None, target=None): | |||||
super(CrossEntropyLoss, self).__init__() | super(CrossEntropyLoss, self).__init__() | ||||
self.get_loss = F.cross_entropy | self.get_loss = F.cross_entropy | ||||
self._init_param_map(input=input, target=target) | |||||
self._init_param_map(input=pred, target=target) | |||||
class L1Loss(LossBase): | class L1Loss(LossBase): | ||||
def __init__(self): | |||||
def __init__(self, pred=None, target=None): | |||||
super(L1Loss, self).__init__() | super(L1Loss, self).__init__() | ||||
self.get_loss = F.l1_loss | self.get_loss = F.l1_loss | ||||
self._init_param_map(input=pred, target=target) | |||||
class BCELoss(LossBase): | class BCELoss(LossBase): | ||||
def __init__(self, input=None, target=None): | |||||
def __init__(self, pred=None, target=None): | |||||
super(BCELoss, self).__init__() | super(BCELoss, self).__init__() | ||||
self.get_loss = F.binary_cross_entropy | self.get_loss = F.binary_cross_entropy | ||||
self._init_param_map(input=input, target=target) | |||||
self._init_param_map(input=pred, target=target) | |||||
class NLLLoss(LossBase): | class NLLLoss(LossBase): | ||||
def __init__(self): | |||||
def __init__(self, pred=None, target=None): | |||||
super(NLLLoss, self).__init__() | super(NLLLoss, self).__init__() | ||||
self.get_loss = F.nll_loss | self.get_loss = F.nll_loss | ||||
self._init_param_map(input=pred, target=target) | |||||
class LossInForward(LossBase): | class LossInForward(LossBase): | ||||
@@ -199,10 +211,11 @@ class LossInForward(LossBase): | |||||
all_needed=[], | all_needed=[], | ||||
varargs=[]) | varargs=[]) | ||||
raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss)) | raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss)) | ||||
return kwargs[self.loss_key] | |||||
def __call__(self, output_dict, predict_dict, force_check=False): | |||||
def __call__(self, pred_dict, target_dict, check=False): | |||||
loss = self.get_loss(**output_dict) | |||||
loss = self.get_loss(**pred_dict) | |||||
if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | ||||
if not isinstance(loss, torch.Tensor): | if not isinstance(loss, torch.Tensor): | ||||
@@ -1,4 +1,3 @@ | |||||
import inspect | import inspect | ||||
import warnings | import warnings | ||||
from collections import defaultdict | from collections import defaultdict | ||||
@@ -7,11 +6,12 @@ import numpy as np | |||||
import torch | import torch | ||||
from fastNLP.core.utils import CheckError | from fastNLP.core.utils import CheckError | ||||
from fastNLP.core.utils import CheckRes | |||||
from fastNLP.core.utils import _build_args | from fastNLP.core.utils import _build_args | ||||
from fastNLP.core.utils import _check_arg_dict_list | from fastNLP.core.utils import _check_arg_dict_list | ||||
from fastNLP.core.utils import get_func_signature | from fastNLP.core.utils import get_func_signature | ||||
from fastNLP.core.utils import seq_lens_to_masks | from fastNLP.core.utils import seq_lens_to_masks | ||||
from fastNLP.core.utils import CheckRes | |||||
class MetricBase(object): | class MetricBase(object): | ||||
def __init__(self): | def __init__(self): | ||||
@@ -59,9 +59,10 @@ class MetricBase(object): | |||||
func_args = [arg for arg in func_spect.args if arg != 'self'] | func_args = [arg for arg in func_spect.args if arg != 'self'] | ||||
for func_param, input_param in self.param_map.items(): | for func_param, input_param in self.param_map.items(): | ||||
if func_param not in func_args: | if func_param not in func_args: | ||||
raise NameError(f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " | |||||
f"initialization parameters, or change the signature of" | |||||
f" {get_func_signature(self.evaluate)}.") | |||||
raise NameError( | |||||
f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " | |||||
f"initialization parameters, or change the signature of" | |||||
f" {get_func_signature(self.evaluate)}.") | |||||
# evaluate should not have varargs. | # evaluate should not have varargs. | ||||
if func_spect.varargs: | if func_spect.varargs: | ||||
@@ -71,7 +72,7 @@ class MetricBase(object): | |||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
raise NotImplemented | raise NotImplemented | ||||
def _fast_call_evaluate(self, pred_dict, target_dict): | |||||
def _fast_param_map(self, pred_dict, target_dict): | |||||
""" | """ | ||||
Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | ||||
@@ -80,7 +81,9 @@ class MetricBase(object): | |||||
:param target_dict: | :param target_dict: | ||||
:return: boolean, whether to go on codes in self.__call__(). When False, don't go on. | :return: boolean, whether to go on codes in self.__call__(). When False, don't go on. | ||||
""" | """ | ||||
return False | |||||
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||||
return pred_dict.values[0] and target_dict.values[0] | |||||
return None | |||||
def __call__(self, pred_dict, target_dict, check=False): | def __call__(self, pred_dict, target_dict, check=False): | ||||
""" | """ | ||||
@@ -103,13 +106,15 @@ class MetricBase(object): | |||||
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | ||||
if not check: | if not check: | ||||
if self._fast_call_evaluate(pred_dict=pred_dict, target_dict=target_dict): | |||||
fast_param = self._fast_param_map(pred_dict=pred_dict, target_dict=target_dict) | |||||
if fast_param is not None: | |||||
self.evaluate(*fast_param) | |||||
return | return | ||||
if not self._checked: | if not self._checked: | ||||
# 1. check consistence between signature and param_map | # 1. check consistence between signature and param_map | ||||
func_spect = inspect.getfullargspec(self.evaluate) | func_spect = inspect.getfullargspec(self.evaluate) | ||||
func_args = set([arg for arg in func_spect.args if arg!='self']) | |||||
func_args = set([arg for arg in func_spect.args if arg != 'self']) | |||||
for func_arg, input_arg in self.param_map.items(): | for func_arg, input_arg in self.param_map.items(): | ||||
if func_arg not in func_args: | if func_arg not in func_args: | ||||
raise NameError(f"`{func_arg}` not in {get_func_signature(self.evaluate)}.") | raise NameError(f"`{func_arg}` not in {get_func_signature(self.evaluate)}.") | ||||
@@ -117,7 +122,7 @@ class MetricBase(object): | |||||
# 2. only part of the param_map are passed, left are not | # 2. only part of the param_map are passed, left are not | ||||
for arg in func_args: | for arg in func_args: | ||||
if arg not in self.param_map: | if arg not in self.param_map: | ||||
self.param_map[arg] = arg #This param does not need mapping. | |||||
self.param_map[arg] = arg # This param does not need mapping. | |||||
self._evaluate_args = func_args | self._evaluate_args = func_args | ||||
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} | self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} | ||||
@@ -149,14 +154,14 @@ class MetricBase(object): | |||||
replaced_missing = list(missing) | replaced_missing = list(missing) | ||||
for idx, func_arg in enumerate(missing): | for idx, func_arg in enumerate(missing): | ||||
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | ||||
f"in `{self.__class__.__name__}`)" | |||||
f"in `{self.__class__.__name__}`)" | |||||
check_res = CheckRes(missing=replaced_missing, | check_res = CheckRes(missing=replaced_missing, | ||||
unused=check_res.unused, | |||||
duplicated=duplicated, | |||||
required=check_res.required, | |||||
all_needed=check_res.all_needed, | |||||
varargs=check_res.varargs) | |||||
unused=check_res.unused, | |||||
duplicated=duplicated, | |||||
required=check_res.required, | |||||
all_needed=check_res.all_needed, | |||||
varargs=check_res.varargs) | |||||
if check_res.missing or check_res.duplicated or check_res.varargs: | if check_res.missing or check_res.duplicated or check_res.varargs: | ||||
raise CheckError(check_res=check_res, | raise CheckError(check_res=check_res, | ||||
@@ -168,6 +173,7 @@ class MetricBase(object): | |||||
return | return | ||||
class AccuracyMetric(MetricBase): | class AccuracyMetric(MetricBase): | ||||
def __init__(self, pred=None, target=None, masks=None, seq_lens=None): | def __init__(self, pred=None, target=None, masks=None, seq_lens=None): | ||||
super().__init__() | super().__init__() | ||||
@@ -187,7 +193,7 @@ class AccuracyMetric(MetricBase): | |||||
:param target_dict: | :param target_dict: | ||||
:return: boolean, whether to go on codes in self.__call__(). When False, don't go on. | :return: boolean, whether to go on codes in self.__call__(). When False, don't go on. | ||||
""" | """ | ||||
if len(pred_dict)==1 and len(target_dict)==1: | |||||
if len(pred_dict) == 1 and len(target_dict) == 1: | |||||
pred = list(pred_dict.values())[0] | pred = list(pred_dict.values())[0] | ||||
target = list(target_dict.values())[0] | target = list(target_dict.values())[0] | ||||
self.evaluate(pred=pred, target=target) | self.evaluate(pred=pred, target=target) | ||||
@@ -207,7 +213,7 @@ class AccuracyMetric(MetricBase): | |||||
None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. | None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. | ||||
:return: dict({'acc': float}) | :return: dict({'acc': float}) | ||||
""" | """ | ||||
#TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value | |||||
# TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value | |||||
if not isinstance(pred, torch.Tensor): | if not isinstance(pred, torch.Tensor): | ||||
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," | raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," | ||||
f"got {type(pred)}.") | f"got {type(pred)}.") | ||||
@@ -220,14 +226,14 @@ class AccuracyMetric(MetricBase): | |||||
f"got {type(masks)}.") | f"got {type(masks)}.") | ||||
elif seq_lens is not None and not isinstance(seq_lens, torch.Tensor): | elif seq_lens is not None and not isinstance(seq_lens, torch.Tensor): | ||||
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," | raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," | ||||
f"got {type(seq_lens)}.") | |||||
f"got {type(seq_lens)}.") | |||||
if masks is None and seq_lens is not None: | if masks is None and seq_lens is not None: | ||||
masks = seq_lens_to_masks(seq_lens=seq_lens, float=True) | masks = seq_lens_to_masks(seq_lens=seq_lens, float=True) | ||||
if pred.size()==target.size(): | |||||
if pred.size() == target.size(): | |||||
pass | pass | ||||
elif len(pred.size())==len(target.size())+1: | |||||
elif len(pred.size()) == len(target.size()) + 1: | |||||
pred = pred.argmax(dim=-1) | pred = pred.argmax(dim=-1) | ||||
else: | else: | ||||
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " | raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " | ||||
@@ -241,18 +247,17 @@ class AccuracyMetric(MetricBase): | |||||
self.acc_count += torch.sum(torch.eq(pred, target).float() * masks.float()).item() | self.acc_count += torch.sum(torch.eq(pred, target).float() * masks.float()).item() | ||||
self.total += torch.sum(masks.float()).item() | self.total += torch.sum(masks.float()).item() | ||||
else: | else: | ||||
self.acc_count += torch.sum(torch.eq(pred, target).float()).item() | |||||
self.acc_count += torch.sum(torch.eq(pred, target).float()).item() | |||||
self.total += np.prod(list(pred.size())) | self.total += np.prod(list(pred.size())) | ||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
evaluate_result = {'acc': round(self.acc_count/self.total, 6)} | |||||
evaluate_result = {'acc': round(self.acc_count / self.total, 6)} | |||||
if reset: | if reset: | ||||
self.acc_count = 0 | self.acc_count = 0 | ||||
self.total = 0 | self.total = 0 | ||||
return evaluate_result | return evaluate_result | ||||
def _prepare_metrics(metrics): | def _prepare_metrics(metrics): | ||||
""" | """ | ||||
@@ -274,7 +279,8 @@ def _prepare_metrics(metrics): | |||||
raise TypeError(f"{metric_name}.get_metric must be callable, got {type(metric.get_metric)}.") | raise TypeError(f"{metric_name}.get_metric must be callable, got {type(metric.get_metric)}.") | ||||
_metrics.append(metric) | _metrics.append(metric) | ||||
else: | else: | ||||
raise TypeError(f"The type of metric in metrics must be `fastNLP.MetricBase`, not `{type(metric)}`.") | |||||
raise TypeError( | |||||
f"The type of metric in metrics must be `fastNLP.MetricBase`, not `{type(metric)}`.") | |||||
elif isinstance(metrics, MetricBase): | elif isinstance(metrics, MetricBase): | ||||
_metrics = [metrics] | _metrics = [metrics] | ||||
else: | else: | ||||
@@ -296,6 +302,7 @@ class Evaluator(object): | |||||
""" | """ | ||||
raise NotImplementedError | raise NotImplementedError | ||||
class ClassifyEvaluator(Evaluator): | class ClassifyEvaluator(Evaluator): | ||||
def __init__(self): | def __init__(self): | ||||
super(ClassifyEvaluator, self).__init__() | super(ClassifyEvaluator, self).__init__() | ||||
@@ -331,6 +338,7 @@ class SeqLabelEvaluator(Evaluator): | |||||
accuracy = total_correct / total_count | accuracy = total_correct / total_count | ||||
return {"accuracy": float(accuracy)} | return {"accuracy": float(accuracy)} | ||||
class SeqLabelEvaluator2(Evaluator): | class SeqLabelEvaluator2(Evaluator): | ||||
# 上面的evaluator应该是错误的 | # 上面的evaluator应该是错误的 | ||||
def __init__(self, seq_lens_field_name='word_seq_origin_len'): | def __init__(self, seq_lens_field_name='word_seq_origin_len'): | ||||
@@ -363,7 +371,7 @@ class SeqLabelEvaluator2(Evaluator): | |||||
if x_i in self.end_tagidx_set: | if x_i in self.end_tagidx_set: | ||||
truth_count += 1 | truth_count += 1 | ||||
for j in range(start, idx_i + 1): | for j in range(start, idx_i + 1): | ||||
if y_[j]!=x_[j]: | |||||
if y_[j] != x_[j]: | |||||
flag = False | flag = False | ||||
break | break | ||||
if flag: | if flag: | ||||
@@ -376,8 +384,7 @@ class SeqLabelEvaluator2(Evaluator): | |||||
R = corr_count / (float(truth_count) + 1e-6) | R = corr_count / (float(truth_count) + 1e-6) | ||||
F = 2 * P * R / (P + R + 1e-6) | F = 2 * P * R / (P + R + 1e-6) | ||||
return {"P": P, 'R':R, 'F': F} | |||||
return {"P": P, 'R': R, 'F': F} | |||||
class SNLIEvaluator(Evaluator): | class SNLIEvaluator(Evaluator): | ||||
@@ -559,10 +566,6 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary'): | |||||
return 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 | return 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 | ||||
def classification_report(y_true, y_pred, labels=None, target_names=None, digits=2): | |||||
raise NotImplementedError | |||||
def accuracy_topk(y_true, y_prob, k=1): | def accuracy_topk(y_true, y_prob, k=1): | ||||
"""Compute accuracy of y_true matching top-k probable | """Compute accuracy of y_true matching top-k probable | ||||
labels in y_prob. | labels in y_prob. | ||||
@@ -4,40 +4,13 @@ import torch | |||||
class Optimizer(object): | class Optimizer(object): | ||||
def __init__(self, model_params, **kwargs): | def __init__(self, model_params, **kwargs): | ||||
if model_params is not None and not hasattr(model_params, "__next__"): | if model_params is not None and not hasattr(model_params, "__next__"): | ||||
raise RuntimeError("model parameters should be a generator, rather than {}".format(type(model_params))) | |||||
raise RuntimeError("model parameters should be a generator, rather than {}.".format(type(model_params))) | |||||
self.model_params = model_params | self.model_params = model_params | ||||
self.settings = kwargs | self.settings = kwargs | ||||
class SGD(Optimizer): | class SGD(Optimizer): | ||||
def __init__(self, *args, **kwargs): | |||||
model_params, lr, momentum = None, 0.01, 0.9 | |||||
if len(args) == 0 and len(kwargs) == 0: | |||||
# SGD() | |||||
pass | |||||
elif len(args) == 1 and len(kwargs) == 0: | |||||
if isinstance(args[0], float) or isinstance(args[0], int): | |||||
# SGD(0.001) | |||||
lr = args[0] | |||||
elif hasattr(args[0], "__next__"): | |||||
# SGD(model.parameters()) args[0] is a generator | |||||
model_params = args[0] | |||||
else: | |||||
raise RuntimeError("Not supported type {}.".format(type(args[0]))) | |||||
elif 2 >= len(kwargs) > 0 and len(args) <= 1: | |||||
# SGD(lr=0.01), SGD(lr=0.01, momentum=0.9), SGD(model.parameters(), lr=0.1, momentum=0.9) | |||||
if len(args) == 1: | |||||
if hasattr(args[0], "__next__"): | |||||
model_params = args[0] | |||||
else: | |||||
raise RuntimeError("Not supported type {}.".format(type(args[0]))) | |||||
if not all(key in ("lr", "momentum") for key in kwargs): | |||||
raise RuntimeError("Invalid SGD arguments. Expect {}, got {}.".format(("lr", "momentum"), kwargs)) | |||||
lr = kwargs.get("lr", 0.01) | |||||
momentum = kwargs.get("momentum", 0.9) | |||||
else: | |||||
raise RuntimeError("SGD only accept 0 or 1 sequential argument, but got {}: {}".format(len(args), args)) | |||||
def __init__(self, model_params=None, lr=0.01, momentum=0): | |||||
super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) | super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) | ||||
def construct_from_pytorch(self, model_params): | def construct_from_pytorch(self, model_params): | ||||
@@ -49,30 +22,7 @@ class SGD(Optimizer): | |||||
class Adam(Optimizer): | class Adam(Optimizer): | ||||
def __init__(self, *args, **kwargs): | |||||
model_params, lr, weight_decay = None, 0.01, 0.9 | |||||
if len(args) == 0 and len(kwargs) == 0: | |||||
pass | |||||
elif len(args) == 1 and len(kwargs) == 0: | |||||
if isinstance(args[0], float) or isinstance(args[0], int): | |||||
lr = args[0] | |||||
elif hasattr(args[0], "__next__"): | |||||
model_params = args[0] | |||||
else: | |||||
raise RuntimeError("Not supported type {}.".format(type(args[0]))) | |||||
elif 2 >= len(kwargs) > 0 and len(args) <= 1: | |||||
if len(args) == 1: | |||||
if hasattr(args[0], "__next__"): | |||||
model_params = args[0] | |||||
else: | |||||
raise RuntimeError("Not supported type {}.".format(type(args[0]))) | |||||
if not all(key in ("lr", "weight_decay") for key in kwargs): | |||||
raise RuntimeError("Invalid Adam arguments. Expect {}, got {}.".format(("lr", "weight_decay"), kwargs)) | |||||
lr = kwargs.get("lr", 0.01) | |||||
weight_decay = kwargs.get("weight_decay", 0.9) | |||||
else: | |||||
raise RuntimeError("Adam only accept 0 or 1 sequential argument, but got {}: {}".format(len(args), args)) | |||||
def __init__(self, model_params=None, lr=0.01, weight_decay=0): | |||||
super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay) | super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay) | ||||
def construct_from_pytorch(self, model_params): | def construct_from_pytorch(self, model_params): | ||||
@@ -1,6 +1,7 @@ | |||||
import os | import os | ||||
import time | import time | ||||
from datetime import datetime | from datetime import datetime | ||||
from datetime import timedelta | |||||
from tqdm import tqdm | from tqdm import tqdm | ||||
import torch | import torch | ||||
@@ -22,17 +23,16 @@ from fastNLP.core.utils import _check_forward_error | |||||
from fastNLP.core.utils import _check_loss_evaluate | from fastNLP.core.utils import _check_loss_evaluate | ||||
from fastNLP.core.utils import _move_dict_value_to_device | from fastNLP.core.utils import _move_dict_value_to_device | ||||
from fastNLP.core.utils import get_func_signature | from fastNLP.core.utils import get_func_signature | ||||
from fastNLP.core.utils import _relocate_pbar | |||||
class Trainer(object): | class Trainer(object): | ||||
"""Main Training Loop | """Main Training Loop | ||||
""" | """ | ||||
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, update_every=50, | def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, update_every=50, | ||||
validate_every=-1, dev_data=None, use_cuda=False, save_path=None, | validate_every=-1, dev_data=None, use_cuda=False, save_path=None, | ||||
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, | optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, | ||||
metric_key=None, sampler=RandomSampler()): | |||||
metric_key=None, sampler=RandomSampler(), use_tqdm=True): | |||||
""" | """ | ||||
:param DataSet train_data: the training data | :param DataSet train_data: the training data | ||||
@@ -54,6 +54,7 @@ class Trainer(object): | |||||
:: | :: | ||||
metric_key="-PPL" # language model gets better as perplexity gets smaller | metric_key="-PPL" # language model gets better as perplexity gets smaller | ||||
:param sampler: method used to generate batch data. | :param sampler: method used to generate batch data. | ||||
:param use_tqdm: boolean, use tqdm to show train progress. | |||||
""" | """ | ||||
super(Trainer, self).__init__() | super(Trainer, self).__init__() | ||||
@@ -117,19 +118,23 @@ class Trainer(object): | |||||
else: | else: | ||||
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) | self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) | ||||
self.use_tqdm = use_tqdm | |||||
if self.use_tqdm: | |||||
tester_verbose = 0 | |||||
else: | |||||
tester_verbose = 1 | |||||
if self.dev_data is not None: | if self.dev_data is not None: | ||||
self.tester = Tester(model=self.model, | self.tester = Tester(model=self.model, | ||||
data=self.dev_data, | data=self.dev_data, | ||||
metrics=self.metrics, | metrics=self.metrics, | ||||
batch_size=self.batch_size, | batch_size=self.batch_size, | ||||
use_cuda=self.use_cuda, | use_cuda=self.use_cuda, | ||||
verbose=0) | |||||
verbose=tester_verbose) | |||||
self.step = 0 | self.step = 0 | ||||
self.start_time = None # start timestamp | self.start_time = None # start timestamp | ||||
# print(self.__dict__) | |||||
def train(self): | def train(self): | ||||
"""Start Training. | """Start Training. | ||||
@@ -155,8 +160,10 @@ class Trainer(object): | |||||
else: | else: | ||||
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) | path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) | ||||
self._summary_writer = SummaryWriter(path) | self._summary_writer = SummaryWriter(path) | ||||
self._tqdm_train() | |||||
if self.use_tqdm: | |||||
self._tqdm_train() | |||||
else: | |||||
self._print_train() | |||||
finally: | finally: | ||||
self._summary_writer.close() | self._summary_writer.close() | ||||
@@ -196,31 +203,67 @@ class Trainer(object): | |||||
eval_res = self._do_validation() | eval_res = self._do_validation() | ||||
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ | eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ | ||||
self.tester._format_eval_results(eval_res) | self.tester._format_eval_results(eval_res) | ||||
pbar = self._relocate_pbar(pbar, print_str=eval_str, total=total_steps, initial=self.step) | |||||
time.sleep(0.1) | |||||
pbar = _relocate_pbar(pbar, print_str=eval_str) | |||||
if self.validate_every < 0 and self.dev_data: | if self.validate_every < 0 and self.dev_data: | ||||
eval_res = self._do_validation() | eval_res = self._do_validation() | ||||
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ | eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ | ||||
self.tester._format_eval_results(eval_res) | self.tester._format_eval_results(eval_res) | ||||
pbar = self._relocate_pbar(pbar, print_str=eval_str, total=total_steps, initial=self.step) | |||||
pbar = _relocate_pbar(pbar, print_str=eval_str) | |||||
if epoch!=self.n_epochs: | if epoch!=self.n_epochs: | ||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, | data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, | ||||
as_numpy=False) | as_numpy=False) | ||||
pbar.close() | pbar.close() | ||||
def _relocate_pbar(self, pbar, total, initial, print_str=None): | |||||
postfix = pbar.postfix | |||||
desc = pbar.desc | |||||
pbar.close() | |||||
avg_time = pbar.avg_time | |||||
start_t = pbar.start_t | |||||
if print_str: | |||||
print(print_str) | |||||
pbar = tqdm(total=total, postfix=postfix, desc=desc, leave=False, initial=initial, dynamic_ncols=True) | |||||
pbar.start_t = start_t | |||||
pbar.avg_time = avg_time | |||||
pbar.sp(pbar.__repr__()) | |||||
return pbar | |||||
def _print_train(self): | |||||
""" | |||||
:param data_iterator: | |||||
:param model: | |||||
:param epoch: | |||||
:param start: | |||||
:return: | |||||
""" | |||||
epoch = 1 | |||||
start = time.time() | |||||
while epoch <= self.n_epochs: | |||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, | |||||
as_numpy=False) | |||||
for batch_x, batch_y in data_iterator: | |||||
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 | |||||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | |||||
prediction = self._data_forward(self.model, batch_x) | |||||
loss = self._compute_loss(prediction, batch_y) | |||||
self._grad_backward(loss) | |||||
self._update() | |||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) | |||||
for name, param in self.model.named_parameters(): | |||||
if param.requires_grad: | |||||
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step) | |||||
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) | |||||
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) | |||||
if self.print_every > 0 and self.step % self.print_every == 0: | |||||
end = time.time() | |||||
diff = timedelta(seconds=round(end - start)) | |||||
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( | |||||
epoch, self.step, loss.data, diff) | |||||
print(print_output) | |||||
if (self.validate_every > 0 and self.step % self.validate_every == 0 and | |||||
self.dev_data is not None): | |||||
self._do_validation() | |||||
self.step += 1 | |||||
# validate_every override validation at end of epochs | |||||
if self.dev_data and self.validate_every <= 0: | |||||
self._do_validation() | |||||
epoch += 1 | |||||
def _do_validation(self): | def _do_validation(self): | ||||
res = self.tester.test() | res = self.tester.test() | ||||
@@ -7,9 +7,12 @@ from collections import namedtuple | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from tqdm import tqdm | |||||
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | ||||
'varargs'], verbose=False) | 'varargs'], verbose=False) | ||||
def save_pickle(obj, pickle_path, file_name): | def save_pickle(obj, pickle_path, file_name): | ||||
"""Save an object into a pickle file. | """Save an object into a pickle file. | ||||
@@ -53,6 +56,7 @@ def pickle_exist(pickle_path, pickle_name): | |||||
else: | else: | ||||
return False | return False | ||||
def _build_args(func, **kwargs): | def _build_args(func, **kwargs): | ||||
spect = inspect.getfullargspec(func) | spect = inspect.getfullargspec(func) | ||||
if spect.varkw is not None: | if spect.varkw is not None: | ||||
@@ -108,7 +112,7 @@ def _check_arg_dict_list(func, args): | |||||
assert callable(func) and isinstance(arg_dict_list, (list, tuple)) | assert callable(func) and isinstance(arg_dict_list, (list, tuple)) | ||||
assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict) | assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict) | ||||
spect = inspect.getfullargspec(func) | spect = inspect.getfullargspec(func) | ||||
all_args = set([arg for arg in spect.args if arg!='self']) | |||||
all_args = set([arg for arg in spect.args if arg != 'self']) | |||||
defaults = [] | defaults = [] | ||||
if spect.defaults is not None: | if spect.defaults is not None: | ||||
defaults = [arg for arg in spect.defaults] | defaults = [arg for arg in spect.defaults] | ||||
@@ -130,6 +134,7 @@ def _check_arg_dict_list(func, args): | |||||
all_needed=list(all_args), | all_needed=list(all_args), | ||||
varargs=varargs) | varargs=varargs) | ||||
def get_func_signature(func): | def get_func_signature(func): | ||||
""" | """ | ||||
@@ -153,7 +158,7 @@ def get_func_signature(func): | |||||
class_name = func.__self__.__class__.__name__ | class_name = func.__self__.__class__.__name__ | ||||
signature = inspect.signature(func) | signature = inspect.signature(func) | ||||
signature_str = str(signature) | signature_str = str(signature) | ||||
if len(signature_str)>2: | |||||
if len(signature_str) > 2: | |||||
_self = '(self, ' | _self = '(self, ' | ||||
else: | else: | ||||
_self = '(self' | _self = '(self' | ||||
@@ -176,12 +181,13 @@ def _is_function_or_method(func): | |||||
return False | return False | ||||
return True | return True | ||||
def _check_function_or_method(func): | def _check_function_or_method(func): | ||||
if not _is_function_or_method(func): | if not _is_function_or_method(func): | ||||
raise TypeError(f"{type(func)} is not a method or function.") | raise TypeError(f"{type(func)} is not a method or function.") | ||||
def _move_dict_value_to_device(*args, device:torch.device): | |||||
def _move_dict_value_to_device(*args, device: torch.device): | |||||
""" | """ | ||||
move data to model's device, element in *args should be dict. This is a inplace change. | move data to model's device, element in *args should be dict. This is a inplace change. | ||||
@@ -206,7 +212,8 @@ class CheckError(Exception): | |||||
CheckError. Used in losses.LossBase, metrics.MetricBase. | CheckError. Used in losses.LossBase, metrics.MetricBase. | ||||
""" | """ | ||||
def __init__(self, check_res:CheckRes, func_signature:str): | |||||
def __init__(self, check_res: CheckRes, func_signature: str): | |||||
errs = [f'The following problems occurred when calling `{func_signature}`'] | errs = [f'The following problems occurred when calling `{func_signature}`'] | ||||
if check_res.varargs: | if check_res.varargs: | ||||
@@ -228,8 +235,9 @@ IGNORE_CHECK_LEVEL = 0 | |||||
WARNING_CHECK_LEVEL = 1 | WARNING_CHECK_LEVEL = 1 | ||||
STRICT_CHECK_LEVEL = 2 | STRICT_CHECK_LEVEL = 2 | ||||
def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res:CheckRes, | |||||
pred_dict:dict, target_dict:dict, dataset, check_level=0): | |||||
def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_res: CheckRes, | |||||
pred_dict: dict, target_dict: dict, dataset, check_level=0): | |||||
errs = [] | errs = [] | ||||
unuseds = [] | unuseds = [] | ||||
_unused_field = [] | _unused_field = [] | ||||
@@ -268,8 +276,8 @@ def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res: | |||||
f"target is {list(target_dict.keys())}).") | f"target is {list(target_dict.keys())}).") | ||||
if _miss_out_dataset: | if _miss_out_dataset: | ||||
_tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now " | _tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now " | ||||
f"target is {list(target_dict.keys())}) or output it " | |||||
f"in {prev_func_signature}(Right now it outputs {list(pred_dict.keys())}).") | |||||
f"target is {list(target_dict.keys())}) or output it " | |||||
f"in {prev_func_signature}(Right now it outputs {list(pred_dict.keys())}).") | |||||
if _unused_field: | if _unused_field: | ||||
_tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. " | _tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. " | ||||
suggestions.append(_tmp) | suggestions.append(_tmp) | ||||
@@ -277,15 +285,15 @@ def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res: | |||||
if check_res.duplicated: | if check_res.duplicated: | ||||
errs.append(f"\tduplicated param: {check_res.duplicated}.") | errs.append(f"\tduplicated param: {check_res.duplicated}.") | ||||
suggestions.append(f"Delete {check_res.duplicated} in the output of " | suggestions.append(f"Delete {check_res.duplicated} in the output of " | ||||
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") | |||||
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") | |||||
if check_level == STRICT_CHECK_LEVEL: | if check_level == STRICT_CHECK_LEVEL: | ||||
errs.extend(unuseds) | errs.extend(unuseds) | ||||
if len(errs)>0: | |||||
if len(errs) > 0: | |||||
errs.insert(0, f'The following problems occurred when calling {func_signature}') | errs.insert(0, f'The following problems occurred when calling {func_signature}') | ||||
sugg_str = "" | sugg_str = "" | ||||
if len(suggestions)>1: | |||||
if len(suggestions) > 1: | |||||
for idx, sugg in enumerate(suggestions): | for idx, sugg in enumerate(suggestions): | ||||
sugg_str += f'({idx+1}). {sugg}' | sugg_str += f'({idx+1}). {sugg}' | ||||
else: | else: | ||||
@@ -332,10 +340,10 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||||
if check_level == STRICT_CHECK_LEVEL: | if check_level == STRICT_CHECK_LEVEL: | ||||
errs.extend(_unused) | errs.extend(_unused) | ||||
if len(errs)>0: | |||||
if len(errs) > 0: | |||||
errs.insert(0, f'The following problems occurred when calling {func_signature}') | errs.insert(0, f'The following problems occurred when calling {func_signature}') | ||||
sugg_str = "" | sugg_str = "" | ||||
if len(suggestions)>1: | |||||
if len(suggestions) > 1: | |||||
for idx, sugg in enumerate(suggestions): | for idx, sugg in enumerate(suggestions): | ||||
sugg_str += f'({idx+1}). {sugg}' | sugg_str += f'({idx+1}). {sugg}' | ||||
else: | else: | ||||
@@ -357,11 +365,11 @@ def seq_lens_to_masks(seq_lens, float=True): | |||||
:return: list, np.ndarray or torch.Tensor, shape will be (B, max_length) | :return: list, np.ndarray or torch.Tensor, shape will be (B, max_length) | ||||
""" | """ | ||||
if isinstance(seq_lens, np.ndarray): | if isinstance(seq_lens, np.ndarray): | ||||
assert len(np.shape(seq_lens))==1, f"seq_lens can only have one dimension, got {len(np.shape(seq_lens))}." | |||||
assert len(np.shape(seq_lens)) == 1, f"seq_lens can only have one dimension, got {len(np.shape(seq_lens))}." | |||||
assert seq_lens.dtype in (int, np.int32, np.int64), f"seq_lens can only be integer, not {seq_lens.dtype}." | assert seq_lens.dtype in (int, np.int32, np.int64), f"seq_lens can only be integer, not {seq_lens.dtype}." | ||||
raise NotImplemented | raise NotImplemented | ||||
elif isinstance(seq_lens, torch.LongTensor): | elif isinstance(seq_lens, torch.LongTensor): | ||||
assert len(seq_lens.size())==1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}." | |||||
assert len(seq_lens.size()) == 1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}." | |||||
batch_size = seq_lens.size(0) | batch_size = seq_lens.size(0) | ||||
max_len = seq_lens.max() | max_len = seq_lens.max() | ||||
indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device) | indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device) | ||||
@@ -376,3 +384,54 @@ def seq_lens_to_masks(seq_lens, float=True): | |||||
else: | else: | ||||
raise NotImplemented | raise NotImplemented | ||||
def seq_mask(seq_len, max_len): | |||||
"""Create sequence mask. | |||||
:param seq_len: list or torch.Tensor, the lengths of sequences in a batch. | |||||
:param max_len: int, the maximum sequence length in a batch. | |||||
:return mask: torch.LongTensor, [batch_size, max_len] | |||||
""" | |||||
if not isinstance(seq_len, torch.Tensor): | |||||
seq_len = torch.LongTensor(seq_len) | |||||
seq_len = seq_len.view(-1, 1).long() # [batch_size, 1] | |||||
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len] | |||||
return torch.gt(seq_len, seq_range) # [batch_size, max_len] | |||||
def _relocate_pbar(pbar:tqdm, print_str:str): | |||||
""" | |||||
When using tqdm, you cannot print. If you print, the tqdm will duplicate. By using this function, print_str will | |||||
show above tqdm. | |||||
:param pbar: tqdm | |||||
:param print_str: | |||||
:return: | |||||
""" | |||||
params = ['desc', 'total', 'leave', 'file', 'ncols', 'mininterval', 'maxinterval', 'miniters', 'ascii', 'disable', | |||||
'unit', 'unit_scale', 'dynamic_ncols', 'smoothing', 'bar_format', 'initial', 'position', 'postfix', 'unit_divisor', | |||||
'gui'] | |||||
attr_map = {'file': 'fp', 'initial':'n', 'position':'pos'} | |||||
param_dict = {} | |||||
for param in params: | |||||
attr_name = param | |||||
if param in attr_map: | |||||
attr_name = attr_map[param] | |||||
value = getattr(pbar, attr_name) | |||||
if attr_name == 'pos': | |||||
value = abs(value) | |||||
param_dict[param] = value | |||||
pbar.close() | |||||
avg_time = pbar.avg_time | |||||
start_t = pbar.start_t | |||||
print(print_str) | |||||
pbar = tqdm(**param_dict) | |||||
pbar.start_t = start_t | |||||
pbar.avg_time = avg_time | |||||
pbar.sp(pbar.__repr__()) | |||||
return pbar |
@@ -105,9 +105,9 @@ class EmbedLoader(BaseLoader): | |||||
if np.sum(hit_flags) < len(vocab): | if np.sum(hit_flags) < len(vocab): | ||||
# some words from vocab are missing in pre-trained embedding | # some words from vocab are missing in pre-trained embedding | ||||
# we normally sample them | |||||
# we normally sample each dimension | |||||
vocab_embed = embedding_matrix[np.where(hit_flags)] | vocab_embed = embedding_matrix[np.where(hit_flags)] | ||||
mean, cov = vocab_embed.mean(axis=0), np.cov(vocab_embed.T) | |||||
sampled_vectors = np.random.multivariate_normal(mean, cov, size=(len(vocab) - np.sum(hit_flags),)) | |||||
sampled_vectors = np.random.normal(vocab_embed.mean(axis=0), vocab_embed.std(axis=0), | |||||
size=(len(vocab) - np.sum(hit_flags), emb_dim)) | |||||
embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors | embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors | ||||
return embedding_matrix | return embedding_matrix |
@@ -43,7 +43,7 @@ class ConvCharEmbedding(nn.Module): | |||||
# [batch_size*sent_length, feature_maps[i], 1, width - kernels[i] + 1] | # [batch_size*sent_length, feature_maps[i], 1, width - kernels[i] + 1] | ||||
y = torch.squeeze(y, 2) | y = torch.squeeze(y, 2) | ||||
# [batch_size*sent_length, feature_maps[i], width - kernels[i] + 1] | # [batch_size*sent_length, feature_maps[i], width - kernels[i] + 1] | ||||
y = F.tanh(y) | |||||
y = torch.tanh(y) | |||||
y, __ = torch.max(y, 2) | y, __ = torch.max(y, 2) | ||||
# [batch_size*sent_length, feature_maps[i]] | # [batch_size*sent_length, feature_maps[i]] | ||||
feats.append(y) | feats.append(y) | ||||
@@ -44,6 +44,9 @@ class TestDataSet(unittest.TestCase): | |||||
self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10) | self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10) | ||||
self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10) | self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10) | ||||
with self.assertRaises(RuntimeError): | |||||
dd.add_field("??", [[1, 2]] * 40) | |||||
def test_delete_field(self): | def test_delete_field(self): | ||||
dd = DataSet() | dd = DataSet() | ||||
dd.add_field("x", [[1, 2, 3]] * 10) | dd.add_field("x", [[1, 2, 3]] * 10) | ||||
@@ -65,8 +68,66 @@ class TestDataSet(unittest.TestCase): | |||||
self.assertTrue(isinstance(sub_ds, DataSet)) | self.assertTrue(isinstance(sub_ds, DataSet)) | ||||
self.assertEqual(len(sub_ds), 10) | self.assertEqual(len(sub_ds), 10) | ||||
def test_get_item_error(self): | |||||
with self.assertRaises(RuntimeError): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
_ = ds[40:] | |||||
with self.assertRaises(KeyError): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
_ = ds["kom"] | |||||
def test_len_(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||||
self.assertEqual(len(ds), 40) | |||||
ds = DataSet() | |||||
self.assertEqual(len(ds), 0) | |||||
def test_apply(self): | def test_apply(self): | ||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | ||||
ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx") | ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx") | ||||
self.assertTrue("rx" in ds.field_arrays) | self.assertTrue("rx" in ds.field_arrays) | ||||
self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1]) | self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1]) | ||||
def test_contains(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||||
self.assertTrue("x" in ds) | |||||
self.assertTrue("y" in ds) | |||||
self.assertFalse("z" in ds) | |||||
def test_rename_field(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
ds.rename_field("x", "xx") | |||||
self.assertTrue("xx" in ds) | |||||
self.assertFalse("x" in ds) | |||||
with self.assertRaises(KeyError): | |||||
ds.rename_field("yyy", "oo") | |||||
def test_input_target(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
ds.set_input("x") | |||||
ds.set_target("y") | |||||
self.assertTrue(ds.field_arrays["x"].is_input) | |||||
self.assertTrue(ds.field_arrays["y"].is_target) | |||||
with self.assertRaises(KeyError): | |||||
ds.set_input("xxx") | |||||
with self.assertRaises(KeyError): | |||||
ds.set_input("yyy") | |||||
def test_get_input_name(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
self.assertEqual(ds.get_input_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_input]) | |||||
def test_get_target_name(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
self.assertEqual(ds.get_target_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_target]) | |||||
class TestDataSetIter(unittest.TestCase): | |||||
def test__repr__(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||||
for iter in ds: | |||||
self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4],\n'y': [5, 6]}") |
@@ -27,3 +27,9 @@ class TestCase(unittest.TestCase): | |||||
self.assertEqual(ins["x"], [1, 2, 3]) | self.assertEqual(ins["x"], [1, 2, 3]) | ||||
self.assertEqual(ins["y"], [4, 5, 6]) | self.assertEqual(ins["y"], [4, 5, 6]) | ||||
self.assertEqual(ins["z"], [1, 1, 1]) | self.assertEqual(ins["z"], [1, 1, 1]) | ||||
def test_repr(self): | |||||
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} | |||||
ins = Instance(**fields) | |||||
# simple print, that is enough. | |||||
print(ins) |
@@ -271,40 +271,32 @@ class TestLoss(unittest.TestCase): | |||||
loss3 = get_loss_3({'predict': predict}, {'truth': truth}) | loss3 = get_loss_3({'predict': predict}, {'truth': truth}) | ||||
assert loss1 == loss2 and loss1 == loss3 | assert loss1 == loss2 and loss1 == loss3 | ||||
""" | |||||
get_loss_4 = LossFunc(func4) | |||||
loss4 = get_loss_4({'a': 1, 'b': 3}, {}) | |||||
print(loss4) | |||||
assert loss4 == (1 + 3) * 2 | |||||
get_loss_5 = LossFunc(func4) | |||||
loss5 = get_loss_5({'a': 1, 'b': 3}, {'c': 4}) | |||||
print(loss5) | |||||
assert loss5 == (1 + 3) * 4 | |||||
get_loss_6 = LossFunc(func6) | |||||
loss6 = get_loss_6({'a': 1, 'b': 3}, {'c': 4}) | |||||
print(loss6) | |||||
assert loss6 == (1 + 3) * 4 | |||||
get_loss_7 = LossFunc(func6, c='cc') | |||||
loss7 = get_loss_7({'a': 1, 'b': 3}, {'cc': 4}) | |||||
print(loss7) | |||||
assert loss7 == (1 + 3) * 4 | |||||
""" | |||||
class TestLoss_v2(unittest.TestCase): | class TestLoss_v2(unittest.TestCase): | ||||
def test_CrossEntropyLoss(self): | def test_CrossEntropyLoss(self): | ||||
ce = loss.CrossEntropyLoss(input="my_predict", target="my_truth") | |||||
ce = loss.CrossEntropyLoss(pred="my_predict", target="my_truth") | |||||
a = torch.randn(3, 5, requires_grad=False) | a = torch.randn(3, 5, requires_grad=False) | ||||
b = torch.empty(3, dtype=torch.long).random_(5) | b = torch.empty(3, dtype=torch.long).random_(5) | ||||
ans = ce({"my_predict": a}, {"my_truth": b}) | ans = ce({"my_predict": a}, {"my_truth": b}) | ||||
self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b)) | self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b)) | ||||
def test_BCELoss(self): | def test_BCELoss(self): | ||||
bce = loss.BCELoss(input="my_predict", target="my_truth") | |||||
bce = loss.BCELoss(pred="my_predict", target="my_truth") | |||||
a = torch.sigmoid(torch.randn((3, 5), requires_grad=False)) | a = torch.sigmoid(torch.randn((3, 5), requires_grad=False)) | ||||
b = torch.randn((3, 5), requires_grad=False) | b = torch.randn((3, 5), requires_grad=False) | ||||
ans = bce({"my_predict": a}, {"my_truth": b}) | ans = bce({"my_predict": a}, {"my_truth": b}) | ||||
self.assertEqual(ans, torch.nn.functional.binary_cross_entropy(a, b)) | self.assertEqual(ans, torch.nn.functional.binary_cross_entropy(a, b)) | ||||
def test_L1Loss(self): | |||||
l1 = loss.L1Loss(pred="my_predict", target="my_truth") | |||||
a = torch.randn(3, 5, requires_grad=False) | |||||
b = torch.randn(3, 5) | |||||
ans = l1({"my_predict": a}, {"my_truth": b}) | |||||
self.assertEqual(ans, torch.nn.functional.l1_loss(a, b)) | |||||
def test_NLLLoss(self): | |||||
l1 = loss.NLLLoss(pred="my_predict", target="my_truth") | |||||
a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0) | |||||
b = torch.tensor([1, 0, 4]) | |||||
ans = l1({"my_predict": a}, {"my_truth": b}) | |||||
self.assertEqual(ans, torch.nn.functional.nll_loss(a, b)) |
@@ -11,9 +11,6 @@ class TestOptim(unittest.TestCase): | |||||
self.assertTrue("lr" in optim.__dict__["settings"]) | self.assertTrue("lr" in optim.__dict__["settings"]) | ||||
self.assertTrue("momentum" in optim.__dict__["settings"]) | self.assertTrue("momentum" in optim.__dict__["settings"]) | ||||
optim = SGD(0.001) | |||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||||
optim = SGD(lr=0.001) | optim = SGD(lr=0.001) | ||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | ||||
@@ -25,17 +22,12 @@ class TestOptim(unittest.TestCase): | |||||
_ = SGD("???") | _ = SGD("???") | ||||
with self.assertRaises(RuntimeError): | with self.assertRaises(RuntimeError): | ||||
_ = SGD(0.001, lr=0.002) | _ = SGD(0.001, lr=0.002) | ||||
with self.assertRaises(RuntimeError): | |||||
_ = SGD(lr=0.009, shit=9000) | |||||
def test_Adam(self): | def test_Adam(self): | ||||
optim = Adam(torch.nn.Linear(10, 3).parameters()) | optim = Adam(torch.nn.Linear(10, 3).parameters()) | ||||
self.assertTrue("lr" in optim.__dict__["settings"]) | self.assertTrue("lr" in optim.__dict__["settings"]) | ||||
self.assertTrue("weight_decay" in optim.__dict__["settings"]) | self.assertTrue("weight_decay" in optim.__dict__["settings"]) | ||||
optim = Adam(0.001) | |||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||||
optim = Adam(lr=0.001) | optim = Adam(lr=0.001) | ||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | ||||
@@ -32,14 +32,14 @@ class TrainerTestGround(unittest.TestCase): | |||||
model = NaiveClassifier(2, 1) | model = NaiveClassifier(2, 1) | ||||
trainer = Trainer(train_set, model, | trainer = Trainer(train_set, model, | ||||
losser=BCELoss(input="predict", target="y"), | |||||
losser=BCELoss(pred="predict", target="y"), | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | metrics=AccuracyMetric(pred="predict", target="y"), | ||||
n_epochs=10, | n_epochs=10, | ||||
batch_size=32, | batch_size=32, | ||||
update_every=1, | update_every=1, | ||||
validate_every=-1, | |||||
validate_every=10, | |||||
dev_data=dev_set, | dev_data=dev_set, | ||||
optimizer=SGD(0.1), | |||||
check_code_level=2 | |||||
) | |||||
trainer.train() | |||||
optimizer=SGD(lr=0.1), | |||||
check_code_level=2, | |||||
use_tqdm=True) | |||||
trainer.train() |
@@ -1,12 +1,12 @@ | |||||
import unittest | import unittest | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from fastNLP.io.embed_loader import EmbedLoader | |||||
class TestEmbedLoader(unittest.TestCase): | class TestEmbedLoader(unittest.TestCase): | ||||
def test_case(self): | def test_case(self): | ||||
vocab = Vocabulary() | vocab = Vocabulary() | ||||
vocab.update(["the", "in", "I", "to", "of", "hahaha"]) | vocab.update(["the", "in", "I", "to", "of", "hahaha"]) | ||||
# TODO: np.cov在linux上segment fault,原因未知 | |||||
# embedding = EmbedLoader().fast_load_embedding(50, "../data_for_tests/glove.6B.50d_test.txt", vocab) | |||||
# self.assertEqual(tuple(embedding.shape), (len(vocab), 50)) | |||||
embedding = EmbedLoader().fast_load_embedding(50, "test/data_for_tests/glove.6B.50d_test.txt", vocab) | |||||
self.assertEqual(tuple(embedding.shape), (len(vocab), 50)) |
@@ -71,20 +71,16 @@ class TestTutorial(unittest.TestCase): | |||||
# 实例化Trainer,传入模型和数据,进行训练 | # 实例化Trainer,传入模型和数据,进行训练 | ||||
copy_model = deepcopy(model) | copy_model = deepcopy(model) | ||||
overfit_trainer = Trainer(model=copy_model, train_data=test_data, dev_data=test_data, | |||||
losser=CrossEntropyLoss(input="output", target="label_seq"), | |||||
metrics=AccuracyMetric(pred="predict", target="label_seq"), | |||||
save_path="./save", | |||||
batch_size=4, | |||||
n_epochs=10) | |||||
overfit_trainer = Trainer(train_data=test_data, model=copy_model, | |||||
losser=CrossEntropyLoss(pred="output", target="label_seq"), | |||||
metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4, | |||||
dev_data=test_data, save_path="./save") | |||||
overfit_trainer.train() | overfit_trainer.train() | ||||
trainer = Trainer(model=model, train_data=train_data, dev_data=test_data, | |||||
losser=CrossEntropyLoss(input="output", target="label_seq"), | |||||
metrics=AccuracyMetric(pred="predict", target="label_seq"), | |||||
save_path="./save", | |||||
batch_size=4, | |||||
n_epochs=10) | |||||
trainer = Trainer(train_data=train_data, model=model, | |||||
losser=CrossEntropyLoss(pred="output", target="label_seq"), | |||||
metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4, | |||||
dev_data=test_data, save_path="./save") | |||||
trainer.train() | trainer.train() | ||||
print('Train finished!') | print('Train finished!') | ||||