@@ -24,5 +24,5 @@ from .optimizer import Optimizer, SGD, Adam | |||
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler | |||
from .tester import Tester | |||
from .trainer import Trainer | |||
from .utils import cache_results | |||
from .utils import cache_results, seq_len_to_mask | |||
from .vocabulary import Vocabulary |
@@ -13,7 +13,7 @@ from .utils import _CheckRes | |||
from .utils import _build_args | |||
from .utils import _check_arg_dict_list | |||
from .utils import _get_func_signature | |||
from .utils import seq_lens_to_masks | |||
from .utils import seq_len_to_mask | |||
from .vocabulary import Vocabulary | |||
@@ -305,7 +305,7 @@ class AccuracyMetric(MetricBase): | |||
f"got {type(seq_len)}.") | |||
if seq_len is not None: | |||
masks = seq_lens_to_masks(seq_lens=seq_len) | |||
masks = seq_len_to_mask(seq_len=seq_len) | |||
else: | |||
masks = None | |||
@@ -1,7 +1,7 @@ | |||
""" | |||
utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。 | |||
""" | |||
__all__ = ["cache_results"] | |||
__all__ = ["cache_results", "seq_len_to_mask"] | |||
import _pickle | |||
import inspect | |||
import os | |||
@@ -600,48 +600,41 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||
warnings.warn(message=_unused_warn) | |||
def seq_lens_to_masks(seq_lens, float=False): | |||
def seq_len_to_mask(seq_len): | |||
""" | |||
Convert seq_lens to masks. | |||
:param seq_lens: list, np.ndarray, or torch.LongTensor, shape should all be (B,) | |||
:param float: if True, the return masks is in float type, otherwise it is byte. | |||
:return: list, np.ndarray or torch.Tensor, shape will be (B, max_length) | |||
将一个表示sequence length的一维数组转换为二维的mask,不包含的位置为0。 | |||
转变 1-d seq_len到2-d mask. | |||
Example:: | |||
>>> seq_len = torch.arange(2, 16) | |||
>>> mask = seq_len_to_mask(seq_len) | |||
>>> print(mask.size()) | |||
torch.Size([14, 15]) | |||
>>> seq_len = np.arange(2, 16) | |||
>>> mask = seq_len_to_mask(seq_len) | |||
>>> print(mask.shape) | |||
(14, 15) | |||
:param np.ndarray,torch.LongTensor seq_len: shape将是(B,) | |||
:return: np.ndarray or torch.Tensor, shape将是(B, max_length)。 元素类似为bool或torch.uint8 | |||
""" | |||
if isinstance(seq_lens, np.ndarray): | |||
assert len(np.shape(seq_lens)) == 1, f"seq_lens can only have one dimension, got {len(np.shape(seq_lens))}." | |||
assert seq_lens.dtype in (int, np.int32, np.int64), f"seq_lens can only be integer, not {seq_lens.dtype}." | |||
raise NotImplemented | |||
elif isinstance(seq_lens, torch.Tensor): | |||
assert len(seq_lens.size()) == 1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}." | |||
batch_size = seq_lens.size(0) | |||
max_len = seq_lens.max() | |||
indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device).long() | |||
masks = indexes.lt(seq_lens.unsqueeze(1)) | |||
if float: | |||
masks = masks.float() | |||
return masks | |||
elif isinstance(seq_lens, list): | |||
raise NotImplemented | |||
if isinstance(seq_len, np.ndarray): | |||
assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}." | |||
max_len = int(seq_len.max()) | |||
broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) | |||
mask = broad_cast_seq_len<seq_len.reshape(-1, 1) | |||
elif isinstance(seq_len, torch.Tensor): | |||
assert seq_len.dim() == 1, f"seq_len can only have one dimension, got {seq_len.dim() == 1}." | |||
batch_size = seq_len.size(0) | |||
max_len = seq_len.max().long() | |||
broad_cast_seq_len = torch.arange(max_len).expand(batch_size, -1).to(seq_len) | |||
mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1)) | |||
else: | |||
raise NotImplemented | |||
raise TypeError("Only support 1-d numpy.ndarray or 1-d torch.Tensor.") | |||
def seq_mask(seq_len, max_len): | |||
"""Create sequence mask. | |||
:param seq_len: list or torch.Tensor, the lengths of sequences in a batch. | |||
:param max_len: int, the maximum sequence length in a batch. | |||
:return mask: torch.LongTensor, [batch_size, max_len] | |||
""" | |||
if not isinstance(seq_len, torch.Tensor): | |||
seq_len = torch.LongTensor(seq_len) | |||
seq_len = seq_len.view(-1, 1).long() # [batch_size, 1] | |||
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len] | |||
return torch.gt(seq_len, seq_range) # [batch_size, max_len] | |||
return mask | |||
class _pseudo_tqdm: | |||
@@ -10,14 +10,13 @@ from torch.nn import functional as F | |||
from ..core.const import Const as C | |||
from ..core.losses import LossFunc | |||
from ..core.metrics import MetricBase | |||
from ..core.utils import seq_lens_to_masks | |||
from ..modules.dropout import TimestepDropout | |||
from ..modules.encoder.transformer import TransformerEncoder | |||
from ..modules.encoder.variational_rnn import VarLSTM | |||
from ..modules.utils import initial_parameter | |||
from ..modules.utils import seq_mask | |||
from ..modules.utils import get_embeddings | |||
from .base_model import BaseModel | |||
from ..core.utils import seq_len_to_mask | |||
def _mst(scores): | |||
""" | |||
@@ -346,7 +345,7 @@ class BiaffineParser(GraphParser): | |||
# print('forward {} {}'.format(batch_size, seq_len)) | |||
# get sequence mask | |||
mask = seq_mask(seq_len, length).long() | |||
mask = seq_len_to_mask(seq_len).long() | |||
word = self.word_embedding(words1) # [N,L] -> [N,L,C_0] | |||
pos = self.pos_embedding(words2) # [N,L] -> [N,L,C_1] | |||
@@ -418,7 +417,7 @@ class BiaffineParser(GraphParser): | |||
""" | |||
batch_size, length, _ = pred1.shape | |||
mask = seq_mask(seq_len, length) | |||
mask = seq_len_to_mask(seq_len) | |||
flip_mask = (mask == 0) | |||
_arc_pred = pred1.clone() | |||
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf')) | |||
@@ -514,7 +513,7 @@ class ParserMetric(MetricBase): | |||
if seq_len is None: | |||
seq_mask = pred1.new_ones(pred1.size(), dtype=torch.long) | |||
else: | |||
seq_mask = seq_lens_to_masks(seq_len.long(), float=False).long() | |||
seq_mask = seq_len_to_mask(seq_len.long()).long() | |||
# mask out <root> tag | |||
seq_mask[:,0] = 0 | |||
head_pred_correct = (pred1 == target1).long() * seq_mask | |||
@@ -3,7 +3,7 @@ import torch | |||
from .base_model import BaseModel | |||
from ..modules import decoder, encoder | |||
from ..modules.decoder.CRF import allowed_transitions | |||
from ..modules.utils import seq_mask | |||
from ..core.utils import seq_len_to_mask | |||
from ..core.const import Const as C | |||
from torch import nn | |||
@@ -84,7 +84,7 @@ class SeqLabeling(BaseModel): | |||
def _make_mask(self, x, seq_len): | |||
batch_size, max_len = x.size(0), x.size(1) | |||
mask = seq_mask(seq_len, max_len) | |||
mask = seq_len_to_mask(seq_len) | |||
mask = mask.view(batch_size, max_len) | |||
mask = mask.to(x).float() | |||
return mask | |||
@@ -160,7 +160,7 @@ class AdvSeqLabel(nn.Module): | |||
def _make_mask(self, x, seq_len): | |||
batch_size, max_len = x.size(0), x.size(1) | |||
mask = seq_mask(seq_len, max_len) | |||
mask = seq_len_to_mask(seq_len) | |||
mask = mask.view(batch_size, max_len) | |||
mask = mask.to(x).float() | |||
return mask | |||
@@ -6,7 +6,7 @@ from ..core.const import Const | |||
from ..modules import decoder as Decoder | |||
from ..modules import encoder as Encoder | |||
from ..modules import aggregator as Aggregator | |||
from ..modules.utils import seq_mask | |||
from ..core.utils import seq_len_to_mask | |||
my_inf = 10e12 | |||
@@ -75,12 +75,12 @@ class ESIM(BaseModel): | |||
hypothesis0 = self.embedding_layer(self.embedding(words2)) | |||
if seq_len1 is not None: | |||
seq_len1 = seq_mask(seq_len1, premise0.size(1)) | |||
seq_len1 = seq_len_to_mask(seq_len1) | |||
else: | |||
seq_len1 = torch.ones(premise0.size(0), premise0.size(1)) | |||
seq_len1 = (seq_len1.long()).to(device=premise0.device) | |||
if seq_len2 is not None: | |||
seq_len2 = seq_mask(seq_len2, hypothesis0.size(1)) | |||
seq_len2 = seq_len_to_mask(seq_len2) | |||
else: | |||
seq_len2 = torch.ones(hypothesis0.size(0), hypothesis0.size(1)) | |||
seq_len2 = (seq_len2.long()).to(device=hypothesis0.device) | |||
@@ -1,7 +1,7 @@ | |||
"""Star-Transformer 的 一个 Pytorch 实现. | |||
""" | |||
from ..modules.encoder.star_transformer import StarTransformer | |||
from ..core.utils import seq_lens_to_masks | |||
from ..core.utils import seq_len_to_mask | |||
from ..modules.utils import get_embeddings | |||
from ..core.const import Const | |||
@@ -134,7 +134,7 @@ class STSeqLabel(nn.Module): | |||
:param seq_len: [batch,] 输入序列的长度 | |||
:return output: [batch, num_cls, seq_len] 输出序列中每个元素的分类的概率 | |||
""" | |||
mask = seq_lens_to_masks(seq_len) | |||
mask = seq_len_to_mask(seq_len) | |||
nodes, _ = self.enc(words, mask) | |||
output = self.cls(nodes) | |||
output = output.transpose(1,2) # make hidden to be dim 1 | |||
@@ -195,7 +195,7 @@ class STSeqCls(nn.Module): | |||
:param seq_len: [batch,] 输入序列的长度 | |||
:return output: [batch, num_cls] 输出序列的分类的概率 | |||
""" | |||
mask = seq_lens_to_masks(seq_len) | |||
mask = seq_len_to_mask(seq_len) | |||
nodes, relay = self.enc(words, mask) | |||
y = 0.5 * (relay + nodes.max(1)[0]) | |||
output = self.cls(y) # [bsz, n_cls] | |||
@@ -258,8 +258,8 @@ class STNLICls(nn.Module): | |||
:param seq_len2: [batch,] 输入序列2的长度 | |||
:return output: [batch, num_cls] 输出分类的概率 | |||
""" | |||
mask1 = seq_lens_to_masks(seq_len1) | |||
mask2 = seq_lens_to_masks(seq_len2) | |||
mask1 = seq_len_to_mask(seq_len1) | |||
mask2 = seq_len_to_mask(seq_len2) | |||
def enc(seq, mask): | |||
nodes, relay = self.enc(seq, mask) | |||
return 0.5 * (relay + nodes.max(1)[0]) | |||
@@ -2,17 +2,6 @@ import torch | |||
from torch import nn | |||
from ..utils import initial_parameter | |||
from ..decoder.utils import log_sum_exp | |||
def seq_len_to_byte_mask(seq_lens): | |||
# usually seq_lens: LongTensor, batch_size | |||
# return value: ByteTensor, batch_size x max_len | |||
batch_size = seq_lens.size(0) | |||
max_len = seq_lens.max() | |||
broadcast_arange = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device) | |||
mask = broadcast_arange.float().lt(seq_lens.float().view(-1, 1)) | |||
return mask | |||
def allowed_transitions(id2label, encoding_type='bio', include_start_end=True): | |||
@@ -197,13 +186,13 @@ class ConditionalRandomField(nn.Module): | |||
emit_score = logits[i].view(batch_size, 1, n_tags) | |||
trans_score = self.trans_m.view(1, n_tags, n_tags) | |||
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score | |||
alpha = log_sum_exp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ | |||
alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ | |||
alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0) | |||
if self.include_start_end_trans: | |||
alpha = alpha + self.end_scores.view(1, -1) | |||
return log_sum_exp(alpha, 1) | |||
return torch.logsumexp(alpha, 1) | |||
def _gold_score(self, logits, tags, mask): | |||
""" | |||
@@ -1,4 +1,5 @@ | |||
__all__ = ["MLP", "ConditionalRandomField", "viterbi_decode"] | |||
__all__ = ["MLP", "ConditionalRandomField", "viterbi_decode", "allowed_transitions"] | |||
from .CRF import ConditionalRandomField | |||
from .MLP import MLP | |||
from .utils import viterbi_decode | |||
from .CRF import allowed_transitions |
@@ -2,12 +2,6 @@ __all__ = ["viterbi_decode"] | |||
import torch | |||
def log_sum_exp(x, dim=-1): | |||
max_value, _ = x.max(dim=dim, keepdim=True) | |||
res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value | |||
return res.squeeze(dim) | |||
def viterbi_decode(logits, transitions, mask=None, unpad=False): | |||
"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | |||
@@ -68,22 +68,6 @@ def initial_parameter(net, initial_method=None): | |||
net.apply(weights_init) | |||
def seq_mask(seq_len, max_len): | |||
""" | |||
Create sequence mask. | |||
:param seq_len: list or torch.Tensor, the lengths of sequences in a batch. | |||
:param max_len: int, the maximum sequence length in a batch. | |||
:return: mask, torch.LongTensor, [batch_size, max_len] | |||
""" | |||
if not isinstance(seq_len, torch.Tensor): | |||
seq_len = torch.LongTensor(seq_len) | |||
seq_len = seq_len.view(-1, 1).long() # [batch_size, 1] | |||
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len] | |||
return torch.gt(seq_len, seq_range) # [batch_size, max_len] | |||
def get_embeddings(init_embed): | |||
""" | |||
得到词嵌入 TODO | |||
@@ -9,7 +9,8 @@ import os | |||
import torch | |||
from torch import nn | |||
from fastNLP.core.utils import _move_model_to_device, _get_model_device | |||
import numpy as np | |||
from fastNLP.core.utils import seq_len_to_mask | |||
class Model(nn.Module): | |||
def __init__(self): | |||
@@ -210,3 +211,42 @@ class TestCache(unittest.TestCase): | |||
finally: | |||
os.remove('test/demo1/demo.pkl') | |||
os.rmdir('test/demo1') | |||
class TestSeqLenToMask(unittest.TestCase): | |||
def evaluate_mask_seq_len(self, seq_len, mask): | |||
max_len = int(max(seq_len)) | |||
for i in range(len(seq_len)): | |||
length = seq_len[i] | |||
mask_i = mask[i] | |||
for j in range(max_len): | |||
self.assertEqual(mask_i[j], j<length) | |||
def test_numpy_seq_len(self): | |||
# 测试能否转换numpy类型的seq_len | |||
# 1. 随机测试 | |||
seq_len = np.random.randint(1, 10, size=(10, )) | |||
mask = seq_len_to_mask(seq_len) | |||
max_len = seq_len.max() | |||
self.assertEqual(max_len, mask.shape[1]) | |||
self.evaluate_mask_seq_len(seq_len, mask) | |||
# 2. 异常检测 | |||
seq_len = np.random.randint(10, size=(10, 1)) | |||
with self.assertRaises(AssertionError): | |||
mask = seq_len_to_mask(seq_len) | |||
def test_pytorch_seq_len(self): | |||
# 1. 随机测试 | |||
seq_len = torch.randint(1, 10, size=(10, )) | |||
max_len = seq_len.max() | |||
mask = seq_len_to_mask(seq_len) | |||
self.assertEqual(max_len, mask.shape[1]) | |||
self.evaluate_mask_seq_len(seq_len.tolist(), mask) | |||
# 2. 异常检测 | |||
seq_len = torch.randn(3, 4) | |||
with self.assertRaises(AssertionError): | |||
mask = seq_len_to_mask(seq_len) |
@@ -105,7 +105,7 @@ class TestCRF(unittest.TestCase): | |||
# 测试crf的loss不会出现负数 | |||
import torch | |||
from fastNLP.modules.decoder.CRF import ConditionalRandomField | |||
from fastNLP.core.utils import seq_lens_to_masks | |||
from fastNLP.core.utils import seq_len_to_mask | |||
from torch import optim | |||
from torch import nn | |||
@@ -114,7 +114,7 @@ class TestCRF(unittest.TestCase): | |||
lengths = torch.randint(3, 50, size=(num_samples, )).long() | |||
max_len = lengths.max() | |||
tags = torch.randint(num_tags, size=(num_samples, max_len)) | |||
masks = seq_lens_to_masks(lengths) | |||
masks = seq_len_to_mask(lengths) | |||
feats = nn.Parameter(torch.randn(num_samples, max_len, num_tags)) | |||
crf = ConditionalRandomField(num_tags, include_start_end_trans) | |||
optimizer = optim.SGD([param for param in crf.parameters() if param.requires_grad] + [feats], lr=0.1) | |||
@@ -125,4 +125,4 @@ class TestCRF(unittest.TestCase): | |||
optimizer.step() | |||
if _%1000==0: | |||
print(loss) | |||
assert loss.item()>0, "CRF loss cannot be less than 0." | |||
self.assertGreater(loss.item(), 0, "CRF loss cannot be less than 0.") |