@@ -2,15 +2,19 @@ | |||||
batch 模块实现了 fastNLP 所需的 Batch 类。 | batch 模块实现了 fastNLP 所需的 Batch 类。 | ||||
""" | """ | ||||
__all__ = ["Batch"] | |||||
import atexit | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
import atexit | |||||
from .sampler import RandomSampler, Sampler | |||||
import torch.multiprocessing as mp | import torch.multiprocessing as mp | ||||
from queue import Empty, Full | from queue import Empty, Full | ||||
from .sampler import RandomSampler | |||||
__all__ = [ | |||||
"Batch" | |||||
] | |||||
_python_is_exit = False | _python_is_exit = False | ||||
@@ -120,7 +124,7 @@ class Batch(object): | |||||
:return list(int) indexes: 下标序列 | :return list(int) indexes: 下标序列 | ||||
""" | """ | ||||
return self.cur_batch_indices | return self.cur_batch_indices | ||||
@staticmethod | @staticmethod | ||||
def _run_fetch(batch, q): | def _run_fetch(batch, q): | ||||
try: | try: | ||||
@@ -145,7 +149,7 @@ class Batch(object): | |||||
q.put(e) | q.put(e) | ||||
finally: | finally: | ||||
q.join() | q.join() | ||||
@staticmethod | @staticmethod | ||||
def _run_batch_iter(batch): | def _run_batch_iter(batch): | ||||
q = mp.JoinableQueue(maxsize=10) | q = mp.JoinableQueue(maxsize=10) | ||||
@@ -182,4 +186,3 @@ def _to_tensor(batch, dtype): | |||||
except: | except: | ||||
pass | pass | ||||
return batch | return batch | ||||
@@ -49,6 +49,18 @@ callback模块实现了 fastNLP 中的许多 callback 类,用于增强 :class: | |||||
trainer.train() | trainer.train() | ||||
""" | """ | ||||
import os | |||||
import torch | |||||
try: | |||||
from tensorboardX import SummaryWriter | |||||
tensorboardX_flag = True | |||||
except: | |||||
tensorboardX_flag = False | |||||
from ..io.model_io import ModelSaver, ModelLoader | |||||
__all__ = [ | __all__ = [ | ||||
"Callback", | "Callback", | ||||
"GradientClipCallback", | "GradientClipCallback", | ||||
@@ -60,15 +72,6 @@ __all__ = [ | |||||
"CallbackException", | "CallbackException", | ||||
"EarlyStopError" | "EarlyStopError" | ||||
] | ] | ||||
import os | |||||
import torch | |||||
from ..io.model_io import ModelSaver, ModelLoader | |||||
try: | |||||
from tensorboardX import SummaryWriter | |||||
tensorboardX_flag = True | |||||
except: | |||||
tensorboardX_flag = False | |||||
class Callback(object): | class Callback(object): | ||||
@@ -587,7 +590,7 @@ class TensorboardCallback(Callback): | |||||
self._summary_writer = SummaryWriter(path) | self._summary_writer = SummaryWriter(path) | ||||
else: | else: | ||||
self._summary_writer = None | self._summary_writer = None | ||||
def on_batch_begin(self, batch_x, batch_y, indices): | def on_batch_begin(self, batch_x, batch_y, indices): | ||||
if "model" in self.options and self.graph_added is False: | if "model" in self.options and self.graph_added is False: | ||||
# tesorboardX 这里有大bug,暂时没法画模型图 | # tesorboardX 这里有大bug,暂时没法画模型图 | ||||
@@ -272,9 +272,7 @@ | |||||
""" | """ | ||||
__all__ = ["DataSet"] | |||||
import _pickle as pickle | import _pickle as pickle | ||||
import numpy as np | import numpy as np | ||||
import warnings | import warnings | ||||
@@ -283,6 +281,10 @@ from .field import FieldArray | |||||
from .instance import Instance | from .instance import Instance | ||||
from .utils import _get_func_signature | from .utils import _get_func_signature | ||||
__all__ = [ | |||||
"DataSet" | |||||
] | |||||
class DataSet(object): | class DataSet(object): | ||||
""" | """ | ||||
@@ -854,4 +856,4 @@ class DataSet(object): | |||||
with open(path, 'rb') as f: | with open(path, 'rb') as f: | ||||
d = pickle.load(f) | d = pickle.load(f) | ||||
assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) | assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) | ||||
return d | |||||
return d |
@@ -3,11 +3,17 @@ field模块实现了 FieldArray 和若干 Padder。 FieldArray 是 :class:`~fas | |||||
原理部分请参考 :doc:`fastNLP.core.dataset` | 原理部分请参考 :doc:`fastNLP.core.dataset` | ||||
""" | """ | ||||
import numpy as np | import numpy as np | ||||
from copy import deepcopy | from copy import deepcopy | ||||
__all__ = [ | |||||
"FieldArray", | |||||
"Padder", | |||||
"AutoPadder", | |||||
"EngChar2DPadder" | |||||
] | |||||
class FieldArray(object): | class FieldArray(object): | ||||
""" | """ | ||||
@@ -24,6 +30,7 @@ class FieldArray(object): | |||||
:param bool ignore_type: 是否忽略该field的type,一般如果这个field不需要转为torch.FloatTensor或torch.LongTensor, | :param bool ignore_type: 是否忽略该field的type,一般如果这个field不需要转为torch.FloatTensor或torch.LongTensor, | ||||
就可以设置为True。具体意义请参考 :class:`~fastNLP.DataSet` 。 | 就可以设置为True。具体意义请参考 :class:`~fastNLP.DataSet` 。 | ||||
""" | """ | ||||
def __init__(self, name, content, is_target=None, is_input=None, padder=None, ignore_type=False): | def __init__(self, name, content, is_target=None, is_input=None, padder=None, ignore_type=False): | ||||
self.name = name | self.name = name | ||||
if isinstance(content, list): | if isinstance(content, list): | ||||
@@ -41,7 +48,7 @@ class FieldArray(object): | |||||
raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content))) | raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content))) | ||||
if len(content) == 0: | if len(content) == 0: | ||||
raise RuntimeError("Cannot initialize FieldArray with empty list.") | raise RuntimeError("Cannot initialize FieldArray with empty list.") | ||||
self.content = content # 1维 或 2维 或 3维 list, 形状可能不对齐 | self.content = content # 1维 或 2维 或 3维 list, 形状可能不对齐 | ||||
self.content_dim = None # 表示content是多少维的list | self.content_dim = None # 表示content是多少维的list | ||||
if padder is None: | if padder is None: | ||||
@@ -51,27 +58,27 @@ class FieldArray(object): | |||||
padder = deepcopy(padder) | padder = deepcopy(padder) | ||||
self.set_padder(padder) | self.set_padder(padder) | ||||
self.ignore_type = ignore_type | self.ignore_type = ignore_type | ||||
self.BASIC_TYPES = (int, float, str) # content中可接受的Python基本类型,这里没有np.array | self.BASIC_TYPES = (int, float, str) # content中可接受的Python基本类型,这里没有np.array | ||||
self.pytype = None | self.pytype = None | ||||
self.dtype = None | self.dtype = None | ||||
self._is_input = None | self._is_input = None | ||||
self._is_target = None | self._is_target = None | ||||
if is_input is not None or is_target is not None: | if is_input is not None or is_target is not None: | ||||
self.is_input = is_input | self.is_input = is_input | ||||
self.is_target = is_target | self.is_target = is_target | ||||
def _set_dtype(self): | def _set_dtype(self): | ||||
if self.ignore_type is False: | if self.ignore_type is False: | ||||
self.pytype = self._type_detection(self.content) | self.pytype = self._type_detection(self.content) | ||||
self.dtype = self._map_to_np_type(self.pytype) | self.dtype = self._map_to_np_type(self.pytype) | ||||
@property | @property | ||||
def is_input(self): | def is_input(self): | ||||
return self._is_input | return self._is_input | ||||
@is_input.setter | @is_input.setter | ||||
def is_input(self, value): | def is_input(self, value): | ||||
""" | """ | ||||
@@ -80,11 +87,11 @@ class FieldArray(object): | |||||
if value is True: | if value is True: | ||||
self._set_dtype() | self._set_dtype() | ||||
self._is_input = value | self._is_input = value | ||||
@property | @property | ||||
def is_target(self): | def is_target(self): | ||||
return self._is_target | return self._is_target | ||||
@is_target.setter | @is_target.setter | ||||
def is_target(self, value): | def is_target(self, value): | ||||
""" | """ | ||||
@@ -93,7 +100,7 @@ class FieldArray(object): | |||||
if value is True: | if value is True: | ||||
self._set_dtype() | self._set_dtype() | ||||
self._is_target = value | self._is_target = value | ||||
def _type_detection(self, content): | def _type_detection(self, content): | ||||
""" | """ | ||||
当该field被设置为is_input或者is_target时被调用 | 当该field被设置为is_input或者is_target时被调用 | ||||
@@ -101,9 +108,9 @@ class FieldArray(object): | |||||
""" | """ | ||||
if len(content) == 0: | if len(content) == 0: | ||||
raise RuntimeError("Empty list in Field {}.".format(self.name)) | raise RuntimeError("Empty list in Field {}.".format(self.name)) | ||||
type_set = set([type(item) for item in content]) | type_set = set([type(item) for item in content]) | ||||
if list in type_set: | if list in type_set: | ||||
if len(type_set) > 1: | if len(type_set) > 1: | ||||
# list 跟 非list 混在一起 | # list 跟 非list 混在一起 | ||||
@@ -139,7 +146,7 @@ class FieldArray(object): | |||||
self.name, self.BASIC_TYPES, content_type)) | self.name, self.BASIC_TYPES, content_type)) | ||||
self.content_dim = 1 | self.content_dim = 1 | ||||
return self._basic_type_detection(type_set) | return self._basic_type_detection(type_set) | ||||
def _basic_type_detection(self, type_set): | def _basic_type_detection(self, type_set): | ||||
""" | """ | ||||
:param type_set: a set of Python types | :param type_set: a set of Python types | ||||
@@ -158,7 +165,7 @@ class FieldArray(object): | |||||
else: | else: | ||||
# str, int, float混在一起 | # str, int, float混在一起 | ||||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) | raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) | ||||
def _1d_list_check(self, val): | def _1d_list_check(self, val): | ||||
"""如果不是1D list就报错 | """如果不是1D list就报错 | ||||
""" | """ | ||||
@@ -168,7 +175,7 @@ class FieldArray(object): | |||||
self._basic_type_detection(type_set) | self._basic_type_detection(type_set) | ||||
# otherwise: _basic_type_detection will raise error | # otherwise: _basic_type_detection will raise error | ||||
return True | return True | ||||
def _2d_list_check(self, val): | def _2d_list_check(self, val): | ||||
"""如果不是2D list 就报错 | """如果不是2D list 就报错 | ||||
""" | """ | ||||
@@ -181,15 +188,15 @@ class FieldArray(object): | |||||
inner_type_set.add(type(obj)) | inner_type_set.add(type(obj)) | ||||
self._basic_type_detection(inner_type_set) | self._basic_type_detection(inner_type_set) | ||||
return True | return True | ||||
@staticmethod | @staticmethod | ||||
def _map_to_np_type(basic_type): | def _map_to_np_type(basic_type): | ||||
type_mapping = {int: np.int64, float: np.float64, str: np.str, np.ndarray: np.ndarray} | type_mapping = {int: np.int64, float: np.float64, str: np.str, np.ndarray: np.ndarray} | ||||
return type_mapping[basic_type] | return type_mapping[basic_type] | ||||
def __repr__(self): | def __repr__(self): | ||||
return "FieldArray {}: {}".format(self.name, self.content.__repr__()) | return "FieldArray {}: {}".format(self.name, self.content.__repr__()) | ||||
def append(self, val): | def append(self, val): | ||||
"""将val append到这个field的尾部。如果这个field已经被设置为input或者target,则在append之前会检查该类型是否与已有 | """将val append到这个field的尾部。如果这个field已经被设置为input或者target,则在append之前会检查该类型是否与已有 | ||||
的内容是匹配的。 | 的内容是匹配的。 | ||||
@@ -208,7 +215,7 @@ class FieldArray(object): | |||||
else: | else: | ||||
raise RuntimeError( | raise RuntimeError( | ||||
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) | "Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) | ||||
if self.is_input is True or self.is_target is True: | if self.is_input is True or self.is_target is True: | ||||
if type(val) == list: | if type(val) == list: | ||||
if len(val) == 0: | if len(val) == 0: | ||||
@@ -231,14 +238,14 @@ class FieldArray(object): | |||||
raise RuntimeError( | raise RuntimeError( | ||||
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) | "Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) | ||||
self.content.append(val) | self.content.append(val) | ||||
def __getitem__(self, indices): | def __getitem__(self, indices): | ||||
return self.get(indices, pad=False) | return self.get(indices, pad=False) | ||||
def __setitem__(self, idx, val): | def __setitem__(self, idx, val): | ||||
assert isinstance(idx, int) | assert isinstance(idx, int) | ||||
self.content[idx] = val | self.content[idx] = val | ||||
def get(self, indices, pad=True): | def get(self, indices, pad=True): | ||||
""" | """ | ||||
根据给定的indices返回内容 | 根据给定的indices返回内容 | ||||
@@ -251,13 +258,13 @@ class FieldArray(object): | |||||
return self.content[indices] | return self.content[indices] | ||||
if self.is_input is False and self.is_target is False: | if self.is_input is False and self.is_target is False: | ||||
raise RuntimeError("Please specify either is_input or is_target is True for {}".format(self.name)) | raise RuntimeError("Please specify either is_input or is_target is True for {}".format(self.name)) | ||||
contents = [self.content[i] for i in indices] | contents = [self.content[i] for i in indices] | ||||
if self.padder is None or pad is False: | if self.padder is None or pad is False: | ||||
return np.array(contents) | return np.array(contents) | ||||
else: | else: | ||||
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype) | return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype) | ||||
def set_padder(self, padder): | def set_padder(self, padder): | ||||
""" | """ | ||||
设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。 | 设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。 | ||||
@@ -269,7 +276,7 @@ class FieldArray(object): | |||||
self.padder = deepcopy(padder) | self.padder = deepcopy(padder) | ||||
else: | else: | ||||
self.padder = None | self.padder = None | ||||
def set_pad_val(self, pad_val): | def set_pad_val(self, pad_val): | ||||
""" | """ | ||||
修改padder的pad_val. | 修改padder的pad_val. | ||||
@@ -279,8 +286,7 @@ class FieldArray(object): | |||||
if self.padder is not None: | if self.padder is not None: | ||||
self.padder.set_pad_val(pad_val) | self.padder.set_pad_val(pad_val) | ||||
return self | return self | ||||
def __len__(self): | def __len__(self): | ||||
""" | """ | ||||
Returns the size of FieldArray. | Returns the size of FieldArray. | ||||
@@ -288,7 +294,7 @@ class FieldArray(object): | |||||
:return int length: | :return int length: | ||||
""" | """ | ||||
return len(self.content) | return len(self.content) | ||||
def to(self, other): | def to(self, other): | ||||
""" | """ | ||||
将other的属性复制给本FieldArray(other必须为FieldArray类型). | 将other的属性复制给本FieldArray(other必须为FieldArray类型). | ||||
@@ -298,14 +304,15 @@ class FieldArray(object): | |||||
:return: :class:`~fastNLP.FieldArray` | :return: :class:`~fastNLP.FieldArray` | ||||
""" | """ | ||||
assert isinstance(other, FieldArray), "Only support FieldArray type, not {}.".format(type(other)) | assert isinstance(other, FieldArray), "Only support FieldArray type, not {}.".format(type(other)) | ||||
self.is_input = other.is_input | self.is_input = other.is_input | ||||
self.is_target = other.is_target | self.is_target = other.is_target | ||||
self.padder = other.padder | self.padder = other.padder | ||||
self.ignore_type = other.ignore_type | self.ignore_type = other.ignore_type | ||||
return self | return self | ||||
def _is_iterable(content): | def _is_iterable(content): | ||||
try: | try: | ||||
_ = (e for e in content) | _ = (e for e in content) | ||||
@@ -331,13 +338,13 @@ class Padder: | |||||
:return: np.array([padded_element]) | :return: np.array([padded_element]) | ||||
""" | """ | ||||
def __init__(self, pad_val=0, **kwargs): | def __init__(self, pad_val=0, **kwargs): | ||||
self.pad_val = pad_val | self.pad_val = pad_val | ||||
def set_pad_val(self, pad_val): | def set_pad_val(self, pad_val): | ||||
self.pad_val = pad_val | self.pad_val = pad_val | ||||
def __call__(self, contents, field_name, field_ele_dtype): | def __call__(self, contents, field_name, field_ele_dtype): | ||||
""" | """ | ||||
传入的是List内容。假设有以下的DataSet。 | 传入的是List内容。假设有以下的DataSet。 | ||||
@@ -396,13 +403,13 @@ class AutoPadder(Padder): | |||||
2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。 | 2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。 | ||||
即如果Instance中field形如[1, 2, 3, ...],则可以pad;若为[[1,2], [3,4, ...]]则不能进行pad | 即如果Instance中field形如[1, 2, 3, ...],则可以pad;若为[[1,2], [3,4, ...]]则不能进行pad | ||||
""" | """ | ||||
def __init__(self, pad_val=0): | def __init__(self, pad_val=0): | ||||
""" | """ | ||||
:param pad_val: int, padding的位置使用该index | :param pad_val: int, padding的位置使用该index | ||||
""" | """ | ||||
super().__init__(pad_val=pad_val) | super().__init__(pad_val=pad_val) | ||||
def _is_two_dimension(self, contents): | def _is_two_dimension(self, contents): | ||||
""" | """ | ||||
判断contents是不是只有两个维度。[[1,2], [3]]是两个维度. [[[1,2], [3, 4, 5]], [[4,5]]]有三个维度 | 判断contents是不是只有两个维度。[[1,2], [3]]是两个维度. [[[1,2], [3, 4, 5]], [[4,5]]]有三个维度 | ||||
@@ -416,7 +423,7 @@ class AutoPadder(Padder): | |||||
return False | return False | ||||
return True | return True | ||||
return False | return False | ||||
def __call__(self, contents, field_name, field_ele_dtype): | def __call__(self, contents, field_name, field_ele_dtype): | ||||
if not _is_iterable(contents[0]): | if not _is_iterable(contents[0]): | ||||
@@ -458,6 +465,7 @@ class EngChar2DPadder(Padder): | |||||
dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder | dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder | ||||
""" | """ | ||||
def __init__(self, pad_val=0, pad_length=0): | def __init__(self, pad_val=0, pad_length=0): | ||||
""" | """ | ||||
:param pad_val: int, pad的位置使用该index | :param pad_val: int, pad的位置使用该index | ||||
@@ -465,9 +473,9 @@ class EngChar2DPadder(Padder): | |||||
都pad或截取到该长度. | 都pad或截取到该长度. | ||||
""" | """ | ||||
super().__init__(pad_val=pad_val) | super().__init__(pad_val=pad_val) | ||||
self.pad_length = pad_length | self.pad_length = pad_length | ||||
def _exactly_three_dims(self, contents, field_name): | def _exactly_three_dims(self, contents, field_name): | ||||
""" | """ | ||||
检查传入的contents是否刚好是3维,如果不是3维就报错。理论上,第一个维度是batch,第二个维度是word,第三个维度是character | 检查传入的contents是否刚好是3维,如果不是3维就报错。理论上,第一个维度是batch,第二个维度是word,第三个维度是character | ||||
@@ -486,10 +494,10 @@ class EngChar2DPadder(Padder): | |||||
value = value[0] | value = value[0] | ||||
except: | except: | ||||
raise ValueError("Field:{} only has two dimensions.".format(field_name)) | raise ValueError("Field:{} only has two dimensions.".format(field_name)) | ||||
if _is_iterable(value): | if _is_iterable(value): | ||||
raise ValueError("Field:{} has more than 3 dimension.".format(field_name)) | raise ValueError("Field:{} has more than 3 dimension.".format(field_name)) | ||||
def __call__(self, contents, field_name, field_ele_dtype): | def __call__(self, contents, field_name, field_ele_dtype): | ||||
""" | """ | ||||
期望输入类似于 | 期望输入类似于 | ||||
@@ -516,12 +524,12 @@ class EngChar2DPadder(Padder): | |||||
max_sent_length = max(len(word_lst) for word_lst in contents) | max_sent_length = max(len(word_lst) for word_lst in contents) | ||||
batch_size = len(contents) | batch_size = len(contents) | ||||
dtype = type(contents[0][0][0]) | dtype = type(contents[0][0][0]) | ||||
padded_array = np.full((batch_size, max_sent_length, max_char_length), fill_value=self.pad_val, | padded_array = np.full((batch_size, max_sent_length, max_char_length), fill_value=self.pad_val, | ||||
dtype=dtype) | |||||
dtype=dtype) | |||||
for b_idx, word_lst in enumerate(contents): | for b_idx, word_lst in enumerate(contents): | ||||
for c_idx, char_lst in enumerate(word_lst): | for c_idx, char_lst in enumerate(word_lst): | ||||
chars = char_lst[:max_char_length] | chars = char_lst[:max_char_length] | ||||
padded_array[b_idx, c_idx, :len(chars)] = chars | padded_array[b_idx, c_idx, :len(chars)] = chars | ||||
return padded_array | |||||
return padded_array |
@@ -3,7 +3,9 @@ instance 模块实现了Instance 类在fastNLP中对应sample。一个sample可 | |||||
便于理解的例子可以参考文档 :doc:`fastNLP.core.dataset` 中的表格 | 便于理解的例子可以参考文档 :doc:`fastNLP.core.dataset` 中的表格 | ||||
""" | """ | ||||
__all__ = ["Instance"] | |||||
__all__ = [ | |||||
"Instance" | |||||
] | |||||
class Instance(object): | class Instance(object): | ||||
@@ -2,13 +2,12 @@ | |||||
losses 模块定义了 fastNLP 中所需的各种损失函数,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | losses 模块定义了 fastNLP 中所需的各种损失函数,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | ||||
""" | """ | ||||
__all__ = ["LossBase", "L1Loss", "LossFunc", "LossInForward", "BCELoss", "CrossEntropyLoss", "NLLLoss"] | |||||
import inspect | import inspect | ||||
from collections import defaultdict | |||||
import torch | import torch | ||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from collections import defaultdict | |||||
from .utils import _CheckError | from .utils import _CheckError | ||||
from .utils import _CheckRes | from .utils import _CheckRes | ||||
from .utils import _build_args | from .utils import _build_args | ||||
@@ -16,6 +15,18 @@ from .utils import _check_arg_dict_list | |||||
from .utils import _check_function_or_method | from .utils import _check_function_or_method | ||||
from .utils import _get_func_signature | from .utils import _get_func_signature | ||||
__all__ = [ | |||||
"LossBase", | |||||
"LossFunc", | |||||
"LossInForward", | |||||
"CrossEntropyLoss", | |||||
"BCELoss", | |||||
"L1Loss", | |||||
"NLLLoss" | |||||
] | |||||
class LossBase(object): | class LossBase(object): | ||||
""" | """ | ||||
@@ -3,11 +3,11 @@ metrics 模块实现了 fastNLP 所需的各种常用衡量指标,一般做为 | |||||
""" | """ | ||||
import inspect | import inspect | ||||
from collections import defaultdict | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from collections import defaultdict | |||||
from .utils import _CheckError | from .utils import _CheckError | ||||
from .utils import _CheckRes | from .utils import _CheckRes | ||||
from .utils import _build_args | from .utils import _build_args | ||||
@@ -16,6 +16,13 @@ from .utils import _get_func_signature | |||||
from .utils import seq_len_to_mask | from .utils import seq_len_to_mask | ||||
from .vocabulary import Vocabulary | from .vocabulary import Vocabulary | ||||
__all__ = [ | |||||
"MetricBase", | |||||
"AccuracyMetric", | |||||
"SpanFPreRecMetric", | |||||
"SQuADMetric" | |||||
] | |||||
class MetricBase(object): | class MetricBase(object): | ||||
""" | """ | ||||
@@ -106,16 +113,17 @@ class MetricBase(object): | |||||
self.get_metric将统计当前的评价指标并返回评价结果, 返回值需要是一个dict, key是指标名称,value是指标的值 | self.get_metric将统计当前的评价指标并返回评价结果, 返回值需要是一个dict, key是指标名称,value是指标的值 | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
self.param_map = {} # key is param in function, value is input param. | self.param_map = {} # key is param in function, value is input param. | ||||
self._checked = False | self._checked = False | ||||
def evaluate(self, *args, **kwargs): | def evaluate(self, *args, **kwargs): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
raise NotImplemented | raise NotImplemented | ||||
def _init_param_map(self, key_map=None, **kwargs): | def _init_param_map(self, key_map=None, **kwargs): | ||||
"""检查key_map和其他参数map,并将这些映射关系添加到self.param_map | """检查key_map和其他参数map,并将这些映射关系添加到self.param_map | ||||
@@ -148,7 +156,7 @@ class MetricBase(object): | |||||
for value, key_set in value_counter.items(): | for value, key_set in value_counter.items(): | ||||
if len(key_set) > 1: | if len(key_set) > 1: | ||||
raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") | raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") | ||||
# check consistence between signature and param_map | # check consistence between signature and param_map | ||||
func_spect = inspect.getfullargspec(self.evaluate) | func_spect = inspect.getfullargspec(self.evaluate) | ||||
func_args = [arg for arg in func_spect.args if arg != 'self'] | func_args = [arg for arg in func_spect.args if arg != 'self'] | ||||
@@ -157,7 +165,7 @@ class MetricBase(object): | |||||
raise NameError( | raise NameError( | ||||
f"Parameter `{func_param}` is not in {_get_func_signature(self.evaluate)}. Please check the " | f"Parameter `{func_param}` is not in {_get_func_signature(self.evaluate)}. Please check the " | ||||
f"initialization parameters, or change its signature.") | f"initialization parameters, or change its signature.") | ||||
def _fast_param_map(self, pred_dict, target_dict): | def _fast_param_map(self, pred_dict, target_dict): | ||||
"""Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | """Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | ||||
such as pred_dict has one element, target_dict has one element | such as pred_dict has one element, target_dict has one element | ||||
@@ -172,7 +180,7 @@ class MetricBase(object): | |||||
fast_param['target'] = list(target_dict.values())[0] | fast_param['target'] = list(target_dict.values())[0] | ||||
return fast_param | return fast_param | ||||
return fast_param | return fast_param | ||||
def __call__(self, pred_dict, target_dict): | def __call__(self, pred_dict, target_dict): | ||||
""" | """ | ||||
这个方法会调用self.evaluate 方法. | 这个方法会调用self.evaluate 方法. | ||||
@@ -187,12 +195,12 @@ class MetricBase(object): | |||||
:param target_dict: DataSet.batch_y里的键-值对所组成的dict(即is_target=True的fields的内容) | :param target_dict: DataSet.batch_y里的键-值对所组成的dict(即is_target=True的fields的内容) | ||||
:return: | :return: | ||||
""" | """ | ||||
fast_param = self._fast_param_map(pred_dict, target_dict) | fast_param = self._fast_param_map(pred_dict, target_dict) | ||||
if fast_param: | if fast_param: | ||||
self.evaluate(**fast_param) | self.evaluate(**fast_param) | ||||
return | return | ||||
if not self._checked: | if not self._checked: | ||||
if not callable(self.evaluate): | if not callable(self.evaluate): | ||||
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | ||||
@@ -202,14 +210,14 @@ class MetricBase(object): | |||||
for func_arg, input_arg in self.param_map.items(): | for func_arg, input_arg in self.param_map.items(): | ||||
if func_arg not in func_args: | if func_arg not in func_args: | ||||
raise NameError(f"`{func_arg}` not in {_get_func_signature(self.evaluate)}.") | raise NameError(f"`{func_arg}` not in {_get_func_signature(self.evaluate)}.") | ||||
# 2. only part of the param_map are passed, left are not | # 2. only part of the param_map are passed, left are not | ||||
for arg in func_args: | for arg in func_args: | ||||
if arg not in self.param_map: | if arg not in self.param_map: | ||||
self.param_map[arg] = arg # This param does not need mapping. | self.param_map[arg] = arg # This param does not need mapping. | ||||
self._evaluate_args = func_args | self._evaluate_args = func_args | ||||
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} | self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} | ||||
# need to wrap inputs in dict. | # need to wrap inputs in dict. | ||||
mapped_pred_dict = {} | mapped_pred_dict = {} | ||||
mapped_target_dict = {} | mapped_target_dict = {} | ||||
@@ -229,7 +237,7 @@ class MetricBase(object): | |||||
not_duplicate_flag += 1 | not_duplicate_flag += 1 | ||||
if not_duplicate_flag == 3: | if not_duplicate_flag == 3: | ||||
duplicated.append(input_arg) | duplicated.append(input_arg) | ||||
# missing | # missing | ||||
if not self._checked: | if not self._checked: | ||||
check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) | check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) | ||||
@@ -240,23 +248,23 @@ class MetricBase(object): | |||||
for idx, func_arg in enumerate(missing): | for idx, func_arg in enumerate(missing): | ||||
# Don't delete `` in this information, nor add `` | # Don't delete `` in this information, nor add `` | ||||
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | ||||
f"in `{self.__class__.__name__}`)" | |||||
f"in `{self.__class__.__name__}`)" | |||||
check_res = _CheckRes(missing=replaced_missing, | check_res = _CheckRes(missing=replaced_missing, | ||||
unused=check_res.unused, | unused=check_res.unused, | ||||
duplicated=duplicated, | duplicated=duplicated, | ||||
required=check_res.required, | required=check_res.required, | ||||
all_needed=check_res.all_needed, | all_needed=check_res.all_needed, | ||||
varargs=check_res.varargs) | varargs=check_res.varargs) | ||||
if check_res.missing or check_res.duplicated: | if check_res.missing or check_res.duplicated: | ||||
raise _CheckError(check_res=check_res, | raise _CheckError(check_res=check_res, | ||||
func_signature=_get_func_signature(self.evaluate)) | func_signature=_get_func_signature(self.evaluate)) | ||||
refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) | refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) | ||||
self.evaluate(**refined_args) | self.evaluate(**refined_args) | ||||
self._checked = True | self._checked = True | ||||
return | return | ||||
@@ -271,15 +279,16 @@ class AccuracyMetric(MetricBase): | |||||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | ||||
:param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len` | :param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len` | ||||
""" | """ | ||||
def __init__(self, pred=None, target=None, seq_len=None): | def __init__(self, pred=None, target=None, seq_len=None): | ||||
super().__init__() | super().__init__() | ||||
self._init_param_map(pred=pred, target=target, seq_len=seq_len) | self._init_param_map(pred=pred, target=target, seq_len=seq_len) | ||||
self.total = 0 | self.total = 0 | ||||
self.acc_count = 0 | self.acc_count = 0 | ||||
def evaluate(self, pred, target, seq_len=None): | def evaluate(self, pred, target, seq_len=None): | ||||
""" | """ | ||||
evaluate函数将针对一个批次的预测结果做评价指标的累计 | evaluate函数将针对一个批次的预测结果做评价指标的累计 | ||||
@@ -299,16 +308,16 @@ class AccuracyMetric(MetricBase): | |||||
if not isinstance(target, torch.Tensor): | if not isinstance(target, torch.Tensor): | ||||
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | ||||
f"got {type(target)}.") | f"got {type(target)}.") | ||||
if seq_len is not None and not isinstance(seq_len, torch.Tensor): | if seq_len is not None and not isinstance(seq_len, torch.Tensor): | ||||
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | ||||
f"got {type(seq_len)}.") | f"got {type(seq_len)}.") | ||||
if seq_len is not None: | if seq_len is not None: | ||||
masks = seq_len_to_mask(seq_len=seq_len) | masks = seq_len_to_mask(seq_len=seq_len) | ||||
else: | else: | ||||
masks = None | masks = None | ||||
if pred.size() == target.size(): | if pred.size() == target.size(): | ||||
pass | pass | ||||
elif len(pred.size()) == len(target.size()) + 1: | elif len(pred.size()) == len(target.size()) + 1: | ||||
@@ -317,7 +326,7 @@ class AccuracyMetric(MetricBase): | |||||
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | ||||
f"size:{pred.size()}, target should have size: {pred.size()} or " | f"size:{pred.size()}, target should have size: {pred.size()} or " | ||||
f"{pred.size()[:-1]}, got {target.size()}.") | f"{pred.size()[:-1]}, got {target.size()}.") | ||||
target = target.to(pred) | target = target.to(pred) | ||||
if masks is not None: | if masks is not None: | ||||
self.acc_count += torch.sum(torch.eq(pred, target).masked_fill(masks.eq(0), 0)).item() | self.acc_count += torch.sum(torch.eq(pred, target).masked_fill(masks.eq(0), 0)).item() | ||||
@@ -325,7 +334,7 @@ class AccuracyMetric(MetricBase): | |||||
else: | else: | ||||
self.acc_count += torch.sum(torch.eq(pred, target)).item() | self.acc_count += torch.sum(torch.eq(pred, target)).item() | ||||
self.total += np.prod(list(pred.size())) | self.total += np.prod(list(pred.size())) | ||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
""" | """ | ||||
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. | get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. | ||||
@@ -350,7 +359,7 @@ def _bmes_tag_to_spans(tags, ignore_labels=None): | |||||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | ||||
""" | """ | ||||
ignore_labels = set(ignore_labels) if ignore_labels else set() | ignore_labels = set(ignore_labels) if ignore_labels else set() | ||||
spans = [] | spans = [] | ||||
prev_bmes_tag = None | prev_bmes_tag = None | ||||
for idx, tag in enumerate(tags): | for idx, tag in enumerate(tags): | ||||
@@ -358,14 +367,14 @@ def _bmes_tag_to_spans(tags, ignore_labels=None): | |||||
bmes_tag, label = tag[:1], tag[2:] | bmes_tag, label = tag[:1], tag[2:] | ||||
if bmes_tag in ('b', 's'): | if bmes_tag in ('b', 's'): | ||||
spans.append((label, [idx, idx])) | spans.append((label, [idx, idx])) | ||||
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label==spans[-1][0]: | |||||
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]: | |||||
spans[-1][1][1] = idx | spans[-1][1][1] = idx | ||||
else: | else: | ||||
spans.append((label, [idx, idx])) | spans.append((label, [idx, idx])) | ||||
prev_bmes_tag = bmes_tag | prev_bmes_tag = bmes_tag | ||||
return [(span[0], (span[1][0], span[1][1]+1)) | |||||
for span in spans | |||||
if span[0] not in ignore_labels | |||||
return [(span[0], (span[1][0], span[1][1] + 1)) | |||||
for span in spans | |||||
if span[0] not in ignore_labels | |||||
] | ] | ||||
@@ -379,7 +388,7 @@ def _bmeso_tag_to_spans(tags, ignore_labels=None): | |||||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | ||||
""" | """ | ||||
ignore_labels = set(ignore_labels) if ignore_labels else set() | ignore_labels = set(ignore_labels) if ignore_labels else set() | ||||
spans = [] | spans = [] | ||||
prev_bmes_tag = None | prev_bmes_tag = None | ||||
for idx, tag in enumerate(tags): | for idx, tag in enumerate(tags): | ||||
@@ -387,16 +396,16 @@ def _bmeso_tag_to_spans(tags, ignore_labels=None): | |||||
bmes_tag, label = tag[:1], tag[2:] | bmes_tag, label = tag[:1], tag[2:] | ||||
if bmes_tag in ('b', 's'): | if bmes_tag in ('b', 's'): | ||||
spans.append((label, [idx, idx])) | spans.append((label, [idx, idx])) | ||||
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label==spans[-1][0]: | |||||
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]: | |||||
spans[-1][1][1] = idx | spans[-1][1][1] = idx | ||||
elif bmes_tag == 'o': | elif bmes_tag == 'o': | ||||
pass | pass | ||||
else: | else: | ||||
spans.append((label, [idx, idx])) | spans.append((label, [idx, idx])) | ||||
prev_bmes_tag = bmes_tag | prev_bmes_tag = bmes_tag | ||||
return [(span[0], (span[1][0], span[1][1]+1)) | |||||
for span in spans | |||||
if span[0] not in ignore_labels | |||||
return [(span[0], (span[1][0], span[1][1] + 1)) | |||||
for span in spans | |||||
if span[0] not in ignore_labels | |||||
] | ] | ||||
@@ -410,7 +419,7 @@ def _bio_tag_to_spans(tags, ignore_labels=None): | |||||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | ||||
""" | """ | ||||
ignore_labels = set(ignore_labels) if ignore_labels else set() | ignore_labels = set(ignore_labels) if ignore_labels else set() | ||||
spans = [] | spans = [] | ||||
prev_bio_tag = None | prev_bio_tag = None | ||||
for idx, tag in enumerate(tags): | for idx, tag in enumerate(tags): | ||||
@@ -418,14 +427,14 @@ def _bio_tag_to_spans(tags, ignore_labels=None): | |||||
bio_tag, label = tag[:1], tag[2:] | bio_tag, label = tag[:1], tag[2:] | ||||
if bio_tag == 'b': | if bio_tag == 'b': | ||||
spans.append((label, [idx, idx])) | spans.append((label, [idx, idx])) | ||||
elif bio_tag == 'i' and prev_bio_tag in ('b', 'i') and label==spans[-1][0]: | |||||
elif bio_tag == 'i' and prev_bio_tag in ('b', 'i') and label == spans[-1][0]: | |||||
spans[-1][1][1] = idx | spans[-1][1][1] = idx | ||||
elif bio_tag == 'o': # o tag does not count | |||||
elif bio_tag == 'o': # o tag does not count | |||||
pass | pass | ||||
else: | else: | ||||
spans.append((label, [idx, idx])) | spans.append((label, [idx, idx])) | ||||
prev_bio_tag = bio_tag | prev_bio_tag = bio_tag | ||||
return [(span[0], (span[1][0], span[1][1]+1)) for span in spans if span[0] not in ignore_labels] | |||||
return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels] | |||||
class SpanFPreRecMetric(MetricBase): | class SpanFPreRecMetric(MetricBase): | ||||
@@ -470,16 +479,17 @@ class SpanFPreRecMetric(MetricBase): | |||||
:param float beta: f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 | :param float beta: f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 | ||||
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | 则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | ||||
""" | """ | ||||
def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None, | def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None, | ||||
only_gross=True, f_type='micro', beta=1): | |||||
only_gross=True, f_type='micro', beta=1): | |||||
encoding_type = encoding_type.lower() | encoding_type = encoding_type.lower() | ||||
if not isinstance(tag_vocab, Vocabulary): | if not isinstance(tag_vocab, Vocabulary): | ||||
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) | raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) | ||||
if f_type not in ('micro', 'macro'): | if f_type not in ('micro', 'macro'): | ||||
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | ||||
self.encoding_type = encoding_type | self.encoding_type = encoding_type | ||||
if self.encoding_type == 'bmes': | if self.encoding_type == 'bmes': | ||||
self.tag_to_span_func = _bmes_tag_to_spans | self.tag_to_span_func = _bmes_tag_to_spans | ||||
@@ -489,22 +499,22 @@ class SpanFPreRecMetric(MetricBase): | |||||
self.tag_to_span_func = _bmeso_tag_to_spans | self.tag_to_span_func = _bmeso_tag_to_spans | ||||
else: | else: | ||||
raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.") | raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.") | ||||
self.ignore_labels = ignore_labels | self.ignore_labels = ignore_labels | ||||
self.f_type = f_type | self.f_type = f_type | ||||
self.beta = beta | self.beta = beta | ||||
self.beta_square = self.beta**2 | |||||
self.beta_square = self.beta ** 2 | |||||
self.only_gross = only_gross | self.only_gross = only_gross | ||||
super().__init__() | super().__init__() | ||||
self._init_param_map(pred=pred, target=target, seq_len=seq_len) | self._init_param_map(pred=pred, target=target, seq_len=seq_len) | ||||
self.tag_vocab = tag_vocab | self.tag_vocab = tag_vocab | ||||
self._true_positives = defaultdict(int) | self._true_positives = defaultdict(int) | ||||
self._false_positives = defaultdict(int) | self._false_positives = defaultdict(int) | ||||
self._false_negatives = defaultdict(int) | self._false_negatives = defaultdict(int) | ||||
def evaluate(self, pred, target, seq_len): | def evaluate(self, pred, target, seq_len): | ||||
"""evaluate函数将针对一个批次的预测结果做评价指标的累计 | """evaluate函数将针对一个批次的预测结果做评价指标的累计 | ||||
@@ -519,11 +529,11 @@ class SpanFPreRecMetric(MetricBase): | |||||
if not isinstance(target, torch.Tensor): | if not isinstance(target, torch.Tensor): | ||||
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | ||||
f"got {type(target)}.") | f"got {type(target)}.") | ||||
if not isinstance(seq_len, torch.Tensor): | if not isinstance(seq_len, torch.Tensor): | ||||
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | ||||
f"got {type(seq_len)}.") | f"got {type(seq_len)}.") | ||||
if pred.size() == target.size() and len(target.size()) == 2: | if pred.size() == target.size() and len(target.size()) == 2: | ||||
pass | pass | ||||
elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: | elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: | ||||
@@ -536,20 +546,20 @@ class SpanFPreRecMetric(MetricBase): | |||||
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | ||||
f"size:{pred.size()}, target should have size: {pred.size()} or " | f"size:{pred.size()}, target should have size: {pred.size()} or " | ||||
f"{pred.size()[:-1]}, got {target.size()}.") | f"{pred.size()[:-1]}, got {target.size()}.") | ||||
batch_size = pred.size(0) | batch_size = pred.size(0) | ||||
pred = pred.tolist() | pred = pred.tolist() | ||||
target = target.tolist() | target = target.tolist() | ||||
for i in range(batch_size): | for i in range(batch_size): | ||||
pred_tags = pred[i][:int(seq_len[i])] | pred_tags = pred[i][:int(seq_len[i])] | ||||
gold_tags = target[i][:int(seq_len[i])] | gold_tags = target[i][:int(seq_len[i])] | ||||
pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags] | pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags] | ||||
gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags] | gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags] | ||||
pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels) | pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels) | ||||
gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels) | gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels) | ||||
for span in pred_spans: | for span in pred_spans: | ||||
if span in gold_spans: | if span in gold_spans: | ||||
self._true_positives[span[0]] += 1 | self._true_positives[span[0]] += 1 | ||||
@@ -558,7 +568,7 @@ class SpanFPreRecMetric(MetricBase): | |||||
self._false_positives[span[0]] += 1 | self._false_positives[span[0]] += 1 | ||||
for span in gold_spans: | for span in gold_spans: | ||||
self._false_negatives[span[0]] += 1 | self._false_negatives[span[0]] += 1 | ||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" | """get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" | ||||
evaluate_result = {} | evaluate_result = {} | ||||
@@ -577,19 +587,19 @@ class SpanFPreRecMetric(MetricBase): | |||||
f_sum += f | f_sum += f | ||||
pre_sum += pre | pre_sum += pre | ||||
rec_sum + rec | rec_sum + rec | ||||
if not self.only_gross and tag!='': # tag!=''防止无tag的情况 | |||||
if not self.only_gross and tag != '': # tag!=''防止无tag的情况 | |||||
f_key = 'f-{}'.format(tag) | f_key = 'f-{}'.format(tag) | ||||
pre_key = 'pre-{}'.format(tag) | pre_key = 'pre-{}'.format(tag) | ||||
rec_key = 'rec-{}'.format(tag) | rec_key = 'rec-{}'.format(tag) | ||||
evaluate_result[f_key] = f | evaluate_result[f_key] = f | ||||
evaluate_result[pre_key] = pre | evaluate_result[pre_key] = pre | ||||
evaluate_result[rec_key] = rec | evaluate_result[rec_key] = rec | ||||
if self.f_type == 'macro': | if self.f_type == 'macro': | ||||
evaluate_result['f'] = f_sum/len(tags) | |||||
evaluate_result['pre'] = pre_sum/len(tags) | |||||
evaluate_result['rec'] = rec_sum/len(tags) | |||||
evaluate_result['f'] = f_sum / len(tags) | |||||
evaluate_result['pre'] = pre_sum / len(tags) | |||||
evaluate_result['rec'] = rec_sum / len(tags) | |||||
if self.f_type == 'micro': | if self.f_type == 'micro': | ||||
f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()), | f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()), | ||||
sum(self._false_negatives.values()), | sum(self._false_negatives.values()), | ||||
@@ -597,17 +607,17 @@ class SpanFPreRecMetric(MetricBase): | |||||
evaluate_result['f'] = f | evaluate_result['f'] = f | ||||
evaluate_result['pre'] = pre | evaluate_result['pre'] = pre | ||||
evaluate_result['rec'] = rec | evaluate_result['rec'] = rec | ||||
if reset: | if reset: | ||||
self._true_positives = defaultdict(int) | self._true_positives = defaultdict(int) | ||||
self._false_positives = defaultdict(int) | self._false_positives = defaultdict(int) | ||||
self._false_negatives = defaultdict(int) | self._false_negatives = defaultdict(int) | ||||
for key, value in evaluate_result.items(): | for key, value in evaluate_result.items(): | ||||
evaluate_result[key] = round(value, 6) | evaluate_result[key] = round(value, 6) | ||||
return evaluate_result | return evaluate_result | ||||
def _compute_f_pre_rec(self, tp, fn, fp): | def _compute_f_pre_rec(self, tp, fn, fp): | ||||
""" | """ | ||||
@@ -619,11 +629,10 @@ class SpanFPreRecMetric(MetricBase): | |||||
pre = tp / (fp + tp + 1e-13) | pre = tp / (fp + tp + 1e-13) | ||||
rec = tp / (fn + tp + 1e-13) | rec = tp / (fn + tp + 1e-13) | ||||
f = (1 + self.beta_square) * pre * rec / (self.beta_square * pre + rec + 1e-13) | f = (1 + self.beta_square) * pre * rec / (self.beta_square * pre + rec + 1e-13) | ||||
return f, pre, rec | return f, pre, rec | ||||
def _prepare_metrics(metrics): | def _prepare_metrics(metrics): | ||||
""" | """ | ||||
@@ -705,33 +714,33 @@ class SQuADMetric(MetricBase): | |||||
:param bool print_predict_stat: True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出 | :param bool print_predict_stat: True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出 | ||||
""" | """ | ||||
def __init__(self, pred1=None, pred2=None, target1=None, target2=None, | def __init__(self, pred1=None, pred2=None, target1=None, target2=None, | ||||
beta=1, right_open=True, print_predict_stat=False): | beta=1, right_open=True, print_predict_stat=False): | ||||
super(SQuADMetric, self).__init__() | super(SQuADMetric, self).__init__() | ||||
self._init_param_map(pred1=pred1, pred2=pred2, target1=target1, target2=target2) | self._init_param_map(pred1=pred1, pred2=pred2, target1=target1, target2=target2) | ||||
self.print_predict_stat = print_predict_stat | self.print_predict_stat = print_predict_stat | ||||
self.no_ans_correct = 0 | self.no_ans_correct = 0 | ||||
self.no_ans_wrong = 0 | self.no_ans_wrong = 0 | ||||
self.has_ans_correct = 0 | self.has_ans_correct = 0 | ||||
self.has_ans_wrong = 0 | self.has_ans_wrong = 0 | ||||
self.has_ans_f = 0. | self.has_ans_f = 0. | ||||
self.no2no = 0 | self.no2no = 0 | ||||
self.no2yes = 0 | self.no2yes = 0 | ||||
self.yes2no = 0 | self.yes2no = 0 | ||||
self.yes2yes = 0 | self.yes2yes = 0 | ||||
self.f_beta = beta | self.f_beta = beta | ||||
self.right_open = right_open | self.right_open = right_open | ||||
def evaluate(self, pred1, pred2, target1, target2): | def evaluate(self, pred1, pred2, target1, target2): | ||||
"""evaluate函数将针对一个批次的预测结果做评价指标的累计 | """evaluate函数将针对一个批次的预测结果做评价指标的累计 | ||||
@@ -745,7 +754,7 @@ class SQuADMetric(MetricBase): | |||||
pred_end = pred2 | pred_end = pred2 | ||||
target_start = target1 | target_start = target1 | ||||
target_end = target2 | target_end = target2 | ||||
if len(pred_start.size()) == 2: | if len(pred_start.size()) == 2: | ||||
start_inference = pred_start.max(dim=-1)[1].cpu().tolist() | start_inference = pred_start.max(dim=-1)[1].cpu().tolist() | ||||
else: | else: | ||||
@@ -754,12 +763,12 @@ class SQuADMetric(MetricBase): | |||||
end_inference = pred_end.max(dim=-1)[1].cpu().tolist() | end_inference = pred_end.max(dim=-1)[1].cpu().tolist() | ||||
else: | else: | ||||
end_inference = pred_end.cpu().tolist() | end_inference = pred_end.cpu().tolist() | ||||
start, end = [], [] | start, end = [], [] | ||||
max_len = pred_start.size(1) | max_len = pred_start.size(1) | ||||
t_start = target_start.cpu().tolist() | t_start = target_start.cpu().tolist() | ||||
t_end = target_end.cpu().tolist() | t_end = target_end.cpu().tolist() | ||||
for s, e in zip(start_inference, end_inference): | for s, e in zip(start_inference, end_inference): | ||||
start.append(min(s, e)) | start.append(min(s, e)) | ||||
end.append(max(s, e)) | end.append(max(s, e)) | ||||
@@ -779,7 +788,7 @@ class SQuADMetric(MetricBase): | |||||
self.yes2no += 1 | self.yes2no += 1 | ||||
else: | else: | ||||
self.yes2yes += 1 | self.yes2yes += 1 | ||||
if s == ts and e == te: | if s == ts and e == te: | ||||
self.has_ans_correct += 1 | self.has_ans_correct += 1 | ||||
else: | else: | ||||
@@ -787,29 +796,29 @@ class SQuADMetric(MetricBase): | |||||
a = [0] * s + [1] * (e - s) + [0] * (max_len - e) | a = [0] * s + [1] * (e - s) + [0] * (max_len - e) | ||||
b = [0] * ts + [1] * (te - ts) + [0] * (max_len - te) | b = [0] * ts + [1] * (te - ts) + [0] * (max_len - te) | ||||
a, b = torch.tensor(a), torch.tensor(b) | a, b = torch.tensor(a), torch.tensor(b) | ||||
TP = int(torch.sum(a * b)) | TP = int(torch.sum(a * b)) | ||||
pre = TP / int(torch.sum(a)) if int(torch.sum(a)) > 0 else 0 | pre = TP / int(torch.sum(a)) if int(torch.sum(a)) > 0 else 0 | ||||
rec = TP / int(torch.sum(b)) if int(torch.sum(b)) > 0 else 0 | rec = TP / int(torch.sum(b)) if int(torch.sum(b)) > 0 else 0 | ||||
if pre + rec > 0: | if pre + rec > 0: | ||||
f = (1 + (self.f_beta**2)) * pre * rec / ((self.f_beta**2) * pre + rec) | |||||
f = (1 + (self.f_beta ** 2)) * pre * rec / ((self.f_beta ** 2) * pre + rec) | |||||
else: | else: | ||||
f = 0 | f = 0 | ||||
self.has_ans_f += f | self.has_ans_f += f | ||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" | """get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" | ||||
evaluate_result = {} | evaluate_result = {} | ||||
if self.no_ans_correct + self.no_ans_wrong + self.has_ans_correct + self.no_ans_wrong <= 0: | if self.no_ans_correct + self.no_ans_wrong + self.has_ans_correct + self.no_ans_wrong <= 0: | ||||
return evaluate_result | return evaluate_result | ||||
evaluate_result['EM'] = 0 | evaluate_result['EM'] = 0 | ||||
evaluate_result[f'f_{self.f_beta}'] = 0 | evaluate_result[f'f_{self.f_beta}'] = 0 | ||||
flag = 0 | flag = 0 | ||||
if self.no_ans_correct + self.no_ans_wrong > 0: | if self.no_ans_correct + self.no_ans_wrong > 0: | ||||
evaluate_result[f'noAns-f_{self.f_beta}'] = \ | evaluate_result[f'noAns-f_{self.f_beta}'] = \ | ||||
round(100 * self.no_ans_correct / (self.no_ans_correct + self.no_ans_wrong), 3) | round(100 * self.no_ans_correct / (self.no_ans_correct + self.no_ans_wrong), 3) | ||||
@@ -818,7 +827,7 @@ class SQuADMetric(MetricBase): | |||||
evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'noAns-f_{self.f_beta}'] | evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'noAns-f_{self.f_beta}'] | ||||
evaluate_result['EM'] += evaluate_result['noAns-EM'] | evaluate_result['EM'] += evaluate_result['noAns-EM'] | ||||
flag += 1 | flag += 1 | ||||
if self.has_ans_correct + self.has_ans_wrong > 0: | if self.has_ans_correct + self.has_ans_wrong > 0: | ||||
evaluate_result[f'hasAns-f_{self.f_beta}'] = \ | evaluate_result[f'hasAns-f_{self.f_beta}'] = \ | ||||
round(100 * self.has_ans_f / (self.has_ans_correct + self.has_ans_wrong), 3) | round(100 * self.has_ans_f / (self.has_ans_correct + self.has_ans_wrong), 3) | ||||
@@ -827,32 +836,31 @@ class SQuADMetric(MetricBase): | |||||
evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'hasAns-f_{self.f_beta}'] | evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'hasAns-f_{self.f_beta}'] | ||||
evaluate_result['EM'] += evaluate_result['hasAns-EM'] | evaluate_result['EM'] += evaluate_result['hasAns-EM'] | ||||
flag += 1 | flag += 1 | ||||
if self.print_predict_stat: | if self.print_predict_stat: | ||||
evaluate_result['no2no'] = self.no2no | evaluate_result['no2no'] = self.no2no | ||||
evaluate_result['no2yes'] = self.no2yes | evaluate_result['no2yes'] = self.no2yes | ||||
evaluate_result['yes2no'] = self.yes2no | evaluate_result['yes2no'] = self.yes2no | ||||
evaluate_result['yes2yes'] = self.yes2yes | evaluate_result['yes2yes'] = self.yes2yes | ||||
if flag <= 0: | if flag <= 0: | ||||
return evaluate_result | return evaluate_result | ||||
evaluate_result[f'f_{self.f_beta}'] = round(evaluate_result[f'f_{self.f_beta}'] / flag, 3) | evaluate_result[f'f_{self.f_beta}'] = round(evaluate_result[f'f_{self.f_beta}'] / flag, 3) | ||||
evaluate_result['EM'] = round(evaluate_result['EM'] / flag, 3) | evaluate_result['EM'] = round(evaluate_result['EM'] / flag, 3) | ||||
if reset: | if reset: | ||||
self.no_ans_correct = 0 | self.no_ans_correct = 0 | ||||
self.no_ans_wrong = 0 | self.no_ans_wrong = 0 | ||||
self.has_ans_correct = 0 | self.has_ans_correct = 0 | ||||
self.has_ans_wrong = 0 | self.has_ans_wrong = 0 | ||||
self.has_ans_f = 0. | self.has_ans_f = 0. | ||||
self.no2no = 0 | self.no2no = 0 | ||||
self.no2yes = 0 | self.no2yes = 0 | ||||
self.yes2no = 0 | self.yes2no = 0 | ||||
self.yes2yes = 0 | self.yes2yes = 0 | ||||
return evaluate_result | return evaluate_result | ||||
@@ -4,6 +4,12 @@ optimizer 模块定义了 fastNLP 中所需的各种优化器,一般做为 :cl | |||||
""" | """ | ||||
import torch | import torch | ||||
__all__ = [ | |||||
"Optimizer", | |||||
"SGD", | |||||
"Adam" | |||||
] | |||||
class Optimizer(object): | class Optimizer(object): | ||||
""" | """ | ||||
@@ -12,15 +18,16 @@ class Optimizer(object): | |||||
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | :param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | ||||
:param kwargs: additional parameters. | :param kwargs: additional parameters. | ||||
""" | """ | ||||
def __init__(self, model_params, **kwargs): | def __init__(self, model_params, **kwargs): | ||||
if model_params is not None and not hasattr(model_params, "__next__"): | if model_params is not None and not hasattr(model_params, "__next__"): | ||||
raise RuntimeError("model parameters should be a generator, rather than {}.".format(type(model_params))) | raise RuntimeError("model parameters should be a generator, rather than {}.".format(type(model_params))) | ||||
self.model_params = model_params | self.model_params = model_params | ||||
self.settings = kwargs | self.settings = kwargs | ||||
def construct_from_pytorch(self, model_params): | def construct_from_pytorch(self, model_params): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def _get_require_grads_param(self, params): | def _get_require_grads_param(self, params): | ||||
""" | """ | ||||
将params中不需要gradient的删除 | 将params中不需要gradient的删除 | ||||
@@ -29,6 +36,7 @@ class Optimizer(object): | |||||
""" | """ | ||||
return [param for param in params if param.requires_grad] | return [param for param in params if param.requires_grad] | ||||
class SGD(Optimizer): | class SGD(Optimizer): | ||||
""" | """ | ||||
别名::class:`fastNLP.SGD` :class:`fastNLP.core.optimizer.SGD` | 别名::class:`fastNLP.SGD` :class:`fastNLP.core.optimizer.SGD` | ||||
@@ -37,12 +45,12 @@ class SGD(Optimizer): | |||||
:param float momentum: momentum. Default: 0 | :param float momentum: momentum. Default: 0 | ||||
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | :param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | ||||
""" | """ | ||||
def __init__(self, lr=0.001, momentum=0, model_params=None): | def __init__(self, lr=0.001, momentum=0, model_params=None): | ||||
if not isinstance(lr, float): | if not isinstance(lr, float): | ||||
raise TypeError("learning rate has to be float.") | raise TypeError("learning rate has to be float.") | ||||
super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) | super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) | ||||
def construct_from_pytorch(self, model_params): | def construct_from_pytorch(self, model_params): | ||||
if self.model_params is None: | if self.model_params is None: | ||||
# careful! generator cannot be assigned. | # careful! generator cannot be assigned. | ||||
@@ -59,13 +67,13 @@ class Adam(Optimizer): | |||||
:param float weight_decay: | :param float weight_decay: | ||||
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | :param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | ||||
""" | """ | ||||
def __init__(self, lr=0.001, weight_decay=0, betas=(0.9, 0.999), eps=1e-8, amsgrad=False, model_params=None): | def __init__(self, lr=0.001, weight_decay=0, betas=(0.9, 0.999), eps=1e-8, amsgrad=False, model_params=None): | ||||
if not isinstance(lr, float): | if not isinstance(lr, float): | ||||
raise TypeError("learning rate has to be float.") | raise TypeError("learning rate has to be float.") | ||||
super(Adam, self).__init__(model_params, lr=lr, betas=betas, eps=eps, amsgrad=amsgrad, | super(Adam, self).__init__(model_params, lr=lr, betas=betas, eps=eps, amsgrad=amsgrad, | ||||
weight_decay=weight_decay) | weight_decay=weight_decay) | ||||
def construct_from_pytorch(self, model_params): | def construct_from_pytorch(self, model_params): | ||||
if self.model_params is None: | if self.model_params is None: | ||||
# careful! generator cannot be assigned. | # careful! generator cannot be assigned. | ||||
@@ -1,7 +1,11 @@ | |||||
from collections import defaultdict | |||||
""" | |||||
..todo:: | |||||
检查这个类是否需要 | |||||
""" | |||||
import torch | import torch | ||||
from collections import defaultdict | |||||
from . import Batch | from . import Batch | ||||
from . import DataSet | from . import DataSet | ||||
from . import SequentialSampler | from . import SequentialSampler | ||||
@@ -9,7 +13,8 @@ from .utils import _build_args | |||||
class Predictor(object): | class Predictor(object): | ||||
"""An interface for predicting outputs based on trained models. | |||||
""" | |||||
An interface for predicting outputs based on trained models. | |||||
It does not care about evaluations of the model, which is different from Tester. | It does not care about evaluations of the model, which is different from Tester. | ||||
This is a high-level model wrapper to be called by FastNLP. | This is a high-level model wrapper to be called by FastNLP. | ||||
@@ -1,12 +1,16 @@ | |||||
""" | """ | ||||
sampler 子类实现了 fastNLP 所需的各种采样器。 | sampler 子类实现了 fastNLP 所需的各种采样器。 | ||||
""" | """ | ||||
__all__ = ["Sampler", "BucketSampler", "SequentialSampler", "RandomSampler"] | |||||
import numpy as np | |||||
from itertools import chain | from itertools import chain | ||||
import numpy as np | |||||
__all__ = [ | |||||
"Sampler", | |||||
"BucketSampler", | |||||
"SequentialSampler", | |||||
"RandomSampler" | |||||
] | |||||
class Sampler(object): | class Sampler(object): | ||||
@@ -33,9 +33,8 @@ Tester在验证进行之前会调用model.eval()提示当前进入了evaluation | |||||
""" | """ | ||||
import warnings | import warnings | ||||
import torch | import torch | ||||
from torch import nn | |||||
import torch.nn as nn | |||||
from .batch import Batch | from .batch import Batch | ||||
from .dataset import DataSet | from .dataset import DataSet | ||||
@@ -49,6 +48,10 @@ from .utils import _get_func_signature | |||||
from .utils import _get_model_device | from .utils import _get_model_device | ||||
from .utils import _move_model_to_device | from .utils import _move_model_to_device | ||||
__all__ = [ | |||||
"Tester" | |||||
] | |||||
class Tester(object): | class Tester(object): | ||||
""" | """ | ||||
@@ -77,29 +80,29 @@ class Tester(object): | |||||
如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。 | 如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。 | ||||
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 | :param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 | ||||
""" | """ | ||||
def __init__(self, data, model, metrics, batch_size=16, device=None, verbose=1): | def __init__(self, data, model, metrics, batch_size=16, device=None, verbose=1): | ||||
super(Tester, self).__init__() | super(Tester, self).__init__() | ||||
if not isinstance(data, DataSet): | if not isinstance(data, DataSet): | ||||
raise TypeError(f"The type of data must be `fastNLP.DataSet`, got `{type(data)}`.") | raise TypeError(f"The type of data must be `fastNLP.DataSet`, got `{type(data)}`.") | ||||
if not isinstance(model, nn.Module): | if not isinstance(model, nn.Module): | ||||
raise TypeError(f"The type of model must be `torch.nn.Module`, got `{type(model)}`.") | raise TypeError(f"The type of model must be `torch.nn.Module`, got `{type(model)}`.") | ||||
self.metrics = _prepare_metrics(metrics) | self.metrics = _prepare_metrics(metrics) | ||||
self.data = data | self.data = data | ||||
self._model = _move_model_to_device(model, device=device) | self._model = _move_model_to_device(model, device=device) | ||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
self.verbose = verbose | self.verbose = verbose | ||||
# 如果是DataParallel将没有办法使用predict方法 | # 如果是DataParallel将没有办法使用predict方法 | ||||
if isinstance(self._model, nn.DataParallel): | if isinstance(self._model, nn.DataParallel): | ||||
if hasattr(self._model.module, 'predict') and not hasattr(self._model, 'predict'): | if hasattr(self._model.module, 'predict') and not hasattr(self._model, 'predict'): | ||||
warnings.warn("Cannot use DataParallel to test your model, because your model offer predict() function," | warnings.warn("Cannot use DataParallel to test your model, because your model offer predict() function," | ||||
" while DataParallel has no predict() function.") | " while DataParallel has no predict() function.") | ||||
self._model = self._model.module | self._model = self._model.module | ||||
# check predict | # check predict | ||||
if hasattr(self._model, 'predict'): | if hasattr(self._model, 'predict'): | ||||
self._predict_func = self._model.predict | self._predict_func = self._model.predict | ||||
@@ -109,7 +112,7 @@ class Tester(object): | |||||
f"for evaluation, not `{type(self._predict_func)}`.") | f"for evaluation, not `{type(self._predict_func)}`.") | ||||
else: | else: | ||||
self._predict_func = self._model.forward | self._predict_func = self._model.forward | ||||
def test(self): | def test(self): | ||||
"""开始进行验证,并返回验证结果。 | """开始进行验证,并返回验证结果。 | ||||
@@ -144,12 +147,12 @@ class Tester(object): | |||||
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, | _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, | ||||
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, | check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, | ||||
dataset=self.data, check_level=0) | dataset=self.data, check_level=0) | ||||
if self.verbose >= 1: | if self.verbose >= 1: | ||||
print("[tester] \n{}".format(self._format_eval_results(eval_results))) | print("[tester] \n{}".format(self._format_eval_results(eval_results))) | ||||
self._mode(network, is_test=False) | self._mode(network, is_test=False) | ||||
return eval_results | return eval_results | ||||
def _mode(self, model, is_test=False): | def _mode(self, model, is_test=False): | ||||
"""Train mode or Test mode. This is for PyTorch currently. | """Train mode or Test mode. This is for PyTorch currently. | ||||
@@ -161,13 +164,13 @@ class Tester(object): | |||||
model.eval() | model.eval() | ||||
else: | else: | ||||
model.train() | model.train() | ||||
def _data_forward(self, func, x): | def _data_forward(self, func, x): | ||||
"""A forward pass of the model. """ | """A forward pass of the model. """ | ||||
x = _build_args(func, **x) | x = _build_args(func, **x) | ||||
y = func(**x) | y = func(**x) | ||||
return y | return y | ||||
def _format_eval_results(self, results): | def _format_eval_results(self, results): | ||||
"""Override this method to support more print formats. | """Override this method to support more print formats. | ||||
@@ -297,13 +297,12 @@ Example2.3 | |||||
""" | """ | ||||
import os | import os | ||||
import time | |||||
from datetime import datetime | |||||
from datetime import timedelta | |||||
import numpy as np | import numpy as np | ||||
import time | |||||
import torch | import torch | ||||
from torch import nn | |||||
import torch.nn as nn | |||||
from datetime import datetime, timedelta | |||||
try: | try: | ||||
from tqdm.auto import tqdm | from tqdm.auto import tqdm | ||||
@@ -315,6 +314,7 @@ from .callback import CallbackManager, CallbackException | |||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .losses import _prepare_losser | from .losses import _prepare_losser | ||||
from .metrics import _prepare_metrics | from .metrics import _prepare_metrics | ||||
from .optimizer import Optimizer | |||||
from .sampler import Sampler | from .sampler import Sampler | ||||
from .sampler import RandomSampler | from .sampler import RandomSampler | ||||
from .sampler import SequentialSampler | from .sampler import SequentialSampler | ||||
@@ -326,7 +326,6 @@ from .utils import _check_loss_evaluate | |||||
from .utils import _move_dict_value_to_device | from .utils import _move_dict_value_to_device | ||||
from .utils import _get_func_signature | from .utils import _get_func_signature | ||||
from .utils import _get_model_device | from .utils import _get_model_device | ||||
from .optimizer import Optimizer | |||||
from .utils import _move_model_to_device | from .utils import _move_model_to_device | ||||
@@ -464,7 +463,7 @@ class Trainer(object): | |||||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | len(self.train_data) % self.batch_size != 0)) * self.n_epochs | ||||
self.model = _move_model_to_device(self.model, device=device) | self.model = _move_model_to_device(self.model, device=device) | ||||
if isinstance(optimizer, torch.optim.Optimizer): | if isinstance(optimizer, torch.optim.Optimizer): | ||||
self.optimizer = optimizer | self.optimizer = optimizer | ||||
elif isinstance(optimizer, Optimizer): | elif isinstance(optimizer, Optimizer): | ||||
@@ -1,20 +1,25 @@ | |||||
""" | """ | ||||
utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。 | utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。 | ||||
""" | """ | ||||
__all__ = ["cache_results", "seq_len_to_mask"] | |||||
import _pickle | import _pickle | ||||
import inspect | import inspect | ||||
import numpy as np | |||||
import os | import os | ||||
import torch | |||||
import torch.nn as nn | |||||
import warnings | import warnings | ||||
from collections import Counter | from collections import Counter | ||||
from collections import namedtuple | from collections import namedtuple | ||||
import numpy as np | |||||
import torch | |||||
from torch import nn | |||||
__all__ = [ | |||||
"cache_results", | |||||
"seq_len_to_mask" | |||||
] | |||||
_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | ||||
'varargs']) | |||||
'varargs']) | |||||
def _prepare_cache_filepath(filepath): | def _prepare_cache_filepath(filepath): | ||||
""" | """ | ||||
@@ -83,11 +88,13 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||||
:param int _verbose: 是否打印cache的信息。 | :param int _verbose: 是否打印cache的信息。 | ||||
:return: | :return: | ||||
""" | """ | ||||
def wrapper_(func): | def wrapper_(func): | ||||
signature = inspect.signature(func) | signature = inspect.signature(func) | ||||
for key, _ in signature.parameters.items(): | for key, _ in signature.parameters.items(): | ||||
if key in ('_cache_fp', '_refresh', '_verbose'): | if key in ('_cache_fp', '_refresh', '_verbose'): | ||||
raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) | raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) | ||||
def wrapper(*args, **kwargs): | def wrapper(*args, **kwargs): | ||||
if '_cache_fp' in kwargs: | if '_cache_fp' in kwargs: | ||||
cache_filepath = kwargs.pop('_cache_fp') | cache_filepath = kwargs.pop('_cache_fp') | ||||
@@ -95,7 +102,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||||
else: | else: | ||||
cache_filepath = _cache_fp | cache_filepath = _cache_fp | ||||
if '_refresh' in kwargs: | if '_refresh' in kwargs: | ||||
refresh = kwargs.pop('_refresh') | |||||
refresh = kwargs.pop('_refresh') | |||||
assert isinstance(refresh, bool), "_refresh can only be bool." | assert isinstance(refresh, bool), "_refresh can only be bool." | ||||
else: | else: | ||||
refresh = _refresh | refresh = _refresh | ||||
@@ -105,16 +112,16 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||||
else: | else: | ||||
verbose = _verbose | verbose = _verbose | ||||
refresh_flag = True | refresh_flag = True | ||||
if cache_filepath is not None and refresh is False: | if cache_filepath is not None and refresh is False: | ||||
# load data | # load data | ||||
if os.path.exists(cache_filepath): | if os.path.exists(cache_filepath): | ||||
with open(cache_filepath, 'rb') as f: | with open(cache_filepath, 'rb') as f: | ||||
results = _pickle.load(f) | results = _pickle.load(f) | ||||
if verbose==1: | |||||
if verbose == 1: | |||||
print("Read cache from {}.".format(cache_filepath)) | print("Read cache from {}.".format(cache_filepath)) | ||||
refresh_flag = False | refresh_flag = False | ||||
if refresh_flag: | if refresh_flag: | ||||
results = func(*args, **kwargs) | results = func(*args, **kwargs) | ||||
if cache_filepath is not None: | if cache_filepath is not None: | ||||
@@ -124,11 +131,14 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||||
with open(cache_filepath, 'wb') as f: | with open(cache_filepath, 'wb') as f: | ||||
_pickle.dump(results, f) | _pickle.dump(results, f) | ||||
print("Save cache to {}.".format(cache_filepath)) | print("Save cache to {}.".format(cache_filepath)) | ||||
return results | return results | ||||
return wrapper | return wrapper | ||||
return wrapper_ | return wrapper_ | ||||
# def save_pickle(obj, pickle_path, file_name): | # def save_pickle(obj, pickle_path, file_name): | ||||
# """Save an object into a pickle file. | # """Save an object into a pickle file. | ||||
# | # | ||||
@@ -196,7 +206,7 @@ def _move_model_to_device(model, device): | |||||
""" | """ | ||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel): | if isinstance(model, torch.nn.parallel.DistributedDataParallel): | ||||
raise RuntimeError("model of `torch.nn.parallel.DistributedDataParallel` is not supported right now.") | raise RuntimeError("model of `torch.nn.parallel.DistributedDataParallel` is not supported right now.") | ||||
if device is None: | if device is None: | ||||
if isinstance(model, torch.nn.DataParallel): | if isinstance(model, torch.nn.DataParallel): | ||||
model.cuda() | model.cuda() | ||||
@@ -205,34 +215,35 @@ def _move_model_to_device(model, device): | |||||
if not torch.cuda.is_available() and ( | if not torch.cuda.is_available() and ( | ||||
device != 'cpu' or (isinstance(device, torch.device) and device.type != 'cpu')): | device != 'cpu' or (isinstance(device, torch.device) and device.type != 'cpu')): | ||||
raise ValueError("There is no usable gpu. set `device` as `cpu` or `None`.") | raise ValueError("There is no usable gpu. set `device` as `cpu` or `None`.") | ||||
if isinstance(model, torch.nn.DataParallel): | if isinstance(model, torch.nn.DataParallel): | ||||
raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.") | raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.") | ||||
if isinstance(device, int): | if isinstance(device, int): | ||||
assert device>-1, "device can only be non-negative integer" | |||||
assert torch.cuda.device_count()>device, "Only has {} gpus, cannot use device {}.".format(torch.cuda.device_count(), | |||||
device) | |||||
assert device > -1, "device can only be non-negative integer" | |||||
assert torch.cuda.device_count() > device, "Only has {} gpus, cannot use device {}.".format( | |||||
torch.cuda.device_count(), | |||||
device) | |||||
device = torch.device('cuda:{}'.format(device)) | device = torch.device('cuda:{}'.format(device)) | ||||
elif isinstance(device, str): | elif isinstance(device, str): | ||||
device = torch.device(device) | device = torch.device(device) | ||||
if device.type == 'cuda' and device.index is not None: | if device.type == 'cuda' and device.index is not None: | ||||
assert device.index<torch.cuda.device_count(), "Only has {} gpus, cannot use device cuda:{}.".format( | |||||
torch.cuda.device_count(), | |||||
device) | |||||
assert device.index < torch.cuda.device_count(), "Only has {} gpus, cannot use device cuda:{}.".format( | |||||
torch.cuda.device_count(), | |||||
device) | |||||
elif isinstance(device, torch.device): | elif isinstance(device, torch.device): | ||||
if device.type == 'cuda' and device.index is not None: | if device.type == 'cuda' and device.index is not None: | ||||
assert device.index<torch.cuda.device_count(), "Only has {} gpus, cannot use device cuda:{}.".format( | |||||
torch.cuda.device_count(), | |||||
device) | |||||
assert device.index < torch.cuda.device_count(), "Only has {} gpus, cannot use device cuda:{}.".format( | |||||
torch.cuda.device_count(), | |||||
device) | |||||
elif isinstance(device, list): | elif isinstance(device, list): | ||||
types = set([type(d) for d in device]) | types = set([type(d) for d in device]) | ||||
assert len(types)==1, "Mixed type in device, only `int` allowed." | |||||
assert len(types) == 1, "Mixed type in device, only `int` allowed." | |||||
assert list(types)[0] == int, "Only int supported for multiple devices." | assert list(types)[0] == int, "Only int supported for multiple devices." | ||||
assert len(set(device))==len(device), "Duplicated device id found in device." | |||||
assert len(set(device)) == len(device), "Duplicated device id found in device." | |||||
for d in device: | for d in device: | ||||
assert d>-1, "Only non-negative device id allowed." | |||||
if len(device)>1: | |||||
assert d > -1, "Only non-negative device id allowed." | |||||
if len(device) > 1: | |||||
output_device = device[0] | output_device = device[0] | ||||
model = nn.DataParallel(model, device_ids=device, output_device=output_device) | model = nn.DataParallel(model, device_ids=device, output_device=output_device) | ||||
device = torch.device(device[0]) | device = torch.device(device[0]) | ||||
@@ -250,9 +261,9 @@ def _get_model_device(model): | |||||
:return: torch.device,None 如果返回值为None,说明这个模型没有任何参数。 | :return: torch.device,None 如果返回值为None,说明这个模型没有任何参数。 | ||||
""" | """ | ||||
assert isinstance(model, nn.Module) | assert isinstance(model, nn.Module) | ||||
parameters = list(model.parameters()) | parameters = list(model.parameters()) | ||||
if len(parameters)==0: | |||||
if len(parameters) == 0: | |||||
return None | return None | ||||
else: | else: | ||||
return parameters[0].device | return parameters[0].device | ||||
@@ -407,7 +418,7 @@ def _move_dict_value_to_device(*args, device: torch.device, non_blocking=False): | |||||
if not isinstance(device, torch.device): | if not isinstance(device, torch.device): | ||||
raise TypeError(f"device must be `torch.device`, got `{type(device)}`") | raise TypeError(f"device must be `torch.device`, got `{type(device)}`") | ||||
for arg in args: | for arg in args: | ||||
if isinstance(arg, dict): | if isinstance(arg, dict): | ||||
for key, value in arg.items(): | for key, value in arg.items(): | ||||
@@ -422,10 +433,10 @@ class _CheckError(Exception): | |||||
_CheckError. Used in losses.LossBase, metrics.MetricBase. | _CheckError. Used in losses.LossBase, metrics.MetricBase. | ||||
""" | """ | ||||
def __init__(self, check_res: _CheckRes, func_signature: str): | def __init__(self, check_res: _CheckRes, func_signature: str): | ||||
errs = [f'Problems occurred when calling `{func_signature}`'] | errs = [f'Problems occurred when calling `{func_signature}`'] | ||||
if check_res.varargs: | if check_res.varargs: | ||||
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") | errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") | ||||
if check_res.missing: | if check_res.missing: | ||||
@@ -434,9 +445,9 @@ class _CheckError(Exception): | |||||
errs.append(f"\tduplicated param: {check_res.duplicated}") | errs.append(f"\tduplicated param: {check_res.duplicated}") | ||||
if check_res.unused: | if check_res.unused: | ||||
errs.append(f"\tunused param: {check_res.unused}") | errs.append(f"\tunused param: {check_res.unused}") | ||||
Exception.__init__(self, '\n'.join(errs)) | Exception.__init__(self, '\n'.join(errs)) | ||||
self.check_res = check_res | self.check_res = check_res | ||||
self.func_signature = func_signature | self.func_signature = func_signature | ||||
@@ -456,7 +467,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||||
# if check_res.varargs: | # if check_res.varargs: | ||||
# errs.append(f"\tvarargs: *{check_res.varargs}") | # errs.append(f"\tvarargs: *{check_res.varargs}") | ||||
# suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") | # suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") | ||||
if check_res.unused: | if check_res.unused: | ||||
for _unused in check_res.unused: | for _unused in check_res.unused: | ||||
if _unused in target_dict: | if _unused in target_dict: | ||||
@@ -466,8 +477,8 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||||
if _unused_field: | if _unused_field: | ||||
unuseds.append(f"\tunused field: {_unused_field}") | unuseds.append(f"\tunused field: {_unused_field}") | ||||
if _unused_param: | if _unused_param: | ||||
unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward | |||||
unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward | |||||
module_name = func_signature.split('.')[0] | module_name = func_signature.split('.')[0] | ||||
if check_res.missing: | if check_res.missing: | ||||
errs.append(f"\tmissing param: {check_res.missing}") | errs.append(f"\tmissing param: {check_res.missing}") | ||||
@@ -488,14 +499,14 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||||
mapped_missing.append(_miss) | mapped_missing.append(_miss) | ||||
else: | else: | ||||
unmapped_missing.append(_miss) | unmapped_missing.append(_miss) | ||||
for _miss in mapped_missing + unmapped_missing: | for _miss in mapped_missing + unmapped_missing: | ||||
if _miss in dataset: | if _miss in dataset: | ||||
suggestions.append(f"Set `{_miss}` as target.") | suggestions.append(f"Set `{_miss}` as target.") | ||||
else: | else: | ||||
_tmp = '' | _tmp = '' | ||||
if check_res.unused: | if check_res.unused: | ||||
_tmp = f"Check key assignment for `{input_func_map.get(_miss, _miss)}` when initialize {module_name}." | |||||
_tmp = f"Check key assignment for `{input_func_map.get(_miss,_miss)}` when initialize {module_name}." | |||||
if _tmp: | if _tmp: | ||||
_tmp += f' Or provide `{_miss}` in DataSet or output of {prev_func_signature}.' | _tmp += f' Or provide `{_miss}` in DataSet or output of {prev_func_signature}.' | ||||
else: | else: | ||||
@@ -513,25 +524,25 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||||
# else: | # else: | ||||
# _tmp = f'Provide `{_miss}` in output of {prev_func_signature} or DataSet.' | # _tmp = f'Provide `{_miss}` in output of {prev_func_signature} or DataSet.' | ||||
# suggestions.append(_tmp) | # suggestions.append(_tmp) | ||||
if check_res.duplicated: | if check_res.duplicated: | ||||
errs.append(f"\tduplicated param: {check_res.duplicated}.") | errs.append(f"\tduplicated param: {check_res.duplicated}.") | ||||
suggestions.append(f"Delete {check_res.duplicated} in the output of " | suggestions.append(f"Delete {check_res.duplicated} in the output of " | ||||
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") | f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") | ||||
if len(errs)>0: | |||||
if len(errs) > 0: | |||||
errs.extend(unuseds) | errs.extend(unuseds) | ||||
elif check_level == STRICT_CHECK_LEVEL: | elif check_level == STRICT_CHECK_LEVEL: | ||||
errs.extend(unuseds) | errs.extend(unuseds) | ||||
if len(errs) > 0: | if len(errs) > 0: | ||||
errs.insert(0, f'Problems occurred when calling {func_signature}') | errs.insert(0, f'Problems occurred when calling {func_signature}') | ||||
sugg_str = "" | sugg_str = "" | ||||
if len(suggestions) > 1: | if len(suggestions) > 1: | ||||
for idx, sugg in enumerate(suggestions): | for idx, sugg in enumerate(suggestions): | ||||
if idx>0: | |||||
if idx > 0: | |||||
sugg_str += '\t\t\t' | sugg_str += '\t\t\t' | ||||
sugg_str += f'({idx+1}). {sugg}\n' | |||||
sugg_str += f'({idx + 1}). {sugg}\n' | |||||
sugg_str = sugg_str[:-1] | sugg_str = sugg_str[:-1] | ||||
else: | else: | ||||
sugg_str += suggestions[0] | sugg_str += suggestions[0] | ||||
@@ -546,14 +557,15 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||||
_unused_warn = f'{check_res.unused} is not used by {module_name}.' | _unused_warn = f'{check_res.unused} is not used by {module_name}.' | ||||
warnings.warn(message=_unused_warn) | warnings.warn(message=_unused_warn) | ||||
def _check_forward_error(forward_func, batch_x, dataset, check_level): | def _check_forward_error(forward_func, batch_x, dataset, check_level): | ||||
check_res = _check_arg_dict_list(forward_func, batch_x) | check_res = _check_arg_dict_list(forward_func, batch_x) | ||||
func_signature = _get_func_signature(forward_func) | func_signature = _get_func_signature(forward_func) | ||||
errs = [] | errs = [] | ||||
suggestions = [] | suggestions = [] | ||||
_unused = [] | _unused = [] | ||||
# if check_res.varargs: | # if check_res.varargs: | ||||
# errs.append(f"\tvarargs: {check_res.varargs}") | # errs.append(f"\tvarargs: {check_res.varargs}") | ||||
# suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") | # suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") | ||||
@@ -574,20 +586,20 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||||
# _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \ | # _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \ | ||||
# f"rename the field in `unused field:`." | # f"rename the field in `unused field:`." | ||||
suggestions.append(_tmp) | suggestions.append(_tmp) | ||||
if check_res.unused: | if check_res.unused: | ||||
_unused = [f"\tunused field: {check_res.unused}"] | _unused = [f"\tunused field: {check_res.unused}"] | ||||
if len(errs)>0: | |||||
if len(errs) > 0: | |||||
errs.extend(_unused) | errs.extend(_unused) | ||||
elif check_level == STRICT_CHECK_LEVEL: | elif check_level == STRICT_CHECK_LEVEL: | ||||
errs.extend(_unused) | errs.extend(_unused) | ||||
if len(errs) > 0: | if len(errs) > 0: | ||||
errs.insert(0, f'Problems occurred when calling {func_signature}') | errs.insert(0, f'Problems occurred when calling {func_signature}') | ||||
sugg_str = "" | sugg_str = "" | ||||
if len(suggestions) > 1: | if len(suggestions) > 1: | ||||
for idx, sugg in enumerate(suggestions): | for idx, sugg in enumerate(suggestions): | ||||
sugg_str += f'({idx+1}). {sugg}' | |||||
sugg_str += f'({idx + 1}). {sugg}' | |||||
else: | else: | ||||
sugg_str += suggestions[0] | sugg_str += suggestions[0] | ||||
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str | err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str | ||||
@@ -622,8 +634,8 @@ def seq_len_to_mask(seq_len): | |||||
assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}." | assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}." | ||||
max_len = int(seq_len.max()) | max_len = int(seq_len.max()) | ||||
broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) | broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) | ||||
mask = broad_cast_seq_len<seq_len.reshape(-1, 1) | |||||
mask = broad_cast_seq_len < seq_len.reshape(-1, 1) | |||||
elif isinstance(seq_len, torch.Tensor): | elif isinstance(seq_len, torch.Tensor): | ||||
assert seq_len.dim() == 1, f"seq_len can only have one dimension, got {seq_len.dim() == 1}." | assert seq_len.dim() == 1, f"seq_len can only have one dimension, got {seq_len.dim() == 1}." | ||||
batch_size = seq_len.size(0) | batch_size = seq_len.size(0) | ||||
@@ -632,7 +644,7 @@ def seq_len_to_mask(seq_len): | |||||
mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1)) | mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1)) | ||||
else: | else: | ||||
raise TypeError("Only support 1-d numpy.ndarray or 1-d torch.Tensor.") | raise TypeError("Only support 1-d numpy.ndarray or 1-d torch.Tensor.") | ||||
return mask | return mask | ||||
@@ -640,24 +652,24 @@ class _pseudo_tqdm: | |||||
""" | """ | ||||
当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 | 当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 | ||||
""" | """ | ||||
def __init__(self, **kwargs): | def __init__(self, **kwargs): | ||||
pass | pass | ||||
def write(self, info): | def write(self, info): | ||||
print(info) | print(info) | ||||
def set_postfix_str(self, info): | def set_postfix_str(self, info): | ||||
print(info) | print(info) | ||||
def __getattr__(self, item): | def __getattr__(self, item): | ||||
def pass_func(*args, **kwargs): | def pass_func(*args, **kwargs): | ||||
pass | pass | ||||
return pass_func | return pass_func | ||||
def __enter__(self): | def __enter__(self): | ||||
return self | return self | ||||
def __exit__(self, exc_type, exc_val, exc_tb): | def __exit__(self, exc_type, exc_val, exc_tb): | ||||
del self | del self |
@@ -1,7 +1,12 @@ | |||||
from functools import wraps | from functools import wraps | ||||
from collections import Counter | from collections import Counter | ||||
from .dataset import DataSet | from .dataset import DataSet | ||||
__all__ = [ | |||||
"Vocabulary" | |||||
] | |||||
def _check_build_vocab(func): | def _check_build_vocab(func): | ||||
"""A decorator to make sure the indexing is built before used. | """A decorator to make sure the indexing is built before used. | ||||