Browse Source

修改为中文注释,增加viterbi解码方法

tags/v0.4.10
yh 5 years ago
parent
commit
c520d35082
8 changed files with 249 additions and 136 deletions
  1. +107
    -61
      fastNLP/core/dataset.py
  2. +11
    -11
      fastNLP/core/fieldarray.py
  3. +1
    -1
      fastNLP/core/utils.py
  4. +1
    -1
      fastNLP/models/sequence_modeling.py
  5. +55
    -58
      fastNLP/modules/decoder/CRF.py
  6. +70
    -0
      fastNLP/modules/decoder/utils.py
  7. +2
    -2
      reproduction/Chinese_word_segmentation/models/cws_model.py
  8. +2
    -2
      reproduction/Chinese_word_segmentation/models/cws_transformer.py

+ 107
- 61
fastNLP/core/dataset.py View File

@@ -151,16 +151,19 @@ class DataSet(object):
assert name in self.field_arrays
self.field_arrays[name].append(field)

def add_field(self, name, fields, padder=AutoPadder(pad_val=0), is_input=False, is_target=False, ignore_type=False):
def add_field(self, name, fields, padder=None, is_input=False, is_target=False, ignore_type=False):
"""Add a new field to the DataSet.
:param str name: the name of the field.
:param fields: a list of int, float, or other objects.
:param int padder: PadBase对象,如何对该Field进行padding。大部分情况使用默认值即可
:param padder: PadBase对象,如何对该Field进行padding。如果为None则使用
:param bool is_input: whether this field is model input.
:param bool is_target: whether this field is label or target.
:param bool ignore_type: If True, do not perform type check. (Default: False)
"""
if padder is None:
padder = AutoPadder(pad_val=0)

if len(self.field_arrays) != 0:
if len(self) != len(fields):
raise RuntimeError(f"The field to append must have the same size as dataset. "
@@ -231,8 +234,8 @@ class DataSet(object):
raise KeyError("{} is not a valid field name.".format(name))

def set_padder(self, field_name, padder):
"""
为field_name设置padder
"""为field_name设置padder
:param field_name: str, 设置field的padding方式为padder
:param padder: PadderBase类型或None. 设置为None即删除padder。即对该field不进行padding操作.
:return:
@@ -242,8 +245,7 @@ class DataSet(object):
self.field_arrays[field_name].set_padder(padder)

def set_pad_val(self, field_name, pad_val):
"""
为某个
"""为某个field设置对应的pad_val.

:param field_name: str,修改该field的pad_val
:param pad_val: int,该field的padder会以pad_val作为padding index
@@ -254,43 +256,60 @@ class DataSet(object):
self.field_arrays[field_name].set_pad_val(pad_val)

def get_input_name(self):
"""Get all field names with `is_input` as True.
"""返回所有is_input被设置为True的field名称

:return field_names: a list of str
:return list, 里面的元素为被设置为input的field名称
"""
return [name for name, field in self.field_arrays.items() if field.is_input]

def get_target_name(self):
"""Get all field names with `is_target` as True.
"""返回所有is_target被设置为True的field名称

:return field_names: a list of str
:return list, 里面的元素为被设置为target的field名称
"""
return [name for name, field in self.field_arrays.items() if field.is_target]

def apply(self, func, new_field_name=None, **kwargs):
"""Apply a function to every instance of the DataSet.

:param func: a function that takes an instance as input.
:param str new_field_name: If not None, results of the function will be stored as a new field.
:param **kwargs: Accept parameters will be
(1) is_input: boolean, will be ignored if new_field is None. If True, the new field will be as input.
(2) is_target: boolean, will be ignored if new_field is None. If True, the new field will be as target.
:return results: if new_field_name is not passed, returned values of the function over all instances.
def apply_field(self, func, field_name, new_field_name=None, **kwargs):
"""将DataSet中的每个instance中的`field_name`这个field传给func,并获取它的返回值.

:param func: Callable, input是instance的`field_name`这个field.
:param field_name: str, 传入func的是哪个field.
:param new_field_name: (str, None). 如果不是None,将func的返回值放入这个名为`new_field_name`的新field中,如果名称与已有
的field相同,则覆盖之前的field.
:param **kwargs: 合法的参数有以下三个
(1) is_input: bool, 如果为True则将`new_field_name`这个field设置为input
(2) is_target: bool, 如果为True则将`new_field_name`这个field设置为target
(3) ignore_type: bool, 如果为True则将`new_field_name`这个field的ignore_type设置为true, 忽略其类型
:return: List[], 里面的元素为func的返回值,所以list长度为DataSet的长度
"""
assert len(self)!=0, "Null dataset cannot use .apply()."
assert len(self)!=0, "Null DataSet cannot use apply()."
if field_name not in self:
raise KeyError("DataSet has no field named `{}`.".format(field_name))
results = []
idx = -1
try:
for idx, ins in enumerate(self._inner_iter()):
results.append(func(ins))
results.append(func(ins[field_name]))
except Exception as e:
if idx!=-1:
print("Exception happens at the `{}`th instance.".format(idx))
raise e
# results = [func(ins) for ins in self._inner_iter()]
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None
raise ValueError("{} always return None.".format(get_func_signature(func=func)))

if new_field_name is not None:
self._add_apply_field(results, new_field_name, kwargs)

return results

def _add_apply_field(self, results, new_field_name, kwargs):
"""将results作为加入到新的field中,field名称为new_field_name

:param results: List[], 一般是apply*()之后的结果
:param new_field_name: str, 新加入的field的名称
:param kwargs: dict, 用户apply*()时传入的自定义参数
:return:
"""
extra_param = {}
if 'is_input' in kwargs:
extra_param['is_input'] = kwargs['is_input']
@@ -298,56 +317,84 @@ class DataSet(object):
extra_param['is_target'] = kwargs['is_target']
if 'ignore_type' in kwargs:
extra_param['ignore_type'] = kwargs['ignore_type']
if new_field_name is not None:
if new_field_name in self.field_arrays:
# overwrite the field, keep same attributes
old_field = self.field_arrays[new_field_name]
if 'is_input' not in extra_param:
extra_param['is_input'] = old_field.is_input
if 'is_target' not in extra_param:
extra_param['is_target'] = old_field.is_target
if 'ignore_type' not in extra_param:
extra_param['ignore_type'] = old_field.ignore_type
self.add_field(name=new_field_name, fields=results, is_input=extra_param["is_input"],
is_target=extra_param["is_target"], ignore_type=extra_param['ignore_type'])
else:
self.add_field(name=new_field_name, fields=results, is_input=extra_param.get("is_input", None),
is_target=extra_param.get("is_target", None),
ignore_type=extra_param.get("ignore_type", False))
if new_field_name in self.field_arrays:
# overwrite the field, keep same attributes
old_field = self.field_arrays[new_field_name]
if 'is_input' not in extra_param:
extra_param['is_input'] = old_field.is_input
if 'is_target' not in extra_param:
extra_param['is_target'] = old_field.is_target
if 'ignore_type' not in extra_param:
extra_param['ignore_type'] = old_field.ignore_type
self.add_field(name=new_field_name, fields=results, is_input=extra_param["is_input"],
is_target=extra_param["is_target"], ignore_type=extra_param['ignore_type'])
else:
return results
self.add_field(name=new_field_name, fields=results, is_input=extra_param.get("is_input", None),
is_target=extra_param.get("is_target", None),
ignore_type=extra_param.get("ignore_type", False))

def apply(self, func, new_field_name=None, **kwargs):
"""将DataSet中每个instance传入到func中,并获取它的返回值.

:param func: Callable, 参数是DataSet中的instance
:param new_field_name: (None, str). (1) None, 不创建新的field; (2) str,将func的返回值放入这个名为
`new_field_name`的新field中,如果名称与已有的field相同,则覆盖之前的field;
:param kwargs: 合法的参数有以下三个
(1) is_input: bool, 如果为True则将`new_field_name`的field设置为input
(2) is_target: bool, 如果为True则将`new_field_name`的field设置为target
(3) ignore_type: bool, 如果为True则将`new_field_name`的field的ignore_type设置为true, 忽略其类型
:return: List[], 里面的元素为func的返回值,所以list长度为DataSet的长度
"""
assert len(self)!=0, "Null DataSet cannot use apply()."
idx = -1
try:
results = []
for idx, ins in enumerate(self._inner_iter()):
results.append(func(ins))
except Exception as e:
if idx!=-1:
print("Exception happens at the `{}`th instance.".format(idx))
raise e
# results = [func(ins) for ins in self._inner_iter()]
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None
raise ValueError("{} always return None.".format(get_func_signature(func=func)))

if new_field_name is not None:
self._add_apply_field(results, new_field_name, kwargs)

return results

def drop(self, func, inplace=True):
"""Drop instances if a condition holds.
"""func接受一个instance,返回bool值,返回值为True时,该instance会被删除。

:param func: a function that takes an Instance object as input, and returns bool.
The instance will be dropped if the function returns True.
:param inplace: bool, whether to drop inpalce. Otherwise a new dataset will be returned.
:param func: Callable, 接受一个instance作为参数,返回bool值。为True时删除该instance
:param inplace: bool, 是否在当前DataSet中直接删除instance。如果为False,返回值为一个删除了相应instance的新的DataSet

:return: DataSet.
"""
if inplace:
results = [ins for ins in self._inner_iter() if not func(ins)]
for name, old_field in self.field_arrays.items():
self.field_arrays[name].content = [ins[name] for ins in results]
return self
else:
results = [ins for ins in self if not func(ins)]
data = DataSet(results)
for field_name, field in self.field_arrays.items():
data.field_arrays[field_name].to(field)
return data

def split(self, dev_ratio):
"""Split the dataset into training and development(validation) set.
def split(self, ratio):
"""将DataSet按照ratio的比例拆分,返回两个DataSet

:param float dev_ratio: the ratio of test set in all data.
:return (train_set, dev_set):
train_set: the training set
dev_set: the development set
:param ratio: float, 0<ratio<1, 返回的第一个DataSet拥有ratio这么多数据,第二个DataSet拥有(1-ratio)这么多数据
:return (DataSet, DataSet)
"""
assert isinstance(dev_ratio, float)
assert 0 < dev_ratio < 1
assert isinstance(ratio, float)
assert 0 < ratio < 1
all_indices = [_ for _ in range(len(self))]
np.random.shuffle(all_indices)
split = int(dev_ratio * len(self))
split = int(ratio * len(self))
dev_indices = all_indices[:split]
train_indices = all_indices[split:]
dev_set = DataSet()
@@ -398,26 +445,25 @@ class DataSet(object):
_dict[header].append(content)
return cls(_dict)

# def read_pos(self):
# return DataLoaderRegister.get_reader('read_pos')

def save(self, path):
"""Save the DataSet object as pickle.
"""保存DataSet.

:param str path: the path to the pickle
:param path: str, 将DataSet存在哪个路径
"""
with open(path, 'wb') as f:
pickle.dump(self, f)

@staticmethod
def load(path):
"""Load a DataSet object from pickle.
"""从保存的DataSet pickle路径中读取DataSet

:param str path: the path to the pickle
:return data_set:
:param path: str, 读取路径
:return DataSet:
"""
with open(path, 'rb') as f:
return pickle.load(f)
d = pickle.load(f)
assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d))
return d


def construct_dataset(sentences):


+ 11
- 11
fastNLP/core/fieldarray.py View File

@@ -84,7 +84,7 @@ class AutoPadder(PadderBase):
for i, content in enumerate(contents):
array[i][:len(content)] = content
elif field_ele_dtype is None:
array = contents # 当ignore_type=True时,直接返回contents
array = np.array(contents) # 当ignore_type=True时,直接返回contents
else: # should only be str
array = np.array([content for content in contents])
return array
@@ -290,9 +290,10 @@ class FieldArray(object):
return "FieldArray {}: {}".format(self.name, self.content.__repr__())

def append(self, val):
"""Add a new item to the tail of FieldArray.
"""将val增加到FieldArray中,若该field的ignore_type为True则直接append到这个field中;若ignore_type为False,且当前field为
input或者target,则会检查传入的content是否与之前的内容在dimension, 元素的类型上是匹配的。

:param val: int, float, str, or a list of one.
:param val: Any.
"""
if self.ignore_type is False:
if isinstance(val, list):
@@ -367,13 +368,14 @@ class FieldArray(object):
self.padder = deepcopy(padder)

def set_pad_val(self, pad_val):
"""
修改padder的pad_val.
:param pad_val: int。
"""修改padder的pad_val.
:param pad_val: int。将该field的pad值设置为该值
:return:
"""
if self.padder is not None:
self.padder.set_pad_val(pad_val)
return self


def __len__(self):
@@ -385,8 +387,7 @@ class FieldArray(object):

def to(self, other):
"""
将other的属性复制给本fieldarray(必须通过fieldarray类型). 包含 is_input, is_target, padder, dtype, pytype, content_dim
ignore_type
将other的属性复制给本FieldArray(other必须为FieldArray类型). 包含 is_input, is_target, padder, ignore_type

:param other: FieldArray
:return:
@@ -396,11 +397,10 @@ class FieldArray(object):
self.is_input = other.is_input
self.is_target = other.is_target
self.padder = other.padder
self.dtype = other.dtype
self.pytype = other.pytype
self.content_dim = other.content_dim
self.ignore_type = other.ignore_type

return self

def is_iterable(content):
try:
_ = (e for e in content)


+ 1
- 1
fastNLP/core/utils.py View File

@@ -24,7 +24,7 @@ def _prepare_cache_filepath(filepath):
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)

# TODO 可以保存下缓存时的参数,如果load的时候发现参数不一致,发出警告。
def cache_results(cache_filepath, refresh=False, verbose=1):
def wrapper_(func):
signature = inspect.signature(func)


+ 1
- 1
fastNLP/models/sequence_modeling.py View File

@@ -79,7 +79,7 @@ class SeqLabeling(BaseModel):
:return prediction: list of [decode path(list)]
"""
max_len = x.shape[1]
tag_seq = self.Crf.viterbi_decode(x, self.mask)
tag_seq, _ = self.Crf.viterbi_decode(x, self.mask)
# pad prediction to equal length
if pad is True:
for pred in tag_seq:


+ 55
- 58
fastNLP/modules/decoder/CRF.py View File

@@ -2,12 +2,7 @@ import torch
from torch import nn

from fastNLP.modules.utils import initial_parameter


def log_sum_exp(x, dim=-1):
max_value, _ = x.max(dim=dim, keepdim=True)
res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value
return res.squeeze(dim)
from fastNLP.modules.decoder.utils import log_sum_exp


def seq_len_to_byte_mask(seq_lens):
@@ -20,22 +15,27 @@ def seq_len_to_byte_mask(seq_lens):
return mask


def allowed_transitions(id2label, encoding_type='bio'):
def allowed_transitions(id2label, encoding_type='bio', include_start_end=True):
"""
给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。

:param dict id2label: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()id2label。
:param id2label: Dict, key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()得到id2label。
:param encoding_type: str, 支持"bio", "bmes", "bmeso"。
:return: List[Tuple(int, int)]], 内部的Tuple是(from_tag_id, to_tag_id)。 返回的结果考虑了start和end,比如"BIO"中,B、O可以
位于序列的开端,而I不行。所以返回的结果中会包含(start_idx, B_idx), (start_idx, O_idx), 但是不包含(start_idx, I_idx).
start_idx=len(id2label), end_idx=len(id2label)+1。
:param include_start_end: bool, 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头;
为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx);
start_idx=len(id2label), end_idx=len(id2label)+1。
为False, 返回的结果中不含与开始结尾相关的内容
:return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。
"""
num_tags = len(id2label)
start_idx = num_tags
end_idx = num_tags + 1
encoding_type = encoding_type.lower()
allowed_trans = []
id_label_lst = list(id2label.items()) + [(start_idx, 'start'), (end_idx, 'end')]
id_label_lst = list(id2label.items())
if include_start_end:
id_label_lst += [(start_idx, 'start'), (end_idx, 'end')]
def split_tag_label(from_label):
from_label = from_label.lower()
if from_label in ['start', 'end']:
@@ -54,12 +54,12 @@ def allowed_transitions(id2label, encoding_type='bio'):
if to_label in ['<pad>', '<unk>']:
continue
to_tag, to_label = split_tag_label(to_label)
if is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label):
if _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label):
allowed_trans.append((from_id, to_id))
return allowed_trans


def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label):
def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label):
"""

:param encoding_type: str, 支持"BIO", "BMES", "BEMSO"。
@@ -140,20 +140,22 @@ def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label)
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag))

else:
raise ValueError("Only support BIO, BMES encoding type, got {}.".format(encoding_type))
raise ValueError("Only support BIO, BMES, BMESO encoding type, got {}.".format(encoding_type))


class ConditionalRandomField(nn.Module):
"""

:param int num_tags: 标签的数量。
:param bool include_start_end_trans: 是否包含起始tag
:param list allowed_transitions: ``List[Tuple[from_tag_id(int), to_tag_id(int)]]``. 允许的跃迁,可以通过allowed_transitions()得到。
如果为None,则所有跃迁均为合法
:param str initial_method:
"""

def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, initial_method=None):
def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None,
initial_method=None):
"""条件随机场。
提供forward()以及viterbi_decode()两个方法,分别用于训练与inference。

:param num_tags: int, 标签的数量
:param include_start_end_trans: bool, 是否考虑各个tag作为开始以及结尾的分数。
:param allowed_transitions: List[Tuple[from_tag_id(int), to_tag_id(int)]], 内部的Tuple[from_tag_id(int),
to_tag_id(int)]视为允许发生的跃迁,其他没有包含的跃迁认为是禁止跃迁,可以通过
allowed_transitions()函数得到;如果为None,则所有跃迁均为合法
:param initial_method: str, 初始化方法。见initial_parameter
"""
super(ConditionalRandomField, self).__init__()

self.include_start_end_trans = include_start_end_trans
@@ -168,18 +170,12 @@ class ConditionalRandomField(nn.Module):
if allowed_transitions is None:
constrain = torch.zeros(num_tags + 2, num_tags + 2)
else:
constrain = torch.ones(num_tags + 2, num_tags + 2) * -1000
constrain = torch.new_full((num_tags+2, num_tags+2), fill_value=-10000.0, dtype=torch.float)
for from_tag_id, to_tag_id in allowed_transitions:
constrain[from_tag_id, to_tag_id] = 0
self._constrain = nn.Parameter(constrain, requires_grad=False)

# self.reset_parameter()
initial_parameter(self, initial_method)
def reset_parameter(self):
nn.init.xavier_normal_(self.trans_m)
if self.include_start_end_trans:
nn.init.normal_(self.start_scores)
nn.init.normal_(self.end_scores)

def _normalizer_likelihood(self, logits, mask):
"""Computes the (batch_size,) denominator term for the log-likelihood, which is the
@@ -239,10 +235,11 @@ class ConditionalRandomField(nn.Module):

def forward(self, feats, tags, mask):
"""
Calculate the neg log likelihood
:param feats:FloatTensor, batch_size x max_len x num_tags
:param tags:LongTensor, batch_size x max_len
:param mask:ByteTensor batch_size x max_len
用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。

:param feats:FloatTensor, batch_size x max_len x num_tags,特征矩阵。
:param tags:LongTensor, batch_size x max_len,标签矩阵。
:param mask:ByteTensor batch_size x max_len,为0的位置认为是padding。
:return:FloatTensor, batch_size
"""
feats = feats.transpose(0, 1)
@@ -253,28 +250,27 @@ class ConditionalRandomField(nn.Module):

return all_path_score - gold_path_score

def viterbi_decode(self, data, mask, get_score=False, unpad=False):
"""Given a feats matrix, return best decode path and best score.
def viterbi_decode(self, feats, mask, unpad=False):
"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数

:param data:FloatTensor, batch_size x max_len x num_tags
:param mask:ByteTensor batch_size x max_len
:param get_score: bool, whether to output the decode score.
:param unpad: bool, 是否将结果unpad,
如果False, 返回的是batch_size x max_len的tensor,
如果True,返回的是List[List[int]], List[int]为每个sequence的label,已经unpadding了,即每个
List[int]的长度是这个sample的有效长度
:return: 如果get_score为False,返回结果根据unpadding变动
如果get_score为True, 返回 (paths, List[float], )。第一个仍然是解码后的路径(根据unpad变化),第二个List[Float]
为每个seqence的解码分数。
:param feats: FloatTensor, batch_size x max_len x num_tags,特征矩阵。
:param mask: ByteTensor, batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。
:param unpad: bool, 是否将结果删去padding,
False, 返回的是batch_size x max_len的tensor,
True,返回的是List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]
的长度是这个sample的有效长度。
:return: 返回 (paths, scores)。
paths: 是解码后的路径, 其值参照unpad参数.
scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。

"""
batch_size, seq_len, n_tags = data.size()
data = data.transpose(0, 1).data # L, B, H
batch_size, seq_len, n_tags = feats.size()
feats = feats.transpose(0, 1).data # L, B, H
mask = mask.transpose(0, 1).data.byte() # L, B

# dp
vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long)
vscore = data[0]
vpath = feats.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long)
vscore = feats[0]
transitions = self._constrain.data.clone()
transitions[:n_tags, :n_tags] += self.trans_m.data
if self.include_start_end_trans:
@@ -285,23 +281,24 @@ class ConditionalRandomField(nn.Module):
trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data
for i in range(1, seq_len):
prev_score = vscore.view(batch_size, n_tags, 1)
cur_score = data[i].view(batch_size, 1, n_tags)
cur_score = feats[i].view(batch_size, 1, n_tags)
score = prev_score + trans_score + cur_score
best_score, best_dst = score.max(1)
vpath[i] = best_dst
vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \
vscore.masked_fill(mask[i].view(batch_size, 1), 0)

vscore += transitions[:n_tags, n_tags+1].view(1, -1)
if self.include_start_end_trans:
vscore += transitions[:n_tags, n_tags+1].view(1, -1)

# backtrace
batch_idx = torch.arange(batch_size, dtype=torch.long, device=data.device)
seq_idx = torch.arange(seq_len, dtype=torch.long, device=data.device)
batch_idx = torch.arange(batch_size, dtype=torch.long, device=feats.device)
seq_idx = torch.arange(seq_len, dtype=torch.long, device=feats.device)
lens = (mask.long().sum(0) - 1)
# idxes [L, B], batched idx from seq_len-1 to 0
idxes = (lens.view(1,-1) - seq_idx.view(-1,1)) % seq_len

ans = data.new_empty((seq_len, batch_size), dtype=torch.long)
ans = feats.new_empty((seq_len, batch_size), dtype=torch.long)
ans_score, last_tags = vscore.max(1)
ans[idxes[0], batch_idx] = last_tags
for i in range(seq_len - 1):


+ 70
- 0
fastNLP/modules/decoder/utils.py View File

@@ -0,0 +1,70 @@

import torch


def log_sum_exp(x, dim=-1):
max_value, _ = x.max(dim=dim, keepdim=True)
res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value
return res.squeeze(dim)


def viterbi_decode(feats, transitions, mask=None, unpad=False):
"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数

:param feats: FloatTensor, batch_size x max_len x num_tags,特征矩阵。
:param transitions: FloatTensor, n_tags x n_tags。[i, j]位置的值认为是从tag i到tag j的转换。
:param mask: ByteTensor, batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。
:param unpad: bool, 是否将结果删去padding,
False, 返回的是batch_size x max_len的tensor,
True,返回的是List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是
这个sample的有效长度。
:return: 返回 (paths, scores)。
paths: 是解码后的路径, 其值参照unpad参数.
scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。

"""
batch_size, seq_len, n_tags = feats.size()
assert n_tags==transitions.size(0) and n_tags==transitions.size(1), "The shapes of transitions and feats are not " \
"compatible."
feats = feats.transpose(0, 1).data # L, B, H
if mask is not None:
mask = mask.transpose(0, 1).data.byte() # L, B
else:
mask = feats.new_ones((seq_len, batch_size), dtype=torch.uint8)

# dp
vpath = feats.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long)
vscore = feats[0]

vscore += transitions[n_tags, :n_tags]
trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data
for i in range(1, seq_len):
prev_score = vscore.view(batch_size, n_tags, 1)
cur_score = feats[i].view(batch_size, 1, n_tags)
score = prev_score + trans_score + cur_score
best_score, best_dst = score.max(1)
vpath[i] = best_dst
vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \
vscore.masked_fill(mask[i].view(batch_size, 1), 0)

# backtrace
batch_idx = torch.arange(batch_size, dtype=torch.long, device=feats.device)
seq_idx = torch.arange(seq_len, dtype=torch.long, device=feats.device)
lens = (mask.long().sum(0) - 1)
# idxes [L, B], batched idx from seq_len-1 to 0
idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len

ans = feats.new_empty((seq_len, batch_size), dtype=torch.long)
ans_score, last_tags = vscore.max(1)
ans[idxes[0], batch_idx] = last_tags
for i in range(seq_len - 1):
last_tags = vpath[idxes[i], batch_idx, last_tags]
ans[idxes[i + 1], batch_idx] = last_tags
ans = ans.transpose(0, 1)
if unpad:
paths = []
for idx, seq_len in enumerate(lens):
paths.append(ans[idx, :seq_len + 1].tolist())
else:
paths = ans
return paths, ans_score

+ 2
- 2
reproduction/Chinese_word_segmentation/models/cws_model.py View File

@@ -183,7 +183,7 @@ class CWSBiLSTMCRF(BaseModel):
masks = seq_lens_to_mask(seq_lens)
feats = self.encoder_model(chars, bigrams, seq_lens)
feats = self.decoder_model(feats)
probs = self.crf.viterbi_decode(feats, masks, get_score=False)
paths, _ = self.crf.viterbi_decode(feats, masks)

return {'pred': probs, 'seq_lens':seq_lens}
return {'pred': paths, 'seq_lens':seq_lens}


+ 2
- 2
reproduction/Chinese_word_segmentation/models/cws_transformer.py View File

@@ -72,9 +72,9 @@ class TransformerCWS(nn.Module):
feats = self.transformer(x, masks)
feats = self.fc2(feats)

probs = self.crf.viterbi_decode(feats, masks, get_score=False)
paths, _ = self.crf.viterbi_decode(feats, masks)

return {'pred': probs, 'seq_lens':seq_lens}
return {'pred': paths, 'seq_lens':seq_lens}


class NoamOpt(torch.optim.Optimizer):


Loading…
Cancel
Save