| @@ -138,6 +138,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||
| msg = f"Fail to get padder for field:{field_name}. " + e.msg + " To view more " \ | |||
| "information please set logger's level to DEBUG." | |||
| if must_pad: | |||
| logger.error(msg) | |||
| raise type(e)(msg=msg) | |||
| logger.debug(msg) | |||
| return NullPadder() | |||
| @@ -16,6 +16,7 @@ if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from torch.nn import LSTM | |||
| from .embedding import TokenEmbedding | |||
| from .static_embedding import StaticEmbedding | |||
| @@ -23,7 +24,6 @@ from .utils import _construct_char_vocab_from_vocab | |||
| from .utils import get_embeddings | |||
| from ...core import logger | |||
| from ...core.vocabulary import Vocabulary | |||
| from ...modules.torch.encoder.lstm import LSTM | |||
| class CNNCharEmbedding(TokenEmbedding): | |||
| @@ -0,0 +1,21 @@ | |||
| __all__ = [ | |||
| 'BiaffineParser', | |||
| "CNNText", | |||
| "SequenceGeneratorModel", | |||
| "Seq2SeqModel", | |||
| 'TransformerSeq2SeqModel', | |||
| 'LSTMSeq2SeqModel', | |||
| "SeqLabeling", | |||
| "AdvSeqLabel", | |||
| "BiLSTMCRF", | |||
| ] | |||
| from .biaffine_parser import BiaffineParser | |||
| from .cnn_text_classification import CNNText | |||
| from .seq2seq_generator import SequenceGeneratorModel | |||
| from .seq2seq_model import * | |||
| from .sequence_labeling import * | |||
| @@ -0,0 +1,475 @@ | |||
| r""" | |||
| Biaffine Dependency Parser 的 Pytorch 实现. | |||
| """ | |||
| __all__ = [ | |||
| "BiaffineParser", | |||
| "GraphParser" | |||
| ] | |||
| from collections import defaultdict | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from ...core.utils import seq_len_to_mask | |||
| from ...embeddings.torch.utils import get_embeddings | |||
| from ...modules.torch.dropout import TimestepDropout | |||
| from ...modules.torch.encoder.transformer import TransformerEncoder | |||
| from ...modules.torch.encoder.variational_rnn import VarLSTM | |||
| def _mst(scores): | |||
| r""" | |||
| with some modification to support parser output for MST decoding | |||
| https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692 | |||
| """ | |||
| length = scores.shape[0] | |||
| min_score = scores.min() - 1 | |||
| eye = np.eye(length) | |||
| scores = scores * (1 - eye) + min_score * eye | |||
| heads = np.argmax(scores, axis=1) | |||
| heads[0] = 0 | |||
| tokens = np.arange(1, length) | |||
| roots = np.where(heads[tokens] == 0)[0] + 1 | |||
| if len(roots) < 1: | |||
| root_scores = scores[tokens, 0] | |||
| head_scores = scores[tokens, heads[tokens]] | |||
| new_root = tokens[np.argmax(root_scores / head_scores)] | |||
| heads[new_root] = 0 | |||
| elif len(roots) > 1: | |||
| root_scores = scores[roots, 0] | |||
| scores[roots, 0] = 0 | |||
| new_heads = np.argmax(scores[roots][:, tokens], axis=1) + 1 | |||
| new_root = roots[np.argmin( | |||
| scores[roots, new_heads] / root_scores)] | |||
| heads[roots] = new_heads | |||
| heads[new_root] = 0 | |||
| edges = defaultdict(set) | |||
| vertices = set((0,)) | |||
| for dep, head in enumerate(heads[tokens]): | |||
| vertices.add(dep + 1) | |||
| edges[head].add(dep + 1) | |||
| for cycle in _find_cycle(vertices, edges): | |||
| dependents = set() | |||
| to_visit = set(cycle) | |||
| while len(to_visit) > 0: | |||
| node = to_visit.pop() | |||
| if node not in dependents: | |||
| dependents.add(node) | |||
| to_visit.update(edges[node]) | |||
| cycle = np.array(list(cycle)) | |||
| old_heads = heads[cycle] | |||
| old_scores = scores[cycle, old_heads] | |||
| non_heads = np.array(list(dependents)) | |||
| scores[np.repeat(cycle, len(non_heads)), | |||
| np.repeat([non_heads], len(cycle), axis=0).flatten()] = min_score | |||
| new_heads = np.argmax(scores[cycle][:, tokens], axis=1) + 1 | |||
| new_scores = scores[cycle, new_heads] / old_scores | |||
| change = np.argmax(new_scores) | |||
| changed_cycle = cycle[change] | |||
| old_head = old_heads[change] | |||
| new_head = new_heads[change] | |||
| heads[changed_cycle] = new_head | |||
| edges[new_head].add(changed_cycle) | |||
| edges[old_head].remove(changed_cycle) | |||
| return heads | |||
| def _find_cycle(vertices, edges): | |||
| r""" | |||
| https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm | |||
| https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/etc/tarjan.py | |||
| """ | |||
| _index = 0 | |||
| _stack = [] | |||
| _indices = {} | |||
| _lowlinks = {} | |||
| _onstack = defaultdict(lambda: False) | |||
| _SCCs = [] | |||
| def _strongconnect(v): | |||
| nonlocal _index | |||
| _indices[v] = _index | |||
| _lowlinks[v] = _index | |||
| _index += 1 | |||
| _stack.append(v) | |||
| _onstack[v] = True | |||
| for w in edges[v]: | |||
| if w not in _indices: | |||
| _strongconnect(w) | |||
| _lowlinks[v] = min(_lowlinks[v], _lowlinks[w]) | |||
| elif _onstack[w]: | |||
| _lowlinks[v] = min(_lowlinks[v], _indices[w]) | |||
| if _lowlinks[v] == _indices[v]: | |||
| SCC = set() | |||
| while True: | |||
| w = _stack.pop() | |||
| _onstack[w] = False | |||
| SCC.add(w) | |||
| if not (w != v): | |||
| break | |||
| _SCCs.append(SCC) | |||
| for v in vertices: | |||
| if v not in _indices: | |||
| _strongconnect(v) | |||
| return [SCC for SCC in _SCCs if len(SCC) > 1] | |||
| class GraphParser(nn.Module): | |||
| r""" | |||
| 基于图的parser base class, 支持贪婪解码和最大生成树解码 | |||
| """ | |||
| def __init__(self): | |||
| super(GraphParser, self).__init__() | |||
| @staticmethod | |||
| def greedy_decoder(arc_matrix, mask=None): | |||
| r""" | |||
| 贪心解码方式, 输入图, 输出贪心解码的parsing结果, 不保证合法的构成树 | |||
| :param arc_matrix: [batch, seq_len, seq_len] 输入图矩阵 | |||
| :param mask: [batch, seq_len] 输入图的padding mask, 有内容的部分为 1, 否则为 0. | |||
| 若为 ``None`` 时, 默认为全1向量. Default: ``None`` | |||
| :return heads: [batch, seq_len] 每个元素在树中对应的head(parent)预测结果 | |||
| """ | |||
| _, seq_len, _ = arc_matrix.shape | |||
| matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) | |||
| flip_mask = mask.eq(False) | |||
| matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf) | |||
| _, heads = torch.max(matrix, dim=2) | |||
| if mask is not None: | |||
| heads *= mask.long() | |||
| return heads | |||
| @staticmethod | |||
| def mst_decoder(arc_matrix, mask=None): | |||
| r""" | |||
| 用最大生成树算法, 计算parsing结果, 保证输出合法的树结构 | |||
| :param arc_matrix: [batch, seq_len, seq_len] 输入图矩阵 | |||
| :param mask: [batch, seq_len] 输入图的padding mask, 有内容的部分为 1, 否则为 0. | |||
| 若为 ``None`` 时, 默认为全1向量. Default: ``None`` | |||
| :return heads: [batch, seq_len] 每个元素在树中对应的head(parent)预测结果 | |||
| """ | |||
| batch_size, seq_len, _ = arc_matrix.shape | |||
| matrix = arc_matrix.clone() | |||
| ans = matrix.new_zeros(batch_size, seq_len).long() | |||
| lens = (mask.long()).sum(1) if mask is not None else torch.zeros(batch_size) + seq_len | |||
| for i, graph in enumerate(matrix): | |||
| len_i = lens[i] | |||
| ans[i, :len_i] = torch.as_tensor(_mst(graph.detach()[:len_i, :len_i].cpu().numpy()), device=ans.device) | |||
| if mask is not None: | |||
| ans *= mask.long() | |||
| return ans | |||
| class ArcBiaffine(nn.Module): | |||
| r""" | |||
| Biaffine Dependency Parser 的子模块, 用于构建预测边的图 | |||
| """ | |||
| def __init__(self, hidden_size, bias=True): | |||
| r""" | |||
| :param hidden_size: 输入的特征维度 | |||
| :param bias: 是否使用bias. Default: ``True`` | |||
| """ | |||
| super(ArcBiaffine, self).__init__() | |||
| self.U = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad=True) | |||
| self.has_bias = bias | |||
| if self.has_bias: | |||
| self.bias = nn.Parameter(torch.Tensor(hidden_size), requires_grad=True) | |||
| else: | |||
| self.register_parameter("bias", None) | |||
| def forward(self, head, dep): | |||
| r""" | |||
| :param head: arc-head tensor [batch, length, hidden] | |||
| :param dep: arc-dependent tensor [batch, length, hidden] | |||
| :return output: tensor [bacth, length, length] | |||
| """ | |||
| output = dep.matmul(self.U) | |||
| output = output.bmm(head.transpose(-1, -2)) | |||
| if self.has_bias: | |||
| output = output + head.matmul(self.bias).unsqueeze(1) | |||
| return output | |||
| class LabelBilinear(nn.Module): | |||
| r""" | |||
| Biaffine Dependency Parser 的子模块, 用于构建预测边类别的图 | |||
| """ | |||
| def __init__(self, in1_features, in2_features, num_label, bias=True): | |||
| r""" | |||
| :param in1_features: 输入的特征1维度 | |||
| :param in2_features: 输入的特征2维度 | |||
| :param num_label: 边类别的个数 | |||
| :param bias: 是否使用bias. Default: ``True`` | |||
| """ | |||
| super(LabelBilinear, self).__init__() | |||
| self.bilinear = nn.Bilinear(in1_features, in2_features, num_label, bias=bias) | |||
| self.lin = nn.Linear(in1_features + in2_features, num_label, bias=False) | |||
| def forward(self, x1, x2): | |||
| r""" | |||
| :param x1: [batch, seq_len, hidden] 输入特征1, 即label-head | |||
| :param x2: [batch, seq_len, hidden] 输入特征2, 即label-dep | |||
| :return output: [batch, seq_len, num_cls] 每个元素对应类别的概率图 | |||
| """ | |||
| output = self.bilinear(x1, x2) | |||
| output = output + self.lin(torch.cat([x1, x2], dim=2)) | |||
| return output | |||
| class BiaffineParser(GraphParser): | |||
| r""" | |||
| Biaffine Dependency Parser 实现. | |||
| 论文参考 `Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) <https://arxiv.org/abs/1611.01734>`_ . | |||
| """ | |||
| def __init__(self, | |||
| embed, | |||
| pos_vocab_size, | |||
| pos_emb_dim, | |||
| num_label, | |||
| rnn_layers=1, | |||
| rnn_hidden_size=200, | |||
| arc_mlp_size=100, | |||
| label_mlp_size=100, | |||
| dropout=0.3, | |||
| encoder='lstm', | |||
| use_greedy_infer=False): | |||
| r""" | |||
| :param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 | |||
| embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, | |||
| 此时就以传入的对象作为embedding | |||
| :param pos_vocab_size: part-of-speech 词典大小 | |||
| :param pos_emb_dim: part-of-speech 向量维度 | |||
| :param num_label: 边的类别个数 | |||
| :param rnn_layers: rnn encoder的层数 | |||
| :param rnn_hidden_size: rnn encoder 的隐状态维度 | |||
| :param arc_mlp_size: 边预测的MLP维度 | |||
| :param label_mlp_size: 类别预测的MLP维度 | |||
| :param dropout: dropout概率. | |||
| :param encoder: encoder类别, 可选 ('lstm', 'var-lstm', 'transformer'). Default: lstm | |||
| :param use_greedy_infer: 是否在inference时使用贪心算法. | |||
| 若 ``False`` , 使用更加精确但相对缓慢的MST算法. Default: ``False`` | |||
| """ | |||
| super(BiaffineParser, self).__init__() | |||
| rnn_out_size = 2 * rnn_hidden_size | |||
| word_hid_dim = pos_hid_dim = rnn_hidden_size | |||
| self.word_embedding = get_embeddings(embed) | |||
| word_emb_dim = self.word_embedding.embedding_dim | |||
| self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim) | |||
| self.word_fc = nn.Linear(word_emb_dim, word_hid_dim) | |||
| self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim) | |||
| self.word_norm = nn.LayerNorm(word_hid_dim) | |||
| self.pos_norm = nn.LayerNorm(pos_hid_dim) | |||
| self.encoder_name = encoder | |||
| self.max_len = 512 | |||
| if encoder == 'var-lstm': | |||
| self.encoder = VarLSTM(input_size=word_hid_dim + pos_hid_dim, | |||
| hidden_size=rnn_hidden_size, | |||
| num_layers=rnn_layers, | |||
| bias=True, | |||
| batch_first=True, | |||
| input_dropout=dropout, | |||
| hidden_dropout=dropout, | |||
| bidirectional=True) | |||
| elif encoder == 'lstm': | |||
| self.encoder = nn.LSTM(input_size=word_hid_dim + pos_hid_dim, | |||
| hidden_size=rnn_hidden_size, | |||
| num_layers=rnn_layers, | |||
| bias=True, | |||
| batch_first=True, | |||
| dropout=dropout, | |||
| bidirectional=True) | |||
| elif encoder == 'transformer': | |||
| n_head = 16 | |||
| d_k = d_v = int(rnn_out_size / n_head) | |||
| if (d_k * n_head) != rnn_out_size: | |||
| raise ValueError('Unsupported rnn_out_size: {} for transformer'.format(rnn_out_size)) | |||
| self.position_emb = nn.Embedding(num_embeddings=self.max_len, | |||
| embedding_dim=rnn_out_size, ) | |||
| self.encoder = TransformerEncoder( num_layers=rnn_layers, d_model=rnn_out_size, | |||
| n_head=n_head, dim_ff=1024, dropout=dropout) | |||
| else: | |||
| raise ValueError('Unsupported encoder type: {}'.format(encoder)) | |||
| self.mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size * 2 + label_mlp_size * 2), | |||
| nn.ELU(), | |||
| TimestepDropout(p=dropout), ) | |||
| self.arc_mlp_size = arc_mlp_size | |||
| self.label_mlp_size = label_mlp_size | |||
| self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | |||
| self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | |||
| self.use_greedy_infer = use_greedy_infer | |||
| self.reset_parameters() | |||
| self.dropout = dropout | |||
| def reset_parameters(self): | |||
| for m in self.modules(): | |||
| if isinstance(m, nn.Embedding): | |||
| continue | |||
| elif isinstance(m, nn.LayerNorm): | |||
| nn.init.constant_(m.weight, 0.1) | |||
| nn.init.constant_(m.bias, 0) | |||
| else: | |||
| for p in m.parameters(): | |||
| nn.init.normal_(p, 0, 0.1) | |||
| def forward(self, words1, words2, seq_len, target1=None): | |||
| r"""模型forward阶段 | |||
| :param words1: [batch_size, seq_len] 输入word序列 | |||
| :param words2: [batch_size, seq_len] 输入pos序列 | |||
| :param seq_len: [batch_size, seq_len] 输入序列长度 | |||
| :param target1: [batch_size, seq_len] 输入真实标注的heads, 仅在训练阶段有效, | |||
| 用于训练label分类器. 若为 ``None`` , 使用预测的heads输入到label分类器 | |||
| Default: ``None`` | |||
| :return dict: parsing | |||
| 结果:: | |||
| pred1: [batch_size, seq_len, seq_len] 边预测logits | |||
| pred2: [batch_size, seq_len, num_label] label预测logits | |||
| pred3: [batch_size, seq_len] heads的预测结果, 在 ``target1=None`` 时预测 | |||
| """ | |||
| # prepare embeddings | |||
| batch_size, length = words1.shape | |||
| # print('forward {} {}'.format(batch_size, seq_len)) | |||
| # get sequence mask | |||
| mask = seq_len_to_mask(seq_len, max_len=length).long() | |||
| word = self.word_embedding(words1) # [N,L] -> [N,L,C_0] | |||
| pos = self.pos_embedding(words2) # [N,L] -> [N,L,C_1] | |||
| word, pos = self.word_fc(word), self.pos_fc(pos) | |||
| word, pos = self.word_norm(word), self.pos_norm(pos) | |||
| x = torch.cat([word, pos], dim=2) # -> [N,L,C] | |||
| # encoder, extract features | |||
| if self.encoder_name.endswith('lstm'): | |||
| sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) | |||
| x = x[sort_idx] | |||
| x = nn.utils.rnn.pack_padded_sequence(x, sort_lens.cpu(), batch_first=True) | |||
| feat, _ = self.encoder(x) # -> [N,L,C] | |||
| feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) | |||
| _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | |||
| feat = feat[unsort_idx] | |||
| else: | |||
| seq_range = torch.arange(length, dtype=torch.long, device=x.device)[None, :] | |||
| x = x + self.position_emb(seq_range) | |||
| feat = self.encoder(x, mask.float()) | |||
| # for arc biaffine | |||
| # mlp, reduce dim | |||
| feat = self.mlp(feat) | |||
| arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size | |||
| arc_dep, arc_head = feat[:, :, :arc_sz], feat[:, :, arc_sz:2 * arc_sz] | |||
| label_dep, label_head = feat[:, :, 2 * arc_sz:2 * arc_sz + label_sz], feat[:, :, 2 * arc_sz + label_sz:] | |||
| # biaffine arc classifier | |||
| arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | |||
| # use gold or predicted arc to predict label | |||
| if target1 is None or not self.training: | |||
| # use greedy decoding in training | |||
| if self.training or self.use_greedy_infer: | |||
| heads = self.greedy_decoder(arc_pred, mask) | |||
| else: | |||
| heads = self.mst_decoder(arc_pred, mask) | |||
| head_pred = heads | |||
| else: | |||
| assert self.training # must be training mode | |||
| if target1 is None: | |||
| heads = self.greedy_decoder(arc_pred, mask) | |||
| head_pred = heads | |||
| else: | |||
| head_pred = None | |||
| heads = target1 | |||
| batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=words1.device).unsqueeze(1) | |||
| label_head = label_head[batch_range, heads].contiguous() | |||
| label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] | |||
| res_dict = {'pred1': arc_pred, 'pred2': label_pred} | |||
| if head_pred is not None: | |||
| res_dict['pred3'] = head_pred | |||
| return res_dict | |||
| def train_step(self, words1, words2, seq_len, target1, target2): | |||
| res = self(words1, words2, seq_len, target1) | |||
| arc_pred = res['pred1'] | |||
| label_pred = res['pred2'] | |||
| loss = self.loss(pred1=arc_pred, pred2=label_pred, target1=target1, target2=target2, seq_len=seq_len) | |||
| return {'loss': loss} | |||
| @staticmethod | |||
| def loss(pred1, pred2, target1, target2, seq_len): | |||
| r""" | |||
| 计算parser的loss | |||
| :param pred1: [batch_size, seq_len, seq_len] 边预测logits | |||
| :param pred2: [batch_size, seq_len, num_label] label预测logits | |||
| :param target1: [batch_size, seq_len] 真实边的标注 | |||
| :param target2: [batch_size, seq_len] 真实类别的标注 | |||
| :param seq_len: [batch_size, seq_len] 真实目标的长度 | |||
| :return loss: scalar | |||
| """ | |||
| batch_size, length, _ = pred1.shape | |||
| mask = seq_len_to_mask(seq_len, max_len=length) | |||
| flip_mask = (mask.eq(False)) | |||
| _arc_pred = pred1.clone() | |||
| _arc_pred = _arc_pred.masked_fill(flip_mask.unsqueeze(1), -float('inf')) | |||
| arc_logits = F.log_softmax(_arc_pred, dim=2) | |||
| label_logits = F.log_softmax(pred2, dim=2) | |||
| batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1) | |||
| child_index = torch.arange(length, device=arc_logits.device, dtype=torch.long).unsqueeze(0) | |||
| arc_loss = arc_logits[batch_index, child_index, target1] | |||
| label_loss = label_logits[batch_index, child_index, target2] | |||
| arc_loss = arc_loss.masked_fill(flip_mask, 0) | |||
| label_loss = label_loss.masked_fill(flip_mask, 0) | |||
| arc_nll = -arc_loss.mean() | |||
| label_nll = -label_loss.mean() | |||
| return arc_nll + label_nll | |||
| def evaluate_step(self, words1, words2, seq_len): | |||
| r"""模型预测API | |||
| :param words1: [batch_size, seq_len] 输入word序列 | |||
| :param words2: [batch_size, seq_len] 输入pos序列 | |||
| :param seq_len: [batch_size, seq_len] 输入序列长度 | |||
| :return dict: parsing | |||
| 结果:: | |||
| pred1: [batch_size, seq_len] heads的预测结果 | |||
| pred2: [batch_size, seq_len, num_label] label预测logits | |||
| """ | |||
| res = self(words1, words2, seq_len) | |||
| output = {} | |||
| output['pred1'] = res.pop('pred3') | |||
| _, label_pred = res.pop('pred2').max(2) | |||
| output['pred2'] = label_pred | |||
| return output | |||
| @@ -0,0 +1,92 @@ | |||
| r""" | |||
| .. todo:: | |||
| doc | |||
| """ | |||
| __all__ = [ | |||
| "CNNText" | |||
| ] | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from ...core.utils import seq_len_to_mask | |||
| from ...embeddings.torch import embedding | |||
| from ...modules.torch import encoder | |||
| class CNNText(torch.nn.Module): | |||
| r""" | |||
| 使用CNN进行文本分类的模型 | |||
| 'Yoon Kim. 2014. Convolution Neural Networks for Sentence Classification.' | |||
| """ | |||
| def __init__(self, embed, | |||
| num_classes, | |||
| kernel_nums=(30, 40, 50), | |||
| kernel_sizes=(1, 3, 5), | |||
| dropout=0.5): | |||
| r""" | |||
| :param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int), | |||
| 第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding | |||
| :param int num_classes: 一共有多少类 | |||
| :param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 | |||
| :param float dropout: Dropout的大小 | |||
| """ | |||
| super(CNNText, self).__init__() | |||
| # no support for pre-trained embedding currently | |||
| self.embed = embedding.Embedding(embed) | |||
| self.conv_pool = encoder.ConvMaxpool( | |||
| in_channels=self.embed.embedding_dim, | |||
| out_channels=kernel_nums, | |||
| kernel_sizes=kernel_sizes) | |||
| self.dropout = nn.Dropout(dropout) | |||
| self.fc = nn.Linear(sum(kernel_nums), num_classes) | |||
| def forward(self, words, seq_len=None): | |||
| r""" | |||
| :param torch.LongTensor words: [batch_size, seq_len],句子中word的index | |||
| :param torch.LongTensor seq_len: [batch,] 每个句子的长度 | |||
| :param target: 每个 sample 的目标值。 | |||
| :return output: | |||
| """ | |||
| x = self.embed(words) # [N,L] -> [N,L,C] | |||
| if seq_len is not None: | |||
| mask = seq_len_to_mask(seq_len) | |||
| x = self.conv_pool(x, mask) | |||
| else: | |||
| x = self.conv_pool(x) # [N,L,C] -> [N,C] | |||
| x = self.dropout(x) | |||
| x = self.fc(x) # [N,C] -> [N, N_class] | |||
| res = {'pred': x} | |||
| return res | |||
| def train_step(self, words, target, seq_len=None): | |||
| """ | |||
| :param words: | |||
| :param target: | |||
| :param seq_len: | |||
| :return: | |||
| """ | |||
| res = self(words, seq_len) | |||
| x = res['pred'] | |||
| loss = F.cross_entropy(x, target) | |||
| return {'loss': loss} | |||
| def evaluate_step(self, words, seq_len=None): | |||
| r""" | |||
| :param torch.LongTensor words: [batch_size, seq_len],句子中word的index | |||
| :param torch.LongTensor seq_len: [batch,] 每个句子的长度 | |||
| :return predict: dict of torch.LongTensor, [batch_size, ] | |||
| """ | |||
| output = self(words, seq_len) | |||
| _, predict = output['pred'].max(dim=1) | |||
| return {'pred': predict} | |||
| @@ -0,0 +1,81 @@ | |||
| r"""undocumented""" | |||
| import torch | |||
| from torch import nn | |||
| import torch.nn.functional as F | |||
| from fastNLP import seq_len_to_mask | |||
| from .seq2seq_model import Seq2SeqModel | |||
| from ...modules.torch.generator.seq2seq_generator import SequenceGenerator | |||
| __all__ = ['SequenceGeneratorModel'] | |||
| class SequenceGeneratorModel(nn.Module): | |||
| """ | |||
| 通过使用本模型封装seq2seq_model使得其既可以用于训练也可以用于生成。训练的时候,本模型的forward函数会被调用,生成的时候本模型的predict | |||
| 函数会被调用。 | |||
| """ | |||
| def __init__(self, seq2seq_model: Seq2SeqModel, bos_token_id, eos_token_id=None, max_length=30, max_len_a=0.0, | |||
| num_beams=1, do_sample=True, temperature=1.0, top_k=50, top_p=1.0, | |||
| repetition_penalty=1, length_penalty=1.0, pad_token_id=0): | |||
| """ | |||
| :param Seq2SeqModel seq2seq_model: 序列到序列模型 | |||
| :param int,None bos_token_id: 句子开头的token id | |||
| :param int,None eos_token_id: 句子结束的token id | |||
| :param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len | |||
| :param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask | |||
| :param int num_beams: beam search的大小 | |||
| :param bool do_sample: 是否通过采样的方式生成 | |||
| :param float temperature: 只有在do_sample为True才有意义 | |||
| :param int top_k: 只从top_k中采样 | |||
| :param float top_p: 只从top_p的token中采样,nucles sample | |||
| :param float repetition_penalty: 多大程度上惩罚重复的token | |||
| :param float length_penalty: 对长度的惩罚,小于1鼓励长句,大于1鼓励短剧 | |||
| :param int pad_token_id: 当某句话生成结束之后,之后生成的内容用pad_token_id补充 | |||
| """ | |||
| super().__init__() | |||
| self.seq2seq_model = seq2seq_model | |||
| self.generator = SequenceGenerator(seq2seq_model.decoder, max_length=max_length, max_len_a=max_len_a, | |||
| num_beams=num_beams, | |||
| do_sample=do_sample, temperature=temperature, top_k=top_k, top_p=top_p, | |||
| bos_token_id=bos_token_id, | |||
| eos_token_id=eos_token_id, | |||
| repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||
| pad_token_id=pad_token_id) | |||
| def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None): | |||
| """ | |||
| 透传调用seq2seq_model的forward。 | |||
| :param torch.LongTensor src_tokens: bsz x max_len | |||
| :param torch.LongTensor tgt_tokens: bsz x max_len' | |||
| :param torch.LongTensor src_seq_len: bsz | |||
| :param torch.LongTensor tgt_seq_len: bsz | |||
| :return: | |||
| """ | |||
| return self.seq2seq_model(src_tokens, tgt_tokens, src_seq_len, tgt_seq_len) | |||
| def train_step(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None): | |||
| res = self(src_tokens, tgt_tokens, src_seq_len, tgt_seq_len) | |||
| pred = res['pred'] | |||
| if tgt_seq_len is not None: | |||
| mask = seq_len_to_mask(tgt_seq_len, max_len=tgt_tokens.size(1)) | |||
| tgt_tokens = tgt_tokens.masked_fill(mask.eq(0), -100) | |||
| loss = F.cross_entropy(pred.transpose(1, 2), tgt_tokens) | |||
| return {'loss': loss} | |||
| def evaluate_step(self, src_tokens, src_seq_len=None): | |||
| """ | |||
| 给定source的内容,输出generate的内容。 | |||
| :param torch.LongTensor src_tokens: bsz x max_len | |||
| :param torch.LongTensor src_seq_len: bsz | |||
| :return: | |||
| """ | |||
| state = self.seq2seq_model.prepare_state(src_tokens, src_seq_len) | |||
| result = self.generator.generate(state) | |||
| return {'pred': result} | |||
| @@ -0,0 +1,196 @@ | |||
| r""" | |||
| 主要包含组成Sequence-to-Sequence的model | |||
| """ | |||
| import torch | |||
| from torch import nn | |||
| import torch.nn.functional as F | |||
| from fastNLP import seq_len_to_mask | |||
| from ...embeddings.torch.utils import get_embeddings | |||
| from ...embeddings.torch.utils import get_sinusoid_encoding_table | |||
| from ...modules.torch.decoder.seq2seq_decoder import Seq2SeqDecoder, TransformerSeq2SeqDecoder, LSTMSeq2SeqDecoder | |||
| from ...modules.torch.encoder.seq2seq_encoder import Seq2SeqEncoder, TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder | |||
| __all__ = ['Seq2SeqModel', 'TransformerSeq2SeqModel', 'LSTMSeq2SeqModel'] | |||
| class Seq2SeqModel(nn.Module): | |||
| def __init__(self, encoder: Seq2SeqEncoder, decoder: Seq2SeqDecoder): | |||
| """ | |||
| 可以用于在Trainer中训练的Seq2Seq模型。正常情况下,继承了该函数之后,只需要实现classmethod build_model即可。如果需要使用该模型 | |||
| 进行生成,需要把该模型输入到 :class:`~fastNLP.models.SequenceGeneratorModel` 中。在本模型中,forward()会把encoder后的 | |||
| 结果传入到decoder中,并将decoder的输出output出来。 | |||
| :param encoder: Seq2SeqEncoder 对象,需要实现对应的forward()函数,接受两个参数,第一个为bsz x max_len的source tokens, 第二个为 | |||
| bsz的source的长度;需要返回两个tensor: encoder_outputs: bsz x max_len x hidden_size, encoder_mask: bsz x max_len | |||
| 为1的地方需要被attend。如果encoder的输出或者输入有变化,可以重载本模型的prepare_state()函数或者forward()函数 | |||
| :param decoder: Seq2SeqDecoder 对象,需要实现init_state()函数,输出为两个参数,第一个为bsz x max_len x hidden_size是 | |||
| encoder的输出; 第二个为bsz x max_len,为encoder输出的mask,为0的地方为pad。若decoder需要更多输入,请重载当前模型的 | |||
| prepare_state()或forward()函数 | |||
| """ | |||
| super().__init__() | |||
| self.encoder = encoder | |||
| self.decoder = decoder | |||
| def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None): | |||
| """ | |||
| :param torch.LongTensor src_tokens: source的token | |||
| :param torch.LongTensor tgt_tokens: target的token | |||
| :param torch.LongTensor src_seq_len: src的长度 | |||
| :param torch.LongTensor tgt_seq_len: target的长度,默认用不上 | |||
| :return: {'pred': torch.Tensor}, 其中pred的shape为bsz x max_len x vocab_size | |||
| """ | |||
| state = self.prepare_state(src_tokens, src_seq_len) | |||
| decoder_output = self.decoder(tgt_tokens, state) | |||
| if isinstance(decoder_output, torch.Tensor): | |||
| return {'pred': decoder_output} | |||
| elif isinstance(decoder_output, (tuple, list)): | |||
| return {'pred': decoder_output[0]} | |||
| else: | |||
| raise TypeError(f"Unsupported return type from Decoder:{type(self.decoder)}") | |||
| def train_step(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None): | |||
| res = self(src_tokens, tgt_tokens, src_seq_len, tgt_seq_len) | |||
| pred = res['pred'] | |||
| if tgt_seq_len is not None: | |||
| mask = seq_len_to_mask(tgt_seq_len, max_len=tgt_tokens.size(1)) | |||
| tgt_tokens = tgt_tokens.masked_fill(mask.eq(0), -100) | |||
| loss = F.cross_entropy(pred.transpose(1, 2), tgt_tokens) | |||
| return {'loss': loss} | |||
| def prepare_state(self, src_tokens, src_seq_len=None): | |||
| """ | |||
| 调用encoder获取state,会把encoder的encoder_output, encoder_mask直接传入到decoder.init_state中初始化一个state | |||
| :param src_tokens: | |||
| :param src_seq_len: | |||
| :return: | |||
| """ | |||
| encoder_output, encoder_mask = self.encoder(src_tokens, src_seq_len) | |||
| state = self.decoder.init_state(encoder_output, encoder_mask) | |||
| return state | |||
| @classmethod | |||
| def build_model(cls, *args, **kwargs): | |||
| """ | |||
| 需要实现本方法来进行Seq2SeqModel的初始化 | |||
| :return: | |||
| """ | |||
| raise NotImplemented | |||
| class TransformerSeq2SeqModel(Seq2SeqModel): | |||
| """ | |||
| Encoder为TransformerSeq2SeqEncoder, decoder为TransformerSeq2SeqDecoder,通过build_model方法初始化 | |||
| """ | |||
| def __init__(self, encoder, decoder): | |||
| super().__init__(encoder, decoder) | |||
| @classmethod | |||
| def build_model(cls, src_embed, tgt_embed=None, | |||
| pos_embed='sin', max_position=1024, num_layers=6, d_model=512, n_head=8, dim_ff=2048, dropout=0.1, | |||
| bind_encoder_decoder_embed=False, | |||
| bind_decoder_input_output_embed=True): | |||
| """ | |||
| 初始化一个TransformerSeq2SeqModel | |||
| :param nn.Module, StaticEmbedding, Tuple[int, int] src_embed: source的embedding | |||
| :param nn.Module, StaticEmbedding, Tuple[int, int] tgt_embed: target的embedding,如果bind_encoder_decoder_embed为 | |||
| True,则不要输入该值 | |||
| :param str pos_embed: 支持sin, learned两种 | |||
| :param int max_position: 最大支持长度 | |||
| :param int num_layers: encoder和decoder的层数 | |||
| :param int d_model: encoder和decoder输入输出的大小 | |||
| :param int n_head: encoder和decoder的head的数量 | |||
| :param int dim_ff: encoder和decoder中FFN中间映射的维度 | |||
| :param float dropout: Attention和FFN dropout的大小 | |||
| :param bool bind_encoder_decoder_embed: 是否对encoder和decoder使用相同的embedding | |||
| :param bool bind_decoder_input_output_embed: decoder的输出embedding是否与其输入embedding是一样的权重 | |||
| :return: TransformerSeq2SeqModel | |||
| """ | |||
| if bind_encoder_decoder_embed and tgt_embed is not None: | |||
| raise RuntimeError("If you set `bind_encoder_decoder_embed=True`, please do not provide `tgt_embed`.") | |||
| src_embed = get_embeddings(src_embed) | |||
| if bind_encoder_decoder_embed: | |||
| tgt_embed = src_embed | |||
| else: | |||
| assert tgt_embed is not None, "You need to pass `tgt_embed` when `bind_encoder_decoder_embed=False`" | |||
| tgt_embed = get_embeddings(tgt_embed) | |||
| if pos_embed == 'sin': | |||
| encoder_pos_embed = nn.Embedding.from_pretrained( | |||
| get_sinusoid_encoding_table(max_position + 1, src_embed.embedding_dim, padding_idx=0), | |||
| freeze=True) # 这里规定0是padding | |||
| deocder_pos_embed = nn.Embedding.from_pretrained( | |||
| get_sinusoid_encoding_table(max_position + 1, tgt_embed.embedding_dim, padding_idx=0), | |||
| freeze=True) # 这里规定0是padding | |||
| elif pos_embed == 'learned': | |||
| encoder_pos_embed = get_embeddings((max_position + 1, src_embed.embedding_dim), padding_idx=0) | |||
| deocder_pos_embed = get_embeddings((max_position + 1, src_embed.embedding_dim), padding_idx=1) | |||
| else: | |||
| raise ValueError("pos_embed only supports sin or learned.") | |||
| encoder = TransformerSeq2SeqEncoder(embed=src_embed, pos_embed=encoder_pos_embed, | |||
| num_layers=num_layers, d_model=d_model, n_head=n_head, dim_ff=dim_ff, | |||
| dropout=dropout) | |||
| decoder = TransformerSeq2SeqDecoder(embed=tgt_embed, pos_embed=deocder_pos_embed, | |||
| d_model=d_model, num_layers=num_layers, n_head=n_head, dim_ff=dim_ff, | |||
| dropout=dropout, | |||
| bind_decoder_input_output_embed=bind_decoder_input_output_embed) | |||
| return cls(encoder, decoder) | |||
| class LSTMSeq2SeqModel(Seq2SeqModel): | |||
| """ | |||
| 使用LSTMSeq2SeqEncoder和LSTMSeq2SeqDecoder的model | |||
| """ | |||
| def __init__(self, encoder, decoder): | |||
| super().__init__(encoder, decoder) | |||
| @classmethod | |||
| def build_model(cls, src_embed, tgt_embed=None, | |||
| num_layers = 3, hidden_size = 400, dropout = 0.3, bidirectional=True, | |||
| attention=True, bind_encoder_decoder_embed=False, | |||
| bind_decoder_input_output_embed=True): | |||
| """ | |||
| :param nn.Module, StaticEmbedding, Tuple[int, int] src_embed: source的embedding | |||
| :param nn.Module, StaticEmbedding, Tuple[int, int] tgt_embed: target的embedding,如果bind_encoder_decoder_embed为 | |||
| True,则不要输入该值 | |||
| :param int num_layers: Encoder和Decoder的层数 | |||
| :param int hidden_size: encoder和decoder的隐藏层大小 | |||
| :param float dropout: 每层之间的Dropout的大小 | |||
| :param bool bidirectional: encoder是否使用双向LSTM | |||
| :param bool attention: decoder是否使用attention attend encoder在所有时刻的状态 | |||
| :param bool bind_encoder_decoder_embed: 是否对encoder和decoder使用相同的embedding | |||
| :param bool bind_decoder_input_output_embed: decoder的输出embedding是否与其输入embedding是一样的权重 | |||
| :return: LSTMSeq2SeqModel | |||
| """ | |||
| if bind_encoder_decoder_embed and tgt_embed is not None: | |||
| raise RuntimeError("If you set `bind_encoder_decoder_embed=True`, please do not provide `tgt_embed`.") | |||
| src_embed = get_embeddings(src_embed) | |||
| if bind_encoder_decoder_embed: | |||
| tgt_embed = src_embed | |||
| else: | |||
| assert tgt_embed is not None, "You need to pass `tgt_embed` when `bind_encoder_decoder_embed=False`" | |||
| tgt_embed = get_embeddings(tgt_embed) | |||
| encoder = LSTMSeq2SeqEncoder(embed=src_embed, num_layers = num_layers, | |||
| hidden_size = hidden_size, dropout = dropout, bidirectional=bidirectional) | |||
| decoder = LSTMSeq2SeqDecoder(embed=tgt_embed, num_layers = num_layers, hidden_size = hidden_size, | |||
| dropout = dropout, bind_decoder_input_output_embed = bind_decoder_input_output_embed, | |||
| attention=attention) | |||
| return cls(encoder, decoder) | |||
| @@ -0,0 +1,271 @@ | |||
| r""" | |||
| 本模块实现了几种序列标注模型 | |||
| """ | |||
| __all__ = [ | |||
| "SeqLabeling", | |||
| "AdvSeqLabel", | |||
| "BiLSTMCRF" | |||
| ] | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from ...core.utils import seq_len_to_mask | |||
| from ...embeddings.torch.utils import get_embeddings | |||
| from ...modules.torch.decoder import ConditionalRandomField | |||
| from ...modules.torch.encoder import LSTM | |||
| from ...modules.torch import decoder, encoder | |||
| from ...modules.torch.decoder.crf import allowed_transitions | |||
| class BiLSTMCRF(nn.Module): | |||
| r""" | |||
| 结构为embedding + BiLSTM + FC + Dropout + CRF. | |||
| """ | |||
| def __init__(self, embed, num_classes, num_layers=1, hidden_size=100, dropout=0.5, | |||
| target_vocab=None): | |||
| r""" | |||
| :param embed: 支持(1)fastNLP的各种Embedding, (2) tuple, 指明num_embedding, dimension, 如(1000, 100) | |||
| :param num_classes: 一共多少个类 | |||
| :param num_layers: BiLSTM的层数 | |||
| :param hidden_size: BiLSTM的hidden_size,实际hidden size为该值的两倍(前向、后向) | |||
| :param dropout: dropout的概率,0为不dropout | |||
| :param target_vocab: Vocabulary对象,target与index的对应关系。如果传入该值,将自动避免非法的解码序列。 | |||
| """ | |||
| super().__init__() | |||
| self.embed = get_embeddings(embed) | |||
| if num_layers>1: | |||
| self.lstm = LSTM(self.embed.embedding_dim, num_layers=num_layers, hidden_size=hidden_size, bidirectional=True, | |||
| batch_first=True, dropout=dropout) | |||
| else: | |||
| self.lstm = LSTM(self.embed.embedding_dim, num_layers=num_layers, hidden_size=hidden_size, bidirectional=True, | |||
| batch_first=True) | |||
| self.dropout = nn.Dropout(dropout) | |||
| self.fc = nn.Linear(hidden_size*2, num_classes) | |||
| trans = None | |||
| if target_vocab is not None: | |||
| assert len(target_vocab)==num_classes, "The number of classes should be same with the length of target vocabulary." | |||
| trans = allowed_transitions(target_vocab.idx2word, include_start_end=True) | |||
| self.crf = ConditionalRandomField(num_classes, include_start_end_trans=True, allowed_transitions=trans) | |||
| def forward(self, words, seq_len=None, target=None): | |||
| words = self.embed(words) | |||
| feats, _ = self.lstm(words, seq_len=seq_len) | |||
| feats = self.fc(feats) | |||
| feats = self.dropout(feats) | |||
| logits = F.log_softmax(feats, dim=-1) | |||
| mask = seq_len_to_mask(seq_len) | |||
| if target is None: | |||
| pred, _ = self.crf.viterbi_decode(logits, mask) | |||
| return {'pred':pred} | |||
| else: | |||
| loss = self.crf(logits, target, mask).mean() | |||
| return {'loss':loss} | |||
| def train_step(self, words, seq_len, target): | |||
| return self(words, seq_len, target) | |||
| def evaluate_step(self, words, seq_len): | |||
| return self(words, seq_len) | |||
| class SeqLabeling(nn.Module): | |||
| r""" | |||
| 一个基础的Sequence labeling的模型。 | |||
| 用于做sequence labeling的基础类。结构包含一层Embedding,一层LSTM(单向,一层),一层FC,以及一层CRF。 | |||
| """ | |||
| def __init__(self, embed, hidden_size, num_classes): | |||
| r""" | |||
| :param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int), | |||
| 第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, embedding, ndarray等则直接使用该值初始化Embedding | |||
| :param int hidden_size: LSTM隐藏层的大小 | |||
| :param int num_classes: 一共有多少类 | |||
| """ | |||
| super(SeqLabeling, self).__init__() | |||
| self.embedding = get_embeddings(embed) | |||
| self.rnn = encoder.LSTM(self.embedding.embedding_dim, hidden_size) | |||
| self.fc = nn.Linear(hidden_size, num_classes) | |||
| self.crf = decoder.ConditionalRandomField(num_classes) | |||
| def forward(self, words, seq_len): | |||
| r""" | |||
| :param torch.LongTensor words: [batch_size, max_len],序列的index | |||
| :param torch.LongTensor seq_len: [batch_size,], 这个序列的长度 | |||
| :return | |||
| """ | |||
| x = self.embedding(words) | |||
| # [batch_size, max_len, word_emb_dim] | |||
| x, _ = self.rnn(x, seq_len) | |||
| # [batch_size, max_len, hidden_size * direction] | |||
| x = self.fc(x) | |||
| return {'pred': x} | |||
| # [batch_size, max_len, num_classes] | |||
| def train_step(self, words, seq_len, target): | |||
| res = self(words, seq_len) | |||
| pred = res['pred'] | |||
| mask = seq_len_to_mask(seq_len, max_len=target.size(1)) | |||
| return {'loss': self._internal_loss(pred, target, mask)} | |||
| def evaluate_step(self, words, seq_len): | |||
| r""" | |||
| 用于在预测时使用 | |||
| :param torch.LongTensor words: [batch_size, max_len] | |||
| :param torch.LongTensor seq_len: [batch_size,] | |||
| :return: {'pred': xx}, [batch_size, max_len] | |||
| """ | |||
| mask = seq_len_to_mask(seq_len, max_len=words.size(1)) | |||
| res = self(words, seq_len) | |||
| pred = res['pred'] | |||
| # [batch_size, max_len, num_classes] | |||
| pred = self._decode(pred, mask) | |||
| return {'pred': pred} | |||
| def _internal_loss(self, x, y, mask): | |||
| r""" | |||
| Negative log likelihood loss. | |||
| :param x: Tensor, [batch_size, max_len, tag_size] | |||
| :param y: Tensor, [batch_size, max_len] | |||
| :return loss: a scalar Tensor | |||
| """ | |||
| x = x.float() | |||
| y = y.long() | |||
| total_loss = self.crf(x, y, mask) | |||
| return torch.mean(total_loss) | |||
| def _decode(self, x, mask): | |||
| r""" | |||
| :param torch.FloatTensor x: [batch_size, max_len, tag_size] | |||
| :return prediction: [batch_size, max_len] | |||
| """ | |||
| tag_seq, _ = self.crf.viterbi_decode(x, mask) | |||
| return tag_seq | |||
| class AdvSeqLabel(nn.Module): | |||
| r""" | |||
| 更复杂的Sequence Labelling模型。结构为Embedding, LayerNorm, 双向LSTM(两层),FC,LayerNorm,DropOut,FC,CRF。 | |||
| """ | |||
| def __init__(self, embed, hidden_size, num_classes, dropout=0.3, id2words=None, encoding_type='bmes'): | |||
| r""" | |||
| :param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int), | |||
| 第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding | |||
| :param int hidden_size: LSTM的隐层大小 | |||
| :param int num_classes: 有多少个类 | |||
| :param float dropout: LSTM中以及DropOut层的drop概率 | |||
| :param dict id2words: tag id转为其tag word的表。用于在CRF解码时防止解出非法的顺序,比如'BMES'这个标签规范中,'S' | |||
| 不能出现在'B'之后。这里也支持类似与'B-NN',即'-'前为标签类型的指示,后面为具体的tag的情况。这里不但会保证 | |||
| 'B-NN'后面不为'S-NN'还会保证'B-NN'后面不会出现'M-xx'(任何非'M-NN'和'E-NN'的情况。) | |||
| :param str encoding_type: 支持"BIO", "BMES", "BEMSO", 只有在id2words不为None的情况有用。 | |||
| """ | |||
| super().__init__() | |||
| self.Embedding = get_embeddings(embed) | |||
| self.norm1 = torch.nn.LayerNorm(self.Embedding.embedding_dim) | |||
| self.Rnn = encoder.LSTM(input_size=self.Embedding.embedding_dim, hidden_size=hidden_size, num_layers=2, | |||
| dropout=dropout, | |||
| bidirectional=True, batch_first=True) | |||
| self.Linear1 = nn.Linear(hidden_size * 2, hidden_size * 2 // 3) | |||
| self.norm2 = torch.nn.LayerNorm(hidden_size * 2 // 3) | |||
| self.relu = torch.nn.LeakyReLU() | |||
| self.drop = torch.nn.Dropout(dropout) | |||
| self.Linear2 = nn.Linear(hidden_size * 2 // 3, num_classes) | |||
| if id2words is None: | |||
| self.Crf = decoder.crf.ConditionalRandomField(num_classes, include_start_end_trans=False) | |||
| else: | |||
| self.Crf = decoder.crf.ConditionalRandomField(num_classes, include_start_end_trans=False, | |||
| allowed_transitions=allowed_transitions(id2words, | |||
| encoding_type=encoding_type)) | |||
| def _decode(self, x, mask): | |||
| r""" | |||
| :param torch.FloatTensor x: [batch_size, max_len, tag_size] | |||
| :param torch.ByteTensor mask: [batch_size, max_len] | |||
| :return torch.LongTensor, [batch_size, max_len] | |||
| """ | |||
| tag_seq, _ = self.Crf.viterbi_decode(x, mask) | |||
| return tag_seq | |||
| def _internal_loss(self, x, y, mask): | |||
| r""" | |||
| Negative log likelihood loss. | |||
| :param x: Tensor, [batch_size, max_len, tag_size] | |||
| :param y: Tensor, [batch_size, max_len] | |||
| :param mask: Tensor, [batch_size, max_len] | |||
| :return loss: a scalar Tensor | |||
| """ | |||
| x = x.float() | |||
| y = y.long() | |||
| total_loss = self.Crf(x, y, mask) | |||
| return torch.mean(total_loss) | |||
| def forward(self, words, seq_len, target=None): | |||
| r""" | |||
| :param torch.LongTensor words: [batch_size, mex_len] | |||
| :param torch.LongTensor seq_len:[batch_size, ] | |||
| :param torch.LongTensor target: [batch_size, max_len] | |||
| :return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. | |||
| If truth is not None, return loss, a scalar. Used in training. | |||
| """ | |||
| words = words.long() | |||
| seq_len = seq_len.long() | |||
| mask = seq_len_to_mask(seq_len, max_len=words.size(1)) | |||
| target = target.long() if target is not None else None | |||
| if next(self.parameters()).is_cuda: | |||
| words = words.cuda() | |||
| x = self.Embedding(words) | |||
| x = self.norm1(x) | |||
| # [batch_size, max_len, word_emb_dim] | |||
| x, _ = self.Rnn(x, seq_len=seq_len) | |||
| x = self.Linear1(x) | |||
| x = self.norm2(x) | |||
| x = self.relu(x) | |||
| x = self.drop(x) | |||
| x = self.Linear2(x) | |||
| if target is not None: | |||
| return {"loss": self._internal_loss(x, target, mask)} | |||
| else: | |||
| return {"pred": self._decode(x, mask)} | |||
| def train_step(self, words, seq_len, target): | |||
| r""" | |||
| :param torch.LongTensor words: [batch_size, mex_len] | |||
| :param torch.LongTensor seq_len: [batch_size, ] | |||
| :param torch.LongTensor target: [batch_size, max_len], 目标 | |||
| :return torch.Tensor: a scalar loss | |||
| """ | |||
| return self(words, seq_len, target) | |||
| def evaluate_step(self, words, seq_len): | |||
| r""" | |||
| :param torch.LongTensor words: [batch_size, mex_len] | |||
| :param torch.LongTensor seq_len: [batch_size, ] | |||
| :return torch.LongTensor: [batch_size, max_len] | |||
| """ | |||
| return self(words, seq_len) | |||
| @@ -0,0 +1,26 @@ | |||
| __all__ = [ | |||
| 'ConditionalRandomField', | |||
| 'allowed_transitions', | |||
| "State", | |||
| "Seq2SeqDecoder", | |||
| "LSTMSeq2SeqDecoder", | |||
| "TransformerSeq2SeqDecoder", | |||
| "LSTM", | |||
| "Seq2SeqEncoder", | |||
| "TransformerSeq2SeqEncoder", | |||
| "LSTMSeq2SeqEncoder", | |||
| "StarTransformer", | |||
| "VarRNN", | |||
| "VarLSTM", | |||
| "VarGRU", | |||
| 'SequenceGenerator', | |||
| "TimestepDropout", | |||
| ] | |||
| from .decoder import * | |||
| from .encoder import * | |||
| from .generator import * | |||
| from .dropout import TimestepDropout | |||
| @@ -0,0 +1,321 @@ | |||
| r"""undocumented""" | |||
| __all__ = [ | |||
| "MultiHeadAttention", | |||
| "BiAttention", | |||
| "SelfAttention", | |||
| ] | |||
| import math | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from torch import nn | |||
| from .decoder.seq2seq_state import TransformerState | |||
| class DotAttention(nn.Module): | |||
| r""" | |||
| Transformer当中的DotAttention | |||
| """ | |||
| def __init__(self, key_size, value_size, dropout=0.0): | |||
| super(DotAttention, self).__init__() | |||
| self.key_size = key_size | |||
| self.value_size = value_size | |||
| self.scale = math.sqrt(key_size) | |||
| self.drop = nn.Dropout(dropout) | |||
| self.softmax = nn.Softmax(dim=-1) | |||
| def forward(self, Q, K, V, mask_out=None): | |||
| r""" | |||
| :param Q: [..., seq_len_q, key_size] | |||
| :param K: [..., seq_len_k, key_size] | |||
| :param V: [..., seq_len_k, value_size] | |||
| :param mask_out: [..., 1, seq_len] or [..., seq_len_q, seq_len_k] | |||
| """ | |||
| output = torch.matmul(Q, K.transpose(-1, -2)) / self.scale | |||
| if mask_out is not None: | |||
| output.masked_fill_(mask_out, -1e9) | |||
| output = self.softmax(output) | |||
| output = self.drop(output) | |||
| return torch.matmul(output, V) | |||
| class MultiHeadAttention(nn.Module): | |||
| """ | |||
| Attention is all you need中提到的多头注意力 | |||
| """ | |||
| def __init__(self, d_model: int = 512, n_head: int = 8, dropout: float = 0.0, layer_idx: int = None): | |||
| super(MultiHeadAttention, self).__init__() | |||
| self.d_model = d_model | |||
| self.n_head = n_head | |||
| self.dropout = dropout | |||
| self.head_dim = d_model // n_head | |||
| self.layer_idx = layer_idx | |||
| assert d_model % n_head == 0, "d_model should be divisible by n_head" | |||
| self.scaling = self.head_dim ** -0.5 | |||
| self.q_proj = nn.Linear(d_model, d_model) | |||
| self.k_proj = nn.Linear(d_model, d_model) | |||
| self.v_proj = nn.Linear(d_model, d_model) | |||
| self.out_proj = nn.Linear(d_model, d_model) | |||
| self.reset_parameters() | |||
| def forward(self, query, key, value, key_mask=None, attn_mask=None, state=None): | |||
| """ | |||
| :param query: batch x seq x dim | |||
| :param key: batch x seq x dim | |||
| :param value: batch x seq x dim | |||
| :param key_mask: batch x seq 用于指示哪些key不要attend到;注意到mask为1的地方是要attend到的 | |||
| :param attn_mask: seq x seq, 用于mask掉attention map。 主要是用在训练时decoder端的self attention,下三角为1 | |||
| :param state: 过去的信息,在inference的时候会用到,比如encoder output、decoder的prev kv。这样可以减少计算。 | |||
| :return: | |||
| """ | |||
| assert key.size() == value.size() | |||
| if state is not None: | |||
| assert self.layer_idx is not None | |||
| qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr() | |||
| q = self.q_proj(query) # batch x seq x dim | |||
| q *= self.scaling | |||
| k = v = None | |||
| prev_k = prev_v = None | |||
| # 从state中取kv | |||
| if isinstance(state, TransformerState): # 说明此时在inference阶段 | |||
| if qkv_same: # 此时在decoder self attention | |||
| prev_k = state.decoder_prev_key[self.layer_idx] | |||
| prev_v = state.decoder_prev_value[self.layer_idx] | |||
| else: # 此时在decoder-encoder attention,直接将保存下来的key装载起来即可 | |||
| k = state.encoder_key[self.layer_idx] | |||
| v = state.encoder_value[self.layer_idx] | |||
| if k is None: | |||
| k = self.k_proj(key) | |||
| v = self.v_proj(value) | |||
| if prev_k is not None: | |||
| k = torch.cat((prev_k, k), dim=1) | |||
| v = torch.cat((prev_v, v), dim=1) | |||
| # 更新state | |||
| if isinstance(state, TransformerState): | |||
| if qkv_same: | |||
| state.decoder_prev_key[self.layer_idx] = k | |||
| state.decoder_prev_value[self.layer_idx] = v | |||
| else: | |||
| state.encoder_key[self.layer_idx] = k | |||
| state.encoder_value[self.layer_idx] = v | |||
| # 开始计算attention | |||
| batch_size, q_len, d_model = query.size() | |||
| k_len, v_len = k.size(1), v.size(1) | |||
| q = q.reshape(batch_size, q_len, self.n_head, self.head_dim) | |||
| k = k.reshape(batch_size, k_len, self.n_head, self.head_dim) | |||
| v = v.reshape(batch_size, v_len, self.n_head, self.head_dim) | |||
| attn_weights = torch.einsum('bqnh,bknh->bqkn', q, k) # bs,q_len,k_len,n_head | |||
| if key_mask is not None: | |||
| _key_mask = ~key_mask[:, None, :, None].bool() # batch,1,k_len,1 | |||
| attn_weights = attn_weights.masked_fill(_key_mask, -float('inf')) | |||
| if attn_mask is not None: | |||
| _attn_mask = attn_mask[None, :, :, None].eq(0) # 1,q_len,k_len,n_head | |||
| attn_weights = attn_weights.masked_fill(_attn_mask, -float('inf')) | |||
| attn_weights = F.softmax(attn_weights, dim=2) | |||
| attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) | |||
| output = torch.einsum('bqkn,bknh->bqnh', attn_weights, v) # batch,q_len,n_head,head_dim | |||
| output = output.reshape(batch_size, q_len, -1) | |||
| output = self.out_proj(output) # batch,q_len,dim | |||
| return output, attn_weights | |||
| def reset_parameters(self): | |||
| nn.init.xavier_uniform_(self.q_proj.weight) | |||
| nn.init.xavier_uniform_(self.k_proj.weight) | |||
| nn.init.xavier_uniform_(self.v_proj.weight) | |||
| nn.init.xavier_uniform_(self.out_proj.weight) | |||
| def set_layer_idx(self, layer_idx): | |||
| self.layer_idx = layer_idx | |||
| class AttentionLayer(nn.Module): | |||
| def __init__(selfu, input_size, key_dim, value_dim, bias=False): | |||
| """ | |||
| 可用于LSTM2LSTM的序列到序列模型的decode过程中,该attention是在decode过程中根据上一个step的hidden计算对encoder结果的attention | |||
| :param int input_size: 输入的大小 | |||
| :param int key_dim: 一般就是encoder_output输出的维度 | |||
| :param int value_dim: 输出的大小维度, 一般就是decoder hidden的大小 | |||
| :param bias: | |||
| """ | |||
| super().__init__() | |||
| selfu.input_proj = nn.Linear(input_size, key_dim, bias=bias) | |||
| selfu.output_proj = nn.Linear(input_size + key_dim, value_dim, bias=bias) | |||
| def forward(self, input, encode_outputs, encode_mask): | |||
| """ | |||
| :param input: batch_size x input_size | |||
| :param encode_outputs: batch_size x max_len x key_dim | |||
| :param encode_mask: batch_size x max_len, 为0的地方为padding | |||
| :return: hidden: batch_size x value_dim, scores: batch_size x max_len, normalized过的 | |||
| """ | |||
| # x: bsz x encode_hidden_size | |||
| x = self.input_proj(input) | |||
| # compute attention | |||
| attn_scores = torch.matmul(encode_outputs, x.unsqueeze(-1)).squeeze(-1) # b x max_len | |||
| # don't attend over padding | |||
| if encode_mask is not None: | |||
| attn_scores = attn_scores.float().masked_fill_( | |||
| encode_mask.eq(0), | |||
| float('-inf') | |||
| ).type_as(attn_scores) # FP16 support: cast to float and back | |||
| attn_scores = F.softmax(attn_scores, dim=-1) # srclen x bsz | |||
| # sum weighted sources | |||
| x = torch.matmul(attn_scores.unsqueeze(1), encode_outputs).squeeze(1) # b x encode_hidden_size | |||
| x = torch.tanh(self.output_proj(torch.cat((x, input), dim=1))) | |||
| return x, attn_scores | |||
| def _masked_softmax(tensor, mask): | |||
| tensor_shape = tensor.size() | |||
| reshaped_tensor = tensor.view(-1, tensor_shape[-1]) | |||
| # Reshape the mask so it matches the size of the input tensor. | |||
| while mask.dim() < tensor.dim(): | |||
| mask = mask.unsqueeze(1) | |||
| mask = mask.expand_as(tensor).contiguous().float() | |||
| reshaped_mask = mask.view(-1, mask.size()[-1]) | |||
| result = F.softmax(reshaped_tensor * reshaped_mask, dim=-1) | |||
| result = result * reshaped_mask | |||
| # 1e-13 is added to avoid divisions by zero. | |||
| result = result / (result.sum(dim=-1, keepdim=True) + 1e-13) | |||
| return result.view(*tensor_shape) | |||
| def _weighted_sum(tensor, weights, mask): | |||
| w_sum = weights.bmm(tensor) | |||
| while mask.dim() < w_sum.dim(): | |||
| mask = mask.unsqueeze(1) | |||
| mask = mask.transpose(-1, -2) | |||
| mask = mask.expand_as(w_sum).contiguous().float() | |||
| return w_sum * mask | |||
| class BiAttention(nn.Module): | |||
| r""" | |||
| Bi Attention module | |||
| 对于给定的两个向量序列 :math:`a_i` 和 :math:`b_j` , BiAttention模块将通过以下的公式来计算attention结果 | |||
| .. math:: | |||
| \begin{array}{ll} \\ | |||
| e_{ij} = {a}^{\mathrm{T}}_{i}{b}_{j} \\ | |||
| {\hat{a}}_{i} = \sum_{j=1}^{\mathcal{l}_{b}}{\frac{\mathrm{exp}(e_{ij})}{\sum_{k=1}^{\mathcal{l}_{b}}{\mathrm{exp}(e_{ik})}}}{b}_{j} \\ | |||
| {\hat{b}}_{j} = \sum_{i=1}^{\mathcal{l}_{a}}{\frac{\mathrm{exp}(e_{ij})}{\sum_{k=1}^{\mathcal{l}_{a}}{\mathrm{exp}(e_{ik})}}}{a}_{i} \\ | |||
| \end{array} | |||
| """ | |||
| def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask): | |||
| r""" | |||
| :param torch.Tensor premise_batch: [batch_size, a_seq_len, hidden_size] | |||
| :param torch.Tensor premise_mask: [batch_size, a_seq_len] | |||
| :param torch.Tensor hypothesis_batch: [batch_size, b_seq_len, hidden_size] | |||
| :param torch.Tensor hypothesis_mask: [batch_size, b_seq_len] | |||
| :return: torch.Tensor attended_premises: [batch_size, a_seq_len, hidden_size] torch.Tensor attended_hypotheses: [batch_size, b_seq_len, hidden_size] | |||
| """ | |||
| similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1) | |||
| .contiguous()) | |||
| prem_hyp_attn = _masked_softmax(similarity_matrix, hypothesis_mask) | |||
| hyp_prem_attn = _masked_softmax(similarity_matrix.transpose(1, 2) | |||
| .contiguous(), | |||
| premise_mask) | |||
| attended_premises = _weighted_sum(hypothesis_batch, | |||
| prem_hyp_attn, | |||
| premise_mask) | |||
| attended_hypotheses = _weighted_sum(premise_batch, | |||
| hyp_prem_attn, | |||
| hypothesis_mask) | |||
| return attended_premises, attended_hypotheses | |||
| class SelfAttention(nn.Module): | |||
| r""" | |||
| 这是一个基于论文 `A structured self-attentive sentence embedding <https://arxiv.org/pdf/1703.03130.pdf>`_ | |||
| 的Self Attention Module. | |||
| """ | |||
| def __init__(self, input_size, attention_unit=300, attention_hops=10, drop=0.5): | |||
| r""" | |||
| :param int input_size: 输入tensor的hidden维度 | |||
| :param int attention_unit: 输出tensor的hidden维度 | |||
| :param int attention_hops: | |||
| :param float drop: dropout概率,默认值为0.5 | |||
| """ | |||
| super(SelfAttention, self).__init__() | |||
| self.attention_hops = attention_hops | |||
| self.ws1 = nn.Linear(input_size, attention_unit, bias=False) | |||
| self.ws2 = nn.Linear(attention_unit, attention_hops, bias=False) | |||
| self.I = torch.eye(attention_hops, requires_grad=False) | |||
| self.I_origin = self.I | |||
| self.drop = nn.Dropout(drop) | |||
| self.tanh = nn.Tanh() | |||
| def _penalization(self, attention): | |||
| r""" | |||
| compute the penalization term for attention module | |||
| """ | |||
| baz = attention.size(0) | |||
| size = self.I.size() | |||
| if len(size) != 3 or size[0] != baz: | |||
| self.I = self.I_origin.expand(baz, -1, -1) | |||
| self.I = self.I.to(device=attention.device) | |||
| attention_t = torch.transpose(attention, 1, 2).contiguous() | |||
| mat = torch.bmm(attention, attention_t) - self.I[:attention.size(0)] | |||
| ret = (torch.sum(torch.sum((mat ** 2), 2), 1).squeeze() + 1e-10) ** 0.5 | |||
| return torch.sum(ret) / size[0] | |||
| def forward(self, input, input_origin): | |||
| r""" | |||
| :param torch.Tensor input: [batch_size, seq_len, hidden_size] 要做attention的矩阵 | |||
| :param torch.Tensor input_origin: [batch_size, seq_len] 原始token的index组成的矩阵,含有pad部分内容 | |||
| :return torch.Tensor output1: [batch_size, multi-head, hidden_size] 经过attention操作后输入矩阵的结果 | |||
| :return torch.Tensor output2: [1] attention惩罚项,是一个标量 | |||
| """ | |||
| input = input.contiguous() | |||
| size = input.size() # [bsz, len, nhid] | |||
| input_origin = input_origin.expand(self.attention_hops, -1, -1) # [hops,baz, len] | |||
| input_origin = input_origin.transpose(0, 1).contiguous() # [baz, hops,len] | |||
| y1 = self.tanh(self.ws1(self.drop(input))) # [baz,len,dim] -->[bsz,len, attention-unit] | |||
| attention = self.ws2(y1).transpose(1, 2).contiguous() | |||
| # [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len] | |||
| attention = attention + (-999999 * (input_origin == 0).float()) # remove the weight on padding token. | |||
| attention = F.softmax(attention, 2) # [baz ,hop, len] | |||
| return torch.bmm(attention, input), self._penalization(attention) # output1 --> [baz ,hop ,nhid] | |||
| @@ -0,0 +1,15 @@ | |||
| __all__ = [ | |||
| 'ConditionalRandomField', | |||
| 'allowed_transitions', | |||
| "State", | |||
| "Seq2SeqDecoder", | |||
| "LSTMSeq2SeqDecoder", | |||
| "TransformerSeq2SeqDecoder" | |||
| ] | |||
| from .crf import ConditionalRandomField, allowed_transitions | |||
| from .seq2seq_state import State | |||
| from .seq2seq_decoder import LSTMSeq2SeqDecoder, TransformerSeq2SeqDecoder, Seq2SeqDecoder | |||
| @@ -0,0 +1,354 @@ | |||
| r"""undocumented""" | |||
| __all__ = [ | |||
| "ConditionalRandomField", | |||
| "allowed_transitions" | |||
| ] | |||
| from typing import Union, List | |||
| import torch | |||
| from torch import nn | |||
| from ....core.metrics.span_f1_pre_rec_metric import _get_encoding_type_from_tag_vocab, _check_tag_vocab_and_encoding_type | |||
| from ....core.vocabulary import Vocabulary | |||
| def allowed_transitions(tag_vocab:Union[Vocabulary, dict], encoding_type:str=None, include_start_end:bool=False): | |||
| r""" | |||
| 给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 | |||
| :param tag_vocab: 支持类型为tag或tag-label。只有tag的,比如"B", "M"; 也可以是"B-NN", "M-NN", | |||
| tag和label之间一定要用"-"隔开。如果传入dict,格式需要形如{0:"O", 1:"B-tag1"},即index在前,tag在后。 | |||
| :param encoding_type: 支持``["bio", "bmes", "bmeso", "bioes"]``。默认为None,通过vocab自动推断 | |||
| :param include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | |||
| 为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); | |||
| start_idx=len(id2label), end_idx=len(id2label)+1。为False, 返回的结果中不含与开始结尾相关的内容 | |||
| :return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。 | |||
| """ | |||
| if encoding_type is None: | |||
| encoding_type = _get_encoding_type_from_tag_vocab(tag_vocab) | |||
| else: | |||
| encoding_type = encoding_type.lower() | |||
| _check_tag_vocab_and_encoding_type(tag_vocab, encoding_type) | |||
| pad_token = '<pad>' | |||
| unk_token = '<unk>' | |||
| if isinstance(tag_vocab, Vocabulary): | |||
| id_label_lst = list(tag_vocab.idx2word.items()) | |||
| pad_token = tag_vocab.padding | |||
| unk_token = tag_vocab.unknown | |||
| else: | |||
| id_label_lst = list(tag_vocab.items()) | |||
| num_tags = len(tag_vocab) | |||
| start_idx = num_tags | |||
| end_idx = num_tags + 1 | |||
| allowed_trans = [] | |||
| if include_start_end: | |||
| id_label_lst += [(start_idx, 'start'), (end_idx, 'end')] | |||
| def split_tag_label(from_label): | |||
| from_label = from_label.lower() | |||
| if from_label in ['start', 'end']: | |||
| from_tag = from_label | |||
| from_label = '' | |||
| else: | |||
| from_tag = from_label[:1] | |||
| from_label = from_label[2:] | |||
| return from_tag, from_label | |||
| for from_id, from_label in id_label_lst: | |||
| if from_label in [pad_token, unk_token]: | |||
| continue | |||
| from_tag, from_label = split_tag_label(from_label) | |||
| for to_id, to_label in id_label_lst: | |||
| if to_label in [pad_token, unk_token]: | |||
| continue | |||
| to_tag, to_label = split_tag_label(to_label) | |||
| if _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||
| allowed_trans.append((from_id, to_id)) | |||
| return allowed_trans | |||
| def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||
| r""" | |||
| :param str encoding_type: 支持"BIO", "BMES", "BEMSO", 'bioes'。 | |||
| :param str from_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||
| :param str from_label: 比如"PER", "LOC"等label | |||
| :param str to_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||
| :param str to_label: 比如"PER", "LOC"等label | |||
| :return: bool,能否跃迁 | |||
| """ | |||
| if to_tag == 'start' or from_tag == 'end': | |||
| return False | |||
| encoding_type = encoding_type.lower() | |||
| if encoding_type == 'bio': | |||
| r""" | |||
| 第一行是to_tag, 第一列是from_tag. y任意条件下可转,-只有在label相同时可转,n不可转 | |||
| +-------+---+---+---+-------+-----+ | |||
| | | B | I | O | start | end | | |||
| +-------+---+---+---+-------+-----+ | |||
| | B | y | - | y | n | y | | |||
| +-------+---+---+---+-------+-----+ | |||
| | I | y | - | y | n | y | | |||
| +-------+---+---+---+-------+-----+ | |||
| | O | y | n | y | n | y | | |||
| +-------+---+---+---+-------+-----+ | |||
| | start | y | n | y | n | n | | |||
| +-------+---+---+---+-------+-----+ | |||
| | end | n | n | n | n | n | | |||
| +-------+---+---+---+-------+-----+ | |||
| """ | |||
| if from_tag == 'start': | |||
| return to_tag in ('b', 'o') | |||
| elif from_tag in ['b', 'i']: | |||
| return any([to_tag in ['end', 'b', 'o'], to_tag == 'i' and from_label == to_label]) | |||
| elif from_tag == 'o': | |||
| return to_tag in ['end', 'b', 'o'] | |||
| else: | |||
| raise ValueError("Unexpect tag {}. Expect only 'B', 'I', 'O'.".format(from_tag)) | |||
| elif encoding_type == 'bmes': | |||
| r""" | |||
| 第一行是to_tag, 第一列是from_tag,y任意条件下可转,-只有在label相同时可转,n不可转 | |||
| +-------+---+---+---+---+-------+-----+ | |||
| | | B | M | E | S | start | end | | |||
| +-------+---+---+---+---+-------+-----+ | |||
| | B | n | - | - | n | n | n | | |||
| +-------+---+---+---+---+-------+-----+ | |||
| | M | n | - | - | n | n | n | | |||
| +-------+---+---+---+---+-------+-----+ | |||
| | E | y | n | n | y | n | y | | |||
| +-------+---+---+---+---+-------+-----+ | |||
| | S | y | n | n | y | n | y | | |||
| +-------+---+---+---+---+-------+-----+ | |||
| | start | y | n | n | y | n | n | | |||
| +-------+---+---+---+---+-------+-----+ | |||
| | end | n | n | n | n | n | n | | |||
| +-------+---+---+---+---+-------+-----+ | |||
| """ | |||
| if from_tag == 'start': | |||
| return to_tag in ['b', 's'] | |||
| elif from_tag == 'b': | |||
| return to_tag in ['m', 'e'] and from_label == to_label | |||
| elif from_tag == 'm': | |||
| return to_tag in ['m', 'e'] and from_label == to_label | |||
| elif from_tag in ['e', 's']: | |||
| return to_tag in ['b', 's', 'end'] | |||
| else: | |||
| raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S'.".format(from_tag)) | |||
| elif encoding_type == 'bmeso': | |||
| if from_tag == 'start': | |||
| return to_tag in ['b', 's', 'o'] | |||
| elif from_tag == 'b': | |||
| return to_tag in ['m', 'e'] and from_label == to_label | |||
| elif from_tag == 'm': | |||
| return to_tag in ['m', 'e'] and from_label == to_label | |||
| elif from_tag in ['e', 's', 'o']: | |||
| return to_tag in ['b', 's', 'end', 'o'] | |||
| else: | |||
| raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag)) | |||
| elif encoding_type == 'bioes': | |||
| if from_tag == 'start': | |||
| return to_tag in ['b', 's', 'o'] | |||
| elif from_tag == 'b': | |||
| return to_tag in ['i', 'e'] and from_label == to_label | |||
| elif from_tag == 'i': | |||
| return to_tag in ['i', 'e'] and from_label == to_label | |||
| elif from_tag in ['e', 's', 'o']: | |||
| return to_tag in ['b', 's', 'end', 'o'] | |||
| else: | |||
| raise ValueError("Unexpect tag type {}. Expect only 'B', 'I', 'E', 'S', 'O'.".format(from_tag)) | |||
| else: | |||
| raise ValueError("Only support BIO, BMES, BMESO, BIOES encoding type, got {}.".format(encoding_type)) | |||
| class ConditionalRandomField(nn.Module): | |||
| r""" | |||
| 条件随机场。提供 forward() 以及 viterbi_decode() 两个方法,分别用于训练与inference。 | |||
| """ | |||
| def __init__(self, num_tags:int, include_start_end_trans:bool=False, allowed_transitions:List=None): | |||
| r""" | |||
| :param num_tags: 标签的数量 | |||
| :param include_start_end_trans: 是否考虑各个tag作为开始以及结尾的分数。 | |||
| :param allowed_transitions: 内部的Tuple[from_tag_id(int), | |||
| to_tag_id(int)]视为允许发生的跃迁,其他没有包含的跃迁认为是禁止跃迁,可以通过 | |||
| allowed_transitions()函数得到;如果为None,则所有跃迁均为合法 | |||
| """ | |||
| super(ConditionalRandomField, self).__init__() | |||
| self.include_start_end_trans = include_start_end_trans | |||
| self.num_tags = num_tags | |||
| # the meaning of entry in this matrix is (from_tag_id, to_tag_id) score | |||
| self.trans_m = nn.Parameter(torch.randn(num_tags, num_tags)) | |||
| if self.include_start_end_trans: | |||
| self.start_scores = nn.Parameter(torch.randn(num_tags)) | |||
| self.end_scores = nn.Parameter(torch.randn(num_tags)) | |||
| if allowed_transitions is None: | |||
| constrain = torch.zeros(num_tags + 2, num_tags + 2) | |||
| else: | |||
| constrain = torch.full((num_tags + 2, num_tags + 2), fill_value=-10000.0, dtype=torch.float) | |||
| has_start = False | |||
| has_end = False | |||
| for from_tag_id, to_tag_id in allowed_transitions: | |||
| constrain[from_tag_id, to_tag_id] = 0 | |||
| if from_tag_id==num_tags: | |||
| has_start = True | |||
| if to_tag_id==num_tags+1: | |||
| has_end = True | |||
| if not has_start: | |||
| constrain[num_tags, :].fill_(0) | |||
| if not has_end: | |||
| constrain[:, num_tags+1].fill_(0) | |||
| self._constrain = nn.Parameter(constrain, requires_grad=False) | |||
| def _normalizer_likelihood(self, logits, mask): | |||
| r"""Computes the (batch_size,) denominator term for the log-likelihood, which is the | |||
| sum of the likelihoods across all possible state sequences. | |||
| :param logits:FloatTensor, max_len x batch_size x num_tags | |||
| :param mask:ByteTensor, max_len x batch_size | |||
| :return:FloatTensor, batch_size | |||
| """ | |||
| seq_len, batch_size, n_tags = logits.size() | |||
| alpha = logits[0] | |||
| if self.include_start_end_trans: | |||
| alpha = alpha + self.start_scores.view(1, -1) | |||
| flip_mask = mask.eq(False) | |||
| for i in range(1, seq_len): | |||
| emit_score = logits[i].view(batch_size, 1, n_tags) | |||
| trans_score = self.trans_m.view(1, n_tags, n_tags) | |||
| tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score | |||
| alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ | |||
| alpha.masked_fill(mask[i].eq(True).view(batch_size, 1), 0) | |||
| if self.include_start_end_trans: | |||
| alpha = alpha + self.end_scores.view(1, -1) | |||
| return torch.logsumexp(alpha, 1) | |||
| def _gold_score(self, logits, tags, mask): | |||
| r""" | |||
| Compute the score for the gold path. | |||
| :param logits: FloatTensor, max_len x batch_size x num_tags | |||
| :param tags: LongTensor, max_len x batch_size | |||
| :param mask: ByteTensor, max_len x batch_size | |||
| :return:FloatTensor, batch_size | |||
| """ | |||
| seq_len, batch_size, _ = logits.size() | |||
| batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) | |||
| seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) | |||
| # trans_socre [L-1, B] | |||
| mask = mask.eq(True) | |||
| flip_mask = mask.eq(False) | |||
| trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]].masked_fill(flip_mask[1:, :], 0) | |||
| # emit_score [L, B] | |||
| emit_score = logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags].masked_fill(flip_mask, 0) | |||
| # score [L-1, B] | |||
| score = trans_score + emit_score[:seq_len - 1, :] | |||
| score = score.sum(0) + emit_score[-1].masked_fill(flip_mask[-1], 0) | |||
| if self.include_start_end_trans: | |||
| st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] | |||
| last_idx = mask.long().sum(0) - 1 | |||
| ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] | |||
| score = score + st_scores + ed_scores | |||
| # return [B,] | |||
| return score | |||
| def forward(self, feats, tags, mask): | |||
| r""" | |||
| 用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。 | |||
| :param torch.FloatTensor feats: batch_size x max_len x num_tags,特征矩阵。 | |||
| :param torch.LongTensor tags: batch_size x max_len,标签矩阵。 | |||
| :param torch.ByteTensor mask: batch_size x max_len,为0的位置认为是padding。 | |||
| :return: torch.FloatTensor, (batch_size,) | |||
| """ | |||
| feats = feats.transpose(0, 1) | |||
| tags = tags.transpose(0, 1).long() | |||
| mask = mask.transpose(0, 1).float() | |||
| all_path_score = self._normalizer_likelihood(feats, mask) | |||
| gold_path_score = self._gold_score(feats, tags, mask) | |||
| return all_path_score - gold_path_score | |||
| def viterbi_decode(self, logits, mask, unpad=False): | |||
| r"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 | |||
| :param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。 | |||
| :param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 | |||
| :param bool unpad: 是否将结果删去padding。False, 返回的是batch_size x max_len的tensor; True,返回的是 | |||
| List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是这 | |||
| 个sample的有效长度。 | |||
| :return: 返回 (paths, scores)。 | |||
| paths: 是解码后的路径, 其值参照unpad参数. | |||
| scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。 | |||
| """ | |||
| batch_size, max_len, n_tags = logits.size() | |||
| seq_len = mask.long().sum(1) | |||
| logits = logits.transpose(0, 1).data # L, B, H | |||
| mask = mask.transpose(0, 1).data.eq(True) # L, B | |||
| flip_mask = mask.eq(False) | |||
| # dp | |||
| vpath = logits.new_zeros((max_len, batch_size, n_tags), dtype=torch.long) | |||
| vscore = logits[0] # bsz x n_tags | |||
| transitions = self._constrain.data.clone() | |||
| transitions[:n_tags, :n_tags] += self.trans_m.data | |||
| if self.include_start_end_trans: | |||
| transitions[n_tags, :n_tags] += self.start_scores.data | |||
| transitions[:n_tags, n_tags + 1] += self.end_scores.data | |||
| vscore += transitions[n_tags, :n_tags] | |||
| trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data | |||
| end_trans_score = transitions[:n_tags, n_tags+1].view(1, 1, n_tags).repeat(batch_size, 1, 1) # bsz, 1, n_tags | |||
| # 针对长度为1的句子 | |||
| vscore += transitions[:n_tags, n_tags+1].view(1, n_tags).repeat(batch_size, 1) \ | |||
| .masked_fill(seq_len.ne(1).view(-1, 1), 0) | |||
| for i in range(1, max_len): | |||
| prev_score = vscore.view(batch_size, n_tags, 1) | |||
| cur_score = logits[i].view(batch_size, 1, n_tags) + trans_score | |||
| score = prev_score + cur_score.masked_fill(flip_mask[i].view(batch_size, 1, 1), 0) # bsz x n_tag x n_tag | |||
| # 需要考虑当前位置是该序列的最后一个 | |||
| score += end_trans_score.masked_fill(seq_len.ne(i+1).view(-1, 1, 1), 0) | |||
| best_score, best_dst = score.max(1) | |||
| vpath[i] = best_dst | |||
| # 由于最终是通过last_tags回溯,需要保持每个位置的vscore情况 | |||
| vscore = best_score.masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ | |||
| vscore.masked_fill(mask[i].view(batch_size, 1), 0) | |||
| # backtrace | |||
| batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) | |||
| seq_idx = torch.arange(max_len, dtype=torch.long, device=logits.device) | |||
| lens = (seq_len - 1) | |||
| # idxes [L, B], batched idx from seq_len-1 to 0 | |||
| idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % max_len | |||
| ans = logits.new_empty((max_len, batch_size), dtype=torch.long) | |||
| ans_score, last_tags = vscore.max(1) | |||
| ans[idxes[0], batch_idx] = last_tags | |||
| for i in range(max_len - 1): | |||
| last_tags = vpath[idxes[i], batch_idx, last_tags] | |||
| ans[idxes[i + 1], batch_idx] = last_tags | |||
| ans = ans.transpose(0, 1) | |||
| if unpad: | |||
| paths = [] | |||
| for idx, max_len in enumerate(lens): | |||
| paths.append(ans[idx, :max_len + 1].tolist()) | |||
| else: | |||
| paths = ans | |||
| return paths, ans_score | |||
| @@ -0,0 +1,416 @@ | |||
| r"""undocumented""" | |||
| from typing import Union, Tuple | |||
| import math | |||
| import torch | |||
| from torch import nn | |||
| import torch.nn.functional as F | |||
| from ..attention import AttentionLayer, MultiHeadAttention | |||
| from ....embeddings.torch.utils import get_embeddings | |||
| from ....embeddings.torch.static_embedding import StaticEmbedding | |||
| from .seq2seq_state import State, LSTMState, TransformerState | |||
| __all__ = ['Seq2SeqDecoder', 'TransformerSeq2SeqDecoder', 'LSTMSeq2SeqDecoder'] | |||
| class Seq2SeqDecoder(nn.Module): | |||
| """ | |||
| Sequence-to-Sequence Decoder的基类。一定需要实现forward、decode函数,剩下的函数根据需要实现。每个Seq2SeqDecoder都应该有相应的State对象 | |||
| 用来承载该Decoder所需要的Encoder输出、Decoder需要记录的历史信息(例如LSTM的hidden信息)。 | |||
| """ | |||
| def __init__(self): | |||
| super().__init__() | |||
| def forward(self, tokens, state, **kwargs): | |||
| """ | |||
| :param torch.LongTensor tokens: bsz x max_len | |||
| :param State state: state包含了encoder的输出以及decode之前的内容 | |||
| :return: 返回值可以为bsz x max_len x vocab_size的Tensor,也可以是一个list,但是第一个元素必须是词的预测分布 | |||
| """ | |||
| raise NotImplemented | |||
| def reorder_states(self, indices, states): | |||
| """ | |||
| 根据indices重新排列states中的状态,在beam search进行生成时,会用到该函数。 | |||
| :param torch.LongTensor indices: | |||
| :param State states: | |||
| :return: | |||
| """ | |||
| assert isinstance(states, State), f"`states` should be of type State instead of {type(states)}" | |||
| states.reorder_state(indices) | |||
| def init_state(self, encoder_output, encoder_mask): | |||
| """ | |||
| 初始化一个state对象,用来记录了encoder的输出以及decode已经完成的部分。 | |||
| :param Union[torch.Tensor, list, tuple] encoder_output: 如果不为None,内部元素需要为torch.Tensor, 默认其中第一维是batch | |||
| 维度 | |||
| :param Union[torch.Tensor, list, tuple] encoder_mask: 如果部位None,内部元素需要torch.Tensor, 默认其中第一维是batch | |||
| 维度 | |||
| :param kwargs: | |||
| :return: State, 返回一个State对象,记录了encoder的输出 | |||
| """ | |||
| state = State(encoder_output, encoder_mask) | |||
| return state | |||
| def decode(self, tokens, state): | |||
| """ | |||
| 根据states中的内容,以及tokens中的内容进行之后的生成。 | |||
| :param torch.LongTensor tokens: bsz x max_len, 截止到上一个时刻所有的token输出。 | |||
| :param State state: 记录了encoder输出与decoder过去状态 | |||
| :return: torch.FloatTensor: bsz x vocab_size, 输出的是下一个时刻的分布 | |||
| """ | |||
| outputs = self(state=state, tokens=tokens) | |||
| if isinstance(outputs, torch.Tensor): | |||
| return outputs[:, -1] | |||
| else: | |||
| raise RuntimeError("Unrecognized output from the `forward()` function. Please override the `decode()` function.") | |||
| class TiedEmbedding(nn.Module): | |||
| """ | |||
| 用于将weight和原始weight绑定 | |||
| """ | |||
| def __init__(self, weight): | |||
| super().__init__() | |||
| self.weight = weight # vocab_size x embed_size | |||
| def forward(self, x): | |||
| """ | |||
| :param torch.FloatTensor x: bsz x * x embed_size | |||
| :return: torch.FloatTensor bsz x * x vocab_size | |||
| """ | |||
| return torch.matmul(x, self.weight.t()) | |||
| def get_bind_decoder_output_embed(embed): | |||
| """ | |||
| 给定一个embedding,输出对应的绑定的embedding,输出对象为TiedEmbedding | |||
| :param embed: | |||
| :return: | |||
| """ | |||
| if isinstance(embed, StaticEmbedding): | |||
| for idx, map2idx in enumerate(embed.words_to_words): | |||
| assert idx == map2idx, "Invalid StaticEmbedding for Decoder, please check:(1) whether the vocabulary " \ | |||
| "include `no_create_entry=True` word; (2) StaticEmbedding should not initialize with " \ | |||
| "`lower=True` or `min_freq!=1`." | |||
| elif not isinstance(embed, nn.Embedding): | |||
| raise TypeError("Only nn.Embedding or StaticEmbedding is allowed for binding.") | |||
| return TiedEmbedding(embed.weight) | |||
| class LSTMSeq2SeqDecoder(Seq2SeqDecoder): | |||
| """ | |||
| LSTM的Decoder | |||
| :param nn.Module,tuple embed: decoder输入的embedding. | |||
| :param int num_layers: 多少层LSTM | |||
| :param int hidden_size: 隐藏层大小, 该值也被认为是encoder的输出维度大小 | |||
| :param dropout: Dropout的大小 | |||
| :param bool bind_decoder_input_output_embed: 是否将输出层和输入层的词向量绑定在一起(即为同一个),若embed为StaticEmbedding, | |||
| 则StaticEmbedding的vocab不能包含no_create_entry的token,同时StaticEmbedding初始化时lower为False, min_freq=1. | |||
| :param bool attention: 是否使用attention | |||
| """ | |||
| def __init__(self, embed: Union[nn.Module, Tuple[int, int]], num_layers = 3, hidden_size = 300, | |||
| dropout = 0.3, bind_decoder_input_output_embed = True, attention=True): | |||
| super().__init__() | |||
| self.embed = get_embeddings(init_embed=embed) | |||
| self.embed_dim = embed.embedding_dim | |||
| if bind_decoder_input_output_embed: | |||
| self.output_layer = get_bind_decoder_output_embed(self.embed) | |||
| else: # 不需要bind | |||
| self.output_embed = get_embeddings((self.embed.num_embeddings, self.embed.embedding_dim)) | |||
| self.output_layer = TiedEmbedding(self.output_embed.weight) | |||
| self.hidden_size = hidden_size | |||
| self.num_layers = num_layers | |||
| self.lstm = nn.LSTM(input_size=self.embed_dim + hidden_size, hidden_size=hidden_size, num_layers=num_layers, | |||
| batch_first=True, bidirectional=False, dropout=dropout if num_layers>1 else 0) | |||
| self.attention_layer = AttentionLayer(hidden_size, hidden_size, hidden_size) if attention else None | |||
| self.output_proj = nn.Linear(hidden_size, self.embed_dim) | |||
| self.dropout_layer = nn.Dropout(dropout) | |||
| def forward(self, tokens, state, return_attention=False): | |||
| """ | |||
| :param torch.LongTensor tokens: batch x max_len | |||
| :param LSTMState state: 保存encoder输出和decode状态的State对象 | |||
| :param bool return_attention: 是否返回attention的的score | |||
| :return: bsz x max_len x vocab_size; 如果return_attention=True, 还会返回bsz x max_len x encode_length | |||
| """ | |||
| src_output = state.encoder_output | |||
| encoder_mask = state.encoder_mask | |||
| assert tokens.size(1)>state.decode_length, "The state does not match the tokens." | |||
| tokens = tokens[:, state.decode_length:] | |||
| x = self.embed(tokens) | |||
| attn_weights = [] if self.attention_layer is not None else None # 保存attention weight, batch,tgt_seq,src_seq | |||
| input_feed = state.input_feed | |||
| decoder_out = [] | |||
| cur_hidden = state.hidden | |||
| cur_cell = state.cell | |||
| # 开始计算 | |||
| for i in range(tokens.size(1)): | |||
| input = torch.cat( | |||
| (x[:, i:i + 1, :], | |||
| input_feed[:, None, :] | |||
| ), | |||
| dim=2 | |||
| ) # batch,1,2*dim | |||
| _, (cur_hidden, cur_cell) = self.lstm(input, hx=(cur_hidden, cur_cell)) # hidden/cell保持原来的size | |||
| if self.attention_layer is not None: | |||
| input_feed, attn_weight = self.attention_layer(cur_hidden[-1], src_output, encoder_mask) | |||
| attn_weights.append(attn_weight) | |||
| else: | |||
| input_feed = cur_hidden[-1] | |||
| state.input_feed = input_feed # batch, hidden | |||
| state.hidden = cur_hidden | |||
| state.cell = cur_cell | |||
| state.decode_length += 1 | |||
| decoder_out.append(input_feed) | |||
| decoder_out = torch.stack(decoder_out, dim=1) # batch,seq_len,hidden | |||
| decoder_out = self.dropout_layer(decoder_out) | |||
| if attn_weights is not None: | |||
| attn_weights = torch.cat(attn_weights, dim=1) # batch, tgt_len, src_len | |||
| decoder_out = self.output_proj(decoder_out) | |||
| feats = self.output_layer(decoder_out) | |||
| if return_attention: | |||
| return feats, attn_weights | |||
| return feats | |||
| def init_state(self, encoder_output, encoder_mask) -> LSTMState: | |||
| """ | |||
| :param encoder_output: 输入可以有两种情况(1) 输入为一个tuple,包含三个内容(encoder_output, (hidden, cell)),其中encoder_output: | |||
| bsz x max_len x hidden_size, hidden: bsz x hidden_size, cell:bsz x hidden_size,一般使用LSTMEncoder的最后一层的 | |||
| hidden state和cell state来赋值这两个值 | |||
| (2) 只有encoder_output: bsz x max_len x hidden_size, 这种情况下hidden和cell使用0初始化 | |||
| :param torch.ByteTensor encoder_mask: bsz x max_len, 为0的位置是padding, 用来指示source中哪些不需要attend | |||
| :return: | |||
| """ | |||
| if not isinstance(encoder_output, torch.Tensor): | |||
| encoder_output, (hidden, cell) = encoder_output | |||
| else: | |||
| hidden = cell = None | |||
| assert encoder_output.ndim==3 | |||
| assert encoder_mask.size()==encoder_output.size()[:2] | |||
| assert encoder_output.size(-1)==self.hidden_size, "The dimension of encoder outputs should be the same with " \ | |||
| "the hidden_size." | |||
| t = [hidden, cell] | |||
| for idx in range(2): | |||
| v = t[idx] | |||
| if v is None: | |||
| v = encoder_output.new_zeros(self.num_layers, encoder_output.size(0), self.hidden_size) | |||
| else: | |||
| assert v.dim()==2 | |||
| assert v.size(-1)==self.hidden_size | |||
| v = v[None].repeat(self.num_layers, 1, 1) # num_layers x bsz x hidden_size | |||
| t[idx] = v | |||
| state = LSTMState(encoder_output, encoder_mask, t[0], t[1]) | |||
| return state | |||
| class TransformerSeq2SeqDecoderLayer(nn.Module): | |||
| """ | |||
| :param int d_model: 输入、输出的维度 | |||
| :param int n_head: 多少个head,需要能被d_model整除 | |||
| :param int dim_ff: | |||
| :param float dropout: | |||
| :param int layer_idx: layer的编号 | |||
| """ | |||
| def __init__(self, d_model = 512, n_head = 8, dim_ff = 2048, dropout = 0.1, layer_idx = None): | |||
| super().__init__() | |||
| self.d_model = d_model | |||
| self.n_head = n_head | |||
| self.dim_ff = dim_ff | |||
| self.dropout = dropout | |||
| self.layer_idx = layer_idx # 记录layer的层索引,以方便获取state的信息 | |||
| self.self_attn = MultiHeadAttention(d_model, n_head, dropout, layer_idx) | |||
| self.self_attn_layer_norm = nn.LayerNorm(d_model) | |||
| self.encoder_attn = MultiHeadAttention(d_model, n_head, dropout, layer_idx) | |||
| self.encoder_attn_layer_norm = nn.LayerNorm(d_model) | |||
| self.ffn = nn.Sequential(nn.Linear(self.d_model, self.dim_ff), | |||
| nn.ReLU(), | |||
| nn.Dropout(dropout), | |||
| nn.Linear(self.dim_ff, self.d_model), | |||
| nn.Dropout(dropout)) | |||
| self.final_layer_norm = nn.LayerNorm(self.d_model) | |||
| def forward(self, x, encoder_output, encoder_mask=None, self_attn_mask=None, state=None): | |||
| """ | |||
| :param x: (batch, seq_len, dim), decoder端的输入 | |||
| :param encoder_output: (batch,src_seq_len,dim), encoder的输出 | |||
| :param encoder_mask: batch,src_seq_len, 为1的地方需要attend | |||
| :param self_attn_mask: seq_len, seq_len,下三角的mask矩阵,只在训练时传入 | |||
| :param TransformerState state: 只在inference阶段传入 | |||
| :return: | |||
| """ | |||
| # self attention part | |||
| residual = x | |||
| x = self.self_attn_layer_norm(x) | |||
| x, _ = self.self_attn(query=x, | |||
| key=x, | |||
| value=x, | |||
| attn_mask=self_attn_mask, | |||
| state=state) | |||
| x = F.dropout(x, p=self.dropout, training=self.training) | |||
| x = residual + x | |||
| # encoder attention part | |||
| residual = x | |||
| x = self.encoder_attn_layer_norm(x) | |||
| x, attn_weight = self.encoder_attn(query=x, | |||
| key=encoder_output, | |||
| value=encoder_output, | |||
| key_mask=encoder_mask, | |||
| state=state) | |||
| x = F.dropout(x, p=self.dropout, training=self.training) | |||
| x = residual + x | |||
| # ffn | |||
| residual = x | |||
| x = self.final_layer_norm(x) | |||
| x = self.ffn(x) | |||
| x = residual + x | |||
| return x, attn_weight | |||
| class TransformerSeq2SeqDecoder(Seq2SeqDecoder): | |||
| """ | |||
| :param embed: 输入token的embedding | |||
| :param nn.Module pos_embed: 位置embedding | |||
| :param int d_model: 输出、输出的大小 | |||
| :param int num_layers: 多少层 | |||
| :param int n_head: 多少个head | |||
| :param int dim_ff: FFN 的中间大小 | |||
| :param float dropout: Self-Attention和FFN中的dropout的大小 | |||
| :param bool bind_decoder_input_output_embed: 是否将输出层和输入层的词向量绑定在一起(即为同一个),若embed为StaticEmbedding, | |||
| 则StaticEmbedding的vocab不能包含no_create_entry的token,同时StaticEmbedding初始化时lower为False, min_freq=1. | |||
| """ | |||
| def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], pos_embed: nn.Module = None, | |||
| d_model = 512, num_layers=6, n_head = 8, dim_ff = 2048, dropout = 0.1, | |||
| bind_decoder_input_output_embed = True): | |||
| super().__init__() | |||
| self.embed = get_embeddings(embed) | |||
| self.pos_embed = pos_embed | |||
| if bind_decoder_input_output_embed: | |||
| self.output_layer = get_bind_decoder_output_embed(self.embed) | |||
| else: # 不需要bind | |||
| self.output_embed = get_embeddings((self.embed.num_embeddings, self.embed.embedding_dim)) | |||
| self.output_layer = TiedEmbedding(self.output_embed.weight) | |||
| self.num_layers = num_layers | |||
| self.d_model = d_model | |||
| self.n_head = n_head | |||
| self.dim_ff = dim_ff | |||
| self.dropout = dropout | |||
| self.input_fc = nn.Linear(self.embed.embedding_dim, d_model) | |||
| self.layer_stacks = nn.ModuleList([TransformerSeq2SeqDecoderLayer(d_model, n_head, dim_ff, dropout, layer_idx) | |||
| for layer_idx in range(num_layers)]) | |||
| self.embed_scale = math.sqrt(d_model) | |||
| self.layer_norm = nn.LayerNorm(d_model) | |||
| self.output_fc = nn.Linear(self.d_model, self.embed.embedding_dim) | |||
| def forward(self, tokens, state, return_attention=False): | |||
| """ | |||
| :param torch.LongTensor tokens: batch x tgt_len,decode的词 | |||
| :param TransformerState state: 用于记录encoder的输出以及decode状态的对象,可以通过init_state()获取 | |||
| :param bool return_attention: 是否返回对encoder结果的attention score | |||
| :return: bsz x max_len x vocab_size; 如果return_attention=True, 还会返回bsz x max_len x encode_length | |||
| """ | |||
| encoder_output = state.encoder_output | |||
| encoder_mask = state.encoder_mask | |||
| assert state.decode_length<tokens.size(1), "The decoded tokens in State should be less than tokens." | |||
| tokens = tokens[:, state.decode_length:] | |||
| device = tokens.device | |||
| x = self.embed_scale * self.embed(tokens) | |||
| if self.pos_embed is not None: | |||
| position = torch.arange(state.decode_length, state.decode_length+tokens.size(1)).long().to(device)[None] | |||
| x += self.pos_embed(position) | |||
| x = self.input_fc(x) | |||
| x = F.dropout(x, p=self.dropout, training=self.training) | |||
| batch_size, max_tgt_len = tokens.size() | |||
| if max_tgt_len>1: | |||
| triangle_mask = self._get_triangle_mask(tokens) | |||
| else: | |||
| triangle_mask = None | |||
| for layer in self.layer_stacks: | |||
| x, attn_weight = layer(x=x, | |||
| encoder_output=encoder_output, | |||
| encoder_mask=encoder_mask, | |||
| self_attn_mask=triangle_mask, | |||
| state=state | |||
| ) | |||
| x = self.layer_norm(x) # batch, tgt_len, dim | |||
| x = self.output_fc(x) | |||
| feats = self.output_layer(x) | |||
| if return_attention: | |||
| return feats, attn_weight | |||
| return feats | |||
| def init_state(self, encoder_output, encoder_mask): | |||
| """ | |||
| 初始化一个TransformerState用于forward | |||
| :param torch.FloatTensor encoder_output: bsz x max_len x d_model, encoder的输出 | |||
| :param torch.ByteTensor encoder_mask: bsz x max_len, 为1的位置需要attend。 | |||
| :return: TransformerState | |||
| """ | |||
| if isinstance(encoder_output, torch.Tensor): | |||
| encoder_output = encoder_output | |||
| elif isinstance(encoder_output, (list, tuple)): | |||
| encoder_output = encoder_output[0] # 防止是LSTMEncoder的输出结果 | |||
| else: | |||
| raise TypeError("Unsupported `encoder_output` for TransformerSeq2SeqDecoder") | |||
| state = TransformerState(encoder_output, encoder_mask, num_decoder_layer=self.num_layers) | |||
| return state | |||
| @staticmethod | |||
| def _get_triangle_mask(tokens): | |||
| tensor = tokens.new_ones(tokens.size(1), tokens.size(1)) | |||
| return torch.tril(tensor).byte() | |||
| @@ -0,0 +1,145 @@ | |||
| r""" | |||
| 每个Decoder都有对应的State用来记录encoder的输出以及Decode的历史记录 | |||
| """ | |||
| __all__ = [ | |||
| 'State', | |||
| "LSTMState", | |||
| "TransformerState" | |||
| ] | |||
| from typing import Union | |||
| import torch | |||
| class State: | |||
| def __init__(self, encoder_output=None, encoder_mask=None, **kwargs): | |||
| """ | |||
| 每个Decoder都有对应的State对象用来承载encoder的输出以及当前时刻之前的decode状态。 | |||
| :param Union[torch.Tensor, list, tuple] encoder_output: 如果不为None,内部元素需要为torch.Tensor, 默认其中第一维是batch | |||
| 维度 | |||
| :param Union[torch.Tensor, list, tuple] encoder_mask: 如果部位None,内部元素需要torch.Tensor, 默认其中第一维是batch | |||
| 维度 | |||
| :param kwargs: | |||
| """ | |||
| self.encoder_output = encoder_output | |||
| self.encoder_mask = encoder_mask | |||
| self._decode_length = 0 | |||
| @property | |||
| def num_samples(self): | |||
| """ | |||
| 返回的State中包含的是多少个sample的encoder状态,主要用于Generate的时候确定batch的大小。 | |||
| :return: | |||
| """ | |||
| if self.encoder_output is not None: | |||
| return self.encoder_output.size(0) | |||
| else: | |||
| return None | |||
| @property | |||
| def decode_length(self): | |||
| """ | |||
| 当前Decode到哪个token了,decoder只会从decode_length之后的token开始decode, 为0说明还没开始decode。 | |||
| :return: | |||
| """ | |||
| return self._decode_length | |||
| @decode_length.setter | |||
| def decode_length(self, value): | |||
| self._decode_length = value | |||
| def _reorder_state(self, state: Union[torch.Tensor, list, tuple], indices: torch.LongTensor, dim: int = 0): | |||
| if isinstance(state, torch.Tensor): | |||
| state = state.index_select(index=indices, dim=dim) | |||
| elif isinstance(state, list): | |||
| for i in range(len(state)): | |||
| assert state[i] is not None | |||
| state[i] = self._reorder_state(state[i], indices, dim) | |||
| elif isinstance(state, tuple): | |||
| tmp_list = [] | |||
| for i in range(len(state)): | |||
| assert state[i] is not None | |||
| tmp_list.append(self._reorder_state(state[i], indices, dim)) | |||
| state = tuple(tmp_list) | |||
| else: | |||
| raise TypeError(f"Cannot reorder data of type:{type(state)}") | |||
| return state | |||
| def reorder_state(self, indices: torch.LongTensor): | |||
| if self.encoder_mask is not None: | |||
| self.encoder_mask = self._reorder_state(self.encoder_mask, indices) | |||
| if self.encoder_output is not None: | |||
| self.encoder_output = self._reorder_state(self.encoder_output, indices) | |||
| class LSTMState(State): | |||
| def __init__(self, encoder_output, encoder_mask, hidden, cell): | |||
| """ | |||
| LSTMDecoder对应的State,保存encoder的输出以及LSTM解码过程中的一些中间状态 | |||
| :param torch.FloatTensor encoder_output: bsz x src_seq_len x encode_output_size,encoder的输出 | |||
| :param torch.BoolTensor encoder_mask: bsz x src_seq_len, 为0的地方是padding | |||
| :param torch.FloatTensor hidden: num_layers x bsz x hidden_size, 上个时刻的hidden状态 | |||
| :param torch.FloatTensor cell: num_layers x bsz x hidden_size, 上个时刻的cell状态 | |||
| """ | |||
| super().__init__(encoder_output, encoder_mask) | |||
| self.hidden = hidden | |||
| self.cell = cell | |||
| self._input_feed = hidden[0] # 默认是上一个时刻的输出 | |||
| @property | |||
| def input_feed(self): | |||
| """ | |||
| LSTMDecoder中每个时刻的输入会把上个token的embedding和input_feed拼接起来输入到下个时刻,在LSTMDecoder不使用attention时, | |||
| input_feed即上个时刻的hidden state, 否则是attention layer的输出。 | |||
| :return: torch.FloatTensor, bsz x hidden_size | |||
| """ | |||
| return self._input_feed | |||
| @input_feed.setter | |||
| def input_feed(self, value): | |||
| self._input_feed = value | |||
| def reorder_state(self, indices: torch.LongTensor): | |||
| super().reorder_state(indices) | |||
| self.hidden = self._reorder_state(self.hidden, indices, dim=1) | |||
| self.cell = self._reorder_state(self.cell, indices, dim=1) | |||
| if self.input_feed is not None: | |||
| self.input_feed = self._reorder_state(self.input_feed, indices, dim=0) | |||
| class TransformerState(State): | |||
| def __init__(self, encoder_output, encoder_mask, num_decoder_layer): | |||
| """ | |||
| 与TransformerSeq2SeqDecoder对应的State, | |||
| :param torch.FloatTensor encoder_output: bsz x encode_max_len x encoder_output_size, encoder的输出 | |||
| :param torch.ByteTensor encoder_mask: bsz x encode_max_len 为1的地方需要attend | |||
| :param int num_decoder_layer: decode有多少层 | |||
| """ | |||
| super().__init__(encoder_output, encoder_mask) | |||
| self.encoder_key = [None] * num_decoder_layer # 每一个元素 bsz x encoder_max_len x key_dim | |||
| self.encoder_value = [None] * num_decoder_layer # 每一个元素 bsz x encoder_max_len x value_dim | |||
| self.decoder_prev_key = [None] * num_decoder_layer # 每一个元素 bsz x decode_length x key_dim | |||
| self.decoder_prev_value = [None] * num_decoder_layer # 每一个元素 bsz x decode_length x key_dim | |||
| def reorder_state(self, indices: torch.LongTensor): | |||
| super().reorder_state(indices) | |||
| self.encoder_key = self._reorder_state(self.encoder_key, indices) | |||
| self.encoder_value = self._reorder_state(self.encoder_value, indices) | |||
| self.decoder_prev_key = self._reorder_state(self.decoder_prev_key, indices) | |||
| self.decoder_prev_value = self._reorder_state(self.decoder_prev_value, indices) | |||
| @property | |||
| def decode_length(self): | |||
| if self.decoder_prev_key[0] is not None: | |||
| return self.decoder_prev_key[0].size(1) | |||
| return 0 | |||
| @@ -0,0 +1,24 @@ | |||
| r"""undocumented""" | |||
| __all__ = [ | |||
| "TimestepDropout" | |||
| ] | |||
| import torch | |||
| class TimestepDropout(torch.nn.Dropout): | |||
| r""" | |||
| 传入参数的shape为 ``(batch_size, num_timesteps, embedding_dim)`` | |||
| 使用同一个shape为 ``(batch_size, embedding_dim)`` 的mask在每个timestamp上做dropout。 | |||
| """ | |||
| def forward(self, x): | |||
| dropout_mask = x.new_ones(x.shape[0], x.shape[-1]) | |||
| torch.nn.functional.dropout(dropout_mask, self.p, self.training, inplace=True) | |||
| dropout_mask = dropout_mask.unsqueeze(1) # [batch_size, 1, embedding_dim] | |||
| if self.inplace: | |||
| x *= dropout_mask | |||
| return | |||
| else: | |||
| return x * dropout_mask | |||
| @@ -1,5 +1,21 @@ | |||
| __all__ = [ | |||
| "ConvMaxpool", | |||
| "LSTM", | |||
| "Seq2SeqEncoder", | |||
| "TransformerSeq2SeqEncoder", | |||
| "LSTMSeq2SeqEncoder", | |||
| "StarTransformer", | |||
| "VarRNN", | |||
| "VarLSTM", | |||
| "VarGRU" | |||
| ] | |||
| from .lstm import LSTM | |||
| from .conv_maxpool import ConvMaxpool | |||
| from .lstm import LSTM | |||
| from .seq2seq_encoder import Seq2SeqEncoder, TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder | |||
| from .star_transformer import StarTransformer | |||
| from .variational_rnn import VarRNN, VarLSTM, VarGRU | |||
| @@ -0,0 +1,87 @@ | |||
| r"""undocumented""" | |||
| __all__ = [ | |||
| "ConvMaxpool" | |||
| ] | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class ConvMaxpool(nn.Module): | |||
| r""" | |||
| 集合了Convolution和Max-Pooling于一体的层。给定一个batch_size x max_len x input_size的输入,返回batch_size x | |||
| sum(output_channels) 大小的matrix。在内部,是先使用CNN给输入做卷积,然后经过activation激活层,在通过在长度(max_len) | |||
| 这一维进行max_pooling。最后得到每个sample的一个向量表示。 | |||
| """ | |||
| def __init__(self, in_channels, out_channels, kernel_sizes, activation="relu"): | |||
| r""" | |||
| :param int in_channels: 输入channel的大小,一般是embedding的维度; 或encoder的output维度 | |||
| :param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致 | |||
| :param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 | |||
| :param str activation: Convolution后的结果将通过该activation后再经过max-pooling。支持relu, sigmoid, tanh | |||
| """ | |||
| super(ConvMaxpool, self).__init__() | |||
| for kernel_size in kernel_sizes: | |||
| assert kernel_size % 2 == 1, "kernel size has to be odd numbers." | |||
| # convolution | |||
| if isinstance(kernel_sizes, (list, tuple, int)): | |||
| if isinstance(kernel_sizes, int) and isinstance(out_channels, int): | |||
| out_channels = [out_channels] | |||
| kernel_sizes = [kernel_sizes] | |||
| elif isinstance(kernel_sizes, (tuple, list)) and isinstance(out_channels, (tuple, list)): | |||
| assert len(out_channels) == len( | |||
| kernel_sizes), "The number of out_channels should be equal to the number" \ | |||
| " of kernel_sizes." | |||
| else: | |||
| raise ValueError("The type of out_channels and kernel_sizes should be the same.") | |||
| self.convs = nn.ModuleList([nn.Conv1d( | |||
| in_channels=in_channels, | |||
| out_channels=oc, | |||
| kernel_size=ks, | |||
| stride=1, | |||
| padding=ks // 2, | |||
| dilation=1, | |||
| groups=1, | |||
| bias=False) | |||
| for oc, ks in zip(out_channels, kernel_sizes)]) | |||
| else: | |||
| raise Exception( | |||
| 'Incorrect kernel sizes: should be list, tuple or int') | |||
| # activation function | |||
| if activation == 'relu': | |||
| self.activation = F.relu | |||
| elif activation == 'sigmoid': | |||
| self.activation = F.sigmoid | |||
| elif activation == 'tanh': | |||
| self.activation = F.tanh | |||
| else: | |||
| raise Exception( | |||
| "Undefined activation function: choose from: relu, tanh, sigmoid") | |||
| def forward(self, x, mask=None): | |||
| r""" | |||
| :param torch.FloatTensor x: batch_size x max_len x input_size, 一般是经过embedding后的值 | |||
| :param mask: batch_size x max_len, pad的地方为0。不影响卷积运算,max-pool一定不会pool到pad为0的位置 | |||
| :return: | |||
| """ | |||
| # [N,L,C] -> [N,C,L] | |||
| x = torch.transpose(x, 1, 2) | |||
| # convolution | |||
| xs = [self.activation(conv(x)) for conv in self.convs] # [[N,C,L], ...] | |||
| if mask is not None: | |||
| mask = mask.unsqueeze(1) # B x 1 x L | |||
| xs = [x.masked_fill(mask.eq(False), float('-inf')) for x in xs] | |||
| # max-pooling | |||
| xs = [F.max_pool1d(input=i, kernel_size=i.size(2)).squeeze(2) | |||
| for i in xs] # [[N, C], ...] | |||
| return torch.cat(xs, dim=-1) # [N, C] | |||
| @@ -0,0 +1,193 @@ | |||
| r"""undocumented""" | |||
| import torch.nn as nn | |||
| import torch | |||
| from torch.nn import LayerNorm | |||
| import torch.nn.functional as F | |||
| from typing import Union, Tuple | |||
| from ....core.utils import seq_len_to_mask | |||
| import math | |||
| from .lstm import LSTM | |||
| from ..attention import MultiHeadAttention | |||
| from ....embeddings.torch import StaticEmbedding | |||
| from ....embeddings.torch.utils import get_embeddings | |||
| __all__ = ['Seq2SeqEncoder', 'TransformerSeq2SeqEncoder', 'LSTMSeq2SeqEncoder'] | |||
| class Seq2SeqEncoder(nn.Module): | |||
| """ | |||
| 所有Sequence2Sequence Encoder的基类。需要实现forward函数 | |||
| """ | |||
| def __init__(self): | |||
| super().__init__() | |||
| def forward(self, tokens, seq_len): | |||
| """ | |||
| :param torch.LongTensor tokens: bsz x max_len, encoder的输入 | |||
| :param torch.LongTensor seq_len: bsz | |||
| :return: | |||
| """ | |||
| raise NotImplementedError | |||
| class TransformerSeq2SeqEncoderLayer(nn.Module): | |||
| """ | |||
| Self-Attention的Layer, | |||
| :param int d_model: input和output的输出维度 | |||
| :param int n_head: 多少个head,每个head的维度为d_model/n_head | |||
| :param int dim_ff: FFN的维度大小 | |||
| :param float dropout: Self-attention和FFN的dropout大小,0表示不drop | |||
| """ | |||
| def __init__(self, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, | |||
| dropout: float = 0.1): | |||
| super(TransformerSeq2SeqEncoderLayer, self).__init__() | |||
| self.d_model = d_model | |||
| self.n_head = n_head | |||
| self.dim_ff = dim_ff | |||
| self.dropout = dropout | |||
| self.self_attn = MultiHeadAttention(d_model, n_head, dropout) | |||
| self.attn_layer_norm = LayerNorm(d_model) | |||
| self.ffn_layer_norm = LayerNorm(d_model) | |||
| self.ffn = nn.Sequential(nn.Linear(self.d_model, self.dim_ff), | |||
| nn.ReLU(), | |||
| nn.Dropout(dropout), | |||
| nn.Linear(self.dim_ff, self.d_model), | |||
| nn.Dropout(dropout)) | |||
| def forward(self, x, mask): | |||
| """ | |||
| :param x: batch x src_seq x d_model | |||
| :param mask: batch x src_seq,为0的地方为padding | |||
| :return: | |||
| """ | |||
| # attention | |||
| residual = x | |||
| x = self.attn_layer_norm(x) | |||
| x, _ = self.self_attn(query=x, | |||
| key=x, | |||
| value=x, | |||
| key_mask=mask) | |||
| x = F.dropout(x, p=self.dropout, training=self.training) | |||
| x = residual + x | |||
| # ffn | |||
| residual = x | |||
| x = self.ffn_layer_norm(x) | |||
| x = self.ffn(x) | |||
| x = residual + x | |||
| return x | |||
| class TransformerSeq2SeqEncoder(Seq2SeqEncoder): | |||
| """ | |||
| 基于Transformer的Encoder | |||
| :param embed: encoder输入token的embedding | |||
| :param nn.Module pos_embed: position embedding | |||
| :param int num_layers: 多少层的encoder | |||
| :param int d_model: 输入输出的维度 | |||
| :param int n_head: 多少个head | |||
| :param int dim_ff: FFN中间的维度大小 | |||
| :param float dropout: Attention和FFN的dropout大小 | |||
| """ | |||
| def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], pos_embed = None, | |||
| num_layers = 6, d_model = 512, n_head = 8, dim_ff = 2048, dropout = 0.1): | |||
| super(TransformerSeq2SeqEncoder, self).__init__() | |||
| self.embed = get_embeddings(embed) | |||
| self.embed_scale = math.sqrt(d_model) | |||
| self.pos_embed = pos_embed | |||
| self.num_layers = num_layers | |||
| self.d_model = d_model | |||
| self.n_head = n_head | |||
| self.dim_ff = dim_ff | |||
| self.dropout = dropout | |||
| self.input_fc = nn.Linear(self.embed.embedding_dim, d_model) | |||
| self.layer_stacks = nn.ModuleList([TransformerSeq2SeqEncoderLayer(d_model, n_head, dim_ff, dropout) | |||
| for _ in range(num_layers)]) | |||
| self.layer_norm = LayerNorm(d_model) | |||
| def forward(self, tokens, seq_len): | |||
| """ | |||
| :param tokens: batch x max_len | |||
| :param seq_len: [batch] | |||
| :return: bsz x max_len x d_model, bsz x max_len(为0的地方为padding) | |||
| """ | |||
| x = self.embed(tokens) * self.embed_scale # batch, seq, dim | |||
| batch_size, max_src_len, _ = x.size() | |||
| device = x.device | |||
| if self.pos_embed is not None: | |||
| position = torch.arange(1, max_src_len + 1).unsqueeze(0).long().to(device) | |||
| x += self.pos_embed(position) | |||
| x = self.input_fc(x) | |||
| x = F.dropout(x, p=self.dropout, training=self.training) | |||
| encoder_mask = seq_len_to_mask(seq_len, max_len=max_src_len) | |||
| encoder_mask = encoder_mask.to(device) | |||
| for layer in self.layer_stacks: | |||
| x = layer(x, encoder_mask) | |||
| x = self.layer_norm(x) | |||
| return x, encoder_mask | |||
| class LSTMSeq2SeqEncoder(Seq2SeqEncoder): | |||
| """ | |||
| LSTM的Encoder | |||
| :param embed: encoder的token embed | |||
| :param int num_layers: 多少层 | |||
| :param int hidden_size: LSTM隐藏层、输出的大小 | |||
| :param float dropout: LSTM层之间的Dropout是多少 | |||
| :param bool bidirectional: 是否使用双向 | |||
| """ | |||
| def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], num_layers = 3, | |||
| hidden_size = 400, dropout = 0.3, bidirectional=True): | |||
| super().__init__() | |||
| self.embed = get_embeddings(embed) | |||
| self.num_layers = num_layers | |||
| self.dropout = dropout | |||
| self.hidden_size = hidden_size | |||
| self.bidirectional = bidirectional | |||
| hidden_size = hidden_size//2 if bidirectional else hidden_size | |||
| self.lstm = LSTM(input_size=embed.embedding_dim, hidden_size=hidden_size, bidirectional=bidirectional, | |||
| batch_first=True, dropout=dropout if num_layers>1 else 0, num_layers=num_layers) | |||
| def forward(self, tokens, seq_len): | |||
| """ | |||
| :param torch.LongTensor tokens: bsz x max_len | |||
| :param torch.LongTensor seq_len: bsz | |||
| :return: (output, (hidden, cell)), encoder_mask | |||
| output: bsz x max_len x hidden_size, | |||
| hidden,cell: batch_size x hidden_size, 最后一层的隐藏状态或cell状态 | |||
| encoder_mask: bsz x max_len, 为0的地方是padding | |||
| """ | |||
| x = self.embed(tokens) | |||
| device = x.device | |||
| x, (final_hidden, final_cell) = self.lstm(x, seq_len) | |||
| encoder_mask = seq_len_to_mask(seq_len).to(device) | |||
| # x: batch,seq_len,dim; h/c: num_layers*2,batch,dim | |||
| if self.bidirectional: | |||
| final_hidden = self.concat_bidir(final_hidden) # 将双向的hidden state拼接起来,用于接下来的decoder的input | |||
| final_cell = self.concat_bidir(final_cell) | |||
| return (x, (final_hidden[-1], final_cell[-1])), encoder_mask # 为了配合Seq2SeqBaseModel的forward,这边需要分为两个return | |||
| def concat_bidir(self, input): | |||
| output = input.view(self.num_layers, 2, input.size(1), -1).transpose(1, 2) | |||
| return output.reshape(self.num_layers, input.size(1), -1) | |||
| @@ -0,0 +1,166 @@ | |||
| r"""undocumented | |||
| Star-Transformer 的encoder部分的 Pytorch 实现 | |||
| """ | |||
| __all__ = [ | |||
| "StarTransformer" | |||
| ] | |||
| import numpy as np | |||
| import torch | |||
| from torch import nn | |||
| from torch.nn import functional as F | |||
| class StarTransformer(nn.Module): | |||
| r""" | |||
| Star-Transformer 的encoder部分。 输入3d的文本输入, 返回相同长度的文本编码 | |||
| paper: https://arxiv.org/abs/1902.09113 | |||
| """ | |||
| def __init__(self, hidden_size, num_layers, num_head, head_dim, dropout=0.1, max_len=None): | |||
| r""" | |||
| :param int hidden_size: 输入维度的大小。同时也是输出维度的大小。 | |||
| :param int num_layers: star-transformer的层数 | |||
| :param int num_head: head的数量。 | |||
| :param int head_dim: 每个head的维度大小。 | |||
| :param float dropout: dropout 概率. Default: 0.1 | |||
| :param int max_len: int or None, 如果为int,输入序列的最大长度, | |||
| 模型会为输入序列加上position embedding。 | |||
| 若为`None`,忽略加上position embedding的步骤. Default: `None` | |||
| """ | |||
| super(StarTransformer, self).__init__() | |||
| self.iters = num_layers | |||
| self.norm = nn.ModuleList([nn.LayerNorm(hidden_size, eps=1e-6) for _ in range(self.iters)]) | |||
| # self.emb_fc = nn.Conv2d(hidden_size, hidden_size, 1) | |||
| self.emb_drop = nn.Dropout(dropout) | |||
| self.ring_att = nn.ModuleList( | |||
| [_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0) | |||
| for _ in range(self.iters)]) | |||
| self.star_att = nn.ModuleList( | |||
| [_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0) | |||
| for _ in range(self.iters)]) | |||
| if max_len is not None: | |||
| self.pos_emb = nn.Embedding(max_len, hidden_size) | |||
| else: | |||
| self.pos_emb = None | |||
| def forward(self, data, mask): | |||
| r""" | |||
| :param FloatTensor data: [batch, length, hidden] 输入的序列 | |||
| :param ByteTensor mask: [batch, length] 输入序列的padding mask, 在没有内容(padding 部分) 为 0, | |||
| 否则为 1 | |||
| :return: [batch, length, hidden] 编码后的输出序列 | |||
| [batch, hidden] 全局 relay 节点, 详见论文 | |||
| """ | |||
| def norm_func(f, x): | |||
| # B, H, L, 1 | |||
| return f(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) | |||
| B, L, H = data.size() | |||
| mask = (mask.eq(False)) # flip the mask for masked_fill_ | |||
| smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) | |||
| embs = data.permute(0, 2, 1)[:, :, :, None] # B H L 1 | |||
| if self.pos_emb: | |||
| P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device) \ | |||
| .view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None] # 1 H L 1 | |||
| embs = embs + P | |||
| embs = norm_func(self.emb_drop, embs) | |||
| nodes = embs | |||
| relay = embs.mean(2, keepdim=True) | |||
| ex_mask = mask[:, None, :, None].expand(B, H, L, 1) | |||
| r_embs = embs.view(B, H, 1, L) | |||
| for i in range(self.iters): | |||
| ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2) | |||
| nodes = F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax)) | |||
| # nodes = F.leaky_relu(self.ring_att[i](nodes, ax=ax)) | |||
| relay = F.leaky_relu(self.star_att[i](relay, torch.cat([relay, nodes], 2), smask)) | |||
| nodes = nodes.masked_fill_(ex_mask, 0) | |||
| nodes = nodes.view(B, H, L).permute(0, 2, 1) | |||
| return nodes, relay.view(B, H) | |||
| class _MSA1(nn.Module): | |||
| def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1): | |||
| super(_MSA1, self).__init__() | |||
| # Multi-head Self Attention Case 1, doing self-attention for small regions | |||
| # Due to the architecture of GPU, using hadamard production and summation are faster than dot production when unfold_size is very small | |||
| self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) | |||
| self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) | |||
| self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) | |||
| self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) | |||
| self.drop = nn.Dropout(dropout) | |||
| # print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim) | |||
| self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3 | |||
| def forward(self, x, ax=None): | |||
| # x: B, H, L, 1, ax : B, H, X, L append features | |||
| nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size | |||
| B, H, L, _ = x.shape | |||
| q, k, v = self.WQ(x), self.WK(x), self.WV(x) # x: (B,H,L,1) | |||
| if ax is not None: | |||
| aL = ax.shape[2] | |||
| ak = self.WK(ax).view(B, nhead, head_dim, aL, L) | |||
| av = self.WV(ax).view(B, nhead, head_dim, aL, L) | |||
| q = q.view(B, nhead, head_dim, 1, L) | |||
| k = F.unfold(k.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0)) \ | |||
| .view(B, nhead, head_dim, unfold_size, L) | |||
| v = F.unfold(v.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0)) \ | |||
| .view(B, nhead, head_dim, unfold_size, L) | |||
| if ax is not None: | |||
| k = torch.cat([k, ak], 3) | |||
| v = torch.cat([v, av], 3) | |||
| alphas = self.drop(F.softmax((q * k).sum(2, keepdim=True) / np.sqrt(head_dim), 3)) # B N L 1 U | |||
| att = (alphas * v).sum(3).view(B, nhead * head_dim, L, 1) | |||
| ret = self.WO(att) | |||
| return ret | |||
| class _MSA2(nn.Module): | |||
| def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1): | |||
| # Multi-head Self Attention Case 2, a broadcastable query for a sequence key and value | |||
| super(_MSA2, self).__init__() | |||
| self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) | |||
| self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) | |||
| self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) | |||
| self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) | |||
| self.drop = nn.Dropout(dropout) | |||
| # print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim) | |||
| self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3 | |||
| def forward(self, x, y, mask=None): | |||
| # x: B, H, 1, 1, 1 y: B H L 1 | |||
| nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size | |||
| B, H, L, _ = y.shape | |||
| q, k, v = self.WQ(x), self.WK(y), self.WV(y) | |||
| q = q.view(B, nhead, 1, head_dim) # B, H, 1, 1 -> B, N, 1, h | |||
| k = k.view(B, nhead, head_dim, L) # B, H, L, 1 -> B, N, h, L | |||
| v = v.view(B, nhead, head_dim, L).permute(0, 1, 3, 2) # B, H, L, 1 -> B, N, L, h | |||
| pre_a = torch.matmul(q, k) / np.sqrt(head_dim) | |||
| if mask is not None: | |||
| pre_a = pre_a.masked_fill(mask[:, None, None, :], -float('inf')) | |||
| alphas = self.drop(F.softmax(pre_a, 3)) # B, N, 1, L | |||
| att = torch.matmul(alphas, v).view(B, -1, 1, 1) # B, N, 1, h -> B, N*h, 1, 1 | |||
| return self.WO(att) | |||
| @@ -0,0 +1,43 @@ | |||
| r"""undocumented""" | |||
| __all__ = [ | |||
| "TransformerEncoder" | |||
| ] | |||
| from torch import nn | |||
| from .seq2seq_encoder import TransformerSeq2SeqEncoderLayer | |||
| class TransformerEncoder(nn.Module): | |||
| r""" | |||
| transformer的encoder模块,不包含embedding层 | |||
| """ | |||
| def __init__(self, num_layers, d_model=512, n_head=8, dim_ff=2048, dropout=0.1): | |||
| """ | |||
| :param int num_layers: 多少层Transformer | |||
| :param int d_model: input和output的大小 | |||
| :param int n_head: 多少个head | |||
| :param int dim_ff: FFN中间hidden大小 | |||
| :param float dropout: 多大概率drop attention和ffn中间的表示 | |||
| """ | |||
| super(TransformerEncoder, self).__init__() | |||
| self.layers = nn.ModuleList([TransformerSeq2SeqEncoderLayer(d_model = d_model, n_head = n_head, dim_ff = dim_ff, | |||
| dropout = dropout) for _ in range(num_layers)]) | |||
| self.norm = nn.LayerNorm(d_model, eps=1e-6) | |||
| def forward(self, x, seq_mask=None): | |||
| r""" | |||
| :param x: [batch, seq_len, model_size] 输入序列 | |||
| :param seq_mask: [batch, seq_len] 输入序列的padding mask, 若为 ``None`` , 生成全1向量. 为1的地方需要attend | |||
| Default: ``None`` | |||
| :return: [batch, seq_len, model_size] 输出序列 | |||
| """ | |||
| output = x | |||
| if seq_mask is None: | |||
| seq_mask = x.new_ones(x.size(0), x.size(1)).bool() | |||
| for layer in self.layers: | |||
| output = layer(output, seq_mask) | |||
| return self.norm(output) | |||
| @@ -0,0 +1,303 @@ | |||
| r"""undocumented | |||
| Variational RNN 及相关模型的 fastNLP实现,相关论文参考: | |||
| `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | |||
| """ | |||
| __all__ = [ | |||
| "VarRNN", | |||
| "VarLSTM", | |||
| "VarGRU" | |||
| ] | |||
| import torch | |||
| import torch.nn as nn | |||
| from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence | |||
| try: | |||
| from torch import flip | |||
| except ImportError: | |||
| def flip(x, dims): | |||
| indices = [slice(None)] * x.dim() | |||
| for dim in dims: | |||
| indices[dim] = torch.arange( | |||
| x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) | |||
| return x[tuple(indices)] | |||
| class VarRnnCellWrapper(nn.Module): | |||
| r""" | |||
| Wrapper for normal RNN Cells, make it support variational dropout | |||
| """ | |||
| def __init__(self, cell, hidden_size, input_p, hidden_p): | |||
| super(VarRnnCellWrapper, self).__init__() | |||
| self.cell = cell | |||
| self.hidden_size = hidden_size | |||
| self.input_p = input_p | |||
| self.hidden_p = hidden_p | |||
| def forward(self, input_x, hidden, mask_x, mask_h, is_reversed=False): | |||
| r""" | |||
| :param PackedSequence input_x: [seq_len, batch_size, input_size] | |||
| :param hidden: for LSTM, tuple of (h_0, c_0), [batch_size, hidden_size] | |||
| for other RNN, h_0, [batch_size, hidden_size] | |||
| :param mask_x: [batch_size, input_size] dropout mask for input | |||
| :param mask_h: [batch_size, hidden_size] dropout mask for hidden | |||
| :return PackedSequence output: [seq_len, bacth_size, hidden_size] | |||
| hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] | |||
| for other RNN, h_n, [batch_size, hidden_size] | |||
| """ | |||
| def get_hi(hi, h0, size): | |||
| h0_size = size - hi.size(0) | |||
| if h0_size > 0: | |||
| return torch.cat([hi, h0[:h0_size]], dim=0) | |||
| return hi[:size] | |||
| is_lstm = isinstance(hidden, tuple) | |||
| input, batch_sizes = input_x.data, input_x.batch_sizes | |||
| output = [] | |||
| cell = self.cell | |||
| if is_reversed: | |||
| batch_iter = flip(batch_sizes, [0]) | |||
| idx = input.size(0) | |||
| else: | |||
| batch_iter = batch_sizes | |||
| idx = 0 | |||
| if is_lstm: | |||
| hn = (hidden[0].clone(), hidden[1].clone()) | |||
| else: | |||
| hn = hidden.clone() | |||
| hi = hidden | |||
| for size in batch_iter: | |||
| if is_reversed: | |||
| input_i = input[idx - size: idx] * mask_x[:size] | |||
| idx -= size | |||
| else: | |||
| input_i = input[idx: idx + size] * mask_x[:size] | |||
| idx += size | |||
| mask_hi = mask_h[:size] | |||
| if is_lstm: | |||
| hx, cx = hi | |||
| hi = (get_hi(hx, hidden[0], size) * | |||
| mask_hi, get_hi(cx, hidden[1], size)) | |||
| hi = cell(input_i, hi) | |||
| hn[0][:size] = hi[0] | |||
| hn[1][:size] = hi[1] | |||
| output.append(hi[0]) | |||
| else: | |||
| hi = get_hi(hi, hidden, size) * mask_hi | |||
| hi = cell(input_i, hi) | |||
| hn[:size] = hi | |||
| output.append(hi) | |||
| if is_reversed: | |||
| output = list(reversed(output)) | |||
| output = torch.cat(output, dim=0) | |||
| return PackedSequence(output, batch_sizes), hn | |||
| class VarRNNBase(nn.Module): | |||
| r""" | |||
| Variational Dropout RNN 实现. | |||
| 论文参考: `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) | |||
| https://arxiv.org/abs/1512.05287`. | |||
| """ | |||
| def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, | |||
| bias=True, batch_first=False, | |||
| input_dropout=0, hidden_dropout=0, bidirectional=False): | |||
| r""" | |||
| :param mode: rnn 模式, (lstm or not) | |||
| :param Cell: rnn cell 类型, (lstm, gru, etc) | |||
| :param input_size: 输入 `x` 的特征维度 | |||
| :param hidden_size: 隐状态 `h` 的特征维度 | |||
| :param num_layers: rnn的层数. Default: 1 | |||
| :param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` | |||
| :param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 | |||
| (batch, seq, feature). Default: ``False`` | |||
| :param input_dropout: 对输入的dropout概率. Default: 0 | |||
| :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | |||
| :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | |||
| """ | |||
| super(VarRNNBase, self).__init__() | |||
| self.mode = mode | |||
| self.input_size = input_size | |||
| self.hidden_size = hidden_size | |||
| self.num_layers = num_layers | |||
| self.bias = bias | |||
| self.batch_first = batch_first | |||
| self.input_dropout = input_dropout | |||
| self.hidden_dropout = hidden_dropout | |||
| self.bidirectional = bidirectional | |||
| self.num_directions = 2 if bidirectional else 1 | |||
| self._all_cells = nn.ModuleList() | |||
| for layer in range(self.num_layers): | |||
| for direction in range(self.num_directions): | |||
| input_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions | |||
| cell = Cell(input_size, self.hidden_size, bias) | |||
| self._all_cells.append(VarRnnCellWrapper( | |||
| cell, self.hidden_size, input_dropout, hidden_dropout)) | |||
| self.is_lstm = (self.mode == "LSTM") | |||
| def _forward_one(self, n_layer, n_direction, input, hx, mask_x, mask_h): | |||
| is_lstm = self.is_lstm | |||
| idx = self.num_directions * n_layer + n_direction | |||
| cell = self._all_cells[idx] | |||
| hi = (hx[0][idx], hx[1][idx]) if is_lstm else hx[idx] | |||
| output_x, hidden_x = cell( | |||
| input, hi, mask_x, mask_h, is_reversed=(n_direction == 1)) | |||
| return output_x, hidden_x | |||
| def forward(self, x, hx=None): | |||
| r""" | |||
| :param x: [batch, seq_len, input_size] 输入序列 | |||
| :param hx: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全1向量. Default: ``None`` | |||
| :return (output, ht): [batch, seq_len, hidden_size*num_direction] 输出序列 | |||
| 和 [batch, hidden_size*num_direction] 最后时刻隐状态 | |||
| """ | |||
| is_lstm = self.is_lstm | |||
| is_packed = isinstance(x, PackedSequence) | |||
| if not is_packed: | |||
| seq_len = x.size(1) if self.batch_first else x.size(0) | |||
| max_batch_size = x.size(0) if self.batch_first else x.size(1) | |||
| seq_lens = torch.LongTensor( | |||
| [seq_len for _ in range(max_batch_size)]) | |||
| x = pack_padded_sequence(x, seq_lens, batch_first=self.batch_first) | |||
| else: | |||
| max_batch_size = int(x.batch_sizes[0]) | |||
| x, batch_sizes = x.data, x.batch_sizes | |||
| if hx is None: | |||
| hx = x.new_zeros(self.num_layers * self.num_directions, | |||
| max_batch_size, self.hidden_size, requires_grad=True) | |||
| if is_lstm: | |||
| hx = (hx, hx.new_zeros(hx.size(), requires_grad=True)) | |||
| mask_x = x.new_ones((max_batch_size, self.input_size)) | |||
| mask_out = x.new_ones( | |||
| (max_batch_size, self.hidden_size * self.num_directions)) | |||
| mask_h_ones = x.new_ones((max_batch_size, self.hidden_size)) | |||
| nn.functional.dropout(mask_x, p=self.input_dropout, | |||
| training=self.training, inplace=True) | |||
| nn.functional.dropout(mask_out, p=self.hidden_dropout, | |||
| training=self.training, inplace=True) | |||
| hidden = x.new_zeros( | |||
| (self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) | |||
| if is_lstm: | |||
| cellstate = x.new_zeros( | |||
| (self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) | |||
| for layer in range(self.num_layers): | |||
| output_list = [] | |||
| input_seq = PackedSequence(x, batch_sizes) | |||
| mask_h = nn.functional.dropout( | |||
| mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False) | |||
| for direction in range(self.num_directions): | |||
| output_x, hidden_x = self._forward_one(layer, direction, input_seq, hx, | |||
| mask_x if layer == 0 else mask_out, mask_h) | |||
| output_list.append(output_x.data) | |||
| idx = self.num_directions * layer + direction | |||
| if is_lstm: | |||
| hidden[idx] = hidden_x[0] | |||
| cellstate[idx] = hidden_x[1] | |||
| else: | |||
| hidden[idx] = hidden_x | |||
| x = torch.cat(output_list, dim=-1) | |||
| if is_lstm: | |||
| hidden = (hidden, cellstate) | |||
| if is_packed: | |||
| output = PackedSequence(x, batch_sizes) | |||
| else: | |||
| x = PackedSequence(x, batch_sizes) | |||
| output, _ = pad_packed_sequence(x, batch_first=self.batch_first) | |||
| return output, hidden | |||
| class VarLSTM(VarRNNBase): | |||
| r""" | |||
| Variational Dropout LSTM. | |||
| 相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | |||
| """ | |||
| def __init__(self, *args, **kwargs): | |||
| r""" | |||
| :param input_size: 输入 `x` 的特征维度 | |||
| :param hidden_size: 隐状态 `h` 的特征维度 | |||
| :param num_layers: rnn的层数. Default: 1 | |||
| :param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` | |||
| :param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 | |||
| (batch, seq, feature). Default: ``False`` | |||
| :param input_dropout: 对输入的dropout概率. Default: 0 | |||
| :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | |||
| :param bidirectional: 若为 ``True``, 使用双向的LSTM. Default: ``False`` | |||
| """ | |||
| super(VarLSTM, self).__init__( | |||
| mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs) | |||
| def forward(self, x, hx=None): | |||
| return super(VarLSTM, self).forward(x, hx) | |||
| class VarRNN(VarRNNBase): | |||
| r""" | |||
| Variational Dropout RNN. | |||
| 相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | |||
| """ | |||
| def __init__(self, *args, **kwargs): | |||
| r""" | |||
| :param input_size: 输入 `x` 的特征维度 | |||
| :param hidden_size: 隐状态 `h` 的特征维度 | |||
| :param num_layers: rnn的层数. Default: 1 | |||
| :param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` | |||
| :param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 | |||
| (batch, seq, feature). Default: ``False`` | |||
| :param input_dropout: 对输入的dropout概率. Default: 0 | |||
| :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | |||
| :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | |||
| """ | |||
| super(VarRNN, self).__init__( | |||
| mode="RNN", Cell=nn.RNNCell, *args, **kwargs) | |||
| def forward(self, x, hx=None): | |||
| return super(VarRNN, self).forward(x, hx) | |||
| class VarGRU(VarRNNBase): | |||
| r""" | |||
| Variational Dropout GRU. | |||
| 相关论文参考:`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) <https://arxiv.org/abs/1512.05287>`_ | |||
| """ | |||
| def __init__(self, *args, **kwargs): | |||
| r""" | |||
| :param input_size: 输入 `x` 的特征维度 | |||
| :param hidden_size: 隐状态 `h` 的特征维度 | |||
| :param num_layers: rnn的层数. Default: 1 | |||
| :param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` | |||
| :param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 | |||
| (batch, seq, feature). Default: ``False`` | |||
| :param input_dropout: 对输入的dropout概率. Default: 0 | |||
| :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | |||
| :param bidirectional: 若为 ``True``, 使用双向的GRU. Default: ``False`` | |||
| """ | |||
| super(VarGRU, self).__init__( | |||
| mode="GRU", Cell=nn.GRUCell, *args, **kwargs) | |||
| def forward(self, x, hx=None): | |||
| return super(VarGRU, self).forward(x, hx) | |||
| @@ -0,0 +1,6 @@ | |||
| __all__ = [ | |||
| 'SequenceGenerator' | |||
| ] | |||
| from .seq2seq_generator import SequenceGenerator | |||
| @@ -0,0 +1,536 @@ | |||
| r""" | |||
| """ | |||
| __all__ = [ | |||
| 'SequenceGenerator' | |||
| ] | |||
| import torch | |||
| from torch import nn | |||
| import torch.nn.functional as F | |||
| from ..decoder.seq2seq_decoder import Seq2SeqDecoder, State | |||
| from functools import partial | |||
| def _get_model_device(model): | |||
| r""" | |||
| 传入一个nn.Module的模型,获取它所在的device | |||
| :param model: nn.Module | |||
| :return: torch.device,None 如果返回值为None,说明这个模型没有任何参数。 | |||
| """ | |||
| assert isinstance(model, nn.Module) | |||
| parameters = list(model.parameters()) | |||
| if len(parameters) == 0: | |||
| return None | |||
| else: | |||
| return parameters[0].device | |||
| class SequenceGenerator: | |||
| """ | |||
| 给定一个Seq2SeqDecoder,decode出句子。输入的decoder对象需要有decode()函数, 接受的第一个参数为decode的到目前位置的所有输出, | |||
| 第二个参数为state。SequenceGenerator不会对state进行任何操作。 | |||
| """ | |||
| def __init__(self, decoder: Seq2SeqDecoder, max_length=20, max_len_a=0.0, num_beams=1, | |||
| do_sample=True, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, | |||
| repetition_penalty=1, length_penalty=1.0, pad_token_id=0): | |||
| """ | |||
| :param Seq2SeqDecoder decoder: Decoder对象 | |||
| :param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len | |||
| :param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask | |||
| :param int num_beams: beam search的大小 | |||
| :param bool do_sample: 是否通过采样的方式生成 | |||
| :param float temperature: 只有在do_sample为True才有意义 | |||
| :param int top_k: 只从top_k中采样 | |||
| :param float top_p: 只从top_p的token中采样,nucles sample | |||
| :param int,None bos_token_id: 句子开头的token id | |||
| :param int,None eos_token_id: 句子结束的token id | |||
| :param float repetition_penalty: 多大程度上惩罚重复的token | |||
| :param float length_penalty: 对长度的惩罚,小于1鼓励长句,大于1鼓励短剧 | |||
| :param int pad_token_id: 当某句话生成结束之后,之后生成的内容用pad_token_id补充 | |||
| """ | |||
| if do_sample: | |||
| self.generate_func = partial(sample_generate, decoder=decoder, max_length=max_length, max_len_a=max_len_a, | |||
| num_beams=num_beams, | |||
| temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id, | |||
| eos_token_id=eos_token_id, repetition_penalty=repetition_penalty, | |||
| length_penalty=length_penalty, pad_token_id=pad_token_id) | |||
| else: | |||
| self.generate_func = partial(greedy_generate, decoder=decoder, max_length=max_length, max_len_a=max_len_a, | |||
| num_beams=num_beams, | |||
| bos_token_id=bos_token_id, eos_token_id=eos_token_id, | |||
| repetition_penalty=repetition_penalty, | |||
| length_penalty=length_penalty, pad_token_id=pad_token_id) | |||
| self.do_sample = do_sample | |||
| self.max_length = max_length | |||
| self.num_beams = num_beams | |||
| self.temperature = temperature | |||
| self.top_k = top_k | |||
| self.top_p = top_p | |||
| self.bos_token_id = bos_token_id | |||
| self.eos_token_id = eos_token_id | |||
| self.repetition_penalty = repetition_penalty | |||
| self.length_penalty = length_penalty | |||
| self.decoder = decoder | |||
| @torch.no_grad() | |||
| def generate(self, state, tokens=None): | |||
| """ | |||
| :param State state: encoder结果的State, 是与Decoder配套是用的 | |||
| :param torch.LongTensor,None tokens: batch_size x length, 开始的token。如果为None,则默认添加bos_token作为开头的token | |||
| 进行生成。 | |||
| :return: bsz x max_length' 生成的token序列。如果eos_token_id不为None, 每个sequence的结尾一定是eos_token_id | |||
| """ | |||
| return self.generate_func(tokens=tokens, state=state) | |||
| @torch.no_grad() | |||
| def greedy_generate(decoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=1, | |||
| bos_token_id=None, eos_token_id=None, pad_token_id=0, | |||
| repetition_penalty=1, length_penalty=1.0): | |||
| """ | |||
| 贪婪地搜索句子 | |||
| :param Decoder decoder: Decoder对象 | |||
| :param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 | |||
| :param State state: 应该包含encoder的一些输出。 | |||
| :param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len | |||
| :param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask | |||
| :param int num_beams: 使用多大的beam进行解码。 | |||
| :param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。 | |||
| :param int eos_token_id: 结束的token,如果为None,则一定会解码到max_length这么长。 | |||
| :param int pad_token_id: pad的token id | |||
| :param float repetition_penalty: 对重复出现的token多大的惩罚。 | |||
| :param float length_penalty: 对每个token(除了eos)按照长度进行一定的惩罚。 | |||
| :return: | |||
| """ | |||
| if num_beams == 1: | |||
| token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a, | |||
| temperature=1, top_k=50, top_p=1, | |||
| bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, | |||
| repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||
| pad_token_id=pad_token_id) | |||
| else: | |||
| token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a, | |||
| num_beams=num_beams, temperature=1, top_k=50, top_p=1, | |||
| bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, | |||
| repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||
| pad_token_id=pad_token_id) | |||
| return token_ids | |||
| @torch.no_grad() | |||
| def sample_generate(decoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=1, temperature=1.0, top_k=50, | |||
| top_p=1.0, bos_token_id=None, eos_token_id=None, pad_token_id=0, repetition_penalty=1.0, | |||
| length_penalty=1.0): | |||
| """ | |||
| 使用采样的方法生成句子 | |||
| :param Decoder decoder: Decoder对象 | |||
| :param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 | |||
| :param State state: 应该包含encoder的一些输出。 | |||
| :param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len | |||
| :param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask | |||
| :param int num_beam: 使用多大的beam进行解码。 | |||
| :param float temperature: 采样时的退火大小 | |||
| :param int top_k: 只在top_k的sample里面采样 | |||
| :param float top_p: 介于0,1的值。 | |||
| :param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。 | |||
| :param int eos_token_id: 结束的token,如果为None,则一定会解码到max_length这么长。 | |||
| :param int pad_token_id: pad的token id | |||
| :param float repetition_penalty: 对重复出现的token多大的惩罚。 | |||
| :param float length_penalty: 对每个token(除了eos)按照长度进行一定的惩罚。 | |||
| :return: | |||
| """ | |||
| # 每个位置在生成的时候会sample生成 | |||
| if num_beams == 1: | |||
| token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a, | |||
| temperature=temperature, top_k=top_k, top_p=top_p, | |||
| bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, | |||
| repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||
| pad_token_id=pad_token_id) | |||
| else: | |||
| token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a, | |||
| num_beams=num_beams, temperature=temperature, top_k=top_k, top_p=top_p, | |||
| bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, | |||
| repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||
| pad_token_id=pad_token_id) | |||
| return token_ids | |||
| def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_length=20, max_len_a=0.0, temperature=1.0, top_k=50, | |||
| top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, | |||
| repetition_penalty=1.0, length_penalty=1.0, pad_token_id=0): | |||
| device = _get_model_device(decoder) | |||
| if tokens is None: | |||
| if bos_token_id is None: | |||
| raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.") | |||
| batch_size = state.num_samples | |||
| if batch_size is None: | |||
| raise RuntimeError("Cannot infer the number of samples from `state`.") | |||
| tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device) | |||
| batch_size = tokens.size(0) | |||
| if state.num_samples: | |||
| assert state.num_samples == batch_size, "The number of samples in `tokens` and `state` should match." | |||
| if eos_token_id is None: | |||
| _eos_token_id = -1 | |||
| else: | |||
| _eos_token_id = eos_token_id | |||
| scores = decoder.decode(tokens=tokens, state=state) # 主要是为了update state | |||
| if _eos_token_id!=-1: # 防止第一个位置为结束 | |||
| scores[:, _eos_token_id] = -1e12 | |||
| next_tokens = scores.argmax(dim=-1, keepdim=True) | |||
| token_ids = torch.cat([tokens, next_tokens], dim=1) | |||
| cur_len = token_ids.size(1) | |||
| dones = token_ids.new_zeros(batch_size).eq(1) | |||
| # tokens = tokens[:, -1:] | |||
| if max_len_a!=0: | |||
| # (bsz x num_beams, ) | |||
| if state.encoder_mask is not None: | |||
| max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length | |||
| else: | |||
| max_lengths = tokens.new_full((tokens.size(0), ), fill_value=max_length, dtype=torch.long) | |||
| real_max_length = max_lengths.max().item() | |||
| else: | |||
| real_max_length = max_length | |||
| if state.encoder_mask is not None: | |||
| max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length | |||
| else: | |||
| max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long) | |||
| while cur_len < real_max_length: | |||
| scores = decoder.decode(tokens=token_ids, state=state) # batch_size x vocab_size | |||
| if repetition_penalty != 1.0: | |||
| token_scores = scores.gather(dim=1, index=token_ids) | |||
| lt_zero_mask = token_scores.lt(0).float() | |||
| ge_zero_mask = lt_zero_mask.eq(0).float() | |||
| token_scores = lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores | |||
| scores.scatter_(dim=1, index=token_ids, src=token_scores) | |||
| if eos_token_id is not None and length_penalty != 1.0: | |||
| token_scores = scores / cur_len ** length_penalty # batch_size x vocab_size | |||
| eos_mask = scores.new_ones(scores.size(1)) | |||
| eos_mask[eos_token_id] = 0 | |||
| eos_mask = eos_mask.unsqueeze(0).eq(1) | |||
| scores = scores.masked_scatter(eos_mask, token_scores) # 也即除了eos,其他词的分数经过了放大/缩小 | |||
| if do_sample: | |||
| if temperature > 0 and temperature != 1: | |||
| scores = scores / temperature | |||
| scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=2) | |||
| # 加上1e-12是为了避免https://github.com/pytorch/pytorch/pull/27523 | |||
| probs = F.softmax(scores, dim=-1) + 1e-12 | |||
| next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # batch_size | |||
| else: | |||
| next_tokens = torch.argmax(scores, dim=-1) # batch_size | |||
| # 如果已经达到对应的sequence长度了,就直接填为eos了 | |||
| if _eos_token_id!=-1: | |||
| next_tokens = next_tokens.masked_fill(max_lengths.eq(cur_len+1), _eos_token_id) | |||
| next_tokens = next_tokens.masked_fill(dones, pad_token_id) # 对已经搜索完成的sample做padding | |||
| tokens = next_tokens.unsqueeze(1) | |||
| token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len | |||
| end_mask = next_tokens.eq(_eos_token_id) | |||
| dones = dones.__or__(end_mask) | |||
| cur_len += 1 | |||
| if dones.min() == 1: | |||
| break | |||
| # if eos_token_id is not None: | |||
| # tokens.scatter(index=max_lengths[:, None], dim=1, value=eos_token_id) # 将最大长度位置设置为eos | |||
| # if cur_len == max_length: | |||
| # token_ids[:, -1].masked_fill_(~dones, eos_token_id) # 若到最长长度仍未到EOS,则强制将最后一个词替换成eos | |||
| return token_ids | |||
| def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=4, temperature=1.0, | |||
| top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, | |||
| repetition_penalty=1.0, length_penalty=None, pad_token_id=0) -> torch.LongTensor: | |||
| # 进行beam search | |||
| device = _get_model_device(decoder) | |||
| if tokens is None: | |||
| if bos_token_id is None: | |||
| raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.") | |||
| batch_size = state.num_samples | |||
| if batch_size is None: | |||
| raise RuntimeError("Cannot infer the number of samples from `state`.") | |||
| tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device) | |||
| batch_size = tokens.size(0) | |||
| if state.num_samples: | |||
| assert state.num_samples == batch_size, "The number of samples in `tokens` and `state` should match." | |||
| if eos_token_id is None: | |||
| _eos_token_id = -1 | |||
| else: | |||
| _eos_token_id = eos_token_id | |||
| scores = decoder.decode(tokens=tokens, state=state) # 这里要传入的是整个句子的长度 | |||
| if _eos_token_id!=-1: # 防止第一个位置为结束 | |||
| scores[:, _eos_token_id] = -1e12 | |||
| vocab_size = scores.size(1) | |||
| assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size." | |||
| if do_sample: | |||
| probs = F.softmax(scores, dim=-1) + 1e-12 | |||
| next_tokens = torch.multinomial(probs, num_samples=num_beams) # (batch_size, num_beams) | |||
| logits = probs.log() | |||
| next_scores = logits.gather(dim=1, index=next_tokens) # (batch_size, num_beams) | |||
| else: | |||
| scores = F.log_softmax(scores, dim=-1) # (batch_size, vocab_size) | |||
| # 得到(batch_size, num_beams), (batch_size, num_beams) | |||
| next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True) | |||
| # 根据index来做顺序的调转 | |||
| indices = torch.arange(batch_size, dtype=torch.long).to(device) | |||
| indices = indices.repeat_interleave(num_beams) | |||
| state.reorder_state(indices) | |||
| tokens = tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length | |||
| # 记录生成好的token (batch_size', cur_len) | |||
| token_ids = torch.cat([tokens, next_tokens.view(-1, 1)], dim=-1) | |||
| dones = [False] * batch_size | |||
| beam_scores = next_scores.view(-1) # batch_size * num_beams | |||
| # 用来记录已经生成好的token的长度 | |||
| cur_len = token_ids.size(1) | |||
| if max_len_a!=0: | |||
| # (bsz x num_beams, ) | |||
| if state.encoder_mask is not None: | |||
| max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length | |||
| else: | |||
| max_lengths = tokens.new_full((tokens.size(0), ), fill_value=max_length, dtype=torch.long) | |||
| real_max_length = max_lengths.max().item() | |||
| else: | |||
| real_max_length = max_length | |||
| if state.encoder_mask is not None: | |||
| max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length | |||
| else: | |||
| max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long) | |||
| hypos = [ | |||
| BeamHypotheses(num_beams, real_max_length, length_penalty, early_stopping=False) for _ in range(batch_size) | |||
| ] | |||
| # 0, num_beams, 2*num_beams, ... | |||
| batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids) | |||
| while cur_len < real_max_length: | |||
| scores = decoder.decode(token_ids, state) # (bsz x num_beams, vocab_size) | |||
| if repetition_penalty != 1.0: | |||
| token_scores = scores.gather(dim=1, index=token_ids) | |||
| lt_zero_mask = token_scores.lt(0).float() | |||
| ge_zero_mask = lt_zero_mask.eq(0).float() | |||
| token_scores = lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores | |||
| scores.scatter_(dim=1, index=token_ids, src=token_scores) | |||
| if _eos_token_id!=-1: | |||
| max_len_eos_mask = max_lengths.eq(cur_len+1) | |||
| eos_scores = scores[:, _eos_token_id] | |||
| # 如果已经达到最大长度,就把eos的分数加大 | |||
| scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+1e32, eos_scores) | |||
| if do_sample: | |||
| if temperature > 0 and temperature != 1: | |||
| scores = scores / temperature | |||
| # 多召回一个防止eos | |||
| scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=num_beams + 1) | |||
| # 加上1e-12是为了避免https://github.com/pytorch/pytorch/pull/27523 | |||
| probs = F.softmax(scores, dim=-1) + 1e-12 | |||
| # 保证至少有一个不是eos的值 | |||
| _tokens = torch.multinomial(probs, num_samples=num_beams + 1) # batch_size' x (num_beams+1) | |||
| logits = probs.log() | |||
| # 防止全是这个beam的被选中了,且需要考虑eos被选择的情况 | |||
| _scores = logits.gather(dim=1, index=_tokens) # batch_size' x (num_beams+1) | |||
| _scores = _scores + beam_scores[:, None] # batch_size' x (num_beams+1) | |||
| # 从这里面再选择top的2*num_beam个 | |||
| _scores = _scores.view(batch_size, num_beams * (num_beams + 1)) | |||
| next_scores, ids = _scores.topk(2 * num_beams, dim=1, largest=True, sorted=True) | |||
| _tokens = _tokens.view(batch_size, num_beams * (num_beams + 1)) | |||
| next_tokens = _tokens.gather(dim=1, index=ids) # (batch_size, 2*num_beams) | |||
| from_which_beam = ids // (num_beams + 1) # (batch_size, 2*num_beams) | |||
| else: | |||
| scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) | |||
| _scores = scores + beam_scores[:, None] # (batch_size * num_beams, vocab_size) | |||
| _scores = _scores.view(batch_size, -1) # (batch_size, num_beams*vocab_size) | |||
| next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) # (bsz, 2*num_beams) | |||
| from_which_beam = ids // vocab_size # (batch_size, 2*num_beams) | |||
| next_tokens = ids % vocab_size # (batch_size, 2*num_beams) | |||
| # 接下来需要组装下一个batch的结果。 | |||
| # 需要选定哪些留下来 | |||
| # next_scores, sorted_inds = next_scores.sort(dim=-1, descending=True) | |||
| # next_tokens = next_tokens.gather(dim=1, index=sorted_inds) | |||
| # from_which_beam = from_which_beam.gather(dim=1, index=sorted_inds) | |||
| not_eos_mask = next_tokens.ne(_eos_token_id) # 为1的地方不是eos | |||
| keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams) # 为1的地方需要保留 | |||
| keep_mask = not_eos_mask.__and__(keep_mask) # 为1的地方是需要进行下一步search的 | |||
| _next_tokens = next_tokens.masked_select(keep_mask).view(-1, 1) | |||
| _from_which_beam = from_which_beam.masked_select(keep_mask).view(batch_size, num_beams) # 上面的token是来自哪个beam | |||
| _next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams) | |||
| beam_scores = _next_scores.view(-1) | |||
| flag = True | |||
| if cur_len+1 == real_max_length: | |||
| eos_batch_idx = torch.arange(batch_size).to(next_tokens).repeat_interleave(repeats=num_beams, dim=0) | |||
| eos_beam_ind = torch.arange(num_beams).to(token_ids).repeat(batch_size) # 表示的是indice | |||
| eos_beam_idx = from_which_beam[:, :num_beams].reshape(-1) # 表示的是从哪个beam获取得到的 | |||
| else: | |||
| # 将每个batch中在num_beam内的序列添加到结束中, 为1的地方需要结束了 | |||
| effective_eos_mask = next_tokens[:, :num_beams].eq(_eos_token_id) # batch_size x num_beams | |||
| if effective_eos_mask.sum().gt(0): | |||
| eos_batch_idx, eos_beam_ind = effective_eos_mask.nonzero(as_tuple=True) | |||
| # 是由于from_which_beam是 (batch_size, 2*num_beams)的,所以需要2*num_beams | |||
| eos_beam_idx = eos_batch_idx * num_beams * 2 + eos_beam_ind | |||
| eos_beam_idx = from_which_beam.view(-1)[eos_beam_idx] # 获取真实的从哪个beam获取的eos | |||
| else: | |||
| flag = False | |||
| if flag: | |||
| _token_ids = torch.cat([token_ids, _next_tokens], dim=-1) | |||
| for batch_idx, beam_ind, beam_idx in zip(eos_batch_idx.tolist(), eos_beam_ind.tolist(), | |||
| eos_beam_idx.tolist()): | |||
| if not dones[batch_idx]: | |||
| score = next_scores[batch_idx, beam_ind].item() | |||
| # 之后需要在结尾新增一个eos | |||
| if _eos_token_id!=-1: | |||
| hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score) | |||
| else: | |||
| hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx].clone(), score) | |||
| # 更改state状态, 重组token_ids | |||
| reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维 | |||
| state.reorder_state(reorder_inds) | |||
| # 重新组织token_ids的状态 | |||
| token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), _next_tokens], dim=-1) | |||
| for batch_idx in range(batch_size): | |||
| dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item()) or \ | |||
| max_lengths[batch_idx*num_beams]==cur_len+1 | |||
| cur_len += 1 | |||
| if all(dones): | |||
| break | |||
| # select the best hypotheses | |||
| tgt_len = token_ids.new_zeros(batch_size) | |||
| best = [] | |||
| for i, hypotheses in enumerate(hypos): | |||
| best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1] | |||
| # 把上面替换为非eos的词替换回eos | |||
| if _eos_token_id!=-1: | |||
| best_hyp = torch.cat([best_hyp, best_hyp.new_ones(1)*_eos_token_id]) | |||
| tgt_len[i] = len(best_hyp) | |||
| best.append(best_hyp) | |||
| # generate target batch | |||
| decoded = token_ids.new_zeros(batch_size, tgt_len.max().item()).fill_(pad_token_id) | |||
| for i, hypo in enumerate(best): | |||
| decoded[i, :tgt_len[i]] = hypo | |||
| return decoded | |||
| class BeamHypotheses(object): | |||
| def __init__(self, num_beams, max_length, length_penalty, early_stopping): | |||
| """ | |||
| Initialize n-best list of hypotheses. | |||
| """ | |||
| self.max_length = max_length - 1 # ignoring bos_token | |||
| self.length_penalty = length_penalty | |||
| self.early_stopping = early_stopping | |||
| self.num_beams = num_beams | |||
| self.hyp = [] | |||
| self.worst_score = 1e9 | |||
| def __len__(self): | |||
| """ | |||
| Number of hypotheses in the list. | |||
| """ | |||
| return len(self.hyp) | |||
| def add(self, hyp, sum_logprobs): | |||
| """ | |||
| Add a new hypothesis to the list. | |||
| """ | |||
| score = sum_logprobs / len(hyp) ** self.length_penalty | |||
| if len(self) < self.num_beams or score > self.worst_score: | |||
| self.hyp.append((score, hyp)) | |||
| if len(self) > self.num_beams: | |||
| sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)]) | |||
| del self.hyp[sorted_scores[0][1]] | |||
| self.worst_score = sorted_scores[1][0] | |||
| else: | |||
| self.worst_score = min(score, self.worst_score) | |||
| def is_done(self, best_sum_logprobs): | |||
| """ | |||
| If there are enough hypotheses and that none of the hypotheses being generated | |||
| can become better than the worst one in the heap, then we are done with this sentence. | |||
| """ | |||
| if len(self) < self.num_beams: | |||
| return False | |||
| elif self.early_stopping: | |||
| return True | |||
| else: | |||
| return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty | |||
| def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1): | |||
| """ | |||
| 根据top_k, top_p的值,将不满足的值置为filter_value的值 | |||
| :param torch.Tensor logits: bsz x vocab_size | |||
| :param int top_k: 如果大于0,则只保留最top_k的词汇的概率,剩下的位置被置为filter_value | |||
| :param int top_p: 根据(http://arxiv.org/abs/1904.09751)设置的筛选方式 | |||
| :param float filter_value: | |||
| :param int min_tokens_to_keep: 每个sample返回的分布中有概率的词不会低于这个值 | |||
| :return: | |||
| """ | |||
| if top_k > 0: | |||
| top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check | |||
| # Remove all tokens with a probability less than the last token of the top-k | |||
| indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | |||
| logits[indices_to_remove] = filter_value | |||
| if top_p < 1.0: | |||
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |||
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |||
| # Remove tokens with cumulative probability above the threshold (token with 0 are kept) | |||
| sorted_indices_to_remove = cumulative_probs > top_p | |||
| if min_tokens_to_keep > 1: | |||
| # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) | |||
| sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 | |||
| # Shift the indices to the right to keep also the first token above the threshold | |||
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |||
| sorted_indices_to_remove[..., 0] = 0 | |||
| # scatter sorted tensors to original indexing | |||
| indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) | |||
| logits[indices_to_remove] = filter_value | |||
| return logits | |||
| @@ -0,0 +1,142 @@ | |||
| """ | |||
| 此模块可以非常方便的测试模型。 | |||
| 若你的模型属于:文本分类,序列标注,自然语言推理(NLI),可以直接使用此模块测试 | |||
| 若模型不属于上述类别,也可以自己准备假数据,设定loss和metric进行测试 | |||
| 此模块的测试仅保证模型能使用fastNLP进行训练和测试,不测试模型实际性能 | |||
| Example:: | |||
| # import 全大写变量... | |||
| from model_runner import * | |||
| # 测试一个文本分类模型 | |||
| init_emb = (VOCAB_SIZE, 50) | |||
| model = SomeModel(init_emb, num_cls=NUM_CLS) | |||
| RUNNER.run_model_with_task(TEXT_CLS, model) | |||
| # 序列标注模型 | |||
| RUNNER.run_model_with_task(POS_TAGGING, model) | |||
| # NLI模型 | |||
| RUNNER.run_model_with_task(NLI, model) | |||
| # 自定义模型 | |||
| RUNNER.run_model(model, data=get_mydata(), | |||
| loss=Myloss(), metrics=Mymetric()) | |||
| """ | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| if _NEED_IMPORT_TORCH: | |||
| from torch import optim | |||
| from fastNLP import Trainer, Evaluator, DataSet, Callback | |||
| from fastNLP import Accuracy | |||
| from random import randrange | |||
| from fastNLP import TorchDataLoader | |||
| VOCAB_SIZE = 100 | |||
| NUM_CLS = 100 | |||
| MAX_LEN = 10 | |||
| N_SAMPLES = 100 | |||
| N_EPOCHS = 1 | |||
| BATCH_SIZE = 5 | |||
| TEXT_CLS = 'text_cls' | |||
| POS_TAGGING = 'pos_tagging' | |||
| NLI = 'nli' | |||
| class ModelRunner(): | |||
| class Checker(Callback): | |||
| def on_backward_begin(self, trainer, outputs): | |||
| assert outputs['loss'].to('cpu').numpy().isfinate() | |||
| def gen_seq(self, length, vocab_size): | |||
| """generate fake sequence indexes with given length""" | |||
| # reserve 0 for padding | |||
| return [randrange(1, vocab_size) for _ in range(length)] | |||
| def gen_var_seq(self, max_len, vocab_size): | |||
| """generate fake sequence indexes in variant length""" | |||
| length = randrange(3, max_len) # at least 3 words in a seq | |||
| return self.gen_seq(length, vocab_size) | |||
| def prepare_text_classification_data(self): | |||
| index = 'index' | |||
| ds = DataSet({index: list(range(N_SAMPLES))}) | |||
| ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), | |||
| field_name=index, new_field_name='words') | |||
| ds.apply_field(lambda x: randrange(NUM_CLS), | |||
| field_name=index, new_field_name='target') | |||
| ds.apply_field(len, 'words', 'seq_len') | |||
| dl = TorchDataLoader(ds, batch_size=BATCH_SIZE) | |||
| return dl | |||
| def prepare_pos_tagging_data(self): | |||
| index = 'index' | |||
| ds = DataSet({index: list(range(N_SAMPLES))}) | |||
| ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), | |||
| field_name=index, new_field_name='words') | |||
| ds.apply_field(lambda x: self.gen_seq(len(x), NUM_CLS), | |||
| field_name='words', new_field_name='target') | |||
| ds.apply_field(len, 'words', 'seq_len') | |||
| dl = TorchDataLoader(ds, batch_size=BATCH_SIZE) | |||
| return dl | |||
| def prepare_nli_data(self): | |||
| index = 'index' | |||
| ds = DataSet({index: list(range(N_SAMPLES))}) | |||
| ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), | |||
| field_name=index, new_field_name='words1') | |||
| ds.apply_field(lambda x: self.gen_var_seq(MAX_LEN, VOCAB_SIZE), | |||
| field_name=index, new_field_name='words2') | |||
| ds.apply_field(lambda x: randrange(NUM_CLS), | |||
| field_name=index, new_field_name='target') | |||
| ds.apply_field(len, 'words1', 'seq_len1') | |||
| ds.apply_field(len, 'words2', 'seq_len2') | |||
| dl = TorchDataLoader(ds, batch_size=BATCH_SIZE) | |||
| return dl | |||
| def run_text_classification(self, model, data=None): | |||
| if data is None: | |||
| data = self.prepare_text_classification_data() | |||
| metric = Accuracy() | |||
| self.run_model(model, data, metric) | |||
| def run_pos_tagging(self, model, data=None): | |||
| if data is None: | |||
| data = self.prepare_pos_tagging_data() | |||
| metric = Accuracy() | |||
| self.run_model(model, data, metric) | |||
| def run_nli(self, model, data=None): | |||
| if data is None: | |||
| data = self.prepare_nli_data() | |||
| metric = Accuracy() | |||
| self.run_model(model, data, metric) | |||
| def run_model(self, model, data, metrics): | |||
| """run a model, test if it can run with fastNLP""" | |||
| print('testing model:', model.__class__.__name__) | |||
| tester = Evaluator(model, data, metrics={'metric': metrics}, driver='torch') | |||
| before_train = tester.run() | |||
| optimizer = optim.SGD(model.parameters(), lr=1e-3) | |||
| trainer = Trainer(model, driver='torch', train_dataloader=data, | |||
| n_epochs=N_EPOCHS, optimizers=optimizer) | |||
| trainer.run() | |||
| after_train = tester.run() | |||
| for metric_name, v1 in before_train.items(): | |||
| assert metric_name in after_train | |||
| # # at least we can sure model params changed, even if we don't know performance | |||
| # v2 = after_train[metric_name] | |||
| # assert v1 != v2 | |||
| def run_model_with_task(self, task, model): | |||
| """run a model with certain task""" | |||
| TASKS = { | |||
| TEXT_CLS: self.run_text_classification, | |||
| POS_TAGGING: self.run_pos_tagging, | |||
| NLI: self.run_nli, | |||
| } | |||
| assert task in TASKS | |||
| TASKS[task](model) | |||
| RUNNER = ModelRunner() | |||
| @@ -0,0 +1,91 @@ | |||
| import pytest | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| from fastNLP.models.torch.biaffine_parser import BiaffineParser | |||
| from fastNLP import Metric, seq_len_to_mask | |||
| from .model_runner import * | |||
| class ParserMetric(Metric): | |||
| r""" | |||
| 评估parser的性能 | |||
| """ | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.num_arc = 0 | |||
| self.num_label = 0 | |||
| self.num_sample = 0 | |||
| def get_metric(self, reset=True): | |||
| res = {'UAS': self.num_arc * 1.0 / self.num_sample, 'LAS': self.num_label * 1.0 / self.num_sample} | |||
| if reset: | |||
| self.num_sample = self.num_label = self.num_arc = 0 | |||
| return res | |||
| def update(self, pred1, pred2, target1, target2, seq_len=None): | |||
| r""" | |||
| :param pred1: 边预测logits | |||
| :param pred2: label预测logits | |||
| :param target1: 真实边的标注 | |||
| :param target2: 真实类别的标注 | |||
| :param seq_len: 序列长度 | |||
| :return dict: 评估结果:: | |||
| UAS: 不带label时, 边预测的准确率 | |||
| LAS: 同时预测边和label的准确率 | |||
| """ | |||
| if seq_len is None: | |||
| seq_mask = pred1.new_ones(pred1.size(), dtype=torch.long) | |||
| else: | |||
| seq_mask = seq_len_to_mask(seq_len.long()).long() | |||
| # mask out <root> tag | |||
| seq_mask[:, 0] = 0 | |||
| head_pred_correct = (pred1 == target1).long() * seq_mask | |||
| label_pred_correct = (pred2 == target2).long() * head_pred_correct | |||
| self.num_arc += head_pred_correct.sum().item() | |||
| self.num_label += label_pred_correct.sum().item() | |||
| self.num_sample += seq_mask.sum().item() | |||
| def prepare_parser_data(): | |||
| index = 'index' | |||
| ds = DataSet({index: list(range(N_SAMPLES))}) | |||
| ds.apply_field(lambda x: RUNNER.gen_var_seq(MAX_LEN, VOCAB_SIZE), | |||
| field_name=index, new_field_name='words1') | |||
| ds.apply_field(lambda x: RUNNER.gen_seq(len(x), NUM_CLS), | |||
| field_name='words1', new_field_name='words2') | |||
| # target1 is heads, should in range(0, len(words)) | |||
| ds.apply_field(lambda x: RUNNER.gen_seq(len(x), len(x)), | |||
| field_name='words1', new_field_name='target1') | |||
| ds.apply_field(lambda x: RUNNER.gen_seq(len(x), NUM_CLS), | |||
| field_name='words1', new_field_name='target2') | |||
| ds.apply_field(len, field_name='words1', new_field_name='seq_len') | |||
| dl = TorchDataLoader(ds, batch_size=BATCH_SIZE) | |||
| return dl | |||
| @pytest.mark.torch | |||
| class TestBiaffineParser: | |||
| def test_train(self): | |||
| model = BiaffineParser(embed=(VOCAB_SIZE, 10), | |||
| pos_vocab_size=VOCAB_SIZE, pos_emb_dim=10, | |||
| rnn_hidden_size=10, | |||
| arc_mlp_size=10, | |||
| label_mlp_size=10, | |||
| num_label=NUM_CLS, encoder='var-lstm') | |||
| ds = prepare_parser_data() | |||
| RUNNER.run_model(model, ds, metrics=ParserMetric()) | |||
| def test_train2(self): | |||
| model = BiaffineParser(embed=(VOCAB_SIZE, 10), | |||
| pos_vocab_size=VOCAB_SIZE, pos_emb_dim=10, | |||
| rnn_hidden_size=16, | |||
| arc_mlp_size=10, | |||
| label_mlp_size=10, | |||
| num_label=NUM_CLS, encoder='transformer') | |||
| ds = prepare_parser_data() | |||
| RUNNER.run_model(model, ds, metrics=ParserMetric()) | |||
| @@ -0,0 +1,33 @@ | |||
| import pytest | |||
| from .model_runner import * | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| if _NEED_IMPORT_TORCH: | |||
| from fastNLP.models.torch.cnn_text_classification import CNNText | |||
| @pytest.mark.torch | |||
| class TestCNNText: | |||
| def init_model(self, kernel_sizes, kernel_nums=(1,3,5)): | |||
| model = CNNText((VOCAB_SIZE, 30), | |||
| NUM_CLS, | |||
| kernel_nums=kernel_nums, | |||
| kernel_sizes=kernel_sizes) | |||
| return model | |||
| def test_case1(self): | |||
| # 测试能否正常运行CNN | |||
| model = self.init_model((1,3,5)) | |||
| RUNNER.run_model_with_task(TEXT_CLS, model) | |||
| def test_init_model(self): | |||
| with pytest.raises(Exception): | |||
| self.init_model(2, 4) | |||
| with pytest.raises(Exception): | |||
| self.init_model((2,)) | |||
| def test_output(self): | |||
| model = self.init_model((3,), (1,)) | |||
| global MAX_LEN | |||
| MAX_LEN = 2 | |||
| RUNNER.run_model_with_task(TEXT_CLS, model) | |||
| @@ -0,0 +1,73 @@ | |||
| import pytest | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| if _NEED_IMPORT_TORCH: | |||
| from fastNLP.models.torch import LSTMSeq2SeqModel, TransformerSeq2SeqModel | |||
| import torch | |||
| from fastNLP.embeddings.torch import StaticEmbedding | |||
| from fastNLP import Vocabulary, DataSet | |||
| from fastNLP import Trainer, Accuracy | |||
| from fastNLP import Callback, TorchDataLoader | |||
| def prepare_env(): | |||
| vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
| vocab.add_word_lst("Another test !".split()) | |||
| embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5) | |||
| src_words_idx = [[3, 1, 2], [1, 2]] | |||
| # tgt_words_idx = [[1, 2, 3, 4], [2, 3]] | |||
| src_seq_len = [3, 2] | |||
| # tgt_seq_len = [4, 2] | |||
| ds = DataSet({'src_tokens': src_words_idx, 'src_seq_len': src_seq_len, 'tgt_tokens': src_words_idx, | |||
| 'tgt_seq_len':src_seq_len}) | |||
| dl = TorchDataLoader(ds, batch_size=32) | |||
| return embed, dl | |||
| class ExitCallback(Callback): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def on_valid_end(self, trainer, results): | |||
| if results['acc#acc'] == 1: | |||
| raise KeyboardInterrupt() | |||
| @pytest.mark.torch | |||
| class TestSeq2SeqGeneratorModel: | |||
| def test_run(self): | |||
| # 检测是否能够使用SequenceGeneratorModel训练, 透传预测 | |||
| embed, dl = prepare_env() | |||
| model1 = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||
| pos_embed='sin', max_position=20, num_layers=2, d_model=30, n_head=6, | |||
| dim_ff=20, dropout=0.1, | |||
| bind_encoder_decoder_embed=True, | |||
| bind_decoder_input_output_embed=True) | |||
| optimizer = torch.optim.Adam(model1.parameters(), lr=1e-3) | |||
| trainer = Trainer(model1, driver='torch', optimizers=optimizer, train_dataloader=dl, | |||
| n_epochs=100, evaluate_dataloaders=dl, metrics={'acc': Accuracy()}, | |||
| evaluate_input_mapping=lambda x: {'target': x['tgt_tokens'], | |||
| 'seq_len': x['tgt_seq_len'], | |||
| **x}, | |||
| callbacks=ExitCallback()) | |||
| trainer.run() | |||
| embed, dl = prepare_env() | |||
| model2 = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||
| num_layers=1, hidden_size=20, dropout=0.1, | |||
| bind_encoder_decoder_embed=True, | |||
| bind_decoder_input_output_embed=True, attention=True) | |||
| optimizer = torch.optim.Adam(model2.parameters(), lr=0.01) | |||
| trainer = Trainer(model2, driver='torch', optimizers=optimizer, train_dataloader=dl, | |||
| n_epochs=100, evaluate_dataloaders=dl, metrics={'acc': Accuracy()}, | |||
| evaluate_input_mapping=lambda x: {'target': x['tgt_tokens'], | |||
| 'seq_len': x['tgt_seq_len'], | |||
| **x}, | |||
| callbacks=ExitCallback()) | |||
| trainer.run() | |||
| @@ -0,0 +1,113 @@ | |||
| import pytest | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| if _NEED_IMPORT_TORCH: | |||
| from fastNLP.models.torch.seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel | |||
| from fastNLP import Vocabulary | |||
| from fastNLP.embeddings.torch import StaticEmbedding | |||
| import torch | |||
| from torch import optim | |||
| import torch.nn.functional as F | |||
| from fastNLP import seq_len_to_mask | |||
| def prepare_env(): | |||
| vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
| vocab.add_word_lst("Another test !".split()) | |||
| embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5) | |||
| src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]]) | |||
| tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) | |||
| src_seq_len = torch.LongTensor([3, 2]) | |||
| tgt_seq_len = torch.LongTensor([4, 2]) | |||
| return embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len | |||
| def train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len): | |||
| optimizer = optim.Adam(model.parameters(), lr=1e-2) | |||
| mask = seq_len_to_mask(tgt_seq_len).eq(0) | |||
| target = tgt_words_idx.masked_fill(mask, -100) | |||
| for i in range(100): | |||
| optimizer.zero_grad() | |||
| pred = model(src_words_idx, tgt_words_idx, src_seq_len)['pred'] # bsz x max_len x vocab_size | |||
| loss = F.cross_entropy(pred.transpose(1, 2), target) | |||
| loss.backward() | |||
| optimizer.step() | |||
| right_count = pred.argmax(dim=-1).eq(target).masked_fill(mask, 1).sum() | |||
| return right_count | |||
| @pytest.mark.torch | |||
| class TestTransformerSeq2SeqModel: | |||
| def test_run(self): | |||
| # 测试能否跑通 | |||
| embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() | |||
| for pos_embed in ['learned', 'sin']: | |||
| model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||
| pos_embed=pos_embed, max_position=20, num_layers=2, d_model=30, n_head=6, dim_ff=20, dropout=0.1, | |||
| bind_encoder_decoder_embed=True, | |||
| bind_decoder_input_output_embed=True) | |||
| output = model(src_words_idx, tgt_words_idx, src_seq_len) | |||
| assert (output['pred'].size() == (2, 4, len(embed))) | |||
| for bind_encoder_decoder_embed in [True, False]: | |||
| tgt_embed = None | |||
| for bind_decoder_input_output_embed in [True, False]: | |||
| if bind_encoder_decoder_embed == False: | |||
| tgt_embed = embed | |||
| model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=tgt_embed, | |||
| pos_embed='sin', max_position=20, num_layers=2, | |||
| d_model=30, n_head=6, dim_ff=20, dropout=0.1, | |||
| bind_encoder_decoder_embed=bind_encoder_decoder_embed, | |||
| bind_decoder_input_output_embed=bind_decoder_input_output_embed) | |||
| output = model(src_words_idx, tgt_words_idx, src_seq_len) | |||
| assert (output['pred'].size() == (2, 4, len(embed))) | |||
| def test_train(self): | |||
| # 测试能否train到overfit | |||
| embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() | |||
| model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||
| pos_embed='sin', max_position=20, num_layers=2, d_model=30, n_head=6, dim_ff=20, dropout=0.1, | |||
| bind_encoder_decoder_embed=True, | |||
| bind_decoder_input_output_embed=True) | |||
| right_count = train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len) | |||
| assert(right_count == tgt_words_idx.nelement()) | |||
| @pytest.mark.torch | |||
| class TestLSTMSeq2SeqModel: | |||
| def test_run(self): | |||
| # 测试能否跑通 | |||
| embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() | |||
| for bind_encoder_decoder_embed in [True, False]: | |||
| tgt_embed = None | |||
| for bind_decoder_input_output_embed in [True, False]: | |||
| if bind_encoder_decoder_embed == False: | |||
| tgt_embed = embed | |||
| model = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=tgt_embed, | |||
| num_layers=2, hidden_size=20, dropout=0.1, | |||
| bind_encoder_decoder_embed=bind_encoder_decoder_embed, | |||
| bind_decoder_input_output_embed=bind_decoder_input_output_embed) | |||
| output = model(src_words_idx, tgt_words_idx, src_seq_len) | |||
| assert (output['pred'].size() == (2, 4, len(embed))) | |||
| def test_train(self): | |||
| embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() | |||
| model = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||
| num_layers=1, hidden_size=20, dropout=0.1, | |||
| bind_encoder_decoder_embed=True, | |||
| bind_decoder_input_output_embed=True) | |||
| right_count = train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len) | |||
| assert (right_count == tgt_words_idx.nelement()) | |||
| @@ -0,0 +1,47 @@ | |||
| import pytest | |||
| from .model_runner import * | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| if _NEED_IMPORT_TORCH: | |||
| from fastNLP.models.torch.sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF | |||
| @pytest.mark.torch | |||
| class TestBiLSTM: | |||
| def test_case1(self): | |||
| # 测试能否正常运行CNN | |||
| init_emb = (VOCAB_SIZE, 30) | |||
| model = BiLSTMCRF(init_emb, | |||
| hidden_size=30, | |||
| num_classes=NUM_CLS) | |||
| dl = RUNNER.prepare_pos_tagging_data() | |||
| metric = Accuracy() | |||
| RUNNER.run_model(model, dl, metric) | |||
| @pytest.mark.torch | |||
| class TestSeqLabel: | |||
| def test_case1(self): | |||
| # 测试能否正常运行CNN | |||
| init_emb = (VOCAB_SIZE, 30) | |||
| model = SeqLabeling(init_emb, | |||
| hidden_size=30, | |||
| num_classes=NUM_CLS) | |||
| dl = RUNNER.prepare_pos_tagging_data() | |||
| metric = Accuracy() | |||
| RUNNER.run_model(model, dl, metric) | |||
| @pytest.mark.torch | |||
| class TestAdvSeqLabel: | |||
| def test_case1(self): | |||
| # 测试能否正常运行CNN | |||
| init_emb = (VOCAB_SIZE, 30) | |||
| model = AdvSeqLabel(init_emb, | |||
| hidden_size=30, | |||
| num_classes=NUM_CLS) | |||
| dl = RUNNER.prepare_pos_tagging_data() | |||
| metric = Accuracy() | |||
| RUNNER.run_model(model, dl, metric) | |||
| @@ -0,0 +1,327 @@ | |||
| import pytest | |||
| import os | |||
| from fastNLP import Vocabulary | |||
| @pytest.mark.torch | |||
| class TestCRF: | |||
| def test_case1(self): | |||
| from fastNLP.modules.torch.decoder.crf import allowed_transitions | |||
| # 检查allowed_transitions()能否正确使用 | |||
| id2label = {0: 'B', 1: 'I', 2:'O'} | |||
| expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), | |||
| (2, 4), (3, 0), (3, 2)} | |||
| assert expected_res == set(allowed_transitions(id2label, include_start_end=True)) | |||
| id2label = {0: 'B', 1:'M', 2:'E', 3:'S'} | |||
| expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} | |||
| assert (expected_res == set( | |||
| allowed_transitions(id2label, encoding_type='BMES', include_start_end=True))) | |||
| id2label = {0: 'B', 1: 'I', 2:'O', 3: '<pad>', 4:"<unk>"} | |||
| allowed_transitions(id2label, include_start_end=True) | |||
| labels = ['O'] | |||
| for label in ['X', 'Y']: | |||
| for tag in 'BI': | |||
| labels.append('{}-{}'.format(tag, label)) | |||
| id2label = {idx:label for idx, label in enumerate(labels)} | |||
| expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), | |||
| (2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), | |||
| (4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} | |||
| assert (expected_res == set(allowed_transitions(id2label, include_start_end=True))) | |||
| labels = [] | |||
| for label in ['X', 'Y']: | |||
| for tag in 'BMES': | |||
| labels.append('{}-{}'.format(tag, label)) | |||
| id2label = {idx: label for idx, label in enumerate(labels)} | |||
| expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), | |||
| (3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), | |||
| (7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} | |||
| assert (expected_res == set( | |||
| allowed_transitions(id2label, include_start_end=True))) | |||
| def test_case11(self): | |||
| # 测试自动推断encoding类型 | |||
| from fastNLP.modules.torch.decoder.crf import allowed_transitions | |||
| id2label = {0: 'B', 1: 'I', 2: 'O'} | |||
| expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), | |||
| (2, 4), (3, 0), (3, 2)} | |||
| assert (expected_res == set(allowed_transitions(id2label, include_start_end=True))) | |||
| id2label = {0: 'B', 1: 'M', 2: 'E', 3: 'S'} | |||
| expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} | |||
| assert (expected_res == set( | |||
| allowed_transitions(id2label, include_start_end=True))) | |||
| id2label = {0: 'B', 1: 'I', 2: 'O', 3: '<pad>', 4: "<unk>"} | |||
| allowed_transitions(id2label, include_start_end=True) | |||
| labels = ['O'] | |||
| for label in ['X', 'Y']: | |||
| for tag in 'BI': | |||
| labels.append('{}-{}'.format(tag, label)) | |||
| id2label = {idx: label for idx, label in enumerate(labels)} | |||
| expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), | |||
| (2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), | |||
| (4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} | |||
| assert (expected_res == set(allowed_transitions(id2label, include_start_end=True))) | |||
| labels = [] | |||
| for label in ['X', 'Y']: | |||
| for tag in 'BMES': | |||
| labels.append('{}-{}'.format(tag, label)) | |||
| id2label = {idx: label for idx, label in enumerate(labels)} | |||
| expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), | |||
| (3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), | |||
| (7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} | |||
| assert (expected_res == set( | |||
| allowed_transitions(id2label, include_start_end=True))) | |||
| def test_case12(self): | |||
| # 测试能否通过vocab生成转移矩阵 | |||
| from fastNLP.modules.torch.decoder.crf import allowed_transitions | |||
| id2label = {0: 'B', 1: 'I', 2: 'O'} | |||
| vocab = Vocabulary(unknown=None, padding=None) | |||
| for idx, tag in id2label.items(): | |||
| vocab.add_word(tag) | |||
| expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), | |||
| (2, 4), (3, 0), (3, 2)} | |||
| assert (expected_res == set(allowed_transitions(vocab, include_start_end=True))) | |||
| id2label = {0: 'B', 1: 'M', 2: 'E', 3: 'S'} | |||
| vocab = Vocabulary(unknown=None, padding=None) | |||
| for idx, tag in id2label.items(): | |||
| vocab.add_word(tag) | |||
| expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} | |||
| assert (expected_res == set( | |||
| allowed_transitions(vocab, include_start_end=True))) | |||
| id2label = {0: 'B', 1: 'I', 2: 'O', 3: '<pad>', 4: "<unk>"} | |||
| vocab = Vocabulary() | |||
| for idx, tag in id2label.items(): | |||
| vocab.add_word(tag) | |||
| allowed_transitions(vocab, include_start_end=True) | |||
| labels = ['O'] | |||
| for label in ['X', 'Y']: | |||
| for tag in 'BI': | |||
| labels.append('{}-{}'.format(tag, label)) | |||
| id2label = {idx: label for idx, label in enumerate(labels)} | |||
| expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), | |||
| (2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), | |||
| (4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} | |||
| vocab = Vocabulary(unknown=None, padding=None) | |||
| for idx, tag in id2label.items(): | |||
| vocab.add_word(tag) | |||
| assert (expected_res == set(allowed_transitions(vocab, include_start_end=True))) | |||
| labels = [] | |||
| for label in ['X', 'Y']: | |||
| for tag in 'BMES': | |||
| labels.append('{}-{}'.format(tag, label)) | |||
| id2label = {idx: label for idx, label in enumerate(labels)} | |||
| vocab = Vocabulary(unknown=None, padding=None) | |||
| for idx, tag in id2label.items(): | |||
| vocab.add_word(tag) | |||
| expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), | |||
| (3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), | |||
| (7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} | |||
| assert (expected_res == set( | |||
| allowed_transitions(vocab, include_start_end=True))) | |||
| # def test_case2(self): | |||
| # # 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。 | |||
| # pass | |||
| # import torch | |||
| # from fastNLP import seq_len_to_mask | |||
| # | |||
| # labels = ['O'] | |||
| # for label in ['X', 'Y']: | |||
| # for tag in 'BI': | |||
| # labels.append('{}-{}'.format(tag, label)) | |||
| # id2label = {idx: label for idx, label in enumerate(labels)} | |||
| # num_tags = len(id2label) | |||
| # max_len = 10 | |||
| # batch_size = 4 | |||
| # bio_logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, max_len, num_tags)), dim=-1).log() | |||
| # from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions | |||
| # allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BIO', id2label), | |||
| # include_start_end_transitions=False) | |||
| # bio_trans_m = allen_CRF.transitions | |||
| # bio_seq_lens = torch.randint(1, max_len, size=(batch_size,)) | |||
| # bio_seq_lens[0] = 1 | |||
| # bio_seq_lens[-1] = max_len | |||
| # mask = seq_len_to_mask(bio_seq_lens) | |||
| # allen_res = allen_CRF.viterbi_tags(bio_logits, mask) | |||
| # | |||
| # from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions | |||
| # fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, | |||
| # include_start_end=True)) | |||
| # fast_CRF.trans_m = bio_trans_m | |||
| # fast_res = fast_CRF.viterbi_decode(bio_logits, mask, unpad=True) | |||
| # bio_scores = [round(score, 4) for _, score in allen_res] | |||
| # # score equal | |||
| # self.assertListEqual(bio_scores, [round(s, 4) for s in fast_res[1].tolist()]) | |||
| # # seq equal | |||
| # bio_path = [_ for _, score in allen_res] | |||
| # self.assertListEqual(bio_path, fast_res[0]) | |||
| # | |||
| # labels = [] | |||
| # for label in ['X', 'Y']: | |||
| # for tag in 'BMES': | |||
| # labels.append('{}-{}'.format(tag, label)) | |||
| # id2label = {idx: label for idx, label in enumerate(labels)} | |||
| # num_tags = len(id2label) | |||
| # | |||
| # from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions | |||
| # allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BMES', id2label), | |||
| # include_start_end_transitions=False) | |||
| # bmes_logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, max_len, num_tags)), dim=-1).log() | |||
| # bmes_trans_m = allen_CRF.transitions | |||
| # bmes_seq_lens = torch.randint(1, max_len, size=(batch_size,)) | |||
| # bmes_seq_lens[0] = 1 | |||
| # bmes_seq_lens[-1] = max_len | |||
| # mask = seq_len_to_mask(bmes_seq_lens) | |||
| # allen_res = allen_CRF.viterbi_tags(bmes_logits, mask) | |||
| # | |||
| # from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions | |||
| # fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, | |||
| # encoding_type='BMES', | |||
| # include_start_end=True)) | |||
| # fast_CRF.trans_m = bmes_trans_m | |||
| # fast_res = fast_CRF.viterbi_decode(bmes_logits, mask, unpad=True) | |||
| # # score equal | |||
| # bmes_scores = [round(score, 4) for _, score in allen_res] | |||
| # self.assertListEqual(bmes_scores, [round(s, 4) for s in fast_res[1].tolist()]) | |||
| # # seq equal | |||
| # bmes_path = [_ for _, score in allen_res] | |||
| # self.assertListEqual(bmes_path, fast_res[0]) | |||
| # | |||
| # data = { | |||
| # 'bio_logits': bio_logits.tolist(), | |||
| # 'bio_scores': bio_scores, | |||
| # 'bio_path': bio_path, | |||
| # 'bio_trans_m': bio_trans_m.tolist(), | |||
| # 'bio_seq_lens': bio_seq_lens.tolist(), | |||
| # 'bmes_logits': bmes_logits.tolist(), | |||
| # 'bmes_scores': bmes_scores, | |||
| # 'bmes_path': bmes_path, | |||
| # 'bmes_trans_m': bmes_trans_m.tolist(), | |||
| # 'bmes_seq_lens': bmes_seq_lens.tolist(), | |||
| # } | |||
| # | |||
| # with open('weights.json', 'w') as f: | |||
| # import json | |||
| # json.dump(data, f) | |||
| def test_case2(self): | |||
| # 测试CRF是否正常work。 | |||
| import json | |||
| import torch | |||
| from fastNLP import seq_len_to_mask | |||
| folder = os.path.dirname(os.path.abspath(__file__)) | |||
| path = os.path.join(folder, '../../../', 'helpers/data/modules/decoder/crf.json') | |||
| with open(os.path.abspath(path), 'r') as f: | |||
| data = json.load(f) | |||
| bio_logits = torch.FloatTensor(data['bio_logits']) | |||
| bio_scores = data['bio_scores'] | |||
| bio_path = data['bio_path'] | |||
| bio_trans_m = torch.FloatTensor(data['bio_trans_m']) | |||
| bio_seq_lens = torch.LongTensor(data['bio_seq_lens']) | |||
| bmes_logits = torch.FloatTensor(data['bmes_logits']) | |||
| bmes_scores = data['bmes_scores'] | |||
| bmes_path = data['bmes_path'] | |||
| bmes_trans_m = torch.FloatTensor(data['bmes_trans_m']) | |||
| bmes_seq_lens = torch.LongTensor(data['bmes_seq_lens']) | |||
| labels = ['O'] | |||
| for label in ['X', 'Y']: | |||
| for tag in 'BI': | |||
| labels.append('{}-{}'.format(tag, label)) | |||
| id2label = {idx: label for idx, label in enumerate(labels)} | |||
| num_tags = len(id2label) | |||
| mask = seq_len_to_mask(bio_seq_lens) | |||
| from fastNLP.modules.torch.decoder.crf import ConditionalRandomField, allowed_transitions | |||
| fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, | |||
| include_start_end=True)) | |||
| fast_CRF.trans_m.data = bio_trans_m | |||
| fast_res = fast_CRF.viterbi_decode(bio_logits, mask, unpad=True) | |||
| # score equal | |||
| assert (bio_scores == [round(s, 4) for s in fast_res[1].tolist()]) | |||
| # seq equal | |||
| assert (bio_path == fast_res[0]) | |||
| labels = [] | |||
| for label in ['X', 'Y']: | |||
| for tag in 'BMES': | |||
| labels.append('{}-{}'.format(tag, label)) | |||
| id2label = {idx: label for idx, label in enumerate(labels)} | |||
| num_tags = len(id2label) | |||
| mask = seq_len_to_mask(bmes_seq_lens) | |||
| from fastNLP.modules.torch.decoder.crf import ConditionalRandomField, allowed_transitions | |||
| fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, | |||
| encoding_type='BMES', | |||
| include_start_end=True)) | |||
| fast_CRF.trans_m.data = bmes_trans_m | |||
| fast_res = fast_CRF.viterbi_decode(bmes_logits, mask, unpad=True) | |||
| # score equal | |||
| assert (bmes_scores == [round(s, 4) for s in fast_res[1].tolist()]) | |||
| # seq equal | |||
| assert (bmes_path == fast_res[0]) | |||
| def test_case3(self): | |||
| # 测试crf的loss不会出现负数 | |||
| import torch | |||
| from fastNLP.modules.torch.decoder.crf import ConditionalRandomField | |||
| from fastNLP.core.utils import seq_len_to_mask | |||
| from torch import optim | |||
| from torch import nn | |||
| num_tags, include_start_end_trans = 4, True | |||
| num_samples = 4 | |||
| lengths = torch.randint(3, 50, size=(num_samples, )).long() | |||
| max_len = lengths.max() | |||
| tags = torch.randint(num_tags, size=(num_samples, max_len)) | |||
| masks = seq_len_to_mask(lengths) | |||
| feats = nn.Parameter(torch.randn(num_samples, max_len, num_tags)) | |||
| crf = ConditionalRandomField(num_tags, include_start_end_trans) | |||
| optimizer = optim.SGD([param for param in crf.parameters() if param.requires_grad] + [feats], lr=0.1) | |||
| for _ in range(10): | |||
| loss = crf(feats, tags, masks).mean() | |||
| optimizer.zero_grad() | |||
| loss.backward() | |||
| optimizer.step() | |||
| if _%1000==0: | |||
| print(loss) | |||
| assert (loss.item()> 0) | |||
| def test_masking(self): | |||
| # 测试crf的pad masking正常运行 | |||
| import torch | |||
| from fastNLP.modules.torch.decoder.crf import ConditionalRandomField | |||
| max_len = 5 | |||
| n_tags = 5 | |||
| pad_len = 5 | |||
| torch.manual_seed(4) | |||
| logit = torch.rand(1, max_len+pad_len, n_tags) | |||
| # logit[0, -1, :] = 0.0 | |||
| mask = torch.ones(1, max_len+pad_len) | |||
| mask[0,-pad_len] = 0 | |||
| model = ConditionalRandomField(n_tags) | |||
| pred, score = model.viterbi_decode(logit[:,:-pad_len], mask[:,:-pad_len]) | |||
| mask_pred, mask_score = model.viterbi_decode(logit, mask) | |||
| assert (pred[0].tolist() == mask_pred[0,:-pad_len].tolist()) | |||
| @@ -0,0 +1,49 @@ | |||
| import pytest | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| from fastNLP import Vocabulary | |||
| from fastNLP.embeddings.torch import StaticEmbedding | |||
| from fastNLP.modules.torch import TransformerSeq2SeqDecoder | |||
| from fastNLP.modules.torch import LSTMSeq2SeqDecoder | |||
| from fastNLP import seq_len_to_mask | |||
| @pytest.mark.torch | |||
| class TestTransformerSeq2SeqDecoder: | |||
| @pytest.mark.parametrize("flag", [True, False]) | |||
| def test_case(self, flag): | |||
| vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
| vocab.add_word_lst("Another test !".split()) | |||
| embed = StaticEmbedding(vocab, embedding_dim=10) | |||
| encoder_output = torch.randn(2, 3, 10) | |||
| src_seq_len = torch.LongTensor([3, 2]) | |||
| encoder_mask = seq_len_to_mask(src_seq_len) | |||
| decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed = None, | |||
| d_model = 10, num_layers=2, n_head = 5, dim_ff = 20, dropout = 0.1, | |||
| bind_decoder_input_output_embed = True) | |||
| state = decoder.init_state(encoder_output, encoder_mask) | |||
| output = decoder(tokens=torch.randint(0, len(vocab), size=(2, 4)), state=state) | |||
| assert (output.size() == (2, 4, len(vocab))) | |||
| @pytest.mark.torch | |||
| class TestLSTMDecoder: | |||
| @pytest.mark.parametrize("flag", [True, False]) | |||
| @pytest.mark.parametrize("attention", [True, False]) | |||
| def test_case(self, flag, attention): | |||
| vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
| vocab.add_word_lst("Another test !".split()) | |||
| embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=10) | |||
| encoder_output = torch.randn(2, 3, 10) | |||
| tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) | |||
| src_seq_len = torch.LongTensor([3, 2]) | |||
| encoder_mask = seq_len_to_mask(src_seq_len) | |||
| decoder = LSTMSeq2SeqDecoder(embed=embed, num_layers = 2, hidden_size = 10, | |||
| dropout = 0.3, bind_decoder_input_output_embed=flag, attention=attention) | |||
| state = decoder.init_state(encoder_output, encoder_mask) | |||
| output = decoder(tgt_words_idx, state) | |||
| assert tuple(output.size()) == (2, 4, len(vocab)) | |||
| @@ -0,0 +1,33 @@ | |||
| import pytest | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| from fastNLP.modules.torch.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder | |||
| from fastNLP import Vocabulary | |||
| from fastNLP.embeddings.torch import StaticEmbedding | |||
| class TestTransformerSeq2SeqEncoder: | |||
| def test_case(self): | |||
| vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
| embed = StaticEmbedding(vocab, embedding_dim=5) | |||
| encoder = TransformerSeq2SeqEncoder(embed, num_layers=2, d_model=10, n_head=2) | |||
| words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0) | |||
| seq_len = torch.LongTensor([3]) | |||
| encoder_output, encoder_mask = encoder(words_idx, seq_len) | |||
| assert (encoder_output.size() == (1, 3, 10)) | |||
| class TestBiLSTMEncoder: | |||
| def test_case(self): | |||
| vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
| embed = StaticEmbedding(vocab, embedding_dim=5) | |||
| encoder = LSTMSeq2SeqEncoder(embed, hidden_size=5, num_layers=1) | |||
| words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0) | |||
| seq_len = torch.LongTensor([3]) | |||
| encoder_output, encoder_mask = encoder(words_idx, seq_len) | |||
| assert (encoder_mask.size() == (1, 3)) | |||
| @@ -0,0 +1,18 @@ | |||
| import pytest | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| from fastNLP.modules.torch.encoder.star_transformer import StarTransformer | |||
| @pytest.mark.torch | |||
| class TestStarTransformer: | |||
| def test_1(self): | |||
| model = StarTransformer(num_layers=6, hidden_size=100, num_head=8, head_dim=20, max_len=100) | |||
| x = torch.rand(16, 45, 100) | |||
| mask = torch.ones(16, 45).byte() | |||
| y, yn = model(x, mask) | |||
| assert (tuple(y.size()) == (16, 45, 100)) | |||
| assert (tuple(yn.size()) == (16, 100)) | |||
| @@ -0,0 +1,27 @@ | |||
| import pytest | |||
| import numpy as np | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| from fastNLP.modules.torch.encoder.variational_rnn import VarLSTM | |||
| @pytest.mark.torch | |||
| class TestMaskedRnn: | |||
| def test_case_1(self): | |||
| masked_rnn = VarLSTM(input_size=1, hidden_size=1, bidirectional=True, batch_first=True) | |||
| x = torch.tensor([[[1.0], [2.0]]]) | |||
| print(x.size()) | |||
| y = masked_rnn(x) | |||
| def test_case_2(self): | |||
| input_size = 12 | |||
| batch = 16 | |||
| hidden = 10 | |||
| masked_rnn = VarLSTM(input_size=input_size, hidden_size=hidden, bidirectional=False, batch_first=True) | |||
| xx = torch.randn((batch, 32, input_size)) | |||
| y, _ = masked_rnn(xx) | |||
| assert(tuple(y.shape) == (batch, 32, hidden)) | |||
| @@ -0,0 +1 @@ | |||
| @@ -0,0 +1,146 @@ | |||
| import pytest | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| from fastNLP.modules.torch.generator import SequenceGenerator | |||
| from fastNLP.modules.torch import TransformerSeq2SeqDecoder, LSTMSeq2SeqDecoder, Seq2SeqDecoder, State | |||
| from fastNLP import Vocabulary | |||
| from fastNLP.embeddings.torch import StaticEmbedding | |||
| from torch import nn | |||
| from fastNLP import seq_len_to_mask | |||
| else: | |||
| from fastNLP.core.utils.dummy_class import DummyClass as State | |||
| from fastNLP.core.utils.dummy_class import DummyClass as Seq2SeqDecoder | |||
| def prepare_env(): | |||
| vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
| vocab.add_word_lst("Another test !".split()) | |||
| embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5) | |||
| encoder_output = torch.randn(2, 3, 10) | |||
| src_seq_len = torch.LongTensor([3, 2]) | |||
| encoder_mask = seq_len_to_mask(src_seq_len) | |||
| return embed, encoder_output, encoder_mask | |||
| class GreedyDummyDecoder(Seq2SeqDecoder): | |||
| def __init__(self, decoder_output): | |||
| super().__init__() | |||
| self.cur_length = 0 | |||
| self.decoder_output = decoder_output | |||
| def decode(self, tokens, state): | |||
| self.cur_length += 1 | |||
| scores = self.decoder_output[:, self.cur_length] | |||
| return scores | |||
| class DummyState(State): | |||
| def __init__(self, decoder): | |||
| super().__init__() | |||
| self.decoder = decoder | |||
| def reorder_state(self, indices: torch.LongTensor): | |||
| self.decoder.decoder_output = self._reorder_state(self.decoder.decoder_output, indices, dim=0) | |||
| @pytest.mark.torch | |||
| class TestSequenceGenerator: | |||
| def test_run(self): | |||
| # 测试能否运行 (1) 初始化decoder,(2) decode一发 | |||
| embed, encoder_output, encoder_mask = prepare_env() | |||
| for do_sample in [True, False]: | |||
| for num_beams in [1, 3, 5]: | |||
| decoder = LSTMSeq2SeqDecoder(embed=embed, num_layers=1, hidden_size=10, | |||
| dropout=0.3, bind_decoder_input_output_embed=True, attention=True) | |||
| state = decoder.init_state(encoder_output, encoder_mask) | |||
| generator = SequenceGenerator(decoder=decoder, max_length=20, num_beams=num_beams, | |||
| do_sample=do_sample, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=1, eos_token_id=None, | |||
| repetition_penalty=1, length_penalty=1.0, pad_token_id=0) | |||
| generator.generate(state=state, tokens=None) | |||
| decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed=nn.Embedding(10, embed.embedding_dim), | |||
| d_model=encoder_output.size(-1), num_layers=2, n_head=2, dim_ff=10, dropout=0.1, | |||
| bind_decoder_input_output_embed=True) | |||
| state = decoder.init_state(encoder_output, encoder_mask) | |||
| generator = SequenceGenerator(decoder=decoder, max_length=5, num_beams=num_beams, | |||
| do_sample=do_sample, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=1, eos_token_id=None, | |||
| repetition_penalty=1, length_penalty=1.0, pad_token_id=0) | |||
| generator.generate(state=state, tokens=None) | |||
| # 测试一下其它值 | |||
| decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed=nn.Embedding(10, embed.embedding_dim), | |||
| d_model=encoder_output.size(-1), num_layers=2, n_head=2, dim_ff=10, | |||
| dropout=0.1, | |||
| bind_decoder_input_output_embed=True) | |||
| state = decoder.init_state(encoder_output, encoder_mask) | |||
| generator = SequenceGenerator(decoder=decoder, max_length=5, num_beams=num_beams, | |||
| do_sample=do_sample, temperature=0.9, top_k=50, top_p=0.5, bos_token_id=1, | |||
| eos_token_id=3, repetition_penalty=2, length_penalty=1.5, pad_token_id=0) | |||
| generator.generate(state=state, tokens=None) | |||
| def test_greedy_decode(self): | |||
| # 测试能否正确的generate | |||
| # greedy | |||
| for beam_search in [1, 3]: | |||
| decoder_output = torch.randn(2, 10, 5) | |||
| path = decoder_output.argmax(dim=-1) # 2 x 10 | |||
| decoder = GreedyDummyDecoder(decoder_output) | |||
| generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search, | |||
| do_sample=False, temperature=1, top_k=50, top_p=1, bos_token_id=1, | |||
| eos_token_id=None, repetition_penalty=1, length_penalty=1, pad_token_id=0) | |||
| decode_path = generator.generate(DummyState(decoder), tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True)) | |||
| assert (decode_path.eq(path).sum() == path.numel()) | |||
| # greedy check eos_token_id | |||
| for beam_search in [1, 3]: | |||
| decoder_output = torch.randn(2, 10, 5) | |||
| decoder_output[:, :7, 4].fill_(-100) | |||
| decoder_output[0, 7, 4] = 1000 # 在第8个结束 | |||
| decoder_output[1, 5, 4] = 1000 | |||
| path = decoder_output.argmax(dim=-1) # 2 x 4 | |||
| decoder = GreedyDummyDecoder(decoder_output) | |||
| generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search, | |||
| do_sample=False, temperature=1, top_k=50, top_p=0.5, bos_token_id=1, | |||
| eos_token_id=4, repetition_penalty=1, length_penalty=1, pad_token_id=0) | |||
| decode_path = generator.generate(DummyState(decoder), | |||
| tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True)) | |||
| assert (decode_path.size(1) == 8) # 长度为8 | |||
| assert (decode_path[0].eq(path[0, :8]).sum() == 8) | |||
| assert (decode_path[1, :6].eq(path[1, :6]).sum() == 6) | |||
| def test_sample_decoder(self): | |||
| # greedy check eos_token_id | |||
| for beam_search in [1, 3]: | |||
| decode_paths = [] | |||
| # 因为是随机,所以需要测试100次,如果至少有一次是对的,应该就问题不大 | |||
| num_tests = 10 | |||
| for i in range(num_tests): | |||
| decoder_output = torch.randn(2, 10, 5) * 10 | |||
| decoder_output[:, :7, 4].fill_(-100) | |||
| decoder_output[0, 7, 4] = 10000 # 在第8个结束 | |||
| decoder_output[1, 5, 4] = 10000 | |||
| path = decoder_output.argmax(dim=-1) # 2 x 4 | |||
| decoder = GreedyDummyDecoder(decoder_output) | |||
| generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search, | |||
| do_sample=True, temperature=1, top_k=50, top_p=0.5, bos_token_id=1, | |||
| eos_token_id=4, repetition_penalty=1, length_penalty=1, pad_token_id=0) | |||
| decode_path = generator.generate(DummyState(decoder), | |||
| tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True)) | |||
| decode_paths.append([decode_path, path]) | |||
| sizes = [] | |||
| eqs = [] | |||
| eq2s = [] | |||
| for i in range(num_tests): | |||
| decode_path, path = decode_paths[i] | |||
| sizes.append(decode_path.size(1)==8) | |||
| eqs.append(decode_path[0].eq(path[0, :8]).sum()==8) | |||
| eq2s.append(decode_path[1, :6].eq(path[1, :6]).sum()==6) | |||
| assert any(sizes) | |||
| assert any(eqs) | |||
| assert any(eq2s) | |||