@@ -138,6 +138,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
msg = f"Fail to get padder for field:{field_name}. " + e.msg + " To view more " \ | msg = f"Fail to get padder for field:{field_name}. " + e.msg + " To view more " \ | ||||
"information please set logger's level to DEBUG." | "information please set logger's level to DEBUG." | ||||
if must_pad: | if must_pad: | ||||
logger.error(msg) | |||||
raise type(e)(msg=msg) | raise type(e)(msg=msg) | ||||
logger.debug(msg) | logger.debug(msg) | ||||
return NullPadder() | return NullPadder() | ||||
@@ -16,6 +16,7 @@ if _NEED_IMPORT_TORCH: | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from torch.nn import LSTM | |||||
from .embedding import TokenEmbedding | from .embedding import TokenEmbedding | ||||
from .static_embedding import StaticEmbedding | from .static_embedding import StaticEmbedding | ||||
@@ -23,7 +24,6 @@ from .utils import _construct_char_vocab_from_vocab | |||||
from .utils import get_embeddings | from .utils import get_embeddings | ||||
from ...core import logger | from ...core import logger | ||||
from ...core.vocabulary import Vocabulary | from ...core.vocabulary import Vocabulary | ||||
from ...modules.torch.encoder.lstm import LSTM | |||||
class CNNCharEmbedding(TokenEmbedding): | class CNNCharEmbedding(TokenEmbedding): | ||||
@@ -0,0 +1,21 @@ | |||||
__all__ = [ | |||||
'BiaffineParser', | |||||
"CNNText", | |||||
"SequenceGeneratorModel", | |||||
"Seq2SeqModel", | |||||
'TransformerSeq2SeqModel', | |||||
'LSTMSeq2SeqModel', | |||||
"SeqLabeling", | |||||
"AdvSeqLabel", | |||||
"BiLSTMCRF", | |||||
] | |||||
from .biaffine_parser import BiaffineParser | |||||
from .cnn_text_classification import CNNText | |||||
from .seq2seq_generator import SequenceGeneratorModel | |||||
from .seq2seq_model import * | |||||
from .sequence_labeling import * |
@@ -0,0 +1,475 @@ | |||||
r""" | |||||
Biaffine Dependency Parser 的 Pytorch 实现. | |||||
""" | |||||
__all__ = [ | |||||
"BiaffineParser", | |||||
"GraphParser" | |||||
] | |||||
from collections import defaultdict | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from ...core.utils import seq_len_to_mask | |||||
from ...embeddings.torch.utils import get_embeddings | |||||
from ...modules.torch.dropout import TimestepDropout | |||||
from ...modules.torch.encoder.transformer import TransformerEncoder | |||||
from ...modules.torch.encoder.variational_rnn import VarLSTM | |||||
def _mst(scores): | |||||
r""" | |||||
with some modification to support parser output for MST decoding | |||||
https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692 | |||||
""" | |||||
length = scores.shape[0] | |||||
min_score = scores.min() - 1 | |||||
eye = np.eye(length) | |||||
scores = scores * (1 - eye) + min_score * eye | |||||
heads = np.argmax(scores, axis=1) | |||||
heads[0] = 0 | |||||
tokens = np.arange(1, length) | |||||
roots = np.where(heads[tokens] == 0)[0] + 1 | |||||
if len(roots) < 1: | |||||
root_scores = scores[tokens, 0] | |||||
head_scores = scores[tokens, heads[tokens]] | |||||
new_root = tokens[np.argmax(root_scores / head_scores)] | |||||
heads[new_root] = 0 | |||||
elif len(roots) > 1: | |||||
root_scores = scores[roots, 0] | |||||
scores[roots, 0] = 0 | |||||
new_heads = np.argmax(scores[roots][:, tokens], axis=1) + 1 | |||||
new_root = roots[np.argmin( | |||||
scores[roots, new_heads] / root_scores)] | |||||
heads[roots] = new_heads | |||||
heads[new_root] = 0 | |||||
edges = defaultdict(set) | |||||
vertices = set((0,)) | |||||
for dep, head in enumerate(heads[tokens]): | |||||
vertices.add(dep + 1) | |||||
edges[head].add(dep + 1) | |||||
for cycle in _find_cycle(vertices, edges): | |||||
dependents = set() | |||||
to_visit = set(cycle) | |||||
while len(to_visit) > 0: | |||||
node = to_visit.pop() | |||||
if node not in dependents: | |||||
dependents.add(node) | |||||
to_visit.update(edges[node]) | |||||
cycle = np.array(list(cycle)) | |||||
old_heads = heads[cycle] | |||||
old_scores = scores[cycle, old_heads] | |||||
non_heads = np.array(list(dependents)) | |||||
scores[np.repeat(cycle, len(non_heads)), | |||||
np.repeat([non_heads], len(cycle), axis=0).flatten()] = min_score | |||||
new_heads = np.argmax(scores[cycle][:, tokens], axis=1) + 1 | |||||
new_scores = scores[cycle, new_heads] / old_scores | |||||
change = np.argmax(new_scores) | |||||
changed_cycle = cycle[change] | |||||
old_head = old_heads[change] | |||||
new_head = new_heads[change] | |||||
heads[changed_cycle] = new_head | |||||
edges[new_head].add(changed_cycle) | |||||
edges[old_head].remove(changed_cycle) | |||||
return heads | |||||
def _find_cycle(vertices, edges): | |||||
r""" | |||||
https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm | |||||
https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/etc/tarjan.py | |||||
""" | |||||
_index = 0 | |||||
_stack = [] | |||||
_indices = {} | |||||
_lowlinks = {} | |||||
_onstack = defaultdict(lambda: False) | |||||
_SCCs = [] | |||||
def _strongconnect(v): | |||||
nonlocal _index | |||||
_indices[v] = _index | |||||
_lowlinks[v] = _index | |||||
_index += 1 | |||||
_stack.append(v) | |||||
_onstack[v] = True | |||||
for w in edges[v]: | |||||
if w not in _indices: | |||||
_strongconnect(w) | |||||
_lowlinks[v] = min(_lowlinks[v], _lowlinks[w]) | |||||
elif _onstack[w]: | |||||
_lowlinks[v] = min(_lowlinks[v], _indices[w]) | |||||
if _lowlinks[v] == _indices[v]: | |||||
SCC = set() | |||||
while True: | |||||
w = _stack.pop() | |||||
_onstack[w] = False | |||||
SCC.add(w) | |||||
if not (w != v): | |||||
break | |||||
_SCCs.append(SCC) | |||||
for v in vertices: | |||||
if v not in _indices: | |||||
_strongconnect(v) | |||||
return [SCC for SCC in _SCCs if len(SCC) > 1] | |||||
class GraphParser(nn.Module): | |||||
r""" | |||||
基于图的parser base class, 支持贪婪解码和最大生成树解码 | |||||
""" | |||||
def __init__(self): | |||||
super(GraphParser, self).__init__() | |||||
@staticmethod | |||||
def greedy_decoder(arc_matrix, mask=None): | |||||
r""" | |||||
贪心解码方式, 输入图, 输出贪心解码的parsing结果, 不保证合法的构成树 | |||||
:param arc_matrix: [batch, seq_len, seq_len] 输入图矩阵 | |||||
:param mask: [batch, seq_len] 输入图的padding mask, 有内容的部分为 1, 否则为 0. | |||||
若为 ``None`` 时, 默认为全1向量. Default: ``None`` | |||||
:return heads: [batch, seq_len] 每个元素在树中对应的head(parent)预测结果 | |||||
""" | |||||
_, seq_len, _ = arc_matrix.shape | |||||
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) | |||||
flip_mask = mask.eq(False) | |||||
matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | |||||
_, heads = torch.max(matrix, dim=2) | |||||
if mask is not None: | |||||
heads *= mask.long() | |||||
return heads | |||||
@staticmethod | |||||
def mst_decoder(arc_matrix, mask=None): | |||||
r""" | |||||
用最大生成树算法, 计算parsing结果, 保证输出合法的树结构 | |||||
:param arc_matrix: [batch, seq_len, seq_len] 输入图矩阵 | |||||
:param mask: [batch, seq_len] 输入图的padding mask, 有内容的部分为 1, 否则为 0. | |||||
若为 ``None`` 时, 默认为全1向量. Default: ``None`` | |||||
:return heads: [batch, seq_len] 每个元素在树中对应的head(parent)预测结果 | |||||
""" | |||||
batch_size, seq_len, _ = arc_matrix.shape | |||||
matrix = arc_matrix.clone() | |||||
ans = matrix.new_zeros(batch_size, seq_len).long() | |||||
lens = (mask.long()).sum(1) if mask is not None else torch.zeros(batch_size) + seq_len | |||||
for i, graph in enumerate(matrix): | |||||
len_i = lens[i] | |||||
ans[i, :len_i] = torch.as_tensor(_mst(graph.detach()[:len_i, :len_i].cpu().numpy()), device=ans.device) | |||||
if mask is not None: | |||||
ans *= mask.long() | |||||
return ans | |||||
class ArcBiaffine(nn.Module): | |||||
r""" | |||||
Biaffine Dependency Parser 的子模块, 用于构建预测边的图 | |||||
""" | |||||
def __init__(self, hidden_size, bias=True): | |||||
r""" | |||||
:param hidden_size: 输入的特征维度 | |||||
:param bias: 是否使用bias. Default: ``True`` | |||||
""" | |||||
super(ArcBiaffine, self).__init__() | |||||
self.U = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad=True) | |||||
self.has_bias = bias | |||||
if self.has_bias: | |||||
self.bias = nn.Parameter(torch.Tensor(hidden_size), requires_grad=True) | |||||
else: | |||||
self.register_parameter("bias", None) | |||||
def forward(self, head, dep): | |||||
r""" | |||||
:param head: arc-head tensor [batch, length, hidden] | |||||
:param dep: arc-dependent tensor [batch, length, hidden] | |||||
:return output: tensor [bacth, length, length] | |||||
""" | |||||
output = dep.matmul(self.U) | |||||
output = output.bmm(head.transpose(-1, -2)) | |||||
if self.has_bias: | |||||
output = output + head.matmul(self.bias).unsqueeze(1) | |||||
return output | |||||
class LabelBilinear(nn.Module): | |||||
r""" | |||||
Biaffine Dependency Parser 的子模块, 用于构建预测边类别的图 | |||||
""" | |||||
def __init__(self, in1_features, in2_features, num_label, bias=True): | |||||
r""" | |||||
:param in1_features: 输入的特征1维度 | |||||
:param in2_features: 输入的特征2维度 | |||||
:param num_label: 边类别的个数 | |||||
:param bias: 是否使用bias. Default: ``True`` | |||||
""" | |||||
super(LabelBilinear, self).__init__() | |||||
self.bilinear = nn.Bilinear(in1_features, in2_features, num_label, bias=bias) | |||||
self.lin = nn.Linear(in1_features + in2_features, num_label, bias=False) | |||||
def forward(self, x1, x2): | |||||
r""" | |||||
:param x1: [batch, seq_len, hidden] 输入特征1, 即label-head | |||||
:param x2: [batch, seq_len, hidden] 输入特征2, 即label-dep | |||||
:return output: [batch, seq_len, num_cls] 每个元素对应类别的概率图 | |||||
""" | |||||
output = self.bilinear(x1, x2) | |||||
output = output + self.lin(torch.cat([x1, x2], dim=2)) | |||||
return output | |||||
class BiaffineParser(GraphParser): | |||||
r""" | |||||
Biaffine Dependency Parser 实现. | |||||
论文参考 `Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) <https://arxiv.org/abs/1611.01734>`_ . | |||||
""" | |||||
def __init__(self, | |||||
embed, | |||||
pos_vocab_size, | |||||
pos_emb_dim, | |||||
num_label, | |||||
rnn_layers=1, | |||||
rnn_hidden_size=200, | |||||
arc_mlp_size=100, | |||||
label_mlp_size=100, | |||||
dropout=0.3, | |||||
encoder='lstm', | |||||
use_greedy_infer=False): | |||||
r""" | |||||
:param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 | |||||
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, | |||||
此时就以传入的对象作为embedding | |||||
:param pos_vocab_size: part-of-speech 词典大小 | |||||
:param pos_emb_dim: part-of-speech 向量维度 | |||||
:param num_label: 边的类别个数 | |||||
:param rnn_layers: rnn encoder的层数 | |||||
:param rnn_hidden_size: rnn encoder 的隐状态维度 | |||||
:param arc_mlp_size: 边预测的MLP维度 | |||||
:param label_mlp_size: 类别预测的MLP维度 | |||||
:param dropout: dropout概率. | |||||
:param encoder: encoder类别, 可选 ('lstm', 'var-lstm', 'transformer'). Default: lstm | |||||
:param use_greedy_infer: 是否在inference时使用贪心算法. | |||||
若 ``False`` , 使用更加精确但相对缓慢的MST算法. Default: ``False`` | |||||
""" | |||||
super(BiaffineParser, self).__init__() | |||||
rnn_out_size = 2 * rnn_hidden_size | |||||
word_hid_dim = pos_hid_dim = rnn_hidden_size | |||||
self.word_embedding = get_embeddings(embed) | |||||
word_emb_dim = self.word_embedding.embedding_dim | |||||
self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim) | |||||
self.word_fc = nn.Linear(word_emb_dim, word_hid_dim) | |||||
self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim) | |||||
self.word_norm = nn.LayerNorm(word_hid_dim) | |||||
self.pos_norm = nn.LayerNorm(pos_hid_dim) | |||||
self.encoder_name = encoder | |||||
self.max_len = 512 | |||||
if encoder == 'var-lstm': | |||||
self.encoder = VarLSTM(input_size=word_hid_dim + pos_hid_dim, | |||||
hidden_size=rnn_hidden_size, | |||||
num_layers=rnn_layers, | |||||
bias=True, | |||||
batch_first=True, | |||||
input_dropout=dropout, | |||||
hidden_dropout=dropout, | |||||
bidirectional=True) | |||||
elif encoder == 'lstm': | |||||
self.encoder = nn.LSTM(input_size=word_hid_dim + pos_hid_dim, | |||||
hidden_size=rnn_hidden_size, | |||||
num_layers=rnn_layers, | |||||
bias=True, | |||||
batch_first=True, | |||||
dropout=dropout, | |||||
bidirectional=True) | |||||
elif encoder == 'transformer': | |||||
n_head = 16 | |||||
d_k = d_v = int(rnn_out_size / n_head) | |||||
if (d_k * n_head) != rnn_out_size: | |||||
raise ValueError('Unsupported rnn_out_size: {} for transformer'.format(rnn_out_size)) | |||||
self.position_emb = nn.Embedding(num_embeddings=self.max_len, | |||||
embedding_dim=rnn_out_size, ) | |||||
self.encoder = TransformerEncoder( num_layers=rnn_layers, d_model=rnn_out_size, | |||||
n_head=n_head, dim_ff=1024, dropout=dropout) | |||||
else: | |||||
raise ValueError('Unsupported encoder type: {}'.format(encoder)) | |||||
self.mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size * 2 + label_mlp_size * 2), | |||||
nn.ELU(), | |||||
TimestepDropout(p=dropout), ) | |||||
self.arc_mlp_size = arc_mlp_size | |||||
self.label_mlp_size = label_mlp_size | |||||
self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | |||||
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | |||||
self.use_greedy_infer = use_greedy_infer | |||||
self.reset_parameters() | |||||
self.dropout = dropout | |||||
def reset_parameters(self): | |||||
for m in self.modules(): | |||||
if isinstance(m, nn.Embedding): | |||||
continue | |||||
elif isinstance(m, nn.LayerNorm): | |||||
nn.init.constant_(m.weight, 0.1) | |||||
nn.init.constant_(m.bias, 0) | |||||
else: | |||||
for p in m.parameters(): | |||||
nn.init.normal_(p, 0, 0.1) | |||||
def forward(self, words1, words2, seq_len, target1=None): | |||||
r"""模型forward阶段 | |||||
:param words1: [batch_size, seq_len] 输入word序列 | |||||
:param words2: [batch_size, seq_len] 输入pos序列 | |||||
:param seq_len: [batch_size, seq_len] 输入序列长度 | |||||
:param target1: [batch_size, seq_len] 输入真实标注的heads, 仅在训练阶段有效, | |||||
用于训练label分类器. 若为 ``None`` , 使用预测的heads输入到label分类器 | |||||
Default: ``None`` | |||||
:return dict: parsing | |||||
结果:: | |||||
pred1: [batch_size, seq_len, seq_len] 边预测logits | |||||
pred2: [batch_size, seq_len, num_label] label预测logits | |||||
pred3: [batch_size, seq_len] heads的预测结果, 在 ``target1=None`` 时预测 | |||||
""" | |||||
# prepare embeddings | |||||
batch_size, length = words1.shape | |||||
# print('forward {} {}'.format(batch_size, seq_len)) | |||||
# get sequence mask | |||||
mask = seq_len_to_mask(seq_len, max_len=length).long() | |||||
word = self.word_embedding(words1) # [N,L] -> [N,L,C_0] | |||||
pos = self.pos_embedding(words2) # [N,L] -> [N,L,C_1] | |||||
word, pos = self.word_fc(word), self.pos_fc(pos) | |||||
word, pos = self.word_norm(word), self.pos_norm(pos) | |||||
x = torch.cat([word, pos], dim=2) # -> [N,L,C] | |||||
# encoder, extract features | |||||
if self.encoder_name.endswith('lstm'): | |||||
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) | |||||
x = x[sort_idx] | |||||
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens.cpu(), batch_first=True) | |||||
feat, _ = self.encoder(x) # -> [N,L,C] | |||||
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) | |||||
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | |||||
feat = feat[unsort_idx] | |||||
else: | |||||
seq_range = torch.arange(length, dtype=torch.long, device=x.device)[None, :] | |||||
x = x + self.position_emb(seq_range) | |||||
feat = self.encoder(x, mask.float()) | |||||
# for arc biaffine | |||||
# mlp, reduce dim | |||||
feat = self.mlp(feat) | |||||
arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size | |||||
arc_dep, arc_head = feat[:, :, :arc_sz], feat[:, :, arc_sz:2 * arc_sz] | |||||
label_dep, label_head = feat[:, :, 2 * arc_sz:2 * arc_sz + label_sz], feat[:, :, 2 * arc_sz + label_sz:] | |||||
# biaffine arc classifier | |||||
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | |||||
# use gold or predicted arc to predict label | |||||
if target1 is None or not self.training: | |||||
# use greedy decoding in training | |||||
if self.training or self.use_greedy_infer: | |||||
heads = self.greedy_decoder(arc_pred, mask) | |||||
else: | |||||
heads = self.mst_decoder(arc_pred, mask) | |||||
head_pred = heads | |||||
else: | |||||
assert self.training # must be training mode | |||||
if target1 is None: | |||||
heads = self.greedy_decoder(arc_pred, mask) | |||||
head_pred = heads | |||||
else: | |||||
head_pred = None | |||||
heads = target1 | |||||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=words1.device).unsqueeze(1) | |||||
label_head = label_head[batch_range, heads].contiguous() | |||||
label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] | |||||
res_dict = {'pred1': arc_pred, 'pred2': label_pred} | |||||
if head_pred is not None: | |||||
res_dict['pred3'] = head_pred | |||||
return res_dict | |||||
def train_step(self, words1, words2, seq_len, target1, target2): | |||||
res = self(words1, words2, seq_len, target1) | |||||
arc_pred = res['pred1'] | |||||
label_pred = res['pred2'] | |||||
loss = self.loss(pred1=arc_pred, pred2=label_pred, target1=target1, target2=target2, seq_len=seq_len) | |||||
return {'loss': loss} | |||||
@staticmethod | |||||
def loss(pred1, pred2, target1, target2, seq_len): | |||||
r""" | |||||
计算parser的loss | |||||
:param pred1: [batch_size, seq_len, seq_len] 边预测logits | |||||
:param pred2: [batch_size, seq_len, num_label] label预测logits | |||||
:param target1: [batch_size, seq_len] 真实边的标注 | |||||
:param target2: [batch_size, seq_len] 真实类别的标注 | |||||
:param seq_len: [batch_size, seq_len] 真实目标的长度 | |||||
:return loss: scalar | |||||
""" | |||||
batch_size, length, _ = pred1.shape | |||||
mask = seq_len_to_mask(seq_len, max_len=length) | |||||
flip_mask = (mask.eq(False)) | |||||
_arc_pred = pred1.clone() | |||||
_arc_pred = _arc_pred.masked_fill(flip_mask.unsqueeze(1), -float('inf')) | |||||
arc_logits = F.log_softmax(_arc_pred, dim=2) | |||||
label_logits = F.log_softmax(pred2, dim=2) | |||||
batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1) | |||||
child_index = torch.arange(length, device=arc_logits.device, dtype=torch.long).unsqueeze(0) | |||||
arc_loss = arc_logits[batch_index, child_index, target1] | |||||
label_loss = label_logits[batch_index, child_index, target2] | |||||
arc_loss = arc_loss.masked_fill(flip_mask, 0) | |||||
label_loss = label_loss.masked_fill(flip_mask, 0) | |||||
arc_nll = -arc_loss.mean() | |||||
label_nll = -label_loss.mean() | |||||
return arc_nll + label_nll | |||||
def evaluate_step(self, words1, words2, seq_len): | |||||
r"""模型预测API | |||||
:param words1: [batch_size, seq_len] 输入word序列 | |||||
:param words2: [batch_size, seq_len] 输入pos序列 | |||||
:param seq_len: [batch_size, seq_len] 输入序列长度 | |||||
:return dict: parsing | |||||
结果:: | |||||
pred1: [batch_size, seq_len] heads的预测结果 | |||||
pred2: [batch_size, seq_len, num_label] label预测logits | |||||
""" | |||||
res = self(words1, words2, seq_len) | |||||
output = {} | |||||
output['pred1'] = res.pop('pred3') | |||||
_, label_pred = res.pop('pred2').max(2) | |||||
output['pred2'] = label_pred | |||||
return output | |||||
@@ -0,0 +1,92 @@ | |||||
r""" | |||||
.. todo:: | |||||
doc | |||||
""" | |||||
__all__ = [ | |||||
"CNNText" | |||||
] | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from ...core.utils import seq_len_to_mask | |||||
from ...embeddings.torch import embedding | |||||
from ...modules.torch import encoder | |||||
class CNNText(torch.nn.Module): | |||||
r""" | |||||
使用CNN进行文本分类的模型 | |||||
'Yoon Kim. 2014. Convolution Neural Networks for Sentence Classification.' | |||||
""" | |||||
def __init__(self, embed, | |||||
num_classes, | |||||
kernel_nums=(30, 40, 50), | |||||
kernel_sizes=(1, 3, 5), | |||||
dropout=0.5): | |||||
r""" | |||||
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int), | |||||
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding | |||||
:param int num_classes: 一共有多少类 | |||||
:param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 | |||||
:param float dropout: Dropout的大小 | |||||
""" | |||||
super(CNNText, self).__init__() | |||||
# no support for pre-trained embedding currently | |||||
self.embed = embedding.Embedding(embed) | |||||
self.conv_pool = encoder.ConvMaxpool( | |||||
in_channels=self.embed.embedding_dim, | |||||
out_channels=kernel_nums, | |||||
kernel_sizes=kernel_sizes) | |||||
self.dropout = nn.Dropout(dropout) | |||||
self.fc = nn.Linear(sum(kernel_nums), num_classes) | |||||
def forward(self, words, seq_len=None): | |||||
r""" | |||||
:param torch.LongTensor words: [batch_size, seq_len],句子中word的index | |||||
:param torch.LongTensor seq_len: [batch,] 每个句子的长度 | |||||
:param target: 每个 sample 的目标值。 | |||||
:return output: | |||||
""" | |||||
x = self.embed(words) # [N,L] -> [N,L,C] | |||||
if seq_len is not None: | |||||
mask = seq_len_to_mask(seq_len) | |||||
x = self.conv_pool(x, mask) | |||||
else: | |||||
x = self.conv_pool(x) # [N,L,C] -> [N,C] | |||||
x = self.dropout(x) | |||||
x = self.fc(x) # [N,C] -> [N, N_class] | |||||
res = {'pred': x} | |||||
return res | |||||
def train_step(self, words, target, seq_len=None): | |||||
""" | |||||
:param words: | |||||
:param target: | |||||
:param seq_len: | |||||
:return: | |||||
""" | |||||
res = self(words, seq_len) | |||||
x = res['pred'] | |||||
loss = F.cross_entropy(x, target) | |||||
return {'loss': loss} | |||||
def evaluate_step(self, words, seq_len=None): | |||||
r""" | |||||
:param torch.LongTensor words: [batch_size, seq_len],句子中word的index | |||||
:param torch.LongTensor seq_len: [batch,] 每个句子的长度 | |||||
:return predict: dict of torch.LongTensor, [batch_size, ] | |||||
""" | |||||
output = self(words, seq_len) | |||||
_, predict = output['pred'].max(dim=1) | |||||
return {'pred': predict} |
@@ -0,0 +1,81 @@ | |||||
r"""undocumented""" | |||||
import torch | |||||
from torch import nn | |||||
import torch.nn.functional as F | |||||
from fastNLP import seq_len_to_mask | |||||
from .seq2seq_model import Seq2SeqModel | |||||
from ...modules.torch.generator.seq2seq_generator import SequenceGenerator | |||||
__all__ = ['SequenceGeneratorModel'] | |||||
class SequenceGeneratorModel(nn.Module): | |||||
""" | |||||
通过使用本模型封装seq2seq_model使得其既可以用于训练也可以用于生成。训练的时候,本模型的forward函数会被调用,生成的时候本模型的predict | |||||
函数会被调用。 | |||||
""" | |||||
def __init__(self, seq2seq_model: Seq2SeqModel, bos_token_id, eos_token_id=None, max_length=30, max_len_a=0.0, | |||||
num_beams=1, do_sample=True, temperature=1.0, top_k=50, top_p=1.0, | |||||
repetition_penalty=1, length_penalty=1.0, pad_token_id=0): | |||||
""" | |||||
:param Seq2SeqModel seq2seq_model: 序列到序列模型 | |||||
:param int,None bos_token_id: 句子开头的token id | |||||
:param int,None eos_token_id: 句子结束的token id | |||||
:param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len | |||||
:param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask | |||||
:param int num_beams: beam search的大小 | |||||
:param bool do_sample: 是否通过采样的方式生成 | |||||
:param float temperature: 只有在do_sample为True才有意义 | |||||
:param int top_k: 只从top_k中采样 | |||||
:param float top_p: 只从top_p的token中采样,nucles sample | |||||
:param float repetition_penalty: 多大程度上惩罚重复的token | |||||
:param float length_penalty: 对长度的惩罚,小于1鼓励长句,大于1鼓励短剧 | |||||
:param int pad_token_id: 当某句话生成结束之后,之后生成的内容用pad_token_id补充 | |||||
""" | |||||
super().__init__() | |||||
self.seq2seq_model = seq2seq_model | |||||
self.generator = SequenceGenerator(seq2seq_model.decoder, max_length=max_length, max_len_a=max_len_a, | |||||
num_beams=num_beams, | |||||
do_sample=do_sample, temperature=temperature, top_k=top_k, top_p=top_p, | |||||
bos_token_id=bos_token_id, | |||||
eos_token_id=eos_token_id, | |||||
repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||||
pad_token_id=pad_token_id) | |||||
def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None): | |||||
""" | |||||
透传调用seq2seq_model的forward。 | |||||
:param torch.LongTensor src_tokens: bsz x max_len | |||||
:param torch.LongTensor tgt_tokens: bsz x max_len' | |||||
:param torch.LongTensor src_seq_len: bsz | |||||
:param torch.LongTensor tgt_seq_len: bsz | |||||
:return: | |||||
""" | |||||
return self.seq2seq_model(src_tokens, tgt_tokens, src_seq_len, tgt_seq_len) | |||||
def train_step(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None): | |||||
res = self(src_tokens, tgt_tokens, src_seq_len, tgt_seq_len) | |||||
pred = res['pred'] | |||||
if tgt_seq_len is not None: | |||||
mask = seq_len_to_mask(tgt_seq_len, max_len=tgt_tokens.size(1)) | |||||
tgt_tokens = tgt_tokens.masked_fill(mask.eq(0), -100) | |||||
loss = F.cross_entropy(pred.transpose(1, 2), tgt_tokens) | |||||
return {'loss': loss} | |||||
def evaluate_step(self, src_tokens, src_seq_len=None): | |||||
""" | |||||
给定source的内容,输出generate的内容。 | |||||
:param torch.LongTensor src_tokens: bsz x max_len | |||||
:param torch.LongTensor src_seq_len: bsz | |||||
:return: | |||||
""" | |||||
state = self.seq2seq_model.prepare_state(src_tokens, src_seq_len) | |||||
result = self.generator.generate(state) | |||||
return {'pred': result} |
@@ -0,0 +1,196 @@ | |||||
r""" | |||||
主要包含组成Sequence-to-Sequence的model | |||||
""" | |||||
import torch | |||||
from torch import nn | |||||
import torch.nn.functional as F | |||||
from fastNLP import seq_len_to_mask | |||||
from ...embeddings.torch.utils import get_embeddings | |||||
from ...embeddings.torch.utils import get_sinusoid_encoding_table | |||||
from ...modules.torch.decoder.seq2seq_decoder import Seq2SeqDecoder, TransformerSeq2SeqDecoder, LSTMSeq2SeqDecoder | |||||
from ...modules.torch.encoder.seq2seq_encoder import Seq2SeqEncoder, TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder | |||||
__all__ = ['Seq2SeqModel', 'TransformerSeq2SeqModel', 'LSTMSeq2SeqModel'] | |||||
class Seq2SeqModel(nn.Module): | |||||
def __init__(self, encoder: Seq2SeqEncoder, decoder: Seq2SeqDecoder): | |||||
""" | |||||
可以用于在Trainer中训练的Seq2Seq模型。正常情况下,继承了该函数之后,只需要实现classmethod build_model即可。如果需要使用该模型 | |||||
进行生成,需要把该模型输入到 :class:`~fastNLP.models.SequenceGeneratorModel` 中。在本模型中,forward()会把encoder后的 | |||||
结果传入到decoder中,并将decoder的输出output出来。 | |||||
:param encoder: Seq2SeqEncoder 对象,需要实现对应的forward()函数,接受两个参数,第一个为bsz x max_len的source tokens, 第二个为 | |||||
bsz的source的长度;需要返回两个tensor: encoder_outputs: bsz x max_len x hidden_size, encoder_mask: bsz x max_len | |||||
为1的地方需要被attend。如果encoder的输出或者输入有变化,可以重载本模型的prepare_state()函数或者forward()函数 | |||||
:param decoder: Seq2SeqDecoder 对象,需要实现init_state()函数,输出为两个参数,第一个为bsz x max_len x hidden_size是 | |||||
encoder的输出; 第二个为bsz x max_len,为encoder输出的mask,为0的地方为pad。若decoder需要更多输入,请重载当前模型的 | |||||
prepare_state()或forward()函数 | |||||
""" | |||||
super().__init__() | |||||
self.encoder = encoder | |||||
self.decoder = decoder | |||||
def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None): | |||||
""" | |||||
:param torch.LongTensor src_tokens: source的token | |||||
:param torch.LongTensor tgt_tokens: target的token | |||||
:param torch.LongTensor src_seq_len: src的长度 | |||||
:param torch.LongTensor tgt_seq_len: target的长度,默认用不上 | |||||
:return: {'pred': torch.Tensor}, 其中pred的shape为bsz x max_len x vocab_size | |||||
""" | |||||
state = self.prepare_state(src_tokens, src_seq_len) | |||||
decoder_output = self.decoder(tgt_tokens, state) | |||||
if isinstance(decoder_output, torch.Tensor): | |||||
return {'pred': decoder_output} | |||||
elif isinstance(decoder_output, (tuple, list)): | |||||
return {'pred': decoder_output[0]} | |||||
else: | |||||
raise TypeError(f"Unsupported return type from Decoder:{type(self.decoder)}") | |||||
def train_step(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None): | |||||
res = self(src_tokens, tgt_tokens, src_seq_len, tgt_seq_len) | |||||
pred = res['pred'] | |||||
if tgt_seq_len is not None: | |||||
mask = seq_len_to_mask(tgt_seq_len, max_len=tgt_tokens.size(1)) | |||||
tgt_tokens = tgt_tokens.masked_fill(mask.eq(0), -100) | |||||
loss = F.cross_entropy(pred.transpose(1, 2), tgt_tokens) | |||||
return {'loss': loss} | |||||
def prepare_state(self, src_tokens, src_seq_len=None): | |||||
""" | |||||
调用encoder获取state,会把encoder的encoder_output, encoder_mask直接传入到decoder.init_state中初始化一个state | |||||
:param src_tokens: | |||||
:param src_seq_len: | |||||
:return: | |||||
""" | |||||
encoder_output, encoder_mask = self.encoder(src_tokens, src_seq_len) | |||||
state = self.decoder.init_state(encoder_output, encoder_mask) | |||||
return state | |||||
@classmethod | |||||
def build_model(cls, *args, **kwargs): | |||||
""" | |||||
需要实现本方法来进行Seq2SeqModel的初始化 | |||||
:return: | |||||
""" | |||||
raise NotImplemented | |||||
class TransformerSeq2SeqModel(Seq2SeqModel): | |||||
""" | |||||
Encoder为TransformerSeq2SeqEncoder, decoder为TransformerSeq2SeqDecoder,通过build_model方法初始化 | |||||
""" | |||||
def __init__(self, encoder, decoder): | |||||
super().__init__(encoder, decoder) | |||||
@classmethod | |||||
def build_model(cls, src_embed, tgt_embed=None, | |||||
pos_embed='sin', max_position=1024, num_layers=6, d_model=512, n_head=8, dim_ff=2048, dropout=0.1, | |||||
bind_encoder_decoder_embed=False, | |||||
bind_decoder_input_output_embed=True): | |||||
""" | |||||
初始化一个TransformerSeq2SeqModel | |||||
:param nn.Module, StaticEmbedding, Tuple[int, int] src_embed: source的embedding | |||||
:param nn.Module, StaticEmbedding, Tuple[int, int] tgt_embed: target的embedding,如果bind_encoder_decoder_embed为 | |||||
True,则不要输入该值 | |||||
:param str pos_embed: 支持sin, learned两种 | |||||
:param int max_position: 最大支持长度 | |||||
:param int num_layers: encoder和decoder的层数 | |||||
:param int d_model: encoder和decoder输入输出的大小 | |||||
:param int n_head: encoder和decoder的head的数量 | |||||
:param int dim_ff: encoder和decoder中FFN中间映射的维度 | |||||
:param float dropout: Attention和FFN dropout的大小 | |||||
:param bool bind_encoder_decoder_embed: 是否对encoder和decoder使用相同的embedding | |||||
:param bool bind_decoder_input_output_embed: decoder的输出embedding是否与其输入embedding是一样的权重 | |||||
:return: TransformerSeq2SeqModel | |||||
""" | |||||
if bind_encoder_decoder_embed and tgt_embed is not None: | |||||
raise RuntimeError("If you set `bind_encoder_decoder_embed=True`, please do not provide `tgt_embed`.") | |||||
src_embed = get_embeddings(src_embed) | |||||
if bind_encoder_decoder_embed: | |||||
tgt_embed = src_embed | |||||
else: | |||||
assert tgt_embed is not None, "You need to pass `tgt_embed` when `bind_encoder_decoder_embed=False`" | |||||
tgt_embed = get_embeddings(tgt_embed) | |||||
if pos_embed == 'sin': | |||||
encoder_pos_embed = nn.Embedding.from_pretrained( | |||||
get_sinusoid_encoding_table(max_position + 1, src_embed.embedding_dim, padding_idx=0), | |||||
freeze=True) # 这里规定0是padding | |||||
deocder_pos_embed = nn.Embedding.from_pretrained( | |||||
get_sinusoid_encoding_table(max_position + 1, tgt_embed.embedding_dim, padding_idx=0), | |||||
freeze=True) # 这里规定0是padding | |||||
elif pos_embed == 'learned': | |||||
encoder_pos_embed = get_embeddings((max_position + 1, src_embed.embedding_dim), padding_idx=0) | |||||
deocder_pos_embed = get_embeddings((max_position + 1, src_embed.embedding_dim), padding_idx=1) | |||||
else: | |||||
raise ValueError("pos_embed only supports sin or learned.") | |||||
encoder = TransformerSeq2SeqEncoder(embed=src_embed, pos_embed=encoder_pos_embed, | |||||
num_layers=num_layers, d_model=d_model, n_head=n_head, dim_ff=dim_ff, | |||||
dropout=dropout) | |||||
decoder = TransformerSeq2SeqDecoder(embed=tgt_embed, pos_embed=deocder_pos_embed, | |||||
d_model=d_model, num_layers=num_layers, n_head=n_head, dim_ff=dim_ff, | |||||
dropout=dropout, | |||||
bind_decoder_input_output_embed=bind_decoder_input_output_embed) | |||||
return cls(encoder, decoder) | |||||
class LSTMSeq2SeqModel(Seq2SeqModel): | |||||
""" | |||||
使用LSTMSeq2SeqEncoder和LSTMSeq2SeqDecoder的model | |||||
""" | |||||
def __init__(self, encoder, decoder): | |||||
super().__init__(encoder, decoder) | |||||
@classmethod | |||||
def build_model(cls, src_embed, tgt_embed=None, | |||||
num_layers = 3, hidden_size = 400, dropout = 0.3, bidirectional=True, | |||||
attention=True, bind_encoder_decoder_embed=False, | |||||
bind_decoder_input_output_embed=True): | |||||
""" | |||||
:param nn.Module, StaticEmbedding, Tuple[int, int] src_embed: source的embedding | |||||
:param nn.Module, StaticEmbedding, Tuple[int, int] tgt_embed: target的embedding,如果bind_encoder_decoder_embed为 | |||||
True,则不要输入该值 | |||||
:param int num_layers: Encoder和Decoder的层数 | |||||
:param int hidden_size: encoder和decoder的隐藏层大小 | |||||
:param float dropout: 每层之间的Dropout的大小 | |||||
:param bool bidirectional: encoder是否使用双向LSTM | |||||
:param bool attention: decoder是否使用attention attend encoder在所有时刻的状态 | |||||
:param bool bind_encoder_decoder_embed: 是否对encoder和decoder使用相同的embedding | |||||
:param bool bind_decoder_input_output_embed: decoder的输出embedding是否与其输入embedding是一样的权重 | |||||
:return: LSTMSeq2SeqModel | |||||
""" | |||||
if bind_encoder_decoder_embed and tgt_embed is not None: | |||||
raise RuntimeError("If you set `bind_encoder_decoder_embed=True`, please do not provide `tgt_embed`.") | |||||
src_embed = get_embeddings(src_embed) | |||||
if bind_encoder_decoder_embed: | |||||
tgt_embed = src_embed | |||||
else: | |||||
assert tgt_embed is not None, "You need to pass `tgt_embed` when `bind_encoder_decoder_embed=False`" | |||||
tgt_embed = get_embeddings(tgt_embed) | |||||
encoder = LSTMSeq2SeqEncoder(embed=src_embed, num_layers = num_layers, | |||||
hidden_size = hidden_size, dropout = dropout, bidirectional=bidirectional) | |||||
decoder = LSTMSeq2SeqDecoder(embed=tgt_embed, num_layers = num_layers, hidden_size = hidden_size, | |||||
dropout = dropout, bind_decoder_input_output_embed = bind_decoder_input_output_embed, | |||||
attention=attention) | |||||
return cls(encoder, decoder) |
@@ -0,0 +1,271 @@ | |||||
r""" | |||||
本模块实现了几种序列标注模型 | |||||
""" | |||||
__all__ = [ | |||||
"SeqLabeling", | |||||
"AdvSeqLabel", | |||||
"BiLSTMCRF" | |||||
] | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from ...core.utils import seq_len_to_mask | |||||
from ...embeddings.torch.utils import get_embeddings | |||||
from ...modules.torch.decoder import ConditionalRandomField | |||||
from ...modules.torch.encoder import LSTM | |||||
from ...modules.torch import decoder, encoder | |||||
from ...modules.torch.decoder.crf import allowed_transitions | |||||
class BiLSTMCRF(nn.Module): | |||||
r""" | |||||
结构为embedding + BiLSTM + FC + Dropout + CRF. | |||||
""" | |||||
def __init__(self, embed, num_classes, num_layers=1, hidden_size=100, dropout=0.5, | |||||
target_vocab=None): | |||||
r""" | |||||
:param embed: 支持(1)fastNLP的各种Embedding, (2) tuple, 指明num_embedding, dimension, 如(1000, 100) | |||||
:param num_classes: 一共多少个类 | |||||
:param num_layers: BiLSTM的层数 | |||||
:param hidden_size: BiLSTM的hidden_size,实际hidden size为该值的两倍(前向、后向) | |||||
:param dropout: dropout的概率,0为不dropout | |||||
:param target_vocab: Vocabulary对象,target与index的对应关系。如果传入该值,将自动避免非法的解码序列。 | |||||
""" | |||||
super().__init__() | |||||
self.embed = get_embeddings(embed) | |||||
if num_layers>1: | |||||
self.lstm = LSTM(self.embed.embedding_dim, num_layers=num_layers, hidden_size=hidden_size, bidirectional=True, | |||||
batch_first=True, dropout=dropout) | |||||
else: | |||||
self.lstm = LSTM(self.embed.embedding_dim, num_layers=num_layers, hidden_size=hidden_size, bidirectional=True, | |||||
batch_first=True) | |||||
self.dropout = nn.Dropout(dropout) | |||||
self.fc = nn.Linear(hidden_size*2, num_classes) | |||||
trans = None | |||||
if target_vocab is not None: | |||||
assert len(target_vocab)==num_classes, "The number of classes should be same with the length of target vocabulary." | |||||
trans = allowed_transitions(target_vocab.idx2word, include_start_end=True) | |||||
self.crf = ConditionalRandomField(num_classes, include_start_end_trans=True, allowed_transitions=trans) | |||||
def forward(self, words, seq_len=None, target=None): | |||||
words = self.embed(words) | |||||
feats, _ = self.lstm(words, seq_len=seq_len) | |||||
feats = self.fc(feats) | |||||
feats = self.dropout(feats) | |||||
logits = F.log_softmax(feats, dim=-1) | |||||
mask = seq_len_to_mask(seq_len) | |||||
if target is None: | |||||
pred, _ = self.crf.viterbi_decode(logits, mask) | |||||
return {'pred':pred} | |||||
else: | |||||
loss = self.crf(logits, target, mask).mean() | |||||
return {'loss':loss} | |||||
def train_step(self, words, seq_len, target): | |||||
return self(words, seq_len, target) | |||||
def evaluate_step(self, words, seq_len): | |||||
return self(words, seq_len) | |||||
class SeqLabeling(nn.Module): | |||||
r""" | |||||
一个基础的Sequence labeling的模型。 | |||||
用于做sequence labeling的基础类。结构包含一层Embedding,一层LSTM(单向,一层),一层FC,以及一层CRF。 | |||||
""" | |||||
def __init__(self, embed, hidden_size, num_classes): | |||||
r""" | |||||
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int), | |||||
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, embedding, ndarray等则直接使用该值初始化Embedding | |||||
:param int hidden_size: LSTM隐藏层的大小 | |||||
:param int num_classes: 一共有多少类 | |||||
""" | |||||
super(SeqLabeling, self).__init__() | |||||
self.embedding = get_embeddings(embed) | |||||
self.rnn = encoder.LSTM(self.embedding.embedding_dim, hidden_size) | |||||
self.fc = nn.Linear(hidden_size, num_classes) | |||||
self.crf = decoder.ConditionalRandomField(num_classes) | |||||
def forward(self, words, seq_len): | |||||
r""" | |||||
:param torch.LongTensor words: [batch_size, max_len],序列的index | |||||
:param torch.LongTensor seq_len: [batch_size,], 这个序列的长度 | |||||
:return | |||||
""" | |||||
x = self.embedding(words) | |||||
# [batch_size, max_len, word_emb_dim] | |||||
x, _ = self.rnn(x, seq_len) | |||||
# [batch_size, max_len, hidden_size * direction] | |||||
x = self.fc(x) | |||||
return {'pred': x} | |||||
# [batch_size, max_len, num_classes] | |||||
def train_step(self, words, seq_len, target): | |||||
res = self(words, seq_len) | |||||
pred = res['pred'] | |||||
mask = seq_len_to_mask(seq_len, max_len=target.size(1)) | |||||
return {'loss': self._internal_loss(pred, target, mask)} | |||||
def evaluate_step(self, words, seq_len): | |||||
r""" | |||||
用于在预测时使用 | |||||
:param torch.LongTensor words: [batch_size, max_len] | |||||
:param torch.LongTensor seq_len: [batch_size,] | |||||
:return: {'pred': xx}, [batch_size, max_len] | |||||
""" | |||||
mask = seq_len_to_mask(seq_len, max_len=words.size(1)) | |||||
res = self(words, seq_len) | |||||
pred = res['pred'] | |||||
# [batch_size, max_len, num_classes] | |||||
pred = self._decode(pred, mask) | |||||
return {'pred': pred} | |||||
def _internal_loss(self, x, y, mask): | |||||
r""" | |||||
Negative log likelihood loss. | |||||
:param x: Tensor, [batch_size, max_len, tag_size] | |||||
:param y: Tensor, [batch_size, max_len] | |||||
:return loss: a scalar Tensor | |||||
""" | |||||
x = x.float() | |||||
y = y.long() | |||||
total_loss = self.crf(x, y, mask) | |||||
return torch.mean(total_loss) | |||||
def _decode(self, x, mask): | |||||
r""" | |||||
:param torch.FloatTensor x: [batch_size, max_len, tag_size] | |||||
:return prediction: [batch_size, max_len] | |||||
""" | |||||
tag_seq, _ = self.crf.viterbi_decode(x, mask) | |||||
return tag_seq | |||||
class AdvSeqLabel(nn.Module): | |||||
r""" | |||||
更复杂的Sequence Labelling模型。结构为Embedding, LayerNorm, 双向LSTM(两层),FC,LayerNorm,DropOut,FC,CRF。 | |||||
""" | |||||
def __init__(self, embed, hidden_size, num_classes, dropout=0.3, id2words=None, encoding_type='bmes'): | |||||
r""" | |||||
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int), | |||||
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding | |||||
:param int hidden_size: LSTM的隐层大小 | |||||
:param int num_classes: 有多少个类 | |||||
:param float dropout: LSTM中以及DropOut层的drop概率 | |||||
:param dict id2words: tag id转为其tag word的表。用于在CRF解码时防止解出非法的顺序,比如'BMES'这个标签规范中,'S' | |||||
不能出现在'B'之后。这里也支持类似与'B-NN',即'-'前为标签类型的指示,后面为具体的tag的情况。这里不但会保证 | |||||
'B-NN'后面不为'S-NN'还会保证'B-NN'后面不会出现'M-xx'(任何非'M-NN'和'E-NN'的情况。) | |||||
:param str encoding_type: 支持"BIO", "BMES", "BEMSO", 只有在id2words不为None的情况有用。 | |||||
""" | |||||
super().__init__() | |||||
self.Embedding = get_embeddings(embed) | |||||
self.norm1 = torch.nn.LayerNorm(self.Embedding.embedding_dim) | |||||
self.Rnn = encoder.LSTM(input_size=self.Embedding.embedding_dim, hidden_size=hidden_size, num_layers=2, | |||||
dropout=dropout, | |||||
bidirectional=True, batch_first=True) | |||||
self.Linear1 = nn.Linear(hidden_size * 2, hidden_size * 2 // 3) | |||||
self.norm2 = torch.nn.LayerNorm(hidden_size * 2 // 3) | |||||
self.relu = torch.nn.LeakyReLU() | |||||
self.drop = torch.nn.Dropout(dropout) | |||||
self.Linear2 = nn.Linear(hidden_size * 2 // 3, num_classes) | |||||
if id2words is None: | |||||
self.Crf = decoder.crf.ConditionalRandomField(num_classes, include_start_end_trans=False) | |||||
else: | |||||
self.Crf = decoder.crf.ConditionalRandomField(num_classes, include_start_end_trans=False, | |||||
allowed_transitions=allowed_transitions(id2words, | |||||
encoding_type=encoding_type)) | |||||
def _decode(self, x, mask): | |||||
r""" | |||||
:param torch.FloatTensor x: [batch_size, max_len, tag_size] | |||||
:param torch.ByteTensor mask: [batch_size, max_len] | |||||
:return torch.LongTensor, [batch_size, max_len] | |||||
""" | |||||
tag_seq, _ = self.Crf.viterbi_decode(x, mask) | |||||
return tag_seq | |||||
def _internal_loss(self, x, y, mask): | |||||
r""" | |||||
Negative log likelihood loss. | |||||
:param x: Tensor, [batch_size, max_len, tag_size] | |||||
:param y: Tensor, [batch_size, max_len] | |||||
:param mask: Tensor, [batch_size, max_len] | |||||
:return loss: a scalar Tensor | |||||
""" | |||||
x = x.float() | |||||
y = y.long() | |||||
total_loss = self.Crf(x, y, mask) | |||||
return torch.mean(total_loss) | |||||
def forward(self, words, seq_len, target=None): | |||||
r""" | |||||
:param torch.LongTensor words: [batch_size, mex_len] | |||||
:param torch.LongTensor seq_len:[batch_size, ] | |||||
:param torch.LongTensor target: [batch_size, max_len] | |||||
:return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. | |||||
If truth is not None, return loss, a scalar. Used in training. | |||||
""" | |||||
words = words.long() | |||||
seq_len = seq_len.long() | |||||
mask = seq_len_to_mask(seq_len, max_len=words.size(1)) | |||||
target = target.long() if target is not None else None | |||||
if next(self.parameters()).is_cuda: | |||||
words = words.cuda() | |||||
x = self.Embedding(words) | |||||
x = self.norm1(x) | |||||
# [batch_size, max_len, word_emb_dim] | |||||
x, _ = self.Rnn(x, seq_len=seq_len) | |||||
x = self.Linear1(x) | |||||
x = self.norm2(x) | |||||
x = self.relu(x) | |||||
x = self.drop(x) | |||||
x = self.Linear2(x) | |||||
if target is not None: | |||||
return {"loss": self._internal_loss(x, target, mask)} | |||||
else: | |||||
return {"pred": self._decode(x, mask)} | |||||
def train_step(self, words, seq_len, target): | |||||
r""" | |||||
:param torch.LongTensor words: [batch_size, mex_len] | |||||
:param torch.LongTensor seq_len: [batch_size, ] | |||||
:param torch.LongTensor target: [batch_size, max_len], 目标 | |||||
:return torch.Tensor: a scalar loss | |||||
""" | |||||
return self(words, seq_len, target) | |||||
def evaluate_step(self, words, seq_len): | |||||
r""" | |||||
:param torch.LongTensor words: [batch_size, mex_len] | |||||
:param torch.LongTensor seq_len: [batch_size, ] | |||||
:return torch.LongTensor: [batch_size, max_len] | |||||
""" | |||||
return self(words, seq_len) |
@@ -0,0 +1,26 @@ | |||||
__all__ = [ | |||||
'ConditionalRandomField', | |||||
'allowed_transitions', | |||||
"State", | |||||
"Seq2SeqDecoder", | |||||
"LSTMSeq2SeqDecoder", | |||||
"TransformerSeq2SeqDecoder", | |||||
"LSTM", | |||||
"Seq2SeqEncoder", | |||||
"TransformerSeq2SeqEncoder", | |||||
"LSTMSeq2SeqEncoder", | |||||
"StarTransformer", | |||||
"VarRNN", | |||||
"VarLSTM", | |||||
"VarGRU", | |||||
'SequenceGenerator', | |||||
"TimestepDropout", | |||||
] | |||||
from .decoder import * | |||||
from .encoder import * | |||||
from .generator import * | |||||
from .dropout import TimestepDropout |
@@ -0,0 +1,321 @@ | |||||
r"""undocumented""" | |||||
__all__ = [ | |||||
"MultiHeadAttention", | |||||
"BiAttention", | |||||
"SelfAttention", | |||||
] | |||||
import math | |||||
import torch | |||||
import torch.nn.functional as F | |||||
from torch import nn | |||||
from .decoder.seq2seq_state import TransformerState | |||||
class DotAttention(nn.Module): | |||||
r""" | |||||
Transformer当中的DotAttention | |||||
""" | |||||
def __init__(self, key_size, value_size, dropout=0.0): | |||||
super(DotAttention, self).__init__() | |||||
self.key_size = key_size | |||||
self.value_size = value_size | |||||
self.scale = math.sqrt(key_size) | |||||
self.drop = nn.Dropout(dropout) | |||||
self.softmax = nn.Softmax(dim=-1) | |||||
def forward(self, Q, K, V, mask_out=None): | |||||
r""" | |||||
:param Q: [..., seq_len_q, key_size] | |||||
:param K: [..., seq_len_k, key_size] | |||||
:param V: [..., seq_len_k, value_size] | |||||
:param mask_out: [..., 1, seq_len] or [..., seq_len_q, seq_len_k] | |||||
""" | |||||
output = torch.matmul(Q, K.transpose(-1, -2)) / self.scale | |||||
if mask_out is not None: | |||||
output.masked_fill_(mask_out, -1e9) | |||||
output = self.softmax(output) | |||||
output = self.drop(output) | |||||
return torch.matmul(output, V) | |||||
class MultiHeadAttention(nn.Module): | |||||
""" | |||||
Attention is all you need中提到的多头注意力 | |||||
""" | |||||
def __init__(self, d_model: int = 512, n_head: int = 8, dropout: float = 0.0, layer_idx: int = None): | |||||
super(MultiHeadAttention, self).__init__() | |||||
self.d_model = d_model | |||||
self.n_head = n_head | |||||
self.dropout = dropout | |||||
self.head_dim = d_model // n_head | |||||
self.layer_idx = layer_idx | |||||
assert d_model % n_head == 0, "d_model should be divisible by n_head" | |||||
self.scaling = self.head_dim ** -0.5 | |||||
self.q_proj = nn.Linear(d_model, d_model) | |||||
self.k_proj = nn.Linear(d_model, d_model) | |||||
self.v_proj = nn.Linear(d_model, d_model) | |||||
self.out_proj = nn.Linear(d_model, d_model) | |||||
self.reset_parameters() | |||||
def forward(self, query, key, value, key_mask=None, attn_mask=None, state=None): | |||||
""" | |||||
:param query: batch x seq x dim | |||||
:param key: batch x seq x dim | |||||
:param value: batch x seq x dim | |||||
:param key_mask: batch x seq 用于指示哪些key不要attend到;注意到mask为1的地方是要attend到的 | |||||
:param attn_mask: seq x seq, 用于mask掉attention map。 主要是用在训练时decoder端的self attention,下三角为1 | |||||
:param state: 过去的信息,在inference的时候会用到,比如encoder output、decoder的prev kv。这样可以减少计算。 | |||||
:return: | |||||
""" | |||||
assert key.size() == value.size() | |||||
if state is not None: | |||||
assert self.layer_idx is not None | |||||
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr() | |||||
q = self.q_proj(query) # batch x seq x dim | |||||
q *= self.scaling | |||||
k = v = None | |||||
prev_k = prev_v = None | |||||
# 从state中取kv | |||||
if isinstance(state, TransformerState): # 说明此时在inference阶段 | |||||
if qkv_same: # 此时在decoder self attention | |||||
prev_k = state.decoder_prev_key[self.layer_idx] | |||||
prev_v = state.decoder_prev_value[self.layer_idx] | |||||
else: # 此时在decoder-encoder attention,直接将保存下来的key装载起来即可 | |||||
k = state.encoder_key[self.layer_idx] | |||||
v = state.encoder_value[self.layer_idx] | |||||
if k is None: | |||||
k = self.k_proj(key) | |||||
v = self.v_proj(value) | |||||
if prev_k is not None: | |||||
k = torch.cat((prev_k, k), dim=1) | |||||
v = torch.cat((prev_v, v), dim=1) | |||||
# 更新state | |||||
if isinstance(state, TransformerState): | |||||
if qkv_same: | |||||
state.decoder_prev_key[self.layer_idx] = k | |||||
state.decoder_prev_value[self.layer_idx] = v | |||||
else: | |||||
state.encoder_key[self.layer_idx] = k | |||||
state.encoder_value[self.layer_idx] = v | |||||
# 开始计算attention | |||||
batch_size, q_len, d_model = query.size() | |||||
k_len, v_len = k.size(1), v.size(1) | |||||
q = q.reshape(batch_size, q_len, self.n_head, self.head_dim) | |||||
k = k.reshape(batch_size, k_len, self.n_head, self.head_dim) | |||||
v = v.reshape(batch_size, v_len, self.n_head, self.head_dim) | |||||
attn_weights = torch.einsum('bqnh,bknh->bqkn', q, k) # bs,q_len,k_len,n_head | |||||
if key_mask is not None: | |||||
_key_mask = ~key_mask[:, None, :, None].bool() # batch,1,k_len,1 | |||||
attn_weights = attn_weights.masked_fill(_key_mask, -float('inf')) | |||||
if attn_mask is not None: | |||||
_attn_mask = attn_mask[None, :, :, None].eq(0) # 1,q_len,k_len,n_head | |||||
attn_weights = attn_weights.masked_fill(_attn_mask, -float('inf')) | |||||
attn_weights = F.softmax(attn_weights, dim=2) | |||||
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) | |||||
output = torch.einsum('bqkn,bknh->bqnh', attn_weights, v) # batch,q_len,n_head,head_dim | |||||
output = output.reshape(batch_size, q_len, -1) | |||||
output = self.out_proj(output) # batch,q_len,dim | |||||
return output, attn_weights | |||||
def reset_parameters(self): | |||||
nn.init.xavier_uniform_(self.q_proj.weight) | |||||
nn.init.xavier_uniform_(self.k_proj.weight) | |||||
nn.init.xavier_uniform_(self.v_proj.weight) | |||||
nn.init.xavier_uniform_(self.out_proj.weight) | |||||
def set_layer_idx(self, layer_idx): | |||||
self.layer_idx = layer_idx | |||||
class AttentionLayer(nn.Module): | |||||
def __init__(selfu, input_size, key_dim, value_dim, bias=False): | |||||
""" | |||||
可用于LSTM2LSTM的序列到序列模型的decode过程中,该attention是在decode过程中根据上一个step的hidden计算对encoder结果的attention | |||||
:param int input_size: 输入的大小 | |||||
:param int key_dim: 一般就是encoder_output输出的维度 | |||||
:param int value_dim: 输出的大小维度, 一般就是decoder hidden的大小 | |||||
:param bias: | |||||
""" | |||||
super().__init__() | |||||
selfu.input_proj = nn.Linear(input_size, key_dim, bias=bias) | |||||
selfu.output_proj = nn.Linear(input_size + key_dim, value_dim, bias=bias) | |||||
def forward(self, input, encode_outputs, encode_mask): | |||||
""" | |||||
:param input: batch_size x input_size | |||||
:param encode_outputs: batch_size x max_len x key_dim | |||||
:param encode_mask: batch_size x max_len, 为0的地方为padding | |||||
:return: hidden: batch_size x value_dim, scores: batch_size x max_len, normalized过的 | |||||
""" | |||||
# x: bsz x encode_hidden_size | |||||
x = self.input_proj(input) | |||||
# compute attention | |||||
attn_scores = torch.matmul(encode_outputs, x.unsqueeze(-1)).squeeze(-1) # b x max_len | |||||
# don't attend over padding | |||||
if encode_mask is not None: | |||||
attn_scores = attn_scores.float().masked_fill_( | |||||
encode_mask.eq(0), | |||||
float('-inf') | |||||
).type_as(attn_scores) # FP16 support: cast to float and back | |||||
attn_scores = F.softmax(attn_scores, dim=-1) # srclen x bsz | |||||
# sum weighted sources | |||||
x = torch.matmul(attn_scores.unsqueeze(1), encode_outputs).squeeze(1) # b x encode_hidden_size | |||||
x = torch.tanh(self.output_proj(torch.cat((x, input), dim=1))) | |||||
return x, attn_scores | |||||
def _masked_softmax(tensor, mask): | |||||
tensor_shape = tensor.size() | |||||
reshaped_tensor = tensor.view(-1, tensor_shape[-1]) | |||||
# Reshape the mask so it matches the size of the input tensor. | |||||
while mask.dim() < tensor.dim(): | |||||
mask = mask.unsqueeze(1) | |||||
mask = mask.expand_as(tensor).contiguous().float() | |||||
reshaped_mask = mask.view(-1, mask.size()[-1]) | |||||
result = F.softmax(reshaped_tensor * reshaped_mask, dim=-1) | |||||
result = result * reshaped_mask | |||||
# 1e-13 is added to avoid divisions by zero. | |||||
result = result / (result.sum(dim=-1, keepdim=True) + 1e-13) | |||||
return result.view(*tensor_shape) | |||||
def _weighted_sum(tensor, weights, mask): | |||||
w_sum = weights.bmm(tensor) | |||||
while mask.dim() < w_sum.dim(): | |||||
mask = mask.unsqueeze(1) | |||||
mask = mask.transpose(-1, -2) | |||||
mask = mask.expand_as(w_sum).contiguous().float() | |||||
return w_sum * mask | |||||
class BiAttention(nn.Module): | |||||
r""" | |||||
Bi Attention module | |||||
对于给定的两个向量序列 :math:`a_i` 和 :math:`b_j` , BiAttention模块将通过以下的公式来计算attention结果 | |||||
.. math:: | |||||
\begin{array}{ll} \\ | |||||
e_{ij} = {a}^{\mathrm{T}}_{i}{b}_{j} \\ | |||||
{\hat{a}}_{i} = \sum_{j=1}^{\mathcal{l}_{b}}{\frac{\mathrm{exp}(e_{ij})}{\sum_{k=1}^{\mathcal{l}_{b}}{\mathrm{exp}(e_{ik})}}}{b}_{j} \\ | |||||
{\hat{b}}_{j} = \sum_{i=1}^{\mathcal{l}_{a}}{\frac{\mathrm{exp}(e_{ij})}{\sum_{k=1}^{\mathcal{l}_{a}}{\mathrm{exp}(e_{ik})}}}{a}_{i} \\ | |||||
\end{array} | |||||
""" | |||||
def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask): | |||||
r""" | |||||
:param torch.Tensor premise_batch: [batch_size, a_seq_len, hidden_size] | |||||
:param torch.Tensor premise_mask: [batch_size, a_seq_len] | |||||
:param torch.Tensor hypothesis_batch: [batch_size, b_seq_len, hidden_size] | |||||
:param torch.Tensor hypothesis_mask: [batch_size, b_seq_len] | |||||
:return: torch.Tensor attended_premises: [batch_size, a_seq_len, hidden_size] torch.Tensor attended_hypotheses: [batch_size, b_seq_len, hidden_size] | |||||
""" | |||||
similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1) | |||||
.contiguous()) | |||||
prem_hyp_attn = _masked_softmax(similarity_matrix, hypothesis_mask) | |||||
hyp_prem_attn = _masked_softmax(similarity_matrix.transpose(1, 2) | |||||
.contiguous(), | |||||
premise_mask) | |||||
attended_premises = _weighted_sum(hypothesis_batch, | |||||
prem_hyp_attn, | |||||
premise_mask) | |||||
attended_hypotheses = _weighted_sum(premise_batch, | |||||
hyp_prem_attn, | |||||
hypothesis_mask) | |||||
return attended_premises, attended_hypotheses | |||||
class SelfAttention(nn.Module): | |||||
r""" | |||||
这是一个基于论文 `A structured self-attentive sentence embedding <https://arxiv.org/pdf/1703.03130.pdf>`_ | |||||
的Self Attention Module. | |||||
""" | |||||
def __init__(self, input_size, attention_unit=300, attention_hops=10, drop=0.5): | |||||
r""" | |||||
:param int input_size: 输入tensor的hidden维度 | |||||
:param int attention_unit: 输出tensor的hidden维度 | |||||
:param int attention_hops: | |||||
:param float drop: dropout概率,默认值为0.5 | |||||
""" | |||||
super(SelfAttention, self).__init__() | |||||
self.attention_hops = attention_hops | |||||
self.ws1 = nn.Linear(input_size, attention_unit, bias=False) | |||||
self.ws2 = nn.Linear(attention_unit, attention_hops, bias=False) | |||||
self.I = torch.eye(attention_hops, requires_grad=False) | |||||
self.I_origin = self.I | |||||
self.drop = nn.Dropout(drop) | |||||
self.tanh = nn.Tanh() | |||||
def _penalization(self, attention): | |||||
r""" | |||||
compute the penalization term for attention module | |||||
""" | |||||
baz = attention.size(0) | |||||
size = self.I.size() | |||||
if len(size) != 3 or size[0] != baz: | |||||
self.I = self.I_origin.expand(baz, -1, -1) | |||||
self.I = self.I.to(device=attention.device) | |||||
attention_t = torch.transpose(attention, 1, 2).contiguous() | |||||
mat = torch.bmm(attention, attention_t) - self.I[:attention.size(0)] | |||||
ret = (torch.sum(torch.sum((mat ** 2), 2), 1).squeeze() + 1e-10) ** 0.5 | |||||
return torch.sum(ret) / size[0] | |||||
def forward(self, input, input_origin): | |||||
r""" | |||||
:param torch.Tensor input: [batch_size, seq_len, hidden_size] 要做attention的矩阵 | |||||
:param torch.Tensor input_origin: [batch_size, seq_len] 原始token的index组成的矩阵,含有pad部分内容 | |||||
:return torch.Tensor output1: [batch_size, multi-head, hidden_size] 经过attention操作后输入矩阵的结果 | |||||
:return torch.Tensor output2: [1] attention惩罚项,是一个标量 | |||||
""" | |||||
input = input.contiguous() | |||||
size = input.size() # [bsz, len, nhid] | |||||
input_origin = input_origin.expand(self.attention_hops, -1, -1) # [hops,baz, len] | |||||
input_origin = input_origin.transpose(0, 1).contiguous() # [baz, hops,len] | |||||
y1 = self.tanh(self.ws1(self.drop(input))) # [baz,len,dim] -->[bsz,len, attention-unit] | |||||
attention = self.ws2(y1).transpose(1, 2).contiguous() | |||||
# [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len] | |||||
attention = attention + (-999999 * (input_origin == 0).float()) # remove the weight on padding token. | |||||
attention = F.softmax(attention, 2) # [baz ,hop, len] | |||||
return torch.bmm(attention, input), self._penalization(attention) # output1 --> [baz ,hop ,nhid] |
@@ -0,0 +1,15 @@ | |||||
__all__ = [ | |||||
'ConditionalRandomField', | |||||
'allowed_transitions', | |||||
"State", | |||||
"Seq2SeqDecoder", | |||||
"LSTMSeq2SeqDecoder", | |||||
"TransformerSeq2SeqDecoder" | |||||
] | |||||
from .crf import ConditionalRandomField, allowed_transitions | |||||
from .seq2seq_state import State | |||||
from .seq2seq_decoder import LSTMSeq2SeqDecoder, TransformerSeq2SeqDecoder, Seq2SeqDecoder |
@@ -0,0 +1,354 @@ | |||||
r"""undocumented""" | |||||
__all__ = [ | |||||
"ConditionalRandomField", | |||||
"allowed_transitions" | |||||
] | |||||
from typing import Union, List | |||||
import torch | |||||
from torch import nn | |||||
from ....core.metrics.span_f1_pre_rec_metric import _get_encoding_type_from_tag_vocab, _check_tag_vocab_and_encoding_type | |||||
from ....core.vocabulary import Vocabulary | |||||
def allowed_transitions(tag_vocab:Union[Vocabulary, dict], encoding_type:str=None, include_start_end:bool=False): | |||||
r""" | |||||
给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 | |||||
:param tag_vocab: 支持类型为tag或tag-label。只有tag的,比如"B", "M"; 也可以是"B-NN", "M-NN", | |||||
tag和label之间一定要用"-"隔开。如果传入dict,格式需要形如{0:"O", 1:"B-tag1"},即index在前,tag在后。 | |||||
:param encoding_type: 支持``["bio", "bmes", "bmeso", "bioes"]``。默认为None,通过vocab自动推断 | |||||
:param include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | |||||
为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); | |||||
start_idx=len(id2label), end_idx=len(id2label)+1。为False, 返回的结果中不含与开始结尾相关的内容 | |||||
:return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。 | |||||
""" | |||||
if encoding_type is None: | |||||
encoding_type = _get_encoding_type_from_tag_vocab(tag_vocab) | |||||
else: | |||||
encoding_type = encoding_type.lower() | |||||
_check_tag_vocab_and_encoding_type(tag_vocab, encoding_type) | |||||
pad_token = '<pad>' | |||||
unk_token = '<unk>' | |||||
if isinstance(tag_vocab, Vocabulary): | |||||
id_label_lst = list(tag_vocab.idx2word.items()) | |||||
pad_token = tag_vocab.padding | |||||
unk_token = tag_vocab.unknown | |||||
else: | |||||
id_label_lst = list(tag_vocab.items()) | |||||
num_tags = len(tag_vocab) | |||||
start_idx = num_tags | |||||
end_idx = num_tags + 1 | |||||
allowed_trans = [] | |||||
if include_start_end: | |||||
id_label_lst += [(start_idx, 'start'), (end_idx, 'end')] | |||||
def split_tag_label(from_label): | |||||
from_label = from_label.lower() | |||||
if from_label in ['start', 'end']: | |||||
from_tag = from_label | |||||
from_label = '' | |||||
else: | |||||
from_tag = from_label[:1] | |||||
from_label = from_label[2:] | |||||
return from_tag, from_label | |||||
for from_id, from_label in id_label_lst: | |||||
if from_label in [pad_token, unk_token]: | |||||
continue | |||||
from_tag, from_label = split_tag_label(from_label) | |||||
for to_id, to_label in id_label_lst: | |||||
if to_label in [pad_token, unk_token]: | |||||
continue | |||||
to_tag, to_label = split_tag_label(to_label) | |||||
if _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||||
allowed_trans.append((from_id, to_id)) | |||||
return allowed_trans | |||||
def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||||
r""" | |||||
:param str encoding_type: 支持"BIO", "BMES", "BEMSO", 'bioes'。 | |||||
:param str from_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||||
:param str from_label: 比如"PER", "LOC"等label | |||||
:param str to_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||||
:param str to_label: 比如"PER", "LOC"等label | |||||
:return: bool,能否跃迁 | |||||
""" | |||||
if to_tag == 'start' or from_tag == 'end': | |||||
return False | |||||
encoding_type = encoding_type.lower() | |||||
if encoding_type == 'bio': | |||||
r""" | |||||
第一行是to_tag, 第一列是from_tag. y任意条件下可转,-只有在label相同时可转,n不可转 | |||||
+-------+---+---+---+-------+-----+ | |||||
| | B | I | O | start | end | | |||||
+-------+---+---+---+-------+-----+ | |||||
| B | y | - | y | n | y | | |||||
+-------+---+---+---+-------+-----+ | |||||
| I | y | - | y | n | y | | |||||
+-------+---+---+---+-------+-----+ | |||||
| O | y | n | y | n | y | | |||||
+-------+---+---+---+-------+-----+ | |||||
| start | y | n | y | n | n | | |||||
+-------+---+---+---+-------+-----+ | |||||
| end | n | n | n | n | n | | |||||
+-------+---+---+---+-------+-----+ | |||||
""" | |||||
if from_tag == 'start': | |||||
return to_tag in ('b', 'o') | |||||
elif from_tag in ['b', 'i']: | |||||
return any([to_tag in ['end', 'b', 'o'], to_tag == 'i' and from_label == to_label]) | |||||
elif from_tag == 'o': | |||||
return to_tag in ['end', 'b', 'o'] | |||||
else: | |||||
raise ValueError("Unexpect tag {}. Expect only 'B', 'I', 'O'.".format(from_tag)) | |||||
elif encoding_type == 'bmes': | |||||
r""" | |||||
第一行是to_tag, 第一列是from_tag,y任意条件下可转,-只有在label相同时可转,n不可转 | |||||
+-------+---+---+---+---+-------+-----+ | |||||
| | B | M | E | S | start | end | | |||||
+-------+---+---+---+---+-------+-----+ | |||||
| B | n | - | - | n | n | n | | |||||
+-------+---+---+---+---+-------+-----+ | |||||
| M | n | - | - | n | n | n | | |||||
+-------+---+---+---+---+-------+-----+ | |||||
| E | y | n | n | y | n | y | | |||||
+-------+---+---+---+---+-------+-----+ | |||||
| S | y | n | n | y | n | y | | |||||
+-------+---+---+---+---+-------+-----+ | |||||
| start | y | n | n | y | n | n | | |||||
+-------+---+---+---+---+-------+-----+ | |||||
| end | n | n | n | n | n | n | | |||||
+-------+---+---+---+---+-------+-----+ | |||||
""" | |||||
if from_tag == 'start': | |||||
return to_tag in ['b', 's'] | |||||
elif from_tag == 'b': | |||||
return to_tag in ['m', 'e'] and from_label == to_label | |||||
elif from_tag == 'm': | |||||
return to_tag in ['m', 'e'] and from_label == to_label | |||||
elif from_tag in ['e', 's']: | |||||
return to_tag in ['b', 's', 'end'] | |||||
else: | |||||
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S'.".format(from_tag)) | |||||
elif encoding_type == 'bmeso': | |||||
if from_tag == 'start': | |||||
return to_tag in ['b', 's', 'o'] | |||||
elif from_tag == 'b': | |||||
return to_tag in ['m', 'e'] and from_label == to_label | |||||
elif from_tag == 'm': | |||||
return to_tag in ['m', 'e'] and from_label == to_label | |||||
elif from_tag in ['e', 's', 'o']: | |||||
return to_tag in ['b', 's', 'end', 'o'] | |||||
else: | |||||
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag)) | |||||
elif encoding_type == 'bioes': | |||||
if from_tag == 'start': | |||||
return to_tag in ['b', 's', 'o'] | |||||
elif from_tag == 'b': | |||||
return to_tag in ['i', 'e'] and from_label == to_label | |||||
elif from_tag == 'i': | |||||
return to_tag in ['i', 'e'] and from_label == to_label | |||||
elif from_tag in ['e', 's', 'o']: | |||||
return to_tag in ['b', 's', 'end', 'o'] | |||||
else: | |||||
raise ValueError("Unexpect tag type {}. Expect only 'B', 'I', 'E', 'S', 'O'.".format(from_tag)) | |||||
else: | |||||
raise ValueError("Only support BIO, BMES, BMESO, BIOES encoding type, got {}.".format(encoding_type)) | |||||
class ConditionalRandomField(nn.Module): | |||||
r""" | |||||
条件随机场。提供 forward() 以及 viterbi_decode() 两个方法,分别用于训练与inference。 | |||||
""" | |||||
def __init__(self, num_tags:int, include_start_end_trans:bool=False, allowed_transitions:List=None): | |||||
r""" | |||||
:param num_tags: 标签的数量 | |||||
:param include_start_end_trans: 是否考虑各个tag作为开始以及结尾的分数。 | |||||
:param allowed_transitions: 内部的Tuple[from_tag_id(int), | |||||
to_tag_id(int)]视为允许发生的跃迁,其他没有包含的跃迁认为是禁止跃迁,可以通过 | |||||
allowed_transitions()函数得到;如果为None,则所有跃迁均为合法 | |||||
""" | |||||
super(ConditionalRandomField, self).__init__() | |||||
self.include_start_end_trans = include_start_end_trans | |||||
self.num_tags = num_tags | |||||
# the meaning of entry in this matrix is (from_tag_id, to_tag_id) score | |||||
self.trans_m = nn.Parameter(torch.randn(num_tags, num_tags)) | |||||
if self.include_start_end_trans: | |||||
self.start_scores = nn.Parameter(torch.randn(num_tags)) | |||||
self.end_scores = nn.Parameter(torch.randn(num_tags)) | |||||
if allowed_transitions is None: | |||||
constrain = torch.zeros(num_tags + 2, num_tags + 2) | |||||
else: | |||||
constrain = torch.full((num_tags + 2, num_tags + 2), fill_value=-10000.0, dtype=torch.float) | |||||
has_start = False | |||||
has_end = False | |||||
for from_tag_id, to_tag_id in allowed_transitions: | |||||
constrain[from_tag_id, to_tag_id] = 0 | |||||
if from_tag_id==num_tags: | |||||
has_start = True | |||||
if to_tag_id==num_tags+1: | |||||
has_end = True | |||||
if not has_start: | |||||
constrain[num_tags, :].fill_(0) | |||||
if not has_end: | |||||
constrain[:, num_tags+1].fill_(0) | |||||
self._constrain = nn.Parameter(constrain, requires_grad=False) | |||||
def _normalizer_likelihood(self, logits, mask): | |||||
r"""Computes the (batch_size,) denominator term for the log-likelihood, which is the | |||||
sum of the likelihoods across all possible state sequences. | |||||
:param logits:FloatTensor, max_len x batch_size x num_tags | |||||
:param mask:ByteTensor, max_len x batch_size | |||||
:return:FloatTensor, batch_size | |||||
""" | |||||
seq_len, batch_size, n_tags = logits.size() | |||||
alpha = logits[0] | |||||
if self.include_start_end_trans: | |||||
alpha = alpha + self.start_scores.view(1, -1) | |||||
flip_mask = mask.eq(False) | |||||
for i in range(1, seq_len): | |||||
emit_score = logits[i].view(batch_size, 1, n_tags) | |||||
trans_score = self.trans_m.view(1, n_tags, n_tags) | |||||
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score | |||||
alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ | |||||
alpha.masked_fill(mask[i].eq(True).view(batch_size, 1), 0) | |||||
if self.include_start_end_trans: | |||||
alpha = alpha + self.end_scores.view(1, -1) | |||||
return torch.logsumexp(alpha, 1) | |||||
def _gold_score(self, logits, tags, mask): | |||||
r""" | |||||
Compute the score for the gold path. | |||||
:param logits: FloatTensor, max_len x batch_size x num_tags | |||||
:param tags: LongTensor, max_len x batch_size | |||||
:param mask: ByteTensor, max_len x batch_size | |||||
:return:FloatTensor, batch_size | |||||
""" | |||||
seq_len, batch_size, _ = logits.size() | |||||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) | |||||
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | |||||
# trans_socre [L-1, B] | |||||
mask = mask.eq(True) | |||||
flip_mask = mask.eq(False) | |||||
trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]].masked_fill(flip_mask[1:, :], 0) | |||||
# emit_score [L, B] | |||||
emit_score = logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags].masked_fill(flip_mask, 0) | |||||
# score [L-1, B] | |||||
score = trans_score + emit_score[:seq_len - 1, :] | |||||
score = score.sum(0) + emit_score[-1].masked_fill(flip_mask[-1], 0) | |||||
if self.include_start_end_trans: | |||||
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] | |||||
last_idx = mask.long().sum(0) - 1 | |||||
ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] | |||||
score = score + st_scores + ed_scores | |||||
# return [B,] | |||||
return score | |||||
def forward(self, feats, tags, mask): | |||||
r""" | |||||
用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。 | |||||
:param torch.FloatTensor feats: batch_size x max_len x num_tags,特征矩阵。 | |||||
:param torch.LongTensor tags: batch_size x max_len,标签矩阵。 | |||||
:param torch.ByteTensor mask: batch_size x max_len,为0的位置认为是padding。 | |||||
:return: torch.FloatTensor, (batch_size,) | |||||
""" | |||||
feats = feats.transpose(0, 1) | |||||
tags = tags.transpose(0, 1).long() | |||||
mask = mask.transpose(0, 1).float() | |||||
all_path_score = self._normalizer_likelihood(feats, mask) | |||||
gold_path_score = self._gold_score(feats, tags, mask) | |||||
return all_path_score - gold_path_score | |||||
def viterbi_decode(self, logits, mask, unpad=False): | |||||
r"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | |||||
:param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。 | |||||
:param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 | |||||
:param bool unpad: 是否将结果删去padding。False, 返回的是batch_size x max_len的tensor; True,返回的是 | |||||
List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是这 | |||||
个sample的有效长度。 | |||||
:return: 返回 (paths, scores)。 | |||||
paths: 是解码后的路径, 其值参照unpad参数. | |||||
scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。 | |||||
""" | |||||
batch_size, max_len, n_tags = logits.size() | |||||
seq_len = mask.long().sum(1) | |||||
logits = logits.transpose(0, 1).data # L, B, H | |||||
mask = mask.transpose(0, 1).data.eq(True) # L, B | |||||
flip_mask = mask.eq(False) | |||||
# dp | |||||
vpath = logits.new_zeros((max_len, batch_size, n_tags), dtype=torch.long) | |||||
vscore = logits[0] # bsz x n_tags | |||||
transitions = self._constrain.data.clone() | |||||
transitions[:n_tags, :n_tags] += self.trans_m.data | |||||
if self.include_start_end_trans: | |||||
transitions[n_tags, :n_tags] += self.start_scores.data | |||||
transitions[:n_tags, n_tags + 1] += self.end_scores.data | |||||
vscore += transitions[n_tags, :n_tags] | |||||
trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data | |||||
end_trans_score = transitions[:n_tags, n_tags+1].view(1, 1, n_tags).repeat(batch_size, 1, 1) # bsz, 1, n_tags | |||||
# 针对长度为1的句子 | |||||
vscore += transitions[:n_tags, n_tags+1].view(1, n_tags).repeat(batch_size, 1) \ | |||||
.masked_fill(seq_len.ne(1).view(-1, 1), 0) | |||||
for i in range(1, max_len): | |||||
prev_score = vscore.view(batch_size, n_tags, 1) | |||||
cur_score = logits[i].view(batch_size, 1, n_tags) + trans_score | |||||
score = prev_score + cur_score.masked_fill(flip_mask[i].view(batch_size, 1, 1), 0) # bsz x n_tag x n_tag | |||||
# 需要考虑当前位置是该序列的最后一个 | |||||
score += end_trans_score.masked_fill(seq_len.ne(i+1).view(-1, 1, 1), 0) | |||||
best_score, best_dst = score.max(1) | |||||
vpath[i] = best_dst | |||||
# 由于最终是通过last_tags回溯,需要保持每个位置的vscore情况 | |||||
vscore = best_score.masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ | |||||
vscore.masked_fill(mask[i].view(batch_size, 1), 0) | |||||
# backtrace | |||||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) | |||||
seq_idx = torch.arange(max_len, dtype=torch.long, device=logits.device) | |||||
lens = (seq_len - 1) | |||||
# idxes [L, B], batched idx from seq_len-1 to 0 | |||||
idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % max_len | |||||
ans = logits.new_empty((max_len, batch_size), dtype=torch.long) | |||||
ans_score, last_tags = vscore.max(1) | |||||
ans[idxes[0], batch_idx] = last_tags | |||||
for i in range(max_len - 1): | |||||
last_tags = vpath[idxes[i], batch_idx, last_tags] | |||||
ans[idxes[i + 1], batch_idx] = last_tags | |||||
ans = ans.transpose(0, 1) | |||||
if unpad: | |||||
paths = [] | |||||
for idx, max_len in enumerate(lens): | |||||
paths.append(ans[idx, :max_len + 1].tolist()) | |||||
else: | |||||
paths = ans | |||||
return paths, ans_score |
@@ -0,0 +1,416 @@ | |||||
r"""undocumented""" | |||||
from typing import Union, Tuple | |||||
import math | |||||
import torch | |||||
from torch import nn | |||||
import torch.nn.functional as F | |||||
from ..attention import AttentionLayer, MultiHeadAttention | |||||
from ....embeddings.torch.utils import get_embeddings | |||||
from ....embeddings.torch.static_embedding import StaticEmbedding | |||||
from .seq2seq_state import State, LSTMState, TransformerState | |||||
__all__ = ['Seq2SeqDecoder', 'TransformerSeq2SeqDecoder', 'LSTMSeq2SeqDecoder'] | |||||
class Seq2SeqDecoder(nn.Module): | |||||
""" | |||||
Sequence-to-Sequence Decoder的基类。一定需要实现forward、decode函数,剩下的函数根据需要实现。每个Seq2SeqDecoder都应该有相应的State对象 | |||||
用来承载该Decoder所需要的Encoder输出、Decoder需要记录的历史信息(例如LSTM的hidden信息)。 | |||||
""" | |||||
def __init__(self): | |||||
super().__init__() | |||||
def forward(self, tokens, state, **kwargs): | |||||
""" | |||||
:param torch.LongTensor tokens: bsz x max_len | |||||
:param State state: state包含了encoder的输出以及decode之前的内容 | |||||
:return: 返回值可以为bsz x max_len x vocab_size的Tensor,也可以是一个list,但是第一个元素必须是词的预测分布 | |||||
""" | |||||
raise NotImplemented | |||||
def reorder_states(self, indices, states): | |||||
""" | |||||
根据indices重新排列states中的状态,在beam search进行生成时,会用到该函数。 | |||||
:param torch.LongTensor indices: | |||||
:param State states: | |||||
:return: | |||||
""" | |||||
assert isinstance(states, State), f"`states` should be of type State instead of {type(states)}" | |||||
states.reorder_state(indices) | |||||
def init_state(self, encoder_output, encoder_mask): | |||||
""" | |||||
初始化一个state对象,用来记录了encoder的输出以及decode已经完成的部分。 | |||||
:param Union[torch.Tensor, list, tuple] encoder_output: 如果不为None,内部元素需要为torch.Tensor, 默认其中第一维是batch | |||||
维度 | |||||
:param Union[torch.Tensor, list, tuple] encoder_mask: 如果部位None,内部元素需要torch.Tensor, 默认其中第一维是batch | |||||
维度 | |||||
:param kwargs: | |||||
:return: State, 返回一个State对象,记录了encoder的输出 | |||||
""" | |||||
state = State(encoder_output, encoder_mask) | |||||
return state | |||||
def decode(self, tokens, state): | |||||
""" | |||||
根据states中的内容,以及tokens中的内容进行之后的生成。 | |||||
:param torch.LongTensor tokens: bsz x max_len, 截止到上一个时刻所有的token输出。 | |||||
:param State state: 记录了encoder输出与decoder过去状态 | |||||
:return: torch.FloatTensor: bsz x vocab_size, 输出的是下一个时刻的分布 | |||||
""" | |||||
outputs = self(state=state, tokens=tokens) | |||||
if isinstance(outputs, torch.Tensor): | |||||
return outputs[:, -1] | |||||
else: | |||||
raise RuntimeError("Unrecognized output from the `forward()` function. Please override the `decode()` function.") | |||||
class TiedEmbedding(nn.Module): | |||||
""" | |||||
用于将weight和原始weight绑定 | |||||
""" | |||||
def __init__(self, weight): | |||||
super().__init__() | |||||
self.weight = weight # vocab_size x embed_size | |||||
def forward(self, x): | |||||
""" | |||||
:param torch.FloatTensor x: bsz x * x embed_size | |||||
:return: torch.FloatTensor bsz x * x vocab_size | |||||
""" | |||||
return torch.matmul(x, self.weight.t()) | |||||
def get_bind_decoder_output_embed(embed): | |||||
""" | |||||
给定一个embedding,输出对应的绑定的embedding,输出对象为TiedEmbedding | |||||
:param embed: | |||||
:return: | |||||
""" | |||||
if isinstance(embed, StaticEmbedding): | |||||
for idx, map2idx in enumerate(embed.words_to_words): | |||||
assert idx == map2idx, "Invalid StaticEmbedding for Decoder, please check:(1) whether the vocabulary " \ | |||||
"include `no_create_entry=True` word; (2) StaticEmbedding should not initialize with " \ | |||||
"`lower=True` or `min_freq!=1`." | |||||
elif not isinstance(embed, nn.Embedding): | |||||
raise TypeError("Only nn.Embedding or StaticEmbedding is allowed for binding.") | |||||
return TiedEmbedding(embed.weight) | |||||
class LSTMSeq2SeqDecoder(Seq2SeqDecoder): | |||||
""" | |||||
LSTM的Decoder | |||||
:param nn.Module,tuple embed: decoder输入的embedding. | |||||
:param int num_layers: 多少层LSTM | |||||
:param int hidden_size: 隐藏层大小, 该值也被认为是encoder的输出维度大小 | |||||
:param dropout: Dropout的大小 | |||||
:param bool bind_decoder_input_output_embed: 是否将输出层和输入层的词向量绑定在一起(即为同一个),若embed为StaticEmbedding, | |||||
则StaticEmbedding的vocab不能包含no_create_entry的token,同时StaticEmbedding初始化时lower为False, min_freq=1. | |||||
:param bool attention: 是否使用attention | |||||
""" | |||||
def __init__(self, embed: Union[nn.Module, Tuple[int, int]], num_layers = 3, hidden_size = 300, | |||||
dropout = 0.3, bind_decoder_input_output_embed = True, attention=True): | |||||
super().__init__() | |||||
self.embed = get_embeddings(init_embed=embed) | |||||
self.embed_dim = embed.embedding_dim | |||||
if bind_decoder_input_output_embed: | |||||
self.output_layer = get_bind_decoder_output_embed(self.embed) | |||||
else: # 不需要bind | |||||
self.output_embed = get_embeddings((self.embed.num_embeddings, self.embed.embedding_dim)) | |||||
self.output_layer = TiedEmbedding(self.output_embed.weight) | |||||
self.hidden_size = hidden_size | |||||
self.num_layers = num_layers | |||||
self.lstm = nn.LSTM(input_size=self.embed_dim + hidden_size, hidden_size=hidden_size, num_layers=num_layers, | |||||
batch_first=True, bidirectional=False, dropout=dropout if num_layers>1 else 0) | |||||
self.attention_layer = AttentionLayer(hidden_size, hidden_size, hidden_size) if attention else None | |||||
self.output_proj = nn.Linear(hidden_size, self.embed_dim) | |||||
self.dropout_layer = nn.Dropout(dropout) | |||||
def forward(self, tokens, state, return_attention=False): | |||||
""" | |||||
:param torch.LongTensor tokens: batch x max_len | |||||
:param LSTMState state: 保存encoder输出和decode状态的State对象 | |||||
:param bool return_attention: 是否返回attention的的score | |||||
:return: bsz x max_len x vocab_size; 如果return_attention=True, 还会返回bsz x max_len x encode_length | |||||
""" | |||||
src_output = state.encoder_output | |||||
encoder_mask = state.encoder_mask | |||||
assert tokens.size(1)>state.decode_length, "The state does not match the tokens." | |||||
tokens = tokens[:, state.decode_length:] | |||||
x = self.embed(tokens) | |||||
attn_weights = [] if self.attention_layer is not None else None # 保存attention weight, batch,tgt_seq,src_seq | |||||
input_feed = state.input_feed | |||||
decoder_out = [] | |||||
cur_hidden = state.hidden | |||||
cur_cell = state.cell | |||||
# 开始计算 | |||||
for i in range(tokens.size(1)): | |||||
input = torch.cat( | |||||
(x[:, i:i + 1, :], | |||||
input_feed[:, None, :] | |||||
), | |||||
dim=2 | |||||
) # batch,1,2*dim | |||||
_, (cur_hidden, cur_cell) = self.lstm(input, hx=(cur_hidden, cur_cell)) # hidden/cell保持原来的size | |||||
if self.attention_layer is not None: | |||||
input_feed, attn_weight = self.attention_layer(cur_hidden[-1], src_output, encoder_mask) | |||||
attn_weights.append(attn_weight) | |||||
else: | |||||
input_feed = cur_hidden[-1] | |||||
state.input_feed = input_feed # batch, hidden | |||||
state.hidden = cur_hidden | |||||
state.cell = cur_cell | |||||
state.decode_length += 1 | |||||
decoder_out.append(input_feed) | |||||
decoder_out = torch.stack(decoder_out, dim=1) # batch,seq_len,hidden | |||||
decoder_out = self.dropout_layer(decoder_out) | |||||
if attn_weights is not None: | |||||
attn_weights = torch.cat(attn_weights, dim=1) # batch, tgt_len, src_len | |||||
decoder_out = self.output_proj(decoder_out) | |||||
feats = self.output_layer(decoder_out) | |||||
if return_attention: | |||||
return feats, attn_weights | |||||
return feats | |||||
def init_state(self, encoder_output, encoder_mask) -> LSTMState: | |||||
""" | |||||
:param encoder_output: 输入可以有两种情况(1) 输入为一个tuple,包含三个内容(encoder_output, (hidden, cell)),其中encoder_output: | |||||
bsz x max_len x hidden_size, hidden: bsz x hidden_size, cell:bsz x hidden_size,一般使用LSTMEncoder的最后一层的 | |||||
hidden state和cell state来赋值这两个值 | |||||
(2) 只有encoder_output: bsz x max_len x hidden_size, 这种情况下hidden和cell使用0初始化 | |||||
:param torch.ByteTensor encoder_mask: bsz x max_len, 为0的位置是padding, 用来指示source中哪些不需要attend | |||||
:return: | |||||
""" | |||||
if not isinstance(encoder_output, torch.Tensor): | |||||
encoder_output, (hidden, cell) = encoder_output | |||||
else: | |||||
hidden = cell = None | |||||
assert encoder_output.ndim==3 | |||||
assert encoder_mask.size()==encoder_output.size()[:2] | |||||
assert encoder_output.size(-1)==self.hidden_size, "The dimension of encoder outputs should be the same with " \ | |||||
"the hidden_size." | |||||
t = [hidden, cell] | |||||
for idx in range(2): | |||||
v = t[idx] | |||||
if v is None: | |||||
v = encoder_output.new_zeros(self.num_layers, encoder_output.size(0), self.hidden_size) | |||||
else: | |||||
assert v.dim()==2 | |||||
assert v.size(-1)==self.hidden_size | |||||
v = v[None].repeat(self.num_layers, 1, 1) # num_layers x bsz x hidden_size | |||||
t[idx] = v | |||||
state = LSTMState(encoder_output, encoder_mask, t[0], t[1]) | |||||
return state | |||||
class TransformerSeq2SeqDecoderLayer(nn.Module): | |||||
""" | |||||
:param int d_model: 输入、输出的维度 | |||||
:param int n_head: 多少个head,需要能被d_model整除 | |||||
:param int dim_ff: | |||||
:param float dropout: | |||||
:param int layer_idx: layer的编号 | |||||
""" | |||||
def __init__(self, d_model = 512, n_head = 8, dim_ff = 2048, dropout = 0.1, layer_idx = None): | |||||
super().__init__() | |||||
self.d_model = d_model | |||||
self.n_head = n_head | |||||
self.dim_ff = dim_ff | |||||
self.dropout = dropout | |||||
self.layer_idx = layer_idx # 记录layer的层索引,以方便获取state的信息 | |||||
self.self_attn = MultiHeadAttention(d_model, n_head, dropout, layer_idx) | |||||
self.self_attn_layer_norm = nn.LayerNorm(d_model) | |||||
self.encoder_attn = MultiHeadAttention(d_model, n_head, dropout, layer_idx) | |||||
self.encoder_attn_layer_norm = nn.LayerNorm(d_model) | |||||
self.ffn = nn.Sequential(nn.Linear(self.d_model, self.dim_ff), | |||||
nn.ReLU(), | |||||
nn.Dropout(dropout), | |||||
nn.Linear(self.dim_ff, self.d_model), | |||||
nn.Dropout(dropout)) | |||||
self.final_layer_norm = nn.LayerNorm(self.d_model) | |||||
def forward(self, x, encoder_output, encoder_mask=None, self_attn_mask=None, state=None): | |||||
""" | |||||
:param x: (batch, seq_len, dim), decoder端的输入 | |||||
:param encoder_output: (batch,src_seq_len,dim), encoder的输出 | |||||
:param encoder_mask: batch,src_seq_len, 为1的地方需要attend | |||||
:param self_attn_mask: seq_len, seq_len,下三角的mask矩阵,只在训练时传入 | |||||
:param TransformerState state: 只在inference阶段传入 | |||||
:return: | |||||
""" | |||||
# self attention part | |||||
residual = x | |||||
x = self.self_attn_layer_norm(x) | |||||
x, _ = self.self_attn(query=x, | |||||
key=x, | |||||
value=x, | |||||
attn_mask=self_attn_mask, | |||||
state=state) | |||||
x = F.dropout(x, p=self.dropout, training=self.training) | |||||
x = residual + x | |||||
# encoder attention part | |||||
residual = x | |||||
x = self.encoder_attn_layer_norm(x) | |||||
x, attn_weight = self.encoder_attn(query=x, | |||||
key=encoder_output, | |||||
value=encoder_output, | |||||
key_mask=encoder_mask, | |||||
state=state) | |||||
x = F.dropout(x, p=self.dropout, training=self.training) | |||||
x = residual + x | |||||
# ffn | |||||
residual = x | |||||
x = self.final_layer_norm(x) | |||||
x = self.ffn(x) | |||||
x = residual + x | |||||
return x, attn_weight | |||||
class TransformerSeq2SeqDecoder(Seq2SeqDecoder): | |||||
""" | |||||
:param embed: 输入token的embedding | |||||
:param nn.Module pos_embed: 位置embedding | |||||
:param int d_model: 输出、输出的大小 | |||||
:param int num_layers: 多少层 | |||||
:param int n_head: 多少个head | |||||
:param int dim_ff: FFN 的中间大小 | |||||
:param float dropout: Self-Attention和FFN中的dropout的大小 | |||||
:param bool bind_decoder_input_output_embed: 是否将输出层和输入层的词向量绑定在一起(即为同一个),若embed为StaticEmbedding, | |||||
则StaticEmbedding的vocab不能包含no_create_entry的token,同时StaticEmbedding初始化时lower为False, min_freq=1. | |||||
""" | |||||
def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], pos_embed: nn.Module = None, | |||||
d_model = 512, num_layers=6, n_head = 8, dim_ff = 2048, dropout = 0.1, | |||||
bind_decoder_input_output_embed = True): | |||||
super().__init__() | |||||
self.embed = get_embeddings(embed) | |||||
self.pos_embed = pos_embed | |||||
if bind_decoder_input_output_embed: | |||||
self.output_layer = get_bind_decoder_output_embed(self.embed) | |||||
else: # 不需要bind | |||||
self.output_embed = get_embeddings((self.embed.num_embeddings, self.embed.embedding_dim)) | |||||
self.output_layer = TiedEmbedding(self.output_embed.weight) | |||||
self.num_layers = num_layers | |||||
self.d_model = d_model | |||||
self.n_head = n_head | |||||
self.dim_ff = dim_ff | |||||
self.dropout = dropout | |||||
self.input_fc = nn.Linear(self.embed.embedding_dim, d_model) | |||||
self.layer_stacks = nn.ModuleList([TransformerSeq2SeqDecoderLayer(d_model, n_head, dim_ff, dropout, layer_idx) | |||||
for layer_idx in range(num_layers)]) | |||||
self.embed_scale = math.sqrt(d_model) | |||||
self.layer_norm = nn.LayerNorm(d_model) | |||||
self.output_fc = nn.Linear(self.d_model, self.embed.embedding_dim) | |||||
def forward(self, tokens, state, return_attention=False): | |||||
""" | |||||
:param torch.LongTensor tokens: batch x tgt_len,decode的词 | |||||
:param TransformerState state: 用于记录encoder的输出以及decode状态的对象,可以通过init_state()获取 | |||||
:param bool return_attention: 是否返回对encoder结果的attention score | |||||
:return: bsz x max_len x vocab_size; 如果return_attention=True, 还会返回bsz x max_len x encode_length | |||||
""" | |||||
encoder_output = state.encoder_output | |||||
encoder_mask = state.encoder_mask | |||||
assert state.decode_length<tokens.size(1), "The decoded tokens in State should be less than tokens." | |||||
tokens = tokens[:, state.decode_length:] | |||||
device = tokens.device | |||||
x = self.embed_scale * self.embed(tokens) | |||||
if self.pos_embed is not None: | |||||
position = torch.arange(state.decode_length, state.decode_length+tokens.size(1)).long().to(device)[None] | |||||
x += self.pos_embed(position) | |||||
x = self.input_fc(x) | |||||
x = F.dropout(x, p=self.dropout, training=self.training) | |||||
batch_size, max_tgt_len = tokens.size() | |||||
if max_tgt_len>1: | |||||
triangle_mask = self._get_triangle_mask(tokens) | |||||
else: | |||||
triangle_mask = None | |||||
for layer in self.layer_stacks: | |||||
x, attn_weight = layer(x=x, | |||||
encoder_output=encoder_output, | |||||
encoder_mask=encoder_mask, | |||||
self_attn_mask=triangle_mask, | |||||
state=state | |||||
) | |||||
x = self.layer_norm(x) # batch, tgt_len, dim | |||||
x = self.output_fc(x) | |||||
feats = self.output_layer(x) | |||||
if return_attention: | |||||
return feats, attn_weight | |||||
return feats | |||||
def init_state(self, encoder_output, encoder_mask): | |||||
""" | |||||
初始化一个TransformerState用于forward | |||||
:param torch.FloatTensor encoder_output: bsz x max_len x d_model, encoder的输出 | |||||
:param torch.ByteTensor encoder_mask: bsz x max_len, 为1的位置需要attend。 | |||||
:return: TransformerState | |||||
""" | |||||
if isinstance(encoder_output, torch.Tensor): | |||||
encoder_output = encoder_output | |||||
elif isinstance(encoder_output, (list, tuple)): | |||||
encoder_output = encoder_output[0] # 防止是LSTMEncoder的输出结果 | |||||
else: | |||||
raise TypeError("Unsupported `encoder_output` for TransformerSeq2SeqDecoder") | |||||
state = TransformerState(encoder_output, encoder_mask, num_decoder_layer=self.num_layers) | |||||
return state | |||||
@staticmethod | |||||
def _get_triangle_mask(tokens): | |||||
tensor = tokens.new_ones(tokens.size(1), tokens.size(1)) | |||||
return torch.tril(tensor).byte() | |||||
@@ -0,0 +1,145 @@ | |||||
r""" | |||||
每个Decoder都有对应的State用来记录encoder的输出以及Decode的历史记录 | |||||
""" | |||||
__all__ = [ | |||||
'State', | |||||
"LSTMState", | |||||
"TransformerState" | |||||
] | |||||
from typing import Union | |||||
import torch | |||||
class State: | |||||
def __init__(self, encoder_output=None, encoder_mask=None, **kwargs): | |||||
""" | |||||
每个Decoder都有对应的State对象用来承载encoder的输出以及当前时刻之前的decode状态。 | |||||
:param Union[torch.Tensor, list, tuple] encoder_output: 如果不为None,内部元素需要为torch.Tensor, 默认其中第一维是batch | |||||
维度 | |||||
:param Union[torch.Tensor, list, tuple] encoder_mask: 如果部位None,内部元素需要torch.Tensor, 默认其中第一维是batch | |||||
维度 | |||||
:param kwargs: | |||||
""" | |||||
self.encoder_output = encoder_output | |||||
self.encoder_mask = encoder_mask | |||||
self._decode_length = 0 | |||||
@property | |||||
def num_samples(self): | |||||
""" | |||||
返回的State中包含的是多少个sample的encoder状态,主要用于Generate的时候确定batch的大小。 | |||||
:return: | |||||
""" | |||||
if self.encoder_output is not None: | |||||
return self.encoder_output.size(0) | |||||
else: | |||||
return None | |||||
@property | |||||
def decode_length(self): | |||||
""" | |||||
当前Decode到哪个token了,decoder只会从decode_length之后的token开始decode, 为0说明还没开始decode。 | |||||
:return: | |||||
""" | |||||
return self._decode_length | |||||
@decode_length.setter | |||||
def decode_length(self, value): | |||||
self._decode_length = value | |||||
def _reorder_state(self, state: Union[torch.Tensor, list, tuple], indices: torch.LongTensor, dim: int = 0): | |||||
if isinstance(state, torch.Tensor): | |||||
state = state.index_select(index=indices, dim=dim) | |||||
elif isinstance(state, list): | |||||
for i in range(len(state)): | |||||
assert state[i] is not None | |||||
state[i] = self._reorder_state(state[i], indices, dim) | |||||
elif isinstance(state, tuple): | |||||
tmp_list = [] | |||||
for i in range(len(state)): | |||||
assert state[i] is not None | |||||
tmp_list.append(self._reorder_state(state[i], indices, dim)) | |||||
state = tuple(tmp_list) | |||||
else: | |||||
raise TypeError(f"Cannot reorder data of type:{type(state)}") | |||||
return state | |||||
def reorder_state(self, indices: torch.LongTensor): | |||||
if self.encoder_mask is not None: | |||||
self.encoder_mask = self._reorder_state(self.encoder_mask, indices) | |||||
if self.encoder_output is not None: | |||||
self.encoder_output = self._reorder_state(self.encoder_output, indices) | |||||
class LSTMState(State): | |||||
def __init__(self, encoder_output, encoder_mask, hidden, cell): | |||||
""" | |||||
LSTMDecoder对应的State,保存encoder的输出以及LSTM解码过程中的一些中间状态 | |||||
:param torch.FloatTensor encoder_output: bsz x src_seq_len x encode_output_size,encoder的输出 | |||||
:param torch.BoolTensor encoder_mask: bsz x src_seq_len, 为0的地方是padding | |||||
:param torch.FloatTensor hidden: num_layers x bsz x hidden_size, 上个时刻的hidden状态 | |||||
:param torch.FloatTensor cell: num_layers x bsz x hidden_size, 上个时刻的cell状态 | |||||
""" | |||||
super().__init__(encoder_output, encoder_mask) | |||||
self.hidden = hidden | |||||
self.cell = cell | |||||
self._input_feed = hidden[0] # 默认是上一个时刻的输出 | |||||
@property | |||||
def input_feed(self): | |||||
""" | |||||
LSTMDecoder中每个时刻的输入会把上个token的embedding和input_feed拼接起来输入到下个时刻,在LSTMDecoder不使用attention时, | |||||
input_feed即上个时刻的hidden state, 否则是attention layer的输出。 | |||||
:return: torch.FloatTensor, bsz x hidden_size | |||||
""" | |||||
return self._input_feed | |||||
@input_feed.setter | |||||
def input_feed(self, value): | |||||
self._input_feed = value | |||||
def reorder_state(self, indices: torch.LongTensor): | |||||
super().reorder_state(indices) | |||||
self.hidden = self._reorder_state(self.hidden, indices, dim=1) | |||||
self.cell = self._reorder_state(self.cell, indices, dim=1) | |||||
if self.input_feed is not None: | |||||
self.input_feed = self._reorder_state(self.input_feed, indices, dim=0) | |||||
class TransformerState(State): | |||||
def __init__(self, encoder_output, encoder_mask, num_decoder_layer): | |||||
""" | |||||
与TransformerSeq2SeqDecoder对应的State, | |||||
:param torch.FloatTensor encoder_output: bsz x encode_max_len x encoder_output_size, encoder的输出 | |||||
:param torch.ByteTensor encoder_mask: bsz x encode_max_len 为1的地方需要attend | |||||
:param int num_decoder_layer: decode有多少层 | |||||
""" | |||||
super().__init__(encoder_output, encoder_mask) | |||||
self.encoder_key = [None] * num_decoder_layer # 每一个元素 bsz x encoder_max_len x key_dim | |||||
self.encoder_value = [None] * num_decoder_layer # 每一个元素 bsz x encoder_max_len x value_dim | |||||
self.decoder_prev_key = [None] * num_decoder_layer # 每一个元素 bsz x decode_length x key_dim | |||||
self.decoder_prev_value = [None] * num_decoder_layer # 每一个元素 bsz x decode_length x key_dim | |||||
def reorder_state(self, indices: torch.LongTensor): | |||||
super().reorder_state(indices) | |||||
self.encoder_key = self._reorder_state(self.encoder_key, indices) | |||||
self.encoder_value = self._reorder_state(self.encoder_value, indices) | |||||
self.decoder_prev_key = self._reorder_state(self.decoder_prev_key, indices) | |||||
self.decoder_prev_value = self._reorder_state(self.decoder_prev_value, indices) | |||||
@property | |||||
def decode_length(self): | |||||
if self.decoder_prev_key[0] is not None: | |||||
return self.decoder_prev_key[0].size(1) | |||||
return 0 | |||||
@@ -0,0 +1,24 @@ | |||||
r"""undocumented""" | |||||
__all__ = [ | |||||
"TimestepDropout" | |||||
] | |||||
import torch | |||||
class TimestepDropout(torch.nn.Dropout): | |||||
r""" | |||||
传入参数的shape为 ``(batch_size, num_timesteps, embedding_dim)`` | |||||
使用同一个shape为 ``(batch_size, embedding_dim)`` 的mask在每个timestamp上做dropout。 | |||||
""" | |||||
def forward(self, x): | |||||
dropout_mask = x.new_ones(x.shape[0], x.shape[-1]) | |||||
torch.nn.functional.dropout(dropout_mask, self.p, self.training, inplace=True) | |||||
dropout_mask = dropout_mask.unsqueeze(1) # [batch_size, 1, embedding_dim] | |||||
if self.inplace: | |||||
x *= dropout_mask | |||||
return | |||||
else: | |||||
return x * dropout_mask |
@@ -1,5 +1,21 @@ | |||||
__all__ = [ | __all__ = [ | ||||
"ConvMaxpool", | |||||
"LSTM", | "LSTM", | ||||
"Seq2SeqEncoder", | |||||
"TransformerSeq2SeqEncoder", | |||||
"LSTMSeq2SeqEncoder", | |||||
"StarTransformer", | |||||
"VarRNN", | |||||
"VarLSTM", | |||||
"VarGRU" | |||||
] | ] | ||||
from .lstm import LSTM | |||||
from .conv_maxpool import ConvMaxpool | |||||
from .lstm import LSTM | |||||
from .seq2seq_encoder import Seq2SeqEncoder, TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder | |||||
from .star_transformer import StarTransformer | |||||
from .variational_rnn import VarRNN, VarLSTM, VarGRU |
@@ -0,0 +1,87 @@ | |||||
r"""undocumented""" | |||||
__all__ = [ | |||||
"ConvMaxpool" | |||||
] | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
class ConvMaxpool(nn.Module): | |||||
r""" | |||||
集合了Convolution和Max-Pooling于一体的层。给定一个batch_size x max_len x input_size的输入,返回batch_size x | |||||
sum(output_channels) 大小的matrix。在内部,是先使用CNN给输入做卷积,然后经过activation激活层,在通过在长度(max_len) | |||||
这一维进行max_pooling。最后得到每个sample的一个向量表示。 | |||||
""" | |||||
def __init__(self, in_channels, out_channels, kernel_sizes, activation="relu"): | |||||
r""" | |||||
:param int in_channels: 输入channel的大小,一般是embedding的维度; 或encoder的output维度 | |||||
:param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致 | |||||
:param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 | |||||
:param str activation: Convolution后的结果将通过该activation后再经过max-pooling。支持relu, sigmoid, tanh | |||||
""" | |||||
super(ConvMaxpool, self).__init__() | |||||
for kernel_size in kernel_sizes: | |||||
assert kernel_size % 2 == 1, "kernel size has to be odd numbers." | |||||
# convolution | |||||
if isinstance(kernel_sizes, (list, tuple, int)): | |||||
if isinstance(kernel_sizes, int) and isinstance(out_channels, int): | |||||
out_channels = [out_channels] | |||||
kernel_sizes = [kernel_sizes] | |||||
elif isinstance(kernel_sizes, (tuple, list)) and isinstance(out_channels, (tuple, list)): | |||||
assert len(out_channels) == len( | |||||
kernel_sizes), "The number of out_channels should be equal to the number" \ | |||||
" of kernel_sizes." | |||||
else: | |||||
raise ValueError("The type of out_channels and kernel_sizes should be the same.") | |||||
self.convs = nn.ModuleList([nn.Conv1d( | |||||
in_channels=in_channels, | |||||
out_channels=oc, | |||||
kernel_size=ks, | |||||
stride=1, | |||||
padding=ks // 2, | |||||
dilation=1, | |||||
groups=1, | |||||
bias=False) | |||||
for oc, ks in zip(out_channels, kernel_sizes)]) | |||||
else: | |||||
raise Exception( | |||||
'Incorrect kernel sizes: should be list, tuple or int') | |||||
# activation function | |||||
if activation == 'relu': | |||||
self.activation = F.relu | |||||
elif activation == 'sigmoid': | |||||
self.activation = F.sigmoid | |||||
elif activation == 'tanh': | |||||
self.activation = F.tanh | |||||
else: | |||||
raise Exception( | |||||
"Undefined activation function: choose from: relu, tanh, sigmoid") | |||||
def forward(self, x, mask=None): | |||||
r""" | |||||
:param torch.FloatTensor x: batch_size x max_len x input_size, 一般是经过embedding后的值 | |||||
:param mask: batch_size x max_len, pad的地方为0。不影响卷积运算,max-pool一定不会pool到pad为0的位置 | |||||
:return: | |||||
""" | |||||
# [N,L,C] -> [N,C,L] | |||||
x = torch.transpose(x, 1, 2) | |||||
# convolution | |||||
xs = [self.activation(conv(x)) for conv in self.convs] # [[N,C,L], ...] | |||||
if mask is not None: | |||||
mask = mask.unsqueeze(1) # B x 1 x L | |||||
xs = [x.masked_fill(mask.eq(False), float('-inf')) for x in xs] | |||||
# max-pooling | |||||
xs = [F.max_pool1d(input=i, kernel_size=i.size(2)).squeeze(2) | |||||
for i in xs] # [[N, C], ...] | |||||
return torch.cat(xs, dim=-1) # [N, C] |
@@ -0,0 +1,193 @@ | |||||
r"""undocumented""" | |||||
import torch.nn as nn | |||||
import torch | |||||
from torch.nn import LayerNorm | |||||
import torch.nn.functional as F | |||||
from typing import Union, Tuple | |||||
from ....core.utils import seq_len_to_mask | |||||
import math | |||||
from .lstm import LSTM | |||||
from ..attention import MultiHeadAttention | |||||
from ....embeddings.torch import StaticEmbedding | |||||
from ....embeddings.torch.utils import get_embeddings | |||||
__all__ = ['Seq2SeqEncoder', 'TransformerSeq2SeqEncoder', 'LSTMSeq2SeqEncoder'] | |||||
class Seq2SeqEncoder(nn.Module): | |||||
""" | |||||
所有Sequence2Sequence Encoder的基类。需要实现forward函数 | |||||
""" | |||||
def __init__(self): | |||||
super().__init__() | |||||
def forward(self, tokens, seq_len): | |||||
""" | |||||
:param torch.LongTensor tokens: bsz x max_len, encoder的输入 | |||||
:param torch.LongTensor seq_len: bsz | |||||
:return: | |||||
""" | |||||
raise NotImplementedError | |||||
class TransformerSeq2SeqEncoderLayer(nn.Module): | |||||
""" | |||||
Self-Attention的Layer, | |||||
:param int d_model: input和output的输出维度 | |||||
:param int n_head: 多少个head,每个head的维度为d_model/n_head | |||||
:param int dim_ff: FFN的维度大小 | |||||
:param float dropout: Self-attention和FFN的dropout大小,0表示不drop | |||||
""" | |||||
def __init__(self, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, | |||||
dropout: float = 0.1): | |||||
super(TransformerSeq2SeqEncoderLayer, self).__init__() | |||||
self.d_model = d_model | |||||
self.n_head = n_head | |||||
self.dim_ff = dim_ff | |||||
self.dropout = dropout | |||||
self.self_attn = MultiHeadAttention(d_model, n_head, dropout) | |||||
self.attn_layer_norm = LayerNorm(d_model) | |||||
self.ffn_layer_norm = LayerNorm(d_model) | |||||
self.ffn = nn.Sequential(nn.Linear(self.d_model, self.dim_ff), | |||||
nn.ReLU(), | |||||
nn.Dropout(dropout), | |||||
nn.Linear(self.dim_ff, self.d_model), | |||||
nn.Dropout(dropout)) | |||||
def forward(self, x, mask): | |||||
""" | |||||
:param x: batch x src_seq x d_model | |||||
:param mask: batch x src_seq,为0的地方为padding | |||||
:return: | |||||
""" | |||||
# attention | |||||
residual = x | |||||
x = self.attn_layer_norm(x) | |||||
x, _ = self.self_attn(query=x, | |||||
key=x, | |||||
value=x, | |||||
key_mask=mask) | |||||
x = F.dropout(x, p=self.dropout, training=self.training) | |||||
x = residual + x | |||||
# ffn | |||||
residual = x | |||||
x = self.ffn_layer_norm(x) | |||||
x = self.ffn(x) | |||||
x = residual + x | |||||
return x | |||||
class TransformerSeq2SeqEncoder(Seq2SeqEncoder): | |||||
""" | |||||
基于Transformer的Encoder | |||||
:param embed: encoder输入token的embedding | |||||
:param nn.Module pos_embed: position embedding | |||||
:param int num_layers: 多少层的encoder | |||||
:param int d_model: 输入输出的维度 | |||||
:param int n_head: 多少个head | |||||
:param int dim_ff: FFN中间的维度大小 | |||||
:param float dropout: Attention和FFN的dropout大小 | |||||
""" | |||||
def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], pos_embed = None, | |||||
num_layers = 6, d_model = 512, n_head = 8, dim_ff = 2048, dropout = 0.1): | |||||
super(TransformerSeq2SeqEncoder, self).__init__() | |||||
self.embed = get_embeddings(embed) | |||||
self.embed_scale = math.sqrt(d_model) | |||||
self.pos_embed = pos_embed | |||||
self.num_layers = num_layers | |||||
self.d_model = d_model | |||||
self.n_head = n_head | |||||
self.dim_ff = dim_ff | |||||
self.dropout = dropout | |||||
self.input_fc = nn.Linear(self.embed.embedding_dim, d_model) | |||||
self.layer_stacks = nn.ModuleList([TransformerSeq2SeqEncoderLayer(d_model, n_head, dim_ff, dropout) | |||||
for _ in range(num_layers)]) | |||||
self.layer_norm = LayerNorm(d_model) | |||||
def forward(self, tokens, seq_len): | |||||
""" | |||||
:param tokens: batch x max_len | |||||
:param seq_len: [batch] | |||||
:return: bsz x max_len x d_model, bsz x max_len(为0的地方为padding) | |||||
""" | |||||
x = self.embed(tokens) * self.embed_scale # batch, seq, dim | |||||
batch_size, max_src_len, _ = x.size() | |||||
device = x.device | |||||
if self.pos_embed is not None: | |||||
position = torch.arange(1, max_src_len + 1).unsqueeze(0).long().to(device) | |||||
x += self.pos_embed(position) | |||||
x = self.input_fc(x) | |||||
x = F.dropout(x, p=self.dropout, training=self.training) | |||||
encoder_mask = seq_len_to_mask(seq_len, max_len=max_src_len) | |||||
encoder_mask = encoder_mask.to(device) | |||||
for layer in self.layer_stacks: | |||||
x = layer(x, encoder_mask) | |||||
x = self.layer_norm(x) | |||||
return x, encoder_mask | |||||
class LSTMSeq2SeqEncoder(Seq2SeqEncoder): | |||||
""" | |||||
LSTM的Encoder | |||||
:param embed: encoder的token embed | |||||
:param int num_layers: 多少层 | |||||
:param int hidden_size: LSTM隐藏层、输出的大小 | |||||
:param float dropout: LSTM层之间的Dropout是多少 | |||||
:param bool bidirectional: 是否使用双向 | |||||
""" | |||||
def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], num_layers = 3, | |||||
hidden_size = 400, dropout = 0.3, bidirectional=True): | |||||
super().__init__() | |||||
self.embed = get_embeddings(embed) | |||||
self.num_layers = num_layers | |||||
self.dropout = dropout | |||||
self.hidden_size = hidden_size | |||||
self.bidirectional = bidirectional | |||||
hidden_size = hidden_size//2 if bidirectional else hidden_size | |||||
self.lstm = LSTM(input_size=embed.embedding_dim, hidden_size=hidden_size, bidirectional=bidirectional, | |||||
batch_first=True, dropout=dropout if num_layers>1 else 0, num_layers=num_layers) | |||||
def forward(self, tokens, seq_len): | |||||
""" | |||||
:param torch.LongTensor tokens: bsz x max_len | |||||
:param torch.LongTensor seq_len: bsz | |||||
:return: (output, (hidden, cell)), encoder_mask | |||||
output: bsz x max_len x hidden_size, | |||||
hidden,cell: batch_size x hidden_size, 最后一层的隐藏状态或cell状态 | |||||
encoder_mask: bsz x max_len, 为0的地方是padding | |||||
""" | |||||
x = self.embed(tokens) | |||||
device = x.device | |||||
x, (final_hidden, final_cell) = self.lstm(x, seq_len) | |||||
encoder_mask = seq_len_to_mask(seq_len).to(device) | |||||
# x: batch,seq_len,dim; h/c: num_layers*2,batch,dim | |||||
if self.bidirectional: | |||||
final_hidden = self.concat_bidir(final_hidden) # 将双向的hidden state拼接起来,用于接下来的decoder的input | |||||
final_cell = self.concat_bidir(final_cell) | |||||
return (x, (final_hidden[-1], final_cell[-1])), encoder_mask # 为了配合Seq2SeqBaseModel的forward,这边需要分为两个return | |||||
def concat_bidir(self, input): | |||||
output = input.view(self.num_layers, 2, input.size(1), -1).transpose(1, 2) | |||||
return output.reshape(self.num_layers, input.size(1), -1) |
@@ -0,0 +1,166 @@ | |||||
r"""undocumented | |||||
Star-Transformer 的encoder部分的 Pytorch 实现 | |||||
""" | |||||
__all__ = [ | |||||
"StarTransformer" | |||||
] | |||||
import numpy as np | |||||
import torch | |||||
from torch import nn | |||||
from torch.nn import functional as F | |||||
class StarTransformer(nn.Module): | |||||
r""" | |||||
Star-Transformer 的encoder部分。 输入3d的文本输入, 返回相同长度的文本编码 | |||||
paper: https://arxiv.org/abs/1902.09113 | |||||
""" | |||||
def __init__(self, hidden_size, num_layers, num_head, head_dim, dropout=0.1, max_len=None): | |||||
r""" | |||||
:param int hidden_size: 输入维度的大小。同时也是输出维度的大小。 | |||||
:param int num_layers: star-transformer的层数 | |||||
:param int num_head: head的数量。 | |||||
:param int head_dim: 每个head的维度大小。 | |||||
:param float dropout: dropout 概率. Default: 0.1 | |||||
:param int max_len: int or None, 如果为int,输入序列的最大长度, | |||||
模型会为输入序列加上position embedding。 | |||||
若为`None`,忽略加上position embedding的步骤. Default: `None` | |||||
""" | |||||
super(StarTransformer, self).__init__() | |||||
self.iters = num_layers | |||||
self.norm = nn.ModuleList([nn.LayerNorm(hidden_size, eps=1e-6) for _ in range(self.iters)]) | |||||
# self.emb_fc = nn.Conv2d(hidden_size, hidden_size, 1) | |||||
self.emb_drop = nn.Dropout(dropout) | |||||
self.ring_att = nn.ModuleList( | |||||
[_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0) | |||||
for _ in range(self.iters)]) | |||||
self.star_att = nn.ModuleList( | |||||
[_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0) | |||||
for _ in range(self.iters)]) | |||||
if max_len is not None: | |||||
self.pos_emb = nn.Embedding(max_len, hidden_size) | |||||
else: | |||||
self.pos_emb = None | |||||
def forward(self, data, mask): | |||||
r""" | |||||
:param FloatTensor data: [batch, length, hidden] 输入的序列 | |||||
:param ByteTensor mask: [batch, length] 输入序列的padding mask, 在没有内容(padding 部分) 为 0, | |||||
否则为 1 | |||||
:return: [batch, length, hidden] 编码后的输出序列 | |||||
[batch, hidden] 全局 relay 节点, 详见论文 | |||||
""" | |||||
def norm_func(f, x): | |||||
# B, H, L, 1 | |||||
return f(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) | |||||
B, L, H = data.size() | |||||
mask = (mask.eq(False)) # flip the mask for masked_fill_ | |||||
smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) | |||||
embs = data.permute(0, 2, 1)[:, :, :, None] # B H L 1 | |||||
if self.pos_emb: | |||||
P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device) \ | |||||
.view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None] # 1 H L 1 | |||||
embs = embs + P | |||||
embs = norm_func(self.emb_drop, embs) | |||||
nodes = embs | |||||
relay = embs.mean(2, keepdim=True) | |||||
ex_mask = mask[:, None, :, None].expand(B, H, L, 1) | |||||
r_embs = embs.view(B, H, 1, L) | |||||
for i in range(self.iters): | |||||
ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2) | |||||
nodes = F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax)) | |||||
# nodes = F.leaky_relu(self.ring_att[i](nodes, ax=ax)) | |||||
relay = F.leaky_relu(self.star_att[i](relay, torch.cat([relay, nodes], 2), smask)) | |||||
nodes = nodes.masked_fill_(ex_mask, 0) | |||||
nodes = nodes.view(B, H, L).permute(0, 2, 1) | |||||
return nodes, relay.view(B, H) | |||||
class _MSA1(nn.Module): | |||||
def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1): | |||||
super(_MSA1, self).__init__() | |||||
# Multi-head Self Attention Case 1, doing self-attention for small regions | |||||
# Due to the architecture of GPU, using hadamard production and summation are faster than dot production when unfold_size is very small | |||||
self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) | |||||
self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) | |||||
self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) | |||||
self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) | |||||
self.drop = nn.Dropout(dropout) | |||||
# print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim) | |||||
self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3 | |||||
def forward(self, x, ax=None): | |||||
# x: B, H, L, 1, ax : B, H, X, L append features | |||||
nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size | |||||
B, H, L, _ = x.shape | |||||
q, k, v = self.WQ(x), self.WK(x), self.WV(x) # x: (B,H,L,1) | |||||
if ax is not None: | |||||
aL = ax.shape[2] | |||||
ak = self.WK(ax).view(B, nhead, head_dim, aL, L) | |||||
av = self.WV(ax).view(B, nhead, head_dim, aL, L) | |||||
q = q.view(B, nhead, head_dim, 1, L) | |||||
k = F.unfold(k.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0)) \ | |||||
.view(B, nhead, head_dim, unfold_size, L) | |||||
v = F.unfold(v.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0)) \ | |||||
.view(B, nhead, head_dim, unfold_size, L) | |||||
if ax is not None: | |||||
k = torch.cat([k, ak], 3) | |||||
v = torch.cat([v, av], 3) | |||||
alphas = self.drop(F.softmax((q * k).sum(2, keepdim=True) / np.sqrt(head_dim), 3)) # B N L 1 U | |||||
att = (alphas * v).sum(3).view(B, nhead * head_dim, L, 1) | |||||
ret = self.WO(att) | |||||
return ret | |||||
class _MSA2(nn.Module): | |||||
def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1): | |||||
# Multi-head Self Attention Case 2, a broadcastable query for a sequence key and value | |||||
super(_MSA2, self).__init__() | |||||
self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) | |||||
self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) | |||||
self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) | |||||
self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) | |||||
self.drop = nn.Dropout(dropout) | |||||
# print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim) | |||||
self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3 | |||||
def forward(self, x, y, mask=None): | |||||
# x: B, H, 1, 1, 1 y: B H L 1 | |||||
nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size | |||||
B, H, L, _ = y.shape | |||||
q, k, v = self.WQ(x), self.WK(y), self.WV(y) | |||||
q = q.view(B, nhead, 1, head_dim) # B, H, 1, 1 -> B, N, 1, h | |||||
k = k.view(B, nhead, head_dim, L) # B, H, L, 1 -> B, N, h, L | |||||
v = v.view(B, nhead, head_dim, L).permute(0, 1, 3, 2) # B, H, L, 1 -> B, N, L, h | |||||
pre_a = torch.matmul(q, k) / np.sqrt(head_dim) | |||||
if mask is not None: | |||||
pre_a = pre_a.masked_fill(mask[:, None, None, :], -float('inf')) | |||||
alphas = self.drop(F.softmax(pre_a, 3)) # B, N, 1, L | |||||
att = torch.matmul(alphas, v).view(B, -1, 1, 1) # B, N, 1, h -> B, N*h, 1, 1 | |||||
return self.WO(att) |
@@ -0,0 +1,43 @@ | |||||
r"""undocumented""" | |||||
__all__ = [ | |||||
"TransformerEncoder" | |||||
] | |||||
from torch import nn | |||||
from .seq2seq_encoder import TransformerSeq2SeqEncoderLayer | |||||
class TransformerEncoder(nn.Module): | |||||
r""" | |||||
transformer的encoder模块,不包含embedding层 | |||||
""" | |||||
def __init__(self, num_layers, d_model=512, n_head=8, dim_ff=2048, dropout=0.1): | |||||
""" | |||||
:param int num_layers: 多少层Transformer | |||||
:param int d_model: input和output的大小 | |||||
:param int n_head: 多少个head | |||||
:param int dim_ff: FFN中间hidden大小 | |||||
:param float dropout: 多大概率drop attention和ffn中间的表示 | |||||
""" | |||||
super(TransformerEncoder, self).__init__() | |||||
self.layers = nn.ModuleList([TransformerSeq2SeqEncoderLayer(d_model = d_model, n_head = n_head, dim_ff = dim_ff, | |||||
dropout = dropout) for _ in range(num_layers)]) | |||||
self.norm = nn.LayerNorm(d_model, eps=1e-6) | |||||
def forward(self, x, seq_mask=None): | |||||
r""" | |||||
:param x: [batch, seq_len, model_size] 输入序列 | |||||
:param seq_mask: [batch, seq_len] 输入序列的padding mask, 若为 ``None`` , 生成全1向量. 为1的地方需要attend | |||||
Default: ``None`` | |||||
:return: [batch, seq_len, model_size] 输出序列 | |||||
""" | |||||
output = x | |||||
if seq_mask is None: | |||||
seq_mask = x.new_ones(x.size(0), x.size(1)).bool() | |||||
for layer in self.layers: | |||||
output = layer(output, seq_mask) | |||||
return self.norm(output) |
@@ -0,0 +1,303 @@ | |||||
r"""undocumented | |||||
Variational RNN 及相关模型的 fastNLP实现,相关论文参考: | |||||
`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | |||||
""" | |||||
__all__ = [ | |||||
"VarRNN", | |||||
"VarLSTM", | |||||
"VarGRU" | |||||
] | |||||
import torch | |||||
import torch.nn as nn | |||||
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence | |||||
try: | |||||
from torch import flip | |||||
except ImportError: | |||||
def flip(x, dims): | |||||
indices = [slice(None)] * x.dim() | |||||
for dim in dims: | |||||
indices[dim] = torch.arange( | |||||
x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) | |||||
return x[tuple(indices)] | |||||
class VarRnnCellWrapper(nn.Module): | |||||
r""" | |||||
Wrapper for normal RNN Cells, make it support variational dropout | |||||
""" | |||||
def __init__(self, cell, hidden_size, input_p, hidden_p): | |||||
super(VarRnnCellWrapper, self).__init__() | |||||
self.cell = cell | |||||
self.hidden_size = hidden_size | |||||
self.input_p = input_p | |||||
self.hidden_p = hidden_p | |||||
def forward(self, input_x, hidden, mask_x, mask_h, is_reversed=False): | |||||
r""" | |||||
:param PackedSequence input_x: [seq_len, batch_size, input_size] | |||||
:param hidden: for LSTM, tuple of (h_0, c_0), [batch_size, hidden_size] | |||||
for other RNN, h_0, [batch_size, hidden_size] | |||||
:param mask_x: [batch_size, input_size] dropout mask for input | |||||
:param mask_h: [batch_size, hidden_size] dropout mask for hidden | |||||
:return PackedSequence output: [seq_len, bacth_size, hidden_size] | |||||
hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] | |||||
for other RNN, h_n, [batch_size, hidden_size] | |||||
""" | |||||
def get_hi(hi, h0, size): | |||||
h0_size = size - hi.size(0) | |||||
if h0_size > 0: | |||||
return torch.cat([hi, h0[:h0_size]], dim=0) | |||||
return hi[:size] | |||||
is_lstm = isinstance(hidden, tuple) | |||||
input, batch_sizes = input_x.data, input_x.batch_sizes | |||||
output = [] | |||||
cell = self.cell | |||||
if is_reversed: | |||||
batch_iter = flip(batch_sizes, [0]) | |||||
idx = input.size(0) | |||||
else: | |||||
batch_iter = batch_sizes | |||||
idx = 0 | |||||
if is_lstm: | |||||
hn = (hidden[0].clone(), hidden[1].clone()) | |||||
else: | |||||
hn = hidden.clone() | |||||
hi = hidden | |||||
for size in batch_iter: | |||||
if is_reversed: | |||||
input_i = input[idx - size: idx] * mask_x[:size] | |||||
idx -= size | |||||
else: | |||||
input_i = input[idx: idx + size] * mask_x[:size] | |||||
idx += size | |||||
mask_hi = mask_h[:size] | |||||
if is_lstm: | |||||
hx, cx = hi | |||||
hi = (get_hi(hx, hidden[0], size) * | |||||
mask_hi, get_hi(cx, hidden[1], size)) | |||||
hi = cell(input_i, hi) | |||||
hn[0][:size] = hi[0] | |||||
hn[1][:size] = hi[1] | |||||
output.append(hi[0]) | |||||
else: | |||||
hi = get_hi(hi, hidden, size) * mask_hi | |||||
hi = cell(input_i, hi) | |||||
hn[:size] = hi | |||||
output.append(hi) | |||||
if is_reversed: | |||||
output = list(reversed(output)) | |||||
output = torch.cat(output, dim=0) | |||||
return PackedSequence(output, batch_sizes), hn | |||||
class VarRNNBase(nn.Module): | |||||
r""" | |||||
Variational Dropout RNN 实现. | |||||
论文参考: `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) | |||||
https://arxiv.org/abs/1512.05287`. | |||||
""" | |||||
def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, | |||||
bias=True, batch_first=False, | |||||
input_dropout=0, hidden_dropout=0, bidirectional=False): | |||||
r""" | |||||
:param mode: rnn 模式, (lstm or not) | |||||
:param Cell: rnn cell 类型, (lstm, gru, etc) | |||||
:param input_size: 输入 `x` 的特征维度 | |||||
:param hidden_size: 隐状态 `h` 的特征维度 | |||||
:param num_layers: rnn的层数. Default: 1 | |||||
:param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` | |||||
:param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 | |||||
(batch, seq, feature). Default: ``False`` | |||||
:param input_dropout: 对输入的dropout概率. Default: 0 | |||||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | |||||
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | |||||
""" | |||||
super(VarRNNBase, self).__init__() | |||||
self.mode = mode | |||||
self.input_size = input_size | |||||
self.hidden_size = hidden_size | |||||
self.num_layers = num_layers | |||||
self.bias = bias | |||||
self.batch_first = batch_first | |||||
self.input_dropout = input_dropout | |||||
self.hidden_dropout = hidden_dropout | |||||
self.bidirectional = bidirectional | |||||
self.num_directions = 2 if bidirectional else 1 | |||||
self._all_cells = nn.ModuleList() | |||||
for layer in range(self.num_layers): | |||||
for direction in range(self.num_directions): | |||||
input_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions | |||||
cell = Cell(input_size, self.hidden_size, bias) | |||||
self._all_cells.append(VarRnnCellWrapper( | |||||
cell, self.hidden_size, input_dropout, hidden_dropout)) | |||||
self.is_lstm = (self.mode == "LSTM") | |||||
def _forward_one(self, n_layer, n_direction, input, hx, mask_x, mask_h): | |||||
is_lstm = self.is_lstm | |||||
idx = self.num_directions * n_layer + n_direction | |||||
cell = self._all_cells[idx] | |||||
hi = (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx] | |||||
output_x, hidden_x = cell( | |||||
input, hi, mask_x, mask_h, is_reversed=(n_direction == 1)) | |||||
return output_x, hidden_x | |||||
def forward(self, x, hx=None): | |||||
r""" | |||||
:param x: [batch, seq_len, input_size] 输入序列 | |||||
:param hx: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全1向量. Default: ``None`` | |||||
:return (output, ht): [batch, seq_len, hidden_size*num_direction] 输出序列 | |||||
和 [batch, hidden_size*num_direction] 最后时刻隐状态 | |||||
""" | |||||
is_lstm = self.is_lstm | |||||
is_packed = isinstance(x, PackedSequence) | |||||
if not is_packed: | |||||
seq_len = x.size(1) if self.batch_first else x.size(0) | |||||
max_batch_size = x.size(0) if self.batch_first else x.size(1) | |||||
seq_lens = torch.LongTensor( | |||||
[seq_len for _ in range(max_batch_size)]) | |||||
x = pack_padded_sequence(x, seq_lens, batch_first=self.batch_first) | |||||
else: | |||||
max_batch_size = int(x.batch_sizes[0]) | |||||
x, batch_sizes = x.data, x.batch_sizes | |||||
if hx is None: | |||||
hx = x.new_zeros(self.num_layers * self.num_directions, | |||||
max_batch_size, self.hidden_size, requires_grad=True) | |||||
if is_lstm: | |||||
hx = (hx, hx.new_zeros(hx.size(), requires_grad=True)) | |||||
mask_x = x.new_ones((max_batch_size, self.input_size)) | |||||
mask_out = x.new_ones( | |||||
(max_batch_size, self.hidden_size * self.num_directions)) | |||||
mask_h_ones = x.new_ones((max_batch_size, self.hidden_size)) | |||||
nn.functional.dropout(mask_x, p=self.input_dropout, | |||||
training=self.training, inplace=True) | |||||
nn.functional.dropout(mask_out, p=self.hidden_dropout, | |||||
training=self.training, inplace=True) | |||||
hidden = x.new_zeros( | |||||
(self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) | |||||
if is_lstm: | |||||
cellstate = x.new_zeros( | |||||
(self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) | |||||
for layer in range(self.num_layers): | |||||
output_list = [] | |||||
input_seq = PackedSequence(x, batch_sizes) | |||||
mask_h = nn.functional.dropout( | |||||
mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False) | |||||
for direction in range(self.num_directions): | |||||
output_x, hidden_x = self._forward_one(layer, direction, input_seq, hx, | |||||
mask_x if layer == 0 else mask_out, mask_h) | |||||
output_list.append(output_x.data) | |||||
idx = self.num_directions * layer + direction | |||||
if is_lstm: | |||||
hidden[idx] = hidden_x[0] | |||||
cellstate[idx] = hidden_x[1] | |||||
else: | |||||
hidden[idx] = hidden_x | |||||
x = torch.cat(output_list, dim=-1) | |||||
if is_lstm: | |||||
hidden = (hidden, cellstate) | |||||
if is_packed: | |||||
output = PackedSequence(x, batch_sizes) | |||||
else: | |||||
x = PackedSequence(x, batch_sizes) | |||||
output, _ = pad_packed_sequence(x, batch_first=self.batch_first) | |||||
return output, hidden | |||||
class VarLSTM(VarRNNBase): | |||||
r""" | |||||
Variational Dropout LSTM. | |||||
相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | |||||
""" | |||||
def __init__(self, *args, **kwargs): | |||||
r""" | |||||
:param input_size: 输入 `x` 的特征维度 | |||||
:param hidden_size: 隐状态 `h` 的特征维度 | |||||
:param num_layers: rnn的层数. Default: 1 | |||||
:param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` | |||||
:param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 | |||||
(batch, seq, feature). Default: ``False`` | |||||
:param input_dropout: 对输入的dropout概率. Default: 0 | |||||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | |||||
:param bidirectional: 若为 ``True``, 使用双向的LSTM. Default: ``False`` | |||||
""" | |||||
super(VarLSTM, self).__init__( | |||||
mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs) | |||||
def forward(self, x, hx=None): | |||||
return super(VarLSTM, self).forward(x, hx) | |||||
class VarRNN(VarRNNBase): | |||||
r""" | |||||
Variational Dropout RNN. | |||||
相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | |||||
""" | |||||
def __init__(self, *args, **kwargs): | |||||
r""" | |||||
:param input_size: 输入 `x` 的特征维度 | |||||
:param hidden_size: 隐状态 `h` 的特征维度 | |||||
:param num_layers: rnn的层数. Default: 1 | |||||
:param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` | |||||
:param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 | |||||
(batch, seq, feature). Default: ``False`` | |||||
:param input_dropout: 对输入的dropout概率. Default: 0 | |||||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | |||||
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | |||||
""" | |||||
super(VarRNN, self).__init__( | |||||
mode="RNN", Cell=nn.RNNCell, *args, **kwargs) | |||||
def forward(self, x, hx=None): | |||||
return super(VarRNN, self).forward(x, hx) | |||||
class VarGRU(VarRNNBase): | |||||
r""" | |||||
Variational Dropout GRU. | |||||
相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | |||||
""" | |||||
def __init__(self, *args, **kwargs): | |||||
r""" | |||||
:param input_size: 输入 `x` 的特征维度 | |||||
:param hidden_size: 隐状态 `h` 的特征维度 | |||||
:param num_layers: rnn的层数. Default: 1 | |||||
:param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` | |||||
:param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 | |||||
(batch, seq, feature). Default: ``False`` | |||||
:param input_dropout: 对输入的dropout概率. Default: 0 | |||||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | |||||
:param bidirectional: 若为 ``True``, 使用双向的GRU. Default: ``False`` | |||||
""" | |||||
super(VarGRU, self).__init__( | |||||
mode="GRU", Cell=nn.GRUCell, *args, **kwargs) | |||||
def forward(self, x, hx=None): | |||||
return super(VarGRU, self).forward(x, hx) |
@@ -0,0 +1,6 @@ | |||||
__all__ = [ | |||||
'SequenceGenerator' | |||||
] | |||||
from .seq2seq_generator import SequenceGenerator |
@@ -0,0 +1,536 @@ | |||||
r""" | |||||
""" | |||||
__all__ = [ | |||||
'SequenceGenerator' | |||||
] | |||||
import torch | |||||
from torch import nn | |||||
import torch.nn.functional as F | |||||
from ..decoder.seq2seq_decoder import Seq2SeqDecoder, State | |||||
from functools import partial | |||||
def _get_model_device(model): | |||||
r""" | |||||
传入一个nn.Module的模型,获取它所在的device | |||||
:param model: nn.Module | |||||
:return: torch.device,None 如果返回值为None,说明这个模型没有任何参数。 | |||||
""" | |||||
assert isinstance(model, nn.Module) | |||||
parameters = list(model.parameters()) | |||||
if len(parameters) == 0: | |||||
return None | |||||
else: | |||||
return parameters[0].device | |||||
class SequenceGenerator: | |||||
""" | |||||
给定一个Seq2SeqDecoder,decode出句子。输入的decoder对象需要有decode()函数, 接受的第一个参数为decode的到目前位置的所有输出, | |||||
第二个参数为state。SequenceGenerator不会对state进行任何操作。 | |||||
""" | |||||
def __init__(self, decoder: Seq2SeqDecoder, max_length=20, max_len_a=0.0, num_beams=1, | |||||
do_sample=True, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, | |||||
repetition_penalty=1, length_penalty=1.0, pad_token_id=0): | |||||
""" | |||||
:param Seq2SeqDecoder decoder: Decoder对象 | |||||
:param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len | |||||
:param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask | |||||
:param int num_beams: beam search的大小 | |||||
:param bool do_sample: 是否通过采样的方式生成 | |||||
:param float temperature: 只有在do_sample为True才有意义 | |||||
:param int top_k: 只从top_k中采样 | |||||
:param float top_p: 只从top_p的token中采样,nucles sample | |||||
:param int,None bos_token_id: 句子开头的token id | |||||
:param int,None eos_token_id: 句子结束的token id | |||||
:param float repetition_penalty: 多大程度上惩罚重复的token | |||||
:param float length_penalty: 对长度的惩罚,小于1鼓励长句,大于1鼓励短剧 | |||||
:param int pad_token_id: 当某句话生成结束之后,之后生成的内容用pad_token_id补充 | |||||
""" | |||||
if do_sample: | |||||
self.generate_func = partial(sample_generate, decoder=decoder, max_length=max_length, max_len_a=max_len_a, | |||||
num_beams=num_beams, | |||||
temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id, | |||||
eos_token_id=eos_token_id, repetition_penalty=repetition_penalty, | |||||
length_penalty=length_penalty, pad_token_id=pad_token_id) | |||||
else: | |||||
self.generate_func = partial(greedy_generate, decoder=decoder, max_length=max_length, max_len_a=max_len_a, | |||||
num_beams=num_beams, | |||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, | |||||
repetition_penalty=repetition_penalty, | |||||
length_penalty=length_penalty, pad_token_id=pad_token_id) | |||||
self.do_sample = do_sample | |||||
self.max_length = max_length | |||||
self.num_beams = num_beams | |||||
self.temperature = temperature | |||||
self.top_k = top_k | |||||
self.top_p = top_p | |||||
self.bos_token_id = bos_token_id | |||||
self.eos_token_id = eos_token_id | |||||
self.repetition_penalty = repetition_penalty | |||||
self.length_penalty = length_penalty | |||||
self.decoder = decoder | |||||
@torch.no_grad() | |||||
def generate(self, state, tokens=None): | |||||
""" | |||||
:param State state: encoder结果的State, 是与Decoder配套是用的 | |||||
:param torch.LongTensor,None tokens: batch_size x length, 开始的token。如果为None,则默认添加bos_token作为开头的token | |||||
进行生成。 | |||||
:return: bsz x max_length' 生成的token序列。如果eos_token_id不为None, 每个sequence的结尾一定是eos_token_id | |||||
""" | |||||
return self.generate_func(tokens=tokens, state=state) | |||||
@torch.no_grad() | |||||
def greedy_generate(decoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=1, | |||||
bos_token_id=None, eos_token_id=None, pad_token_id=0, | |||||
repetition_penalty=1, length_penalty=1.0): | |||||
""" | |||||
贪婪地搜索句子 | |||||
:param Decoder decoder: Decoder对象 | |||||
:param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 | |||||
:param State state: 应该包含encoder的一些输出。 | |||||
:param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len | |||||
:param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask | |||||
:param int num_beams: 使用多大的beam进行解码。 | |||||
:param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。 | |||||
:param int eos_token_id: 结束的token,如果为None,则一定会解码到max_length这么长。 | |||||
:param int pad_token_id: pad的token id | |||||
:param float repetition_penalty: 对重复出现的token多大的惩罚。 | |||||
:param float length_penalty: 对每个token(除了eos)按照长度进行一定的惩罚。 | |||||
:return: | |||||
""" | |||||
if num_beams == 1: | |||||
token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a, | |||||
temperature=1, top_k=50, top_p=1, | |||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, | |||||
repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||||
pad_token_id=pad_token_id) | |||||
else: | |||||
token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a, | |||||
num_beams=num_beams, temperature=1, top_k=50, top_p=1, | |||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, | |||||
repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||||
pad_token_id=pad_token_id) | |||||
return token_ids | |||||
@torch.no_grad() | |||||
def sample_generate(decoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=1, temperature=1.0, top_k=50, | |||||
top_p=1.0, bos_token_id=None, eos_token_id=None, pad_token_id=0, repetition_penalty=1.0, | |||||
length_penalty=1.0): | |||||
""" | |||||
使用采样的方法生成句子 | |||||
:param Decoder decoder: Decoder对象 | |||||
:param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 | |||||
:param State state: 应该包含encoder的一些输出。 | |||||
:param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len | |||||
:param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask | |||||
:param int num_beam: 使用多大的beam进行解码。 | |||||
:param float temperature: 采样时的退火大小 | |||||
:param int top_k: 只在top_k的sample里面采样 | |||||
:param float top_p: 介于0,1的值。 | |||||
:param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。 | |||||
:param int eos_token_id: 结束的token,如果为None,则一定会解码到max_length这么长。 | |||||
:param int pad_token_id: pad的token id | |||||
:param float repetition_penalty: 对重复出现的token多大的惩罚。 | |||||
:param float length_penalty: 对每个token(除了eos)按照长度进行一定的惩罚。 | |||||
:return: | |||||
""" | |||||
# 每个位置在生成的时候会sample生成 | |||||
if num_beams == 1: | |||||
token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a, | |||||
temperature=temperature, top_k=top_k, top_p=top_p, | |||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, | |||||
repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||||
pad_token_id=pad_token_id) | |||||
else: | |||||
token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a, | |||||
num_beams=num_beams, temperature=temperature, top_k=top_k, top_p=top_p, | |||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, | |||||
repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||||
pad_token_id=pad_token_id) | |||||
return token_ids | |||||
def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_length=20, max_len_a=0.0, temperature=1.0, top_k=50, | |||||
top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, | |||||
repetition_penalty=1.0, length_penalty=1.0, pad_token_id=0): | |||||
device = _get_model_device(decoder) | |||||
if tokens is None: | |||||
if bos_token_id is None: | |||||
raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.") | |||||
batch_size = state.num_samples | |||||
if batch_size is None: | |||||
raise RuntimeError("Cannot infer the number of samples from `state`.") | |||||
tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device) | |||||
batch_size = tokens.size(0) | |||||
if state.num_samples: | |||||
assert state.num_samples == batch_size, "The number of samples in `tokens` and `state` should match." | |||||
if eos_token_id is None: | |||||
_eos_token_id = -1 | |||||
else: | |||||
_eos_token_id = eos_token_id | |||||
scores = decoder.decode(tokens=tokens, state=state) # 主要是为了update state | |||||
if _eos_token_id!=-1: # 防止第一个位置为结束 | |||||
scores[:, _eos_token_id] = -1e12 | |||||
next_tokens = scores.argmax(dim=-1, keepdim=True) | |||||
token_ids = torch.cat([tokens, next_tokens], dim=1) | |||||
cur_len = token_ids.size(1) | |||||
dones = token_ids.new_zeros(batch_size).eq(1) | |||||
# tokens = tokens[:, -1:] | |||||
if max_len_a!=0: | |||||
# (bsz x num_beams, ) | |||||
if state.encoder_mask is not None: | |||||
max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length | |||||
else: | |||||
max_lengths = tokens.new_full((tokens.size(0), ), fill_value=max_length, dtype=torch.long) | |||||
real_max_length = max_lengths.max().item() | |||||
else: | |||||
real_max_length = max_length | |||||
if state.encoder_mask is not None: | |||||
max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length | |||||
else: | |||||
max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long) | |||||
while cur_len < real_max_length: | |||||
scores = decoder.decode(tokens=token_ids, state=state) # batch_size x vocab_size | |||||
if repetition_penalty != 1.0: | |||||
token_scores = scores.gather(dim=1, index=token_ids) | |||||
lt_zero_mask = token_scores.lt(0).float() | |||||
ge_zero_mask = lt_zero_mask.eq(0).float() | |||||
token_scores = lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores | |||||
scores.scatter_(dim=1, index=token_ids, src=token_scores) | |||||
if eos_token_id is not None and length_penalty != 1.0: | |||||
token_scores = scores / cur_len ** length_penalty # batch_size x vocab_size | |||||
eos_mask = scores.new_ones(scores.size(1)) | |||||
eos_mask[eos_token_id] = 0 | |||||
eos_mask = eos_mask.unsqueeze(0).eq(1) | |||||
scores = scores.masked_scatter(eos_mask, token_scores) # 也即除了eos,其他词的分数经过了放大/缩小 | |||||
if do_sample: | |||||
if temperature > 0 and temperature != 1: | |||||
scores = scores / temperature | |||||
scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=2) | |||||
# 加上1e-12是为了避免https://github.com/pytorch/pytorch/pull/27523 | |||||
probs = F.softmax(scores, dim=-1) + 1e-12 | |||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # batch_size | |||||
else: | |||||
next_tokens = torch.argmax(scores, dim=-1) # batch_size | |||||
# 如果已经达到对应的sequence长度了,就直接填为eos了 | |||||
if _eos_token_id!=-1: | |||||
next_tokens = next_tokens.masked_fill(max_lengths.eq(cur_len+1), _eos_token_id) | |||||
next_tokens = next_tokens.masked_fill(dones, pad_token_id) # 对已经搜索完成的sample做padding | |||||
tokens = next_tokens.unsqueeze(1) | |||||
token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len | |||||
end_mask = next_tokens.eq(_eos_token_id) | |||||
dones = dones.__or__(end_mask) | |||||
cur_len += 1 | |||||
if dones.min() == 1: | |||||
break | |||||
# if eos_token_id is not None: | |||||
# tokens.scatter(index=max_lengths[:, None], dim=1, value=eos_token_id) # 将最大长度位置设置为eos | |||||
# if cur_len == max_length: | |||||
# token_ids[:, -1].masked_fill_(~dones, eos_token_id) # 若到最长长度仍未到EOS,则强制将最后一个词替换成eos | |||||
return token_ids | |||||
def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=4, temperature=1.0, | |||||
top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, | |||||
repetition_penalty=1.0, length_penalty=None, pad_token_id=0) -> torch.LongTensor: | |||||
# 进行beam search | |||||
device = _get_model_device(decoder) | |||||
if tokens is None: | |||||
if bos_token_id is None: | |||||
raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.") | |||||
batch_size = state.num_samples | |||||
if batch_size is None: | |||||
raise RuntimeError("Cannot infer the number of samples from `state`.") | |||||
tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device) | |||||
batch_size = tokens.size(0) | |||||
if state.num_samples: | |||||
assert state.num_samples == batch_size, "The number of samples in `tokens` and `state` should match." | |||||
if eos_token_id is None: | |||||
_eos_token_id = -1 | |||||
else: | |||||
_eos_token_id = eos_token_id | |||||
scores = decoder.decode(tokens=tokens, state=state) # 这里要传入的是整个句子的长度 | |||||
if _eos_token_id!=-1: # 防止第一个位置为结束 | |||||
scores[:, _eos_token_id] = -1e12 | |||||
vocab_size = scores.size(1) | |||||
assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size." | |||||
if do_sample: | |||||
probs = F.softmax(scores, dim=-1) + 1e-12 | |||||
next_tokens = torch.multinomial(probs, num_samples=num_beams) # (batch_size, num_beams) | |||||
logits = probs.log() | |||||
next_scores = logits.gather(dim=1, index=next_tokens) # (batch_size, num_beams) | |||||
else: | |||||
scores = F.log_softmax(scores, dim=-1) # (batch_size, vocab_size) | |||||
# 得到(batch_size, num_beams), (batch_size, num_beams) | |||||
next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True) | |||||
# 根据index来做顺序的调转 | |||||
indices = torch.arange(batch_size, dtype=torch.long).to(device) | |||||
indices = indices.repeat_interleave(num_beams) | |||||
state.reorder_state(indices) | |||||
tokens = tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length | |||||
# 记录生成好的token (batch_size', cur_len) | |||||
token_ids = torch.cat([tokens, next_tokens.view(-1, 1)], dim=-1) | |||||
dones = [False] * batch_size | |||||
beam_scores = next_scores.view(-1) # batch_size * num_beams | |||||
# 用来记录已经生成好的token的长度 | |||||
cur_len = token_ids.size(1) | |||||
if max_len_a!=0: | |||||
# (bsz x num_beams, ) | |||||
if state.encoder_mask is not None: | |||||
max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length | |||||
else: | |||||
max_lengths = tokens.new_full((tokens.size(0), ), fill_value=max_length, dtype=torch.long) | |||||
real_max_length = max_lengths.max().item() | |||||
else: | |||||
real_max_length = max_length | |||||
if state.encoder_mask is not None: | |||||
max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length | |||||
else: | |||||
max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long) | |||||
hypos = [ | |||||
BeamHypotheses(num_beams, real_max_length, length_penalty, early_stopping=False) for _ in range(batch_size) | |||||
] | |||||
# 0, num_beams, 2*num_beams, ... | |||||
batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids) | |||||
while cur_len < real_max_length: | |||||
scores = decoder.decode(token_ids, state) # (bsz x num_beams, vocab_size) | |||||
if repetition_penalty != 1.0: | |||||
token_scores = scores.gather(dim=1, index=token_ids) | |||||
lt_zero_mask = token_scores.lt(0).float() | |||||
ge_zero_mask = lt_zero_mask.eq(0).float() | |||||
token_scores = lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores | |||||
scores.scatter_(dim=1, index=token_ids, src=token_scores) | |||||
if _eos_token_id!=-1: | |||||
max_len_eos_mask = max_lengths.eq(cur_len+1) | |||||
eos_scores = scores[:, _eos_token_id] | |||||
# 如果已经达到最大长度,就把eos的分数加大 | |||||
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+1e32, eos_scores) | |||||
if do_sample: | |||||
if temperature > 0 and temperature != 1: | |||||
scores = scores / temperature | |||||
# 多召回一个防止eos | |||||
scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=num_beams + 1) | |||||
# 加上1e-12是为了避免https://github.com/pytorch/pytorch/pull/27523 | |||||
probs = F.softmax(scores, dim=-1) + 1e-12 | |||||
# 保证至少有一个不是eos的值 | |||||
_tokens = torch.multinomial(probs, num_samples=num_beams + 1) # batch_size' x (num_beams+1) | |||||
logits = probs.log() | |||||
# 防止全是这个beam的被选中了,且需要考虑eos被选择的情况 | |||||
_scores = logits.gather(dim=1, index=_tokens) # batch_size' x (num_beams+1) | |||||
_scores = _scores + beam_scores[:, None] # batch_size' x (num_beams+1) | |||||
# 从这里面再选择top的2*num_beam个 | |||||
_scores = _scores.view(batch_size, num_beams * (num_beams + 1)) | |||||
next_scores, ids = _scores.topk(2 * num_beams, dim=1, largest=True, sorted=True) | |||||
_tokens = _tokens.view(batch_size, num_beams * (num_beams + 1)) | |||||
next_tokens = _tokens.gather(dim=1, index=ids) # (batch_size, 2*num_beams) | |||||
from_which_beam = ids // (num_beams + 1) # (batch_size, 2*num_beams) | |||||
else: | |||||
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) | |||||
_scores = scores + beam_scores[:, None] # (batch_size * num_beams, vocab_size) | |||||
_scores = _scores.view(batch_size, -1) # (batch_size, num_beams*vocab_size) | |||||
next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) # (bsz, 2*num_beams) | |||||
from_which_beam = ids // vocab_size # (batch_size, 2*num_beams) | |||||
next_tokens = ids % vocab_size # (batch_size, 2*num_beams) | |||||
# 接下来需要组装下一个batch的结果。 | |||||
# 需要选定哪些留下来 | |||||
# next_scores, sorted_inds = next_scores.sort(dim=-1, descending=True) | |||||
# next_tokens = next_tokens.gather(dim=1, index=sorted_inds) | |||||
# from_which_beam = from_which_beam.gather(dim=1, index=sorted_inds) | |||||
not_eos_mask = next_tokens.ne(_eos_token_id) # 为1的地方不是eos | |||||
keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams) # 为1的地方需要保留 | |||||
keep_mask = not_eos_mask.__and__(keep_mask) # 为1的地方是需要进行下一步search的 | |||||
_next_tokens = next_tokens.masked_select(keep_mask).view(-1, 1) | |||||
_from_which_beam = from_which_beam.masked_select(keep_mask).view(batch_size, num_beams) # 上面的token是来自哪个beam | |||||
_next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams) | |||||
beam_scores = _next_scores.view(-1) | |||||
flag = True | |||||
if cur_len+1 == real_max_length: | |||||
eos_batch_idx = torch.arange(batch_size).to(next_tokens).repeat_interleave(repeats=num_beams, dim=0) | |||||
eos_beam_ind = torch.arange(num_beams).to(token_ids).repeat(batch_size) # 表示的是indice | |||||
eos_beam_idx = from_which_beam[:, :num_beams].reshape(-1) # 表示的是从哪个beam获取得到的 | |||||
else: | |||||
# 将每个batch中在num_beam内的序列添加到结束中, 为1的地方需要结束了 | |||||
effective_eos_mask = next_tokens[:, :num_beams].eq(_eos_token_id) # batch_size x num_beams | |||||
if effective_eos_mask.sum().gt(0): | |||||
eos_batch_idx, eos_beam_ind = effective_eos_mask.nonzero(as_tuple=True) | |||||
# 是由于from_which_beam是 (batch_size, 2*num_beams)的,所以需要2*num_beams | |||||
eos_beam_idx = eos_batch_idx * num_beams * 2 + eos_beam_ind | |||||
eos_beam_idx = from_which_beam.view(-1)[eos_beam_idx] # 获取真实的从哪个beam获取的eos | |||||
else: | |||||
flag = False | |||||
if flag: | |||||
_token_ids = torch.cat([token_ids, _next_tokens], dim=-1) | |||||
for batch_idx, beam_ind, beam_idx in zip(eos_batch_idx.tolist(), eos_beam_ind.tolist(), | |||||
eos_beam_idx.tolist()): | |||||
if not dones[batch_idx]: | |||||
score = next_scores[batch_idx, beam_ind].item() | |||||
# 之后需要在结尾新增一个eos | |||||
if _eos_token_id!=-1: | |||||
hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score) | |||||
else: | |||||
hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx].clone(), score) | |||||
# 更改state状态, 重组token_ids | |||||
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维 | |||||
state.reorder_state(reorder_inds) | |||||
# 重新组织token_ids的状态 | |||||
token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), _next_tokens], dim=-1) | |||||
for batch_idx in range(batch_size): | |||||
dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item()) or \ | |||||
max_lengths[batch_idx*num_beams]==cur_len+1 | |||||
cur_len += 1 | |||||
if all(dones): | |||||
break | |||||
# select the best hypotheses | |||||
tgt_len = token_ids.new_zeros(batch_size) | |||||
best = [] | |||||
for i, hypotheses in enumerate(hypos): | |||||
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1] | |||||
# 把上面替换为非eos的词替换回eos | |||||
if _eos_token_id!=-1: | |||||
best_hyp = torch.cat([best_hyp, best_hyp.new_ones(1)*_eos_token_id]) | |||||
tgt_len[i] = len(best_hyp) | |||||
best.append(best_hyp) | |||||
# generate target batch | |||||
decoded = token_ids.new_zeros(batch_size, tgt_len.max().item()).fill_(pad_token_id) | |||||
for i, hypo in enumerate(best): | |||||
decoded[i, :tgt_len[i]] = hypo | |||||
return decoded | |||||
class BeamHypotheses(object): | |||||
def __init__(self, num_beams, max_length, length_penalty, early_stopping): | |||||
""" | |||||
Initialize n-best list of hypotheses. | |||||
""" | |||||
self.max_length = max_length - 1 # ignoring bos_token | |||||
self.length_penalty = length_penalty | |||||
self.early_stopping = early_stopping | |||||
self.num_beams = num_beams | |||||
self.hyp = [] | |||||
self.worst_score = 1e9 | |||||
def __len__(self): | |||||
""" | |||||
Number of hypotheses in the list. | |||||
""" | |||||
return len(self.hyp) | |||||
def add(self, hyp, sum_logprobs): | |||||
""" | |||||
Add a new hypothesis to the list. | |||||
""" | |||||
score = sum_logprobs / len(hyp) ** self.length_penalty | |||||
if len(self) < self.num_beams or score > self.worst_score: | |||||
self.hyp.append((score, hyp)) | |||||
if len(self) > self.num_beams: | |||||
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)]) | |||||
del self.hyp[sorted_scores[0][1]] | |||||
self.worst_score = sorted_scores[1][0] | |||||
else: | |||||
self.worst_score = min(score, self.worst_score) | |||||
def is_done(self, best_sum_logprobs): | |||||
""" | |||||
If there are enough hypotheses and that none of the hypotheses being generated | |||||
can become better than the worst one in the heap, then we are done with this sentence. | |||||
""" | |||||
if len(self) < self.num_beams: | |||||
return False | |||||
elif self.early_stopping: | |||||
return True | |||||
else: | |||||
return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty | |||||
def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1): | |||||
""" | |||||
根据top_k, top_p的值,将不满足的值置为filter_value的值 | |||||
:param torch.Tensor logits: bsz x vocab_size | |||||
:param int top_k: 如果大于0,则只保留最top_k的词汇的概率,剩下的位置被置为filter_value | |||||
:param int top_p: 根据(http://arxiv.org/abs/1904.09751)设置的筛选方式 | |||||
:param float filter_value: | |||||
:param int min_tokens_to_keep: 每个sample返回的分布中有概率的词不会低于这个值 | |||||
:return: | |||||
""" | |||||
if top_k > 0: | |||||
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check | |||||
# Remove all tokens with a probability less than the last token of the top-k | |||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | |||||
logits[indices_to_remove] = filter_value | |||||
if top_p < 1.0: | |||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |||||
# Remove tokens with cumulative probability above the threshold (token with 0 are kept) | |||||
sorted_indices_to_remove = cumulative_probs > top_p | |||||
if min_tokens_to_keep > 1: | |||||
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) | |||||
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 | |||||
# Shift the indices to the right to keep also the first token above the threshold | |||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |||||
sorted_indices_to_remove[..., 0] = 0 | |||||
# scatter sorted tensors to original indexing | |||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) | |||||
logits[indices_to_remove] = filter_value | |||||
return logits |
@@ -0,0 +1,142 @@ | |||||
""" | |||||
此模块可以非常方便的测试模型。 | |||||
若你的模型属于:文本分类,序列标注,自然语言推理(NLI),可以直接使用此模块测试 | |||||
若模型不属于上述类别,也可以自己准备假数据,设定loss和metric进行测试 | |||||
此模块的测试仅保证模型能使用fastNLP进行训练和测试,不测试模型实际性能 | |||||
Example:: | |||||
# import 全大写变量... | |||||
from model_runner import * | |||||
# 测试一个文本分类模型 | |||||
init_emb = (VOCAB_SIZE, 50) | |||||
model = SomeModel(init_emb, num_cls=NUM_CLS) | |||||
RUNNER.run_model_with_task(TEXT_CLS, model) | |||||
# 序列标注模型 | |||||
RUNNER.run_model_with_task(POS_TAGGING, model) | |||||
# NLI模型 | |||||
RUNNER.run_model_with_task(NLI, model) | |||||
# 自定义模型 | |||||
RUNNER.run_model(model, data=get_mydata(), | |||||
loss=Myloss(), metrics=Mymetric()) | |||||
""" | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
from torch import optim | |||||
from fastNLP import Trainer, Evaluator, DataSet, Callback | |||||
from fastNLP import Accuracy | |||||
from random import randrange | |||||
from fastNLP import TorchDataLoader | |||||
VOCAB_SIZE = 100 | |||||
NUM_CLS = 100 | |||||
MAX_LEN = 10 | |||||
N_SAMPLES = 100 | |||||
N_EPOCHS = 1 | |||||
BATCH_SIZE = 5 | |||||
TEXT_CLS = 'text_cls' | |||||
POS_TAGGING = 'pos_tagging' | |||||
NLI = 'nli' | |||||
class ModelRunner(): | |||||
class Checker(Callback): | |||||
def on_backward_begin(self, trainer, outputs): | |||||
assert outputs['loss'].to('cpu').numpy().isfinate() | |||||
def gen_seq(self, length, vocab_size): | |||||
"""generate fake sequence indexes with given length""" | |||||
# reserve 0 for padding | |||||
return [randrange(1, vocab_size) for _ in range(length)] | |||||
def gen_var_seq(self, max_len, vocab_size): | |||||
"""generate fake sequence indexes in variant length""" | |||||
length = randrange(3, max_len) # at least 3 words in a seq | |||||
return self.gen_seq(length, vocab_size) | |||||
def prepare_text_classification_data(self): | |||||
index = 'index' | |||||
ds = DataSet({index: list(range(N_SAMPLES))}) | |||||
ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), | |||||
field_name=index, new_field_name='words') | |||||
ds.apply_field(lambda x: randrange(NUM_CLS), | |||||
field_name=index, new_field_name='target') | |||||
ds.apply_field(len, 'words', 'seq_len') | |||||
dl = TorchDataLoader(ds, batch_size=BATCH_SIZE) | |||||
return dl | |||||
def prepare_pos_tagging_data(self): | |||||
index = 'index' | |||||
ds = DataSet({index: list(range(N_SAMPLES))}) | |||||
ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), | |||||
field_name=index, new_field_name='words') | |||||
ds.apply_field(lambda x: self.gen_seq(len(x), NUM_CLS), | |||||
field_name='words', new_field_name='target') | |||||
ds.apply_field(len, 'words', 'seq_len') | |||||
dl = TorchDataLoader(ds, batch_size=BATCH_SIZE) | |||||
return dl | |||||
def prepare_nli_data(self): | |||||
index = 'index' | |||||
ds = DataSet({index: list(range(N_SAMPLES))}) | |||||
ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), | |||||
field_name=index, new_field_name='words1') | |||||
ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), | |||||
field_name=index, new_field_name='words2') | |||||
ds.apply_field(lambda x: randrange(NUM_CLS), | |||||
field_name=index, new_field_name='target') | |||||
ds.apply_field(len, 'words1', 'seq_len1') | |||||
ds.apply_field(len, 'words2', 'seq_len2') | |||||
dl = TorchDataLoader(ds, batch_size=BATCH_SIZE) | |||||
return dl | |||||
def run_text_classification(self, model, data=None): | |||||
if data is None: | |||||
data = self.prepare_text_classification_data() | |||||
metric = Accuracy() | |||||
self.run_model(model, data, metric) | |||||
def run_pos_tagging(self, model, data=None): | |||||
if data is None: | |||||
data = self.prepare_pos_tagging_data() | |||||
metric = Accuracy() | |||||
self.run_model(model, data, metric) | |||||
def run_nli(self, model, data=None): | |||||
if data is None: | |||||
data = self.prepare_nli_data() | |||||
metric = Accuracy() | |||||
self.run_model(model, data, metric) | |||||
def run_model(self, model, data, metrics): | |||||
"""run a model, test if it can run with fastNLP""" | |||||
print('testing model:', model.__class__.__name__) | |||||
tester = Evaluator(model, data, metrics={'metric': metrics}, driver='torch') | |||||
before_train = tester.run() | |||||
optimizer = optim.SGD(model.parameters(), lr=1e-3) | |||||
trainer = Trainer(model, driver='torch', train_dataloader=data, | |||||
n_epochs=N_EPOCHS, optimizers=optimizer) | |||||
trainer.run() | |||||
after_train = tester.run() | |||||
for metric_name, v1 in before_train.items(): | |||||
assert metric_name in after_train | |||||
# # at least we can sure model params changed, even if we don't know performance | |||||
# v2 = after_train[metric_name] | |||||
# assert v1 != v2 | |||||
def run_model_with_task(self, task, model): | |||||
"""run a model with certain task""" | |||||
TASKS = { | |||||
TEXT_CLS: self.run_text_classification, | |||||
POS_TAGGING: self.run_pos_tagging, | |||||
NLI: self.run_nli, | |||||
} | |||||
assert task in TASKS | |||||
TASKS[task](model) | |||||
RUNNER = ModelRunner() |
@@ -0,0 +1,91 @@ | |||||
import pytest | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
from fastNLP.models.torch.biaffine_parser import BiaffineParser | |||||
from fastNLP import Metric, seq_len_to_mask | |||||
from .model_runner import * | |||||
class ParserMetric(Metric): | |||||
r""" | |||||
评估parser的性能 | |||||
""" | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.num_arc = 0 | |||||
self.num_label = 0 | |||||
self.num_sample = 0 | |||||
def get_metric(self, reset=True): | |||||
res = {'UAS': self.num_arc * 1.0 / self.num_sample, 'LAS': self.num_label * 1.0 / self.num_sample} | |||||
if reset: | |||||
self.num_sample = self.num_label = self.num_arc = 0 | |||||
return res | |||||
def update(self, pred1, pred2, target1, target2, seq_len=None): | |||||
r""" | |||||
:param pred1: 边预测logits | |||||
:param pred2: label预测logits | |||||
:param target1: 真实边的标注 | |||||
:param target2: 真实类别的标注 | |||||
:param seq_len: 序列长度 | |||||
:return dict: 评估结果:: | |||||
UAS: 不带label时, 边预测的准确率 | |||||
LAS: 同时预测边和label的准确率 | |||||
""" | |||||
if seq_len is None: | |||||
seq_mask = pred1.new_ones(pred1.size(), dtype=torch.long) | |||||
else: | |||||
seq_mask = seq_len_to_mask(seq_len.long()).long() | |||||
# mask out <root> tag | |||||
seq_mask[:, 0] = 0 | |||||
head_pred_correct = (pred1 == target1).long() * seq_mask | |||||
label_pred_correct = (pred2 == target2).long() * head_pred_correct | |||||
self.num_arc += head_pred_correct.sum().item() | |||||
self.num_label += label_pred_correct.sum().item() | |||||
self.num_sample += seq_mask.sum().item() | |||||
def prepare_parser_data(): | |||||
index = 'index' | |||||
ds = DataSet({index: list(range(N_SAMPLES))}) | |||||
ds.apply_field(lambda x: RUNNER.gen_var_seq(MAX_LEN, VOCAB_SIZE), | |||||
field_name=index, new_field_name='words1') | |||||
ds.apply_field(lambda x: RUNNER.gen_seq(len(x), NUM_CLS), | |||||
field_name='words1', new_field_name='words2') | |||||
# target1 is heads, should in range(0, len(words)) | |||||
ds.apply_field(lambda x: RUNNER.gen_seq(len(x), len(x)), | |||||
field_name='words1', new_field_name='target1') | |||||
ds.apply_field(lambda x: RUNNER.gen_seq(len(x), NUM_CLS), | |||||
field_name='words1', new_field_name='target2') | |||||
ds.apply_field(len, field_name='words1', new_field_name='seq_len') | |||||
dl = TorchDataLoader(ds, batch_size=BATCH_SIZE) | |||||
return dl | |||||
@pytest.mark.torch | |||||
class TestBiaffineParser: | |||||
def test_train(self): | |||||
model = BiaffineParser(embed=(VOCAB_SIZE, 10), | |||||
pos_vocab_size=VOCAB_SIZE, pos_emb_dim=10, | |||||
rnn_hidden_size=10, | |||||
arc_mlp_size=10, | |||||
label_mlp_size=10, | |||||
num_label=NUM_CLS, encoder='var-lstm') | |||||
ds = prepare_parser_data() | |||||
RUNNER.run_model(model, ds, metrics=ParserMetric()) | |||||
def test_train2(self): | |||||
model = BiaffineParser(embed=(VOCAB_SIZE, 10), | |||||
pos_vocab_size=VOCAB_SIZE, pos_emb_dim=10, | |||||
rnn_hidden_size=16, | |||||
arc_mlp_size=10, | |||||
label_mlp_size=10, | |||||
num_label=NUM_CLS, encoder='transformer') | |||||
ds = prepare_parser_data() | |||||
RUNNER.run_model(model, ds, metrics=ParserMetric()) |
@@ -0,0 +1,33 @@ | |||||
import pytest | |||||
from .model_runner import * | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
from fastNLP.models.torch.cnn_text_classification import CNNText | |||||
@pytest.mark.torch | |||||
class TestCNNText: | |||||
def init_model(self, kernel_sizes, kernel_nums=(1,3,5)): | |||||
model = CNNText((VOCAB_SIZE, 30), | |||||
NUM_CLS, | |||||
kernel_nums=kernel_nums, | |||||
kernel_sizes=kernel_sizes) | |||||
return model | |||||
def test_case1(self): | |||||
# 测试能否正常运行CNN | |||||
model = self.init_model((1,3,5)) | |||||
RUNNER.run_model_with_task(TEXT_CLS, model) | |||||
def test_init_model(self): | |||||
with pytest.raises(Exception): | |||||
self.init_model(2, 4) | |||||
with pytest.raises(Exception): | |||||
self.init_model((2,)) | |||||
def test_output(self): | |||||
model = self.init_model((3,), (1,)) | |||||
global MAX_LEN | |||||
MAX_LEN = 2 | |||||
RUNNER.run_model_with_task(TEXT_CLS, model) |
@@ -0,0 +1,73 @@ | |||||
import pytest | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
from fastNLP.models.torch import LSTMSeq2SeqModel, TransformerSeq2SeqModel | |||||
import torch | |||||
from fastNLP.embeddings.torch import StaticEmbedding | |||||
from fastNLP import Vocabulary, DataSet | |||||
from fastNLP import Trainer, Accuracy | |||||
from fastNLP import Callback, TorchDataLoader | |||||
def prepare_env(): | |||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||||
vocab.add_word_lst("Another test !".split()) | |||||
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5) | |||||
src_words_idx = [[3, 1, 2], [1, 2]] | |||||
# tgt_words_idx = [[1, 2, 3, 4], [2, 3]] | |||||
src_seq_len = [3, 2] | |||||
# tgt_seq_len = [4, 2] | |||||
ds = DataSet({'src_tokens': src_words_idx, 'src_seq_len': src_seq_len, 'tgt_tokens': src_words_idx, | |||||
'tgt_seq_len':src_seq_len}) | |||||
dl = TorchDataLoader(ds, batch_size=32) | |||||
return embed, dl | |||||
class ExitCallback(Callback): | |||||
def __init__(self): | |||||
super().__init__() | |||||
def on_valid_end(self, trainer, results): | |||||
if results['acc#acc'] == 1: | |||||
raise KeyboardInterrupt() | |||||
@pytest.mark.torch | |||||
class TestSeq2SeqGeneratorModel: | |||||
def test_run(self): | |||||
# 检测是否能够使用SequenceGeneratorModel训练, 透传预测 | |||||
embed, dl = prepare_env() | |||||
model1 = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||||
pos_embed='sin', max_position=20, num_layers=2, d_model=30, n_head=6, | |||||
dim_ff=20, dropout=0.1, | |||||
bind_encoder_decoder_embed=True, | |||||
bind_decoder_input_output_embed=True) | |||||
optimizer = torch.optim.Adam(model1.parameters(), lr=1e-3) | |||||
trainer = Trainer(model1, driver='torch', optimizers=optimizer, train_dataloader=dl, | |||||
n_epochs=100, evaluate_dataloaders=dl, metrics={'acc': Accuracy()}, | |||||
evaluate_input_mapping=lambda x: {'target': x['tgt_tokens'], | |||||
'seq_len': x['tgt_seq_len'], | |||||
**x}, | |||||
callbacks=ExitCallback()) | |||||
trainer.run() | |||||
embed, dl = prepare_env() | |||||
model2 = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||||
num_layers=1, hidden_size=20, dropout=0.1, | |||||
bind_encoder_decoder_embed=True, | |||||
bind_decoder_input_output_embed=True, attention=True) | |||||
optimizer = torch.optim.Adam(model2.parameters(), lr=0.01) | |||||
trainer = Trainer(model2, driver='torch', optimizers=optimizer, train_dataloader=dl, | |||||
n_epochs=100, evaluate_dataloaders=dl, metrics={'acc': Accuracy()}, | |||||
evaluate_input_mapping=lambda x: {'target': x['tgt_tokens'], | |||||
'seq_len': x['tgt_seq_len'], | |||||
**x}, | |||||
callbacks=ExitCallback()) | |||||
trainer.run() |
@@ -0,0 +1,113 @@ | |||||
import pytest | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
from fastNLP.models.torch.seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel | |||||
from fastNLP import Vocabulary | |||||
from fastNLP.embeddings.torch import StaticEmbedding | |||||
import torch | |||||
from torch import optim | |||||
import torch.nn.functional as F | |||||
from fastNLP import seq_len_to_mask | |||||
def prepare_env(): | |||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||||
vocab.add_word_lst("Another test !".split()) | |||||
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5) | |||||
src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]]) | |||||
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) | |||||
src_seq_len = torch.LongTensor([3, 2]) | |||||
tgt_seq_len = torch.LongTensor([4, 2]) | |||||
return embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len | |||||
def train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len): | |||||
optimizer = optim.Adam(model.parameters(), lr=1e-2) | |||||
mask = seq_len_to_mask(tgt_seq_len).eq(0) | |||||
target = tgt_words_idx.masked_fill(mask, -100) | |||||
for i in range(100): | |||||
optimizer.zero_grad() | |||||
pred = model(src_words_idx, tgt_words_idx, src_seq_len)['pred'] # bsz x max_len x vocab_size | |||||
loss = F.cross_entropy(pred.transpose(1, 2), target) | |||||
loss.backward() | |||||
optimizer.step() | |||||
right_count = pred.argmax(dim=-1).eq(target).masked_fill(mask, 1).sum() | |||||
return right_count | |||||
@pytest.mark.torch | |||||
class TestTransformerSeq2SeqModel: | |||||
def test_run(self): | |||||
# 测试能否跑通 | |||||
embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() | |||||
for pos_embed in ['learned', 'sin']: | |||||
model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||||
pos_embed=pos_embed, max_position=20, num_layers=2, d_model=30, n_head=6, dim_ff=20, dropout=0.1, | |||||
bind_encoder_decoder_embed=True, | |||||
bind_decoder_input_output_embed=True) | |||||
output = model(src_words_idx, tgt_words_idx, src_seq_len) | |||||
assert (output['pred'].size() == (2, 4, len(embed))) | |||||
for bind_encoder_decoder_embed in [True, False]: | |||||
tgt_embed = None | |||||
for bind_decoder_input_output_embed in [True, False]: | |||||
if bind_encoder_decoder_embed == False: | |||||
tgt_embed = embed | |||||
model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=tgt_embed, | |||||
pos_embed='sin', max_position=20, num_layers=2, | |||||
d_model=30, n_head=6, dim_ff=20, dropout=0.1, | |||||
bind_encoder_decoder_embed=bind_encoder_decoder_embed, | |||||
bind_decoder_input_output_embed=bind_decoder_input_output_embed) | |||||
output = model(src_words_idx, tgt_words_idx, src_seq_len) | |||||
assert (output['pred'].size() == (2, 4, len(embed))) | |||||
def test_train(self): | |||||
# 测试能否train到overfit | |||||
embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() | |||||
model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||||
pos_embed='sin', max_position=20, num_layers=2, d_model=30, n_head=6, dim_ff=20, dropout=0.1, | |||||
bind_encoder_decoder_embed=True, | |||||
bind_decoder_input_output_embed=True) | |||||
right_count = train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len) | |||||
assert(right_count == tgt_words_idx.nelement()) | |||||
@pytest.mark.torch | |||||
class TestLSTMSeq2SeqModel: | |||||
def test_run(self): | |||||
# 测试能否跑通 | |||||
embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() | |||||
for bind_encoder_decoder_embed in [True, False]: | |||||
tgt_embed = None | |||||
for bind_decoder_input_output_embed in [True, False]: | |||||
if bind_encoder_decoder_embed == False: | |||||
tgt_embed = embed | |||||
model = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=tgt_embed, | |||||
num_layers=2, hidden_size=20, dropout=0.1, | |||||
bind_encoder_decoder_embed=bind_encoder_decoder_embed, | |||||
bind_decoder_input_output_embed=bind_decoder_input_output_embed) | |||||
output = model(src_words_idx, tgt_words_idx, src_seq_len) | |||||
assert (output['pred'].size() == (2, 4, len(embed))) | |||||
def test_train(self): | |||||
embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() | |||||
model = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||||
num_layers=1, hidden_size=20, dropout=0.1, | |||||
bind_encoder_decoder_embed=True, | |||||
bind_decoder_input_output_embed=True) | |||||
right_count = train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len) | |||||
assert (right_count == tgt_words_idx.nelement()) | |||||
@@ -0,0 +1,47 @@ | |||||
import pytest | |||||
from .model_runner import * | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
from fastNLP.models.torch.sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF | |||||
@pytest.mark.torch | |||||
class TestBiLSTM: | |||||
def test_case1(self): | |||||
# 测试能否正常运行CNN | |||||
init_emb = (VOCAB_SIZE, 30) | |||||
model = BiLSTMCRF(init_emb, | |||||
hidden_size=30, | |||||
num_classes=NUM_CLS) | |||||
dl = RUNNER.prepare_pos_tagging_data() | |||||
metric = Accuracy() | |||||
RUNNER.run_model(model, dl, metric) | |||||
@pytest.mark.torch | |||||
class TestSeqLabel: | |||||
def test_case1(self): | |||||
# 测试能否正常运行CNN | |||||
init_emb = (VOCAB_SIZE, 30) | |||||
model = SeqLabeling(init_emb, | |||||
hidden_size=30, | |||||
num_classes=NUM_CLS) | |||||
dl = RUNNER.prepare_pos_tagging_data() | |||||
metric = Accuracy() | |||||
RUNNER.run_model(model, dl, metric) | |||||
@pytest.mark.torch | |||||
class TestAdvSeqLabel: | |||||
def test_case1(self): | |||||
# 测试能否正常运行CNN | |||||
init_emb = (VOCAB_SIZE, 30) | |||||
model = AdvSeqLabel(init_emb, | |||||
hidden_size=30, | |||||
num_classes=NUM_CLS) | |||||
dl = RUNNER.prepare_pos_tagging_data() | |||||
metric = Accuracy() | |||||
RUNNER.run_model(model, dl, metric) |
@@ -0,0 +1,327 @@ | |||||
import pytest | |||||
import os | |||||
from fastNLP import Vocabulary | |||||
@pytest.mark.torch | |||||
class TestCRF: | |||||
def test_case1(self): | |||||
from fastNLP.modules.torch.decoder.crf import allowed_transitions | |||||
# 检查allowed_transitions()能否正确使用 | |||||
id2label = {0: 'B', 1: 'I', 2:'O'} | |||||
expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), | |||||
(2, 4), (3, 0), (3, 2)} | |||||
assert expected_res == set(allowed_transitions(id2label, include_start_end=True)) | |||||
id2label = {0: 'B', 1:'M', 2:'E', 3:'S'} | |||||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} | |||||
assert (expected_res == set( | |||||
allowed_transitions(id2label, encoding_type='BMES', include_start_end=True))) | |||||
id2label = {0: 'B', 1: 'I', 2:'O', 3: '<pad>', 4:"<unk>"} | |||||
allowed_transitions(id2label, include_start_end=True) | |||||
labels = ['O'] | |||||
for label in ['X', 'Y']: | |||||
for tag in 'BI': | |||||
labels.append('{}-{}'.format(tag, label)) | |||||
id2label = {idx:label for idx, label in enumerate(labels)} | |||||
expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), | |||||
(2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), | |||||
(4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} | |||||
assert (expected_res == set(allowed_transitions(id2label, include_start_end=True))) | |||||
labels = [] | |||||
for label in ['X', 'Y']: | |||||
for tag in 'BMES': | |||||
labels.append('{}-{}'.format(tag, label)) | |||||
id2label = {idx: label for idx, label in enumerate(labels)} | |||||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), | |||||
(3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), | |||||
(7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} | |||||
assert (expected_res == set( | |||||
allowed_transitions(id2label, include_start_end=True))) | |||||
def test_case11(self): | |||||
# 测试自动推断encoding类型 | |||||
from fastNLP.modules.torch.decoder.crf import allowed_transitions | |||||
id2label = {0: 'B', 1: 'I', 2: 'O'} | |||||
expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), | |||||
(2, 4), (3, 0), (3, 2)} | |||||
assert (expected_res == set(allowed_transitions(id2label, include_start_end=True))) | |||||
id2label = {0: 'B', 1: 'M', 2: 'E', 3: 'S'} | |||||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} | |||||
assert (expected_res == set( | |||||
allowed_transitions(id2label, include_start_end=True))) | |||||
id2label = {0: 'B', 1: 'I', 2: 'O', 3: '<pad>', 4: "<unk>"} | |||||
allowed_transitions(id2label, include_start_end=True) | |||||
labels = ['O'] | |||||
for label in ['X', 'Y']: | |||||
for tag in 'BI': | |||||
labels.append('{}-{}'.format(tag, label)) | |||||
id2label = {idx: label for idx, label in enumerate(labels)} | |||||
expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), | |||||
(2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), | |||||
(4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} | |||||
assert (expected_res == set(allowed_transitions(id2label, include_start_end=True))) | |||||
labels = [] | |||||
for label in ['X', 'Y']: | |||||
for tag in 'BMES': | |||||
labels.append('{}-{}'.format(tag, label)) | |||||
id2label = {idx: label for idx, label in enumerate(labels)} | |||||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), | |||||
(3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), | |||||
(7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} | |||||
assert (expected_res == set( | |||||
allowed_transitions(id2label, include_start_end=True))) | |||||
def test_case12(self): | |||||
# 测试能否通过vocab生成转移矩阵 | |||||
from fastNLP.modules.torch.decoder.crf import allowed_transitions | |||||
id2label = {0: 'B', 1: 'I', 2: 'O'} | |||||
vocab = Vocabulary(unknown=None, padding=None) | |||||
for idx, tag in id2label.items(): | |||||
vocab.add_word(tag) | |||||
expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), | |||||
(2, 4), (3, 0), (3, 2)} | |||||
assert (expected_res == set(allowed_transitions(vocab, include_start_end=True))) | |||||
id2label = {0: 'B', 1: 'M', 2: 'E', 3: 'S'} | |||||
vocab = Vocabulary(unknown=None, padding=None) | |||||
for idx, tag in id2label.items(): | |||||
vocab.add_word(tag) | |||||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} | |||||
assert (expected_res == set( | |||||
allowed_transitions(vocab, include_start_end=True))) | |||||
id2label = {0: 'B', 1: 'I', 2: 'O', 3: '<pad>', 4: "<unk>"} | |||||
vocab = Vocabulary() | |||||
for idx, tag in id2label.items(): | |||||
vocab.add_word(tag) | |||||
allowed_transitions(vocab, include_start_end=True) | |||||
labels = ['O'] | |||||
for label in ['X', 'Y']: | |||||
for tag in 'BI': | |||||
labels.append('{}-{}'.format(tag, label)) | |||||
id2label = {idx: label for idx, label in enumerate(labels)} | |||||
expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), | |||||
(2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), | |||||
(4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} | |||||
vocab = Vocabulary(unknown=None, padding=None) | |||||
for idx, tag in id2label.items(): | |||||
vocab.add_word(tag) | |||||
assert (expected_res == set(allowed_transitions(vocab, include_start_end=True))) | |||||
labels = [] | |||||
for label in ['X', 'Y']: | |||||
for tag in 'BMES': | |||||
labels.append('{}-{}'.format(tag, label)) | |||||
id2label = {idx: label for idx, label in enumerate(labels)} | |||||
vocab = Vocabulary(unknown=None, padding=None) | |||||
for idx, tag in id2label.items(): | |||||
vocab.add_word(tag) | |||||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), | |||||
(3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), | |||||
(7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} | |||||
assert (expected_res == set( | |||||
allowed_transitions(vocab, include_start_end=True))) | |||||
# def test_case2(self): | |||||
# # 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。 | |||||
# pass | |||||
# import torch | |||||
# from fastNLP import seq_len_to_mask | |||||
# | |||||
# labels = ['O'] | |||||
# for label in ['X', 'Y']: | |||||
# for tag in 'BI': | |||||
# labels.append('{}-{}'.format(tag, label)) | |||||
# id2label = {idx: label for idx, label in enumerate(labels)} | |||||
# num_tags = len(id2label) | |||||
# max_len = 10 | |||||
# batch_size = 4 | |||||
# bio_logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, max_len, num_tags)), dim=-1).log() | |||||
# from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions | |||||
# allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BIO', id2label), | |||||
# include_start_end_transitions=False) | |||||
# bio_trans_m = allen_CRF.transitions | |||||
# bio_seq_lens = torch.randint(1, max_len, size=(batch_size,)) | |||||
# bio_seq_lens[0] = 1 | |||||
# bio_seq_lens[-1] = max_len | |||||
# mask = seq_len_to_mask(bio_seq_lens) | |||||
# allen_res = allen_CRF.viterbi_tags(bio_logits, mask) | |||||
# | |||||
# from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions | |||||
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, | |||||
# include_start_end=True)) | |||||
# fast_CRF.trans_m = bio_trans_m | |||||
# fast_res = fast_CRF.viterbi_decode(bio_logits, mask, unpad=True) | |||||
# bio_scores = [round(score, 4) for _, score in allen_res] | |||||
# # score equal | |||||
# self.assertListEqual(bio_scores, [round(s, 4) for s in fast_res[1].tolist()]) | |||||
# # seq equal | |||||
# bio_path = [_ for _, score in allen_res] | |||||
# self.assertListEqual(bio_path, fast_res[0]) | |||||
# | |||||
# labels = [] | |||||
# for label in ['X', 'Y']: | |||||
# for tag in 'BMES': | |||||
# labels.append('{}-{}'.format(tag, label)) | |||||
# id2label = {idx: label for idx, label in enumerate(labels)} | |||||
# num_tags = len(id2label) | |||||
# | |||||
# from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions | |||||
# allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BMES', id2label), | |||||
# include_start_end_transitions=False) | |||||
# bmes_logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, max_len, num_tags)), dim=-1).log() | |||||
# bmes_trans_m = allen_CRF.transitions | |||||
# bmes_seq_lens = torch.randint(1, max_len, size=(batch_size,)) | |||||
# bmes_seq_lens[0] = 1 | |||||
# bmes_seq_lens[-1] = max_len | |||||
# mask = seq_len_to_mask(bmes_seq_lens) | |||||
# allen_res = allen_CRF.viterbi_tags(bmes_logits, mask) | |||||
# | |||||
# from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions | |||||
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, | |||||
# encoding_type='BMES', | |||||
# include_start_end=True)) | |||||
# fast_CRF.trans_m = bmes_trans_m | |||||
# fast_res = fast_CRF.viterbi_decode(bmes_logits, mask, unpad=True) | |||||
# # score equal | |||||
# bmes_scores = [round(score, 4) for _, score in allen_res] | |||||
# self.assertListEqual(bmes_scores, [round(s, 4) for s in fast_res[1].tolist()]) | |||||
# # seq equal | |||||
# bmes_path = [_ for _, score in allen_res] | |||||
# self.assertListEqual(bmes_path, fast_res[0]) | |||||
# | |||||
# data = { | |||||
# 'bio_logits': bio_logits.tolist(), | |||||
# 'bio_scores': bio_scores, | |||||
# 'bio_path': bio_path, | |||||
# 'bio_trans_m': bio_trans_m.tolist(), | |||||
# 'bio_seq_lens': bio_seq_lens.tolist(), | |||||
# 'bmes_logits': bmes_logits.tolist(), | |||||
# 'bmes_scores': bmes_scores, | |||||
# 'bmes_path': bmes_path, | |||||
# 'bmes_trans_m': bmes_trans_m.tolist(), | |||||
# 'bmes_seq_lens': bmes_seq_lens.tolist(), | |||||
# } | |||||
# | |||||
# with open('weights.json', 'w') as f: | |||||
# import json | |||||
# json.dump(data, f) | |||||
def test_case2(self): | |||||
# 测试CRF是否正常work。 | |||||
import json | |||||
import torch | |||||
from fastNLP import seq_len_to_mask | |||||
folder = os.path.dirname(os.path.abspath(__file__)) | |||||
path = os.path.join(folder, '../../../', 'helpers/data/modules/decoder/crf.json') | |||||
with open(os.path.abspath(path), 'r') as f: | |||||
data = json.load(f) | |||||
bio_logits = torch.FloatTensor(data['bio_logits']) | |||||
bio_scores = data['bio_scores'] | |||||
bio_path = data['bio_path'] | |||||
bio_trans_m = torch.FloatTensor(data['bio_trans_m']) | |||||
bio_seq_lens = torch.LongTensor(data['bio_seq_lens']) | |||||
bmes_logits = torch.FloatTensor(data['bmes_logits']) | |||||
bmes_scores = data['bmes_scores'] | |||||
bmes_path = data['bmes_path'] | |||||
bmes_trans_m = torch.FloatTensor(data['bmes_trans_m']) | |||||
bmes_seq_lens = torch.LongTensor(data['bmes_seq_lens']) | |||||
labels = ['O'] | |||||
for label in ['X', 'Y']: | |||||
for tag in 'BI': | |||||
labels.append('{}-{}'.format(tag, label)) | |||||
id2label = {idx: label for idx, label in enumerate(labels)} | |||||
num_tags = len(id2label) | |||||
mask = seq_len_to_mask(bio_seq_lens) | |||||
from fastNLP.modules.torch.decoder.crf import ConditionalRandomField, allowed_transitions | |||||
fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, | |||||
include_start_end=True)) | |||||
fast_CRF.trans_m.data = bio_trans_m | |||||
fast_res = fast_CRF.viterbi_decode(bio_logits, mask, unpad=True) | |||||
# score equal | |||||
assert (bio_scores == [round(s, 4) for s in fast_res[1].tolist()]) | |||||
# seq equal | |||||
assert (bio_path == fast_res[0]) | |||||
labels = [] | |||||
for label in ['X', 'Y']: | |||||
for tag in 'BMES': | |||||
labels.append('{}-{}'.format(tag, label)) | |||||
id2label = {idx: label for idx, label in enumerate(labels)} | |||||
num_tags = len(id2label) | |||||
mask = seq_len_to_mask(bmes_seq_lens) | |||||
from fastNLP.modules.torch.decoder.crf import ConditionalRandomField, allowed_transitions | |||||
fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, | |||||
encoding_type='BMES', | |||||
include_start_end=True)) | |||||
fast_CRF.trans_m.data = bmes_trans_m | |||||
fast_res = fast_CRF.viterbi_decode(bmes_logits, mask, unpad=True) | |||||
# score equal | |||||
assert (bmes_scores == [round(s, 4) for s in fast_res[1].tolist()]) | |||||
# seq equal | |||||
assert (bmes_path == fast_res[0]) | |||||
def test_case3(self): | |||||
# 测试crf的loss不会出现负数 | |||||
import torch | |||||
from fastNLP.modules.torch.decoder.crf import ConditionalRandomField | |||||
from fastNLP.core.utils import seq_len_to_mask | |||||
from torch import optim | |||||
from torch import nn | |||||
num_tags, include_start_end_trans = 4, True | |||||
num_samples = 4 | |||||
lengths = torch.randint(3, 50, size=(num_samples, )).long() | |||||
max_len = lengths.max() | |||||
tags = torch.randint(num_tags, size=(num_samples, max_len)) | |||||
masks = seq_len_to_mask(lengths) | |||||
feats = nn.Parameter(torch.randn(num_samples, max_len, num_tags)) | |||||
crf = ConditionalRandomField(num_tags, include_start_end_trans) | |||||
optimizer = optim.SGD([param for param in crf.parameters() if param.requires_grad] + [feats], lr=0.1) | |||||
for _ in range(10): | |||||
loss = crf(feats, tags, masks).mean() | |||||
optimizer.zero_grad() | |||||
loss.backward() | |||||
optimizer.step() | |||||
if _%1000==0: | |||||
print(loss) | |||||
assert (loss.item()> 0) | |||||
def test_masking(self): | |||||
# 测试crf的pad masking正常运行 | |||||
import torch | |||||
from fastNLP.modules.torch.decoder.crf import ConditionalRandomField | |||||
max_len = 5 | |||||
n_tags = 5 | |||||
pad_len = 5 | |||||
torch.manual_seed(4) | |||||
logit = torch.rand(1, max_len+pad_len, n_tags) | |||||
# logit[0, -1, :] = 0.0 | |||||
mask = torch.ones(1, max_len+pad_len) | |||||
mask[0,-pad_len] = 0 | |||||
model = ConditionalRandomField(n_tags) | |||||
pred, score = model.viterbi_decode(logit[:,:-pad_len], mask[:,:-pad_len]) | |||||
mask_pred, mask_score = model.viterbi_decode(logit, mask) | |||||
assert (pred[0].tolist() == mask_pred[0,:-pad_len].tolist()) | |||||
@@ -0,0 +1,49 @@ | |||||
import pytest | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
from fastNLP import Vocabulary | |||||
from fastNLP.embeddings.torch import StaticEmbedding | |||||
from fastNLP.modules.torch import TransformerSeq2SeqDecoder | |||||
from fastNLP.modules.torch import LSTMSeq2SeqDecoder | |||||
from fastNLP import seq_len_to_mask | |||||
@pytest.mark.torch | |||||
class TestTransformerSeq2SeqDecoder: | |||||
@pytest.mark.parametrize("flag", [True, False]) | |||||
def test_case(self, flag): | |||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||||
vocab.add_word_lst("Another test !".split()) | |||||
embed = StaticEmbedding(vocab, embedding_dim=10) | |||||
encoder_output = torch.randn(2, 3, 10) | |||||
src_seq_len = torch.LongTensor([3, 2]) | |||||
encoder_mask = seq_len_to_mask(src_seq_len) | |||||
decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed = None, | |||||
d_model = 10, num_layers=2, n_head = 5, dim_ff = 20, dropout = 0.1, | |||||
bind_decoder_input_output_embed = True) | |||||
state = decoder.init_state(encoder_output, encoder_mask) | |||||
output = decoder(tokens=torch.randint(0, len(vocab), size=(2, 4)), state=state) | |||||
assert (output.size() == (2, 4, len(vocab))) | |||||
@pytest.mark.torch | |||||
class TestLSTMDecoder: | |||||
@pytest.mark.parametrize("flag", [True, False]) | |||||
@pytest.mark.parametrize("attention", [True, False]) | |||||
def test_case(self, flag, attention): | |||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||||
vocab.add_word_lst("Another test !".split()) | |||||
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=10) | |||||
encoder_output = torch.randn(2, 3, 10) | |||||
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) | |||||
src_seq_len = torch.LongTensor([3, 2]) | |||||
encoder_mask = seq_len_to_mask(src_seq_len) | |||||
decoder = LSTMSeq2SeqDecoder(embed=embed, num_layers = 2, hidden_size = 10, | |||||
dropout = 0.3, bind_decoder_input_output_embed=flag, attention=attention) | |||||
state = decoder.init_state(encoder_output, encoder_mask) | |||||
output = decoder(tgt_words_idx, state) | |||||
assert tuple(output.size()) == (2, 4, len(vocab)) |
@@ -0,0 +1,33 @@ | |||||
import pytest | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
from fastNLP.modules.torch.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder | |||||
from fastNLP import Vocabulary | |||||
from fastNLP.embeddings.torch import StaticEmbedding | |||||
class TestTransformerSeq2SeqEncoder: | |||||
def test_case(self): | |||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||||
embed = StaticEmbedding(vocab, embedding_dim=5) | |||||
encoder = TransformerSeq2SeqEncoder(embed, num_layers=2, d_model=10, n_head=2) | |||||
words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0) | |||||
seq_len = torch.LongTensor([3]) | |||||
encoder_output, encoder_mask = encoder(words_idx, seq_len) | |||||
assert (encoder_output.size() == (1, 3, 10)) | |||||
class TestBiLSTMEncoder: | |||||
def test_case(self): | |||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||||
embed = StaticEmbedding(vocab, embedding_dim=5) | |||||
encoder = LSTMSeq2SeqEncoder(embed, hidden_size=5, num_layers=1) | |||||
words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0) | |||||
seq_len = torch.LongTensor([3]) | |||||
encoder_output, encoder_mask = encoder(words_idx, seq_len) | |||||
assert (encoder_mask.size() == (1, 3)) |
@@ -0,0 +1,18 @@ | |||||
import pytest | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
from fastNLP.modules.torch.encoder.star_transformer import StarTransformer | |||||
@pytest.mark.torch | |||||
class TestStarTransformer: | |||||
def test_1(self): | |||||
model = StarTransformer(num_layers=6, hidden_size=100, num_head=8, head_dim=20, max_len=100) | |||||
x = torch.rand(16, 45, 100) | |||||
mask = torch.ones(16, 45).byte() | |||||
y, yn = model(x, mask) | |||||
assert (tuple(y.size()) == (16, 45, 100)) | |||||
assert (tuple(yn.size()) == (16, 100)) |
@@ -0,0 +1,27 @@ | |||||
import pytest | |||||
import numpy as np | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
from fastNLP.modules.torch.encoder.variational_rnn import VarLSTM | |||||
@pytest.mark.torch | |||||
class TestMaskedRnn: | |||||
def test_case_1(self): | |||||
masked_rnn = VarLSTM(input_size=1, hidden_size=1, bidirectional=True, batch_first=True) | |||||
x = torch.tensor([[[1.0], [2.0]]]) | |||||
print(x.size()) | |||||
y = masked_rnn(x) | |||||
def test_case_2(self): | |||||
input_size = 12 | |||||
batch = 16 | |||||
hidden = 10 | |||||
masked_rnn = VarLSTM(input_size=input_size, hidden_size=hidden, bidirectional=False, batch_first=True) | |||||
xx = torch.randn((batch, 32, input_size)) | |||||
y, _ = masked_rnn(xx) | |||||
assert(tuple(y.shape) == (batch, 32, hidden)) |
@@ -0,0 +1 @@ | |||||
@@ -0,0 +1,146 @@ | |||||
import pytest | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
from fastNLP.modules.torch.generator import SequenceGenerator | |||||
from fastNLP.modules.torch import TransformerSeq2SeqDecoder, LSTMSeq2SeqDecoder, Seq2SeqDecoder, State | |||||
from fastNLP import Vocabulary | |||||
from fastNLP.embeddings.torch import StaticEmbedding | |||||
from torch import nn | |||||
from fastNLP import seq_len_to_mask | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as State | |||||
from fastNLP.core.utils.dummy_class import DummyClass as Seq2SeqDecoder | |||||
def prepare_env(): | |||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||||
vocab.add_word_lst("Another test !".split()) | |||||
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5) | |||||
encoder_output = torch.randn(2, 3, 10) | |||||
src_seq_len = torch.LongTensor([3, 2]) | |||||
encoder_mask = seq_len_to_mask(src_seq_len) | |||||
return embed, encoder_output, encoder_mask | |||||
class GreedyDummyDecoder(Seq2SeqDecoder): | |||||
def __init__(self, decoder_output): | |||||
super().__init__() | |||||
self.cur_length = 0 | |||||
self.decoder_output = decoder_output | |||||
def decode(self, tokens, state): | |||||
self.cur_length += 1 | |||||
scores = self.decoder_output[:, self.cur_length] | |||||
return scores | |||||
class DummyState(State): | |||||
def __init__(self, decoder): | |||||
super().__init__() | |||||
self.decoder = decoder | |||||
def reorder_state(self, indices: torch.LongTensor): | |||||
self.decoder.decoder_output = self._reorder_state(self.decoder.decoder_output, indices, dim=0) | |||||
@pytest.mark.torch | |||||
class TestSequenceGenerator: | |||||
def test_run(self): | |||||
# 测试能否运行 (1) 初始化decoder,(2) decode一发 | |||||
embed, encoder_output, encoder_mask = prepare_env() | |||||
for do_sample in [True, False]: | |||||
for num_beams in [1, 3, 5]: | |||||
decoder = LSTMSeq2SeqDecoder(embed=embed, num_layers=1, hidden_size=10, | |||||
dropout=0.3, bind_decoder_input_output_embed=True, attention=True) | |||||
state = decoder.init_state(encoder_output, encoder_mask) | |||||
generator = SequenceGenerator(decoder=decoder, max_length=20, num_beams=num_beams, | |||||
do_sample=do_sample, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=1, eos_token_id=None, | |||||
repetition_penalty=1, length_penalty=1.0, pad_token_id=0) | |||||
generator.generate(state=state, tokens=None) | |||||
decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed=nn.Embedding(10, embed.embedding_dim), | |||||
d_model=encoder_output.size(-1), num_layers=2, n_head=2, dim_ff=10, dropout=0.1, | |||||
bind_decoder_input_output_embed=True) | |||||
state = decoder.init_state(encoder_output, encoder_mask) | |||||
generator = SequenceGenerator(decoder=decoder, max_length=5, num_beams=num_beams, | |||||
do_sample=do_sample, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=1, eos_token_id=None, | |||||
repetition_penalty=1, length_penalty=1.0, pad_token_id=0) | |||||
generator.generate(state=state, tokens=None) | |||||
# 测试一下其它值 | |||||
decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed=nn.Embedding(10, embed.embedding_dim), | |||||
d_model=encoder_output.size(-1), num_layers=2, n_head=2, dim_ff=10, | |||||
dropout=0.1, | |||||
bind_decoder_input_output_embed=True) | |||||
state = decoder.init_state(encoder_output, encoder_mask) | |||||
generator = SequenceGenerator(decoder=decoder, max_length=5, num_beams=num_beams, | |||||
do_sample=do_sample, temperature=0.9, top_k=50, top_p=0.5, bos_token_id=1, | |||||
eos_token_id=3, repetition_penalty=2, length_penalty=1.5, pad_token_id=0) | |||||
generator.generate(state=state, tokens=None) | |||||
def test_greedy_decode(self): | |||||
# 测试能否正确的generate | |||||
# greedy | |||||
for beam_search in [1, 3]: | |||||
decoder_output = torch.randn(2, 10, 5) | |||||
path = decoder_output.argmax(dim=-1) # 2 x 10 | |||||
decoder = GreedyDummyDecoder(decoder_output) | |||||
generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search, | |||||
do_sample=False, temperature=1, top_k=50, top_p=1, bos_token_id=1, | |||||
eos_token_id=None, repetition_penalty=1, length_penalty=1, pad_token_id=0) | |||||
decode_path = generator.generate(DummyState(decoder), tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True)) | |||||
assert (decode_path.eq(path).sum() == path.numel()) | |||||
# greedy check eos_token_id | |||||
for beam_search in [1, 3]: | |||||
decoder_output = torch.randn(2, 10, 5) | |||||
decoder_output[:, :7, 4].fill_(-100) | |||||
decoder_output[0, 7, 4] = 1000 # 在第8个结束 | |||||
decoder_output[1, 5, 4] = 1000 | |||||
path = decoder_output.argmax(dim=-1) # 2 x 4 | |||||
decoder = GreedyDummyDecoder(decoder_output) | |||||
generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search, | |||||
do_sample=False, temperature=1, top_k=50, top_p=0.5, bos_token_id=1, | |||||
eos_token_id=4, repetition_penalty=1, length_penalty=1, pad_token_id=0) | |||||
decode_path = generator.generate(DummyState(decoder), | |||||
tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True)) | |||||
assert (decode_path.size(1) == 8) # 长度为8 | |||||
assert (decode_path[0].eq(path[0, :8]).sum() == 8) | |||||
assert (decode_path[1, :6].eq(path[1, :6]).sum() == 6) | |||||
def test_sample_decoder(self): | |||||
# greedy check eos_token_id | |||||
for beam_search in [1, 3]: | |||||
decode_paths = [] | |||||
# 因为是随机,所以需要测试100次,如果至少有一次是对的,应该就问题不大 | |||||
num_tests = 10 | |||||
for i in range(num_tests): | |||||
decoder_output = torch.randn(2, 10, 5) * 10 | |||||
decoder_output[:, :7, 4].fill_(-100) | |||||
decoder_output[0, 7, 4] = 10000 # 在第8个结束 | |||||
decoder_output[1, 5, 4] = 10000 | |||||
path = decoder_output.argmax(dim=-1) # 2 x 4 | |||||
decoder = GreedyDummyDecoder(decoder_output) | |||||
generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search, | |||||
do_sample=True, temperature=1, top_k=50, top_p=0.5, bos_token_id=1, | |||||
eos_token_id=4, repetition_penalty=1, length_penalty=1, pad_token_id=0) | |||||
decode_path = generator.generate(DummyState(decoder), | |||||
tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True)) | |||||
decode_paths.append([decode_path, path]) | |||||
sizes = [] | |||||
eqs = [] | |||||
eq2s = [] | |||||
for i in range(num_tests): | |||||
decode_path, path = decode_paths[i] | |||||
sizes.append(decode_path.size(1)==8) | |||||
eqs.append(decode_path[0].eq(path[0, :8]).sum()==8) | |||||
eq2s.append(decode_path[1, :6].eq(path[1, :6]).sum()==6) | |||||
assert any(sizes) | |||||
assert any(eqs) | |||||
assert any(eq2s) |