diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 235a9a3a..90f0fc8c 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -2,15 +2,19 @@ batch 模块实现了 fastNLP 所需的 Batch 类。 """ -__all__ = ["Batch"] +import atexit import numpy as np import torch -import atexit - -from .sampler import RandomSampler, Sampler import torch.multiprocessing as mp + from queue import Empty, Full +from .sampler import RandomSampler + +__all__ = [ + "Batch" +] + _python_is_exit = False @@ -120,7 +124,7 @@ class Batch(object): :return list(int) indexes: 下标序列 """ return self.cur_batch_indices - + @staticmethod def _run_fetch(batch, q): try: @@ -145,7 +149,7 @@ class Batch(object): q.put(e) finally: q.join() - + @staticmethod def _run_batch_iter(batch): q = mp.JoinableQueue(maxsize=10) @@ -182,4 +186,3 @@ def _to_tensor(batch, dtype): except: pass return batch - diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 9dce426b..0a5ddc52 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -49,6 +49,18 @@ callback模块实现了 fastNLP 中的许多 callback 类,用于增强 :class: trainer.train() """ +import os +import torch + +try: + from tensorboardX import SummaryWriter + + tensorboardX_flag = True +except: + tensorboardX_flag = False + +from ..io.model_io import ModelSaver, ModelLoader + __all__ = [ "Callback", "GradientClipCallback", @@ -60,15 +72,6 @@ __all__ = [ "CallbackException", "EarlyStopError" ] -import os -import torch -from ..io.model_io import ModelSaver, ModelLoader - -try: - from tensorboardX import SummaryWriter - tensorboardX_flag = True -except: - tensorboardX_flag = False class Callback(object): @@ -587,7 +590,7 @@ class TensorboardCallback(Callback): self._summary_writer = SummaryWriter(path) else: self._summary_writer = None - + def on_batch_begin(self, batch_x, batch_y, indices): if "model" in self.options and self.graph_added is False: # tesorboardX 这里有大bug,暂时没法画模型图 diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index b506dfae..63f66019 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -272,9 +272,7 @@ """ -__all__ = ["DataSet"] import _pickle as pickle - import numpy as np import warnings @@ -283,6 +281,10 @@ from .field import FieldArray from .instance import Instance from .utils import _get_func_signature +__all__ = [ + "DataSet" +] + class DataSet(object): """ @@ -854,4 +856,4 @@ class DataSet(object): with open(path, 'rb') as f: d = pickle.load(f) assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) - return d \ No newline at end of file + return d diff --git a/fastNLP/core/field.py b/fastNLP/core/field.py index a355c4d2..4029a4ca 100644 --- a/fastNLP/core/field.py +++ b/fastNLP/core/field.py @@ -3,11 +3,17 @@ field模块实现了 FieldArray 和若干 Padder。 FieldArray 是 :class:`~fas 原理部分请参考 :doc:`fastNLP.core.dataset` """ - - import numpy as np + from copy import deepcopy +__all__ = [ + "FieldArray", + "Padder", + "AutoPadder", + "EngChar2DPadder" +] + class FieldArray(object): """ @@ -24,6 +30,7 @@ class FieldArray(object): :param bool ignore_type: 是否忽略该field的type,一般如果这个field不需要转为torch.FloatTensor或torch.LongTensor, 就可以设置为True。具体意义请参考 :class:`~fastNLP.DataSet` 。 """ + def __init__(self, name, content, is_target=None, is_input=None, padder=None, ignore_type=False): self.name = name if isinstance(content, list): @@ -41,7 +48,7 @@ class FieldArray(object): raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content))) if len(content) == 0: raise RuntimeError("Cannot initialize FieldArray with empty list.") - + self.content = content # 1维 或 2维 或 3维 list, 形状可能不对齐 self.content_dim = None # 表示content是多少维的list if padder is None: @@ -51,27 +58,27 @@ class FieldArray(object): padder = deepcopy(padder) self.set_padder(padder) self.ignore_type = ignore_type - + self.BASIC_TYPES = (int, float, str) # content中可接受的Python基本类型,这里没有np.array - + self.pytype = None self.dtype = None self._is_input = None self._is_target = None - + if is_input is not None or is_target is not None: self.is_input = is_input self.is_target = is_target - + def _set_dtype(self): if self.ignore_type is False: self.pytype = self._type_detection(self.content) self.dtype = self._map_to_np_type(self.pytype) - + @property def is_input(self): return self._is_input - + @is_input.setter def is_input(self, value): """ @@ -80,11 +87,11 @@ class FieldArray(object): if value is True: self._set_dtype() self._is_input = value - + @property def is_target(self): return self._is_target - + @is_target.setter def is_target(self, value): """ @@ -93,7 +100,7 @@ class FieldArray(object): if value is True: self._set_dtype() self._is_target = value - + def _type_detection(self, content): """ 当该field被设置为is_input或者is_target时被调用 @@ -101,9 +108,9 @@ class FieldArray(object): """ if len(content) == 0: raise RuntimeError("Empty list in Field {}.".format(self.name)) - + type_set = set([type(item) for item in content]) - + if list in type_set: if len(type_set) > 1: # list 跟 非list 混在一起 @@ -139,7 +146,7 @@ class FieldArray(object): self.name, self.BASIC_TYPES, content_type)) self.content_dim = 1 return self._basic_type_detection(type_set) - + def _basic_type_detection(self, type_set): """ :param type_set: a set of Python types @@ -158,7 +165,7 @@ class FieldArray(object): else: # str, int, float混在一起 raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) - + def _1d_list_check(self, val): """如果不是1D list就报错 """ @@ -168,7 +175,7 @@ class FieldArray(object): self._basic_type_detection(type_set) # otherwise: _basic_type_detection will raise error return True - + def _2d_list_check(self, val): """如果不是2D list 就报错 """ @@ -181,15 +188,15 @@ class FieldArray(object): inner_type_set.add(type(obj)) self._basic_type_detection(inner_type_set) return True - + @staticmethod def _map_to_np_type(basic_type): type_mapping = {int: np.int64, float: np.float64, str: np.str, np.ndarray: np.ndarray} return type_mapping[basic_type] - + def __repr__(self): return "FieldArray {}: {}".format(self.name, self.content.__repr__()) - + def append(self, val): """将val append到这个field的尾部。如果这个field已经被设置为input或者target,则在append之前会检查该类型是否与已有 的内容是匹配的。 @@ -208,7 +215,7 @@ class FieldArray(object): else: raise RuntimeError( "Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) - + if self.is_input is True or self.is_target is True: if type(val) == list: if len(val) == 0: @@ -231,14 +238,14 @@ class FieldArray(object): raise RuntimeError( "Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) self.content.append(val) - + def __getitem__(self, indices): return self.get(indices, pad=False) - + def __setitem__(self, idx, val): assert isinstance(idx, int) self.content[idx] = val - + def get(self, indices, pad=True): """ 根据给定的indices返回内容 @@ -251,13 +258,13 @@ class FieldArray(object): return self.content[indices] if self.is_input is False and self.is_target is False: raise RuntimeError("Please specify either is_input or is_target is True for {}".format(self.name)) - + contents = [self.content[i] for i in indices] if self.padder is None or pad is False: return np.array(contents) else: return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype) - + def set_padder(self, padder): """ 设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。 @@ -269,7 +276,7 @@ class FieldArray(object): self.padder = deepcopy(padder) else: self.padder = None - + def set_pad_val(self, pad_val): """ 修改padder的pad_val. @@ -279,8 +286,7 @@ class FieldArray(object): if self.padder is not None: self.padder.set_pad_val(pad_val) return self - - + def __len__(self): """ Returns the size of FieldArray. @@ -288,7 +294,7 @@ class FieldArray(object): :return int length: """ return len(self.content) - + def to(self, other): """ 将other的属性复制给本FieldArray(other必须为FieldArray类型). @@ -298,14 +304,15 @@ class FieldArray(object): :return: :class:`~fastNLP.FieldArray` """ assert isinstance(other, FieldArray), "Only support FieldArray type, not {}.".format(type(other)) - + self.is_input = other.is_input self.is_target = other.is_target self.padder = other.padder self.ignore_type = other.ignore_type - + return self + def _is_iterable(content): try: _ = (e for e in content) @@ -331,13 +338,13 @@ class Padder: :return: np.array([padded_element]) """ - + def __init__(self, pad_val=0, **kwargs): self.pad_val = pad_val - + def set_pad_val(self, pad_val): self.pad_val = pad_val - + def __call__(self, contents, field_name, field_ele_dtype): """ 传入的是List内容。假设有以下的DataSet。 @@ -396,13 +403,13 @@ class AutoPadder(Padder): 2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。 即如果Instance中field形如[1, 2, 3, ...],则可以pad;若为[[1,2], [3,4, ...]]则不能进行pad """ - + def __init__(self, pad_val=0): """ :param pad_val: int, padding的位置使用该index """ super().__init__(pad_val=pad_val) - + def _is_two_dimension(self, contents): """ 判断contents是不是只有两个维度。[[1,2], [3]]是两个维度. [[[1,2], [3, 4, 5]], [[4,5]]]有三个维度 @@ -416,7 +423,7 @@ class AutoPadder(Padder): return False return True return False - + def __call__(self, contents, field_name, field_ele_dtype): if not _is_iterable(contents[0]): @@ -458,6 +465,7 @@ class EngChar2DPadder(Padder): dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder """ + def __init__(self, pad_val=0, pad_length=0): """ :param pad_val: int, pad的位置使用该index @@ -465,9 +473,9 @@ class EngChar2DPadder(Padder): 都pad或截取到该长度. """ super().__init__(pad_val=pad_val) - + self.pad_length = pad_length - + def _exactly_three_dims(self, contents, field_name): """ 检查传入的contents是否刚好是3维,如果不是3维就报错。理论上,第一个维度是batch,第二个维度是word,第三个维度是character @@ -486,10 +494,10 @@ class EngChar2DPadder(Padder): value = value[0] except: raise ValueError("Field:{} only has two dimensions.".format(field_name)) - + if _is_iterable(value): raise ValueError("Field:{} has more than 3 dimension.".format(field_name)) - + def __call__(self, contents, field_name, field_ele_dtype): """ 期望输入类似于 @@ -516,12 +524,12 @@ class EngChar2DPadder(Padder): max_sent_length = max(len(word_lst) for word_lst in contents) batch_size = len(contents) dtype = type(contents[0][0][0]) - + padded_array = np.full((batch_size, max_sent_length, max_char_length), fill_value=self.pad_val, - dtype=dtype) + dtype=dtype) for b_idx, word_lst in enumerate(contents): for c_idx, char_lst in enumerate(word_lst): chars = char_lst[:max_char_length] padded_array[b_idx, c_idx, :len(chars)] = chars - - return padded_array \ No newline at end of file + + return padded_array diff --git a/fastNLP/core/instance.py b/fastNLP/core/instance.py index 2303c510..07ae6495 100644 --- a/fastNLP/core/instance.py +++ b/fastNLP/core/instance.py @@ -3,7 +3,9 @@ instance 模块实现了Instance 类在fastNLP中对应sample。一个sample可 便于理解的例子可以参考文档 :doc:`fastNLP.core.dataset` 中的表格 """ -__all__ = ["Instance"] +__all__ = [ + "Instance" +] class Instance(object): diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 7a5fdf9d..b98c5ac7 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -2,13 +2,12 @@ losses 模块定义了 fastNLP 中所需的各种损失函数,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 """ -__all__ = ["LossBase", "L1Loss", "LossFunc", "LossInForward", "BCELoss", "CrossEntropyLoss", "NLLLoss"] import inspect -from collections import defaultdict - import torch import torch.nn.functional as F +from collections import defaultdict + from .utils import _CheckError from .utils import _CheckRes from .utils import _build_args @@ -16,6 +15,18 @@ from .utils import _check_arg_dict_list from .utils import _check_function_or_method from .utils import _get_func_signature +__all__ = [ + "LossBase", + + "LossFunc", + "LossInForward", + + "CrossEntropyLoss", + "BCELoss", + "L1Loss", + "NLLLoss" +] + class LossBase(object): """ diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 7a96020b..df85a318 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -3,11 +3,11 @@ metrics 模块实现了 fastNLP 所需的各种常用衡量指标,一般做为 """ import inspect -from collections import defaultdict - import numpy as np import torch +from collections import defaultdict + from .utils import _CheckError from .utils import _CheckRes from .utils import _build_args @@ -16,6 +16,13 @@ from .utils import _get_func_signature from .utils import seq_len_to_mask from .vocabulary import Vocabulary +__all__ = [ + "MetricBase", + "AccuracyMetric", + "SpanFPreRecMetric", + "SQuADMetric" +] + class MetricBase(object): """ @@ -106,16 +113,17 @@ class MetricBase(object): self.get_metric将统计当前的评价指标并返回评价结果, 返回值需要是一个dict, key是指标名称,value是指标的值 """ + def __init__(self): self.param_map = {} # key is param in function, value is input param. self._checked = False - + def evaluate(self, *args, **kwargs): raise NotImplementedError - + def get_metric(self, reset=True): raise NotImplemented - + def _init_param_map(self, key_map=None, **kwargs): """检查key_map和其他参数map,并将这些映射关系添加到self.param_map @@ -148,7 +156,7 @@ class MetricBase(object): for value, key_set in value_counter.items(): if len(key_set) > 1: raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") - + # check consistence between signature and param_map func_spect = inspect.getfullargspec(self.evaluate) func_args = [arg for arg in func_spect.args if arg != 'self'] @@ -157,7 +165,7 @@ class MetricBase(object): raise NameError( f"Parameter `{func_param}` is not in {_get_func_signature(self.evaluate)}. Please check the " f"initialization parameters, or change its signature.") - + def _fast_param_map(self, pred_dict, target_dict): """Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. such as pred_dict has one element, target_dict has one element @@ -172,7 +180,7 @@ class MetricBase(object): fast_param['target'] = list(target_dict.values())[0] return fast_param return fast_param - + def __call__(self, pred_dict, target_dict): """ 这个方法会调用self.evaluate 方法. @@ -187,12 +195,12 @@ class MetricBase(object): :param target_dict: DataSet.batch_y里的键-值对所组成的dict(即is_target=True的fields的内容) :return: """ - + fast_param = self._fast_param_map(pred_dict, target_dict) if fast_param: self.evaluate(**fast_param) return - + if not self._checked: if not callable(self.evaluate): raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") @@ -202,14 +210,14 @@ class MetricBase(object): for func_arg, input_arg in self.param_map.items(): if func_arg not in func_args: raise NameError(f"`{func_arg}` not in {_get_func_signature(self.evaluate)}.") - + # 2. only part of the param_map are passed, left are not for arg in func_args: if arg not in self.param_map: self.param_map[arg] = arg # This param does not need mapping. self._evaluate_args = func_args self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} - + # need to wrap inputs in dict. mapped_pred_dict = {} mapped_target_dict = {} @@ -229,7 +237,7 @@ class MetricBase(object): not_duplicate_flag += 1 if not_duplicate_flag == 3: duplicated.append(input_arg) - + # missing if not self._checked: check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) @@ -240,23 +248,23 @@ class MetricBase(object): for idx, func_arg in enumerate(missing): # Don't delete `` in this information, nor add `` replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ - f"in `{self.__class__.__name__}`)" - + f"in `{self.__class__.__name__}`)" + check_res = _CheckRes(missing=replaced_missing, unused=check_res.unused, duplicated=duplicated, required=check_res.required, all_needed=check_res.all_needed, varargs=check_res.varargs) - + if check_res.missing or check_res.duplicated: raise _CheckError(check_res=check_res, func_signature=_get_func_signature(self.evaluate)) refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) - + self.evaluate(**refined_args) self._checked = True - + return @@ -271,15 +279,16 @@ class AccuracyMetric(MetricBase): :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` :param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len` """ + def __init__(self, pred=None, target=None, seq_len=None): super().__init__() - + self._init_param_map(pred=pred, target=target, seq_len=seq_len) - + self.total = 0 self.acc_count = 0 - + def evaluate(self, pred, target, seq_len=None): """ evaluate函数将针对一个批次的预测结果做评价指标的累计 @@ -299,16 +308,16 @@ class AccuracyMetric(MetricBase): if not isinstance(target, torch.Tensor): raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," f"got {type(target)}.") - + if seq_len is not None and not isinstance(seq_len, torch.Tensor): raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," f"got {type(seq_len)}.") - + if seq_len is not None: masks = seq_len_to_mask(seq_len=seq_len) else: masks = None - + if pred.size() == target.size(): pass elif len(pred.size()) == len(target.size()) + 1: @@ -317,7 +326,7 @@ class AccuracyMetric(MetricBase): raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " f"size:{pred.size()}, target should have size: {pred.size()} or " f"{pred.size()[:-1]}, got {target.size()}.") - + target = target.to(pred) if masks is not None: self.acc_count += torch.sum(torch.eq(pred, target).masked_fill(masks.eq(0), 0)).item() @@ -325,7 +334,7 @@ class AccuracyMetric(MetricBase): else: self.acc_count += torch.sum(torch.eq(pred, target)).item() self.total += np.prod(list(pred.size())) - + def get_metric(self, reset=True): """ get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. @@ -350,7 +359,7 @@ def _bmes_tag_to_spans(tags, ignore_labels=None): :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] """ ignore_labels = set(ignore_labels) if ignore_labels else set() - + spans = [] prev_bmes_tag = None for idx, tag in enumerate(tags): @@ -358,14 +367,14 @@ def _bmes_tag_to_spans(tags, ignore_labels=None): bmes_tag, label = tag[:1], tag[2:] if bmes_tag in ('b', 's'): spans.append((label, [idx, idx])) - elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label==spans[-1][0]: + elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]: spans[-1][1][1] = idx else: spans.append((label, [idx, idx])) prev_bmes_tag = bmes_tag - return [(span[0], (span[1][0], span[1][1]+1)) - for span in spans - if span[0] not in ignore_labels + return [(span[0], (span[1][0], span[1][1] + 1)) + for span in spans + if span[0] not in ignore_labels ] @@ -379,7 +388,7 @@ def _bmeso_tag_to_spans(tags, ignore_labels=None): :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] """ ignore_labels = set(ignore_labels) if ignore_labels else set() - + spans = [] prev_bmes_tag = None for idx, tag in enumerate(tags): @@ -387,16 +396,16 @@ def _bmeso_tag_to_spans(tags, ignore_labels=None): bmes_tag, label = tag[:1], tag[2:] if bmes_tag in ('b', 's'): spans.append((label, [idx, idx])) - elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label==spans[-1][0]: + elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]: spans[-1][1][1] = idx elif bmes_tag == 'o': pass else: spans.append((label, [idx, idx])) prev_bmes_tag = bmes_tag - return [(span[0], (span[1][0], span[1][1]+1)) - for span in spans - if span[0] not in ignore_labels + return [(span[0], (span[1][0], span[1][1] + 1)) + for span in spans + if span[0] not in ignore_labels ] @@ -410,7 +419,7 @@ def _bio_tag_to_spans(tags, ignore_labels=None): :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] """ ignore_labels = set(ignore_labels) if ignore_labels else set() - + spans = [] prev_bio_tag = None for idx, tag in enumerate(tags): @@ -418,14 +427,14 @@ def _bio_tag_to_spans(tags, ignore_labels=None): bio_tag, label = tag[:1], tag[2:] if bio_tag == 'b': spans.append((label, [idx, idx])) - elif bio_tag == 'i' and prev_bio_tag in ('b', 'i') and label==spans[-1][0]: + elif bio_tag == 'i' and prev_bio_tag in ('b', 'i') and label == spans[-1][0]: spans[-1][1][1] = idx - elif bio_tag == 'o': # o tag does not count + elif bio_tag == 'o': # o tag does not count pass else: spans.append((label, [idx, idx])) prev_bio_tag = bio_tag - return [(span[0], (span[1][0], span[1][1]+1)) for span in spans if span[0] not in ignore_labels] + return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels] class SpanFPreRecMetric(MetricBase): @@ -470,16 +479,17 @@ class SpanFPreRecMetric(MetricBase): :param float beta: f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 """ + def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None, - only_gross=True, f_type='micro', beta=1): + only_gross=True, f_type='micro', beta=1): encoding_type = encoding_type.lower() - + if not isinstance(tag_vocab, Vocabulary): raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) if f_type not in ('micro', 'macro'): raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) - + self.encoding_type = encoding_type if self.encoding_type == 'bmes': self.tag_to_span_func = _bmes_tag_to_spans @@ -489,22 +499,22 @@ class SpanFPreRecMetric(MetricBase): self.tag_to_span_func = _bmeso_tag_to_spans else: raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.") - + self.ignore_labels = ignore_labels self.f_type = f_type self.beta = beta - self.beta_square = self.beta**2 + self.beta_square = self.beta ** 2 self.only_gross = only_gross - + super().__init__() self._init_param_map(pred=pred, target=target, seq_len=seq_len) - + self.tag_vocab = tag_vocab - + self._true_positives = defaultdict(int) self._false_positives = defaultdict(int) self._false_negatives = defaultdict(int) - + def evaluate(self, pred, target, seq_len): """evaluate函数将针对一个批次的预测结果做评价指标的累计 @@ -519,11 +529,11 @@ class SpanFPreRecMetric(MetricBase): if not isinstance(target, torch.Tensor): raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," f"got {type(target)}.") - + if not isinstance(seq_len, torch.Tensor): raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," f"got {type(seq_len)}.") - + if pred.size() == target.size() and len(target.size()) == 2: pass elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: @@ -536,20 +546,20 @@ class SpanFPreRecMetric(MetricBase): raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " f"size:{pred.size()}, target should have size: {pred.size()} or " f"{pred.size()[:-1]}, got {target.size()}.") - + batch_size = pred.size(0) pred = pred.tolist() target = target.tolist() for i in range(batch_size): pred_tags = pred[i][:int(seq_len[i])] gold_tags = target[i][:int(seq_len[i])] - + pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags] gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags] - + pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels) gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels) - + for span in pred_spans: if span in gold_spans: self._true_positives[span[0]] += 1 @@ -558,7 +568,7 @@ class SpanFPreRecMetric(MetricBase): self._false_positives[span[0]] += 1 for span in gold_spans: self._false_negatives[span[0]] += 1 - + def get_metric(self, reset=True): """get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" evaluate_result = {} @@ -577,19 +587,19 @@ class SpanFPreRecMetric(MetricBase): f_sum += f pre_sum += pre rec_sum + rec - if not self.only_gross and tag!='': # tag!=''防止无tag的情况 + if not self.only_gross and tag != '': # tag!=''防止无tag的情况 f_key = 'f-{}'.format(tag) pre_key = 'pre-{}'.format(tag) rec_key = 'rec-{}'.format(tag) evaluate_result[f_key] = f evaluate_result[pre_key] = pre evaluate_result[rec_key] = rec - + if self.f_type == 'macro': - evaluate_result['f'] = f_sum/len(tags) - evaluate_result['pre'] = pre_sum/len(tags) - evaluate_result['rec'] = rec_sum/len(tags) - + evaluate_result['f'] = f_sum / len(tags) + evaluate_result['pre'] = pre_sum / len(tags) + evaluate_result['rec'] = rec_sum / len(tags) + if self.f_type == 'micro': f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()), sum(self._false_negatives.values()), @@ -597,17 +607,17 @@ class SpanFPreRecMetric(MetricBase): evaluate_result['f'] = f evaluate_result['pre'] = pre evaluate_result['rec'] = rec - + if reset: self._true_positives = defaultdict(int) self._false_positives = defaultdict(int) self._false_negatives = defaultdict(int) - + for key, value in evaluate_result.items(): evaluate_result[key] = round(value, 6) - + return evaluate_result - + def _compute_f_pre_rec(self, tp, fn, fp): """ @@ -619,11 +629,10 @@ class SpanFPreRecMetric(MetricBase): pre = tp / (fp + tp + 1e-13) rec = tp / (fn + tp + 1e-13) f = (1 + self.beta_square) * pre * rec / (self.beta_square * pre + rec + 1e-13) - + return f, pre, rec - def _prepare_metrics(metrics): """ @@ -705,33 +714,33 @@ class SQuADMetric(MetricBase): :param bool print_predict_stat: True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出 """ - + def __init__(self, pred1=None, pred2=None, target1=None, target2=None, beta=1, right_open=True, print_predict_stat=False): super(SQuADMetric, self).__init__() - + self._init_param_map(pred1=pred1, pred2=pred2, target1=target1, target2=target2) - + self.print_predict_stat = print_predict_stat - + self.no_ans_correct = 0 self.no_ans_wrong = 0 - + self.has_ans_correct = 0 self.has_ans_wrong = 0 - + self.has_ans_f = 0. - + self.no2no = 0 self.no2yes = 0 self.yes2no = 0 self.yes2yes = 0 - + self.f_beta = beta - + self.right_open = right_open - + def evaluate(self, pred1, pred2, target1, target2): """evaluate函数将针对一个批次的预测结果做评价指标的累计 @@ -745,7 +754,7 @@ class SQuADMetric(MetricBase): pred_end = pred2 target_start = target1 target_end = target2 - + if len(pred_start.size()) == 2: start_inference = pred_start.max(dim=-1)[1].cpu().tolist() else: @@ -754,12 +763,12 @@ class SQuADMetric(MetricBase): end_inference = pred_end.max(dim=-1)[1].cpu().tolist() else: end_inference = pred_end.cpu().tolist() - + start, end = [], [] max_len = pred_start.size(1) t_start = target_start.cpu().tolist() t_end = target_end.cpu().tolist() - + for s, e in zip(start_inference, end_inference): start.append(min(s, e)) end.append(max(s, e)) @@ -779,7 +788,7 @@ class SQuADMetric(MetricBase): self.yes2no += 1 else: self.yes2yes += 1 - + if s == ts and e == te: self.has_ans_correct += 1 else: @@ -787,29 +796,29 @@ class SQuADMetric(MetricBase): a = [0] * s + [1] * (e - s) + [0] * (max_len - e) b = [0] * ts + [1] * (te - ts) + [0] * (max_len - te) a, b = torch.tensor(a), torch.tensor(b) - + TP = int(torch.sum(a * b)) pre = TP / int(torch.sum(a)) if int(torch.sum(a)) > 0 else 0 rec = TP / int(torch.sum(b)) if int(torch.sum(b)) > 0 else 0 - + if pre + rec > 0: - f = (1 + (self.f_beta**2)) * pre * rec / ((self.f_beta**2) * pre + rec) + f = (1 + (self.f_beta ** 2)) * pre * rec / ((self.f_beta ** 2) * pre + rec) else: f = 0 self.has_ans_f += f - + def get_metric(self, reset=True): """get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" evaluate_result = {} - + if self.no_ans_correct + self.no_ans_wrong + self.has_ans_correct + self.no_ans_wrong <= 0: return evaluate_result - + evaluate_result['EM'] = 0 evaluate_result[f'f_{self.f_beta}'] = 0 - + flag = 0 - + if self.no_ans_correct + self.no_ans_wrong > 0: evaluate_result[f'noAns-f_{self.f_beta}'] = \ round(100 * self.no_ans_correct / (self.no_ans_correct + self.no_ans_wrong), 3) @@ -818,7 +827,7 @@ class SQuADMetric(MetricBase): evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'noAns-f_{self.f_beta}'] evaluate_result['EM'] += evaluate_result['noAns-EM'] flag += 1 - + if self.has_ans_correct + self.has_ans_wrong > 0: evaluate_result[f'hasAns-f_{self.f_beta}'] = \ round(100 * self.has_ans_f / (self.has_ans_correct + self.has_ans_wrong), 3) @@ -827,32 +836,31 @@ class SQuADMetric(MetricBase): evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'hasAns-f_{self.f_beta}'] evaluate_result['EM'] += evaluate_result['hasAns-EM'] flag += 1 - + if self.print_predict_stat: evaluate_result['no2no'] = self.no2no evaluate_result['no2yes'] = self.no2yes evaluate_result['yes2no'] = self.yes2no evaluate_result['yes2yes'] = self.yes2yes - + if flag <= 0: return evaluate_result - + evaluate_result[f'f_{self.f_beta}'] = round(evaluate_result[f'f_{self.f_beta}'] / flag, 3) evaluate_result['EM'] = round(evaluate_result['EM'] / flag, 3) - + if reset: self.no_ans_correct = 0 self.no_ans_wrong = 0 - + self.has_ans_correct = 0 self.has_ans_wrong = 0 - + self.has_ans_f = 0. - + self.no2no = 0 self.no2yes = 0 self.yes2no = 0 self.yes2yes = 0 - + return evaluate_result - diff --git a/fastNLP/core/optimizer.py b/fastNLP/core/optimizer.py index ea4905eb..28f618f9 100644 --- a/fastNLP/core/optimizer.py +++ b/fastNLP/core/optimizer.py @@ -4,6 +4,12 @@ optimizer 模块定义了 fastNLP 中所需的各种优化器,一般做为 :cl """ import torch +__all__ = [ + "Optimizer", + "SGD", + "Adam" +] + class Optimizer(object): """ @@ -12,15 +18,16 @@ class Optimizer(object): :param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. :param kwargs: additional parameters. """ + def __init__(self, model_params, **kwargs): if model_params is not None and not hasattr(model_params, "__next__"): raise RuntimeError("model parameters should be a generator, rather than {}.".format(type(model_params))) self.model_params = model_params self.settings = kwargs - + def construct_from_pytorch(self, model_params): raise NotImplementedError - + def _get_require_grads_param(self, params): """ 将params中不需要gradient的删除 @@ -29,6 +36,7 @@ class Optimizer(object): """ return [param for param in params if param.requires_grad] + class SGD(Optimizer): """ 别名::class:`fastNLP.SGD` :class:`fastNLP.core.optimizer.SGD` @@ -37,12 +45,12 @@ class SGD(Optimizer): :param float momentum: momentum. Default: 0 :param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. """ - + def __init__(self, lr=0.001, momentum=0, model_params=None): if not isinstance(lr, float): raise TypeError("learning rate has to be float.") super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) - + def construct_from_pytorch(self, model_params): if self.model_params is None: # careful! generator cannot be assigned. @@ -59,13 +67,13 @@ class Adam(Optimizer): :param float weight_decay: :param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. """ - + def __init__(self, lr=0.001, weight_decay=0, betas=(0.9, 0.999), eps=1e-8, amsgrad=False, model_params=None): if not isinstance(lr, float): raise TypeError("learning rate has to be float.") super(Adam, self).__init__(model_params, lr=lr, betas=betas, eps=eps, amsgrad=amsgrad, weight_decay=weight_decay) - + def construct_from_pytorch(self, model_params): if self.model_params is None: # careful! generator cannot be assigned. diff --git a/fastNLP/core/predictor.py b/fastNLP/core/predictor.py index 34784b7c..a9ef7924 100644 --- a/fastNLP/core/predictor.py +++ b/fastNLP/core/predictor.py @@ -1,7 +1,11 @@ -from collections import defaultdict - +""" + ..todo:: + 检查这个类是否需要 +""" import torch +from collections import defaultdict + from . import Batch from . import DataSet from . import SequentialSampler @@ -9,7 +13,8 @@ from .utils import _build_args class Predictor(object): - """An interface for predicting outputs based on trained models. + """ + An interface for predicting outputs based on trained models. It does not care about evaluations of the model, which is different from Tester. This is a high-level model wrapper to be called by FastNLP. diff --git a/fastNLP/core/sampler.py b/fastNLP/core/sampler.py index e270dac1..0900e733 100644 --- a/fastNLP/core/sampler.py +++ b/fastNLP/core/sampler.py @@ -1,12 +1,16 @@ """ sampler 子类实现了 fastNLP 所需的各种采样器。 - - """ -__all__ = ["Sampler", "BucketSampler", "SequentialSampler", "RandomSampler"] +import numpy as np + from itertools import chain -import numpy as np +__all__ = [ + "Sampler", + "BucketSampler", + "SequentialSampler", + "RandomSampler" +] class Sampler(object): diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 7b6fdda5..47aef46e 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -33,9 +33,8 @@ Tester在验证进行之前会调用model.eval()提示当前进入了evaluation """ import warnings - import torch -from torch import nn +import torch.nn as nn from .batch import Batch from .dataset import DataSet @@ -49,6 +48,10 @@ from .utils import _get_func_signature from .utils import _get_model_device from .utils import _move_model_to_device +__all__ = [ + "Tester" +] + class Tester(object): """ @@ -77,29 +80,29 @@ class Tester(object): 如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。 :param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 """ - + def __init__(self, data, model, metrics, batch_size=16, device=None, verbose=1): super(Tester, self).__init__() - + if not isinstance(data, DataSet): raise TypeError(f"The type of data must be `fastNLP.DataSet`, got `{type(data)}`.") if not isinstance(model, nn.Module): raise TypeError(f"The type of model must be `torch.nn.Module`, got `{type(model)}`.") - + self.metrics = _prepare_metrics(metrics) - + self.data = data self._model = _move_model_to_device(model, device=device) self.batch_size = batch_size self.verbose = verbose - + # 如果是DataParallel将没有办法使用predict方法 if isinstance(self._model, nn.DataParallel): if hasattr(self._model.module, 'predict') and not hasattr(self._model, 'predict'): warnings.warn("Cannot use DataParallel to test your model, because your model offer predict() function," " while DataParallel has no predict() function.") self._model = self._model.module - + # check predict if hasattr(self._model, 'predict'): self._predict_func = self._model.predict @@ -109,7 +112,7 @@ class Tester(object): f"for evaluation, not `{type(self._predict_func)}`.") else: self._predict_func = self._model.forward - + def test(self): """开始进行验证,并返回验证结果。 @@ -144,12 +147,12 @@ class Tester(object): _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, dataset=self.data, check_level=0) - + if self.verbose >= 1: print("[tester] \n{}".format(self._format_eval_results(eval_results))) self._mode(network, is_test=False) return eval_results - + def _mode(self, model, is_test=False): """Train mode or Test mode. This is for PyTorch currently. @@ -161,13 +164,13 @@ class Tester(object): model.eval() else: model.train() - + def _data_forward(self, func, x): """A forward pass of the model. """ x = _build_args(func, **x) y = func(**x) return y - + def _format_eval_results(self, results): """Override this method to support more print formats. diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 9b56d834..87d57f12 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -297,13 +297,12 @@ Example2.3 """ import os -import time -from datetime import datetime -from datetime import timedelta - import numpy as np +import time import torch -from torch import nn +import torch.nn as nn + +from datetime import datetime, timedelta try: from tqdm.auto import tqdm @@ -315,6 +314,7 @@ from .callback import CallbackManager, CallbackException from .dataset import DataSet from .losses import _prepare_losser from .metrics import _prepare_metrics +from .optimizer import Optimizer from .sampler import Sampler from .sampler import RandomSampler from .sampler import SequentialSampler @@ -326,7 +326,6 @@ from .utils import _check_loss_evaluate from .utils import _move_dict_value_to_device from .utils import _get_func_signature from .utils import _get_model_device -from .optimizer import Optimizer from .utils import _move_model_to_device @@ -464,7 +463,7 @@ class Trainer(object): len(self.train_data) % self.batch_size != 0)) * self.n_epochs self.model = _move_model_to_device(self.model, device=device) - + if isinstance(optimizer, torch.optim.Optimizer): self.optimizer = optimizer elif isinstance(optimizer, Optimizer): diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 2c386bbe..a7ad3326 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -1,20 +1,25 @@ """ utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。 """ -__all__ = ["cache_results", "seq_len_to_mask"] import _pickle import inspect +import numpy as np import os +import torch +import torch.nn as nn import warnings + from collections import Counter from collections import namedtuple -import numpy as np -import torch -from torch import nn +__all__ = [ + "cache_results", + "seq_len_to_mask" +] _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', - 'varargs']) + 'varargs']) + def _prepare_cache_filepath(filepath): """ @@ -83,11 +88,13 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): :param int _verbose: 是否打印cache的信息。 :return: """ + def wrapper_(func): signature = inspect.signature(func) for key, _ in signature.parameters.items(): if key in ('_cache_fp', '_refresh', '_verbose'): raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) + def wrapper(*args, **kwargs): if '_cache_fp' in kwargs: cache_filepath = kwargs.pop('_cache_fp') @@ -95,7 +102,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): else: cache_filepath = _cache_fp if '_refresh' in kwargs: - refresh = kwargs.pop('_refresh') + refresh = kwargs.pop('_refresh') assert isinstance(refresh, bool), "_refresh can only be bool." else: refresh = _refresh @@ -105,16 +112,16 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): else: verbose = _verbose refresh_flag = True - + if cache_filepath is not None and refresh is False: # load data if os.path.exists(cache_filepath): with open(cache_filepath, 'rb') as f: results = _pickle.load(f) - if verbose==1: + if verbose == 1: print("Read cache from {}.".format(cache_filepath)) refresh_flag = False - + if refresh_flag: results = func(*args, **kwargs) if cache_filepath is not None: @@ -124,11 +131,14 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): with open(cache_filepath, 'wb') as f: _pickle.dump(results, f) print("Save cache to {}.".format(cache_filepath)) - + return results + return wrapper + return wrapper_ + # def save_pickle(obj, pickle_path, file_name): # """Save an object into a pickle file. # @@ -196,7 +206,7 @@ def _move_model_to_device(model, device): """ if isinstance(model, torch.nn.parallel.DistributedDataParallel): raise RuntimeError("model of `torch.nn.parallel.DistributedDataParallel` is not supported right now.") - + if device is None: if isinstance(model, torch.nn.DataParallel): model.cuda() @@ -205,34 +215,35 @@ def _move_model_to_device(model, device): if not torch.cuda.is_available() and ( device != 'cpu' or (isinstance(device, torch.device) and device.type != 'cpu')): raise ValueError("There is no usable gpu. set `device` as `cpu` or `None`.") - + if isinstance(model, torch.nn.DataParallel): raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.") - + if isinstance(device, int): - assert device>-1, "device can only be non-negative integer" - assert torch.cuda.device_count()>device, "Only has {} gpus, cannot use device {}.".format(torch.cuda.device_count(), - device) + assert device > -1, "device can only be non-negative integer" + assert torch.cuda.device_count() > device, "Only has {} gpus, cannot use device {}.".format( + torch.cuda.device_count(), + device) device = torch.device('cuda:{}'.format(device)) elif isinstance(device, str): device = torch.device(device) if device.type == 'cuda' and device.index is not None: - assert device.index-1, "Only non-negative device id allowed." - if len(device)>1: + assert d > -1, "Only non-negative device id allowed." + if len(device) > 1: output_device = device[0] model = nn.DataParallel(model, device_ids=device, output_device=output_device) device = torch.device(device[0]) @@ -250,9 +261,9 @@ def _get_model_device(model): :return: torch.device,None 如果返回值为None,说明这个模型没有任何参数。 """ assert isinstance(model, nn.Module) - + parameters = list(model.parameters()) - if len(parameters)==0: + if len(parameters) == 0: return None else: return parameters[0].device @@ -407,7 +418,7 @@ def _move_dict_value_to_device(*args, device: torch.device, non_blocking=False): if not isinstance(device, torch.device): raise TypeError(f"device must be `torch.device`, got `{type(device)}`") - + for arg in args: if isinstance(arg, dict): for key, value in arg.items(): @@ -422,10 +433,10 @@ class _CheckError(Exception): _CheckError. Used in losses.LossBase, metrics.MetricBase. """ - + def __init__(self, check_res: _CheckRes, func_signature: str): errs = [f'Problems occurred when calling `{func_signature}`'] - + if check_res.varargs: errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") if check_res.missing: @@ -434,9 +445,9 @@ class _CheckError(Exception): errs.append(f"\tduplicated param: {check_res.duplicated}") if check_res.unused: errs.append(f"\tunused param: {check_res.unused}") - + Exception.__init__(self, '\n'.join(errs)) - + self.check_res = check_res self.func_signature = func_signature @@ -456,7 +467,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re # if check_res.varargs: # errs.append(f"\tvarargs: *{check_res.varargs}") # suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") - + if check_res.unused: for _unused in check_res.unused: if _unused in target_dict: @@ -466,8 +477,8 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re if _unused_field: unuseds.append(f"\tunused field: {_unused_field}") if _unused_param: - unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward - + unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward + module_name = func_signature.split('.')[0] if check_res.missing: errs.append(f"\tmissing param: {check_res.missing}") @@ -488,14 +499,14 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re mapped_missing.append(_miss) else: unmapped_missing.append(_miss) - + for _miss in mapped_missing + unmapped_missing: if _miss in dataset: suggestions.append(f"Set `{_miss}` as target.") else: _tmp = '' if check_res.unused: - _tmp = f"Check key assignment for `{input_func_map.get(_miss, _miss)}` when initialize {module_name}." + _tmp = f"Check key assignment for `{input_func_map.get(_miss,_miss)}` when initialize {module_name}." if _tmp: _tmp += f' Or provide `{_miss}` in DataSet or output of {prev_func_signature}.' else: @@ -513,25 +524,25 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re # else: # _tmp = f'Provide `{_miss}` in output of {prev_func_signature} or DataSet.' # suggestions.append(_tmp) - + if check_res.duplicated: errs.append(f"\tduplicated param: {check_res.duplicated}.") suggestions.append(f"Delete {check_res.duplicated} in the output of " f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") - - if len(errs)>0: + + if len(errs) > 0: errs.extend(unuseds) elif check_level == STRICT_CHECK_LEVEL: errs.extend(unuseds) - + if len(errs) > 0: errs.insert(0, f'Problems occurred when calling {func_signature}') sugg_str = "" if len(suggestions) > 1: for idx, sugg in enumerate(suggestions): - if idx>0: + if idx > 0: sugg_str += '\t\t\t' - sugg_str += f'({idx+1}). {sugg}\n' + sugg_str += f'({idx + 1}). {sugg}\n' sugg_str = sugg_str[:-1] else: sugg_str += suggestions[0] @@ -546,14 +557,15 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re _unused_warn = f'{check_res.unused} is not used by {module_name}.' warnings.warn(message=_unused_warn) + def _check_forward_error(forward_func, batch_x, dataset, check_level): check_res = _check_arg_dict_list(forward_func, batch_x) func_signature = _get_func_signature(forward_func) - + errs = [] suggestions = [] _unused = [] - + # if check_res.varargs: # errs.append(f"\tvarargs: {check_res.varargs}") # suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") @@ -574,20 +586,20 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): # _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \ # f"rename the field in `unused field:`." suggestions.append(_tmp) - + if check_res.unused: _unused = [f"\tunused field: {check_res.unused}"] - if len(errs)>0: + if len(errs) > 0: errs.extend(_unused) elif check_level == STRICT_CHECK_LEVEL: errs.extend(_unused) - + if len(errs) > 0: errs.insert(0, f'Problems occurred when calling {func_signature}') sugg_str = "" if len(suggestions) > 1: for idx, sugg in enumerate(suggestions): - sugg_str += f'({idx+1}). {sugg}' + sugg_str += f'({idx + 1}). {sugg}' else: sugg_str += suggestions[0] err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str @@ -622,8 +634,8 @@ def seq_len_to_mask(seq_len): assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}." max_len = int(seq_len.max()) broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) - mask = broad_cast_seq_len