@@ -151,16 +151,19 @@ class DataSet(object): | |||
assert name in self.field_arrays | |||
self.field_arrays[name].append(field) | |||
def add_field(self, name, fields, padder=AutoPadder(pad_val=0), is_input=False, is_target=False, ignore_type=False): | |||
def add_field(self, name, fields, padder=None, is_input=False, is_target=False, ignore_type=False): | |||
"""Add a new field to the DataSet. | |||
:param str name: the name of the field. | |||
:param fields: a list of int, float, or other objects. | |||
:param int padder: PadBase对象,如何对该Field进行padding。大部分情况使用默认值即可 | |||
:param padder: PadBase对象,如何对该Field进行padding。如果为None则使用 | |||
:param bool is_input: whether this field is model input. | |||
:param bool is_target: whether this field is label or target. | |||
:param bool ignore_type: If True, do not perform type check. (Default: False) | |||
""" | |||
if padder is None: | |||
padder = AutoPadder(pad_val=0) | |||
if len(self.field_arrays) != 0: | |||
if len(self) != len(fields): | |||
raise RuntimeError(f"The field to append must have the same size as dataset. " | |||
@@ -231,8 +234,8 @@ class DataSet(object): | |||
raise KeyError("{} is not a valid field name.".format(name)) | |||
def set_padder(self, field_name, padder): | |||
""" | |||
为field_name设置padder | |||
"""为field_name设置padder | |||
:param field_name: str, 设置field的padding方式为padder | |||
:param padder: PadderBase类型或None. 设置为None即删除padder。即对该field不进行padding操作. | |||
:return: | |||
@@ -242,8 +245,7 @@ class DataSet(object): | |||
self.field_arrays[field_name].set_padder(padder) | |||
def set_pad_val(self, field_name, pad_val): | |||
""" | |||
为某个 | |||
"""为某个field设置对应的pad_val. | |||
:param field_name: str,修改该field的pad_val | |||
:param pad_val: int,该field的padder会以pad_val作为padding index | |||
@@ -254,43 +256,60 @@ class DataSet(object): | |||
self.field_arrays[field_name].set_pad_val(pad_val) | |||
def get_input_name(self): | |||
"""Get all field names with `is_input` as True. | |||
"""返回所有is_input被设置为True的field名称 | |||
:return field_names: a list of str | |||
:return list, 里面的元素为被设置为input的field名称 | |||
""" | |||
return [name for name, field in self.field_arrays.items() if field.is_input] | |||
def get_target_name(self): | |||
"""Get all field names with `is_target` as True. | |||
"""返回所有is_target被设置为True的field名称 | |||
:return field_names: a list of str | |||
:return list, 里面的元素为被设置为target的field名称 | |||
""" | |||
return [name for name, field in self.field_arrays.items() if field.is_target] | |||
def apply(self, func, new_field_name=None, **kwargs): | |||
"""Apply a function to every instance of the DataSet. | |||
:param func: a function that takes an instance as input. | |||
:param str new_field_name: If not None, results of the function will be stored as a new field. | |||
:param **kwargs: Accept parameters will be | |||
(1) is_input: boolean, will be ignored if new_field is None. If True, the new field will be as input. | |||
(2) is_target: boolean, will be ignored if new_field is None. If True, the new field will be as target. | |||
:return results: if new_field_name is not passed, returned values of the function over all instances. | |||
def apply_field(self, func, field_name, new_field_name=None, **kwargs): | |||
"""将DataSet中的每个instance中的`field_name`这个field传给func,并获取它的返回值. | |||
:param func: Callable, input是instance的`field_name`这个field. | |||
:param field_name: str, 传入func的是哪个field. | |||
:param new_field_name: (str, None). 如果不是None,将func的返回值放入这个名为`new_field_name`的新field中,如果名称与已有 | |||
的field相同,则覆盖之前的field. | |||
:param **kwargs: 合法的参数有以下三个 | |||
(1) is_input: bool, 如果为True则将`new_field_name`这个field设置为input | |||
(2) is_target: bool, 如果为True则将`new_field_name`这个field设置为target | |||
(3) ignore_type: bool, 如果为True则将`new_field_name`这个field的ignore_type设置为true, 忽略其类型 | |||
:return: List[], 里面的元素为func的返回值,所以list长度为DataSet的长度 | |||
""" | |||
assert len(self)!=0, "Null dataset cannot use .apply()." | |||
assert len(self)!=0, "Null DataSet cannot use apply()." | |||
if field_name not in self: | |||
raise KeyError("DataSet has no field named `{}`.".format(field_name)) | |||
results = [] | |||
idx = -1 | |||
try: | |||
for idx, ins in enumerate(self._inner_iter()): | |||
results.append(func(ins)) | |||
results.append(func(ins[field_name])) | |||
except Exception as e: | |||
if idx!=-1: | |||
print("Exception happens at the `{}`th instance.".format(idx)) | |||
raise e | |||
# results = [func(ins) for ins in self._inner_iter()] | |||
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | |||
raise ValueError("{} always return None.".format(get_func_signature(func=func))) | |||
if new_field_name is not None: | |||
self._add_apply_field(results, new_field_name, kwargs) | |||
return results | |||
def _add_apply_field(self, results, new_field_name, kwargs): | |||
"""将results作为加入到新的field中,field名称为new_field_name | |||
:param results: List[], 一般是apply*()之后的结果 | |||
:param new_field_name: str, 新加入的field的名称 | |||
:param kwargs: dict, 用户apply*()时传入的自定义参数 | |||
:return: | |||
""" | |||
extra_param = {} | |||
if 'is_input' in kwargs: | |||
extra_param['is_input'] = kwargs['is_input'] | |||
@@ -298,56 +317,84 @@ class DataSet(object): | |||
extra_param['is_target'] = kwargs['is_target'] | |||
if 'ignore_type' in kwargs: | |||
extra_param['ignore_type'] = kwargs['ignore_type'] | |||
if new_field_name is not None: | |||
if new_field_name in self.field_arrays: | |||
# overwrite the field, keep same attributes | |||
old_field = self.field_arrays[new_field_name] | |||
if 'is_input' not in extra_param: | |||
extra_param['is_input'] = old_field.is_input | |||
if 'is_target' not in extra_param: | |||
extra_param['is_target'] = old_field.is_target | |||
if 'ignore_type' not in extra_param: | |||
extra_param['ignore_type'] = old_field.ignore_type | |||
self.add_field(name=new_field_name, fields=results, is_input=extra_param["is_input"], | |||
is_target=extra_param["is_target"], ignore_type=extra_param['ignore_type']) | |||
else: | |||
self.add_field(name=new_field_name, fields=results, is_input=extra_param.get("is_input", None), | |||
is_target=extra_param.get("is_target", None), | |||
ignore_type=extra_param.get("ignore_type", False)) | |||
if new_field_name in self.field_arrays: | |||
# overwrite the field, keep same attributes | |||
old_field = self.field_arrays[new_field_name] | |||
if 'is_input' not in extra_param: | |||
extra_param['is_input'] = old_field.is_input | |||
if 'is_target' not in extra_param: | |||
extra_param['is_target'] = old_field.is_target | |||
if 'ignore_type' not in extra_param: | |||
extra_param['ignore_type'] = old_field.ignore_type | |||
self.add_field(name=new_field_name, fields=results, is_input=extra_param["is_input"], | |||
is_target=extra_param["is_target"], ignore_type=extra_param['ignore_type']) | |||
else: | |||
return results | |||
self.add_field(name=new_field_name, fields=results, is_input=extra_param.get("is_input", None), | |||
is_target=extra_param.get("is_target", None), | |||
ignore_type=extra_param.get("ignore_type", False)) | |||
def apply(self, func, new_field_name=None, **kwargs): | |||
"""将DataSet中每个instance传入到func中,并获取它的返回值. | |||
:param func: Callable, 参数是DataSet中的instance | |||
:param new_field_name: (None, str). (1) None, 不创建新的field; (2) str,将func的返回值放入这个名为 | |||
`new_field_name`的新field中,如果名称与已有的field相同,则覆盖之前的field; | |||
:param kwargs: 合法的参数有以下三个 | |||
(1) is_input: bool, 如果为True则将`new_field_name`的field设置为input | |||
(2) is_target: bool, 如果为True则将`new_field_name`的field设置为target | |||
(3) ignore_type: bool, 如果为True则将`new_field_name`的field的ignore_type设置为true, 忽略其类型 | |||
:return: List[], 里面的元素为func的返回值,所以list长度为DataSet的长度 | |||
""" | |||
assert len(self)!=0, "Null DataSet cannot use apply()." | |||
idx = -1 | |||
try: | |||
results = [] | |||
for idx, ins in enumerate(self._inner_iter()): | |||
results.append(func(ins)) | |||
except Exception as e: | |||
if idx!=-1: | |||
print("Exception happens at the `{}`th instance.".format(idx)) | |||
raise e | |||
# results = [func(ins) for ins in self._inner_iter()] | |||
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | |||
raise ValueError("{} always return None.".format(get_func_signature(func=func))) | |||
if new_field_name is not None: | |||
self._add_apply_field(results, new_field_name, kwargs) | |||
return results | |||
def drop(self, func, inplace=True): | |||
"""Drop instances if a condition holds. | |||
"""func接受一个instance,返回bool值,返回值为True时,该instance会被删除。 | |||
:param func: a function that takes an Instance object as input, and returns bool. | |||
The instance will be dropped if the function returns True. | |||
:param inplace: bool, whether to drop inpalce. Otherwise a new dataset will be returned. | |||
:param func: Callable, 接受一个instance作为参数,返回bool值。为True时删除该instance | |||
:param inplace: bool, 是否在当前DataSet中直接删除instance。如果为False,返回值为一个删除了相应instance的新的DataSet | |||
:return: DataSet. | |||
""" | |||
if inplace: | |||
results = [ins for ins in self._inner_iter() if not func(ins)] | |||
for name, old_field in self.field_arrays.items(): | |||
self.field_arrays[name].content = [ins[name] for ins in results] | |||
return self | |||
else: | |||
results = [ins for ins in self if not func(ins)] | |||
data = DataSet(results) | |||
for field_name, field in self.field_arrays.items(): | |||
data.field_arrays[field_name].to(field) | |||
return data | |||
def split(self, dev_ratio): | |||
"""Split the dataset into training and development(validation) set. | |||
def split(self, ratio): | |||
"""将DataSet按照ratio的比例拆分,返回两个DataSet | |||
:param float dev_ratio: the ratio of test set in all data. | |||
:return (train_set, dev_set): | |||
train_set: the training set | |||
dev_set: the development set | |||
:param ratio: float, 0<ratio<1, 返回的第一个DataSet拥有ratio这么多数据,第二个DataSet拥有(1-ratio)这么多数据 | |||
:return (DataSet, DataSet) | |||
""" | |||
assert isinstance(dev_ratio, float) | |||
assert 0 < dev_ratio < 1 | |||
assert isinstance(ratio, float) | |||
assert 0 < ratio < 1 | |||
all_indices = [_ for _ in range(len(self))] | |||
np.random.shuffle(all_indices) | |||
split = int(dev_ratio * len(self)) | |||
split = int(ratio * len(self)) | |||
dev_indices = all_indices[:split] | |||
train_indices = all_indices[split:] | |||
dev_set = DataSet() | |||
@@ -398,26 +445,25 @@ class DataSet(object): | |||
_dict[header].append(content) | |||
return cls(_dict) | |||
# def read_pos(self): | |||
# return DataLoaderRegister.get_reader('read_pos') | |||
def save(self, path): | |||
"""Save the DataSet object as pickle. | |||
"""保存DataSet. | |||
:param str path: the path to the pickle | |||
:param path: str, 将DataSet存在哪个路径 | |||
""" | |||
with open(path, 'wb') as f: | |||
pickle.dump(self, f) | |||
@staticmethod | |||
def load(path): | |||
"""Load a DataSet object from pickle. | |||
"""从保存的DataSet pickle路径中读取DataSet | |||
:param str path: the path to the pickle | |||
:return data_set: | |||
:param path: str, 读取路径 | |||
:return DataSet: | |||
""" | |||
with open(path, 'rb') as f: | |||
return pickle.load(f) | |||
d = pickle.load(f) | |||
assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) | |||
return d | |||
def construct_dataset(sentences): | |||
@@ -84,7 +84,7 @@ class AutoPadder(PadderBase): | |||
for i, content in enumerate(contents): | |||
array[i][:len(content)] = content | |||
elif field_ele_dtype is None: | |||
array = contents # 当ignore_type=True时,直接返回contents | |||
array = np.array(contents) # 当ignore_type=True时,直接返回contents | |||
else: # should only be str | |||
array = np.array([content for content in contents]) | |||
return array | |||
@@ -290,9 +290,10 @@ class FieldArray(object): | |||
return "FieldArray {}: {}".format(self.name, self.content.__repr__()) | |||
def append(self, val): | |||
"""Add a new item to the tail of FieldArray. | |||
"""将val增加到FieldArray中,若该field的ignore_type为True则直接append到这个field中;若ignore_type为False,且当前field为 | |||
input或者target,则会检查传入的content是否与之前的内容在dimension, 元素的类型上是匹配的。 | |||
:param val: int, float, str, or a list of one. | |||
:param val: Any. | |||
""" | |||
if self.ignore_type is False: | |||
if isinstance(val, list): | |||
@@ -367,13 +368,14 @@ class FieldArray(object): | |||
self.padder = deepcopy(padder) | |||
def set_pad_val(self, pad_val): | |||
""" | |||
修改padder的pad_val. | |||
:param pad_val: int。 | |||
"""修改padder的pad_val. | |||
:param pad_val: int。将该field的pad值设置为该值 | |||
:return: | |||
""" | |||
if self.padder is not None: | |||
self.padder.set_pad_val(pad_val) | |||
return self | |||
def __len__(self): | |||
@@ -385,8 +387,7 @@ class FieldArray(object): | |||
def to(self, other): | |||
""" | |||
将other的属性复制给本fieldarray(必须通过fieldarray类型). 包含 is_input, is_target, padder, dtype, pytype, content_dim | |||
ignore_type | |||
将other的属性复制给本FieldArray(other必须为FieldArray类型). 包含 is_input, is_target, padder, ignore_type | |||
:param other: FieldArray | |||
:return: | |||
@@ -396,11 +397,10 @@ class FieldArray(object): | |||
self.is_input = other.is_input | |||
self.is_target = other.is_target | |||
self.padder = other.padder | |||
self.dtype = other.dtype | |||
self.pytype = other.pytype | |||
self.content_dim = other.content_dim | |||
self.ignore_type = other.ignore_type | |||
return self | |||
def is_iterable(content): | |||
try: | |||
_ = (e for e in content) | |||
@@ -24,7 +24,7 @@ def _prepare_cache_filepath(filepath): | |||
if not os.path.exists(cache_dir): | |||
os.makedirs(cache_dir) | |||
# TODO 可以保存下缓存时的参数,如果load的时候发现参数不一致,发出警告。 | |||
def cache_results(cache_filepath, refresh=False, verbose=1): | |||
def wrapper_(func): | |||
signature = inspect.signature(func) | |||
@@ -79,7 +79,7 @@ class SeqLabeling(BaseModel): | |||
:return prediction: list of [decode path(list)] | |||
""" | |||
max_len = x.shape[1] | |||
tag_seq = self.Crf.viterbi_decode(x, self.mask) | |||
tag_seq, _ = self.Crf.viterbi_decode(x, self.mask) | |||
# pad prediction to equal length | |||
if pad is True: | |||
for pred in tag_seq: | |||
@@ -2,12 +2,7 @@ import torch | |||
from torch import nn | |||
from fastNLP.modules.utils import initial_parameter | |||
def log_sum_exp(x, dim=-1): | |||
max_value, _ = x.max(dim=dim, keepdim=True) | |||
res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value | |||
return res.squeeze(dim) | |||
from fastNLP.modules.decoder.utils import log_sum_exp | |||
def seq_len_to_byte_mask(seq_lens): | |||
@@ -20,22 +15,27 @@ def seq_len_to_byte_mask(seq_lens): | |||
return mask | |||
def allowed_transitions(id2label, encoding_type='bio'): | |||
def allowed_transitions(id2label, encoding_type='bio', include_start_end=True): | |||
""" | |||
给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 | |||
:param dict id2label: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | |||
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()id2label。 | |||
:param id2label: Dict, key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | |||
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()得到id2label。 | |||
:param encoding_type: str, 支持"bio", "bmes", "bmeso"。 | |||
:return: List[Tuple(int, int)]], 内部的Tuple是(from_tag_id, to_tag_id)。 返回的结果考虑了start和end,比如"BIO"中,B、O可以 | |||
位于序列的开端,而I不行。所以返回的结果中会包含(start_idx, B_idx), (start_idx, O_idx), 但是不包含(start_idx, I_idx). | |||
start_idx=len(id2label), end_idx=len(id2label)+1。 | |||
:param include_start_end: bool, 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | |||
为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); | |||
start_idx=len(id2label), end_idx=len(id2label)+1。 | |||
为False, 返回的结果中不含与开始结尾相关的内容 | |||
:return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。 | |||
""" | |||
num_tags = len(id2label) | |||
start_idx = num_tags | |||
end_idx = num_tags + 1 | |||
encoding_type = encoding_type.lower() | |||
allowed_trans = [] | |||
id_label_lst = list(id2label.items()) + [(start_idx, 'start'), (end_idx, 'end')] | |||
id_label_lst = list(id2label.items()) | |||
if include_start_end: | |||
id_label_lst += [(start_idx, 'start'), (end_idx, 'end')] | |||
def split_tag_label(from_label): | |||
from_label = from_label.lower() | |||
if from_label in ['start', 'end']: | |||
@@ -54,12 +54,12 @@ def allowed_transitions(id2label, encoding_type='bio'): | |||
if to_label in ['<pad>', '<unk>']: | |||
continue | |||
to_tag, to_label = split_tag_label(to_label) | |||
if is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||
if _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||
allowed_trans.append((from_id, to_id)) | |||
return allowed_trans | |||
def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||
def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||
""" | |||
:param encoding_type: str, 支持"BIO", "BMES", "BEMSO"。 | |||
@@ -140,20 +140,22 @@ def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label) | |||
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag)) | |||
else: | |||
raise ValueError("Only support BIO, BMES encoding type, got {}.".format(encoding_type)) | |||
raise ValueError("Only support BIO, BMES, BMESO encoding type, got {}.".format(encoding_type)) | |||
class ConditionalRandomField(nn.Module): | |||
""" | |||
:param int num_tags: 标签的数量。 | |||
:param bool include_start_end_trans: 是否包含起始tag | |||
:param list allowed_transitions: ``List[Tuple[from_tag_id(int), to_tag_id(int)]]``. 允许的跃迁,可以通过allowed_transitions()得到。 | |||
如果为None,则所有跃迁均为合法 | |||
:param str initial_method: | |||
""" | |||
def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, initial_method=None): | |||
def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, | |||
initial_method=None): | |||
"""条件随机场。 | |||
提供forward()以及viterbi_decode()两个方法,分别用于训练与inference。 | |||
:param num_tags: int, 标签的数量 | |||
:param include_start_end_trans: bool, 是否考虑各个tag作为开始以及结尾的分数。 | |||
:param allowed_transitions: List[Tuple[from_tag_id(int), to_tag_id(int)]], 内部的Tuple[from_tag_id(int), | |||
to_tag_id(int)]视为允许发生的跃迁,其他没有包含的跃迁认为是禁止跃迁,可以通过 | |||
allowed_transitions()函数得到;如果为None,则所有跃迁均为合法 | |||
:param initial_method: str, 初始化方法。见initial_parameter | |||
""" | |||
super(ConditionalRandomField, self).__init__() | |||
self.include_start_end_trans = include_start_end_trans | |||
@@ -168,18 +170,12 @@ class ConditionalRandomField(nn.Module): | |||
if allowed_transitions is None: | |||
constrain = torch.zeros(num_tags + 2, num_tags + 2) | |||
else: | |||
constrain = torch.ones(num_tags + 2, num_tags + 2) * -1000 | |||
constrain = torch.new_full((num_tags+2, num_tags+2), fill_value=-10000.0, dtype=torch.float) | |||
for from_tag_id, to_tag_id in allowed_transitions: | |||
constrain[from_tag_id, to_tag_id] = 0 | |||
self._constrain = nn.Parameter(constrain, requires_grad=False) | |||
# self.reset_parameter() | |||
initial_parameter(self, initial_method) | |||
def reset_parameter(self): | |||
nn.init.xavier_normal_(self.trans_m) | |||
if self.include_start_end_trans: | |||
nn.init.normal_(self.start_scores) | |||
nn.init.normal_(self.end_scores) | |||
def _normalizer_likelihood(self, logits, mask): | |||
"""Computes the (batch_size,) denominator term for the log-likelihood, which is the | |||
@@ -239,10 +235,11 @@ class ConditionalRandomField(nn.Module): | |||
def forward(self, feats, tags, mask): | |||
""" | |||
Calculate the neg log likelihood | |||
:param feats:FloatTensor, batch_size x max_len x num_tags | |||
:param tags:LongTensor, batch_size x max_len | |||
:param mask:ByteTensor batch_size x max_len | |||
用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。 | |||
:param feats:FloatTensor, batch_size x max_len x num_tags,特征矩阵。 | |||
:param tags:LongTensor, batch_size x max_len,标签矩阵。 | |||
:param mask:ByteTensor batch_size x max_len,为0的位置认为是padding。 | |||
:return:FloatTensor, batch_size | |||
""" | |||
feats = feats.transpose(0, 1) | |||
@@ -253,28 +250,27 @@ class ConditionalRandomField(nn.Module): | |||
return all_path_score - gold_path_score | |||
def viterbi_decode(self, data, mask, get_score=False, unpad=False): | |||
"""Given a feats matrix, return best decode path and best score. | |||
def viterbi_decode(self, feats, mask, unpad=False): | |||
"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | |||
:param data:FloatTensor, batch_size x max_len x num_tags | |||
:param mask:ByteTensor batch_size x max_len | |||
:param get_score: bool, whether to output the decode score. | |||
:param unpad: bool, 是否将结果unpad, | |||
如果False, 返回的是batch_size x max_len的tensor, | |||
如果True,返回的是List[List[int]], List[int]为每个sequence的label,已经unpadding了,即每个 | |||
List[int]的长度是这个sample的有效长度 | |||
:return: 如果get_score为False,返回结果根据unpadding变动 | |||
如果get_score为True, 返回 (paths, List[float], )。第一个仍然是解码后的路径(根据unpad变化),第二个List[Float] | |||
为每个seqence的解码分数。 | |||
:param feats: FloatTensor, batch_size x max_len x num_tags,特征矩阵。 | |||
:param mask: ByteTensor, batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 | |||
:param unpad: bool, 是否将结果删去padding, | |||
False, 返回的是batch_size x max_len的tensor, | |||
True,返回的是List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int] | |||
的长度是这个sample的有效长度。 | |||
:return: 返回 (paths, scores)。 | |||
paths: 是解码后的路径, 其值参照unpad参数. | |||
scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。 | |||
""" | |||
batch_size, seq_len, n_tags = data.size() | |||
data = data.transpose(0, 1).data # L, B, H | |||
batch_size, seq_len, n_tags = feats.size() | |||
feats = feats.transpose(0, 1).data # L, B, H | |||
mask = mask.transpose(0, 1).data.byte() # L, B | |||
# dp | |||
vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | |||
vscore = data[0] | |||
vpath = feats.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | |||
vscore = feats[0] | |||
transitions = self._constrain.data.clone() | |||
transitions[:n_tags, :n_tags] += self.trans_m.data | |||
if self.include_start_end_trans: | |||
@@ -285,23 +281,24 @@ class ConditionalRandomField(nn.Module): | |||
trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data | |||
for i in range(1, seq_len): | |||
prev_score = vscore.view(batch_size, n_tags, 1) | |||
cur_score = data[i].view(batch_size, 1, n_tags) | |||
cur_score = feats[i].view(batch_size, 1, n_tags) | |||
score = prev_score + trans_score + cur_score | |||
best_score, best_dst = score.max(1) | |||
vpath[i] = best_dst | |||
vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ | |||
vscore.masked_fill(mask[i].view(batch_size, 1), 0) | |||
vscore += transitions[:n_tags, n_tags+1].view(1, -1) | |||
if self.include_start_end_trans: | |||
vscore += transitions[:n_tags, n_tags+1].view(1, -1) | |||
# backtrace | |||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=data.device) | |||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=data.device) | |||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=feats.device) | |||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=feats.device) | |||
lens = (mask.long().sum(0) - 1) | |||
# idxes [L, B], batched idx from seq_len-1 to 0 | |||
idxes = (lens.view(1,-1) - seq_idx.view(-1,1)) % seq_len | |||
ans = data.new_empty((seq_len, batch_size), dtype=torch.long) | |||
ans = feats.new_empty((seq_len, batch_size), dtype=torch.long) | |||
ans_score, last_tags = vscore.max(1) | |||
ans[idxes[0], batch_idx] = last_tags | |||
for i in range(seq_len - 1): | |||
@@ -0,0 +1,70 @@ | |||
import torch | |||
def log_sum_exp(x, dim=-1): | |||
max_value, _ = x.max(dim=dim, keepdim=True) | |||
res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value | |||
return res.squeeze(dim) | |||
def viterbi_decode(feats, transitions, mask=None, unpad=False): | |||
"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | |||
:param feats: FloatTensor, batch_size x max_len x num_tags,特征矩阵。 | |||
:param transitions: FloatTensor, n_tags x n_tags。[i, j]位置的值认为是从tag i到tag j的转换。 | |||
:param mask: ByteTensor, batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 | |||
:param unpad: bool, 是否将结果删去padding, | |||
False, 返回的是batch_size x max_len的tensor, | |||
True,返回的是List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是 | |||
这个sample的有效长度。 | |||
:return: 返回 (paths, scores)。 | |||
paths: 是解码后的路径, 其值参照unpad参数. | |||
scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。 | |||
""" | |||
batch_size, seq_len, n_tags = feats.size() | |||
assert n_tags==transitions.size(0) and n_tags==transitions.size(1), "The shapes of transitions and feats are not " \ | |||
"compatible." | |||
feats = feats.transpose(0, 1).data # L, B, H | |||
if mask is not None: | |||
mask = mask.transpose(0, 1).data.byte() # L, B | |||
else: | |||
mask = feats.new_ones((seq_len, batch_size), dtype=torch.uint8) | |||
# dp | |||
vpath = feats.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | |||
vscore = feats[0] | |||
vscore += transitions[n_tags, :n_tags] | |||
trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data | |||
for i in range(1, seq_len): | |||
prev_score = vscore.view(batch_size, n_tags, 1) | |||
cur_score = feats[i].view(batch_size, 1, n_tags) | |||
score = prev_score + trans_score + cur_score | |||
best_score, best_dst = score.max(1) | |||
vpath[i] = best_dst | |||
vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ | |||
vscore.masked_fill(mask[i].view(batch_size, 1), 0) | |||
# backtrace | |||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=feats.device) | |||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=feats.device) | |||
lens = (mask.long().sum(0) - 1) | |||
# idxes [L, B], batched idx from seq_len-1 to 0 | |||
idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len | |||
ans = feats.new_empty((seq_len, batch_size), dtype=torch.long) | |||
ans_score, last_tags = vscore.max(1) | |||
ans[idxes[0], batch_idx] = last_tags | |||
for i in range(seq_len - 1): | |||
last_tags = vpath[idxes[i], batch_idx, last_tags] | |||
ans[idxes[i + 1], batch_idx] = last_tags | |||
ans = ans.transpose(0, 1) | |||
if unpad: | |||
paths = [] | |||
for idx, seq_len in enumerate(lens): | |||
paths.append(ans[idx, :seq_len + 1].tolist()) | |||
else: | |||
paths = ans | |||
return paths, ans_score |
@@ -183,7 +183,7 @@ class CWSBiLSTMCRF(BaseModel): | |||
masks = seq_lens_to_mask(seq_lens) | |||
feats = self.encoder_model(chars, bigrams, seq_lens) | |||
feats = self.decoder_model(feats) | |||
probs = self.crf.viterbi_decode(feats, masks, get_score=False) | |||
paths, _ = self.crf.viterbi_decode(feats, masks) | |||
return {'pred': probs, 'seq_lens':seq_lens} | |||
return {'pred': paths, 'seq_lens':seq_lens} | |||
@@ -72,9 +72,9 @@ class TransformerCWS(nn.Module): | |||
feats = self.transformer(x, masks) | |||
feats = self.fc2(feats) | |||
probs = self.crf.viterbi_decode(feats, masks, get_score=False) | |||
paths, _ = self.crf.viterbi_decode(feats, masks) | |||
return {'pred': probs, 'seq_lens':seq_lens} | |||
return {'pred': paths, 'seq_lens':seq_lens} | |||
class NoamOpt(torch.optim.Optimizer): | |||