@@ -151,16 +151,19 @@ class DataSet(object): | |||||
assert name in self.field_arrays | assert name in self.field_arrays | ||||
self.field_arrays[name].append(field) | self.field_arrays[name].append(field) | ||||
def add_field(self, name, fields, padder=AutoPadder(pad_val=0), is_input=False, is_target=False, ignore_type=False): | |||||
def add_field(self, name, fields, padder=None, is_input=False, is_target=False, ignore_type=False): | |||||
"""Add a new field to the DataSet. | """Add a new field to the DataSet. | ||||
:param str name: the name of the field. | :param str name: the name of the field. | ||||
:param fields: a list of int, float, or other objects. | :param fields: a list of int, float, or other objects. | ||||
:param int padder: PadBase对象,如何对该Field进行padding。大部分情况使用默认值即可 | |||||
:param padder: PadBase对象,如何对该Field进行padding。如果为None则使用 | |||||
:param bool is_input: whether this field is model input. | :param bool is_input: whether this field is model input. | ||||
:param bool is_target: whether this field is label or target. | :param bool is_target: whether this field is label or target. | ||||
:param bool ignore_type: If True, do not perform type check. (Default: False) | :param bool ignore_type: If True, do not perform type check. (Default: False) | ||||
""" | """ | ||||
if padder is None: | |||||
padder = AutoPadder(pad_val=0) | |||||
if len(self.field_arrays) != 0: | if len(self.field_arrays) != 0: | ||||
if len(self) != len(fields): | if len(self) != len(fields): | ||||
raise RuntimeError(f"The field to append must have the same size as dataset. " | raise RuntimeError(f"The field to append must have the same size as dataset. " | ||||
@@ -231,8 +234,8 @@ class DataSet(object): | |||||
raise KeyError("{} is not a valid field name.".format(name)) | raise KeyError("{} is not a valid field name.".format(name)) | ||||
def set_padder(self, field_name, padder): | def set_padder(self, field_name, padder): | ||||
""" | |||||
为field_name设置padder | |||||
"""为field_name设置padder | |||||
:param field_name: str, 设置field的padding方式为padder | :param field_name: str, 设置field的padding方式为padder | ||||
:param padder: PadderBase类型或None. 设置为None即删除padder。即对该field不进行padding操作. | :param padder: PadderBase类型或None. 设置为None即删除padder。即对该field不进行padding操作. | ||||
:return: | :return: | ||||
@@ -242,8 +245,7 @@ class DataSet(object): | |||||
self.field_arrays[field_name].set_padder(padder) | self.field_arrays[field_name].set_padder(padder) | ||||
def set_pad_val(self, field_name, pad_val): | def set_pad_val(self, field_name, pad_val): | ||||
""" | |||||
为某个 | |||||
"""为某个field设置对应的pad_val. | |||||
:param field_name: str,修改该field的pad_val | :param field_name: str,修改该field的pad_val | ||||
:param pad_val: int,该field的padder会以pad_val作为padding index | :param pad_val: int,该field的padder会以pad_val作为padding index | ||||
@@ -254,43 +256,60 @@ class DataSet(object): | |||||
self.field_arrays[field_name].set_pad_val(pad_val) | self.field_arrays[field_name].set_pad_val(pad_val) | ||||
def get_input_name(self): | def get_input_name(self): | ||||
"""Get all field names with `is_input` as True. | |||||
"""返回所有is_input被设置为True的field名称 | |||||
:return field_names: a list of str | |||||
:return list, 里面的元素为被设置为input的field名称 | |||||
""" | """ | ||||
return [name for name, field in self.field_arrays.items() if field.is_input] | return [name for name, field in self.field_arrays.items() if field.is_input] | ||||
def get_target_name(self): | def get_target_name(self): | ||||
"""Get all field names with `is_target` as True. | |||||
"""返回所有is_target被设置为True的field名称 | |||||
:return field_names: a list of str | |||||
:return list, 里面的元素为被设置为target的field名称 | |||||
""" | """ | ||||
return [name for name, field in self.field_arrays.items() if field.is_target] | return [name for name, field in self.field_arrays.items() if field.is_target] | ||||
def apply(self, func, new_field_name=None, **kwargs): | |||||
"""Apply a function to every instance of the DataSet. | |||||
:param func: a function that takes an instance as input. | |||||
:param str new_field_name: If not None, results of the function will be stored as a new field. | |||||
:param **kwargs: Accept parameters will be | |||||
(1) is_input: boolean, will be ignored if new_field is None. If True, the new field will be as input. | |||||
(2) is_target: boolean, will be ignored if new_field is None. If True, the new field will be as target. | |||||
:return results: if new_field_name is not passed, returned values of the function over all instances. | |||||
def apply_field(self, func, field_name, new_field_name=None, **kwargs): | |||||
"""将DataSet中的每个instance中的`field_name`这个field传给func,并获取它的返回值. | |||||
:param func: Callable, input是instance的`field_name`这个field. | |||||
:param field_name: str, 传入func的是哪个field. | |||||
:param new_field_name: (str, None). 如果不是None,将func的返回值放入这个名为`new_field_name`的新field中,如果名称与已有 | |||||
的field相同,则覆盖之前的field. | |||||
:param **kwargs: 合法的参数有以下三个 | |||||
(1) is_input: bool, 如果为True则将`new_field_name`这个field设置为input | |||||
(2) is_target: bool, 如果为True则将`new_field_name`这个field设置为target | |||||
(3) ignore_type: bool, 如果为True则将`new_field_name`这个field的ignore_type设置为true, 忽略其类型 | |||||
:return: List[], 里面的元素为func的返回值,所以list长度为DataSet的长度 | |||||
""" | """ | ||||
assert len(self)!=0, "Null dataset cannot use .apply()." | |||||
assert len(self)!=0, "Null DataSet cannot use apply()." | |||||
if field_name not in self: | |||||
raise KeyError("DataSet has no field named `{}`.".format(field_name)) | |||||
results = [] | results = [] | ||||
idx = -1 | idx = -1 | ||||
try: | try: | ||||
for idx, ins in enumerate(self._inner_iter()): | for idx, ins in enumerate(self._inner_iter()): | ||||
results.append(func(ins)) | |||||
results.append(func(ins[field_name])) | |||||
except Exception as e: | except Exception as e: | ||||
if idx!=-1: | if idx!=-1: | ||||
print("Exception happens at the `{}`th instance.".format(idx)) | print("Exception happens at the `{}`th instance.".format(idx)) | ||||
raise e | raise e | ||||
# results = [func(ins) for ins in self._inner_iter()] | |||||
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | ||||
raise ValueError("{} always return None.".format(get_func_signature(func=func))) | raise ValueError("{} always return None.".format(get_func_signature(func=func))) | ||||
if new_field_name is not None: | |||||
self._add_apply_field(results, new_field_name, kwargs) | |||||
return results | |||||
def _add_apply_field(self, results, new_field_name, kwargs): | |||||
"""将results作为加入到新的field中,field名称为new_field_name | |||||
:param results: List[], 一般是apply*()之后的结果 | |||||
:param new_field_name: str, 新加入的field的名称 | |||||
:param kwargs: dict, 用户apply*()时传入的自定义参数 | |||||
:return: | |||||
""" | |||||
extra_param = {} | extra_param = {} | ||||
if 'is_input' in kwargs: | if 'is_input' in kwargs: | ||||
extra_param['is_input'] = kwargs['is_input'] | extra_param['is_input'] = kwargs['is_input'] | ||||
@@ -298,56 +317,84 @@ class DataSet(object): | |||||
extra_param['is_target'] = kwargs['is_target'] | extra_param['is_target'] = kwargs['is_target'] | ||||
if 'ignore_type' in kwargs: | if 'ignore_type' in kwargs: | ||||
extra_param['ignore_type'] = kwargs['ignore_type'] | extra_param['ignore_type'] = kwargs['ignore_type'] | ||||
if new_field_name is not None: | |||||
if new_field_name in self.field_arrays: | |||||
# overwrite the field, keep same attributes | |||||
old_field = self.field_arrays[new_field_name] | |||||
if 'is_input' not in extra_param: | |||||
extra_param['is_input'] = old_field.is_input | |||||
if 'is_target' not in extra_param: | |||||
extra_param['is_target'] = old_field.is_target | |||||
if 'ignore_type' not in extra_param: | |||||
extra_param['ignore_type'] = old_field.ignore_type | |||||
self.add_field(name=new_field_name, fields=results, is_input=extra_param["is_input"], | |||||
is_target=extra_param["is_target"], ignore_type=extra_param['ignore_type']) | |||||
else: | |||||
self.add_field(name=new_field_name, fields=results, is_input=extra_param.get("is_input", None), | |||||
is_target=extra_param.get("is_target", None), | |||||
ignore_type=extra_param.get("ignore_type", False)) | |||||
if new_field_name in self.field_arrays: | |||||
# overwrite the field, keep same attributes | |||||
old_field = self.field_arrays[new_field_name] | |||||
if 'is_input' not in extra_param: | |||||
extra_param['is_input'] = old_field.is_input | |||||
if 'is_target' not in extra_param: | |||||
extra_param['is_target'] = old_field.is_target | |||||
if 'ignore_type' not in extra_param: | |||||
extra_param['ignore_type'] = old_field.ignore_type | |||||
self.add_field(name=new_field_name, fields=results, is_input=extra_param["is_input"], | |||||
is_target=extra_param["is_target"], ignore_type=extra_param['ignore_type']) | |||||
else: | else: | ||||
return results | |||||
self.add_field(name=new_field_name, fields=results, is_input=extra_param.get("is_input", None), | |||||
is_target=extra_param.get("is_target", None), | |||||
ignore_type=extra_param.get("ignore_type", False)) | |||||
def apply(self, func, new_field_name=None, **kwargs): | |||||
"""将DataSet中每个instance传入到func中,并获取它的返回值. | |||||
:param func: Callable, 参数是DataSet中的instance | |||||
:param new_field_name: (None, str). (1) None, 不创建新的field; (2) str,将func的返回值放入这个名为 | |||||
`new_field_name`的新field中,如果名称与已有的field相同,则覆盖之前的field; | |||||
:param kwargs: 合法的参数有以下三个 | |||||
(1) is_input: bool, 如果为True则将`new_field_name`的field设置为input | |||||
(2) is_target: bool, 如果为True则将`new_field_name`的field设置为target | |||||
(3) ignore_type: bool, 如果为True则将`new_field_name`的field的ignore_type设置为true, 忽略其类型 | |||||
:return: List[], 里面的元素为func的返回值,所以list长度为DataSet的长度 | |||||
""" | |||||
assert len(self)!=0, "Null DataSet cannot use apply()." | |||||
idx = -1 | |||||
try: | |||||
results = [] | |||||
for idx, ins in enumerate(self._inner_iter()): | |||||
results.append(func(ins)) | |||||
except Exception as e: | |||||
if idx!=-1: | |||||
print("Exception happens at the `{}`th instance.".format(idx)) | |||||
raise e | |||||
# results = [func(ins) for ins in self._inner_iter()] | |||||
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | |||||
raise ValueError("{} always return None.".format(get_func_signature(func=func))) | |||||
if new_field_name is not None: | |||||
self._add_apply_field(results, new_field_name, kwargs) | |||||
return results | |||||
def drop(self, func, inplace=True): | def drop(self, func, inplace=True): | ||||
"""Drop instances if a condition holds. | |||||
"""func接受一个instance,返回bool值,返回值为True时,该instance会被删除。 | |||||
:param func: a function that takes an Instance object as input, and returns bool. | |||||
The instance will be dropped if the function returns True. | |||||
:param inplace: bool, whether to drop inpalce. Otherwise a new dataset will be returned. | |||||
:param func: Callable, 接受一个instance作为参数,返回bool值。为True时删除该instance | |||||
:param inplace: bool, 是否在当前DataSet中直接删除instance。如果为False,返回值为一个删除了相应instance的新的DataSet | |||||
:return: DataSet. | |||||
""" | """ | ||||
if inplace: | if inplace: | ||||
results = [ins for ins in self._inner_iter() if not func(ins)] | results = [ins for ins in self._inner_iter() if not func(ins)] | ||||
for name, old_field in self.field_arrays.items(): | for name, old_field in self.field_arrays.items(): | ||||
self.field_arrays[name].content = [ins[name] for ins in results] | self.field_arrays[name].content = [ins[name] for ins in results] | ||||
return self | |||||
else: | else: | ||||
results = [ins for ins in self if not func(ins)] | results = [ins for ins in self if not func(ins)] | ||||
data = DataSet(results) | data = DataSet(results) | ||||
for field_name, field in self.field_arrays.items(): | for field_name, field in self.field_arrays.items(): | ||||
data.field_arrays[field_name].to(field) | data.field_arrays[field_name].to(field) | ||||
return data | |||||
def split(self, dev_ratio): | |||||
"""Split the dataset into training and development(validation) set. | |||||
def split(self, ratio): | |||||
"""将DataSet按照ratio的比例拆分,返回两个DataSet | |||||
:param float dev_ratio: the ratio of test set in all data. | |||||
:return (train_set, dev_set): | |||||
train_set: the training set | |||||
dev_set: the development set | |||||
:param ratio: float, 0<ratio<1, 返回的第一个DataSet拥有ratio这么多数据,第二个DataSet拥有(1-ratio)这么多数据 | |||||
:return (DataSet, DataSet) | |||||
""" | """ | ||||
assert isinstance(dev_ratio, float) | |||||
assert 0 < dev_ratio < 1 | |||||
assert isinstance(ratio, float) | |||||
assert 0 < ratio < 1 | |||||
all_indices = [_ for _ in range(len(self))] | all_indices = [_ for _ in range(len(self))] | ||||
np.random.shuffle(all_indices) | np.random.shuffle(all_indices) | ||||
split = int(dev_ratio * len(self)) | |||||
split = int(ratio * len(self)) | |||||
dev_indices = all_indices[:split] | dev_indices = all_indices[:split] | ||||
train_indices = all_indices[split:] | train_indices = all_indices[split:] | ||||
dev_set = DataSet() | dev_set = DataSet() | ||||
@@ -398,26 +445,25 @@ class DataSet(object): | |||||
_dict[header].append(content) | _dict[header].append(content) | ||||
return cls(_dict) | return cls(_dict) | ||||
# def read_pos(self): | |||||
# return DataLoaderRegister.get_reader('read_pos') | |||||
def save(self, path): | def save(self, path): | ||||
"""Save the DataSet object as pickle. | |||||
"""保存DataSet. | |||||
:param str path: the path to the pickle | |||||
:param path: str, 将DataSet存在哪个路径 | |||||
""" | """ | ||||
with open(path, 'wb') as f: | with open(path, 'wb') as f: | ||||
pickle.dump(self, f) | pickle.dump(self, f) | ||||
@staticmethod | @staticmethod | ||||
def load(path): | def load(path): | ||||
"""Load a DataSet object from pickle. | |||||
"""从保存的DataSet pickle路径中读取DataSet | |||||
:param str path: the path to the pickle | |||||
:return data_set: | |||||
:param path: str, 读取路径 | |||||
:return DataSet: | |||||
""" | """ | ||||
with open(path, 'rb') as f: | with open(path, 'rb') as f: | ||||
return pickle.load(f) | |||||
d = pickle.load(f) | |||||
assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) | |||||
return d | |||||
def construct_dataset(sentences): | def construct_dataset(sentences): | ||||
@@ -84,7 +84,7 @@ class AutoPadder(PadderBase): | |||||
for i, content in enumerate(contents): | for i, content in enumerate(contents): | ||||
array[i][:len(content)] = content | array[i][:len(content)] = content | ||||
elif field_ele_dtype is None: | elif field_ele_dtype is None: | ||||
array = contents # 当ignore_type=True时,直接返回contents | |||||
array = np.array(contents) # 当ignore_type=True时,直接返回contents | |||||
else: # should only be str | else: # should only be str | ||||
array = np.array([content for content in contents]) | array = np.array([content for content in contents]) | ||||
return array | return array | ||||
@@ -290,9 +290,10 @@ class FieldArray(object): | |||||
return "FieldArray {}: {}".format(self.name, self.content.__repr__()) | return "FieldArray {}: {}".format(self.name, self.content.__repr__()) | ||||
def append(self, val): | def append(self, val): | ||||
"""Add a new item to the tail of FieldArray. | |||||
"""将val增加到FieldArray中,若该field的ignore_type为True则直接append到这个field中;若ignore_type为False,且当前field为 | |||||
input或者target,则会检查传入的content是否与之前的内容在dimension, 元素的类型上是匹配的。 | |||||
:param val: int, float, str, or a list of one. | |||||
:param val: Any. | |||||
""" | """ | ||||
if self.ignore_type is False: | if self.ignore_type is False: | ||||
if isinstance(val, list): | if isinstance(val, list): | ||||
@@ -367,13 +368,14 @@ class FieldArray(object): | |||||
self.padder = deepcopy(padder) | self.padder = deepcopy(padder) | ||||
def set_pad_val(self, pad_val): | def set_pad_val(self, pad_val): | ||||
""" | |||||
修改padder的pad_val. | |||||
:param pad_val: int。 | |||||
"""修改padder的pad_val. | |||||
:param pad_val: int。将该field的pad值设置为该值 | |||||
:return: | :return: | ||||
""" | """ | ||||
if self.padder is not None: | if self.padder is not None: | ||||
self.padder.set_pad_val(pad_val) | self.padder.set_pad_val(pad_val) | ||||
return self | |||||
def __len__(self): | def __len__(self): | ||||
@@ -385,8 +387,7 @@ class FieldArray(object): | |||||
def to(self, other): | def to(self, other): | ||||
""" | """ | ||||
将other的属性复制给本fieldarray(必须通过fieldarray类型). 包含 is_input, is_target, padder, dtype, pytype, content_dim | |||||
ignore_type | |||||
将other的属性复制给本FieldArray(other必须为FieldArray类型). 包含 is_input, is_target, padder, ignore_type | |||||
:param other: FieldArray | :param other: FieldArray | ||||
:return: | :return: | ||||
@@ -396,11 +397,10 @@ class FieldArray(object): | |||||
self.is_input = other.is_input | self.is_input = other.is_input | ||||
self.is_target = other.is_target | self.is_target = other.is_target | ||||
self.padder = other.padder | self.padder = other.padder | ||||
self.dtype = other.dtype | |||||
self.pytype = other.pytype | |||||
self.content_dim = other.content_dim | |||||
self.ignore_type = other.ignore_type | self.ignore_type = other.ignore_type | ||||
return self | |||||
def is_iterable(content): | def is_iterable(content): | ||||
try: | try: | ||||
_ = (e for e in content) | _ = (e for e in content) | ||||
@@ -24,7 +24,7 @@ def _prepare_cache_filepath(filepath): | |||||
if not os.path.exists(cache_dir): | if not os.path.exists(cache_dir): | ||||
os.makedirs(cache_dir) | os.makedirs(cache_dir) | ||||
# TODO 可以保存下缓存时的参数,如果load的时候发现参数不一致,发出警告。 | |||||
def cache_results(cache_filepath, refresh=False, verbose=1): | def cache_results(cache_filepath, refresh=False, verbose=1): | ||||
def wrapper_(func): | def wrapper_(func): | ||||
signature = inspect.signature(func) | signature = inspect.signature(func) | ||||
@@ -79,7 +79,7 @@ class SeqLabeling(BaseModel): | |||||
:return prediction: list of [decode path(list)] | :return prediction: list of [decode path(list)] | ||||
""" | """ | ||||
max_len = x.shape[1] | max_len = x.shape[1] | ||||
tag_seq = self.Crf.viterbi_decode(x, self.mask) | |||||
tag_seq, _ = self.Crf.viterbi_decode(x, self.mask) | |||||
# pad prediction to equal length | # pad prediction to equal length | ||||
if pad is True: | if pad is True: | ||||
for pred in tag_seq: | for pred in tag_seq: | ||||
@@ -2,12 +2,7 @@ import torch | |||||
from torch import nn | from torch import nn | ||||
from fastNLP.modules.utils import initial_parameter | from fastNLP.modules.utils import initial_parameter | ||||
def log_sum_exp(x, dim=-1): | |||||
max_value, _ = x.max(dim=dim, keepdim=True) | |||||
res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value | |||||
return res.squeeze(dim) | |||||
from fastNLP.modules.decoder.utils import log_sum_exp | |||||
def seq_len_to_byte_mask(seq_lens): | def seq_len_to_byte_mask(seq_lens): | ||||
@@ -20,22 +15,27 @@ def seq_len_to_byte_mask(seq_lens): | |||||
return mask | return mask | ||||
def allowed_transitions(id2label, encoding_type='bio'): | |||||
def allowed_transitions(id2label, encoding_type='bio', include_start_end=True): | |||||
""" | """ | ||||
给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 | |||||
:param dict id2label: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | |||||
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()id2label。 | |||||
:param id2label: Dict, key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | |||||
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()得到id2label。 | |||||
:param encoding_type: str, 支持"bio", "bmes", "bmeso"。 | :param encoding_type: str, 支持"bio", "bmes", "bmeso"。 | ||||
:return: List[Tuple(int, int)]], 内部的Tuple是(from_tag_id, to_tag_id)。 返回的结果考虑了start和end,比如"BIO"中,B、O可以 | |||||
位于序列的开端,而I不行。所以返回的结果中会包含(start_idx, B_idx), (start_idx, O_idx), 但是不包含(start_idx, I_idx). | |||||
start_idx=len(id2label), end_idx=len(id2label)+1。 | |||||
:param include_start_end: bool, 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | |||||
为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); | |||||
start_idx=len(id2label), end_idx=len(id2label)+1。 | |||||
为False, 返回的结果中不含与开始结尾相关的内容 | |||||
:return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。 | |||||
""" | """ | ||||
num_tags = len(id2label) | num_tags = len(id2label) | ||||
start_idx = num_tags | start_idx = num_tags | ||||
end_idx = num_tags + 1 | end_idx = num_tags + 1 | ||||
encoding_type = encoding_type.lower() | encoding_type = encoding_type.lower() | ||||
allowed_trans = [] | allowed_trans = [] | ||||
id_label_lst = list(id2label.items()) + [(start_idx, 'start'), (end_idx, 'end')] | |||||
id_label_lst = list(id2label.items()) | |||||
if include_start_end: | |||||
id_label_lst += [(start_idx, 'start'), (end_idx, 'end')] | |||||
def split_tag_label(from_label): | def split_tag_label(from_label): | ||||
from_label = from_label.lower() | from_label = from_label.lower() | ||||
if from_label in ['start', 'end']: | if from_label in ['start', 'end']: | ||||
@@ -54,12 +54,12 @@ def allowed_transitions(id2label, encoding_type='bio'): | |||||
if to_label in ['<pad>', '<unk>']: | if to_label in ['<pad>', '<unk>']: | ||||
continue | continue | ||||
to_tag, to_label = split_tag_label(to_label) | to_tag, to_label = split_tag_label(to_label) | ||||
if is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||||
if _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||||
allowed_trans.append((from_id, to_id)) | allowed_trans.append((from_id, to_id)) | ||||
return allowed_trans | return allowed_trans | ||||
def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||||
def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||||
""" | """ | ||||
:param encoding_type: str, 支持"BIO", "BMES", "BEMSO"。 | :param encoding_type: str, 支持"BIO", "BMES", "BEMSO"。 | ||||
@@ -140,20 +140,22 @@ def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label) | |||||
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag)) | raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag)) | ||||
else: | else: | ||||
raise ValueError("Only support BIO, BMES encoding type, got {}.".format(encoding_type)) | |||||
raise ValueError("Only support BIO, BMES, BMESO encoding type, got {}.".format(encoding_type)) | |||||
class ConditionalRandomField(nn.Module): | class ConditionalRandomField(nn.Module): | ||||
""" | |||||
:param int num_tags: 标签的数量。 | |||||
:param bool include_start_end_trans: 是否包含起始tag | |||||
:param list allowed_transitions: ``List[Tuple[from_tag_id(int), to_tag_id(int)]]``. 允许的跃迁,可以通过allowed_transitions()得到。 | |||||
如果为None,则所有跃迁均为合法 | |||||
:param str initial_method: | |||||
""" | |||||
def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, initial_method=None): | |||||
def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, | |||||
initial_method=None): | |||||
"""条件随机场。 | |||||
提供forward()以及viterbi_decode()两个方法,分别用于训练与inference。 | |||||
:param num_tags: int, 标签的数量 | |||||
:param include_start_end_trans: bool, 是否考虑各个tag作为开始以及结尾的分数。 | |||||
:param allowed_transitions: List[Tuple[from_tag_id(int), to_tag_id(int)]], 内部的Tuple[from_tag_id(int), | |||||
to_tag_id(int)]视为允许发生的跃迁,其他没有包含的跃迁认为是禁止跃迁,可以通过 | |||||
allowed_transitions()函数得到;如果为None,则所有跃迁均为合法 | |||||
:param initial_method: str, 初始化方法。见initial_parameter | |||||
""" | |||||
super(ConditionalRandomField, self).__init__() | super(ConditionalRandomField, self).__init__() | ||||
self.include_start_end_trans = include_start_end_trans | self.include_start_end_trans = include_start_end_trans | ||||
@@ -168,18 +170,12 @@ class ConditionalRandomField(nn.Module): | |||||
if allowed_transitions is None: | if allowed_transitions is None: | ||||
constrain = torch.zeros(num_tags + 2, num_tags + 2) | constrain = torch.zeros(num_tags + 2, num_tags + 2) | ||||
else: | else: | ||||
constrain = torch.ones(num_tags + 2, num_tags + 2) * -1000 | |||||
constrain = torch.new_full((num_tags+2, num_tags+2), fill_value=-10000.0, dtype=torch.float) | |||||
for from_tag_id, to_tag_id in allowed_transitions: | for from_tag_id, to_tag_id in allowed_transitions: | ||||
constrain[from_tag_id, to_tag_id] = 0 | constrain[from_tag_id, to_tag_id] = 0 | ||||
self._constrain = nn.Parameter(constrain, requires_grad=False) | self._constrain = nn.Parameter(constrain, requires_grad=False) | ||||
# self.reset_parameter() | |||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def reset_parameter(self): | |||||
nn.init.xavier_normal_(self.trans_m) | |||||
if self.include_start_end_trans: | |||||
nn.init.normal_(self.start_scores) | |||||
nn.init.normal_(self.end_scores) | |||||
def _normalizer_likelihood(self, logits, mask): | def _normalizer_likelihood(self, logits, mask): | ||||
"""Computes the (batch_size,) denominator term for the log-likelihood, which is the | """Computes the (batch_size,) denominator term for the log-likelihood, which is the | ||||
@@ -239,10 +235,11 @@ class ConditionalRandomField(nn.Module): | |||||
def forward(self, feats, tags, mask): | def forward(self, feats, tags, mask): | ||||
""" | """ | ||||
Calculate the neg log likelihood | |||||
:param feats:FloatTensor, batch_size x max_len x num_tags | |||||
:param tags:LongTensor, batch_size x max_len | |||||
:param mask:ByteTensor batch_size x max_len | |||||
用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。 | |||||
:param feats:FloatTensor, batch_size x max_len x num_tags,特征矩阵。 | |||||
:param tags:LongTensor, batch_size x max_len,标签矩阵。 | |||||
:param mask:ByteTensor batch_size x max_len,为0的位置认为是padding。 | |||||
:return:FloatTensor, batch_size | :return:FloatTensor, batch_size | ||||
""" | """ | ||||
feats = feats.transpose(0, 1) | feats = feats.transpose(0, 1) | ||||
@@ -253,28 +250,27 @@ class ConditionalRandomField(nn.Module): | |||||
return all_path_score - gold_path_score | return all_path_score - gold_path_score | ||||
def viterbi_decode(self, data, mask, get_score=False, unpad=False): | |||||
"""Given a feats matrix, return best decode path and best score. | |||||
def viterbi_decode(self, feats, mask, unpad=False): | |||||
"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | |||||
:param data:FloatTensor, batch_size x max_len x num_tags | |||||
:param mask:ByteTensor batch_size x max_len | |||||
:param get_score: bool, whether to output the decode score. | |||||
:param unpad: bool, 是否将结果unpad, | |||||
如果False, 返回的是batch_size x max_len的tensor, | |||||
如果True,返回的是List[List[int]], List[int]为每个sequence的label,已经unpadding了,即每个 | |||||
List[int]的长度是这个sample的有效长度 | |||||
:return: 如果get_score为False,返回结果根据unpadding变动 | |||||
如果get_score为True, 返回 (paths, List[float], )。第一个仍然是解码后的路径(根据unpad变化),第二个List[Float] | |||||
为每个seqence的解码分数。 | |||||
:param feats: FloatTensor, batch_size x max_len x num_tags,特征矩阵。 | |||||
:param mask: ByteTensor, batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 | |||||
:param unpad: bool, 是否将结果删去padding, | |||||
False, 返回的是batch_size x max_len的tensor, | |||||
True,返回的是List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int] | |||||
的长度是这个sample的有效长度。 | |||||
:return: 返回 (paths, scores)。 | |||||
paths: 是解码后的路径, 其值参照unpad参数. | |||||
scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。 | |||||
""" | """ | ||||
batch_size, seq_len, n_tags = data.size() | |||||
data = data.transpose(0, 1).data # L, B, H | |||||
batch_size, seq_len, n_tags = feats.size() | |||||
feats = feats.transpose(0, 1).data # L, B, H | |||||
mask = mask.transpose(0, 1).data.byte() # L, B | mask = mask.transpose(0, 1).data.byte() # L, B | ||||
# dp | # dp | ||||
vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | |||||
vscore = data[0] | |||||
vpath = feats.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | |||||
vscore = feats[0] | |||||
transitions = self._constrain.data.clone() | transitions = self._constrain.data.clone() | ||||
transitions[:n_tags, :n_tags] += self.trans_m.data | transitions[:n_tags, :n_tags] += self.trans_m.data | ||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
@@ -285,23 +281,24 @@ class ConditionalRandomField(nn.Module): | |||||
trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data | trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data | ||||
for i in range(1, seq_len): | for i in range(1, seq_len): | ||||
prev_score = vscore.view(batch_size, n_tags, 1) | prev_score = vscore.view(batch_size, n_tags, 1) | ||||
cur_score = data[i].view(batch_size, 1, n_tags) | |||||
cur_score = feats[i].view(batch_size, 1, n_tags) | |||||
score = prev_score + trans_score + cur_score | score = prev_score + trans_score + cur_score | ||||
best_score, best_dst = score.max(1) | best_score, best_dst = score.max(1) | ||||
vpath[i] = best_dst | vpath[i] = best_dst | ||||
vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ | vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ | ||||
vscore.masked_fill(mask[i].view(batch_size, 1), 0) | vscore.masked_fill(mask[i].view(batch_size, 1), 0) | ||||
vscore += transitions[:n_tags, n_tags+1].view(1, -1) | |||||
if self.include_start_end_trans: | |||||
vscore += transitions[:n_tags, n_tags+1].view(1, -1) | |||||
# backtrace | # backtrace | ||||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=data.device) | |||||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=data.device) | |||||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=feats.device) | |||||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=feats.device) | |||||
lens = (mask.long().sum(0) - 1) | lens = (mask.long().sum(0) - 1) | ||||
# idxes [L, B], batched idx from seq_len-1 to 0 | # idxes [L, B], batched idx from seq_len-1 to 0 | ||||
idxes = (lens.view(1,-1) - seq_idx.view(-1,1)) % seq_len | idxes = (lens.view(1,-1) - seq_idx.view(-1,1)) % seq_len | ||||
ans = data.new_empty((seq_len, batch_size), dtype=torch.long) | |||||
ans = feats.new_empty((seq_len, batch_size), dtype=torch.long) | |||||
ans_score, last_tags = vscore.max(1) | ans_score, last_tags = vscore.max(1) | ||||
ans[idxes[0], batch_idx] = last_tags | ans[idxes[0], batch_idx] = last_tags | ||||
for i in range(seq_len - 1): | for i in range(seq_len - 1): | ||||
@@ -0,0 +1,70 @@ | |||||
import torch | |||||
def log_sum_exp(x, dim=-1): | |||||
max_value, _ = x.max(dim=dim, keepdim=True) | |||||
res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value | |||||
return res.squeeze(dim) | |||||
def viterbi_decode(feats, transitions, mask=None, unpad=False): | |||||
"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | |||||
:param feats: FloatTensor, batch_size x max_len x num_tags,特征矩阵。 | |||||
:param transitions: FloatTensor, n_tags x n_tags。[i, j]位置的值认为是从tag i到tag j的转换。 | |||||
:param mask: ByteTensor, batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 | |||||
:param unpad: bool, 是否将结果删去padding, | |||||
False, 返回的是batch_size x max_len的tensor, | |||||
True,返回的是List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是 | |||||
这个sample的有效长度。 | |||||
:return: 返回 (paths, scores)。 | |||||
paths: 是解码后的路径, 其值参照unpad参数. | |||||
scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。 | |||||
""" | |||||
batch_size, seq_len, n_tags = feats.size() | |||||
assert n_tags==transitions.size(0) and n_tags==transitions.size(1), "The shapes of transitions and feats are not " \ | |||||
"compatible." | |||||
feats = feats.transpose(0, 1).data # L, B, H | |||||
if mask is not None: | |||||
mask = mask.transpose(0, 1).data.byte() # L, B | |||||
else: | |||||
mask = feats.new_ones((seq_len, batch_size), dtype=torch.uint8) | |||||
# dp | |||||
vpath = feats.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | |||||
vscore = feats[0] | |||||
vscore += transitions[n_tags, :n_tags] | |||||
trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data | |||||
for i in range(1, seq_len): | |||||
prev_score = vscore.view(batch_size, n_tags, 1) | |||||
cur_score = feats[i].view(batch_size, 1, n_tags) | |||||
score = prev_score + trans_score + cur_score | |||||
best_score, best_dst = score.max(1) | |||||
vpath[i] = best_dst | |||||
vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ | |||||
vscore.masked_fill(mask[i].view(batch_size, 1), 0) | |||||
# backtrace | |||||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=feats.device) | |||||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=feats.device) | |||||
lens = (mask.long().sum(0) - 1) | |||||
# idxes [L, B], batched idx from seq_len-1 to 0 | |||||
idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len | |||||
ans = feats.new_empty((seq_len, batch_size), dtype=torch.long) | |||||
ans_score, last_tags = vscore.max(1) | |||||
ans[idxes[0], batch_idx] = last_tags | |||||
for i in range(seq_len - 1): | |||||
last_tags = vpath[idxes[i], batch_idx, last_tags] | |||||
ans[idxes[i + 1], batch_idx] = last_tags | |||||
ans = ans.transpose(0, 1) | |||||
if unpad: | |||||
paths = [] | |||||
for idx, seq_len in enumerate(lens): | |||||
paths.append(ans[idx, :seq_len + 1].tolist()) | |||||
else: | |||||
paths = ans | |||||
return paths, ans_score |
@@ -183,7 +183,7 @@ class CWSBiLSTMCRF(BaseModel): | |||||
masks = seq_lens_to_mask(seq_lens) | masks = seq_lens_to_mask(seq_lens) | ||||
feats = self.encoder_model(chars, bigrams, seq_lens) | feats = self.encoder_model(chars, bigrams, seq_lens) | ||||
feats = self.decoder_model(feats) | feats = self.decoder_model(feats) | ||||
probs = self.crf.viterbi_decode(feats, masks, get_score=False) | |||||
paths, _ = self.crf.viterbi_decode(feats, masks) | |||||
return {'pred': probs, 'seq_lens':seq_lens} | |||||
return {'pred': paths, 'seq_lens':seq_lens} | |||||
@@ -72,9 +72,9 @@ class TransformerCWS(nn.Module): | |||||
feats = self.transformer(x, masks) | feats = self.transformer(x, masks) | ||||
feats = self.fc2(feats) | feats = self.fc2(feats) | ||||
probs = self.crf.viterbi_decode(feats, masks, get_score=False) | |||||
paths, _ = self.crf.viterbi_decode(feats, masks) | |||||
return {'pred': probs, 'seq_lens':seq_lens} | |||||
return {'pred': paths, 'seq_lens':seq_lens} | |||||
class NoamOpt(torch.optim.Optimizer): | class NoamOpt(torch.optim.Optimizer): | ||||