| @@ -16,6 +16,7 @@ from ._logger import logger | |||
| from .dataset import DataSet | |||
| from .utils import Option | |||
| from .utils import _is_iterable | |||
| import io | |||
| class VocabularyOption(Option): | |||
| @@ -487,76 +488,99 @@ class Vocabulary(object): | |||
| def save(self, filepath): | |||
| r""" | |||
| :param str filepath: Vocabulary的储存路径 | |||
| :param str,io.StringIO filepath: Vocabulary的储存路径 | |||
| :return: | |||
| """ | |||
| with open(filepath, 'w', encoding='utf-8') as f: | |||
| f.write(f'max_size\t{self.max_size}\n') | |||
| f.write(f'min_freq\t{self.min_freq}\n') | |||
| f.write(f'unknown\t{self.unknown}\n') | |||
| f.write(f'padding\t{self.padding}\n') | |||
| f.write(f'rebuild\t{self.rebuild}\n') | |||
| f.write('\n') | |||
| # idx: 如果idx为-2, 说明还没有进行build; 如果idx为-1,说明该词未编入 | |||
| # no_create_entry: 如果为1,说明该词是no_create_entry; 0 otherwise | |||
| # word \t count \t idx \t no_create_entry \n | |||
| idx = -2 | |||
| for word, count in self.word_count.items(): | |||
| if self._word2idx is not None: | |||
| idx = self._word2idx.get(word, -1) | |||
| is_no_create_entry = int(self._is_word_no_create_entry(word)) | |||
| f.write(f'{word}\t{count}\t{idx}\t{is_no_create_entry}\n') | |||
| if isinstance(filepath, io.IOBase): | |||
| assert filepath.writable() | |||
| f = filepath | |||
| elif isinstance(filepath, str): | |||
| try: | |||
| f = open(filepath, 'w', encoding='utf-8') | |||
| except Exception as e: | |||
| raise e | |||
| else: | |||
| raise TypeError("Illegal `filepath`.") | |||
| f.write(f'max_size\t{self.max_size}\n') | |||
| f.write(f'min_freq\t{self.min_freq}\n') | |||
| f.write(f'unknown\t{self.unknown}\n') | |||
| f.write(f'padding\t{self.padding}\n') | |||
| f.write(f'rebuild\t{self.rebuild}\n') | |||
| f.write('\n') | |||
| # idx: 如果idx为-2, 说明还没有进行build; 如果idx为-1,说明该词未编入 | |||
| # no_create_entry: 如果为1,说明该词是no_create_entry; 0 otherwise | |||
| # word \t count \t idx \t no_create_entry \n | |||
| idx = -2 | |||
| for word, count in self.word_count.items(): | |||
| if self._word2idx is not None: | |||
| idx = self._word2idx.get(word, -1) | |||
| is_no_create_entry = int(self._is_word_no_create_entry(word)) | |||
| f.write(f'{word}\t{count}\t{idx}\t{is_no_create_entry}\n') | |||
| if isinstance(filepath, str): # 如果是file的话就关闭 | |||
| f.close() | |||
| @staticmethod | |||
| def load(filepath): | |||
| r""" | |||
| :param str filepath: Vocabulary的读取路径 | |||
| :param str,io.StringIO filepath: Vocabulary的读取路径 | |||
| :return: Vocabulary | |||
| """ | |||
| with open(filepath, 'r', encoding='utf-8') as f: | |||
| vocab = Vocabulary() | |||
| for line in f: | |||
| line = line.strip() | |||
| if line: | |||
| name, value = line.split() | |||
| if name in ('max_size', 'min_freq'): | |||
| value = int(value) if value!='None' else None | |||
| setattr(vocab, name, value) | |||
| elif name in ('unknown', 'padding'): | |||
| value = value if value!='None' else None | |||
| setattr(vocab, name, value) | |||
| elif name == 'rebuild': | |||
| vocab.rebuild = True if value=='True' else False | |||
| else: | |||
| break | |||
| word_counter = {} | |||
| no_create_entry_counter = {} | |||
| word2idx = {} | |||
| for line in f: | |||
| line = line.strip() | |||
| if line: | |||
| parts = line.split('\t') | |||
| word,count,idx,no_create_entry = parts[0], int(parts[1]), int(parts[2]), int(parts[3]) | |||
| if idx >= 0: | |||
| word2idx[word] = idx | |||
| word_counter[word] = count | |||
| if no_create_entry: | |||
| no_create_entry_counter[word] = count | |||
| word_counter = Counter(word_counter) | |||
| no_create_entry_counter = Counter(no_create_entry_counter) | |||
| if len(word2idx)>0: | |||
| if vocab.padding: | |||
| word2idx[vocab.padding] = 0 | |||
| if vocab.unknown: | |||
| word2idx[vocab.unknown] = 1 if vocab.padding else 0 | |||
| idx2word = {value:key for key,value in word2idx.items()} | |||
| vocab.word_count = word_counter | |||
| vocab._no_create_word = no_create_entry_counter | |||
| if word2idx: | |||
| vocab._word2idx = word2idx | |||
| vocab._idx2word = idx2word | |||
| if isinstance(filepath, io.IOBase): | |||
| assert filepath.writable() | |||
| f = filepath | |||
| elif isinstance(filepath, str): | |||
| try: | |||
| f = open(filepath, 'r', encoding='utf-8') | |||
| except Exception as e: | |||
| raise e | |||
| else: | |||
| raise TypeError("Illegal `filepath`.") | |||
| vocab = Vocabulary() | |||
| for line in f: | |||
| line = line.strip() | |||
| if line: | |||
| name, value = line.split() | |||
| if name in ('max_size', 'min_freq'): | |||
| value = int(value) if value!='None' else None | |||
| setattr(vocab, name, value) | |||
| elif name in ('unknown', 'padding'): | |||
| value = value if value!='None' else None | |||
| setattr(vocab, name, value) | |||
| elif name == 'rebuild': | |||
| vocab.rebuild = True if value=='True' else False | |||
| else: | |||
| break | |||
| word_counter = {} | |||
| no_create_entry_counter = {} | |||
| word2idx = {} | |||
| for line in f: | |||
| line = line.strip() | |||
| if line: | |||
| parts = line.split('\t') | |||
| word,count,idx,no_create_entry = parts[0], int(parts[1]), int(parts[2]), int(parts[3]) | |||
| if idx >= 0: | |||
| word2idx[word] = idx | |||
| word_counter[word] = count | |||
| if no_create_entry: | |||
| no_create_entry_counter[word] = count | |||
| word_counter = Counter(word_counter) | |||
| no_create_entry_counter = Counter(no_create_entry_counter) | |||
| if len(word2idx)>0: | |||
| if vocab.padding: | |||
| word2idx[vocab.padding] = 0 | |||
| if vocab.unknown: | |||
| word2idx[vocab.unknown] = 1 if vocab.padding else 0 | |||
| idx2word = {value:key for key,value in word2idx.items()} | |||
| vocab.word_count = word_counter | |||
| vocab._no_create_word = no_create_entry_counter | |||
| if word2idx: | |||
| vocab._word2idx = word2idx | |||
| vocab._idx2word = idx2word | |||
| if isinstance(filepath, str): # 如果是file的话就关闭 | |||
| f.close() | |||
| return vocab | |||
| @@ -22,8 +22,9 @@ __all__ = [ | |||
| "StackEmbedding", | |||
| "LSTMCharEmbedding", | |||
| "CNNCharEmbedding", | |||
| "get_embeddings", | |||
| "get_embeddings", | |||
| "get_sinusoid_encoding_table" | |||
| ] | |||
| from .embedding import Embedding, TokenEmbedding | |||
| @@ -34,7 +35,7 @@ from .roberta_embedding import RobertaEmbedding, RobertaWordPieceEncoder | |||
| from .gpt2_embedding import GPT2WordPieceEncoder, GPT2Embedding | |||
| from .char_embedding import CNNCharEmbedding, LSTMCharEmbedding | |||
| from .stack_embedding import StackEmbedding | |||
| from .utils import get_embeddings | |||
| from .utils import get_embeddings, get_sinusoid_encoding_table | |||
| import sys | |||
| from ..doc_utils import doc_process | |||
| @@ -8,11 +8,11 @@ __all__ = [ | |||
| "BertWordPieceEncoder" | |||
| ] | |||
| import collections | |||
| import os | |||
| import warnings | |||
| from itertools import chain | |||
| from functools import partial | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| from torch import nn | |||
| @@ -24,6 +24,13 @@ from ..io.file_utils import PRETRAINED_BERT_MODEL_DIR | |||
| from ..modules.encoder.bert import BertModel | |||
| from ..modules.tokenizer import BertTokenizer | |||
| # TODO 需要重新修改,使得encoder可以直接读取embedding的权重 | |||
| VOCAB_NAME = 'vocab.txt' | |||
| BERT_EMBED_HYPER = 'bert_hyper.json' | |||
| BERT_EMBED_FOLDER = 'bert' | |||
| BERT_ENCODER_HYPER = 'bert_hyper.json' | |||
| BERT_ENCODER_FOLDER = 'bert' | |||
| class BertEmbedding(ContextualEmbedding): | |||
| r""" | |||
| @@ -82,10 +89,7 @@ class BertEmbedding(ContextualEmbedding): | |||
| word pieces后的内容,并将第512个word piece置为[SEP]。超过长度的部分的encode结果直接全部置零。一般仅有只使用[CLS] | |||
| 来进行分类的任务将auto_truncate置为True。 | |||
| :param kwargs: | |||
| bool only_use_pretrain_bpe: 仅使用出现在pretrain词表中的bpe,如果该词没法tokenize则使用unk。如果embedding不需要更新 | |||
| 建议设置为True。 | |||
| int min_freq: 仅在only_use_pretrain_bpe为False有效,大于等于该次数的词会被新加入BERT的BPE词表中 | |||
| bool truncate_embed: 是否仅保留用到的bpe(这样会减内存占用和加快速度) | |||
| int min_freq: 小于该次数的词会被unk代替 | |||
| """ | |||
| super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||
| @@ -106,14 +110,11 @@ class BertEmbedding(ContextualEmbedding): | |||
| if '[CLS]' in vocab: | |||
| self._word_cls_index = vocab['CLS'] | |||
| only_use_pretrain_bpe = kwargs.get('only_use_pretrain_bpe', False) | |||
| truncate_embed = kwargs.get('truncate_embed', True) | |||
| min_freq = kwargs.get('min_freq', 2) | |||
| self._min_freq = min_freq | |||
| self.model = _BertWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | |||
| pool_method=pool_method, include_cls_sep=include_cls_sep, | |||
| pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq, | |||
| only_use_pretrain_bpe=only_use_pretrain_bpe, truncate_embed=truncate_embed) | |||
| pooled_cls=pooled_cls, min_freq=min_freq, auto_truncate=auto_truncate) | |||
| self.requires_grad = requires_grad | |||
| self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | |||
| @@ -160,6 +161,57 @@ class BertEmbedding(ContextualEmbedding): | |||
| words = words.masked_fill(mask, self._word_unk_index) | |||
| return words | |||
| def save(self, folder): | |||
| """ | |||
| 将embedding保存到folder这个目录下,将会保存三个文件vocab.txt, bert_embed_hyper.txt, bert_embed/, 其中bert_embed下包含 | |||
| config.json,pytorch_model.bin,vocab.txt三个文件(该folder下的数据也可以直接被BERTModel读取) | |||
| :param str folder: | |||
| :return: | |||
| """ | |||
| os.makedirs(folder, exist_ok=True) | |||
| self.get_word_vocab().save(os.path.join(folder, VOCAB_NAME)) | |||
| hyper = {} | |||
| hyper['min_freq'] = self._min_freq | |||
| hyper['layers'] = ','.join(map(str, self.model.layers)) | |||
| hyper['pool_method'] = self.model.pool_method | |||
| hyper['dropout'] = self.dropout_layer.p | |||
| hyper['word_dropout'] = self.word_dropout | |||
| hyper['include_cls_sep'] = self.model.include_cls_sep | |||
| hyper['pooled_cls'] = self.model.pooled_cls | |||
| hyper['auto_truncate'] = self.model.auto_truncate | |||
| hyper['requires_grad'] = bool(self.requires_grad) | |||
| with open(os.path.join(folder, BERT_EMBED_HYPER), 'w', encoding='utf-8') as f: | |||
| json.dump(hyper, f, indent=2) | |||
| os.makedirs(os.path.join(folder, BERT_EMBED_FOLDER), exist_ok=True) | |||
| self.model.save(os.path.join(folder, BERT_EMBED_FOLDER)) | |||
| logger.debug(f"BERTEmbedding has been saved in {folder}") | |||
| @classmethod | |||
| def load(cls, folder): | |||
| """ | |||
| 给定一个folder, 需要包含以下三个内容vocab.txt, bert_embed_hyper.txt, bert_embed/ | |||
| :param str folder: | |||
| :return: | |||
| """ | |||
| for name in [VOCAB_NAME, BERT_EMBED_FOLDER, BERT_EMBED_HYPER]: | |||
| assert os.path.exists(os.path.join(folder, name)), f"{name} not found in {folder}." | |||
| vocab = Vocabulary.load(os.path.join(folder, VOCAB_NAME)) | |||
| with open(os.path.join(folder, BERT_EMBED_HYPER), 'r', encoding='utf-8') as f: | |||
| hyper = json.load(f) | |||
| model_dir_or_name = os.path.join(os.path.join(folder, BERT_EMBED_FOLDER)) | |||
| bert_embed = cls(vocab=vocab, model_dir_or_name=model_dir_or_name, **hyper) | |||
| return bert_embed | |||
| class BertWordPieceEncoder(nn.Module): | |||
| r""" | |||
| @@ -180,7 +232,7 @@ class BertWordPieceEncoder(nn.Module): | |||
| """ | |||
| def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False, | |||
| word_dropout=0, dropout=0, requires_grad: bool = True): | |||
| word_dropout=0, dropout=0, requires_grad: bool = True, **kwargs): | |||
| r""" | |||
| :param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased`` | |||
| @@ -270,11 +322,53 @@ class BertWordPieceEncoder(nn.Module): | |||
| words = words.masked_fill(mask, self._wordpiece_unk_index) | |||
| return words | |||
| def save(self, folder): | |||
| """ | |||
| 会在folder下创建两个文件bert_encoder_hyper.json与bert_encoder/, bert_encoder下包含三个文件config.json, | |||
| pytorch_model.bin,vocab.txt三个文件(该folder下的数据也可以直接被BERTModel读取) | |||
| :param str folder: | |||
| :return: | |||
| """ | |||
| os.makedirs(folder, exist_ok=True) | |||
| hyper = {} | |||
| hyper['layers'] = ','.join(map(str, self.model.layers)) | |||
| hyper['dropout'] = self.dropout_layer.p | |||
| hyper['word_dropout'] = self.word_dropout | |||
| hyper['pooled_cls'] = self.model.pooled_cls | |||
| hyper['requires_grad'] = bool(self.requires_grad) | |||
| with open(os.path.join(folder, BERT_ENCODER_HYPER), 'w', encoding='utf-8') as f: | |||
| json.dump(hyper, f, indent=2) | |||
| os.makedirs(os.path.join(folder, BERT_ENCODER_FOLDER), exist_ok=True) | |||
| self.model.save(os.path.join(folder, BERT_ENCODER_FOLDER)) | |||
| logger.debug(f"BertWordPieceEncoder has been saved in {folder}") | |||
| @classmethod | |||
| def load(cls, folder): | |||
| """ | |||
| 会在folder下创建两个文件bert_encoder_hyper.json与bert_encoder/, bert_encoder下包含三个文件 | |||
| :param folder: | |||
| :return: | |||
| """ | |||
| for name in [BERT_ENCODER_HYPER, BERT_ENCODER_FOLDER]: | |||
| assert os.path.exists(os.path.join(folder, name)), f"{name} not found in {folder}." | |||
| with open(os.path.join(folder, BERT_ENCODER_HYPER), 'r', encoding='utf-8') as f: | |||
| hyper = json.load(f) | |||
| model_dir_or_name = os.path.join(os.path.join(folder, BERT_ENCODER_FOLDER)) | |||
| bert_encoder = cls(model_dir_or_name=model_dir_or_name, **hyper) | |||
| return bert_encoder | |||
| class _BertWordModel(nn.Module): | |||
| def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', | |||
| include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2, | |||
| only_use_pretrain_bpe=False, truncate_embed=True): | |||
| include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2): | |||
| super().__init__() | |||
| self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name) | |||
| @@ -303,73 +397,8 @@ class _BertWordModel(nn.Module): | |||
| self.auto_truncate = auto_truncate | |||
| # 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP] | |||
| logger.info("Start to generate word pieces for word.") | |||
| self._has_sep_in_vocab = '[SEP]' in vocab # 用来判断传入的数据是否需要生成token_ids | |||
| # 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值 | |||
| word_piece_dict = {'[CLS]': 1, '[SEP]': 1} # 用到的word_piece以及新增的 | |||
| new_add_to_bpe_vocab = 0 | |||
| unsegment_count = 0 | |||
| if '[sep]' in vocab: | |||
| warnings.warn("Lower cased [sep] detected, it cannot be correctly recognized as [SEP] by BertEmbedding.") | |||
| if "[CLS]" in vocab: | |||
| warnings.warn("[CLS] detected in your vocabulary. BertEmbedding will add [CLS] and [SEP] to the begin " | |||
| "and end of the input automatically, make sure you don't add [CLS] and [SEP] at the begin" | |||
| " and end.") | |||
| for word, index in vocab: | |||
| if index == vocab.padding_idx: # pad是个特殊的符号 | |||
| word = '[PAD]' | |||
| elif index == vocab.unknown_idx: | |||
| word = '[UNK]' | |||
| _words = self.tokenzier.basic_tokenizer._tokenize_chinese_chars(word).split() | |||
| word_pieces = [] | |||
| for w in _words: | |||
| word_pieces.extend(self.tokenzier.wordpiece_tokenizer.tokenize(w)) | |||
| if len(word_pieces) == 1: | |||
| if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到 | |||
| if index != vocab.unknown_idx and word_pieces[0] == '[UNK]': # 说明这个词不在原始的word里面 | |||
| if vocab.word_count[word] >= min_freq and not vocab._is_word_no_create_entry( | |||
| word) and not only_use_pretrain_bpe: # 出现次数大于这个次数才新增 | |||
| word_piece_dict[word] = 1 # 新增一个值 | |||
| new_add_to_bpe_vocab += 1 | |||
| unsegment_count += 1 | |||
| continue | |||
| for word_piece in word_pieces: | |||
| word_piece_dict[word_piece] = 1 | |||
| original_embed = self.encoder.embeddings.word_embeddings.weight.data | |||
| # 特殊词汇要特殊处理 | |||
| if not truncate_embed:# 如果不删除的话需要将已有的加上 | |||
| word_piece_dict.update(self.tokenzier.vocab) | |||
| embed = nn.Embedding(len(word_piece_dict), original_embed.size(1)) # 新的embed | |||
| new_word_piece_vocab = collections.OrderedDict() | |||
| for index, token in enumerate(['[PAD]', '[UNK]']): | |||
| index = word_piece_dict.pop(token, None) | |||
| if index is not None: | |||
| new_word_piece_vocab[token] = len(new_word_piece_vocab) | |||
| embed.weight.data[new_word_piece_vocab[token]] = original_embed[self.tokenzier.vocab[token]] | |||
| for token in word_piece_dict.keys(): | |||
| if token not in new_word_piece_vocab: | |||
| new_word_piece_vocab[token] = len(new_word_piece_vocab) | |||
| index = new_word_piece_vocab[token] | |||
| if token in self.tokenzier.vocab: | |||
| embed.weight.data[index] = original_embed[self.tokenzier.vocab[token]] | |||
| else: | |||
| embed.weight.data[index] = original_embed[self.tokenzier.vocab['[UNK]']] | |||
| self.tokenzier._reinit_on_new_vocab(new_word_piece_vocab) | |||
| self.encoder.embeddings.word_embeddings = embed | |||
| self.encoder.config.vocab_size = len(new_word_piece_vocab) | |||
| if unsegment_count>0: | |||
| if only_use_pretrain_bpe or new_add_to_bpe_vocab==0: | |||
| logger.info(f"{unsegment_count} words are unsegmented.") | |||
| else: | |||
| logger.info(f"{unsegment_count} words are unsegmented. Among them, {new_add_to_bpe_vocab} added to the BPE vocab.") | |||
| word_to_wordpieces = [] | |||
| word_pieces_lengths = [] | |||
| for word, index in vocab: | |||
| @@ -377,6 +406,8 @@ class _BertWordModel(nn.Module): | |||
| word = '[PAD]' | |||
| elif index == vocab.unknown_idx: | |||
| word = '[UNK]' | |||
| elif vocab.word_count[word]<min_freq: | |||
| word = '[UNK]' | |||
| word_pieces = self.tokenzier.wordpiece_tokenizer.tokenize(word) | |||
| word_pieces = self.tokenzier.convert_tokens_to_ids(word_pieces) | |||
| word_to_wordpieces.append(word_pieces) | |||
| @@ -504,6 +535,16 @@ class _BertWordModel(nn.Module): | |||
| # 3. 最终的embedding结果 | |||
| return outputs | |||
| def save(self, folder): | |||
| """ | |||
| 给定一个folder保存pytorch_model.bin, config.json, vocab.txt | |||
| :param str folder: | |||
| :return: | |||
| """ | |||
| self.tokenzier.save_pretrained(folder) | |||
| self.encoder.save_pretrained(folder) | |||
| class _BertWordPieceModel(nn.Module): | |||
| r""" | |||
| @@ -580,4 +621,14 @@ class _BertWordPieceModel(nn.Module): | |||
| if l in (len(bert_outputs)-1, -1) and self.pooled_cls: | |||
| bert_output[:, 0] = pooled_cls | |||
| outputs[l_index] = bert_output | |||
| return outputs | |||
| return outputs | |||
| def save(self, folder): | |||
| """ | |||
| 给定一个folder保存pytorch_model.bin, config.json, vocab.txt | |||
| :param folder: | |||
| :return: | |||
| """ | |||
| self.tokenzier.save_pretrained(folder) | |||
| self.encoder.save_pretrained(folder) | |||
| @@ -78,7 +78,7 @@ class Embedding(nn.Module): | |||
| if isinstance(self.embed, nn.Embedding): | |||
| return self.embed.weight.size(0) | |||
| else: | |||
| return self.embed.num_embedding | |||
| return self.embed.num_embeddings | |||
| def __len__(self): | |||
| return len(self.embed) | |||
| @@ -188,7 +188,7 @@ class TokenEmbedding(nn.Module): | |||
| return self._embed_size | |||
| @property | |||
| def num_embedding(self) -> int: | |||
| def num_embeddings(self) -> int: | |||
| r""" | |||
| 这个值可能会大于实际的embedding矩阵的大小。 | |||
| :return: | |||
| @@ -205,7 +205,7 @@ class TokenEmbedding(nn.Module): | |||
| @property | |||
| def size(self): | |||
| return torch.Size(self.num_embedding, self._embed_size) | |||
| return torch.Size(self.num_embeddings, self._embed_size) | |||
| @abstractmethod | |||
| def forward(self, words): | |||
| @@ -10,8 +10,8 @@ __all__ = [ | |||
| from functools import partial | |||
| import collections | |||
| import warnings | |||
| import os | |||
| import json | |||
| from itertools import chain | |||
| import numpy as np | |||
| @@ -24,6 +24,13 @@ from ..modules.encoder.roberta import RobertaModel | |||
| from ..modules.tokenizer import RobertaTokenizer | |||
| VOCAB_NAME = 'vocab.txt' | |||
| ROBERTA_EMBED_HYPER = 'roberta_hyper.json' | |||
| ROBERTA_ENCODER_HYPER = 'roberta_hyper.json' | |||
| ROBERTA_EMBED_FOLDER = 'roberta' | |||
| ROBERTA_ENCODER_FOLDER = 'roberta' | |||
| class RobertaEmbedding(ContextualEmbedding): | |||
| r""" | |||
| 使用RoBERTa对words进行编码的Embedding。建议将输入的words长度限制在430以内,而不要使用512(根据预训练模型参数,可能有变化)。这是由于 | |||
| @@ -71,10 +78,7 @@ class RobertaEmbedding(ContextualEmbedding): | |||
| word pieces后的内容,并将第512个word piece置为</s>。超过长度的部分的encode结果直接全部置零。一般仅有只使用<s> | |||
| 来进行分类的任务将auto_truncate置为True。 | |||
| :param kwargs: | |||
| bool only_use_pretrain_bpe: 仅使用出现在pretrain词表中的bpe,如果该词没法tokenize则使用unk。如果embedding不需要更新 | |||
| 建议设置为True。 | |||
| int min_freq: 仅在only_use_pretrain_bpe为False有效,大于等于该次数的词会被新加入BERT的BPE词表中 | |||
| bool truncate_embed: 是否仅保留用到的bpe(这样会减内存占用和加快速度) | |||
| int min_freq: 小于该次数的词会被unk代替 | |||
| """ | |||
| super().__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||
| @@ -89,14 +93,12 @@ class RobertaEmbedding(ContextualEmbedding): | |||
| if '<s>' in vocab: | |||
| self._word_cls_index = vocab['<s>'] | |||
| only_use_pretrain_bpe = kwargs.get('only_use_pretrain_bpe', False) | |||
| truncate_embed = kwargs.get('truncate_embed', True) | |||
| min_freq = kwargs.get('min_freq', 2) | |||
| self._min_freq = min_freq | |||
| self.model = _RobertaWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | |||
| pool_method=pool_method, include_cls_sep=include_cls_sep, | |||
| pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq, | |||
| only_use_pretrain_bpe=only_use_pretrain_bpe, truncate_embed=truncate_embed) | |||
| pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq) | |||
| self.requires_grad = requires_grad | |||
| self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | |||
| @@ -142,11 +144,56 @@ class RobertaEmbedding(ContextualEmbedding): | |||
| words = words.masked_fill(mask, self._word_unk_index) | |||
| return words | |||
| def save(self, folder): | |||
| """ | |||
| 将roberta embedding保存到folder,保存之后包含三个文件vocab.txt, roberta_embed_hyper.txt, roberta_embed/, | |||
| :param str folder: 保存地址 | |||
| :return: | |||
| """ | |||
| os.makedirs(folder, exist_ok=True) | |||
| self.get_word_vocab().save(os.path.join(folder, VOCAB_NAME)) | |||
| hyper = {} | |||
| hyper['min_freq'] = self._min_freq | |||
| hyper['layers'] = ','.join(map(str, self.model.layers)) | |||
| hyper['pool_method'] = self.model.pool_method | |||
| hyper['dropout'] = self.dropout_layer.p | |||
| hyper['word_dropout'] = self.word_dropout | |||
| hyper['include_cls_sep'] = self.model.include_cls_sep | |||
| hyper['pooled_cls'] = self.model.pooled_cls | |||
| hyper['auto_truncate'] = self.model.auto_truncate | |||
| hyper['requires_grad'] = bool(self.requires_grad) | |||
| with open(os.path.join(folder, ROBERTA_EMBED_HYPER), 'w', encoding='utf-8') as f: | |||
| json.dump(hyper, f, indent=2) | |||
| os.makedirs(os.path.join(folder, ROBERTA_EMBED_FOLDER), exist_ok=True) | |||
| self.model.save(os.path.join(folder, ROBERTA_EMBED_FOLDER)) | |||
| @classmethod | |||
| def load(cls, folder): | |||
| """ | |||
| 从folder中读取数据初始化RobertaEmbedding | |||
| :param folder: | |||
| :return: | |||
| """ | |||
| for name in [VOCAB_NAME, ROBERTA_EMBED_HYPER, ROBERTA_EMBED_FOLDER]: | |||
| assert os.path.exists(os.path.join(folder, name)), f"{name} not found in {folder}." | |||
| vocab = Vocabulary.load(os.path.join(folder, VOCAB_NAME)) | |||
| with open(os.path.join(folder, ROBERTA_EMBED_HYPER), 'r', encoding='utf-8') as f: | |||
| hyper = json.load(f) | |||
| model_name_or_path = os.path.join(folder, ROBERTA_EMBED_FOLDER) | |||
| roberta = cls(vocab=vocab, model_dir_or_name=model_name_or_path, **hyper) | |||
| return roberta | |||
| class _RobertaWordModel(nn.Module): | |||
| def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', | |||
| include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2, | |||
| only_use_pretrain_bpe=False, truncate_embed=True): | |||
| include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2): | |||
| super().__init__() | |||
| self.tokenzier = RobertaTokenizer.from_pretrained(model_dir_or_name) | |||
| @@ -177,72 +224,6 @@ class _RobertaWordModel(nn.Module): | |||
| self.pooled_cls = pooled_cls | |||
| self.auto_truncate = auto_truncate | |||
| # 将所有vocab中word的wordpiece计算出来, 需要额外考虑<s>和</s> | |||
| logger.info("Start to generate word pieces for word.") | |||
| # 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值 | |||
| word_piece_dict = {'<s>': 1, '</s>': 1} # 用到的word_piece以及新增的 | |||
| found_count = 0 | |||
| new_add_to_bpe_vocab = 0 | |||
| unsegment_count = 0 | |||
| if "<s>" in vocab: | |||
| warnings.warn("<s> detected in your vocabulary. RobertaEmbedding will add <s> and </s> to the begin " | |||
| "and end of the input automatically, make sure you don't add <s> and </s> at the begin" | |||
| " and end.") | |||
| for word, index in vocab: | |||
| if index == vocab.padding_idx: # pad是个特殊的符号 | |||
| word = '<pad>' | |||
| elif index == vocab.unknown_idx: | |||
| word = '<unk>' | |||
| # _words = self.tokenzier.basic_tokenizer._tokenize_chinese_chars(word).split() # 这里暂时不考虑中文内容 | |||
| word_pieces = [] | |||
| # 如果这个word不是在句子开头 | |||
| word_pieces.extend(self.tokenzier.tokenize(word, add_prefix_space=True)) | |||
| if len(word_pieces) == 1: | |||
| if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到 | |||
| if index != vocab.unknown_idx and word_pieces[0] == '<unk>': # 说明这个词不在原始的word里面 | |||
| if vocab.word_count[word] >= min_freq and not vocab._is_word_no_create_entry( | |||
| word) and not only_use_pretrain_bpe: # 出现次数大于这个次数才新增 | |||
| word_piece_dict[word] = 1 # 新增一个值 | |||
| new_add_to_bpe_vocab += 1 | |||
| unsegment_count += 1 | |||
| continue | |||
| found_count += 1 | |||
| for word_piece in word_pieces: | |||
| word_piece_dict[word_piece] = 1 | |||
| # 如果这个word是在句子开头 | |||
| original_embed = self.encoder.embeddings.word_embeddings.weight.data | |||
| # 特殊词汇要特殊处理 | |||
| if not truncate_embed: # 如果不删除的话需要将已有的加上 | |||
| word_piece_dict.update(self.tokenzier.encoder) | |||
| embed = nn.Embedding(len(word_piece_dict), original_embed.size(1)) # 新的embed | |||
| new_word_piece_vocab = collections.OrderedDict() | |||
| for index, token in enumerate(['<s>', '<pad>', '</s>', '<unk>']): | |||
| index = word_piece_dict.pop(token, None) | |||
| if index is not None: | |||
| new_word_piece_vocab[token] = len(new_word_piece_vocab) | |||
| embed.weight.data[new_word_piece_vocab[token]] = original_embed[self.tokenzier.encoder[token]] | |||
| for token in word_piece_dict.keys(): | |||
| if token not in new_word_piece_vocab: | |||
| new_word_piece_vocab[token] = len(new_word_piece_vocab) | |||
| index = new_word_piece_vocab[token] | |||
| if token in self.tokenzier.encoder: | |||
| embed.weight.data[index] = original_embed[self.tokenzier.encoder[token]] | |||
| else: | |||
| embed.weight.data[index] = original_embed[self.tokenzier.encoder['<unk>']] | |||
| self.tokenzier._reinit_on_new_vocab(new_word_piece_vocab) | |||
| self.encoder.embeddings.word_embeddings = embed | |||
| self.encoder.config.vocab_size = len(new_word_piece_vocab) | |||
| if unsegment_count>0: | |||
| if only_use_pretrain_bpe or new_add_to_bpe_vocab==0: | |||
| logger.info(f"{unsegment_count} words are unsegmented.") | |||
| else: | |||
| logger.info(f"{unsegment_count} words are unsegmented. Among them, {new_add_to_bpe_vocab} added to the BPE vocab.") | |||
| word_to_wordpieces = [] | |||
| word_pieces_lengths = [] | |||
| for word, index in vocab: | |||
| @@ -250,6 +231,8 @@ class _RobertaWordModel(nn.Module): | |||
| word = '<pad>' | |||
| elif index == vocab.unknown_idx: | |||
| word = '<unk>' | |||
| elif vocab.word_count[word]<min_freq: | |||
| word = '<unk>' | |||
| word_pieces = self.tokenzier.tokenize(word) | |||
| word_pieces = self.tokenzier.convert_tokens_to_ids(word_pieces) | |||
| word_to_wordpieces.append(word_pieces) | |||
| @@ -368,6 +351,10 @@ class _RobertaWordModel(nn.Module): | |||
| # 3. 最终的embedding结果 | |||
| return outputs | |||
| def save(self, folder): | |||
| self.tokenzier.save_pretrained(folder) | |||
| self.encoder.save_pretrained(folder) | |||
| class RobertaWordPieceEncoder(nn.Module): | |||
| r""" | |||
| @@ -380,7 +367,7 @@ class RobertaWordPieceEncoder(nn.Module): | |||
| """ | |||
| def __init__(self, model_dir_or_name: str = 'en', layers: str = '-1', pooled_cls: bool = False, | |||
| word_dropout=0, dropout=0, requires_grad: bool = True): | |||
| word_dropout=0, dropout=0, requires_grad: bool = True, **kwargs): | |||
| r""" | |||
| :param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased`` | |||
| @@ -462,6 +449,36 @@ class RobertaWordPieceEncoder(nn.Module): | |||
| words = words.masked_fill(mask, self._wordpiece_unk_index) | |||
| return words | |||
| def save(self, folder): | |||
| os.makedirs(folder, exist_ok=True) | |||
| hyper = {} | |||
| hyper['layers'] = ','.join(map(str, self.model.layers)) | |||
| hyper['dropout'] = self.dropout_layer.p | |||
| hyper['word_dropout'] = self.word_dropout | |||
| hyper['pooled_cls'] = self.model.pooled_cls | |||
| hyper['requires_grad'] = bool(self.requires_grad) | |||
| with open(os.path.join(folder, ROBERTA_ENCODER_HYPER), 'w', encoding='utf-8') as f: | |||
| json.dump(hyper, f, indent=2) | |||
| os.makedirs(os.path.join(folder, ROBERTA_ENCODER_FOLDER), exist_ok=True) | |||
| self.model.save(os.path.join(folder, ROBERTA_ENCODER_FOLDER)) | |||
| logger.debug(f"BertWordPieceEncoder has been saved in {folder}") | |||
| @classmethod | |||
| def load(cls, folder): | |||
| for name in [ROBERTA_ENCODER_HYPER, ROBERTA_ENCODER_FOLDER]: | |||
| assert os.path.exists(os.path.join(folder, name)), f"{name} not found in {folder}." | |||
| with open(os.path.join(folder, ROBERTA_ENCODER_HYPER), 'r', encoding='utf-8') as f: | |||
| hyper = json.load(f) | |||
| model_dir_or_name = os.path.join(os.path.join(folder, ROBERTA_ENCODER_FOLDER)) | |||
| bert_encoder = cls(model_dir_or_name=model_dir_or_name, **hyper) | |||
| return bert_encoder | |||
| class _WordPieceRobertaModel(nn.Module): | |||
| def __init__(self, model_dir_or_name: str, layers: str = '-1', pooled_cls: bool=False): | |||
| @@ -535,4 +552,8 @@ class _WordPieceRobertaModel(nn.Module): | |||
| if l in (len(roberta_output)-1, -1) and self.pooled_cls: | |||
| roberta_output[:, 0] = pooled_cls | |||
| outputs[l_index] = roberta_output | |||
| return outputs | |||
| return outputs | |||
| def save(self, folder): | |||
| self.tokenzier.save_pretrained(folder) | |||
| self.encoder.save_pretrained(folder) | |||
| @@ -10,6 +10,8 @@ import os | |||
| import warnings | |||
| from collections import defaultdict | |||
| from copy import deepcopy | |||
| import json | |||
| from typing import Union | |||
| import numpy as np | |||
| import torch | |||
| @@ -19,7 +21,12 @@ from .embedding import TokenEmbedding | |||
| from ..core import logger | |||
| from ..core.vocabulary import Vocabulary | |||
| from ..io.file_utils import PRETRAIN_STATIC_FILES, _get_embedding_url, cached_path | |||
| from fastNLP.io.file_utils import _get_file_name_base_on_postfix | |||
| from ..io.file_utils import _get_file_name_base_on_postfix | |||
| VOCAB_FILENAME = 'vocab.txt' | |||
| STATIC_HYPER_FILENAME = 'static_hyper.json' | |||
| STATIC_EMBED_FILENAME = 'static.txt' | |||
| class StaticEmbedding(TokenEmbedding): | |||
| @@ -70,7 +77,7 @@ class StaticEmbedding(TokenEmbedding): | |||
| """ | |||
| def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', embedding_dim=-1, requires_grad: bool = True, | |||
| def __init__(self, vocab: Vocabulary, model_dir_or_name: Union[str, None] = 'en', embedding_dim=-1, requires_grad: bool = True, | |||
| init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1, **kwargs): | |||
| r""" | |||
| @@ -95,8 +102,8 @@ class StaticEmbedding(TokenEmbedding): | |||
| """ | |||
| super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||
| if embedding_dim > 0: | |||
| if model_dir_or_name is not None: | |||
| warnings.warn(f"StaticEmbedding will ignore `model_dir_or_name`, and randomly initialize embedding with" | |||
| if model_dir_or_name: | |||
| logger.info(f"StaticEmbedding will ignore `model_dir_or_name`, and randomly initialize embedding with" | |||
| f" dimension {embedding_dim}. If you want to use pre-trained embedding, " | |||
| f"set `embedding_dim` to 0.") | |||
| model_dir_or_name = None | |||
| @@ -116,7 +123,9 @@ class StaticEmbedding(TokenEmbedding): | |||
| model_path = _get_file_name_base_on_postfix(os.path.abspath(os.path.expanduser(model_dir_or_name)), '.txt') | |||
| else: | |||
| raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||
| kwargs['min_freq'] = min_freq | |||
| kwargs['lower'] = lower | |||
| # 根据min_freq缩小vocab | |||
| truncate_vocab = (vocab.min_freq is None and min_freq > 1) or (vocab.min_freq and vocab.min_freq < min_freq) | |||
| if truncate_vocab: | |||
| @@ -143,7 +152,7 @@ class StaticEmbedding(TokenEmbedding): | |||
| truncated_words_to_words = torch.arange(len(vocab)).long() | |||
| for word, index in vocab: | |||
| truncated_words_to_words[index] = truncated_vocab.to_index(word) | |||
| logger.info(f"{len(vocab) - len(truncated_vocab)} out of {len(vocab)} words have frequency less than {min_freq}.") | |||
| logger.info(f"{len(vocab) - len(truncated_vocab)} words have frequency less than {min_freq}.") | |||
| vocab = truncated_vocab | |||
| self.only_use_pretrain_word = kwargs.get('only_use_pretrain_word', False) | |||
| @@ -198,6 +207,7 @@ class StaticEmbedding(TokenEmbedding): | |||
| sparse=False, _weight=embedding) | |||
| self._embed_size = self.embedding.weight.size(1) | |||
| self.requires_grad = requires_grad | |||
| self.kwargs = kwargs | |||
| @property | |||
| def weight(self): | |||
| @@ -321,3 +331,71 @@ class StaticEmbedding(TokenEmbedding): | |||
| words = self.embedding(words) | |||
| words = self.dropout(words) | |||
| return words | |||
| def save(self, folder): | |||
| """ | |||
| 将embedding存储到folder下,之后可以通过使用load方法读取 | |||
| :param str folder: 会在该folder下生成三个文件, vocab.txt, static_embed_hyper.txt, static_embed_hyper.json. | |||
| 其中vocab.txt可以用Vocabulary通过load读取; embedding.txt按照word2vec的方式存储,以空格的方式隔开元素, | |||
| 第一行只有两个元素,剩下的行首先是word然后是各个维度的值; static_embed_hyper.json是StaticEmbedding的超参数 | |||
| :return: | |||
| """ | |||
| os.makedirs(folder, exist_ok=True) | |||
| vocab = self.get_word_vocab() | |||
| vocab_fp = os.path.join(folder, VOCAB_FILENAME) | |||
| vocab.save(vocab_fp) | |||
| kwargs = self.kwargs.copy() | |||
| kwargs['dropout'] = self.dropout_layer.p | |||
| kwargs['word_dropout'] = self.word_dropout | |||
| kwargs['requires_grad'] = self.requires_grad | |||
| kwargs['only_norm_found_vector'] = False | |||
| kwargs['only_use_pretrain_word'] = True | |||
| with open(os.path.join(folder, STATIC_HYPER_FILENAME), 'w', encoding='utf-8') as f: | |||
| json.dump(kwargs, f, indent=2) | |||
| with open(os.path.join(folder, STATIC_EMBED_FILENAME), 'w', encoding='utf-8') as f: | |||
| f.write('{}\n'.format(' '*30)) # 留白之后再来填写 | |||
| word_count = 0 | |||
| saved_word = {} | |||
| valid_word_count = 0 | |||
| for i in range(len(self.words_to_words)): | |||
| word = vocab.to_word(i) | |||
| if not vocab._is_word_no_create_entry(word): | |||
| word_count += 1 | |||
| if kwargs['lower']: | |||
| word = word.lower() | |||
| if word in saved_word: | |||
| continue | |||
| saved_word[word] = 1 | |||
| vec_i = self.words_to_words[i] | |||
| if vec_i==vocab.unknown_idx and i!=vocab.unknown_idx: | |||
| continue | |||
| vec = self.embedding.weight.data[vec_i].tolist() | |||
| vec_str = ' '.join(map(str, vec)) | |||
| f.write(f'{word} {vec_str}\n') | |||
| valid_word_count += 1 | |||
| f.seek(0) | |||
| f.write('{} {}'.format(valid_word_count, self.embedding_dim)) | |||
| logger.debug(f"StaticEmbedding has been saved to {folder}.") | |||
| @classmethod | |||
| def load(cls, folder): | |||
| """ | |||
| :param str folder: 该folder下应该有以下三个文件vocab.txt, static_embed.txt, static_hyper.json | |||
| :return: | |||
| """ | |||
| for name in [VOCAB_FILENAME, STATIC_EMBED_FILENAME, STATIC_HYPER_FILENAME]: | |||
| assert os.path.exists(os.path.join(folder, name)), f"{name} not found in {folder}." | |||
| vocab = Vocabulary.load(os.path.join(folder, VOCAB_FILENAME)) | |||
| with open(os.path.join(folder, STATIC_HYPER_FILENAME), 'r', encoding='utf-8') as f: | |||
| hyper = json.load(f) | |||
| logger.info(f"Load StaticEmbedding from {folder}.") | |||
| embed = cls(vocab=vocab, model_dir_or_name=os.path.join(folder, STATIC_EMBED_FILENAME), **hyper) | |||
| return embed | |||
| @@ -9,7 +9,8 @@ from torch import nn as nn | |||
| from ..core.vocabulary import Vocabulary | |||
| __all__ = [ | |||
| 'get_embeddings' | |||
| 'get_embeddings', | |||
| 'get_sinusoid_encoding_table' | |||
| ] | |||
| @@ -31,7 +32,7 @@ def _construct_char_vocab_from_vocab(vocab: Vocabulary, min_freq: int = 1, inclu | |||
| return char_vocab | |||
| def get_embeddings(init_embed): | |||
| def get_embeddings(init_embed, padding_idx=None): | |||
| r""" | |||
| 根据输入的init_embed返回Embedding对象。如果输入是tuple, 则随机初始化一个nn.Embedding; 如果输入是numpy.ndarray, 则按照ndarray | |||
| 的值将nn.Embedding初始化; 如果输入是torch.Tensor, 则按该值初始化nn.Embedding; 如果输入是fastNLP中的embedding将不做处理 | |||
| @@ -40,11 +41,12 @@ def get_embeddings(init_embed): | |||
| :param init_embed: 可以是 tuple:(num_embedings, embedding_dim), 即embedding的大小和每个词的维度;也可以传入 | |||
| nn.Embedding 对象, 此时就以传入的对象作为embedding; 传入np.ndarray也行,将使用传入的ndarray作为作为Embedding初始化; | |||
| 传入torch.Tensor, 将使用传入的值作为Embedding初始化。 | |||
| :param padding_idx: 当传入tuple时,padding_idx有效 | |||
| :return nn.Embedding: embeddings | |||
| """ | |||
| if isinstance(init_embed, tuple): | |||
| res = nn.Embedding( | |||
| num_embeddings=init_embed[0], embedding_dim=init_embed[1]) | |||
| num_embeddings=init_embed[0], embedding_dim=init_embed[1], padding_idx=padding_idx) | |||
| nn.init.uniform_(res.weight.data, a=-np.sqrt(3 / res.weight.data.size(1)), | |||
| b=np.sqrt(3 / res.weight.data.size(1))) | |||
| elif isinstance(init_embed, nn.Module): | |||
| @@ -58,3 +60,32 @@ def get_embeddings(init_embed): | |||
| raise TypeError( | |||
| 'invalid init_embed type: {}'.format((type(init_embed)))) | |||
| return res | |||
| def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): | |||
| """ | |||
| sinusoid的embedding,其中position的表示中,偶数维(0,2,4,...)是sin, 奇数(1,3,5...)是cos | |||
| :param int n_position: 一共多少个position | |||
| :param int d_hid: 多少维度,需要为偶数 | |||
| :param padding_idx: | |||
| :return: torch.FloatTensor, shape为n_position x d_hid | |||
| """ | |||
| def cal_angle(position, hid_idx): | |||
| return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) | |||
| def get_posi_angle_vec(position): | |||
| return [cal_angle(position, hid_j) for hid_j in range(d_hid)] | |||
| sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) | |||
| sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i | |||
| sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 | |||
| if padding_idx is not None: | |||
| # zero vector for padding dimension | |||
| sinusoid_table[padding_idx] = 0. | |||
| return torch.FloatTensor(sinusoid_table) | |||
| @@ -9,18 +9,18 @@ fastNLP 在 :mod:`~fastNLP.models` 模块中内置了如 :class:`~fastNLP.models | |||
| """ | |||
| __all__ = [ | |||
| "CNNText", | |||
| "SeqLabeling", | |||
| "AdvSeqLabel", | |||
| "BiLSTMCRF", | |||
| "ESIM", | |||
| "StarTransEnc", | |||
| "STSeqLabel", | |||
| "STNLICls", | |||
| "STSeqCls", | |||
| "BiaffineParser", | |||
| "GraphParser", | |||
| @@ -28,7 +28,13 @@ __all__ = [ | |||
| "BertForSentenceMatching", | |||
| "BertForMultipleChoice", | |||
| "BertForTokenClassification", | |||
| "BertForQuestionAnswering" | |||
| "BertForQuestionAnswering", | |||
| "TransformerSeq2SeqModel", | |||
| "LSTMSeq2SeqModel", | |||
| "Seq2SeqModel", | |||
| 'SequenceGeneratorModel' | |||
| ] | |||
| from .base_model import BaseModel | |||
| @@ -39,7 +45,9 @@ from .cnn_text_classification import CNNText | |||
| from .sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF | |||
| from .snli import ESIM | |||
| from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel | |||
| from .seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel, Seq2SeqModel | |||
| from .seq2seq_generator import SequenceGeneratorModel | |||
| import sys | |||
| from ..doc_utils import doc_process | |||
| doc_process(sys.modules[__name__]) | |||
| doc_process(sys.modules[__name__]) | |||
| @@ -39,7 +39,7 @@ from torch import nn | |||
| from .base_model import BaseModel | |||
| from ..core._logger import logger | |||
| from ..core.const import Const | |||
| from ..embeddings import BertEmbedding | |||
| from ..embeddings.bert_embedding import BertEmbedding | |||
| class BertForSequenceClassification(BaseModel): | |||
| @@ -314,13 +314,8 @@ class BiaffineParser(GraphParser): | |||
| raise ValueError('unsupported rnn_out_size: {} for transformer'.format(rnn_out_size)) | |||
| self.position_emb = nn.Embedding(num_embeddings=self.max_len, | |||
| embedding_dim=rnn_out_size, ) | |||
| self.encoder = TransformerEncoder(num_layers=rnn_layers, | |||
| model_size=rnn_out_size, | |||
| inner_size=1024, | |||
| key_size=d_k, | |||
| value_size=d_v, | |||
| num_head=n_head, | |||
| dropout=dropout, ) | |||
| self.encoder = TransformerEncoder( num_layers=rnn_layers, d_model=rnn_out_size, | |||
| n_head=n_head, dim_ff=1024, dropout=dropout) | |||
| else: | |||
| raise ValueError('unsupported encoder type: {}'.format(encoder)) | |||
| @@ -0,0 +1,62 @@ | |||
| import torch | |||
| from torch import nn | |||
| from .seq2seq_model import Seq2SeqModel | |||
| from ..modules.generator.seq2seq_generator import SequenceGenerator | |||
| class SequenceGeneratorModel(nn.Module): | |||
| """ | |||
| 用于封装Seq2SeqModel使其可以做生成任务 | |||
| """ | |||
| def __init__(self, seq2seq_model: Seq2SeqModel, bos_token_id, eos_token_id=None, max_length=30, num_beams=1, | |||
| do_sample=True, temperature=1.0, top_k=50, top_p=1.0, | |||
| repetition_penalty=1, length_penalty=1.0, pad_token_id=0): | |||
| """ | |||
| :param Seq2SeqModel seq2seq_model: 序列到序列模型 | |||
| :param int,None bos_token_id: 句子开头的token id | |||
| :param int,None eos_token_id: 句子结束的token id | |||
| :param int max_length: 句子的最大长度 | |||
| :param int num_beams: beam search的大小 | |||
| :param bool do_sample: 是否通过采样的方式生成 | |||
| :param float temperature: 只有在do_sample为True才有意义 | |||
| :param int top_k: 只从top_k中采样 | |||
| :param float top_p: 只从top_p的token中采样,nucles sample | |||
| :param float repetition_penalty: 多大程度上惩罚重复的token | |||
| :param float length_penalty: 对长度的惩罚,小于1鼓励长句,大于1鼓励短剧 | |||
| :param int pad_token_id: 当某句话生成结束之后,之后生成的内容用pad_token_id补充 | |||
| """ | |||
| super().__init__() | |||
| self.seq2seq_model = seq2seq_model | |||
| self.generator = SequenceGenerator(seq2seq_model.decoder, max_length=max_length, num_beams=num_beams, | |||
| do_sample=do_sample, temperature=temperature, top_k=top_k, top_p=top_p, | |||
| bos_token_id=bos_token_id, | |||
| eos_token_id=eos_token_id, | |||
| repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||
| pad_token_id=pad_token_id) | |||
| def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None): | |||
| """ | |||
| 透传调用seq2seq_model的forward | |||
| :param torch.LongTensor src_tokens: bsz x max_len | |||
| :param torch.LongTensor tgt_tokens: bsz x max_len' | |||
| :param torch.LongTensor src_seq_len: bsz | |||
| :param torch.LongTensor tgt_seq_len: bsz | |||
| :return: | |||
| """ | |||
| return self.seq2seq_model(src_tokens, tgt_tokens, src_seq_len, tgt_seq_len) | |||
| def predict(self, src_tokens, src_seq_len=None): | |||
| """ | |||
| 给定source的内容,输出generate的内容 | |||
| :param torch.LongTensor src_tokens: bsz x max_len | |||
| :param torch.LongTensor src_seq_len: bsz | |||
| :return: | |||
| """ | |||
| state = self.seq2seq_model.prepare_state(src_tokens, src_seq_len) | |||
| result = self.generator.generate(state) | |||
| return {'pred': result} | |||
| @@ -0,0 +1,176 @@ | |||
| r""" | |||
| 主要包含组成Sequence-to-Sequence的model | |||
| """ | |||
| import torch | |||
| from torch import nn | |||
| from ..embeddings import get_embeddings | |||
| from ..embeddings.utils import get_sinusoid_encoding_table | |||
| from ..modules.decoder.seq2seq_decoder import Seq2SeqDecoder, TransformerSeq2SeqDecoder, LSTMSeq2SeqDecoder | |||
| from ..modules.encoder.seq2seq_encoder import Seq2SeqEncoder, TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder | |||
| class Seq2SeqModel(nn.Module): | |||
| def __init__(self, encoder: Seq2SeqEncoder, decoder: Seq2SeqDecoder): | |||
| """ | |||
| 可以用于在Trainer中训练的Seq2Seq模型。正常情况下,继承了该函数之后,只需要实现classmethod build_model即可。 | |||
| :param encoder: Encoder | |||
| :param decoder: Decoder | |||
| """ | |||
| super().__init__() | |||
| self.encoder = encoder | |||
| self.decoder = decoder | |||
| def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None): | |||
| """ | |||
| :param torch.LongTensor src_tokens: source的token | |||
| :param torch.LongTensor tgt_tokens: target的token | |||
| :param torch.LongTensor src_seq_len: src的长度 | |||
| :param torch.LongTensor tgt_seq_len: target的长度,默认用不上 | |||
| :return: {'pred': torch.Tensor}, 其中pred的shape为bsz x max_len x vocab_size | |||
| """ | |||
| state = self.prepare_state(src_tokens, src_seq_len) | |||
| decoder_output = self.decoder(tgt_tokens, state) | |||
| if isinstance(decoder_output, torch.Tensor): | |||
| return {'pred': decoder_output} | |||
| elif isinstance(decoder_output, (tuple, list)): | |||
| return {'pred': decoder_output[0]} | |||
| else: | |||
| raise TypeError(f"Unsupported return type from Decoder:{type(self.decoder)}") | |||
| def prepare_state(self, src_tokens, src_seq_len=None): | |||
| """ | |||
| 调用encoder获取state,会把encoder的encoder_output, encoder_mask直接传入到decoder.init_state中初始化一个state | |||
| :param src_tokens: | |||
| :param src_seq_len: | |||
| :return: | |||
| """ | |||
| encoder_output, encoder_mask = self.encoder(src_tokens, src_seq_len) | |||
| state = self.decoder.init_state(encoder_output, encoder_mask) | |||
| return state | |||
| @classmethod | |||
| def build_model(cls, *args, **kwargs): | |||
| """ | |||
| 需要实现本方法来进行Seq2SeqModel的初始化 | |||
| :return: | |||
| """ | |||
| raise NotImplemented | |||
| class TransformerSeq2SeqModel(Seq2SeqModel): | |||
| """ | |||
| Encoder为TransformerSeq2SeqEncoder, decoder为TransformerSeq2SeqDecoder,通过build_model方法初始化 | |||
| """ | |||
| def __init__(self, encoder, decoder): | |||
| super().__init__(encoder, decoder) | |||
| @classmethod | |||
| def build_model(cls, src_embed, tgt_embed=None, | |||
| pos_embed='sin', max_position=1024, num_layers=6, d_model=512, n_head=8, dim_ff=2048, dropout=0.1, | |||
| bind_encoder_decoder_embed=False, | |||
| bind_decoder_input_output_embed=True): | |||
| """ | |||
| 初始化一个TransformerSeq2SeqModel | |||
| :param nn.Module, StaticEmbedding, Tuple[int, int] src_embed: source的embedding | |||
| :param nn.Module, StaticEmbedding, Tuple[int, int] tgt_embed: target的embedding,如果bind_encoder_decoder_embed为 | |||
| True,则不要输入该值 | |||
| :param str pos_embed: 支持sin, learned两种 | |||
| :param int max_position: 最大支持长度 | |||
| :param int num_layers: encoder和decoder的层数 | |||
| :param int d_model: encoder和decoder输入输出的大小 | |||
| :param int n_head: encoder和decoder的head的数量 | |||
| :param int dim_ff: encoder和decoder中FFN中间映射的维度 | |||
| :param float dropout: Attention和FFN dropout的大小 | |||
| :param bool bind_encoder_decoder_embed: 是否对encoder和decoder使用相同的embedding | |||
| :param bool bind_decoder_input_output_embed: decoder的输出embedding是否与其输入embedding是一样的权重 | |||
| :return: TransformerSeq2SeqModel | |||
| """ | |||
| if bind_encoder_decoder_embed and tgt_embed is not None: | |||
| raise RuntimeError("If you set `bind_encoder_decoder_embed=True`, please do not provide `tgt_embed`.") | |||
| src_embed = get_embeddings(src_embed) | |||
| if bind_encoder_decoder_embed: | |||
| tgt_embed = src_embed | |||
| else: | |||
| assert tgt_embed is not None, "You need to pass `tgt_embed` when `bind_encoder_decoder_embed=False`" | |||
| tgt_embed = get_embeddings(tgt_embed) | |||
| if pos_embed == 'sin': | |||
| encoder_pos_embed = nn.Embedding.from_pretrained( | |||
| get_sinusoid_encoding_table(max_position + 1, src_embed.embedding_dim, padding_idx=0), | |||
| freeze=True) # 这里规定0是padding | |||
| deocder_pos_embed = nn.Embedding.from_pretrained( | |||
| get_sinusoid_encoding_table(max_position + 1, tgt_embed.embedding_dim, padding_idx=0), | |||
| freeze=True) # 这里规定0是padding | |||
| elif pos_embed == 'learned': | |||
| encoder_pos_embed = get_embeddings((max_position + 1, src_embed.embedding_dim), padding_idx=0) | |||
| deocder_pos_embed = get_embeddings((max_position + 1, src_embed.embedding_dim), padding_idx=1) | |||
| else: | |||
| raise ValueError("pos_embed only supports sin or learned.") | |||
| encoder = TransformerSeq2SeqEncoder(embed=src_embed, pos_embed=encoder_pos_embed, | |||
| num_layers=num_layers, d_model=d_model, n_head=n_head, dim_ff=dim_ff, | |||
| dropout=dropout) | |||
| decoder = TransformerSeq2SeqDecoder(embed=tgt_embed, pos_embed=deocder_pos_embed, | |||
| d_model=d_model, num_layers=num_layers, n_head=n_head, dim_ff=dim_ff, | |||
| dropout=dropout, | |||
| bind_decoder_input_output_embed=bind_decoder_input_output_embed) | |||
| return cls(encoder, decoder) | |||
| class LSTMSeq2SeqModel(Seq2SeqModel): | |||
| """ | |||
| 使用LSTMSeq2SeqEncoder和LSTMSeq2SeqDecoder的model | |||
| """ | |||
| def __init__(self, encoder, decoder): | |||
| super().__init__(encoder, decoder) | |||
| @classmethod | |||
| def build_model(cls, src_embed, tgt_embed=None, | |||
| num_layers = 3, hidden_size = 400, dropout = 0.3, bidirectional=True, | |||
| attention=True, bind_encoder_decoder_embed=False, | |||
| bind_decoder_input_output_embed=True): | |||
| """ | |||
| :param nn.Module, StaticEmbedding, Tuple[int, int] src_embed: source的embedding | |||
| :param nn.Module, StaticEmbedding, Tuple[int, int] tgt_embed: target的embedding,如果bind_encoder_decoder_embed为 | |||
| True,则不要输入该值 | |||
| :param int num_layers: Encoder和Decoder的层数 | |||
| :param int hidden_size: encoder和decoder的隐藏层大小 | |||
| :param float dropout: 每层之间的Dropout的大小 | |||
| :param bool bidirectional: encoder是否使用双向LSTM | |||
| :param bool attention: decoder是否使用attention attend encoder在所有时刻的状态 | |||
| :param bool bind_encoder_decoder_embed: 是否对encoder和decoder使用相同的embedding | |||
| :param bool bind_decoder_input_output_embed: decoder的输出embedding是否与其输入embedding是一样的权重 | |||
| :return: LSTMSeq2SeqModel | |||
| """ | |||
| if bind_encoder_decoder_embed and tgt_embed is not None: | |||
| raise RuntimeError("If you set `bind_encoder_decoder_embed=True`, please do not provide `tgt_embed`.") | |||
| src_embed = get_embeddings(src_embed) | |||
| if bind_encoder_decoder_embed: | |||
| tgt_embed = src_embed | |||
| else: | |||
| assert tgt_embed is not None, "You need to pass `tgt_embed` when `bind_encoder_decoder_embed=False`" | |||
| tgt_embed = get_embeddings(tgt_embed) | |||
| encoder = LSTMSeq2SeqEncoder(embed=src_embed, num_layers = num_layers, | |||
| hidden_size = hidden_size, dropout = dropout, bidirectional=bidirectional) | |||
| decoder = LSTMSeq2SeqDecoder(embed=tgt_embed, num_layers = num_layers, hidden_size = hidden_size, | |||
| dropout = dropout, bind_decoder_input_output_embed = bind_decoder_input_output_embed, | |||
| attention=attention) | |||
| return cls(encoder, decoder) | |||
| @@ -14,9 +14,9 @@ import torch.nn.functional as F | |||
| from .base_model import BaseModel | |||
| from ..core.const import Const as C | |||
| from ..core.utils import seq_len_to_mask | |||
| from ..embeddings import get_embeddings | |||
| from ..modules import ConditionalRandomField | |||
| from ..modules import LSTM | |||
| from ..embeddings.utils import get_embeddings | |||
| from ..modules.decoder import ConditionalRandomField | |||
| from ..modules.encoder import LSTM | |||
| from ..modules import decoder, encoder | |||
| from ..modules.decoder.crf import allowed_transitions | |||
| @@ -58,7 +58,21 @@ __all__ = [ | |||
| "RobertaModel", | |||
| "GPT2Model", | |||
| "GPT2Tokenizer" | |||
| "GPT2Tokenizer", | |||
| "TransformerSeq2SeqEncoder", | |||
| "LSTMSeq2SeqEncoder", | |||
| "Seq2SeqEncoder", | |||
| "TransformerSeq2SeqDecoder", | |||
| "LSTMSeq2SeqDecoder", | |||
| "Seq2SeqDecoder", | |||
| "TransformerState", | |||
| "LSTMState", | |||
| "State", | |||
| "SequenceGenerator" | |||
| ] | |||
| import sys | |||
| @@ -68,6 +82,7 @@ from . import encoder | |||
| from .decoder import * | |||
| from .dropout import TimestepDropout | |||
| from .encoder import * | |||
| from .generator import * | |||
| from .utils import summary | |||
| from ..doc_utils import doc_process | |||
| from .tokenizer import * | |||
| @@ -12,7 +12,8 @@ import torch | |||
| import torch.nn.functional as F | |||
| from torch import nn | |||
| from fastNLP.modules.utils import initial_parameter | |||
| from .utils import initial_parameter | |||
| from .decoder.seq2seq_state import TransformerState | |||
| class DotAttention(nn.Module): | |||
| @@ -45,64 +46,153 @@ class DotAttention(nn.Module): | |||
| class MultiHeadAttention(nn.Module): | |||
| r""" | |||
| Transformer当中的MultiHeadAttention | |||
| """ | |||
| Attention is all you need中提到的多头注意力 | |||
| def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1): | |||
| r""" | |||
| :param input_size: int, 输入维度的大小。同时也是输出维度的大小。 | |||
| :param key_size: int, 每个head的维度大小。 | |||
| :param value_size: int,每个head中value的维度。 | |||
| :param num_head: int,head的数量。 | |||
| :param dropout: float。 | |||
| """ | |||
| """ | |||
| def __init__(self, d_model: int = 512, n_head: int = 8, dropout: float = 0.0, layer_idx: int = None): | |||
| super(MultiHeadAttention, self).__init__() | |||
| self.input_size = input_size | |||
| self.key_size = key_size | |||
| self.value_size = value_size | |||
| self.num_head = num_head | |||
| in_size = key_size * num_head | |||
| self.q_in = nn.Linear(input_size, in_size) | |||
| self.k_in = nn.Linear(input_size, in_size) | |||
| self.v_in = nn.Linear(input_size, in_size) | |||
| self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=dropout) | |||
| self.out = nn.Linear(value_size * num_head, input_size) | |||
| self.d_model = d_model | |||
| self.n_head = n_head | |||
| self.dropout = dropout | |||
| self.head_dim = d_model // n_head | |||
| self.layer_idx = layer_idx | |||
| assert d_model % n_head == 0, "d_model should be divisible by n_head" | |||
| self.scaling = self.head_dim ** -0.5 | |||
| self.q_proj = nn.Linear(d_model, d_model) | |||
| self.k_proj = nn.Linear(d_model, d_model) | |||
| self.v_proj = nn.Linear(d_model, d_model) | |||
| self.out_proj = nn.Linear(d_model, d_model) | |||
| self.reset_parameters() | |||
| def forward(self, query, key, value, key_mask=None, attn_mask=None, state=None): | |||
| """ | |||
| :param query: batch x seq x dim | |||
| :param key: batch x seq x dim | |||
| :param value: batch x seq x dim | |||
| :param key_mask: batch x seq 用于指示哪些key不要attend到;注意到mask为1的地方是要attend到的 | |||
| :param attn_mask: seq x seq, 用于mask掉attention map。 主要是用在训练时decoder端的self attention,下三角为1 | |||
| :param state: 过去的信息,在inference的时候会用到,比如encoder output、decoder的prev kv。这样可以减少计算。 | |||
| :return: | |||
| """ | |||
| assert key.size() == value.size() | |||
| if state is not None: | |||
| assert self.layer_idx is not None | |||
| qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr() | |||
| q = self.q_proj(query) # batch x seq x dim | |||
| q *= self.scaling | |||
| k = v = None | |||
| prev_k = prev_v = None | |||
| # 从state中取kv | |||
| if isinstance(state, TransformerState): # 说明此时在inference阶段 | |||
| if qkv_same: # 此时在decoder self attention | |||
| prev_k = state.decoder_prev_key[self.layer_idx] | |||
| prev_v = state.decoder_prev_value[self.layer_idx] | |||
| else: # 此时在decoder-encoder attention,直接将保存下来的key装载起来即可 | |||
| k = state.encoder_key[self.layer_idx] | |||
| v = state.encoder_value[self.layer_idx] | |||
| if k is None: | |||
| k = self.k_proj(key) | |||
| v = self.v_proj(value) | |||
| if prev_k is not None: | |||
| k = torch.cat((prev_k, k), dim=1) | |||
| v = torch.cat((prev_v, v), dim=1) | |||
| # 更新state | |||
| if isinstance(state, TransformerState): | |||
| if qkv_same: | |||
| state.decoder_prev_key[self.layer_idx] = k | |||
| state.decoder_prev_value[self.layer_idx] = v | |||
| else: | |||
| state.encoder_key[self.layer_idx] = k | |||
| state.encoder_value[self.layer_idx] = v | |||
| # 开始计算attention | |||
| batch_size, q_len, d_model = query.size() | |||
| k_len, v_len = k.size(1), v.size(1) | |||
| q = q.reshape(batch_size, q_len, self.n_head, self.head_dim) | |||
| k = k.reshape(batch_size, k_len, self.n_head, self.head_dim) | |||
| v = v.reshape(batch_size, v_len, self.n_head, self.head_dim) | |||
| attn_weights = torch.einsum('bqnh,bknh->bqkn', q, k) # bs,q_len,k_len,n_head | |||
| if key_mask is not None: | |||
| _key_mask = ~key_mask[:, None, :, None].bool() # batch,1,k_len,1 | |||
| attn_weights = attn_weights.masked_fill(_key_mask, -float('inf')) | |||
| if attn_mask is not None: | |||
| _attn_mask = attn_mask[None, :, :, None].eq(0) # 1,q_len,k_len,n_head | |||
| attn_weights = attn_weights.masked_fill(_attn_mask, -float('inf')) | |||
| attn_weights = F.softmax(attn_weights, dim=2) | |||
| attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) | |||
| output = torch.einsum('bqkn,bknh->bqnh', attn_weights, v) # batch,q_len,n_head,head_dim | |||
| output = output.reshape(batch_size, q_len, -1) | |||
| output = self.out_proj(output) # batch,q_len,dim | |||
| return output, attn_weights | |||
| def reset_parameters(self): | |||
| sqrt = math.sqrt | |||
| nn.init.normal_(self.q_in.weight, mean=0, std=sqrt(1.0 / self.input_size)) | |||
| nn.init.normal_(self.k_in.weight, mean=0, std=sqrt(1.0 / self.input_size)) | |||
| nn.init.normal_(self.v_in.weight, mean=0, std=sqrt(1.0 / self.input_size)) | |||
| nn.init.normal_(self.out.weight, mean=0, std=sqrt(1.0 / self.input_size)) | |||
| nn.init.xavier_uniform_(self.q_proj.weight) | |||
| nn.init.xavier_uniform_(self.k_proj.weight) | |||
| nn.init.xavier_uniform_(self.v_proj.weight) | |||
| nn.init.xavier_uniform_(self.out_proj.weight) | |||
| def set_layer_idx(self, layer_idx): | |||
| self.layer_idx = layer_idx | |||
| def forward(self, Q, K, V, atte_mask_out=None): | |||
| r""" | |||
| :param Q: [batch, seq_len_q, model_size] | |||
| :param K: [batch, seq_len_k, model_size] | |||
| :param V: [batch, seq_len_k, model_size] | |||
| :param seq_mask: [batch, seq_len] | |||
| class AttentionLayer(nn.Module): | |||
| def __init__(selfu, input_size, key_dim, value_dim, bias=False): | |||
| """ | |||
| batch, sq, _ = Q.size() | |||
| sk = K.size(1) | |||
| d_k, d_v, n_head = self.key_size, self.value_size, self.num_head | |||
| # input linear | |||
| q = self.q_in(Q).view(batch, sq, n_head, d_k).transpose(1, 2) | |||
| k = self.k_in(K).view(batch, sk, n_head, d_k).transpose(1, 2) | |||
| v = self.v_in(V).view(batch, sk, n_head, d_v).transpose(1, 2) | |||
| if atte_mask_out is not None: | |||
| atte_mask_out = atte_mask_out[:,None,:,:] # [bsz,1,1,len] | |||
| atte = self.attention(q, k, v, atte_mask_out).view(batch, n_head, sq, d_v) | |||
| # concat all heads, do output linear | |||
| atte = atte.transpose(1, 2).contiguous().view(batch, sq, -1) | |||
| output = self.out(atte) | |||
| return output | |||
| 可用于LSTM2LSTM的序列到序列模型的decode过程中,该attention是在decode过程中根据上一个step的hidden计算对encoder结果的attention | |||
| :param int input_size: 输入的大小 | |||
| :param int key_dim: 一般就是encoder_output输出的维度 | |||
| :param int value_dim: 输出的大小维度, 一般就是decoder hidden的大小 | |||
| :param bias: | |||
| """ | |||
| super().__init__() | |||
| selfu.input_proj = nn.Linear(input_size, key_dim, bias=bias) | |||
| selfu.output_proj = nn.Linear(input_size + key_dim, value_dim, bias=bias) | |||
| def forward(self, input, encode_outputs, encode_mask): | |||
| """ | |||
| :param input: batch_size x input_size | |||
| :param encode_outputs: batch_size x max_len x key_dim | |||
| :param encode_mask: batch_size x max_len, 为0的地方为padding | |||
| :return: hidden: batch_size x value_dim, scores: batch_size x max_len, normalized过的 | |||
| """ | |||
| # x: bsz x encode_hidden_size | |||
| x = self.input_proj(input) | |||
| # compute attention | |||
| attn_scores = torch.matmul(encode_outputs, x.unsqueeze(-1)).squeeze(-1) # b x max_len | |||
| # don't attend over padding | |||
| if encode_mask is not None: | |||
| attn_scores = attn_scores.float().masked_fill_( | |||
| encode_mask.eq(0), | |||
| float('-inf') | |||
| ).type_as(attn_scores) # FP16 support: cast to float and back | |||
| attn_scores = F.softmax(attn_scores, dim=-1) # srclen x bsz | |||
| # sum weighted sources | |||
| x = torch.matmul(attn_scores.unsqueeze(1), encode_outputs).squeeze(1) # b x encode_hidden_size | |||
| x = torch.tanh(self.output_proj(torch.cat((x, input), dim=1))) | |||
| return x, attn_scores | |||
| def _masked_softmax(tensor, mask): | |||
| @@ -6,10 +6,20 @@ __all__ = [ | |||
| "MLP", | |||
| "ConditionalRandomField", | |||
| "viterbi_decode", | |||
| "allowed_transitions" | |||
| "allowed_transitions", | |||
| "LSTMState", | |||
| "TransformerState", | |||
| "State", | |||
| "TransformerSeq2SeqDecoder", | |||
| "LSTMSeq2SeqDecoder", | |||
| "Seq2SeqDecoder" | |||
| ] | |||
| from .crf import ConditionalRandomField | |||
| from .crf import allowed_transitions | |||
| from .mlp import MLP | |||
| from .utils import viterbi_decode | |||
| from .seq2seq_decoder import Seq2SeqDecoder, LSTMSeq2SeqDecoder, TransformerSeq2SeqDecoder | |||
| from .seq2seq_state import State, LSTMState, TransformerState | |||
| @@ -1,109 +1,413 @@ | |||
| # coding=utf-8 | |||
| __all__ = [ | |||
| "TransformerPast", | |||
| "Past", | |||
| "Decoder" | |||
| ] | |||
| from typing import Union, Tuple | |||
| import math | |||
| import torch | |||
| from torch import nn | |||
| import abc | |||
| import torch.nn.functional as F | |||
| from ..attention import AttentionLayer, MultiHeadAttention | |||
| from ...embeddings import StaticEmbedding | |||
| import numpy as np | |||
| from typing import Union, Tuple | |||
| from ...embeddings.utils import get_embeddings | |||
| from torch.nn import LayerNorm | |||
| import math | |||
| from .seq2seq_state import State, LSTMState, TransformerState | |||
| class Seq2SeqDecoder(nn.Module): | |||
| """ | |||
| Sequence-to-Sequence Decoder的基类。一定需要实现forward函数,剩下的函数根据需要实现。每个Seq2SeqDecoder都应该有相应的State对象 | |||
| 用来承载该Decoder所需要的Encoder输出、Decoder需要记录的历史信息(例如LSTM的hidden信息)。 | |||
| class Past: | |||
| """ | |||
| def __init__(self): | |||
| pass | |||
| super().__init__() | |||
| def forward(self, tokens, state, **kwargs): | |||
| """ | |||
| @abc.abstractmethod | |||
| def num_samples(self): | |||
| pass | |||
| :param torch.LongTensor tokens: bsz x max_len | |||
| :param State state: state包含了encoder的输出以及decode之前的内容 | |||
| :return: 返回值可以为bsz x max_len x vocab_size的Tensor,也可以是一个list,但是第一个元素必须是词的预测分布 | |||
| """ | |||
| raise NotImplemented | |||
| @abc.abstractmethod | |||
| def reorder_past(self, indices: torch.LongTensor): | |||
| def reorder_states(self, indices, states): | |||
| """ | |||
| 根据indices中的index,将past的中状态置为正确的顺序。inplace改变 | |||
| 根据indices重新排列states中的状态,在beam search进行生成时,会用到该函数。 | |||
| :param torch.LongTensor indices: | |||
| :param Past past: | |||
| :param State states: | |||
| :return: | |||
| """ | |||
| raise NotImplemented | |||
| assert isinstance(states, State), f"`states` should be of type State instead of {type(states)}" | |||
| states.reorder_state(indices) | |||
| def init_state(self, encoder_output, encoder_mask): | |||
| """ | |||
| 初始化一个state对象,用来记录了encoder的输出以及decode已经完成的部分。 | |||
| :param Union[torch.Tensor, list, tuple] encoder_output: 如果不为None,内部元素需要为torch.Tensor, 默认其中第一维是batch | |||
| 维度 | |||
| :param Union[torch.Tensor, list, tuple] encoder_mask: 如果部位None,内部元素需要torch.Tensor, 默认其中第一维是batch | |||
| 维度 | |||
| :param kwargs: | |||
| :return: State, 返回一个State对象,记录了encoder的输出 | |||
| """ | |||
| state = State(encoder_output, encoder_mask) | |||
| return state | |||
| class TransformerPast(Past): | |||
| def __init__(self, encoder_outputs: torch.Tensor = None, encoder_mask: torch.Tensor = None, | |||
| num_decoder_layer: int = 6): | |||
| def decode(self, tokens, state): | |||
| """ | |||
| 根据states中的内容,以及tokens中的内容进行之后的生成。 | |||
| :param encoder_outputs: (batch,src_seq_len,dim) | |||
| :param encoder_mask: (batch,src_seq_len) | |||
| :param encoder_key: list of (batch, src_seq_len, dim) | |||
| :param encoder_value: | |||
| :param decoder_prev_key: | |||
| :param decoder_prev_value: | |||
| :param torch.LongTensor tokens: bsz x max_len, 上一个时刻的token输出。 | |||
| :param State state: 记录了encoder输出与decoder过去状态 | |||
| :return: torch.FloatTensor: bsz x vocab_size, 输出的是下一个时刻的分布 | |||
| """ | |||
| outputs = self(state=state, tokens=tokens) | |||
| if isinstance(outputs, torch.Tensor): | |||
| return outputs[:, -1] | |||
| else: | |||
| raise RuntimeError("Unrecognized output from the `forward()` function. Please override the `decode()` function.") | |||
| class TiedEmbedding(nn.Module): | |||
| """ | |||
| 用于将weight和原始weight绑定 | |||
| """ | |||
| def __init__(self, weight): | |||
| super().__init__() | |||
| self.encoder_outputs = encoder_outputs | |||
| self.encoder_mask = encoder_mask | |||
| self.encoder_key = [None] * num_decoder_layer | |||
| self.encoder_value = [None] * num_decoder_layer | |||
| self.decoder_prev_key = [None] * num_decoder_layer | |||
| self.decoder_prev_value = [None] * num_decoder_layer | |||
| def num_samples(self): | |||
| if self.encoder_outputs is not None: | |||
| return self.encoder_outputs.size(0) | |||
| return None | |||
| def _reorder_state(self, state, indices): | |||
| if type(state) == torch.Tensor: | |||
| state = state.index_select(index=indices, dim=0) | |||
| elif type(state) == list: | |||
| for i in range(len(state)): | |||
| assert state[i] is not None | |||
| state[i] = state[i].index_select(index=indices, dim=0) | |||
| self.weight = weight # vocab_size x embed_size | |||
| def forward(self, x): | |||
| """ | |||
| :param torch.FloatTensor x: bsz x * x embed_size | |||
| :return: torch.FloatTensor bsz x * x vocab_size | |||
| """ | |||
| return torch.matmul(x, self.weight.t()) | |||
| def get_binded_decoder_output_embed(embed): | |||
| """ | |||
| 给定一个embedding,输出对应的绑定的embedding,输出对象为TiedEmbedding | |||
| :param embed: | |||
| :return: | |||
| """ | |||
| if isinstance(embed, StaticEmbedding): | |||
| for idx, map2idx in enumerate(embed.words_to_words): | |||
| assert idx == map2idx, "Invalid StaticEmbedding for Decoder, please check:(1) whether the vocabulary " \ | |||
| "include `no_create_entry=True` word; (2) StaticEmbedding should not initialize with " \ | |||
| "`lower=True` or `min_freq!=1`." | |||
| elif not isinstance(embed, nn.Embedding): | |||
| raise TypeError("Only nn.Embedding or StaticEmbedding is allowed for binding.") | |||
| return TiedEmbedding(embed.weight) | |||
| class LSTMSeq2SeqDecoder(Seq2SeqDecoder): | |||
| def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], num_layers = 3, hidden_size = 300, | |||
| dropout = 0.3, bind_decoder_input_output_embed = True, attention=True): | |||
| """ | |||
| LSTM的Decoder | |||
| :param nn.Module,tuple embed: decoder输入的embedding. | |||
| :param int num_layers: 多少层LSTM | |||
| :param int hidden_size: 隐藏层大小, 该值也被认为是encoder的输出维度大小 | |||
| :param dropout: Dropout的大小 | |||
| :param bool bind_decoder_input_output_embed: 是否将输出层和输入层的词向量绑定在一起(即为同一个),若embed为StaticEmbedding, | |||
| 则StaticEmbedding的vocab不能包含no_create_entry的token,同时StaticEmbedding初始化时lower为False, min_freq=1. | |||
| :param bool attention: 是否使用attention | |||
| """ | |||
| super().__init__() | |||
| self.embed = get_embeddings(init_embed=embed) | |||
| self.embed_dim = embed.embedding_dim | |||
| if bind_decoder_input_output_embed: | |||
| self.output_layer = get_binded_decoder_output_embed(self.embed) | |||
| else: # 不需要bind | |||
| self.output_embed = get_embeddings((self.embed.num_embeddings, self.embed.embedding_dim)) | |||
| self.output_layer = TiedEmbedding(self.output_embed.weight) | |||
| self.hidden_size = hidden_size | |||
| self.num_layers = num_layers | |||
| self.lstm = nn.LSTM(input_size=self.embed_dim + hidden_size, hidden_size=hidden_size, num_layers=num_layers, | |||
| batch_first=True, bidirectional=False, dropout=dropout if num_layers>1 else 0) | |||
| self.attention_layer = AttentionLayer(hidden_size, hidden_size, hidden_size) if attention else None | |||
| self.output_proj = nn.Linear(hidden_size, self.embed_dim) | |||
| self.dropout_layer = nn.Dropout(dropout) | |||
| def forward(self, tokens, state, return_attention=False): | |||
| """ | |||
| :param torch.LongTensor tokens: batch x max_len | |||
| :param LSTMState state: 保存encoder输出和decode状态的State对象 | |||
| :param bool return_attention: 是否返回attention的的score | |||
| :return: bsz x max_len x vocab_size; 如果return_attention=True, 还会返回bsz x max_len x encode_length | |||
| """ | |||
| src_output = state.encoder_output | |||
| encoder_mask = state.encoder_mask | |||
| assert tokens.size(1)>state.decode_length, "The state does not match the tokens." | |||
| tokens = tokens[:, state.decode_length:] | |||
| x = self.embed(tokens) | |||
| attn_weights = [] if self.attention_layer is not None else None # 保存attention weight, batch,tgt_seq,src_seq | |||
| input_feed = state.input_feed | |||
| decoder_out = [] | |||
| cur_hidden = state.hidden | |||
| cur_cell = state.cell | |||
| # 开始计算 | |||
| for i in range(tokens.size(1)): | |||
| input = torch.cat( | |||
| (x[:, i:i + 1, :], | |||
| input_feed[:, None, :] | |||
| ), | |||
| dim=2 | |||
| ) # batch,1,2*dim | |||
| _, (cur_hidden, cur_cell) = self.lstm(input, hx=(cur_hidden, cur_cell)) # hidden/cell保持原来的size | |||
| if self.attention_layer is not None: | |||
| input_feed, attn_weight = self.attention_layer(cur_hidden[-1], src_output, encoder_mask) | |||
| attn_weights.append(attn_weight) | |||
| else: | |||
| input_feed = cur_hidden[-1] | |||
| state.input_feed = input_feed # batch, hidden | |||
| state.hidden = cur_hidden | |||
| state.cell = cur_cell | |||
| state.decode_length += 1 | |||
| decoder_out.append(input_feed) | |||
| decoder_out = torch.stack(decoder_out, dim=1) # batch,seq_len,hidden | |||
| decoder_out = self.dropout_layer(decoder_out) | |||
| if attn_weights is not None: | |||
| attn_weights = torch.cat(attn_weights, dim=1) # batch, tgt_len, src_len | |||
| decoder_out = self.output_proj(decoder_out) | |||
| feats = self.output_layer(decoder_out) | |||
| if return_attention: | |||
| return feats, attn_weights | |||
| return feats | |||
| def init_state(self, encoder_output, encoder_mask) -> LSTMState: | |||
| """ | |||
| :param encoder_output: 输入可以有两种情况(1) 输入为一个tuple,包含三个内容(encoder_output, (hidden, cell)),其中encoder_output: | |||
| bsz x max_len x hidden_size, hidden: bsz x hidden_size, cell:bsz x hidden_size,一般使用LSTMEncoder的最后一层的 | |||
| hidden state和cell state来赋值这两个值 | |||
| (2) 只有encoder_output: bsz x max_len x hidden_size, 这种情况下hidden和cell使用0初始化 | |||
| :param torch.ByteTensor encoder_mask: bsz x max_len, 为0的位置是padding, 用来指示source中哪些不需要attend | |||
| :return: | |||
| """ | |||
| if not isinstance(encoder_output, torch.Tensor): | |||
| encoder_output, (hidden, cell) = encoder_output | |||
| else: | |||
| raise ValueError('State does not support other format') | |||
| hidden = cell = None | |||
| assert encoder_output.ndim==3 | |||
| assert encoder_mask.size()==encoder_output.size()[:2] | |||
| assert encoder_output.size(-1)==self.hidden_size, "The dimension of encoder outputs should be the same with " \ | |||
| "the hidden_size." | |||
| t = [hidden, cell] | |||
| for idx in range(2): | |||
| v = t[idx] | |||
| if v is None: | |||
| v = encoder_output.new_zeros(self.num_layers, encoder_output.size(0), self.hidden_size) | |||
| else: | |||
| assert v.dim()==2 | |||
| assert v.size(-1)==self.hidden_size | |||
| v = v[None].repeat(self.num_layers, 1, 1) # num_layers x bsz x hidden_size | |||
| t[idx] = v | |||
| state = LSTMState(encoder_output, encoder_mask, t[0], t[1]) | |||
| return state | |||
| def reorder_past(self, indices: torch.LongTensor): | |||
| self.encoder_outputs = self._reorder_state(self.encoder_outputs, indices) | |||
| self.encoder_mask = self._reorder_state(self.encoder_mask, indices) | |||
| self.encoder_key = self._reorder_state(self.encoder_key, indices) | |||
| self.encoder_value = self._reorder_state(self.encoder_value, indices) | |||
| self.decoder_prev_key = self._reorder_state(self.decoder_prev_key, indices) | |||
| self.decoder_prev_value = self._reorder_state(self.decoder_prev_value, indices) | |||
| return self | |||
| class TransformerSeq2SeqDecoderLayer(nn.Module): | |||
| def __init__(self, d_model = 512, n_head = 8, dim_ff = 2048, dropout = 0.1, layer_idx = None): | |||
| """ | |||
| class Decoder(nn.Module): | |||
| def __init__(self): | |||
| :param int d_model: 输入、输出的维度 | |||
| :param int n_head: 多少个head,需要能被d_model整除 | |||
| :param int dim_ff: | |||
| :param float dropout: | |||
| :param int layer_idx: layer的编号 | |||
| """ | |||
| super().__init__() | |||
| self.d_model = d_model | |||
| self.n_head = n_head | |||
| self.dim_ff = dim_ff | |||
| self.dropout = dropout | |||
| self.layer_idx = layer_idx # 记录layer的层索引,以方便获取state的信息 | |||
| self.self_attn = MultiHeadAttention(d_model, n_head, dropout, layer_idx) | |||
| self.self_attn_layer_norm = nn.LayerNorm(d_model) | |||
| self.encoder_attn = MultiHeadAttention(d_model, n_head, dropout, layer_idx) | |||
| self.encoder_attn_layer_norm = nn.LayerNorm(d_model) | |||
| self.ffn = nn.Sequential(nn.Linear(self.d_model, self.dim_ff), | |||
| nn.ReLU(), | |||
| nn.Dropout(dropout), | |||
| nn.Linear(self.dim_ff, self.d_model), | |||
| nn.Dropout(dropout)) | |||
| self.final_layer_norm = nn.LayerNorm(self.d_model) | |||
| @abc.abstractmethod | |||
| def decode(self, *args, **kwargs) -> Tuple[torch.Tensor, Past]: | |||
| def forward(self, x, encoder_output, encoder_mask=None, self_attn_mask=None, state=None): | |||
| """ | |||
| 当模型进行解码时,使用这个函数。返回一个batch_size x vocab_size的结果与更新的Past状态。需要考虑一种特殊情况,即tokens长度不是1,即给定了 | |||
| 解码句子开头的情况,这种情况需要查看Past中是否正确计算了decode的状态。 | |||
| :return: tensor:batch_size x vocab_size, past: Past | |||
| :param x: (batch, seq_len, dim), decoder端的输入 | |||
| :param encoder_output: (batch,src_seq_len,dim) | |||
| :param encoder_mask: batch,src_seq_len | |||
| :param self_attn_mask: seq_len, seq_len,下三角的mask矩阵,只在训练时传入 | |||
| :param TransformerState state: 只在inference阶段传入 | |||
| :return: | |||
| """ | |||
| raise NotImplemented | |||
| @abc.abstractmethod | |||
| def reorder_past(self, indices: torch.LongTensor, past: Past): | |||
| # self attention part | |||
| residual = x | |||
| x = self.self_attn_layer_norm(x) | |||
| x, _ = self.self_attn(query=x, | |||
| key=x, | |||
| value=x, | |||
| attn_mask=self_attn_mask, | |||
| state=state) | |||
| x = F.dropout(x, p=self.dropout, training=self.training) | |||
| x = residual + x | |||
| # encoder attention part | |||
| residual = x | |||
| x = self.encoder_attn_layer_norm(x) | |||
| x, attn_weight = self.encoder_attn(query=x, | |||
| key=encoder_output, | |||
| value=encoder_output, | |||
| key_mask=encoder_mask, | |||
| state=state) | |||
| x = F.dropout(x, p=self.dropout, training=self.training) | |||
| x = residual + x | |||
| # ffn | |||
| residual = x | |||
| x = self.final_layer_norm(x) | |||
| x = self.ffn(x) | |||
| x = residual + x | |||
| return x, attn_weight | |||
| class TransformerSeq2SeqDecoder(Seq2SeqDecoder): | |||
| def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], pos_embed: nn.Module = None, | |||
| d_model = 512, num_layers=6, n_head = 8, dim_ff = 2048, dropout = 0.1, | |||
| bind_decoder_input_output_embed = True): | |||
| """ | |||
| 根据indices中的index,将past的中状态置为正确的顺序。inplace改变 | |||
| :param torch.LongTensor indices: | |||
| :param Past past: | |||
| :return: | |||
| :param embed: 输入token的embedding | |||
| :param nn.Module pos_embed: 位置embedding | |||
| :param int d_model: 输出、输出的大小 | |||
| :param int num_layers: 多少层 | |||
| :param int n_head: 多少个head | |||
| :param int dim_ff: FFN 的中间大小 | |||
| :param float dropout: Self-Attention和FFN中的dropout的大小 | |||
| :param bool bind_decoder_input_output_embed: 是否将输出层和输入层的词向量绑定在一起(即为同一个),若embed为StaticEmbedding, | |||
| 则StaticEmbedding的vocab不能包含no_create_entry的token,同时StaticEmbedding初始化时lower为False, min_freq=1. | |||
| """ | |||
| super().__init__() | |||
| self.embed = get_embeddings(embed) | |||
| self.pos_embed = pos_embed | |||
| if bind_decoder_input_output_embed: | |||
| self.output_layer = get_binded_decoder_output_embed(self.embed) | |||
| else: # 不需要bind | |||
| self.output_embed = get_embeddings((self.embed.num_embeddings, self.embed.embedding_dim)) | |||
| self.output_layer = TiedEmbedding(self.output_embed.weight) | |||
| self.num_layers = num_layers | |||
| self.d_model = d_model | |||
| self.n_head = n_head | |||
| self.dim_ff = dim_ff | |||
| self.dropout = dropout | |||
| self.input_fc = nn.Linear(self.embed.embedding_dim, d_model) | |||
| self.layer_stacks = nn.ModuleList([TransformerSeq2SeqDecoderLayer(d_model, n_head, dim_ff, dropout, layer_idx) | |||
| for layer_idx in range(num_layers)]) | |||
| self.embed_scale = math.sqrt(d_model) | |||
| self.layer_norm = nn.LayerNorm(d_model) | |||
| self.output_fc = nn.Linear(self.d_model, self.embed.embedding_dim) | |||
| def forward(self, tokens, state, return_attention=False): | |||
| """ | |||
| raise NotImplemented | |||
| :param torch.LongTensor tokens: batch x tgt_len,decode的词 | |||
| :param TransformerState state: 用于记录encoder的输出以及decode状态的对象,可以通过init_state()获取 | |||
| :param bool return_attention: 是否返回对encoder结果的attention score | |||
| :return: bsz x max_len x vocab_size; 如果return_attention=True, 还会返回bsz x max_len x encode_length | |||
| """ | |||
| encoder_output = state.encoder_output | |||
| encoder_mask = state.encoder_mask | |||
| assert state.decode_length<tokens.size(1), "The decoded tokens in State should be less than tokens." | |||
| tokens = tokens[:, state.decode_length:] | |||
| device = tokens.device | |||
| x = self.embed_scale * self.embed(tokens) | |||
| if self.pos_embed is not None: | |||
| position = torch.arange(state.decode_length, state.decode_length+tokens.size(1)).long().to(device)[None] | |||
| x += self.pos_embed(position) | |||
| x = self.input_fc(x) | |||
| x = F.dropout(x, p=self.dropout, training=self.training) | |||
| batch_size, max_tgt_len = tokens.size() | |||
| if max_tgt_len>1: | |||
| triangle_mask = self._get_triangle_mask(tokens) | |||
| else: | |||
| triangle_mask = None | |||
| for layer in self.layer_stacks: | |||
| x, attn_weight = layer(x=x, | |||
| encoder_output=encoder_output, | |||
| encoder_mask=encoder_mask, | |||
| self_attn_mask=triangle_mask, | |||
| state=state | |||
| ) | |||
| x = self.layer_norm(x) # batch, tgt_len, dim | |||
| x = self.output_fc(x) | |||
| feats = self.output_layer(x) | |||
| if return_attention: | |||
| return feats, attn_weight | |||
| return feats | |||
| def init_state(self, encoder_output, encoder_mask): | |||
| """ | |||
| 初始化一个TransformerState用于forward | |||
| :param torch.FloatTensor encoder_output: bsz x max_len x d_model, encoder的输出 | |||
| :param torch.ByteTensor encoder_mask: bsz x max_len, 为1的位置需要attend。 | |||
| :return: TransformerState | |||
| """ | |||
| if isinstance(encoder_output, torch.Tensor): | |||
| encoder_output = encoder_output | |||
| elif isinstance(encoder_output, (list, tuple)): | |||
| encoder_output = encoder_output[0] # 防止是LSTMEncoder的输出结果 | |||
| else: | |||
| raise TypeError("Unsupported `encoder_output` for TransformerSeq2SeqDecoder") | |||
| state = TransformerState(encoder_output, encoder_mask, num_decoder_layer=self.num_layers) | |||
| return state | |||
| @staticmethod | |||
| def _get_triangle_mask(tokens): | |||
| tensor = tokens.new_ones(tokens.size(1), tokens.size(1)) | |||
| return torch.tril(tensor).byte() | |||
| @@ -0,0 +1,145 @@ | |||
| r""" | |||
| 每个Decoder都有对应的State用来记录encoder的输出以及Decode的历史记录 | |||
| """ | |||
| __all__ = [ | |||
| 'State', | |||
| "LSTMState", | |||
| "TransformerState" | |||
| ] | |||
| from typing import Union | |||
| import torch | |||
| class State: | |||
| def __init__(self, encoder_output=None, encoder_mask=None, **kwargs): | |||
| """ | |||
| 每个Decoder都有对应的State对象用来承载encoder的输出以及当前时刻之前的decode状态。 | |||
| :param Union[torch.Tensor, list, tuple] encoder_output: 如果不为None,内部元素需要为torch.Tensor, 默认其中第一维是batch | |||
| 维度 | |||
| :param Union[torch.Tensor, list, tuple] encoder_mask: 如果部位None,内部元素需要torch.Tensor, 默认其中第一维是batch | |||
| 维度 | |||
| :param kwargs: | |||
| """ | |||
| self.encoder_output = encoder_output | |||
| self.encoder_mask = encoder_mask | |||
| self._decode_length = 0 | |||
| @property | |||
| def num_samples(self): | |||
| """ | |||
| 返回的State中包含的是多少个sample的encoder状态,主要用于Generate的时候确定batch的大小。 | |||
| :return: | |||
| """ | |||
| if self.encoder_output is not None: | |||
| return self.encoder_output.size(0) | |||
| else: | |||
| return None | |||
| @property | |||
| def decode_length(self): | |||
| """ | |||
| 当前Decode到哪个token了,decoder只会从decode_length之后的token开始decode, 为0说明还没开始decode。 | |||
| :return: | |||
| """ | |||
| return self._decode_length | |||
| @decode_length.setter | |||
| def decode_length(self, value): | |||
| self._decode_length = value | |||
| def _reorder_state(self, state: Union[torch.Tensor, list, tuple], indices: torch.LongTensor, dim: int = 0): | |||
| if isinstance(state, torch.Tensor): | |||
| state = state.index_select(index=indices, dim=dim) | |||
| elif isinstance(state, list): | |||
| for i in range(len(state)): | |||
| assert state[i] is not None | |||
| state[i] = self._reorder_state(state[i], indices, dim) | |||
| elif isinstance(state, tuple): | |||
| tmp_list = [] | |||
| for i in range(len(state)): | |||
| assert state[i] is not None | |||
| tmp_list.append(self._reorder_state(state[i], indices, dim)) | |||
| state = tuple(tmp_list) | |||
| else: | |||
| raise TypeError(f"Cannot reorder data of type:{type(state)}") | |||
| return state | |||
| def reorder_state(self, indices: torch.LongTensor): | |||
| if self.encoder_mask is not None: | |||
| self.encoder_mask = self._reorder_state(self.encoder_mask, indices) | |||
| if self.encoder_output is not None: | |||
| self.encoder_output = self._reorder_state(self.encoder_output, indices) | |||
| class LSTMState(State): | |||
| def __init__(self, encoder_output, encoder_mask, hidden, cell): | |||
| """ | |||
| LSTMDecoder对应的State,保存encoder的输出以及LSTM解码过程中的一些中间状态 | |||
| :param torch.FloatTensor encoder_output: bsz x src_seq_len x encode_output_size,encoder的输出 | |||
| :param torch.BoolTensor encoder_mask: bsz x src_seq_len, 为0的地方是padding | |||
| :param torch.FloatTensor hidden: num_layers x bsz x hidden_size, 上个时刻的hidden状态 | |||
| :param torch.FloatTensor cell: num_layers x bsz x hidden_size, 上个时刻的cell状态 | |||
| """ | |||
| super().__init__(encoder_output, encoder_mask) | |||
| self.hidden = hidden | |||
| self.cell = cell | |||
| self._input_feed = hidden[0] # 默认是上一个时刻的输出 | |||
| @property | |||
| def input_feed(self): | |||
| """ | |||
| LSTMDecoder中每个时刻的输入会把上个token的embedding和input_feed拼接起来输入到下个时刻,在LSTMDecoder不使用attention时, | |||
| input_feed即上个时刻的hidden state, 否则是attention layer的输出。 | |||
| :return: torch.FloatTensor, bsz x hidden_size | |||
| """ | |||
| return self._input_feed | |||
| @input_feed.setter | |||
| def input_feed(self, value): | |||
| self._input_feed = value | |||
| def reorder_state(self, indices: torch.LongTensor): | |||
| super().reorder_state(indices) | |||
| self.hidden = self._reorder_state(self.hidden, indices, dim=1) | |||
| self.cell = self._reorder_state(self.cell, indices, dim=1) | |||
| if self.input_feed is not None: | |||
| self.input_feed = self._reorder_state(self.input_feed, indices, dim=0) | |||
| class TransformerState(State): | |||
| def __init__(self, encoder_output, encoder_mask, num_decoder_layer): | |||
| """ | |||
| 与TransformerSeq2SeqDecoder对应的State, | |||
| :param torch.FloatTensor encoder_output: bsz x encode_max_len x encoder_output_size, encoder的输出 | |||
| :param torch.ByteTensor encoder_mask: bsz x encode_max_len 为1的地方需要attend | |||
| :param int num_decoder_layer: decode有多少层 | |||
| """ | |||
| super().__init__(encoder_output, encoder_mask) | |||
| self.encoder_key = [None] * num_decoder_layer # 每一个元素 bsz x encoder_max_len x key_dim | |||
| self.encoder_value = [None] * num_decoder_layer # 每一个元素 bsz x encoder_max_len x value_dim | |||
| self.decoder_prev_key = [None] * num_decoder_layer # 每一个元素 bsz x decode_length x key_dim | |||
| self.decoder_prev_value = [None] * num_decoder_layer # 每一个元素 bsz x decode_length x key_dim | |||
| def reorder_state(self, indices: torch.LongTensor): | |||
| super().reorder_state(indices) | |||
| self.encoder_key = self._reorder_state(self.encoder_key, indices) | |||
| self.encoder_value = self._reorder_state(self.encoder_value, indices) | |||
| self.decoder_prev_key = self._reorder_state(self.decoder_prev_key, indices) | |||
| self.decoder_prev_value = self._reorder_state(self.decoder_prev_value, indices) | |||
| @property | |||
| def decode_length(self): | |||
| if self.decoder_prev_key[0] is not None: | |||
| return self.decoder_prev_key[0].size(1) | |||
| return 0 | |||
| @@ -4,8 +4,6 @@ r""" | |||
| """ | |||
| __all__ = [ | |||
| # "BertModel", | |||
| "ConvolutionCharEncoder", | |||
| "LSTMCharEncoder", | |||
| @@ -35,10 +33,14 @@ __all__ = [ | |||
| "RobertaModel", | |||
| "GPT2Model" | |||
| "GPT2Model", | |||
| "LSTMSeq2SeqEncoder", | |||
| "TransformerSeq2SeqEncoder", | |||
| "Seq2SeqEncoder" | |||
| ] | |||
| from .attention import MultiHeadAttention, BiAttention, SelfAttention | |||
| from fastNLP.modules.attention import MultiHeadAttention, BiAttention, SelfAttention | |||
| from .bert import BertModel | |||
| from .roberta import RobertaModel | |||
| from .gpt2 import GPT2Model | |||
| @@ -49,3 +51,4 @@ from .pooling import MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, KMaxPoo | |||
| from .star_transformer import StarTransformer | |||
| from .transformer import TransformerEncoder | |||
| from .variational_rnn import VarRNN, VarLSTM, VarGRU | |||
| from .seq2seq_encoder import LSTMSeq2SeqEncoder, TransformerSeq2SeqEncoder, Seq2SeqEncoder | |||
| @@ -10,6 +10,7 @@ __all__ = [ | |||
| import copy | |||
| import json | |||
| import math | |||
| import os | |||
| import torch | |||
| from torch import nn | |||
| @@ -20,7 +21,8 @@ from ...io.file_utils import _get_bert_dir | |||
| from ...core import logger | |||
| CONFIG_FILE = 'bert_config.json' | |||
| CONFIG_FILE = 'config.json' | |||
| WEIGHTS_NAME = 'pytorch_model.bin' | |||
| BERT_KEY_RENAME_MAP_1 = { | |||
| 'gamma': 'weight', | |||
| @@ -57,7 +59,8 @@ class BertConfig(object): | |||
| max_position_embeddings=512, | |||
| type_vocab_size=2, | |||
| initializer_range=0.02, | |||
| layer_norm_eps=1e-12): | |||
| layer_norm_eps=1e-12, | |||
| architectures='bert'): | |||
| r"""Constructs BertConfig. | |||
| Args: | |||
| @@ -101,6 +104,7 @@ class BertConfig(object): | |||
| self.type_vocab_size = type_vocab_size | |||
| self.initializer_range = initializer_range | |||
| self.layer_norm_eps = layer_norm_eps | |||
| self.architectures = architectures | |||
| else: | |||
| raise ValueError("First argument must be either a vocabulary size (int)" | |||
| "or the path to a pretrained model config file (str)") | |||
| @@ -134,9 +138,13 @@ class BertConfig(object): | |||
| def to_json_file(self, json_file_path): | |||
| r""" Save this instance to a json file.""" | |||
| if os.path.isdir(json_file_path): | |||
| json_file_path = os.path.join(json_file_path, CONFIG_FILE) | |||
| with open(json_file_path, "w", encoding='utf-8') as writer: | |||
| writer.write(self.to_json_string()) | |||
| def save_pretrained(self, save_directory): | |||
| self.to_json_file(save_directory) | |||
| def gelu(x): | |||
| return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) | |||
| @@ -149,21 +157,6 @@ def swish(x): | |||
| ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} | |||
| # class BertLayerNorm(nn.Module): | |||
| # def __init__(self, hidden_size, eps=1e-12): | |||
| # r"""Construct a layernorm module in the TF style (epsilon inside the square root). | |||
| # """ | |||
| # super(BertLayerNorm, self).__init__() | |||
| # self.weight = nn.Parameter(torch.ones(hidden_size)) | |||
| # self.bias = nn.Parameter(torch.zeros(hidden_size)) | |||
| # self.variance_epsilon = eps | |||
| # | |||
| # def forward(self, x): | |||
| # u = x.mean(-1, keepdim=True) | |||
| # s = (x - u).pow(2).mean(-1, keepdim=True) | |||
| # x = (x - u) / torch.sqrt(s + self.variance_epsilon) | |||
| # return self.weight * x + self.bias | |||
| BertLayerNorm = torch.nn.LayerNorm | |||
| @@ -613,3 +606,24 @@ class BertModel(nn.Module): | |||
| logger.info(f"Load pre-trained {model_type} parameters from file {weights_path}.") | |||
| return model | |||
| def save_pretrained(self, save_directory): | |||
| """ 保存模型到某个folder | |||
| """ | |||
| assert os.path.isdir( | |||
| save_directory | |||
| ), "Saving path should be a directory where the model and configuration can be saved" | |||
| # Only save the model itself if we are using distributed training | |||
| model_to_save = self.module if hasattr(self, "module") else self | |||
| # Attach architecture to the config | |||
| model_to_save.config.architectures = [model_to_save.__class__.__name__] | |||
| # Save configuration file | |||
| model_to_save.config.save_pretrained(save_directory) | |||
| # If we save using the predefined names, we can load using `from_pretrained` | |||
| output_model_file = os.path.join(save_directory, WEIGHTS_NAME) | |||
| torch.save(model_to_save.state_dict(), output_model_file) | |||
| logger.debug("Model weights saved in {}".format(output_model_file)) | |||
| @@ -15,9 +15,8 @@ import math | |||
| from torch.nn import CrossEntropyLoss | |||
| from fastNLP.io.file_utils import _get_file_name_base_on_postfix | |||
| from ..decoder.seq2seq_decoder import Decoder, Past | |||
| from ..decoder.seq2seq_decoder import Seq2SeqDecoder, State | |||
| from ..generator.seq2seq_generator import SequenceGenerator | |||
| from typing import Tuple | |||
| GELU_CONSTANT = math.sqrt(2 / math.pi) | |||
| @@ -732,7 +731,7 @@ class GPT2PreTrainedModel(nn.Module): | |||
| bos_token_id=bos_token_id, eos_token_id=eos_token_ids, | |||
| repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||
| pad_token_id=pad_token_id) | |||
| results = generator.generate(input_ids, past=None) | |||
| results = generator.generate(tokens=input_ids, state=GPT2State()) | |||
| return results | |||
| @@ -788,21 +787,13 @@ class GPT2Model(GPT2PreTrainedModel): | |||
| for layer, heads in heads_to_prune.items(): | |||
| self.h[layer].attn.prune_heads(heads) | |||
| def forward( | |||
| self, | |||
| input_ids, | |||
| past=None, | |||
| attention_mask=None, | |||
| token_type_ids=None, | |||
| position_ids=None, | |||
| head_mask=None, | |||
| output_attentions=True | |||
| ): | |||
| def forward(self, input_ids, state=None, attention_mask=None, token_type_ids=None, position_ids=None, | |||
| head_mask=None, output_attentions=True): | |||
| """ | |||
| :param torch.LongTensor input_ids: batch_size x max_len or batch_size x beam_size x 1 | |||
| :param GPT2Past past: 之前的状态 | |||
| :param torch.ByteTensor attention_mask: batch_size x (pre_len+past_len), 与input_ids与past的concat一样大。 | |||
| :param GPT2State state: 之前的状态 | |||
| :param torch.ByteTensor attention_mask: batch_size x (pre_len+past_len), 与input_ids与state的concat一样大。 | |||
| 为0的地方为padding。 | |||
| :param torch.LongTensor token_type_ids: batch_size x max_len。 | |||
| :param torch.LongTensor position_ids: 与input_ids对应的位置 | |||
| @@ -818,11 +809,11 @@ class GPT2Model(GPT2PreTrainedModel): | |||
| if position_ids is not None: | |||
| position_ids = position_ids.view(-1, input_shape[-1]) | |||
| if past is None or len(past)==0: | |||
| if state is None or len(state)==0: | |||
| past_length = 0 | |||
| past = [None] * len(self.h) # len(self.h) 是layer的层数 | |||
| state = [None] * len(self.h) # len(self.h) 是layer的层数 | |||
| else: | |||
| past_length = past[0][0].size(-2) | |||
| past_length = state[0][0].size(-2) | |||
| if position_ids is None: # 如果没有position id则生成 | |||
| device = input_ids.device | |||
| position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) | |||
| @@ -880,7 +871,7 @@ class GPT2Model(GPT2PreTrainedModel): | |||
| presents = () | |||
| all_attentions = [] | |||
| all_hidden_states = () | |||
| for i, (block, layer_past) in enumerate(zip(self.h, past)): | |||
| for i, (block, layer_past) in enumerate(zip(self.h, state)): | |||
| all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),) | |||
| outputs = block( | |||
| @@ -915,56 +906,63 @@ class GPT2Model(GPT2PreTrainedModel): | |||
| return outputs # last hidden state, (presents), (all hidden_states), (attentions) | |||
| class GPT2Past(Past): | |||
| class GPT2State(State): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.past = None # tuple [n_layer, 2 x batch_size x n_head x past_len x head_dim] | |||
| super().__init__(None, None) | |||
| self.state = None # tuple [n_layer, 2 x batch_size x n_head x past_len x head_dim] | |||
| @property | |||
| def num_samples(self): | |||
| if self.past is not None: | |||
| return self.past[0].size(1) | |||
| if self.state is not None: | |||
| return self.state[0].size(1) | |||
| return None | |||
| def reorder_past(self, indices): | |||
| for i in range(len(self.past)): | |||
| assert self.past[i] is not None | |||
| self.past[i] = self.past[i].index_select(index=indices, dim=1) | |||
| @property | |||
| def decode_length(self): | |||
| if self.state is None: | |||
| return 0 | |||
| return self.state[0].size(-2) | |||
| def reorder_state(self, indices): | |||
| if self.state: | |||
| for i in range(len(self.state)): | |||
| assert self.state[i] is not None | |||
| self.state[i] = self.state[i].index_select(index=indices, dim=1) | |||
| def __iter__(self): | |||
| for p in self.past: | |||
| for p in self.state: | |||
| yield p | |||
| def __getitem__(self, item): | |||
| assert isinstance(item, int) | |||
| return self.past[item] | |||
| return self.state[item] | |||
| def __len__(self): | |||
| if self.past is not None: | |||
| return len(self.past) | |||
| if self.state is not None: | |||
| return len(self.state) | |||
| return 0 | |||
| class _GPT2Decoder(Decoder): | |||
| class _GPT2Decoder(Seq2SeqDecoder): | |||
| """ | |||
| 用于wrap GPT2是的可以在SequenceGenerator中使用 | |||
| """ | |||
| def __init__(self, gpt_model): | |||
| super().__init__() | |||
| self.gpt_model = gpt_model | |||
| def decode(self, tokens, past=None) -> Tuple[torch.Tensor, Past]: | |||
| if past is None: | |||
| past = GPT2Past() | |||
| lm_logits, presents, _ = self.gpt_model(input_ids=tokens, | |||
| past=past, | |||
| def decode(self, tokens, state=None) -> torch.Tensor: | |||
| if state is None: | |||
| state = GPT2State() | |||
| lm_logits, presents, _ = self.gpt_model(input_ids=tokens[:, state.decode_length:], | |||
| state=state, | |||
| attention_mask=None, | |||
| token_type_ids=None, | |||
| position_ids=None, | |||
| head_mask=None, | |||
| output_attentions=False) | |||
| past.past = list(presents) | |||
| return lm_logits[:, -1], past | |||
| def reorder_past(self, indices: torch.LongTensor, past: GPT2Past) -> GPT2Past: | |||
| past.reorder_past(indices) | |||
| return past | |||
| state.state = list(presents) | |||
| return lm_logits[:, -1] | |||
| class GPT2LMHeadModel(GPT2PreTrainedModel): | |||
| @@ -1008,21 +1006,12 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): | |||
| def get_input_embeddings(self): | |||
| return self.transformer.wte | |||
| def forward( | |||
| self, | |||
| input_ids, | |||
| past=None, | |||
| attention_mask=None, | |||
| token_type_ids=None, | |||
| position_ids=None, | |||
| head_mask=None, | |||
| labels=None, | |||
| output_attentions=False | |||
| ): | |||
| def forward(self, input_ids, state=None, attention_mask=None, token_type_ids=None, position_ids=None, | |||
| head_mask=None, labels=None, output_attentions=False): | |||
| """ | |||
| :param torch.LongTensor input_ids: batch_size x max_len or batch_size x beam_size x 1 | |||
| :param tuple past: num_layers x 2 x batch_size x n_head x max_len' x head_dim. 可以将前一个时刻的presents作为输入 | |||
| :param tuple state: num_layers x 2 x batch_size x n_head x max_len' x head_dim. 可以将前一个时刻的presents作为输入 | |||
| :param torch.ByteTensor attention_mask: batch_size x max_len, 与input_ids一样大。为0的地方为padding。 | |||
| :param torch.LongTensor token_type_ids: batch_size x max_len。 | |||
| :param torch.LongTensor position_ids: 与input_ids对应的位置 | |||
| @@ -1034,7 +1023,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): | |||
| """ | |||
| transformer_outputs = self.transformer( | |||
| input_ids, | |||
| past=past, | |||
| state=state, | |||
| attention_mask=attention_mask, | |||
| token_type_ids=token_type_ids, | |||
| position_ids=position_ids, | |||
| @@ -0,0 +1,189 @@ | |||
| import torch.nn as nn | |||
| import torch | |||
| from torch.nn import LayerNorm | |||
| import torch.nn.functional as F | |||
| from typing import Union, Tuple | |||
| from ...core.utils import seq_len_to_mask | |||
| import math | |||
| from ...modules.encoder.lstm import LSTM | |||
| from fastNLP.modules.attention import MultiHeadAttention | |||
| from ...embeddings import StaticEmbedding | |||
| from ...embeddings.utils import get_embeddings | |||
| class Seq2SeqEncoder(nn.Module): | |||
| """ | |||
| 所有Sequence2Sequence Encoder的基类。需要实现forward函数 | |||
| """ | |||
| def __init__(self): | |||
| super().__init__() | |||
| def forward(self, tokens, seq_len): | |||
| """ | |||
| :param torch.LongTensor tokens: bsz x max_len, encoder的输入 | |||
| :param torch.LongTensor seq_len: bsz | |||
| :return: | |||
| """ | |||
| raise NotImplementedError | |||
| class TransformerSeq2SeqEncoderLayer(nn.Module): | |||
| def __init__(self, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, | |||
| dropout: float = 0.1): | |||
| """ | |||
| Self-Attention的Layer, | |||
| :param int d_model: input和output的输出维度 | |||
| :param int n_head: 多少个head,每个head的维度为d_model/n_head | |||
| :param int dim_ff: FFN的维度大小 | |||
| :param float dropout: Self-attention和FFN的dropout大小,0表示不drop | |||
| """ | |||
| super(TransformerSeq2SeqEncoderLayer, self).__init__() | |||
| self.d_model = d_model | |||
| self.n_head = n_head | |||
| self.dim_ff = dim_ff | |||
| self.dropout = dropout | |||
| self.self_attn = MultiHeadAttention(d_model, n_head, dropout) | |||
| self.attn_layer_norm = LayerNorm(d_model) | |||
| self.ffn_layer_norm = LayerNorm(d_model) | |||
| self.ffn = nn.Sequential(nn.Linear(self.d_model, self.dim_ff), | |||
| nn.ReLU(), | |||
| nn.Dropout(dropout), | |||
| nn.Linear(self.dim_ff, self.d_model), | |||
| nn.Dropout(dropout)) | |||
| def forward(self, x, mask): | |||
| """ | |||
| :param x: batch x src_seq x d_model | |||
| :param mask: batch x src_seq,为0的地方为padding | |||
| :return: | |||
| """ | |||
| # attention | |||
| residual = x | |||
| x = self.attn_layer_norm(x) | |||
| x, _ = self.self_attn(query=x, | |||
| key=x, | |||
| value=x, | |||
| key_mask=mask) | |||
| x = F.dropout(x, p=self.dropout, training=self.training) | |||
| x = residual + x | |||
| # ffn | |||
| residual = x | |||
| x = self.ffn_layer_norm(x) | |||
| x = self.ffn(x) | |||
| x = residual + x | |||
| return x | |||
| class TransformerSeq2SeqEncoder(Seq2SeqEncoder): | |||
| def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], pos_embed = None, | |||
| num_layers = 6, d_model = 512, n_head = 8, dim_ff = 2048, dropout = 0.1): | |||
| """ | |||
| 基于Transformer的Encoder | |||
| :param embed: encoder输入token的embedding | |||
| :param nn.Module pos_embed: position embedding | |||
| :param int num_layers: 多少层的encoder | |||
| :param int d_model: 输入输出的维度 | |||
| :param int n_head: 多少个head | |||
| :param int dim_ff: FFN中间的维度大小 | |||
| :param float dropout: Attention和FFN的dropout大小 | |||
| """ | |||
| super(TransformerSeq2SeqEncoder, self).__init__() | |||
| self.embed = get_embeddings(embed) | |||
| self.embed_scale = math.sqrt(d_model) | |||
| self.pos_embed = pos_embed | |||
| self.num_layers = num_layers | |||
| self.d_model = d_model | |||
| self.n_head = n_head | |||
| self.dim_ff = dim_ff | |||
| self.dropout = dropout | |||
| self.input_fc = nn.Linear(self.embed.embedding_dim, d_model) | |||
| self.layer_stacks = nn.ModuleList([TransformerSeq2SeqEncoderLayer(d_model, n_head, dim_ff, dropout) | |||
| for _ in range(num_layers)]) | |||
| self.layer_norm = LayerNorm(d_model) | |||
| def forward(self, tokens, seq_len): | |||
| """ | |||
| :param tokens: batch x max_len | |||
| :param seq_len: [batch] | |||
| :return: bsz x max_len x d_model, bsz x max_len(为0的地方为padding) | |||
| """ | |||
| x = self.embed(tokens) * self.embed_scale # batch, seq, dim | |||
| batch_size, max_src_len, _ = x.size() | |||
| device = x.device | |||
| if self.pos_embed is not None: | |||
| position = torch.arange(1, max_src_len + 1).unsqueeze(0).long().to(device) | |||
| x += self.pos_embed(position) | |||
| x = self.input_fc(x) | |||
| x = F.dropout(x, p=self.dropout, training=self.training) | |||
| encoder_mask = seq_len_to_mask(seq_len) | |||
| encoder_mask = encoder_mask.to(device) | |||
| for layer in self.layer_stacks: | |||
| x = layer(x, encoder_mask) | |||
| x = self.layer_norm(x) | |||
| return x, encoder_mask | |||
| class LSTMSeq2SeqEncoder(Seq2SeqEncoder): | |||
| def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], num_layers = 3, | |||
| hidden_size = 400, dropout = 0.3, bidirectional=True): | |||
| """ | |||
| LSTM的Encoder | |||
| :param embed: encoder的token embed | |||
| :param int num_layers: 多少层 | |||
| :param int hidden_size: LSTM隐藏层、输出的大小 | |||
| :param float dropout: LSTM层之间的Dropout是多少 | |||
| :param bool bidirectional: 是否使用双向 | |||
| """ | |||
| super().__init__() | |||
| self.embed = get_embeddings(embed) | |||
| self.num_layers = num_layers | |||
| self.dropout = dropout | |||
| self.hidden_size = hidden_size | |||
| self.bidirectional = bidirectional | |||
| hidden_size = hidden_size//2 if bidirectional else hidden_size | |||
| self.lstm = LSTM(input_size=embed.embedding_dim, hidden_size=hidden_size, bidirectional=bidirectional, | |||
| batch_first=True, dropout=dropout if num_layers>1 else 0, num_layers=num_layers) | |||
| def forward(self, tokens, seq_len): | |||
| """ | |||
| :param torch.LongTensor tokens: bsz x max_len | |||
| :param torch.LongTensor seq_len: bsz | |||
| :return: (output, (hidden, cell)), encoder_mask | |||
| output: bsz x max_len x hidden_size, | |||
| hidden,cell: batch_size x hidden_size, 最后一层的隐藏状态或cell状态 | |||
| encoder_mask: bsz x max_len, 为0的地方是padding | |||
| """ | |||
| x = self.embed(tokens) | |||
| device = x.device | |||
| x, (final_hidden, final_cell) = self.lstm(x, seq_len) | |||
| encoder_mask = seq_len_to_mask(seq_len).to(device) | |||
| # x: batch,seq_len,dim; h/c: num_layers*2,batch,dim | |||
| if self.bidirectional: | |||
| final_hidden = self.concat_bidir(final_hidden) # 将双向的hidden state拼接起来,用于接下来的decoder的input | |||
| final_cell = self.concat_bidir(final_cell) | |||
| return (x, (final_hidden[-1], final_cell[-1])), encoder_mask # 为了配合Seq2SeqBaseModel的forward,这边需要分为两个return | |||
| def concat_bidir(self, input): | |||
| output = input.view(self.num_layers, 2, input.size(1), -1).transpose(1, 2) | |||
| return output.reshape(self.num_layers, input.size(1), -1) | |||
| @@ -5,7 +5,7 @@ __all__ = [ | |||
| ] | |||
| from torch import nn | |||
| from .attention import MultiHeadAttention | |||
| from .seq2seq_encoder import TransformerSeq2SeqEncoderLayer | |||
| class TransformerEncoder(nn.Module): | |||
| @@ -13,66 +13,30 @@ class TransformerEncoder(nn.Module): | |||
| transformer的encoder模块,不包含embedding层 | |||
| """ | |||
| def __init__(self, num_layers, d_model=512, n_head=8, dim_ff=2048, dropout=0.1): | |||
| """ | |||
| class SubLayer(nn.Module): | |||
| def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): | |||
| super(TransformerEncoder.SubLayer, self).__init__() | |||
| self.atte = MultiHeadAttention(model_size, key_size, value_size, num_head, dropout) | |||
| self.norm1 = nn.LayerNorm(model_size, eps=1e-6) | |||
| self.ffn = nn.Sequential(nn.Linear(model_size, inner_size), | |||
| nn.ReLU(), | |||
| nn.Dropout(dropout), | |||
| nn.Linear(inner_size, model_size)) | |||
| self.norm2 = nn.LayerNorm(model_size, eps=1e-6) | |||
| self.dropout = nn.Dropout(dropout) | |||
| def forward(self, input, seq_mask=None, atte_mask_out=None): | |||
| r""" | |||
| :param input: [batch, seq_len, model_size] | |||
| :param seq_mask: [batch, seq_len] | |||
| :return: [batch, seq_len, model_size] | |||
| """ | |||
| if seq_mask is None: # 防止后续乘法时出错 | |||
| seq_mask = 1 | |||
| input = self.norm1(input) | |||
| attention = self.atte(input, input, input, atte_mask_out) | |||
| input = input + self.dropout(attention) | |||
| attention *= seq_mask | |||
| input = self.norm2(input) | |||
| output = self.ffn(input) | |||
| input = input + self.dropout(output) | |||
| input *= seq_mask | |||
| return input | |||
| def __init__(self, num_layers, **kargs): | |||
| r""" | |||
| :param int num_layers: transformer的层数 | |||
| :param int model_size: 输入维度的大小。同时也是输出维度的大小。 | |||
| :param int inner_size: FFN层的hidden大小 | |||
| :param int key_size: 每个head的维度大小。 | |||
| :param int value_size: 每个head中value的维度。 | |||
| :param int num_head: head的数量。 | |||
| :param float dropout: dropout概率. Default: 0.1 | |||
| :param int num_layers: 多少层Transformer | |||
| :param int d_model: input和output的大小 | |||
| :param int n_head: 多少个head | |||
| :param int dim_ff: FFN中间hidden大小 | |||
| :param float dropout: 多大概率drop attention和ffn中间的表示 | |||
| """ | |||
| super(TransformerEncoder, self).__init__() | |||
| self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)]) | |||
| self.norm = nn.LayerNorm(kargs['model_size'], eps=1e-6) | |||
| self.layers = nn.ModuleList([TransformerSeq2SeqEncoderLayer(d_model = d_model, n_head = n_head, dim_ff = dim_ff, | |||
| dropout = dropout) for _ in range(num_layers)]) | |||
| self.norm = nn.LayerNorm(d_model, eps=1e-6) | |||
| def forward(self, x, seq_mask=None): | |||
| r""" | |||
| :param x: [batch, seq_len, model_size] 输入序列 | |||
| :param seq_mask: [batch, seq_len] 输入序列的padding mask, 若为 ``None`` , 生成全1向量. | |||
| :param seq_mask: [batch, seq_len] 输入序列的padding mask, 若为 ``None`` , 生成全1向量. 为1的地方需要attend | |||
| Default: ``None`` | |||
| :return: [batch, seq_len, model_size] 输出序列 | |||
| """ | |||
| output = x | |||
| if seq_mask is None: | |||
| atte_mask_out = None | |||
| else: | |||
| atte_mask_out = (seq_mask.eq(False))[:, None, :] | |||
| seq_mask = seq_mask[:, :, None] | |||
| seq_mask = x.new_ones(x.size(0), x.size(1)).bool() | |||
| for layer in self.layers: | |||
| output = layer(output, seq_mask, atte_mask_out) | |||
| output = layer(output, seq_mask) | |||
| return self.norm(output) | |||
| @@ -0,0 +1,9 @@ | |||
| r""" | |||
| """ | |||
| __all__ = [ | |||
| "SequenceGenerator" | |||
| ] | |||
| from .seq2seq_generator import SequenceGenerator | |||
| @@ -7,16 +7,35 @@ __all__ = [ | |||
| ] | |||
| import torch | |||
| from ..decoder.seq2seq_decoder import Decoder | |||
| from ..decoder.seq2seq_decoder import Seq2SeqDecoder, State | |||
| import torch.nn.functional as F | |||
| from fastNLP.core.utils import _get_model_device | |||
| from ...core.utils import _get_model_device | |||
| from functools import partial | |||
| class SequenceGenerator: | |||
| def __init__(self, decoder: Decoder, max_length=20, num_beams=1, | |||
| """ | |||
| 给定一个Seq2SeqDecoder,decode出句子 | |||
| """ | |||
| def __init__(self, decoder: Seq2SeqDecoder, max_length=20, num_beams=1, | |||
| do_sample=True, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, | |||
| repetition_penalty=1, length_penalty=1.0, pad_token_id=0): | |||
| """ | |||
| :param Seq2SeqDecoder decoder: Decoder对象 | |||
| :param int max_length: 句子的最大长度 | |||
| :param int num_beams: beam search的大小 | |||
| :param bool do_sample: 是否通过采样的方式生成 | |||
| :param float temperature: 只有在do_sample为True才有意义 | |||
| :param int top_k: 只从top_k中采样 | |||
| :param float top_p: 只从top_p的token中采样,nucles sample | |||
| :param int,None bos_token_id: 句子开头的token id | |||
| :param int,None eos_token_id: 句子结束的token id | |||
| :param float repetition_penalty: 多大程度上惩罚重复的token | |||
| :param float length_penalty: 对长度的惩罚,小于1鼓励长句,大于1鼓励短剧 | |||
| :param int pad_token_id: 当某句话生成结束之后,之后生成的内容用pad_token_id补充 | |||
| """ | |||
| if do_sample: | |||
| self.generate_func = partial(sample_generate, decoder=decoder, max_length=max_length, num_beams=num_beams, | |||
| temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id, | |||
| @@ -40,19 +59,19 @@ class SequenceGenerator: | |||
| self.decoder = decoder | |||
| @torch.no_grad() | |||
| def generate(self, tokens=None, past=None): | |||
| def generate(self, state, tokens=None): | |||
| """ | |||
| :param torch.LongTensor tokens: batch_size x length, 开始的token | |||
| :param past: | |||
| :return: | |||
| :param State state: encoder结果的State, 是与Decoder配套是用的 | |||
| :param torch.LongTensor,None tokens: batch_size x length, 开始的token | |||
| :return: bsz x max_length' 生成的token序列。如果eos_token_id不为None, 每个sequence的结尾一定是eos_token_id | |||
| """ | |||
| # TODO 需要查看如果tokens长度不是1,decode的时候是否还能够直接decode? | |||
| return self.generate_func(tokens=tokens, past=past) | |||
| return self.generate_func(tokens=tokens, state=state) | |||
| @torch.no_grad() | |||
| def greedy_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, | |||
| def greedy_generate(decoder, tokens=None, state=None, max_length=20, num_beams=1, | |||
| bos_token_id=None, eos_token_id=None, pad_token_id=0, | |||
| repetition_penalty=1, length_penalty=1.0): | |||
| """ | |||
| @@ -60,23 +79,23 @@ def greedy_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, | |||
| :param Decoder decoder: Decoder对象 | |||
| :param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 | |||
| :param Past past: 应该包好encoder的一些输出。 | |||
| :param State state: 应该包含encoder的一些输出。 | |||
| :param int max_length: 生成句子的最大长度。 | |||
| :param int num_beams: 使用多大的beam进行解码。 | |||
| :param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。 | |||
| :param int eos_token_id: 结束的token,如果为None,则一定会解码到max_length这么长。 | |||
| :param int pad_token_id: | |||
| :param int pad_token_id: pad的token id | |||
| :param float repetition_penalty: 对重复出现的token多大的惩罚。 | |||
| :param float length_penalty: 对每个token(除了eos)按照长度进行一定的惩罚。 | |||
| :return: | |||
| """ | |||
| if num_beams == 1: | |||
| token_ids = _no_beam_search_generate(decoder, tokens, past, max_length, temperature=1, top_k=50, top_p=1, | |||
| token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, temperature=1, top_k=50, top_p=1, | |||
| bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, | |||
| repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||
| pad_token_id=pad_token_id) | |||
| else: | |||
| token_ids = _beam_search_generate(decoder, tokens, past, max_length, num_beams=num_beams, | |||
| token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, num_beams=num_beams, | |||
| temperature=1, top_k=50, top_p=1, | |||
| bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, | |||
| repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||
| @@ -86,7 +105,7 @@ def greedy_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, | |||
| @torch.no_grad() | |||
| def sample_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, temperature=1.0, top_k=50, | |||
| def sample_generate(decoder, tokens=None, state=None, max_length=20, num_beams=1, temperature=1.0, top_k=50, | |||
| top_p=1.0, bos_token_id=None, eos_token_id=None, pad_token_id=0, repetition_penalty=1.0, | |||
| length_penalty=1.0): | |||
| """ | |||
| @@ -94,7 +113,7 @@ def sample_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, | |||
| :param Decoder decoder: Decoder对象 | |||
| :param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 | |||
| :param Past past: 应该包好encoder的一些输出。 | |||
| :param State state: 应该包含encoder的一些输出。 | |||
| :param int max_length: 生成句子的最大长度。 | |||
| :param int num_beam: 使用多大的beam进行解码。 | |||
| :param float temperature: 采样时的退火大小 | |||
| @@ -109,13 +128,13 @@ def sample_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, | |||
| """ | |||
| # 每个位置在生成的时候会sample生成 | |||
| if num_beams == 1: | |||
| token_ids = _no_beam_search_generate(decoder, tokens, past, max_length, temperature=temperature, | |||
| token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, temperature=temperature, | |||
| top_k=top_k, top_p=top_p, | |||
| bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, | |||
| repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||
| pad_token_id=pad_token_id) | |||
| else: | |||
| token_ids = _beam_search_generate(decoder, tokens, past, max_length, num_beams=num_beams, | |||
| token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, num_beams=num_beams, | |||
| temperature=temperature, top_k=top_k, top_p=top_p, | |||
| bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, | |||
| repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||
| @@ -123,40 +142,35 @@ def sample_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, | |||
| return token_ids | |||
| def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=20, temperature=1.0, top_k=50, | |||
| def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_length=20, temperature=1.0, top_k=50, | |||
| top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, | |||
| repetition_penalty=1.0, length_penalty=1.0, pad_token_id=0): | |||
| device = _get_model_device(decoder) | |||
| if tokens is None: | |||
| if bos_token_id is None: | |||
| raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.") | |||
| if past is None: | |||
| raise RuntimeError("You have to specify either `past` or `tokens`.") | |||
| batch_size = past.num_samples() | |||
| batch_size = state.num_samples | |||
| if batch_size is None: | |||
| raise RuntimeError("Cannot infer the number of samples from `past`.") | |||
| raise RuntimeError("Cannot infer the number of samples from `state`.") | |||
| tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device) | |||
| batch_size = tokens.size(0) | |||
| if past is not None: | |||
| assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match." | |||
| if state.num_samples: | |||
| assert state.num_samples == batch_size, "The number of samples in `tokens` and `state` should match." | |||
| if eos_token_id is None: | |||
| _eos_token_id = float('nan') | |||
| _eos_token_id = -1 | |||
| else: | |||
| _eos_token_id = eos_token_id | |||
| # for i in range(tokens.size(1)): | |||
| # scores, past = decoder.decode_one(tokens[:, :i + 1], past) # batch_size x vocab_size, Past | |||
| scores, past = decoder.decode(tokens, past) | |||
| token_ids = tokens.clone() | |||
| scores = decoder.decode(tokens=tokens, state=state) # 主要是为了update state | |||
| next_tokens = scores.argmax(dim=-1, keepdim=True) | |||
| token_ids = torch.cat([tokens, next_tokens], dim=1) | |||
| cur_len = token_ids.size(1) | |||
| dones = token_ids.new_zeros(batch_size).eq(1) | |||
| # tokens = tokens[:, -1:] | |||
| while cur_len < max_length: | |||
| # scores, past = decoder.decode_one(tokens, past) # batch_size x vocab_size, Past | |||
| scores, past = decoder.decode(tokens, past) # batch_size x vocab_size, Past | |||
| scores = decoder.decode(tokens=token_ids, state=state) # batch_size x vocab_size | |||
| if repetition_penalty != 1.0: | |||
| token_scores = scores.gather(dim=1, index=token_ids) | |||
| @@ -204,7 +218,7 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt | |||
| return token_ids | |||
| def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=20, num_beams=4, temperature=1.0, | |||
| def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_length=20, num_beams=4, temperature=1.0, | |||
| top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, | |||
| repetition_penalty=1.0, length_penalty=None, pad_token_id=0) -> torch.LongTensor: | |||
| # 进行beam search | |||
| @@ -212,21 +226,20 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||
| if tokens is None: | |||
| if bos_token_id is None: | |||
| raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.") | |||
| if past is None: | |||
| raise RuntimeError("You have to specify either `past` or `tokens`.") | |||
| batch_size = past.num_samples() | |||
| batch_size = state.num_samples | |||
| if batch_size is None: | |||
| raise RuntimeError("Cannot infer the number of samples from `past`.") | |||
| raise RuntimeError("Cannot infer the number of samples from `state`.") | |||
| tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device) | |||
| batch_size = tokens.size(0) | |||
| if past is not None: | |||
| assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match." | |||
| # for i in range(tokens.size(1) - 1): # 如果输入的长度较长,先decode | |||
| # scores, past = decoder.decode_one(tokens[:, :i + 1], | |||
| # past) # (batch_size, vocab_size), Past | |||
| # scores, past = decoder.decode_one(tokens, past) # 这里要传入的是整个句子的长度 | |||
| scores, past = decoder.decode(tokens, past) # 这里要传入的是整个句子的长度 | |||
| if state.num_samples: | |||
| assert state.num_samples == batch_size, "The number of samples in `tokens` and `state` should match." | |||
| if eos_token_id is None: | |||
| _eos_token_id = -1 | |||
| else: | |||
| _eos_token_id = eos_token_id | |||
| scores = decoder.decode(tokens=tokens, state=state) # 这里要传入的是整个句子的长度 | |||
| vocab_size = scores.size(1) | |||
| assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size." | |||
| @@ -240,15 +253,15 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||
| # 得到(batch_size, num_beams), (batch_size, num_beams) | |||
| next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True) | |||
| # 根据index来做顺序的调转 | |||
| indices = torch.arange(batch_size, dtype=torch.long).to(device) | |||
| indices = indices.repeat_interleave(num_beams) | |||
| decoder.reorder_past(indices, past) | |||
| state.reorder_state(indices) | |||
| tokens = tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length | |||
| # 记录生成好的token (batch_size', cur_len) | |||
| token_ids = torch.cat([tokens, next_tokens.view(-1, 1)], dim=-1) | |||
| dones = [False] * batch_size | |||
| tokens = next_tokens.view(-1, 1) | |||
| beam_scores = next_scores.view(-1) # batch_size * num_beams | |||
| @@ -262,8 +275,7 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||
| batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids) | |||
| while cur_len < max_length: | |||
| # scores, past = decoder.decode_one(tokens, past) # batch_size * num_beams x vocab_size, Past | |||
| scores, past = decoder.decode(tokens, past) | |||
| scores = decoder.decode(token_ids, state) | |||
| if repetition_penalty != 1.0: | |||
| token_scores = scores.gather(dim=1, index=token_ids) | |||
| lt_zero_mask = token_scores.lt(0).float() | |||
| @@ -307,7 +319,7 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||
| next_tokens = next_tokens.gather(dim=1, index=sorted_inds) | |||
| from_which_beam = from_which_beam.gather(dim=1, index=sorted_inds) | |||
| not_eos_mask = next_tokens.ne(eos_token_id) # 为1的地方不是eos | |||
| not_eos_mask = next_tokens.ne(_eos_token_id) # 为1的地方不是eos | |||
| keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams) # 为1的地方需要保留 | |||
| keep_mask = not_eos_mask.__and__(keep_mask) # 为1的地方是需要进行下一步search的 | |||
| @@ -316,18 +328,18 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||
| _next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams) | |||
| beam_scores = _next_scores.view(-1) | |||
| # 更改past状态, 重组token_ids | |||
| # 更改state状态, 重组token_ids | |||
| reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维 | |||
| decoder.reorder_past(reorder_inds, past) | |||
| state.reorder_state(reorder_inds) | |||
| flag = True | |||
| if cur_len + 1 == max_length: | |||
| if cur_len+1 == max_length: | |||
| eos_batch_idx = torch.arange(batch_size).to(next_tokens).repeat_interleave(repeats=num_beams, dim=0) | |||
| eos_beam_ind = torch.arange(num_beams).to(token_ids).repeat(batch_size) # 表示的是indice | |||
| eos_beam_idx = from_which_beam[:, :num_beams].reshape(-1) # 表示的是从哪个beam获取得到的 | |||
| else: | |||
| # 将每个batch中在num_beam内的序列添加到结束中, 为1的地方需要结束了 | |||
| effective_eos_mask = next_tokens[:, :num_beams].eq(eos_token_id) # batch_size x num_beams | |||
| effective_eos_mask = next_tokens[:, :num_beams].eq(_eos_token_id) # batch_size x num_beams | |||
| if effective_eos_mask.sum().gt(0): | |||
| eos_batch_idx, eos_beam_ind = effective_eos_mask.nonzero(as_tuple=True) | |||
| # 是由于from_which_beam是 (batch_size, 2*num_beams)的,所以需要2*num_beams | |||
| @@ -335,16 +347,17 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||
| eos_beam_idx = from_which_beam.view(-1)[eos_beam_idx] # 获取真实的从哪个beam获取的eos | |||
| else: | |||
| flag = False | |||
| # 重新组织token_ids的状态 | |||
| tokens = _next_tokens | |||
| token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), tokens], dim=-1) | |||
| if flag: | |||
| for batch_idx, beam_ind, beam_idx in zip(eos_batch_idx.tolist(), eos_beam_ind.tolist(), | |||
| eos_beam_idx.tolist()): | |||
| if not dones[batch_idx]: | |||
| score = next_scores[batch_idx, beam_ind].item() | |||
| hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score) | |||
| # 重新组织token_ids的状态 | |||
| tokens = _next_tokens | |||
| token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), tokens], dim=-1) | |||
| hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx, :cur_len+1].clone(), score) | |||
| for batch_idx in range(batch_size): | |||
| dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item()) | |||
| @@ -360,15 +373,15 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||
| for i, hypotheses in enumerate(hypos): | |||
| best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1] | |||
| tgt_len[i] = len(best_hyp) + 1 # +1 for the <EOS> symbol | |||
| tgt_len[i] = len(best_hyp) # +1 for the <EOS> symbol | |||
| best.append(best_hyp) | |||
| # generate target batch | |||
| decoded = token_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id) | |||
| for i, hypo in enumerate(best): | |||
| decoded[i, :tgt_len[i] - 1] = hypo | |||
| decoded[i, :tgt_len[i]] = hypo | |||
| if eos_token_id is not None: | |||
| decoded[i, tgt_len[i] - 1] = eos_token_id | |||
| decoded[i, tgt_len[i] - 1] = _eos_token_id | |||
| return decoded | |||
| @@ -384,6 +384,9 @@ class BertTokenizer(object): | |||
| index += 1 | |||
| return vocab_file | |||
| def save_pretrained(self, save_directory): | |||
| self.save_vocabulary(save_directory) | |||
| @classmethod | |||
| def from_pretrained(cls, model_dir_or_name, *inputs, **kwargs): | |||
| r""" | |||
| @@ -377,6 +377,9 @@ class GPT2Tokenizer: | |||
| text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) | |||
| return text | |||
| def save_pretrained(self, save_directory): | |||
| return self.save_vocabulary(save_directory) | |||
| def save_vocabulary(self, save_directory): | |||
| """Save the tokenizer vocabulary and merge files to a directory.""" | |||
| if not os.path.isdir(save_directory): | |||
| @@ -32,7 +32,6 @@ from tools.PositionEmbedding import get_sinusoid_encoding_table | |||
| from tools.logger import * | |||
| from fastNLP.core.const import Const | |||
| from fastNLP.modules.encoder.transformer import TransformerEncoder | |||
| from transformer.Layers import EncoderLayer | |||
| @@ -30,7 +30,7 @@ from .Encoder import Encoder | |||
| from tools.PositionEmbedding import get_sinusoid_encoding_table | |||
| from fastNLP.core.const import Const | |||
| from fastNLP.modules.encoder.transformer import TransformerEncoder | |||
| from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoderLayer | |||
| class TransformerModel(nn.Module): | |||
| def __init__(self, hps, vocab): | |||
| @@ -68,7 +68,8 @@ class TransformerModel(nn.Module): | |||
| get_sinusoid_encoding_table(hps.doc_max_timesteps + 1, self.hidden_size, padding_idx=0), freeze=True) | |||
| self.layer_stack = nn.ModuleList([ | |||
| TransformerEncoder.SubLayer(model_size=self.hidden_size, inner_size=self.d_inner, key_size=self.d_k, value_size=self.d_v,num_head=self.n_head, dropout=hps.atten_dropout_prob) | |||
| TransformerSeq2SeqEncoderLayer(d_model = self.hidden_size, n_head = self.n_head, dim_ff = self.d_inner, | |||
| dropout = hps.atten_dropout_prob) | |||
| for _ in range(self.num_layers)]) | |||
| self.wh = nn.Linear(self.hidden_size, 2) | |||
| @@ -109,7 +110,7 @@ class TransformerModel(nn.Module): | |||
| for enc_layer in self.layer_stack: | |||
| # enc_output = [batch_size, N, hidden_size = n_head * d_v] | |||
| # enc_slf_attn = [n_head * batch_size, N, N] | |||
| enc_input = enc_layer(enc_input, seq_mask=self.non_pad_mask, atte_mask_out=self.slf_attn_mask) | |||
| enc_input = enc_layer(enc_input, encoder_mask=self.slf_attn_mask) | |||
| enc_input_list += [enc_input] | |||
| self.dec_output_state = torch.cat(enc_input_list[-4:]) # [4, batch_size, N, hidden_state] | |||
| @@ -265,7 +265,7 @@ def run_eval(model, loader, hps, best_loss, best_F, non_descent_cnt): | |||
| label = Variable(label) | |||
| input_len = Variable(input_len, requires_grad=False) | |||
| model_outputs = model.forward(input,input_len) # [batch, N, 2] | |||
| model_outputs = model.forward(input, input_len) # [batch, N, 2] | |||
| outputs = model_outputs["p_sent"] | |||
| prediction = model_outputs["prediction"] | |||
| @@ -264,7 +264,7 @@ def run_eval(model, loader, hps, best_loss, best_F, non_descent_cnt): | |||
| label = Variable(label) | |||
| input_len = Variable(input_len, requires_grad=False) | |||
| model_outputs = model.forward(input,input_len) # [batch, N, 2] | |||
| model_outputs = model.forward(input, input_len) # [batch, N, 2] | |||
| outputs = model_outputs[Const.OUTPUTS] | |||
| prediction = model_outputs["prediction"] | |||
| @@ -7,10 +7,12 @@ from transformer.Layers import EncoderLayer, DecoderLayer | |||
| __author__ = "Yu-Hsiang Huang" | |||
| def get_non_pad_mask(seq): | |||
| assert seq.dim() == 2 | |||
| return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1) | |||
| def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): | |||
| ''' Sinusoid position encoding table ''' | |||
| @@ -31,6 +33,7 @@ def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): | |||
| return torch.FloatTensor(sinusoid_table) | |||
| def get_attn_key_pad_mask(seq_k, seq_q): | |||
| ''' For masking out the padding part of key sequence. ''' | |||
| @@ -41,6 +44,7 @@ def get_attn_key_pad_mask(seq_k, seq_q): | |||
| return padding_mask | |||
| def get_subsequent_mask(seq): | |||
| ''' For masking out the subsequent info. ''' | |||
| @@ -51,6 +55,7 @@ def get_subsequent_mask(seq): | |||
| return subsequent_mask | |||
| class Encoder(nn.Module): | |||
| ''' A encoder model with self attention mechanism. ''' | |||
| @@ -98,6 +103,7 @@ class Encoder(nn.Module): | |||
| return enc_output, enc_slf_attn_list | |||
| return enc_output, | |||
| class Decoder(nn.Module): | |||
| ''' A decoder model with self attention mechanism. ''' | |||
| @@ -152,6 +158,7 @@ class Decoder(nn.Module): | |||
| return dec_output, dec_slf_attn_list, dec_enc_attn_list | |||
| return dec_output, | |||
| class Transformer(nn.Module): | |||
| ''' A sequence to sequence model with attention mechanism. ''' | |||
| @@ -181,8 +188,8 @@ class Transformer(nn.Module): | |||
| nn.init.xavier_normal_(self.tgt_word_prj.weight) | |||
| assert d_model == d_word_vec, \ | |||
| 'To facilitate the residual connections, \ | |||
| the dimensions of all module outputs shall be the same.' | |||
| 'To facilitate the residual connections, \ | |||
| the dimensions of all module outputs shall be the same.' | |||
| if tgt_emb_prj_weight_sharing: | |||
| # Share the weight matrix between target word embedding & the final logit dense layer | |||
| @@ -194,7 +201,7 @@ class Transformer(nn.Module): | |||
| if emb_src_tgt_weight_sharing: | |||
| # Share the weight matrix between source & target word embeddings | |||
| assert n_src_vocab == n_tgt_vocab, \ | |||
| "To share word embedding table, the vocabulary size of src/tgt shall be the same." | |||
| "To share word embedding table, the vocabulary size of src/tgt shall be the same." | |||
| self.encoder.src_word_emb.weight = self.decoder.tgt_word_emb.weight | |||
| def forward(self, src_seq, src_pos, tgt_seq, tgt_pos): | |||
| @@ -1,7 +1,6 @@ | |||
| import fastNLP | |||
| import torch | |||
| import math | |||
| from fastNLP.modules.encoder.transformer import TransformerEncoder | |||
| from fastNLP.modules.decoder.crf import ConditionalRandomField | |||
| from fastNLP import Const | |||
| import copy | |||
| @@ -181,7 +180,6 @@ def make_CWS( | |||
| freeze=True, | |||
| ): | |||
| c = copy.deepcopy | |||
| # encoder=TransformerEncoder(num_layers=N,model_size=d_model,inner_size=d_ff,key_size=d_model//h,value_size=d_model//h,num_head=h,dropout=dropout) | |||
| encoder = transformer.make_encoder( | |||
| N=N, d_model=d_model, h=h, dropout=dropout, d_ff=d_ff | |||
| ) | |||
| @@ -1,9 +1,8 @@ | |||
| import torch | |||
| import torch.nn as nn | |||
| from fastNLP.core.const import Const as C | |||
| from fastNLP.modules.encoder.lstm import LSTM | |||
| from fastNLP.embeddings.utils import get_embeddings | |||
| from fastNLP.modules.encoder.attention import SelfAttention | |||
| from fastNLP.modules.attention import SelfAttention | |||
| from fastNLP.modules.decoder.mlp import MLP | |||
| @@ -44,7 +44,7 @@ class WeightDrop(torch.nn.Module): | |||
| def forward(self, *args): | |||
| self._setweights() | |||
| return self.module.forward(*args) | |||
| return self.module.forward() | |||
| if __name__ == '__main__': | |||
| import torch | |||
| @@ -40,8 +40,7 @@ class TestBertEmbedding(unittest.TestCase): | |||
| result = embed(words) | |||
| self.assertEqual(result.size(), (1, 4, 16)) | |||
| embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1, | |||
| only_use_pretrain_bpe=True) | |||
| embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1) | |||
| embed.eval() | |||
| words = torch.LongTensor([[2, 3, 4, 0]]) | |||
| result = embed(words) | |||
| @@ -49,53 +48,30 @@ class TestBertEmbedding(unittest.TestCase): | |||
| # 自动截断而不报错 | |||
| embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1, | |||
| only_use_pretrain_bpe=True, auto_truncate=True) | |||
| auto_truncate=True) | |||
| words = torch.LongTensor([[2, 3, 4, 1]*10, | |||
| [2, 3]+[0]*38]) | |||
| result = embed(words) | |||
| self.assertEqual(result.size(), (2, 40, 16)) | |||
| def test_bert_embedding_2(self): | |||
| # 测试only_use_pretrain_vocab与truncate_embed是否正常工作 | |||
| with open('test/data_for_tests/embedding/small_bert/vocab.txt', 'r', encoding='utf-8') as f: | |||
| num_word = len(f.readlines()) | |||
| Embedding = BertEmbedding | |||
| vocab = Vocabulary().add_word_lst("this is a texta and [SEP] NotInBERT".split()) | |||
| embed1 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | |||
| only_use_pretrain_bpe=True, truncate_embed=True, min_freq=1) | |||
| embed_bpe_vocab_size = len(vocab)-1 + 2 # 排除NotInBERT, 额外加##a, [CLS] | |||
| self.assertEqual(embed_bpe_vocab_size, len(embed1.model.tokenzier.vocab)) | |||
| embed2 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | |||
| only_use_pretrain_bpe=True, truncate_embed=False, min_freq=1) | |||
| embed_bpe_vocab_size = num_word # 排除NotInBERT | |||
| self.assertEqual(embed_bpe_vocab_size, len(embed2.model.tokenzier.vocab)) | |||
| embed3 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | |||
| only_use_pretrain_bpe=False, truncate_embed=True, min_freq=1) | |||
| embed_bpe_vocab_size = len(vocab)+2 # 新增##a, [CLS] | |||
| self.assertEqual(embed_bpe_vocab_size, len(embed3.model.tokenzier.vocab)) | |||
| embed4 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | |||
| only_use_pretrain_bpe=False, truncate_embed=False, min_freq=1) | |||
| embed_bpe_vocab_size = num_word+1 # 新增##a | |||
| self.assertEqual(embed_bpe_vocab_size, len(embed4.model.tokenzier.vocab)) | |||
| # 测试各种情况下以下tensor的值是相等的 | |||
| embed1.eval() | |||
| embed2.eval() | |||
| embed3.eval() | |||
| embed4.eval() | |||
| tensor = torch.LongTensor([[vocab.to_index(w) for w in 'this is a texta and'.split()]]) | |||
| t1 = embed1(tensor) | |||
| t2 = embed2(tensor) | |||
| t3 = embed3(tensor) | |||
| t4 = embed4(tensor) | |||
| self.assertEqual((t1-t2).sum(), 0) | |||
| self.assertEqual((t1-t3).sum(), 0) | |||
| self.assertEqual((t1-t4).sum(), 0) | |||
| def test_save_load(self): | |||
| bert_save_test = 'bert_save_test' | |||
| try: | |||
| os.makedirs(bert_save_test, exist_ok=True) | |||
| vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInBERT".split()) | |||
| embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1, | |||
| auto_truncate=True) | |||
| embed.save(bert_save_test) | |||
| load_embed = BertEmbedding.load(bert_save_test) | |||
| words = torch.randint(len(vocab), size=(2, 20)) | |||
| embed.eval(), load_embed.eval() | |||
| self.assertEqual((embed(words) - load_embed(words)).sum(), 0) | |||
| finally: | |||
| import shutil | |||
| shutil.rmtree(bert_save_test) | |||
| class TestBertWordPieceEncoder(unittest.TestCase): | |||
| @@ -120,11 +96,30 @@ class TestBertWordPieceEncoder(unittest.TestCase): | |||
| ds.set_input('words') | |||
| words = torch.LongTensor(ds['words'].get([0, 1])) | |||
| embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | |||
| pool_method='first', include_cls_sep=True, pooled_cls=False) | |||
| pool_method='first', include_cls_sep=True, pooled_cls=False, min_freq=1) | |||
| embed.eval() | |||
| words_res = embed(words) | |||
| # 检查word piece什么的是正常work的 | |||
| self.assertEqual((word_pieces_res[0, :5]-words_res[0, :5]).sum(), 0) | |||
| self.assertEqual((word_pieces_res[0, 6:]-words_res[0, 5:]).sum(), 0) | |||
| self.assertEqual((word_pieces_res[1, :3]-words_res[1, :3]).sum(), 0) | |||
| self.assertEqual((word_pieces_res[1, :3]-words_res[1, :3]).sum(), 0) | |||
| def test_save_load(self): | |||
| bert_save_test = 'bert_save_test' | |||
| try: | |||
| os.makedirs(bert_save_test, exist_ok=True) | |||
| embed = BertWordPieceEncoder(model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.0, | |||
| layers='-2') | |||
| ds = DataSet({'words': ["this is a test . [SEP]".split()]}) | |||
| embed.index_datasets(ds, field_name='words') | |||
| self.assertTrue(ds.has_field('word_pieces')) | |||
| words = torch.LongTensor([[1, 2, 3, 4]]) | |||
| embed.save(bert_save_test) | |||
| load_embed = BertWordPieceEncoder.load(bert_save_test) | |||
| embed.eval(), load_embed.eval() | |||
| self.assertEqual((embed(words) - load_embed(words)).sum(), 0) | |||
| finally: | |||
| import shutil | |||
| shutil.rmtree(bert_save_test) | |||
| @@ -255,14 +255,17 @@ class TestGPT2WordPieceEncoder(unittest.TestCase): | |||
| result = embed(torch.LongTensor([[1, 2, 3, 4]])) | |||
| def test_generate(self): | |||
| weight_path = 'test/data_for_tests/embedding/small_gpt2' | |||
| # weight_path = 'test/data_for_tests/embedding/small_gpt2' | |||
| weight_path = 'en' | |||
| encoder = GPT2WordPieceEncoder(model_dir_or_name=weight_path, language_model=True) | |||
| # 测试一下各项东西是否正常work | |||
| print(encoder.generate_from_str('this', max_len=20, do_sample=False, num_beams=1, temperature=1, top_k=50, top_p=1.0, | |||
| print(encoder.generate_from_str('This', max_len=20, do_sample=False, num_beams=1, temperature=1, top_k=50, top_p=1.0, | |||
| repetition_penalty=1.0, length_penalty=1.0)) | |||
| print(encoder.generate_from_str('This day', max_len=20, do_sample=False, num_beams=1, temperature=1, top_k=50, top_p=1.0, | |||
| repetition_penalty=1.0, length_penalty=1.0)) | |||
| print(encoder.generate_from_str('this', max_len=20, do_sample=True, num_beams=3, temperature=1, top_k=50, top_p=1.0, | |||
| print(encoder.generate_from_str('This', max_len=20, do_sample=True, num_beams=3, temperature=1, top_k=50, top_p=1.0, | |||
| repetition_penalty=1.0, length_penalty=1.0)) | |||
| print(encoder.generate_from_str('this', max_len=20, do_sample=True, num_beams=3, temperature=2, top_k=20, top_p=2.0, | |||
| print(encoder.generate_from_str('This', max_len=20, do_sample=True, num_beams=3, temperature=2, top_k=20, top_p=2.0, | |||
| repetition_penalty=2.0, length_penalty=1.5)) | |||
| @@ -47,7 +47,7 @@ class TestRobertWordPieceEncoder(unittest.TestCase): | |||
| ds.set_input('words') | |||
| words = torch.LongTensor(ds['words'].get([0, 1])) | |||
| embed = RobertaEmbedding(vocab, model_dir_or_name=weight_path, | |||
| pool_method='first', include_cls_sep=True, pooled_cls=False) | |||
| pool_method='first', include_cls_sep=True, pooled_cls=False, min_freq=1) | |||
| embed.eval() | |||
| words_res = embed(words) | |||
| @@ -183,6 +183,24 @@ class TestRobertWordPieceEncoder(unittest.TestCase): | |||
| torch.save(model.state_dict(), 'test/data_for_tests/embedding/small_roberta/small_pytorch_model.bin') | |||
| print(model(torch.LongTensor([[0,1,2,3]]))) | |||
| def test_save_load(self): | |||
| bert_save_test = 'roberta_save_test' | |||
| try: | |||
| os.makedirs(bert_save_test, exist_ok=True) | |||
| embed = RobertaWordPieceEncoder(model_dir_or_name='test/data_for_tests/embedding/small_roberta', word_dropout=0.0, | |||
| layers='-2') | |||
| ds = DataSet({'words': ["this is a test . [SEP]".split()]}) | |||
| embed.index_datasets(ds, field_name='words') | |||
| self.assertTrue(ds.has_field('word_pieces')) | |||
| words = torch.LongTensor([[1, 2, 3, 4]]) | |||
| embed.save(bert_save_test) | |||
| load_embed = RobertaWordPieceEncoder.load(bert_save_test) | |||
| embed.eval(), load_embed.eval() | |||
| self.assertEqual((embed(words) - load_embed(words)).sum(), 0) | |||
| finally: | |||
| import shutil | |||
| shutil.rmtree(bert_save_test) | |||
| class TestRobertaEmbedding(unittest.TestCase): | |||
| def test_roberta_embedding_1(self): | |||
| @@ -250,3 +268,20 @@ class TestRobertaEmbedding(unittest.TestCase): | |||
| self.assertEqual((t1-t2).sum(), 0) | |||
| self.assertEqual((t1-t3).sum(), 0) | |||
| self.assertEqual((t1-t4).sum(), 0) | |||
| def test_save_load(self): | |||
| bert_save_test = 'roberta_save_test' | |||
| try: | |||
| os.makedirs(bert_save_test, exist_ok=True) | |||
| vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInBERT".split()) | |||
| embed = RobertaEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_roberta', | |||
| word_dropout=0.1, | |||
| auto_truncate=True) | |||
| embed.save(bert_save_test) | |||
| load_embed = RobertaEmbedding.load(bert_save_test) | |||
| words = torch.randint(len(vocab), size=(2, 20)) | |||
| embed.eval(), load_embed.eval() | |||
| self.assertEqual((embed(words) - load_embed(words)).sum(), 0) | |||
| finally: | |||
| import shutil | |||
| shutil.rmtree(bert_save_test) | |||
| @@ -108,6 +108,56 @@ class TestLoad(unittest.TestCase): | |||
| for v1i, v2i in zip(v1, v2): | |||
| self.assertAlmostEqual(v1i, v2i, places=4) | |||
| def test_save_load_static_embed(self): | |||
| static_test_folder = 'static_save_test' | |||
| try: | |||
| # 测试包含no_create_entry | |||
| os.makedirs(static_test_folder, exist_ok=True) | |||
| vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A']) | |||
| vocab.add_word_lst(['notinfile2', 'notinfile2'], no_create_entry=True) | |||
| embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||
| 'glove.6B.50d_test.txt') | |||
| embed.save(static_test_folder) | |||
| load_embed = StaticEmbedding.load(static_test_folder) | |||
| words = torch.randint(len(vocab), size=(2, 20)) | |||
| self.assertEqual((embed(words) - load_embed(words)).sum(), 0) | |||
| # 测试不包含no_create_entry | |||
| vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A']) | |||
| embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||
| 'glove.6B.50d_test.txt') | |||
| embed.save(static_test_folder) | |||
| load_embed = StaticEmbedding.load(static_test_folder) | |||
| words = torch.randint(len(vocab), size=(2, 20)) | |||
| self.assertEqual((embed(words) - load_embed(words)).sum(), 0) | |||
| # 测试lower, min_freq | |||
| vocab = Vocabulary().add_word_lst(['The', 'the', 'the', 'A', 'a', 'B']) | |||
| embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||
| 'glove.6B.50d_test.txt', min_freq=2, lower=True) | |||
| embed.save(static_test_folder) | |||
| load_embed = StaticEmbedding.load(static_test_folder) | |||
| words = torch.randint(len(vocab), size=(2, 20)) | |||
| self.assertEqual((embed(words) - load_embed(words)).sum(), 0) | |||
| # 测试random的embedding | |||
| vocab = Vocabulary().add_word_lst(['The', 'the', 'the', 'A', 'a', 'B']) | |||
| vocab = vocab.add_word_lst(['b'], no_create_entry=True) | |||
| embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=4, min_freq=2, lower=True, | |||
| normalize=True) | |||
| embed.weight.data += 0.2 # 使得它不是normalize | |||
| embed.save(static_test_folder) | |||
| load_embed = StaticEmbedding.load(static_test_folder) | |||
| words = torch.randint(len(vocab), size=(2, 20)) | |||
| self.assertEqual((embed(words) - load_embed(words)).sum(), 0) | |||
| finally: | |||
| if os.path.isdir(static_test_folder): | |||
| import shutil | |||
| shutil.rmtree(static_test_folder) | |||
| def read_static_embed(fp): | |||
| """ | |||
| @@ -30,11 +30,11 @@ class TestLoad(unittest.TestCase): | |||
| 'imdb': ('test/data_for_tests/io/imdb', IMDBLoader, (6, 6, 6), False), | |||
| 'ChnSentiCorp': ('test/data_for_tests/io/ChnSentiCorp', ChnSentiCorpLoader, (6, 6, 6), False), | |||
| 'THUCNews': ('test/data_for_tests/io/THUCNews', THUCNewsLoader, (9, 9, 9), False), | |||
| 'WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kLoader, (7, 6, 6), False), | |||
| 'WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kLoader, (6, 6, 7), False), | |||
| } | |||
| for k, v in data_set_dict.items(): | |||
| path, loader, data_set, warns = v | |||
| with self.subTest(loader=loader): | |||
| with self.subTest(path=path): | |||
| if warns: | |||
| with self.assertWarns(Warning): | |||
| data_bundle = loader().load(path) | |||
| @@ -45,5 +45,6 @@ class TestLoad(unittest.TestCase): | |||
| self.assertEqual(len(data_set), data_bundle.num_dataset) | |||
| for x, y in zip(data_set, data_bundle.iter_datasets()): | |||
| name, dataset = y | |||
| self.assertEqual(x, len(dataset)) | |||
| with self.subTest(split=name): | |||
| self.assertEqual(x, len(dataset)) | |||
| @@ -32,7 +32,7 @@ class TestMatchingLoad(unittest.TestCase): | |||
| 'Quora': ('test/data_for_tests/io/Quora', QuoraLoader, (2, 2, 2), False), | |||
| 'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusLoader, (5, 5, 5), False), | |||
| 'XNLI': ('test/data_for_tests/io/XNLI', CNXNLILoader, (6, 8, 6), False), | |||
| 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCLoader, (5, 6, 6), False), | |||
| 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCLoader, (6, 6, 5), False), | |||
| } | |||
| for k, v in data_set_dict.items(): | |||
| path, loader, instance, warns = v | |||
| @@ -46,5 +46,6 @@ class TestMatchingLoad(unittest.TestCase): | |||
| self.assertEqual(len(instance), data_bundle.num_dataset) | |||
| for x, y in zip(instance, data_bundle.iter_datasets()): | |||
| name, dataset = y | |||
| self.assertEqual(x, len(dataset)) | |||
| with self.subTest(path=path, split=name): | |||
| self.assertEqual(x, len(dataset)) | |||
| @@ -70,7 +70,7 @@ class TestRunClassificationPipe(unittest.TestCase): | |||
| } | |||
| for k, v in data_set_dict.items(): | |||
| path, pipe, data_set, vocab, warns = v | |||
| with self.subTest(pipe=pipe): | |||
| with self.subTest(path=path): | |||
| if 'Chn' not in k: | |||
| if warns: | |||
| with self.assertWarns(Warning): | |||
| @@ -39,7 +39,7 @@ class TestRunMatchingPipe(unittest.TestCase): | |||
| 'MNLI': ('test/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True), | |||
| 'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False), | |||
| 'XNLI': ('test/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 8, 6), (39, 3), False), | |||
| 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (5, 6, 6), (36, 2), False), | |||
| 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (6, 6, 5), (36, 2), False), | |||
| } | |||
| for k, v in data_set_dict.items(): | |||
| path, pipe1, pipe2, data_set, vocab, warns = v | |||
| @@ -58,7 +58,8 @@ class TestRunMatchingPipe(unittest.TestCase): | |||
| print(data_bundle2) | |||
| for x, y in zip(data_set, data_bundle1.iter_datasets()): | |||
| name, dataset = y | |||
| self.assertEqual(x, len(dataset)) | |||
| with self.subTest(path=path, split=name): | |||
| self.assertEqual(x, len(dataset)) | |||
| self.assertEqual(len(data_set), data_bundle2.num_dataset) | |||
| for x, y in zip(data_set, data_bundle2.iter_datasets()): | |||
| name, dataset = y | |||
| @@ -0,0 +1,76 @@ | |||
| import unittest | |||
| from fastNLP.models import SequenceGeneratorModel | |||
| from fastNLP.models import LSTMSeq2SeqModel, TransformerSeq2SeqModel | |||
| from fastNLP import Vocabulary, DataSet | |||
| import torch | |||
| from fastNLP.embeddings import StaticEmbedding | |||
| from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric | |||
| from fastNLP import Callback | |||
| def prepare_env(): | |||
| vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
| vocab.add_word_lst("Another test !".split()) | |||
| embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5) | |||
| src_words_idx = [[3, 1, 2], [1, 2]] | |||
| # tgt_words_idx = [[1, 2, 3, 4], [2, 3]] | |||
| src_seq_len = [3, 2] | |||
| # tgt_seq_len = [4, 2] | |||
| ds = DataSet({'src_tokens': src_words_idx, 'src_seq_len': src_seq_len, 'tgt_tokens': src_words_idx, | |||
| 'tgt_seq_len':src_seq_len}) | |||
| ds.set_input('src_tokens', 'tgt_tokens', 'src_seq_len') | |||
| ds.set_target('tgt_seq_len', 'tgt_tokens') | |||
| return embed, ds | |||
| class ExitCallback(Callback): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||
| if eval_result['AccuracyMetric']['acc']==1: | |||
| raise KeyboardInterrupt() | |||
| class TestSeq2SeqGeneratorModel(unittest.TestCase): | |||
| def test_run(self): | |||
| # 检测是否能够使用SequenceGeneratorModel训练, 透传预测 | |||
| embed, ds = prepare_env() | |||
| model1 = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||
| pos_embed='sin', max_position=20, num_layers=2, d_model=30, n_head=6, | |||
| dim_ff=20, dropout=0.1, | |||
| bind_encoder_decoder_embed=True, | |||
| bind_decoder_input_output_embed=True) | |||
| trainer = Trainer(ds, model1, optimizer=None, loss=CrossEntropyLoss(target='tgt_tokens', seq_len='tgt_seq_len'), | |||
| batch_size=32, sampler=None, drop_last=False, update_every=1, | |||
| num_workers=0, n_epochs=100, print_every=5, | |||
| dev_data=ds, metrics=AccuracyMetric(target='tgt_tokens', seq_len='tgt_seq_len'), metric_key=None, | |||
| validate_every=-1, save_path=None, use_tqdm=False, device=None, | |||
| callbacks=ExitCallback(), check_code_level=0) | |||
| res = trainer.train() | |||
| self.assertEqual(res['best_eval']['AccuracyMetric']['acc'], 1) | |||
| embed, ds = prepare_env() | |||
| model2 = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||
| num_layers=1, hidden_size=20, dropout=0.1, | |||
| bind_encoder_decoder_embed=True, | |||
| bind_decoder_input_output_embed=True, attention=True) | |||
| optimizer = torch.optim.Adam(model2.parameters(), lr=0.01) | |||
| trainer = Trainer(ds, model2, optimizer=optimizer, loss=CrossEntropyLoss(target='tgt_tokens', seq_len='tgt_seq_len'), | |||
| batch_size=32, sampler=None, drop_last=False, update_every=1, | |||
| num_workers=0, n_epochs=200, print_every=1, | |||
| dev_data=ds, metrics=AccuracyMetric(target='tgt_tokens', seq_len='tgt_seq_len'), | |||
| metric_key=None, | |||
| validate_every=-1, save_path=None, use_tqdm=False, device=None, | |||
| callbacks=ExitCallback(), check_code_level=0) | |||
| res = trainer.train() | |||
| self.assertEqual(res['best_eval']['AccuracyMetric']['acc'], 1) | |||
| @@ -0,0 +1,114 @@ | |||
| import unittest | |||
| from fastNLP.models.seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel | |||
| from fastNLP import Vocabulary | |||
| from fastNLP.embeddings import StaticEmbedding | |||
| import torch | |||
| from torch import optim | |||
| import torch.nn.functional as F | |||
| from fastNLP import seq_len_to_mask | |||
| def prepare_env(): | |||
| vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
| vocab.add_word_lst("Another test !".split()) | |||
| embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5) | |||
| src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]]) | |||
| tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) | |||
| src_seq_len = torch.LongTensor([3, 2]) | |||
| tgt_seq_len = torch.LongTensor([4, 2]) | |||
| return embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len | |||
| def train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len): | |||
| optimizer = optim.Adam(model.parameters(), lr=1e-2) | |||
| mask = seq_len_to_mask(tgt_seq_len).eq(0) | |||
| target = tgt_words_idx.masked_fill(mask, -100) | |||
| for i in range(100): | |||
| optimizer.zero_grad() | |||
| pred = model(src_words_idx, tgt_words_idx, src_seq_len)['pred'] # bsz x max_len x vocab_size | |||
| loss = F.cross_entropy(pred.transpose(1, 2), target) | |||
| loss.backward() | |||
| optimizer.step() | |||
| right_count = pred.argmax(dim=-1).eq(target).masked_fill(mask, 1).sum() | |||
| return right_count | |||
| class TestTransformerSeq2SeqModel(unittest.TestCase): | |||
| def test_run(self): | |||
| # 测试能否跑通 | |||
| embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() | |||
| for pos_embed in ['learned', 'sin']: | |||
| with self.subTest(pos_embed=pos_embed): | |||
| model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||
| pos_embed=pos_embed, max_position=20, num_layers=2, d_model=30, n_head=6, dim_ff=20, dropout=0.1, | |||
| bind_encoder_decoder_embed=True, | |||
| bind_decoder_input_output_embed=True) | |||
| output = model(src_words_idx, tgt_words_idx, src_seq_len) | |||
| self.assertEqual(output['pred'].size(), (2, 4, len(embed))) | |||
| for bind_encoder_decoder_embed in [True, False]: | |||
| tgt_embed = None | |||
| for bind_decoder_input_output_embed in [True, False]: | |||
| if bind_encoder_decoder_embed == False: | |||
| tgt_embed = embed | |||
| with self.subTest(bind_encoder_decoder_embed=bind_encoder_decoder_embed, | |||
| bind_decoder_input_output_embed=bind_decoder_input_output_embed): | |||
| model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=tgt_embed, | |||
| pos_embed='sin', max_position=20, num_layers=2, | |||
| d_model=30, n_head=6, dim_ff=20, dropout=0.1, | |||
| bind_encoder_decoder_embed=bind_encoder_decoder_embed, | |||
| bind_decoder_input_output_embed=bind_decoder_input_output_embed) | |||
| output = model(src_words_idx, tgt_words_idx, src_seq_len) | |||
| self.assertEqual(output['pred'].size(), (2, 4, len(embed))) | |||
| def test_train(self): | |||
| # 测试能否train到overfit | |||
| embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() | |||
| model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||
| pos_embed='sin', max_position=20, num_layers=2, d_model=30, n_head=6, dim_ff=20, dropout=0.1, | |||
| bind_encoder_decoder_embed=True, | |||
| bind_decoder_input_output_embed=True) | |||
| right_count = train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len) | |||
| self.assertEqual(right_count, tgt_words_idx.nelement()) | |||
| class TestLSTMSeq2SeqModel(unittest.TestCase): | |||
| def test_run(self): | |||
| # 测试能否跑通 | |||
| embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() | |||
| for bind_encoder_decoder_embed in [True, False]: | |||
| tgt_embed = None | |||
| for bind_decoder_input_output_embed in [True, False]: | |||
| if bind_encoder_decoder_embed == False: | |||
| tgt_embed = embed | |||
| with self.subTest(bind_encoder_decoder_embed=bind_encoder_decoder_embed, | |||
| bind_decoder_input_output_embed=bind_decoder_input_output_embed): | |||
| model = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=tgt_embed, | |||
| num_layers=2, hidden_size=20, dropout=0.1, | |||
| bind_encoder_decoder_embed=bind_encoder_decoder_embed, | |||
| bind_decoder_input_output_embed=bind_decoder_input_output_embed) | |||
| output = model(src_words_idx, tgt_words_idx, src_seq_len) | |||
| self.assertEqual(output['pred'].size(), (2, 4, len(embed))) | |||
| def test_train(self): | |||
| embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() | |||
| model = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||
| num_layers=1, hidden_size=20, dropout=0.1, | |||
| bind_encoder_decoder_embed=True, | |||
| bind_decoder_input_output_embed=True) | |||
| right_count = train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len) | |||
| self.assertEqual(right_count, tgt_words_idx.nelement()) | |||
| @@ -0,0 +1,50 @@ | |||
| import unittest | |||
| import torch | |||
| from fastNLP import Vocabulary | |||
| from fastNLP.embeddings import StaticEmbedding | |||
| from fastNLP.modules import TransformerSeq2SeqDecoder | |||
| from fastNLP.modules import LSTMSeq2SeqDecoder | |||
| from fastNLP import seq_len_to_mask | |||
| class TestTransformerSeq2SeqDecoder(unittest.TestCase): | |||
| def test_case(self): | |||
| vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
| vocab.add_word_lst("Another test !".split()) | |||
| embed = StaticEmbedding(vocab, embedding_dim=10) | |||
| encoder_output = torch.randn(2, 3, 10) | |||
| src_seq_len = torch.LongTensor([3, 2]) | |||
| encoder_mask = seq_len_to_mask(src_seq_len) | |||
| for flag in [True, False]: | |||
| with self.subTest(bind_decoder_input_output_embed=flag): | |||
| decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed = None, | |||
| d_model = 10, num_layers=2, n_head = 5, dim_ff = 20, dropout = 0.1, | |||
| bind_decoder_input_output_embed = True) | |||
| state = decoder.init_state(encoder_output, encoder_mask) | |||
| output = decoder(tokens=torch.randint(0, len(vocab), size=(2, 4)), state=state) | |||
| self.assertEqual(output.size(), (2, 4, len(vocab))) | |||
| class TestLSTMDecoder(unittest.TestCase): | |||
| def test_case(self): | |||
| vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
| vocab.add_word_lst("Another test !".split()) | |||
| embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=10) | |||
| encoder_output = torch.randn(2, 3, 10) | |||
| tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) | |||
| src_seq_len = torch.LongTensor([3, 2]) | |||
| encoder_mask = seq_len_to_mask(src_seq_len) | |||
| for flag in [True, False]: | |||
| for attention in [True, False]: | |||
| with self.subTest(bind_decoder_input_output_embed=flag, attention=attention): | |||
| decoder = LSTMSeq2SeqDecoder(embed=embed, num_layers = 2, hidden_size = 10, | |||
| dropout = 0.3, bind_decoder_input_output_embed=flag, attention=attention) | |||
| state = decoder.init_state(encoder_output, encoder_mask) | |||
| output = decoder(tgt_words_idx, state) | |||
| self.assertEqual(tuple(output.size()), (2, 4, len(vocab))) | |||
| @@ -0,0 +1,30 @@ | |||
| import unittest | |||
| import torch | |||
| from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder | |||
| from fastNLP import Vocabulary | |||
| from fastNLP.embeddings import StaticEmbedding | |||
| class TestTransformerSeq2SeqEncoder(unittest.TestCase): | |||
| def test_case(self): | |||
| vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
| embed = StaticEmbedding(vocab, embedding_dim=5) | |||
| encoder = TransformerSeq2SeqEncoder(embed, num_layers=2, d_model=10, n_head=2) | |||
| words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0) | |||
| seq_len = torch.LongTensor([3]) | |||
| encoder_output, encoder_mask = encoder(words_idx, seq_len) | |||
| self.assertEqual(encoder_output.size(), (1, 3, 10)) | |||
| class TestBiLSTMEncoder(unittest.TestCase): | |||
| def test_case(self): | |||
| vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
| embed = StaticEmbedding(vocab, embedding_dim=5) | |||
| encoder = LSTMSeq2SeqEncoder(embed, hidden_size=5, num_layers=1) | |||
| words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0) | |||
| seq_len = torch.LongTensor([3]) | |||
| encoder_output, encoder_mask = encoder(words_idx, seq_len) | |||
| self.assertEqual(encoder_mask.size(), (1, 3)) | |||
| @@ -0,0 +1 @@ | |||
| @@ -0,0 +1,110 @@ | |||
| import unittest | |||
| import torch | |||
| from fastNLP.modules.generator import SequenceGenerator | |||
| from fastNLP.modules import TransformerSeq2SeqDecoder, LSTMSeq2SeqDecoder, Seq2SeqDecoder, State | |||
| from fastNLP import Vocabulary | |||
| from fastNLP.embeddings import StaticEmbedding | |||
| from torch import nn | |||
| from fastNLP import seq_len_to_mask | |||
| def prepare_env(): | |||
| vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
| vocab.add_word_lst("Another test !".split()) | |||
| embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5) | |||
| encoder_output = torch.randn(2, 3, 10) | |||
| src_seq_len = torch.LongTensor([3, 2]) | |||
| encoder_mask = seq_len_to_mask(src_seq_len) | |||
| return embed, encoder_output, encoder_mask | |||
| class TestSequenceGenerator(unittest.TestCase): | |||
| def test_run(self): | |||
| # 测试能否运行 (1) 初始化decoder,(2) decode一发 | |||
| embed, encoder_output, encoder_mask = prepare_env() | |||
| for do_sample in [True, False]: | |||
| for num_beams in [1, 3, 5]: | |||
| with self.subTest(do_sample=do_sample, num_beams=num_beams): | |||
| decoder = LSTMSeq2SeqDecoder(embed=embed, num_layers=1, hidden_size=10, | |||
| dropout=0.3, bind_decoder_input_output_embed=True, attention=True) | |||
| state = decoder.init_state(encoder_output, encoder_mask) | |||
| generator = SequenceGenerator(decoder=decoder, max_length=20, num_beams=num_beams, | |||
| do_sample=do_sample, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=1, eos_token_id=None, | |||
| repetition_penalty=1, length_penalty=1.0, pad_token_id=0) | |||
| generator.generate(state=state, tokens=None) | |||
| decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed=nn.Embedding(10, embed.embedding_dim), | |||
| d_model=encoder_output.size(-1), num_layers=2, n_head=2, dim_ff=10, dropout=0.1, | |||
| bind_decoder_input_output_embed=True) | |||
| state = decoder.init_state(encoder_output, encoder_mask) | |||
| generator = SequenceGenerator(decoder=decoder, max_length=5, num_beams=num_beams, | |||
| do_sample=do_sample, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=1, eos_token_id=None, | |||
| repetition_penalty=1, length_penalty=1.0, pad_token_id=0) | |||
| generator.generate(state=state, tokens=None) | |||
| # 测试一下其它值 | |||
| decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed=nn.Embedding(10, embed.embedding_dim), | |||
| d_model=encoder_output.size(-1), num_layers=2, n_head=2, dim_ff=10, | |||
| dropout=0.1, | |||
| bind_decoder_input_output_embed=True) | |||
| state = decoder.init_state(encoder_output, encoder_mask) | |||
| generator = SequenceGenerator(decoder=decoder, max_length=5, num_beams=num_beams, | |||
| do_sample=do_sample, temperature=0.9, top_k=50, top_p=0.5, bos_token_id=1, | |||
| eos_token_id=3, repetition_penalty=2, length_penalty=1.5, pad_token_id=0) | |||
| generator.generate(state=state, tokens=None) | |||
| def test_greedy_decode(self): | |||
| # 测试能否正确的generate | |||
| class GreedyDummyDecoder(Seq2SeqDecoder): | |||
| def __init__(self, decoder_output): | |||
| super().__init__() | |||
| self.cur_length = 0 | |||
| self.decoder_output = decoder_output | |||
| def decode(self, tokens, state): | |||
| self.cur_length += 1 | |||
| scores = self.decoder_output[:, self.cur_length] | |||
| return scores | |||
| class DummyState(State): | |||
| def __init__(self, decoder): | |||
| super().__init__() | |||
| self.decoder = decoder | |||
| def reorder_state(self, indices: torch.LongTensor): | |||
| self.decoder.decoder_output = self._reorder_state(self.decoder.decoder_output, indices, dim=0) | |||
| # greedy | |||
| for beam_search in [1, 3]: | |||
| decoder_output = torch.randn(2, 10, 5) | |||
| path = decoder_output.argmax(dim=-1) # 2 x 4 | |||
| decoder = GreedyDummyDecoder(decoder_output) | |||
| with self.subTest(beam_search=beam_search): | |||
| generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search, | |||
| do_sample=False, temperature=1, top_k=50, top_p=1, bos_token_id=1, | |||
| eos_token_id=None, repetition_penalty=1, length_penalty=1, pad_token_id=0) | |||
| decode_path = generator.generate(DummyState(decoder), tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True)) | |||
| self.assertEqual(decode_path.eq(path).sum(), path.numel()) | |||
| # greedy check eos_token_id | |||
| for beam_search in [1, 3]: | |||
| decoder_output = torch.randn(2, 10, 5) | |||
| decoder_output[:, :7, 4].fill_(-100) | |||
| decoder_output[0, 7, 4] = 1000 # 在第8个结束 | |||
| decoder_output[1, 5, 4] = 1000 | |||
| path = decoder_output.argmax(dim=-1) # 2 x 4 | |||
| decoder = GreedyDummyDecoder(decoder_output) | |||
| with self.subTest(beam_search=beam_search): | |||
| generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search, | |||
| do_sample=False, temperature=1, top_k=50, top_p=0.5, bos_token_id=1, | |||
| eos_token_id=4, repetition_penalty=1, length_penalty=1, pad_token_id=0) | |||
| decode_path = generator.generate(DummyState(decoder), | |||
| tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True)) | |||
| self.assertEqual(decode_path.size(1), 8) # 长度为8 | |||
| self.assertEqual(decode_path[0].eq(path[0, :8]).sum(), 8) | |||
| self.assertEqual(decode_path[1, :6].eq(path[1, :6]).sum(), 6) | |||