@@ -16,6 +16,7 @@ from ._logger import logger | |||
from .dataset import DataSet | |||
from .utils import Option | |||
from .utils import _is_iterable | |||
import io | |||
class VocabularyOption(Option): | |||
@@ -487,76 +488,99 @@ class Vocabulary(object): | |||
def save(self, filepath): | |||
r""" | |||
:param str filepath: Vocabulary的储存路径 | |||
:param str,io.StringIO filepath: Vocabulary的储存路径 | |||
:return: | |||
""" | |||
with open(filepath, 'w', encoding='utf-8') as f: | |||
f.write(f'max_size\t{self.max_size}\n') | |||
f.write(f'min_freq\t{self.min_freq}\n') | |||
f.write(f'unknown\t{self.unknown}\n') | |||
f.write(f'padding\t{self.padding}\n') | |||
f.write(f'rebuild\t{self.rebuild}\n') | |||
f.write('\n') | |||
# idx: 如果idx为-2, 说明还没有进行build; 如果idx为-1,说明该词未编入 | |||
# no_create_entry: 如果为1,说明该词是no_create_entry; 0 otherwise | |||
# word \t count \t idx \t no_create_entry \n | |||
idx = -2 | |||
for word, count in self.word_count.items(): | |||
if self._word2idx is not None: | |||
idx = self._word2idx.get(word, -1) | |||
is_no_create_entry = int(self._is_word_no_create_entry(word)) | |||
f.write(f'{word}\t{count}\t{idx}\t{is_no_create_entry}\n') | |||
if isinstance(filepath, io.IOBase): | |||
assert filepath.writable() | |||
f = filepath | |||
elif isinstance(filepath, str): | |||
try: | |||
f = open(filepath, 'w', encoding='utf-8') | |||
except Exception as e: | |||
raise e | |||
else: | |||
raise TypeError("Illegal `filepath`.") | |||
f.write(f'max_size\t{self.max_size}\n') | |||
f.write(f'min_freq\t{self.min_freq}\n') | |||
f.write(f'unknown\t{self.unknown}\n') | |||
f.write(f'padding\t{self.padding}\n') | |||
f.write(f'rebuild\t{self.rebuild}\n') | |||
f.write('\n') | |||
# idx: 如果idx为-2, 说明还没有进行build; 如果idx为-1,说明该词未编入 | |||
# no_create_entry: 如果为1,说明该词是no_create_entry; 0 otherwise | |||
# word \t count \t idx \t no_create_entry \n | |||
idx = -2 | |||
for word, count in self.word_count.items(): | |||
if self._word2idx is not None: | |||
idx = self._word2idx.get(word, -1) | |||
is_no_create_entry = int(self._is_word_no_create_entry(word)) | |||
f.write(f'{word}\t{count}\t{idx}\t{is_no_create_entry}\n') | |||
if isinstance(filepath, str): # 如果是file的话就关闭 | |||
f.close() | |||
@staticmethod | |||
def load(filepath): | |||
r""" | |||
:param str filepath: Vocabulary的读取路径 | |||
:param str,io.StringIO filepath: Vocabulary的读取路径 | |||
:return: Vocabulary | |||
""" | |||
with open(filepath, 'r', encoding='utf-8') as f: | |||
vocab = Vocabulary() | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
name, value = line.split() | |||
if name in ('max_size', 'min_freq'): | |||
value = int(value) if value!='None' else None | |||
setattr(vocab, name, value) | |||
elif name in ('unknown', 'padding'): | |||
value = value if value!='None' else None | |||
setattr(vocab, name, value) | |||
elif name == 'rebuild': | |||
vocab.rebuild = True if value=='True' else False | |||
else: | |||
break | |||
word_counter = {} | |||
no_create_entry_counter = {} | |||
word2idx = {} | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
parts = line.split('\t') | |||
word,count,idx,no_create_entry = parts[0], int(parts[1]), int(parts[2]), int(parts[3]) | |||
if idx >= 0: | |||
word2idx[word] = idx | |||
word_counter[word] = count | |||
if no_create_entry: | |||
no_create_entry_counter[word] = count | |||
word_counter = Counter(word_counter) | |||
no_create_entry_counter = Counter(no_create_entry_counter) | |||
if len(word2idx)>0: | |||
if vocab.padding: | |||
word2idx[vocab.padding] = 0 | |||
if vocab.unknown: | |||
word2idx[vocab.unknown] = 1 if vocab.padding else 0 | |||
idx2word = {value:key for key,value in word2idx.items()} | |||
vocab.word_count = word_counter | |||
vocab._no_create_word = no_create_entry_counter | |||
if word2idx: | |||
vocab._word2idx = word2idx | |||
vocab._idx2word = idx2word | |||
if isinstance(filepath, io.IOBase): | |||
assert filepath.writable() | |||
f = filepath | |||
elif isinstance(filepath, str): | |||
try: | |||
f = open(filepath, 'r', encoding='utf-8') | |||
except Exception as e: | |||
raise e | |||
else: | |||
raise TypeError("Illegal `filepath`.") | |||
vocab = Vocabulary() | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
name, value = line.split() | |||
if name in ('max_size', 'min_freq'): | |||
value = int(value) if value!='None' else None | |||
setattr(vocab, name, value) | |||
elif name in ('unknown', 'padding'): | |||
value = value if value!='None' else None | |||
setattr(vocab, name, value) | |||
elif name == 'rebuild': | |||
vocab.rebuild = True if value=='True' else False | |||
else: | |||
break | |||
word_counter = {} | |||
no_create_entry_counter = {} | |||
word2idx = {} | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
parts = line.split('\t') | |||
word,count,idx,no_create_entry = parts[0], int(parts[1]), int(parts[2]), int(parts[3]) | |||
if idx >= 0: | |||
word2idx[word] = idx | |||
word_counter[word] = count | |||
if no_create_entry: | |||
no_create_entry_counter[word] = count | |||
word_counter = Counter(word_counter) | |||
no_create_entry_counter = Counter(no_create_entry_counter) | |||
if len(word2idx)>0: | |||
if vocab.padding: | |||
word2idx[vocab.padding] = 0 | |||
if vocab.unknown: | |||
word2idx[vocab.unknown] = 1 if vocab.padding else 0 | |||
idx2word = {value:key for key,value in word2idx.items()} | |||
vocab.word_count = word_counter | |||
vocab._no_create_word = no_create_entry_counter | |||
if word2idx: | |||
vocab._word2idx = word2idx | |||
vocab._idx2word = idx2word | |||
if isinstance(filepath, str): # 如果是file的话就关闭 | |||
f.close() | |||
return vocab |
@@ -22,8 +22,9 @@ __all__ = [ | |||
"StackEmbedding", | |||
"LSTMCharEmbedding", | |||
"CNNCharEmbedding", | |||
"get_embeddings", | |||
"get_embeddings", | |||
"get_sinusoid_encoding_table" | |||
] | |||
from .embedding import Embedding, TokenEmbedding | |||
@@ -34,7 +35,7 @@ from .roberta_embedding import RobertaEmbedding, RobertaWordPieceEncoder | |||
from .gpt2_embedding import GPT2WordPieceEncoder, GPT2Embedding | |||
from .char_embedding import CNNCharEmbedding, LSTMCharEmbedding | |||
from .stack_embedding import StackEmbedding | |||
from .utils import get_embeddings | |||
from .utils import get_embeddings, get_sinusoid_encoding_table | |||
import sys | |||
from ..doc_utils import doc_process |
@@ -8,11 +8,11 @@ __all__ = [ | |||
"BertWordPieceEncoder" | |||
] | |||
import collections | |||
import os | |||
import warnings | |||
from itertools import chain | |||
from functools import partial | |||
import json | |||
import numpy as np | |||
import torch | |||
from torch import nn | |||
@@ -24,6 +24,13 @@ from ..io.file_utils import PRETRAINED_BERT_MODEL_DIR | |||
from ..modules.encoder.bert import BertModel | |||
from ..modules.tokenizer import BertTokenizer | |||
# TODO 需要重新修改,使得encoder可以直接读取embedding的权重 | |||
VOCAB_NAME = 'vocab.txt' | |||
BERT_EMBED_HYPER = 'bert_hyper.json' | |||
BERT_EMBED_FOLDER = 'bert' | |||
BERT_ENCODER_HYPER = 'bert_hyper.json' | |||
BERT_ENCODER_FOLDER = 'bert' | |||
class BertEmbedding(ContextualEmbedding): | |||
r""" | |||
@@ -82,10 +89,7 @@ class BertEmbedding(ContextualEmbedding): | |||
word pieces后的内容,并将第512个word piece置为[SEP]。超过长度的部分的encode结果直接全部置零。一般仅有只使用[CLS] | |||
来进行分类的任务将auto_truncate置为True。 | |||
:param kwargs: | |||
bool only_use_pretrain_bpe: 仅使用出现在pretrain词表中的bpe,如果该词没法tokenize则使用unk。如果embedding不需要更新 | |||
建议设置为True。 | |||
int min_freq: 仅在only_use_pretrain_bpe为False有效,大于等于该次数的词会被新加入BERT的BPE词表中 | |||
bool truncate_embed: 是否仅保留用到的bpe(这样会减内存占用和加快速度) | |||
int min_freq: 小于该次数的词会被unk代替 | |||
""" | |||
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||
@@ -106,14 +110,11 @@ class BertEmbedding(ContextualEmbedding): | |||
if '[CLS]' in vocab: | |||
self._word_cls_index = vocab['CLS'] | |||
only_use_pretrain_bpe = kwargs.get('only_use_pretrain_bpe', False) | |||
truncate_embed = kwargs.get('truncate_embed', True) | |||
min_freq = kwargs.get('min_freq', 2) | |||
self._min_freq = min_freq | |||
self.model = _BertWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | |||
pool_method=pool_method, include_cls_sep=include_cls_sep, | |||
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq, | |||
only_use_pretrain_bpe=only_use_pretrain_bpe, truncate_embed=truncate_embed) | |||
pooled_cls=pooled_cls, min_freq=min_freq, auto_truncate=auto_truncate) | |||
self.requires_grad = requires_grad | |||
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | |||
@@ -160,6 +161,57 @@ class BertEmbedding(ContextualEmbedding): | |||
words = words.masked_fill(mask, self._word_unk_index) | |||
return words | |||
def save(self, folder): | |||
""" | |||
将embedding保存到folder这个目录下,将会保存三个文件vocab.txt, bert_embed_hyper.txt, bert_embed/, 其中bert_embed下包含 | |||
config.json,pytorch_model.bin,vocab.txt三个文件(该folder下的数据也可以直接被BERTModel读取) | |||
:param str folder: | |||
:return: | |||
""" | |||
os.makedirs(folder, exist_ok=True) | |||
self.get_word_vocab().save(os.path.join(folder, VOCAB_NAME)) | |||
hyper = {} | |||
hyper['min_freq'] = self._min_freq | |||
hyper['layers'] = ','.join(map(str, self.model.layers)) | |||
hyper['pool_method'] = self.model.pool_method | |||
hyper['dropout'] = self.dropout_layer.p | |||
hyper['word_dropout'] = self.word_dropout | |||
hyper['include_cls_sep'] = self.model.include_cls_sep | |||
hyper['pooled_cls'] = self.model.pooled_cls | |||
hyper['auto_truncate'] = self.model.auto_truncate | |||
hyper['requires_grad'] = bool(self.requires_grad) | |||
with open(os.path.join(folder, BERT_EMBED_HYPER), 'w', encoding='utf-8') as f: | |||
json.dump(hyper, f, indent=2) | |||
os.makedirs(os.path.join(folder, BERT_EMBED_FOLDER), exist_ok=True) | |||
self.model.save(os.path.join(folder, BERT_EMBED_FOLDER)) | |||
logger.debug(f"BERTEmbedding has been saved in {folder}") | |||
@classmethod | |||
def load(cls, folder): | |||
""" | |||
给定一个folder, 需要包含以下三个内容vocab.txt, bert_embed_hyper.txt, bert_embed/ | |||
:param str folder: | |||
:return: | |||
""" | |||
for name in [VOCAB_NAME, BERT_EMBED_FOLDER, BERT_EMBED_HYPER]: | |||
assert os.path.exists(os.path.join(folder, name)), f"{name} not found in {folder}." | |||
vocab = Vocabulary.load(os.path.join(folder, VOCAB_NAME)) | |||
with open(os.path.join(folder, BERT_EMBED_HYPER), 'r', encoding='utf-8') as f: | |||
hyper = json.load(f) | |||
model_dir_or_name = os.path.join(os.path.join(folder, BERT_EMBED_FOLDER)) | |||
bert_embed = cls(vocab=vocab, model_dir_or_name=model_dir_or_name, **hyper) | |||
return bert_embed | |||
class BertWordPieceEncoder(nn.Module): | |||
r""" | |||
@@ -180,7 +232,7 @@ class BertWordPieceEncoder(nn.Module): | |||
""" | |||
def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False, | |||
word_dropout=0, dropout=0, requires_grad: bool = True): | |||
word_dropout=0, dropout=0, requires_grad: bool = True, **kwargs): | |||
r""" | |||
:param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased`` | |||
@@ -270,11 +322,53 @@ class BertWordPieceEncoder(nn.Module): | |||
words = words.masked_fill(mask, self._wordpiece_unk_index) | |||
return words | |||
def save(self, folder): | |||
""" | |||
会在folder下创建两个文件bert_encoder_hyper.json与bert_encoder/, bert_encoder下包含三个文件config.json, | |||
pytorch_model.bin,vocab.txt三个文件(该folder下的数据也可以直接被BERTModel读取) | |||
:param str folder: | |||
:return: | |||
""" | |||
os.makedirs(folder, exist_ok=True) | |||
hyper = {} | |||
hyper['layers'] = ','.join(map(str, self.model.layers)) | |||
hyper['dropout'] = self.dropout_layer.p | |||
hyper['word_dropout'] = self.word_dropout | |||
hyper['pooled_cls'] = self.model.pooled_cls | |||
hyper['requires_grad'] = bool(self.requires_grad) | |||
with open(os.path.join(folder, BERT_ENCODER_HYPER), 'w', encoding='utf-8') as f: | |||
json.dump(hyper, f, indent=2) | |||
os.makedirs(os.path.join(folder, BERT_ENCODER_FOLDER), exist_ok=True) | |||
self.model.save(os.path.join(folder, BERT_ENCODER_FOLDER)) | |||
logger.debug(f"BertWordPieceEncoder has been saved in {folder}") | |||
@classmethod | |||
def load(cls, folder): | |||
""" | |||
会在folder下创建两个文件bert_encoder_hyper.json与bert_encoder/, bert_encoder下包含三个文件 | |||
:param folder: | |||
:return: | |||
""" | |||
for name in [BERT_ENCODER_HYPER, BERT_ENCODER_FOLDER]: | |||
assert os.path.exists(os.path.join(folder, name)), f"{name} not found in {folder}." | |||
with open(os.path.join(folder, BERT_ENCODER_HYPER), 'r', encoding='utf-8') as f: | |||
hyper = json.load(f) | |||
model_dir_or_name = os.path.join(os.path.join(folder, BERT_ENCODER_FOLDER)) | |||
bert_encoder = cls(model_dir_or_name=model_dir_or_name, **hyper) | |||
return bert_encoder | |||
class _BertWordModel(nn.Module): | |||
def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', | |||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2, | |||
only_use_pretrain_bpe=False, truncate_embed=True): | |||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2): | |||
super().__init__() | |||
self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name) | |||
@@ -303,73 +397,8 @@ class _BertWordModel(nn.Module): | |||
self.auto_truncate = auto_truncate | |||
# 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP] | |||
logger.info("Start to generate word pieces for word.") | |||
self._has_sep_in_vocab = '[SEP]' in vocab # 用来判断传入的数据是否需要生成token_ids | |||
# 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值 | |||
word_piece_dict = {'[CLS]': 1, '[SEP]': 1} # 用到的word_piece以及新增的 | |||
new_add_to_bpe_vocab = 0 | |||
unsegment_count = 0 | |||
if '[sep]' in vocab: | |||
warnings.warn("Lower cased [sep] detected, it cannot be correctly recognized as [SEP] by BertEmbedding.") | |||
if "[CLS]" in vocab: | |||
warnings.warn("[CLS] detected in your vocabulary. BertEmbedding will add [CLS] and [SEP] to the begin " | |||
"and end of the input automatically, make sure you don't add [CLS] and [SEP] at the begin" | |||
" and end.") | |||
for word, index in vocab: | |||
if index == vocab.padding_idx: # pad是个特殊的符号 | |||
word = '[PAD]' | |||
elif index == vocab.unknown_idx: | |||
word = '[UNK]' | |||
_words = self.tokenzier.basic_tokenizer._tokenize_chinese_chars(word).split() | |||
word_pieces = [] | |||
for w in _words: | |||
word_pieces.extend(self.tokenzier.wordpiece_tokenizer.tokenize(w)) | |||
if len(word_pieces) == 1: | |||
if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到 | |||
if index != vocab.unknown_idx and word_pieces[0] == '[UNK]': # 说明这个词不在原始的word里面 | |||
if vocab.word_count[word] >= min_freq and not vocab._is_word_no_create_entry( | |||
word) and not only_use_pretrain_bpe: # 出现次数大于这个次数才新增 | |||
word_piece_dict[word] = 1 # 新增一个值 | |||
new_add_to_bpe_vocab += 1 | |||
unsegment_count += 1 | |||
continue | |||
for word_piece in word_pieces: | |||
word_piece_dict[word_piece] = 1 | |||
original_embed = self.encoder.embeddings.word_embeddings.weight.data | |||
# 特殊词汇要特殊处理 | |||
if not truncate_embed:# 如果不删除的话需要将已有的加上 | |||
word_piece_dict.update(self.tokenzier.vocab) | |||
embed = nn.Embedding(len(word_piece_dict), original_embed.size(1)) # 新的embed | |||
new_word_piece_vocab = collections.OrderedDict() | |||
for index, token in enumerate(['[PAD]', '[UNK]']): | |||
index = word_piece_dict.pop(token, None) | |||
if index is not None: | |||
new_word_piece_vocab[token] = len(new_word_piece_vocab) | |||
embed.weight.data[new_word_piece_vocab[token]] = original_embed[self.tokenzier.vocab[token]] | |||
for token in word_piece_dict.keys(): | |||
if token not in new_word_piece_vocab: | |||
new_word_piece_vocab[token] = len(new_word_piece_vocab) | |||
index = new_word_piece_vocab[token] | |||
if token in self.tokenzier.vocab: | |||
embed.weight.data[index] = original_embed[self.tokenzier.vocab[token]] | |||
else: | |||
embed.weight.data[index] = original_embed[self.tokenzier.vocab['[UNK]']] | |||
self.tokenzier._reinit_on_new_vocab(new_word_piece_vocab) | |||
self.encoder.embeddings.word_embeddings = embed | |||
self.encoder.config.vocab_size = len(new_word_piece_vocab) | |||
if unsegment_count>0: | |||
if only_use_pretrain_bpe or new_add_to_bpe_vocab==0: | |||
logger.info(f"{unsegment_count} words are unsegmented.") | |||
else: | |||
logger.info(f"{unsegment_count} words are unsegmented. Among them, {new_add_to_bpe_vocab} added to the BPE vocab.") | |||
word_to_wordpieces = [] | |||
word_pieces_lengths = [] | |||
for word, index in vocab: | |||
@@ -377,6 +406,8 @@ class _BertWordModel(nn.Module): | |||
word = '[PAD]' | |||
elif index == vocab.unknown_idx: | |||
word = '[UNK]' | |||
elif vocab.word_count[word]<min_freq: | |||
word = '[UNK]' | |||
word_pieces = self.tokenzier.wordpiece_tokenizer.tokenize(word) | |||
word_pieces = self.tokenzier.convert_tokens_to_ids(word_pieces) | |||
word_to_wordpieces.append(word_pieces) | |||
@@ -504,6 +535,16 @@ class _BertWordModel(nn.Module): | |||
# 3. 最终的embedding结果 | |||
return outputs | |||
def save(self, folder): | |||
""" | |||
给定一个folder保存pytorch_model.bin, config.json, vocab.txt | |||
:param str folder: | |||
:return: | |||
""" | |||
self.tokenzier.save_pretrained(folder) | |||
self.encoder.save_pretrained(folder) | |||
class _BertWordPieceModel(nn.Module): | |||
r""" | |||
@@ -580,4 +621,14 @@ class _BertWordPieceModel(nn.Module): | |||
if l in (len(bert_outputs)-1, -1) and self.pooled_cls: | |||
bert_output[:, 0] = pooled_cls | |||
outputs[l_index] = bert_output | |||
return outputs | |||
return outputs | |||
def save(self, folder): | |||
""" | |||
给定一个folder保存pytorch_model.bin, config.json, vocab.txt | |||
:param folder: | |||
:return: | |||
""" | |||
self.tokenzier.save_pretrained(folder) | |||
self.encoder.save_pretrained(folder) |
@@ -78,7 +78,7 @@ class Embedding(nn.Module): | |||
if isinstance(self.embed, nn.Embedding): | |||
return self.embed.weight.size(0) | |||
else: | |||
return self.embed.num_embedding | |||
return self.embed.num_embeddings | |||
def __len__(self): | |||
return len(self.embed) | |||
@@ -188,7 +188,7 @@ class TokenEmbedding(nn.Module): | |||
return self._embed_size | |||
@property | |||
def num_embedding(self) -> int: | |||
def num_embeddings(self) -> int: | |||
r""" | |||
这个值可能会大于实际的embedding矩阵的大小。 | |||
:return: | |||
@@ -205,7 +205,7 @@ class TokenEmbedding(nn.Module): | |||
@property | |||
def size(self): | |||
return torch.Size(self.num_embedding, self._embed_size) | |||
return torch.Size(self.num_embeddings, self._embed_size) | |||
@abstractmethod | |||
def forward(self, words): | |||
@@ -10,8 +10,8 @@ __all__ = [ | |||
from functools import partial | |||
import collections | |||
import warnings | |||
import os | |||
import json | |||
from itertools import chain | |||
import numpy as np | |||
@@ -24,6 +24,13 @@ from ..modules.encoder.roberta import RobertaModel | |||
from ..modules.tokenizer import RobertaTokenizer | |||
VOCAB_NAME = 'vocab.txt' | |||
ROBERTA_EMBED_HYPER = 'roberta_hyper.json' | |||
ROBERTA_ENCODER_HYPER = 'roberta_hyper.json' | |||
ROBERTA_EMBED_FOLDER = 'roberta' | |||
ROBERTA_ENCODER_FOLDER = 'roberta' | |||
class RobertaEmbedding(ContextualEmbedding): | |||
r""" | |||
使用RoBERTa对words进行编码的Embedding。建议将输入的words长度限制在430以内,而不要使用512(根据预训练模型参数,可能有变化)。这是由于 | |||
@@ -71,10 +78,7 @@ class RobertaEmbedding(ContextualEmbedding): | |||
word pieces后的内容,并将第512个word piece置为</s>。超过长度的部分的encode结果直接全部置零。一般仅有只使用<s> | |||
来进行分类的任务将auto_truncate置为True。 | |||
:param kwargs: | |||
bool only_use_pretrain_bpe: 仅使用出现在pretrain词表中的bpe,如果该词没法tokenize则使用unk。如果embedding不需要更新 | |||
建议设置为True。 | |||
int min_freq: 仅在only_use_pretrain_bpe为False有效,大于等于该次数的词会被新加入BERT的BPE词表中 | |||
bool truncate_embed: 是否仅保留用到的bpe(这样会减内存占用和加快速度) | |||
int min_freq: 小于该次数的词会被unk代替 | |||
""" | |||
super().__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||
@@ -89,14 +93,12 @@ class RobertaEmbedding(ContextualEmbedding): | |||
if '<s>' in vocab: | |||
self._word_cls_index = vocab['<s>'] | |||
only_use_pretrain_bpe = kwargs.get('only_use_pretrain_bpe', False) | |||
truncate_embed = kwargs.get('truncate_embed', True) | |||
min_freq = kwargs.get('min_freq', 2) | |||
self._min_freq = min_freq | |||
self.model = _RobertaWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | |||
pool_method=pool_method, include_cls_sep=include_cls_sep, | |||
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq, | |||
only_use_pretrain_bpe=only_use_pretrain_bpe, truncate_embed=truncate_embed) | |||
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq) | |||
self.requires_grad = requires_grad | |||
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | |||
@@ -142,11 +144,56 @@ class RobertaEmbedding(ContextualEmbedding): | |||
words = words.masked_fill(mask, self._word_unk_index) | |||
return words | |||
def save(self, folder): | |||
""" | |||
将roberta embedding保存到folder,保存之后包含三个文件vocab.txt, roberta_embed_hyper.txt, roberta_embed/, | |||
:param str folder: 保存地址 | |||
:return: | |||
""" | |||
os.makedirs(folder, exist_ok=True) | |||
self.get_word_vocab().save(os.path.join(folder, VOCAB_NAME)) | |||
hyper = {} | |||
hyper['min_freq'] = self._min_freq | |||
hyper['layers'] = ','.join(map(str, self.model.layers)) | |||
hyper['pool_method'] = self.model.pool_method | |||
hyper['dropout'] = self.dropout_layer.p | |||
hyper['word_dropout'] = self.word_dropout | |||
hyper['include_cls_sep'] = self.model.include_cls_sep | |||
hyper['pooled_cls'] = self.model.pooled_cls | |||
hyper['auto_truncate'] = self.model.auto_truncate | |||
hyper['requires_grad'] = bool(self.requires_grad) | |||
with open(os.path.join(folder, ROBERTA_EMBED_HYPER), 'w', encoding='utf-8') as f: | |||
json.dump(hyper, f, indent=2) | |||
os.makedirs(os.path.join(folder, ROBERTA_EMBED_FOLDER), exist_ok=True) | |||
self.model.save(os.path.join(folder, ROBERTA_EMBED_FOLDER)) | |||
@classmethod | |||
def load(cls, folder): | |||
""" | |||
从folder中读取数据初始化RobertaEmbedding | |||
:param folder: | |||
:return: | |||
""" | |||
for name in [VOCAB_NAME, ROBERTA_EMBED_HYPER, ROBERTA_EMBED_FOLDER]: | |||
assert os.path.exists(os.path.join(folder, name)), f"{name} not found in {folder}." | |||
vocab = Vocabulary.load(os.path.join(folder, VOCAB_NAME)) | |||
with open(os.path.join(folder, ROBERTA_EMBED_HYPER), 'r', encoding='utf-8') as f: | |||
hyper = json.load(f) | |||
model_name_or_path = os.path.join(folder, ROBERTA_EMBED_FOLDER) | |||
roberta = cls(vocab=vocab, model_dir_or_name=model_name_or_path, **hyper) | |||
return roberta | |||
class _RobertaWordModel(nn.Module): | |||
def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', | |||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2, | |||
only_use_pretrain_bpe=False, truncate_embed=True): | |||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2): | |||
super().__init__() | |||
self.tokenzier = RobertaTokenizer.from_pretrained(model_dir_or_name) | |||
@@ -177,72 +224,6 @@ class _RobertaWordModel(nn.Module): | |||
self.pooled_cls = pooled_cls | |||
self.auto_truncate = auto_truncate | |||
# 将所有vocab中word的wordpiece计算出来, 需要额外考虑<s>和</s> | |||
logger.info("Start to generate word pieces for word.") | |||
# 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值 | |||
word_piece_dict = {'<s>': 1, '</s>': 1} # 用到的word_piece以及新增的 | |||
found_count = 0 | |||
new_add_to_bpe_vocab = 0 | |||
unsegment_count = 0 | |||
if "<s>" in vocab: | |||
warnings.warn("<s> detected in your vocabulary. RobertaEmbedding will add <s> and </s> to the begin " | |||
"and end of the input automatically, make sure you don't add <s> and </s> at the begin" | |||
" and end.") | |||
for word, index in vocab: | |||
if index == vocab.padding_idx: # pad是个特殊的符号 | |||
word = '<pad>' | |||
elif index == vocab.unknown_idx: | |||
word = '<unk>' | |||
# _words = self.tokenzier.basic_tokenizer._tokenize_chinese_chars(word).split() # 这里暂时不考虑中文内容 | |||
word_pieces = [] | |||
# 如果这个word不是在句子开头 | |||
word_pieces.extend(self.tokenzier.tokenize(word, add_prefix_space=True)) | |||
if len(word_pieces) == 1: | |||
if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到 | |||
if index != vocab.unknown_idx and word_pieces[0] == '<unk>': # 说明这个词不在原始的word里面 | |||
if vocab.word_count[word] >= min_freq and not vocab._is_word_no_create_entry( | |||
word) and not only_use_pretrain_bpe: # 出现次数大于这个次数才新增 | |||
word_piece_dict[word] = 1 # 新增一个值 | |||
new_add_to_bpe_vocab += 1 | |||
unsegment_count += 1 | |||
continue | |||
found_count += 1 | |||
for word_piece in word_pieces: | |||
word_piece_dict[word_piece] = 1 | |||
# 如果这个word是在句子开头 | |||
original_embed = self.encoder.embeddings.word_embeddings.weight.data | |||
# 特殊词汇要特殊处理 | |||
if not truncate_embed: # 如果不删除的话需要将已有的加上 | |||
word_piece_dict.update(self.tokenzier.encoder) | |||
embed = nn.Embedding(len(word_piece_dict), original_embed.size(1)) # 新的embed | |||
new_word_piece_vocab = collections.OrderedDict() | |||
for index, token in enumerate(['<s>', '<pad>', '</s>', '<unk>']): | |||
index = word_piece_dict.pop(token, None) | |||
if index is not None: | |||
new_word_piece_vocab[token] = len(new_word_piece_vocab) | |||
embed.weight.data[new_word_piece_vocab[token]] = original_embed[self.tokenzier.encoder[token]] | |||
for token in word_piece_dict.keys(): | |||
if token not in new_word_piece_vocab: | |||
new_word_piece_vocab[token] = len(new_word_piece_vocab) | |||
index = new_word_piece_vocab[token] | |||
if token in self.tokenzier.encoder: | |||
embed.weight.data[index] = original_embed[self.tokenzier.encoder[token]] | |||
else: | |||
embed.weight.data[index] = original_embed[self.tokenzier.encoder['<unk>']] | |||
self.tokenzier._reinit_on_new_vocab(new_word_piece_vocab) | |||
self.encoder.embeddings.word_embeddings = embed | |||
self.encoder.config.vocab_size = len(new_word_piece_vocab) | |||
if unsegment_count>0: | |||
if only_use_pretrain_bpe or new_add_to_bpe_vocab==0: | |||
logger.info(f"{unsegment_count} words are unsegmented.") | |||
else: | |||
logger.info(f"{unsegment_count} words are unsegmented. Among them, {new_add_to_bpe_vocab} added to the BPE vocab.") | |||
word_to_wordpieces = [] | |||
word_pieces_lengths = [] | |||
for word, index in vocab: | |||
@@ -250,6 +231,8 @@ class _RobertaWordModel(nn.Module): | |||
word = '<pad>' | |||
elif index == vocab.unknown_idx: | |||
word = '<unk>' | |||
elif vocab.word_count[word]<min_freq: | |||
word = '<unk>' | |||
word_pieces = self.tokenzier.tokenize(word) | |||
word_pieces = self.tokenzier.convert_tokens_to_ids(word_pieces) | |||
word_to_wordpieces.append(word_pieces) | |||
@@ -368,6 +351,10 @@ class _RobertaWordModel(nn.Module): | |||
# 3. 最终的embedding结果 | |||
return outputs | |||
def save(self, folder): | |||
self.tokenzier.save_pretrained(folder) | |||
self.encoder.save_pretrained(folder) | |||
class RobertaWordPieceEncoder(nn.Module): | |||
r""" | |||
@@ -380,7 +367,7 @@ class RobertaWordPieceEncoder(nn.Module): | |||
""" | |||
def __init__(self, model_dir_or_name: str = 'en', layers: str = '-1', pooled_cls: bool = False, | |||
word_dropout=0, dropout=0, requires_grad: bool = True): | |||
word_dropout=0, dropout=0, requires_grad: bool = True, **kwargs): | |||
r""" | |||
:param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased`` | |||
@@ -462,6 +449,36 @@ class RobertaWordPieceEncoder(nn.Module): | |||
words = words.masked_fill(mask, self._wordpiece_unk_index) | |||
return words | |||
def save(self, folder): | |||
os.makedirs(folder, exist_ok=True) | |||
hyper = {} | |||
hyper['layers'] = ','.join(map(str, self.model.layers)) | |||
hyper['dropout'] = self.dropout_layer.p | |||
hyper['word_dropout'] = self.word_dropout | |||
hyper['pooled_cls'] = self.model.pooled_cls | |||
hyper['requires_grad'] = bool(self.requires_grad) | |||
with open(os.path.join(folder, ROBERTA_ENCODER_HYPER), 'w', encoding='utf-8') as f: | |||
json.dump(hyper, f, indent=2) | |||
os.makedirs(os.path.join(folder, ROBERTA_ENCODER_FOLDER), exist_ok=True) | |||
self.model.save(os.path.join(folder, ROBERTA_ENCODER_FOLDER)) | |||
logger.debug(f"BertWordPieceEncoder has been saved in {folder}") | |||
@classmethod | |||
def load(cls, folder): | |||
for name in [ROBERTA_ENCODER_HYPER, ROBERTA_ENCODER_FOLDER]: | |||
assert os.path.exists(os.path.join(folder, name)), f"{name} not found in {folder}." | |||
with open(os.path.join(folder, ROBERTA_ENCODER_HYPER), 'r', encoding='utf-8') as f: | |||
hyper = json.load(f) | |||
model_dir_or_name = os.path.join(os.path.join(folder, ROBERTA_ENCODER_FOLDER)) | |||
bert_encoder = cls(model_dir_or_name=model_dir_or_name, **hyper) | |||
return bert_encoder | |||
class _WordPieceRobertaModel(nn.Module): | |||
def __init__(self, model_dir_or_name: str, layers: str = '-1', pooled_cls: bool=False): | |||
@@ -535,4 +552,8 @@ class _WordPieceRobertaModel(nn.Module): | |||
if l in (len(roberta_output)-1, -1) and self.pooled_cls: | |||
roberta_output[:, 0] = pooled_cls | |||
outputs[l_index] = roberta_output | |||
return outputs | |||
return outputs | |||
def save(self, folder): | |||
self.tokenzier.save_pretrained(folder) | |||
self.encoder.save_pretrained(folder) |
@@ -10,6 +10,8 @@ import os | |||
import warnings | |||
from collections import defaultdict | |||
from copy import deepcopy | |||
import json | |||
from typing import Union | |||
import numpy as np | |||
import torch | |||
@@ -19,7 +21,12 @@ from .embedding import TokenEmbedding | |||
from ..core import logger | |||
from ..core.vocabulary import Vocabulary | |||
from ..io.file_utils import PRETRAIN_STATIC_FILES, _get_embedding_url, cached_path | |||
from fastNLP.io.file_utils import _get_file_name_base_on_postfix | |||
from ..io.file_utils import _get_file_name_base_on_postfix | |||
VOCAB_FILENAME = 'vocab.txt' | |||
STATIC_HYPER_FILENAME = 'static_hyper.json' | |||
STATIC_EMBED_FILENAME = 'static.txt' | |||
class StaticEmbedding(TokenEmbedding): | |||
@@ -70,7 +77,7 @@ class StaticEmbedding(TokenEmbedding): | |||
""" | |||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', embedding_dim=-1, requires_grad: bool = True, | |||
def __init__(self, vocab: Vocabulary, model_dir_or_name: Union[str, None] = 'en', embedding_dim=-1, requires_grad: bool = True, | |||
init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1, **kwargs): | |||
r""" | |||
@@ -95,8 +102,8 @@ class StaticEmbedding(TokenEmbedding): | |||
""" | |||
super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||
if embedding_dim > 0: | |||
if model_dir_or_name is not None: | |||
warnings.warn(f"StaticEmbedding will ignore `model_dir_or_name`, and randomly initialize embedding with" | |||
if model_dir_or_name: | |||
logger.info(f"StaticEmbedding will ignore `model_dir_or_name`, and randomly initialize embedding with" | |||
f" dimension {embedding_dim}. If you want to use pre-trained embedding, " | |||
f"set `embedding_dim` to 0.") | |||
model_dir_or_name = None | |||
@@ -116,7 +123,9 @@ class StaticEmbedding(TokenEmbedding): | |||
model_path = _get_file_name_base_on_postfix(os.path.abspath(os.path.expanduser(model_dir_or_name)), '.txt') | |||
else: | |||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||
kwargs['min_freq'] = min_freq | |||
kwargs['lower'] = lower | |||
# 根据min_freq缩小vocab | |||
truncate_vocab = (vocab.min_freq is None and min_freq > 1) or (vocab.min_freq and vocab.min_freq < min_freq) | |||
if truncate_vocab: | |||
@@ -143,7 +152,7 @@ class StaticEmbedding(TokenEmbedding): | |||
truncated_words_to_words = torch.arange(len(vocab)).long() | |||
for word, index in vocab: | |||
truncated_words_to_words[index] = truncated_vocab.to_index(word) | |||
logger.info(f"{len(vocab) - len(truncated_vocab)} out of {len(vocab)} words have frequency less than {min_freq}.") | |||
logger.info(f"{len(vocab) - len(truncated_vocab)} words have frequency less than {min_freq}.") | |||
vocab = truncated_vocab | |||
self.only_use_pretrain_word = kwargs.get('only_use_pretrain_word', False) | |||
@@ -198,6 +207,7 @@ class StaticEmbedding(TokenEmbedding): | |||
sparse=False, _weight=embedding) | |||
self._embed_size = self.embedding.weight.size(1) | |||
self.requires_grad = requires_grad | |||
self.kwargs = kwargs | |||
@property | |||
def weight(self): | |||
@@ -321,3 +331,71 @@ class StaticEmbedding(TokenEmbedding): | |||
words = self.embedding(words) | |||
words = self.dropout(words) | |||
return words | |||
def save(self, folder): | |||
""" | |||
将embedding存储到folder下,之后可以通过使用load方法读取 | |||
:param str folder: 会在该folder下生成三个文件, vocab.txt, static_embed_hyper.txt, static_embed_hyper.json. | |||
其中vocab.txt可以用Vocabulary通过load读取; embedding.txt按照word2vec的方式存储,以空格的方式隔开元素, | |||
第一行只有两个元素,剩下的行首先是word然后是各个维度的值; static_embed_hyper.json是StaticEmbedding的超参数 | |||
:return: | |||
""" | |||
os.makedirs(folder, exist_ok=True) | |||
vocab = self.get_word_vocab() | |||
vocab_fp = os.path.join(folder, VOCAB_FILENAME) | |||
vocab.save(vocab_fp) | |||
kwargs = self.kwargs.copy() | |||
kwargs['dropout'] = self.dropout_layer.p | |||
kwargs['word_dropout'] = self.word_dropout | |||
kwargs['requires_grad'] = self.requires_grad | |||
kwargs['only_norm_found_vector'] = False | |||
kwargs['only_use_pretrain_word'] = True | |||
with open(os.path.join(folder, STATIC_HYPER_FILENAME), 'w', encoding='utf-8') as f: | |||
json.dump(kwargs, f, indent=2) | |||
with open(os.path.join(folder, STATIC_EMBED_FILENAME), 'w', encoding='utf-8') as f: | |||
f.write('{}\n'.format(' '*30)) # 留白之后再来填写 | |||
word_count = 0 | |||
saved_word = {} | |||
valid_word_count = 0 | |||
for i in range(len(self.words_to_words)): | |||
word = vocab.to_word(i) | |||
if not vocab._is_word_no_create_entry(word): | |||
word_count += 1 | |||
if kwargs['lower']: | |||
word = word.lower() | |||
if word in saved_word: | |||
continue | |||
saved_word[word] = 1 | |||
vec_i = self.words_to_words[i] | |||
if vec_i==vocab.unknown_idx and i!=vocab.unknown_idx: | |||
continue | |||
vec = self.embedding.weight.data[vec_i].tolist() | |||
vec_str = ' '.join(map(str, vec)) | |||
f.write(f'{word} {vec_str}\n') | |||
valid_word_count += 1 | |||
f.seek(0) | |||
f.write('{} {}'.format(valid_word_count, self.embedding_dim)) | |||
logger.debug(f"StaticEmbedding has been saved to {folder}.") | |||
@classmethod | |||
def load(cls, folder): | |||
""" | |||
:param str folder: 该folder下应该有以下三个文件vocab.txt, static_embed.txt, static_hyper.json | |||
:return: | |||
""" | |||
for name in [VOCAB_FILENAME, STATIC_EMBED_FILENAME, STATIC_HYPER_FILENAME]: | |||
assert os.path.exists(os.path.join(folder, name)), f"{name} not found in {folder}." | |||
vocab = Vocabulary.load(os.path.join(folder, VOCAB_FILENAME)) | |||
with open(os.path.join(folder, STATIC_HYPER_FILENAME), 'r', encoding='utf-8') as f: | |||
hyper = json.load(f) | |||
logger.info(f"Load StaticEmbedding from {folder}.") | |||
embed = cls(vocab=vocab, model_dir_or_name=os.path.join(folder, STATIC_EMBED_FILENAME), **hyper) | |||
return embed | |||
@@ -9,7 +9,8 @@ from torch import nn as nn | |||
from ..core.vocabulary import Vocabulary | |||
__all__ = [ | |||
'get_embeddings' | |||
'get_embeddings', | |||
'get_sinusoid_encoding_table' | |||
] | |||
@@ -31,7 +32,7 @@ def _construct_char_vocab_from_vocab(vocab: Vocabulary, min_freq: int = 1, inclu | |||
return char_vocab | |||
def get_embeddings(init_embed): | |||
def get_embeddings(init_embed, padding_idx=None): | |||
r""" | |||
根据输入的init_embed返回Embedding对象。如果输入是tuple, 则随机初始化一个nn.Embedding; 如果输入是numpy.ndarray, 则按照ndarray | |||
的值将nn.Embedding初始化; 如果输入是torch.Tensor, 则按该值初始化nn.Embedding; 如果输入是fastNLP中的embedding将不做处理 | |||
@@ -40,11 +41,12 @@ def get_embeddings(init_embed): | |||
:param init_embed: 可以是 tuple:(num_embedings, embedding_dim), 即embedding的大小和每个词的维度;也可以传入 | |||
nn.Embedding 对象, 此时就以传入的对象作为embedding; 传入np.ndarray也行,将使用传入的ndarray作为作为Embedding初始化; | |||
传入torch.Tensor, 将使用传入的值作为Embedding初始化。 | |||
:param padding_idx: 当传入tuple时,padding_idx有效 | |||
:return nn.Embedding: embeddings | |||
""" | |||
if isinstance(init_embed, tuple): | |||
res = nn.Embedding( | |||
num_embeddings=init_embed[0], embedding_dim=init_embed[1]) | |||
num_embeddings=init_embed[0], embedding_dim=init_embed[1], padding_idx=padding_idx) | |||
nn.init.uniform_(res.weight.data, a=-np.sqrt(3 / res.weight.data.size(1)), | |||
b=np.sqrt(3 / res.weight.data.size(1))) | |||
elif isinstance(init_embed, nn.Module): | |||
@@ -58,3 +60,32 @@ def get_embeddings(init_embed): | |||
raise TypeError( | |||
'invalid init_embed type: {}'.format((type(init_embed)))) | |||
return res | |||
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): | |||
""" | |||
sinusoid的embedding,其中position的表示中,偶数维(0,2,4,...)是sin, 奇数(1,3,5...)是cos | |||
:param int n_position: 一共多少个position | |||
:param int d_hid: 多少维度,需要为偶数 | |||
:param padding_idx: | |||
:return: torch.FloatTensor, shape为n_position x d_hid | |||
""" | |||
def cal_angle(position, hid_idx): | |||
return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) | |||
def get_posi_angle_vec(position): | |||
return [cal_angle(position, hid_j) for hid_j in range(d_hid)] | |||
sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) | |||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i | |||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 | |||
if padding_idx is not None: | |||
# zero vector for padding dimension | |||
sinusoid_table[padding_idx] = 0. | |||
return torch.FloatTensor(sinusoid_table) | |||
@@ -9,18 +9,18 @@ fastNLP 在 :mod:`~fastNLP.models` 模块中内置了如 :class:`~fastNLP.models | |||
""" | |||
__all__ = [ | |||
"CNNText", | |||
"SeqLabeling", | |||
"AdvSeqLabel", | |||
"BiLSTMCRF", | |||
"ESIM", | |||
"StarTransEnc", | |||
"STSeqLabel", | |||
"STNLICls", | |||
"STSeqCls", | |||
"BiaffineParser", | |||
"GraphParser", | |||
@@ -28,7 +28,13 @@ __all__ = [ | |||
"BertForSentenceMatching", | |||
"BertForMultipleChoice", | |||
"BertForTokenClassification", | |||
"BertForQuestionAnswering" | |||
"BertForQuestionAnswering", | |||
"TransformerSeq2SeqModel", | |||
"LSTMSeq2SeqModel", | |||
"Seq2SeqModel", | |||
'SequenceGeneratorModel' | |||
] | |||
from .base_model import BaseModel | |||
@@ -39,7 +45,9 @@ from .cnn_text_classification import CNNText | |||
from .sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF | |||
from .snli import ESIM | |||
from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel | |||
from .seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel, Seq2SeqModel | |||
from .seq2seq_generator import SequenceGeneratorModel | |||
import sys | |||
from ..doc_utils import doc_process | |||
doc_process(sys.modules[__name__]) | |||
doc_process(sys.modules[__name__]) |
@@ -39,7 +39,7 @@ from torch import nn | |||
from .base_model import BaseModel | |||
from ..core._logger import logger | |||
from ..core.const import Const | |||
from ..embeddings import BertEmbedding | |||
from ..embeddings.bert_embedding import BertEmbedding | |||
class BertForSequenceClassification(BaseModel): | |||
@@ -314,13 +314,8 @@ class BiaffineParser(GraphParser): | |||
raise ValueError('unsupported rnn_out_size: {} for transformer'.format(rnn_out_size)) | |||
self.position_emb = nn.Embedding(num_embeddings=self.max_len, | |||
embedding_dim=rnn_out_size, ) | |||
self.encoder = TransformerEncoder(num_layers=rnn_layers, | |||
model_size=rnn_out_size, | |||
inner_size=1024, | |||
key_size=d_k, | |||
value_size=d_v, | |||
num_head=n_head, | |||
dropout=dropout, ) | |||
self.encoder = TransformerEncoder( num_layers=rnn_layers, d_model=rnn_out_size, | |||
n_head=n_head, dim_ff=1024, dropout=dropout) | |||
else: | |||
raise ValueError('unsupported encoder type: {}'.format(encoder)) | |||
@@ -0,0 +1,62 @@ | |||
import torch | |||
from torch import nn | |||
from .seq2seq_model import Seq2SeqModel | |||
from ..modules.generator.seq2seq_generator import SequenceGenerator | |||
class SequenceGeneratorModel(nn.Module): | |||
""" | |||
用于封装Seq2SeqModel使其可以做生成任务 | |||
""" | |||
def __init__(self, seq2seq_model: Seq2SeqModel, bos_token_id, eos_token_id=None, max_length=30, num_beams=1, | |||
do_sample=True, temperature=1.0, top_k=50, top_p=1.0, | |||
repetition_penalty=1, length_penalty=1.0, pad_token_id=0): | |||
""" | |||
:param Seq2SeqModel seq2seq_model: 序列到序列模型 | |||
:param int,None bos_token_id: 句子开头的token id | |||
:param int,None eos_token_id: 句子结束的token id | |||
:param int max_length: 句子的最大长度 | |||
:param int num_beams: beam search的大小 | |||
:param bool do_sample: 是否通过采样的方式生成 | |||
:param float temperature: 只有在do_sample为True才有意义 | |||
:param int top_k: 只从top_k中采样 | |||
:param float top_p: 只从top_p的token中采样,nucles sample | |||
:param float repetition_penalty: 多大程度上惩罚重复的token | |||
:param float length_penalty: 对长度的惩罚,小于1鼓励长句,大于1鼓励短剧 | |||
:param int pad_token_id: 当某句话生成结束之后,之后生成的内容用pad_token_id补充 | |||
""" | |||
super().__init__() | |||
self.seq2seq_model = seq2seq_model | |||
self.generator = SequenceGenerator(seq2seq_model.decoder, max_length=max_length, num_beams=num_beams, | |||
do_sample=do_sample, temperature=temperature, top_k=top_k, top_p=top_p, | |||
bos_token_id=bos_token_id, | |||
eos_token_id=eos_token_id, | |||
repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||
pad_token_id=pad_token_id) | |||
def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None): | |||
""" | |||
透传调用seq2seq_model的forward | |||
:param torch.LongTensor src_tokens: bsz x max_len | |||
:param torch.LongTensor tgt_tokens: bsz x max_len' | |||
:param torch.LongTensor src_seq_len: bsz | |||
:param torch.LongTensor tgt_seq_len: bsz | |||
:return: | |||
""" | |||
return self.seq2seq_model(src_tokens, tgt_tokens, src_seq_len, tgt_seq_len) | |||
def predict(self, src_tokens, src_seq_len=None): | |||
""" | |||
给定source的内容,输出generate的内容 | |||
:param torch.LongTensor src_tokens: bsz x max_len | |||
:param torch.LongTensor src_seq_len: bsz | |||
:return: | |||
""" | |||
state = self.seq2seq_model.prepare_state(src_tokens, src_seq_len) | |||
result = self.generator.generate(state) | |||
return {'pred': result} |
@@ -0,0 +1,176 @@ | |||
r""" | |||
主要包含组成Sequence-to-Sequence的model | |||
""" | |||
import torch | |||
from torch import nn | |||
from ..embeddings import get_embeddings | |||
from ..embeddings.utils import get_sinusoid_encoding_table | |||
from ..modules.decoder.seq2seq_decoder import Seq2SeqDecoder, TransformerSeq2SeqDecoder, LSTMSeq2SeqDecoder | |||
from ..modules.encoder.seq2seq_encoder import Seq2SeqEncoder, TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder | |||
class Seq2SeqModel(nn.Module): | |||
def __init__(self, encoder: Seq2SeqEncoder, decoder: Seq2SeqDecoder): | |||
""" | |||
可以用于在Trainer中训练的Seq2Seq模型。正常情况下,继承了该函数之后,只需要实现classmethod build_model即可。 | |||
:param encoder: Encoder | |||
:param decoder: Decoder | |||
""" | |||
super().__init__() | |||
self.encoder = encoder | |||
self.decoder = decoder | |||
def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None): | |||
""" | |||
:param torch.LongTensor src_tokens: source的token | |||
:param torch.LongTensor tgt_tokens: target的token | |||
:param torch.LongTensor src_seq_len: src的长度 | |||
:param torch.LongTensor tgt_seq_len: target的长度,默认用不上 | |||
:return: {'pred': torch.Tensor}, 其中pred的shape为bsz x max_len x vocab_size | |||
""" | |||
state = self.prepare_state(src_tokens, src_seq_len) | |||
decoder_output = self.decoder(tgt_tokens, state) | |||
if isinstance(decoder_output, torch.Tensor): | |||
return {'pred': decoder_output} | |||
elif isinstance(decoder_output, (tuple, list)): | |||
return {'pred': decoder_output[0]} | |||
else: | |||
raise TypeError(f"Unsupported return type from Decoder:{type(self.decoder)}") | |||
def prepare_state(self, src_tokens, src_seq_len=None): | |||
""" | |||
调用encoder获取state,会把encoder的encoder_output, encoder_mask直接传入到decoder.init_state中初始化一个state | |||
:param src_tokens: | |||
:param src_seq_len: | |||
:return: | |||
""" | |||
encoder_output, encoder_mask = self.encoder(src_tokens, src_seq_len) | |||
state = self.decoder.init_state(encoder_output, encoder_mask) | |||
return state | |||
@classmethod | |||
def build_model(cls, *args, **kwargs): | |||
""" | |||
需要实现本方法来进行Seq2SeqModel的初始化 | |||
:return: | |||
""" | |||
raise NotImplemented | |||
class TransformerSeq2SeqModel(Seq2SeqModel): | |||
""" | |||
Encoder为TransformerSeq2SeqEncoder, decoder为TransformerSeq2SeqDecoder,通过build_model方法初始化 | |||
""" | |||
def __init__(self, encoder, decoder): | |||
super().__init__(encoder, decoder) | |||
@classmethod | |||
def build_model(cls, src_embed, tgt_embed=None, | |||
pos_embed='sin', max_position=1024, num_layers=6, d_model=512, n_head=8, dim_ff=2048, dropout=0.1, | |||
bind_encoder_decoder_embed=False, | |||
bind_decoder_input_output_embed=True): | |||
""" | |||
初始化一个TransformerSeq2SeqModel | |||
:param nn.Module, StaticEmbedding, Tuple[int, int] src_embed: source的embedding | |||
:param nn.Module, StaticEmbedding, Tuple[int, int] tgt_embed: target的embedding,如果bind_encoder_decoder_embed为 | |||
True,则不要输入该值 | |||
:param str pos_embed: 支持sin, learned两种 | |||
:param int max_position: 最大支持长度 | |||
:param int num_layers: encoder和decoder的层数 | |||
:param int d_model: encoder和decoder输入输出的大小 | |||
:param int n_head: encoder和decoder的head的数量 | |||
:param int dim_ff: encoder和decoder中FFN中间映射的维度 | |||
:param float dropout: Attention和FFN dropout的大小 | |||
:param bool bind_encoder_decoder_embed: 是否对encoder和decoder使用相同的embedding | |||
:param bool bind_decoder_input_output_embed: decoder的输出embedding是否与其输入embedding是一样的权重 | |||
:return: TransformerSeq2SeqModel | |||
""" | |||
if bind_encoder_decoder_embed and tgt_embed is not None: | |||
raise RuntimeError("If you set `bind_encoder_decoder_embed=True`, please do not provide `tgt_embed`.") | |||
src_embed = get_embeddings(src_embed) | |||
if bind_encoder_decoder_embed: | |||
tgt_embed = src_embed | |||
else: | |||
assert tgt_embed is not None, "You need to pass `tgt_embed` when `bind_encoder_decoder_embed=False`" | |||
tgt_embed = get_embeddings(tgt_embed) | |||
if pos_embed == 'sin': | |||
encoder_pos_embed = nn.Embedding.from_pretrained( | |||
get_sinusoid_encoding_table(max_position + 1, src_embed.embedding_dim, padding_idx=0), | |||
freeze=True) # 这里规定0是padding | |||
deocder_pos_embed = nn.Embedding.from_pretrained( | |||
get_sinusoid_encoding_table(max_position + 1, tgt_embed.embedding_dim, padding_idx=0), | |||
freeze=True) # 这里规定0是padding | |||
elif pos_embed == 'learned': | |||
encoder_pos_embed = get_embeddings((max_position + 1, src_embed.embedding_dim), padding_idx=0) | |||
deocder_pos_embed = get_embeddings((max_position + 1, src_embed.embedding_dim), padding_idx=1) | |||
else: | |||
raise ValueError("pos_embed only supports sin or learned.") | |||
encoder = TransformerSeq2SeqEncoder(embed=src_embed, pos_embed=encoder_pos_embed, | |||
num_layers=num_layers, d_model=d_model, n_head=n_head, dim_ff=dim_ff, | |||
dropout=dropout) | |||
decoder = TransformerSeq2SeqDecoder(embed=tgt_embed, pos_embed=deocder_pos_embed, | |||
d_model=d_model, num_layers=num_layers, n_head=n_head, dim_ff=dim_ff, | |||
dropout=dropout, | |||
bind_decoder_input_output_embed=bind_decoder_input_output_embed) | |||
return cls(encoder, decoder) | |||
class LSTMSeq2SeqModel(Seq2SeqModel): | |||
""" | |||
使用LSTMSeq2SeqEncoder和LSTMSeq2SeqDecoder的model | |||
""" | |||
def __init__(self, encoder, decoder): | |||
super().__init__(encoder, decoder) | |||
@classmethod | |||
def build_model(cls, src_embed, tgt_embed=None, | |||
num_layers = 3, hidden_size = 400, dropout = 0.3, bidirectional=True, | |||
attention=True, bind_encoder_decoder_embed=False, | |||
bind_decoder_input_output_embed=True): | |||
""" | |||
:param nn.Module, StaticEmbedding, Tuple[int, int] src_embed: source的embedding | |||
:param nn.Module, StaticEmbedding, Tuple[int, int] tgt_embed: target的embedding,如果bind_encoder_decoder_embed为 | |||
True,则不要输入该值 | |||
:param int num_layers: Encoder和Decoder的层数 | |||
:param int hidden_size: encoder和decoder的隐藏层大小 | |||
:param float dropout: 每层之间的Dropout的大小 | |||
:param bool bidirectional: encoder是否使用双向LSTM | |||
:param bool attention: decoder是否使用attention attend encoder在所有时刻的状态 | |||
:param bool bind_encoder_decoder_embed: 是否对encoder和decoder使用相同的embedding | |||
:param bool bind_decoder_input_output_embed: decoder的输出embedding是否与其输入embedding是一样的权重 | |||
:return: LSTMSeq2SeqModel | |||
""" | |||
if bind_encoder_decoder_embed and tgt_embed is not None: | |||
raise RuntimeError("If you set `bind_encoder_decoder_embed=True`, please do not provide `tgt_embed`.") | |||
src_embed = get_embeddings(src_embed) | |||
if bind_encoder_decoder_embed: | |||
tgt_embed = src_embed | |||
else: | |||
assert tgt_embed is not None, "You need to pass `tgt_embed` when `bind_encoder_decoder_embed=False`" | |||
tgt_embed = get_embeddings(tgt_embed) | |||
encoder = LSTMSeq2SeqEncoder(embed=src_embed, num_layers = num_layers, | |||
hidden_size = hidden_size, dropout = dropout, bidirectional=bidirectional) | |||
decoder = LSTMSeq2SeqDecoder(embed=tgt_embed, num_layers = num_layers, hidden_size = hidden_size, | |||
dropout = dropout, bind_decoder_input_output_embed = bind_decoder_input_output_embed, | |||
attention=attention) | |||
return cls(encoder, decoder) |
@@ -14,9 +14,9 @@ import torch.nn.functional as F | |||
from .base_model import BaseModel | |||
from ..core.const import Const as C | |||
from ..core.utils import seq_len_to_mask | |||
from ..embeddings import get_embeddings | |||
from ..modules import ConditionalRandomField | |||
from ..modules import LSTM | |||
from ..embeddings.utils import get_embeddings | |||
from ..modules.decoder import ConditionalRandomField | |||
from ..modules.encoder import LSTM | |||
from ..modules import decoder, encoder | |||
from ..modules.decoder.crf import allowed_transitions | |||
@@ -58,7 +58,21 @@ __all__ = [ | |||
"RobertaModel", | |||
"GPT2Model", | |||
"GPT2Tokenizer" | |||
"GPT2Tokenizer", | |||
"TransformerSeq2SeqEncoder", | |||
"LSTMSeq2SeqEncoder", | |||
"Seq2SeqEncoder", | |||
"TransformerSeq2SeqDecoder", | |||
"LSTMSeq2SeqDecoder", | |||
"Seq2SeqDecoder", | |||
"TransformerState", | |||
"LSTMState", | |||
"State", | |||
"SequenceGenerator" | |||
] | |||
import sys | |||
@@ -68,6 +82,7 @@ from . import encoder | |||
from .decoder import * | |||
from .dropout import TimestepDropout | |||
from .encoder import * | |||
from .generator import * | |||
from .utils import summary | |||
from ..doc_utils import doc_process | |||
from .tokenizer import * | |||
@@ -12,7 +12,8 @@ import torch | |||
import torch.nn.functional as F | |||
from torch import nn | |||
from fastNLP.modules.utils import initial_parameter | |||
from .utils import initial_parameter | |||
from .decoder.seq2seq_state import TransformerState | |||
class DotAttention(nn.Module): | |||
@@ -45,64 +46,153 @@ class DotAttention(nn.Module): | |||
class MultiHeadAttention(nn.Module): | |||
r""" | |||
Transformer当中的MultiHeadAttention | |||
""" | |||
Attention is all you need中提到的多头注意力 | |||
def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1): | |||
r""" | |||
:param input_size: int, 输入维度的大小。同时也是输出维度的大小。 | |||
:param key_size: int, 每个head的维度大小。 | |||
:param value_size: int,每个head中value的维度。 | |||
:param num_head: int,head的数量。 | |||
:param dropout: float。 | |||
""" | |||
""" | |||
def __init__(self, d_model: int = 512, n_head: int = 8, dropout: float = 0.0, layer_idx: int = None): | |||
super(MultiHeadAttention, self).__init__() | |||
self.input_size = input_size | |||
self.key_size = key_size | |||
self.value_size = value_size | |||
self.num_head = num_head | |||
in_size = key_size * num_head | |||
self.q_in = nn.Linear(input_size, in_size) | |||
self.k_in = nn.Linear(input_size, in_size) | |||
self.v_in = nn.Linear(input_size, in_size) | |||
self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=dropout) | |||
self.out = nn.Linear(value_size * num_head, input_size) | |||
self.d_model = d_model | |||
self.n_head = n_head | |||
self.dropout = dropout | |||
self.head_dim = d_model // n_head | |||
self.layer_idx = layer_idx | |||
assert d_model % n_head == 0, "d_model should be divisible by n_head" | |||
self.scaling = self.head_dim ** -0.5 | |||
self.q_proj = nn.Linear(d_model, d_model) | |||
self.k_proj = nn.Linear(d_model, d_model) | |||
self.v_proj = nn.Linear(d_model, d_model) | |||
self.out_proj = nn.Linear(d_model, d_model) | |||
self.reset_parameters() | |||
def forward(self, query, key, value, key_mask=None, attn_mask=None, state=None): | |||
""" | |||
:param query: batch x seq x dim | |||
:param key: batch x seq x dim | |||
:param value: batch x seq x dim | |||
:param key_mask: batch x seq 用于指示哪些key不要attend到;注意到mask为1的地方是要attend到的 | |||
:param attn_mask: seq x seq, 用于mask掉attention map。 主要是用在训练时decoder端的self attention,下三角为1 | |||
:param state: 过去的信息,在inference的时候会用到,比如encoder output、decoder的prev kv。这样可以减少计算。 | |||
:return: | |||
""" | |||
assert key.size() == value.size() | |||
if state is not None: | |||
assert self.layer_idx is not None | |||
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr() | |||
q = self.q_proj(query) # batch x seq x dim | |||
q *= self.scaling | |||
k = v = None | |||
prev_k = prev_v = None | |||
# 从state中取kv | |||
if isinstance(state, TransformerState): # 说明此时在inference阶段 | |||
if qkv_same: # 此时在decoder self attention | |||
prev_k = state.decoder_prev_key[self.layer_idx] | |||
prev_v = state.decoder_prev_value[self.layer_idx] | |||
else: # 此时在decoder-encoder attention,直接将保存下来的key装载起来即可 | |||
k = state.encoder_key[self.layer_idx] | |||
v = state.encoder_value[self.layer_idx] | |||
if k is None: | |||
k = self.k_proj(key) | |||
v = self.v_proj(value) | |||
if prev_k is not None: | |||
k = torch.cat((prev_k, k), dim=1) | |||
v = torch.cat((prev_v, v), dim=1) | |||
# 更新state | |||
if isinstance(state, TransformerState): | |||
if qkv_same: | |||
state.decoder_prev_key[self.layer_idx] = k | |||
state.decoder_prev_value[self.layer_idx] = v | |||
else: | |||
state.encoder_key[self.layer_idx] = k | |||
state.encoder_value[self.layer_idx] = v | |||
# 开始计算attention | |||
batch_size, q_len, d_model = query.size() | |||
k_len, v_len = k.size(1), v.size(1) | |||
q = q.reshape(batch_size, q_len, self.n_head, self.head_dim) | |||
k = k.reshape(batch_size, k_len, self.n_head, self.head_dim) | |||
v = v.reshape(batch_size, v_len, self.n_head, self.head_dim) | |||
attn_weights = torch.einsum('bqnh,bknh->bqkn', q, k) # bs,q_len,k_len,n_head | |||
if key_mask is not None: | |||
_key_mask = ~key_mask[:, None, :, None].bool() # batch,1,k_len,1 | |||
attn_weights = attn_weights.masked_fill(_key_mask, -float('inf')) | |||
if attn_mask is not None: | |||
_attn_mask = attn_mask[None, :, :, None].eq(0) # 1,q_len,k_len,n_head | |||
attn_weights = attn_weights.masked_fill(_attn_mask, -float('inf')) | |||
attn_weights = F.softmax(attn_weights, dim=2) | |||
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) | |||
output = torch.einsum('bqkn,bknh->bqnh', attn_weights, v) # batch,q_len,n_head,head_dim | |||
output = output.reshape(batch_size, q_len, -1) | |||
output = self.out_proj(output) # batch,q_len,dim | |||
return output, attn_weights | |||
def reset_parameters(self): | |||
sqrt = math.sqrt | |||
nn.init.normal_(self.q_in.weight, mean=0, std=sqrt(1.0 / self.input_size)) | |||
nn.init.normal_(self.k_in.weight, mean=0, std=sqrt(1.0 / self.input_size)) | |||
nn.init.normal_(self.v_in.weight, mean=0, std=sqrt(1.0 / self.input_size)) | |||
nn.init.normal_(self.out.weight, mean=0, std=sqrt(1.0 / self.input_size)) | |||
nn.init.xavier_uniform_(self.q_proj.weight) | |||
nn.init.xavier_uniform_(self.k_proj.weight) | |||
nn.init.xavier_uniform_(self.v_proj.weight) | |||
nn.init.xavier_uniform_(self.out_proj.weight) | |||
def set_layer_idx(self, layer_idx): | |||
self.layer_idx = layer_idx | |||
def forward(self, Q, K, V, atte_mask_out=None): | |||
r""" | |||
:param Q: [batch, seq_len_q, model_size] | |||
:param K: [batch, seq_len_k, model_size] | |||
:param V: [batch, seq_len_k, model_size] | |||
:param seq_mask: [batch, seq_len] | |||
class AttentionLayer(nn.Module): | |||
def __init__(selfu, input_size, key_dim, value_dim, bias=False): | |||
""" | |||
batch, sq, _ = Q.size() | |||
sk = K.size(1) | |||
d_k, d_v, n_head = self.key_size, self.value_size, self.num_head | |||
# input linear | |||
q = self.q_in(Q).view(batch, sq, n_head, d_k).transpose(1, 2) | |||
k = self.k_in(K).view(batch, sk, n_head, d_k).transpose(1, 2) | |||
v = self.v_in(V).view(batch, sk, n_head, d_v).transpose(1, 2) | |||
if atte_mask_out is not None: | |||
atte_mask_out = atte_mask_out[:,None,:,:] # [bsz,1,1,len] | |||
atte = self.attention(q, k, v, atte_mask_out).view(batch, n_head, sq, d_v) | |||
# concat all heads, do output linear | |||
atte = atte.transpose(1, 2).contiguous().view(batch, sq, -1) | |||
output = self.out(atte) | |||
return output | |||
可用于LSTM2LSTM的序列到序列模型的decode过程中,该attention是在decode过程中根据上一个step的hidden计算对encoder结果的attention | |||
:param int input_size: 输入的大小 | |||
:param int key_dim: 一般就是encoder_output输出的维度 | |||
:param int value_dim: 输出的大小维度, 一般就是decoder hidden的大小 | |||
:param bias: | |||
""" | |||
super().__init__() | |||
selfu.input_proj = nn.Linear(input_size, key_dim, bias=bias) | |||
selfu.output_proj = nn.Linear(input_size + key_dim, value_dim, bias=bias) | |||
def forward(self, input, encode_outputs, encode_mask): | |||
""" | |||
:param input: batch_size x input_size | |||
:param encode_outputs: batch_size x max_len x key_dim | |||
:param encode_mask: batch_size x max_len, 为0的地方为padding | |||
:return: hidden: batch_size x value_dim, scores: batch_size x max_len, normalized过的 | |||
""" | |||
# x: bsz x encode_hidden_size | |||
x = self.input_proj(input) | |||
# compute attention | |||
attn_scores = torch.matmul(encode_outputs, x.unsqueeze(-1)).squeeze(-1) # b x max_len | |||
# don't attend over padding | |||
if encode_mask is not None: | |||
attn_scores = attn_scores.float().masked_fill_( | |||
encode_mask.eq(0), | |||
float('-inf') | |||
).type_as(attn_scores) # FP16 support: cast to float and back | |||
attn_scores = F.softmax(attn_scores, dim=-1) # srclen x bsz | |||
# sum weighted sources | |||
x = torch.matmul(attn_scores.unsqueeze(1), encode_outputs).squeeze(1) # b x encode_hidden_size | |||
x = torch.tanh(self.output_proj(torch.cat((x, input), dim=1))) | |||
return x, attn_scores | |||
def _masked_softmax(tensor, mask): |
@@ -6,10 +6,20 @@ __all__ = [ | |||
"MLP", | |||
"ConditionalRandomField", | |||
"viterbi_decode", | |||
"allowed_transitions" | |||
"allowed_transitions", | |||
"LSTMState", | |||
"TransformerState", | |||
"State", | |||
"TransformerSeq2SeqDecoder", | |||
"LSTMSeq2SeqDecoder", | |||
"Seq2SeqDecoder" | |||
] | |||
from .crf import ConditionalRandomField | |||
from .crf import allowed_transitions | |||
from .mlp import MLP | |||
from .utils import viterbi_decode | |||
from .seq2seq_decoder import Seq2SeqDecoder, LSTMSeq2SeqDecoder, TransformerSeq2SeqDecoder | |||
from .seq2seq_state import State, LSTMState, TransformerState |
@@ -1,109 +1,413 @@ | |||
# coding=utf-8 | |||
__all__ = [ | |||
"TransformerPast", | |||
"Past", | |||
"Decoder" | |||
] | |||
from typing import Union, Tuple | |||
import math | |||
import torch | |||
from torch import nn | |||
import abc | |||
import torch.nn.functional as F | |||
from ..attention import AttentionLayer, MultiHeadAttention | |||
from ...embeddings import StaticEmbedding | |||
import numpy as np | |||
from typing import Union, Tuple | |||
from ...embeddings.utils import get_embeddings | |||
from torch.nn import LayerNorm | |||
import math | |||
from .seq2seq_state import State, LSTMState, TransformerState | |||
class Seq2SeqDecoder(nn.Module): | |||
""" | |||
Sequence-to-Sequence Decoder的基类。一定需要实现forward函数,剩下的函数根据需要实现。每个Seq2SeqDecoder都应该有相应的State对象 | |||
用来承载该Decoder所需要的Encoder输出、Decoder需要记录的历史信息(例如LSTM的hidden信息)。 | |||
class Past: | |||
""" | |||
def __init__(self): | |||
pass | |||
super().__init__() | |||
def forward(self, tokens, state, **kwargs): | |||
""" | |||
@abc.abstractmethod | |||
def num_samples(self): | |||
pass | |||
:param torch.LongTensor tokens: bsz x max_len | |||
:param State state: state包含了encoder的输出以及decode之前的内容 | |||
:return: 返回值可以为bsz x max_len x vocab_size的Tensor,也可以是一个list,但是第一个元素必须是词的预测分布 | |||
""" | |||
raise NotImplemented | |||
@abc.abstractmethod | |||
def reorder_past(self, indices: torch.LongTensor): | |||
def reorder_states(self, indices, states): | |||
""" | |||
根据indices中的index,将past的中状态置为正确的顺序。inplace改变 | |||
根据indices重新排列states中的状态,在beam search进行生成时,会用到该函数。 | |||
:param torch.LongTensor indices: | |||
:param Past past: | |||
:param State states: | |||
:return: | |||
""" | |||
raise NotImplemented | |||
assert isinstance(states, State), f"`states` should be of type State instead of {type(states)}" | |||
states.reorder_state(indices) | |||
def init_state(self, encoder_output, encoder_mask): | |||
""" | |||
初始化一个state对象,用来记录了encoder的输出以及decode已经完成的部分。 | |||
:param Union[torch.Tensor, list, tuple] encoder_output: 如果不为None,内部元素需要为torch.Tensor, 默认其中第一维是batch | |||
维度 | |||
:param Union[torch.Tensor, list, tuple] encoder_mask: 如果部位None,内部元素需要torch.Tensor, 默认其中第一维是batch | |||
维度 | |||
:param kwargs: | |||
:return: State, 返回一个State对象,记录了encoder的输出 | |||
""" | |||
state = State(encoder_output, encoder_mask) | |||
return state | |||
class TransformerPast(Past): | |||
def __init__(self, encoder_outputs: torch.Tensor = None, encoder_mask: torch.Tensor = None, | |||
num_decoder_layer: int = 6): | |||
def decode(self, tokens, state): | |||
""" | |||
根据states中的内容,以及tokens中的内容进行之后的生成。 | |||
:param encoder_outputs: (batch,src_seq_len,dim) | |||
:param encoder_mask: (batch,src_seq_len) | |||
:param encoder_key: list of (batch, src_seq_len, dim) | |||
:param encoder_value: | |||
:param decoder_prev_key: | |||
:param decoder_prev_value: | |||
:param torch.LongTensor tokens: bsz x max_len, 上一个时刻的token输出。 | |||
:param State state: 记录了encoder输出与decoder过去状态 | |||
:return: torch.FloatTensor: bsz x vocab_size, 输出的是下一个时刻的分布 | |||
""" | |||
outputs = self(state=state, tokens=tokens) | |||
if isinstance(outputs, torch.Tensor): | |||
return outputs[:, -1] | |||
else: | |||
raise RuntimeError("Unrecognized output from the `forward()` function. Please override the `decode()` function.") | |||
class TiedEmbedding(nn.Module): | |||
""" | |||
用于将weight和原始weight绑定 | |||
""" | |||
def __init__(self, weight): | |||
super().__init__() | |||
self.encoder_outputs = encoder_outputs | |||
self.encoder_mask = encoder_mask | |||
self.encoder_key = [None] * num_decoder_layer | |||
self.encoder_value = [None] * num_decoder_layer | |||
self.decoder_prev_key = [None] * num_decoder_layer | |||
self.decoder_prev_value = [None] * num_decoder_layer | |||
def num_samples(self): | |||
if self.encoder_outputs is not None: | |||
return self.encoder_outputs.size(0) | |||
return None | |||
def _reorder_state(self, state, indices): | |||
if type(state) == torch.Tensor: | |||
state = state.index_select(index=indices, dim=0) | |||
elif type(state) == list: | |||
for i in range(len(state)): | |||
assert state[i] is not None | |||
state[i] = state[i].index_select(index=indices, dim=0) | |||
self.weight = weight # vocab_size x embed_size | |||
def forward(self, x): | |||
""" | |||
:param torch.FloatTensor x: bsz x * x embed_size | |||
:return: torch.FloatTensor bsz x * x vocab_size | |||
""" | |||
return torch.matmul(x, self.weight.t()) | |||
def get_binded_decoder_output_embed(embed): | |||
""" | |||
给定一个embedding,输出对应的绑定的embedding,输出对象为TiedEmbedding | |||
:param embed: | |||
:return: | |||
""" | |||
if isinstance(embed, StaticEmbedding): | |||
for idx, map2idx in enumerate(embed.words_to_words): | |||
assert idx == map2idx, "Invalid StaticEmbedding for Decoder, please check:(1) whether the vocabulary " \ | |||
"include `no_create_entry=True` word; (2) StaticEmbedding should not initialize with " \ | |||
"`lower=True` or `min_freq!=1`." | |||
elif not isinstance(embed, nn.Embedding): | |||
raise TypeError("Only nn.Embedding or StaticEmbedding is allowed for binding.") | |||
return TiedEmbedding(embed.weight) | |||
class LSTMSeq2SeqDecoder(Seq2SeqDecoder): | |||
def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], num_layers = 3, hidden_size = 300, | |||
dropout = 0.3, bind_decoder_input_output_embed = True, attention=True): | |||
""" | |||
LSTM的Decoder | |||
:param nn.Module,tuple embed: decoder输入的embedding. | |||
:param int num_layers: 多少层LSTM | |||
:param int hidden_size: 隐藏层大小, 该值也被认为是encoder的输出维度大小 | |||
:param dropout: Dropout的大小 | |||
:param bool bind_decoder_input_output_embed: 是否将输出层和输入层的词向量绑定在一起(即为同一个),若embed为StaticEmbedding, | |||
则StaticEmbedding的vocab不能包含no_create_entry的token,同时StaticEmbedding初始化时lower为False, min_freq=1. | |||
:param bool attention: 是否使用attention | |||
""" | |||
super().__init__() | |||
self.embed = get_embeddings(init_embed=embed) | |||
self.embed_dim = embed.embedding_dim | |||
if bind_decoder_input_output_embed: | |||
self.output_layer = get_binded_decoder_output_embed(self.embed) | |||
else: # 不需要bind | |||
self.output_embed = get_embeddings((self.embed.num_embeddings, self.embed.embedding_dim)) | |||
self.output_layer = TiedEmbedding(self.output_embed.weight) | |||
self.hidden_size = hidden_size | |||
self.num_layers = num_layers | |||
self.lstm = nn.LSTM(input_size=self.embed_dim + hidden_size, hidden_size=hidden_size, num_layers=num_layers, | |||
batch_first=True, bidirectional=False, dropout=dropout if num_layers>1 else 0) | |||
self.attention_layer = AttentionLayer(hidden_size, hidden_size, hidden_size) if attention else None | |||
self.output_proj = nn.Linear(hidden_size, self.embed_dim) | |||
self.dropout_layer = nn.Dropout(dropout) | |||
def forward(self, tokens, state, return_attention=False): | |||
""" | |||
:param torch.LongTensor tokens: batch x max_len | |||
:param LSTMState state: 保存encoder输出和decode状态的State对象 | |||
:param bool return_attention: 是否返回attention的的score | |||
:return: bsz x max_len x vocab_size; 如果return_attention=True, 还会返回bsz x max_len x encode_length | |||
""" | |||
src_output = state.encoder_output | |||
encoder_mask = state.encoder_mask | |||
assert tokens.size(1)>state.decode_length, "The state does not match the tokens." | |||
tokens = tokens[:, state.decode_length:] | |||
x = self.embed(tokens) | |||
attn_weights = [] if self.attention_layer is not None else None # 保存attention weight, batch,tgt_seq,src_seq | |||
input_feed = state.input_feed | |||
decoder_out = [] | |||
cur_hidden = state.hidden | |||
cur_cell = state.cell | |||
# 开始计算 | |||
for i in range(tokens.size(1)): | |||
input = torch.cat( | |||
(x[:, i:i + 1, :], | |||
input_feed[:, None, :] | |||
), | |||
dim=2 | |||
) # batch,1,2*dim | |||
_, (cur_hidden, cur_cell) = self.lstm(input, hx=(cur_hidden, cur_cell)) # hidden/cell保持原来的size | |||
if self.attention_layer is not None: | |||
input_feed, attn_weight = self.attention_layer(cur_hidden[-1], src_output, encoder_mask) | |||
attn_weights.append(attn_weight) | |||
else: | |||
input_feed = cur_hidden[-1] | |||
state.input_feed = input_feed # batch, hidden | |||
state.hidden = cur_hidden | |||
state.cell = cur_cell | |||
state.decode_length += 1 | |||
decoder_out.append(input_feed) | |||
decoder_out = torch.stack(decoder_out, dim=1) # batch,seq_len,hidden | |||
decoder_out = self.dropout_layer(decoder_out) | |||
if attn_weights is not None: | |||
attn_weights = torch.cat(attn_weights, dim=1) # batch, tgt_len, src_len | |||
decoder_out = self.output_proj(decoder_out) | |||
feats = self.output_layer(decoder_out) | |||
if return_attention: | |||
return feats, attn_weights | |||
return feats | |||
def init_state(self, encoder_output, encoder_mask) -> LSTMState: | |||
""" | |||
:param encoder_output: 输入可以有两种情况(1) 输入为一个tuple,包含三个内容(encoder_output, (hidden, cell)),其中encoder_output: | |||
bsz x max_len x hidden_size, hidden: bsz x hidden_size, cell:bsz x hidden_size,一般使用LSTMEncoder的最后一层的 | |||
hidden state和cell state来赋值这两个值 | |||
(2) 只有encoder_output: bsz x max_len x hidden_size, 这种情况下hidden和cell使用0初始化 | |||
:param torch.ByteTensor encoder_mask: bsz x max_len, 为0的位置是padding, 用来指示source中哪些不需要attend | |||
:return: | |||
""" | |||
if not isinstance(encoder_output, torch.Tensor): | |||
encoder_output, (hidden, cell) = encoder_output | |||
else: | |||
raise ValueError('State does not support other format') | |||
hidden = cell = None | |||
assert encoder_output.ndim==3 | |||
assert encoder_mask.size()==encoder_output.size()[:2] | |||
assert encoder_output.size(-1)==self.hidden_size, "The dimension of encoder outputs should be the same with " \ | |||
"the hidden_size." | |||
t = [hidden, cell] | |||
for idx in range(2): | |||
v = t[idx] | |||
if v is None: | |||
v = encoder_output.new_zeros(self.num_layers, encoder_output.size(0), self.hidden_size) | |||
else: | |||
assert v.dim()==2 | |||
assert v.size(-1)==self.hidden_size | |||
v = v[None].repeat(self.num_layers, 1, 1) # num_layers x bsz x hidden_size | |||
t[idx] = v | |||
state = LSTMState(encoder_output, encoder_mask, t[0], t[1]) | |||
return state | |||
def reorder_past(self, indices: torch.LongTensor): | |||
self.encoder_outputs = self._reorder_state(self.encoder_outputs, indices) | |||
self.encoder_mask = self._reorder_state(self.encoder_mask, indices) | |||
self.encoder_key = self._reorder_state(self.encoder_key, indices) | |||
self.encoder_value = self._reorder_state(self.encoder_value, indices) | |||
self.decoder_prev_key = self._reorder_state(self.decoder_prev_key, indices) | |||
self.decoder_prev_value = self._reorder_state(self.decoder_prev_value, indices) | |||
return self | |||
class TransformerSeq2SeqDecoderLayer(nn.Module): | |||
def __init__(self, d_model = 512, n_head = 8, dim_ff = 2048, dropout = 0.1, layer_idx = None): | |||
""" | |||
class Decoder(nn.Module): | |||
def __init__(self): | |||
:param int d_model: 输入、输出的维度 | |||
:param int n_head: 多少个head,需要能被d_model整除 | |||
:param int dim_ff: | |||
:param float dropout: | |||
:param int layer_idx: layer的编号 | |||
""" | |||
super().__init__() | |||
self.d_model = d_model | |||
self.n_head = n_head | |||
self.dim_ff = dim_ff | |||
self.dropout = dropout | |||
self.layer_idx = layer_idx # 记录layer的层索引,以方便获取state的信息 | |||
self.self_attn = MultiHeadAttention(d_model, n_head, dropout, layer_idx) | |||
self.self_attn_layer_norm = nn.LayerNorm(d_model) | |||
self.encoder_attn = MultiHeadAttention(d_model, n_head, dropout, layer_idx) | |||
self.encoder_attn_layer_norm = nn.LayerNorm(d_model) | |||
self.ffn = nn.Sequential(nn.Linear(self.d_model, self.dim_ff), | |||
nn.ReLU(), | |||
nn.Dropout(dropout), | |||
nn.Linear(self.dim_ff, self.d_model), | |||
nn.Dropout(dropout)) | |||
self.final_layer_norm = nn.LayerNorm(self.d_model) | |||
@abc.abstractmethod | |||
def decode(self, *args, **kwargs) -> Tuple[torch.Tensor, Past]: | |||
def forward(self, x, encoder_output, encoder_mask=None, self_attn_mask=None, state=None): | |||
""" | |||
当模型进行解码时,使用这个函数。返回一个batch_size x vocab_size的结果与更新的Past状态。需要考虑一种特殊情况,即tokens长度不是1,即给定了 | |||
解码句子开头的情况,这种情况需要查看Past中是否正确计算了decode的状态。 | |||
:return: tensor:batch_size x vocab_size, past: Past | |||
:param x: (batch, seq_len, dim), decoder端的输入 | |||
:param encoder_output: (batch,src_seq_len,dim) | |||
:param encoder_mask: batch,src_seq_len | |||
:param self_attn_mask: seq_len, seq_len,下三角的mask矩阵,只在训练时传入 | |||
:param TransformerState state: 只在inference阶段传入 | |||
:return: | |||
""" | |||
raise NotImplemented | |||
@abc.abstractmethod | |||
def reorder_past(self, indices: torch.LongTensor, past: Past): | |||
# self attention part | |||
residual = x | |||
x = self.self_attn_layer_norm(x) | |||
x, _ = self.self_attn(query=x, | |||
key=x, | |||
value=x, | |||
attn_mask=self_attn_mask, | |||
state=state) | |||
x = F.dropout(x, p=self.dropout, training=self.training) | |||
x = residual + x | |||
# encoder attention part | |||
residual = x | |||
x = self.encoder_attn_layer_norm(x) | |||
x, attn_weight = self.encoder_attn(query=x, | |||
key=encoder_output, | |||
value=encoder_output, | |||
key_mask=encoder_mask, | |||
state=state) | |||
x = F.dropout(x, p=self.dropout, training=self.training) | |||
x = residual + x | |||
# ffn | |||
residual = x | |||
x = self.final_layer_norm(x) | |||
x = self.ffn(x) | |||
x = residual + x | |||
return x, attn_weight | |||
class TransformerSeq2SeqDecoder(Seq2SeqDecoder): | |||
def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], pos_embed: nn.Module = None, | |||
d_model = 512, num_layers=6, n_head = 8, dim_ff = 2048, dropout = 0.1, | |||
bind_decoder_input_output_embed = True): | |||
""" | |||
根据indices中的index,将past的中状态置为正确的顺序。inplace改变 | |||
:param torch.LongTensor indices: | |||
:param Past past: | |||
:return: | |||
:param embed: 输入token的embedding | |||
:param nn.Module pos_embed: 位置embedding | |||
:param int d_model: 输出、输出的大小 | |||
:param int num_layers: 多少层 | |||
:param int n_head: 多少个head | |||
:param int dim_ff: FFN 的中间大小 | |||
:param float dropout: Self-Attention和FFN中的dropout的大小 | |||
:param bool bind_decoder_input_output_embed: 是否将输出层和输入层的词向量绑定在一起(即为同一个),若embed为StaticEmbedding, | |||
则StaticEmbedding的vocab不能包含no_create_entry的token,同时StaticEmbedding初始化时lower为False, min_freq=1. | |||
""" | |||
super().__init__() | |||
self.embed = get_embeddings(embed) | |||
self.pos_embed = pos_embed | |||
if bind_decoder_input_output_embed: | |||
self.output_layer = get_binded_decoder_output_embed(self.embed) | |||
else: # 不需要bind | |||
self.output_embed = get_embeddings((self.embed.num_embeddings, self.embed.embedding_dim)) | |||
self.output_layer = TiedEmbedding(self.output_embed.weight) | |||
self.num_layers = num_layers | |||
self.d_model = d_model | |||
self.n_head = n_head | |||
self.dim_ff = dim_ff | |||
self.dropout = dropout | |||
self.input_fc = nn.Linear(self.embed.embedding_dim, d_model) | |||
self.layer_stacks = nn.ModuleList([TransformerSeq2SeqDecoderLayer(d_model, n_head, dim_ff, dropout, layer_idx) | |||
for layer_idx in range(num_layers)]) | |||
self.embed_scale = math.sqrt(d_model) | |||
self.layer_norm = nn.LayerNorm(d_model) | |||
self.output_fc = nn.Linear(self.d_model, self.embed.embedding_dim) | |||
def forward(self, tokens, state, return_attention=False): | |||
""" | |||
raise NotImplemented | |||
:param torch.LongTensor tokens: batch x tgt_len,decode的词 | |||
:param TransformerState state: 用于记录encoder的输出以及decode状态的对象,可以通过init_state()获取 | |||
:param bool return_attention: 是否返回对encoder结果的attention score | |||
:return: bsz x max_len x vocab_size; 如果return_attention=True, 还会返回bsz x max_len x encode_length | |||
""" | |||
encoder_output = state.encoder_output | |||
encoder_mask = state.encoder_mask | |||
assert state.decode_length<tokens.size(1), "The decoded tokens in State should be less than tokens." | |||
tokens = tokens[:, state.decode_length:] | |||
device = tokens.device | |||
x = self.embed_scale * self.embed(tokens) | |||
if self.pos_embed is not None: | |||
position = torch.arange(state.decode_length, state.decode_length+tokens.size(1)).long().to(device)[None] | |||
x += self.pos_embed(position) | |||
x = self.input_fc(x) | |||
x = F.dropout(x, p=self.dropout, training=self.training) | |||
batch_size, max_tgt_len = tokens.size() | |||
if max_tgt_len>1: | |||
triangle_mask = self._get_triangle_mask(tokens) | |||
else: | |||
triangle_mask = None | |||
for layer in self.layer_stacks: | |||
x, attn_weight = layer(x=x, | |||
encoder_output=encoder_output, | |||
encoder_mask=encoder_mask, | |||
self_attn_mask=triangle_mask, | |||
state=state | |||
) | |||
x = self.layer_norm(x) # batch, tgt_len, dim | |||
x = self.output_fc(x) | |||
feats = self.output_layer(x) | |||
if return_attention: | |||
return feats, attn_weight | |||
return feats | |||
def init_state(self, encoder_output, encoder_mask): | |||
""" | |||
初始化一个TransformerState用于forward | |||
:param torch.FloatTensor encoder_output: bsz x max_len x d_model, encoder的输出 | |||
:param torch.ByteTensor encoder_mask: bsz x max_len, 为1的位置需要attend。 | |||
:return: TransformerState | |||
""" | |||
if isinstance(encoder_output, torch.Tensor): | |||
encoder_output = encoder_output | |||
elif isinstance(encoder_output, (list, tuple)): | |||
encoder_output = encoder_output[0] # 防止是LSTMEncoder的输出结果 | |||
else: | |||
raise TypeError("Unsupported `encoder_output` for TransformerSeq2SeqDecoder") | |||
state = TransformerState(encoder_output, encoder_mask, num_decoder_layer=self.num_layers) | |||
return state | |||
@staticmethod | |||
def _get_triangle_mask(tokens): | |||
tensor = tokens.new_ones(tokens.size(1), tokens.size(1)) | |||
return torch.tril(tensor).byte() | |||
@@ -0,0 +1,145 @@ | |||
r""" | |||
每个Decoder都有对应的State用来记录encoder的输出以及Decode的历史记录 | |||
""" | |||
__all__ = [ | |||
'State', | |||
"LSTMState", | |||
"TransformerState" | |||
] | |||
from typing import Union | |||
import torch | |||
class State: | |||
def __init__(self, encoder_output=None, encoder_mask=None, **kwargs): | |||
""" | |||
每个Decoder都有对应的State对象用来承载encoder的输出以及当前时刻之前的decode状态。 | |||
:param Union[torch.Tensor, list, tuple] encoder_output: 如果不为None,内部元素需要为torch.Tensor, 默认其中第一维是batch | |||
维度 | |||
:param Union[torch.Tensor, list, tuple] encoder_mask: 如果部位None,内部元素需要torch.Tensor, 默认其中第一维是batch | |||
维度 | |||
:param kwargs: | |||
""" | |||
self.encoder_output = encoder_output | |||
self.encoder_mask = encoder_mask | |||
self._decode_length = 0 | |||
@property | |||
def num_samples(self): | |||
""" | |||
返回的State中包含的是多少个sample的encoder状态,主要用于Generate的时候确定batch的大小。 | |||
:return: | |||
""" | |||
if self.encoder_output is not None: | |||
return self.encoder_output.size(0) | |||
else: | |||
return None | |||
@property | |||
def decode_length(self): | |||
""" | |||
当前Decode到哪个token了,decoder只会从decode_length之后的token开始decode, 为0说明还没开始decode。 | |||
:return: | |||
""" | |||
return self._decode_length | |||
@decode_length.setter | |||
def decode_length(self, value): | |||
self._decode_length = value | |||
def _reorder_state(self, state: Union[torch.Tensor, list, tuple], indices: torch.LongTensor, dim: int = 0): | |||
if isinstance(state, torch.Tensor): | |||
state = state.index_select(index=indices, dim=dim) | |||
elif isinstance(state, list): | |||
for i in range(len(state)): | |||
assert state[i] is not None | |||
state[i] = self._reorder_state(state[i], indices, dim) | |||
elif isinstance(state, tuple): | |||
tmp_list = [] | |||
for i in range(len(state)): | |||
assert state[i] is not None | |||
tmp_list.append(self._reorder_state(state[i], indices, dim)) | |||
state = tuple(tmp_list) | |||
else: | |||
raise TypeError(f"Cannot reorder data of type:{type(state)}") | |||
return state | |||
def reorder_state(self, indices: torch.LongTensor): | |||
if self.encoder_mask is not None: | |||
self.encoder_mask = self._reorder_state(self.encoder_mask, indices) | |||
if self.encoder_output is not None: | |||
self.encoder_output = self._reorder_state(self.encoder_output, indices) | |||
class LSTMState(State): | |||
def __init__(self, encoder_output, encoder_mask, hidden, cell): | |||
""" | |||
LSTMDecoder对应的State,保存encoder的输出以及LSTM解码过程中的一些中间状态 | |||
:param torch.FloatTensor encoder_output: bsz x src_seq_len x encode_output_size,encoder的输出 | |||
:param torch.BoolTensor encoder_mask: bsz x src_seq_len, 为0的地方是padding | |||
:param torch.FloatTensor hidden: num_layers x bsz x hidden_size, 上个时刻的hidden状态 | |||
:param torch.FloatTensor cell: num_layers x bsz x hidden_size, 上个时刻的cell状态 | |||
""" | |||
super().__init__(encoder_output, encoder_mask) | |||
self.hidden = hidden | |||
self.cell = cell | |||
self._input_feed = hidden[0] # 默认是上一个时刻的输出 | |||
@property | |||
def input_feed(self): | |||
""" | |||
LSTMDecoder中每个时刻的输入会把上个token的embedding和input_feed拼接起来输入到下个时刻,在LSTMDecoder不使用attention时, | |||
input_feed即上个时刻的hidden state, 否则是attention layer的输出。 | |||
:return: torch.FloatTensor, bsz x hidden_size | |||
""" | |||
return self._input_feed | |||
@input_feed.setter | |||
def input_feed(self, value): | |||
self._input_feed = value | |||
def reorder_state(self, indices: torch.LongTensor): | |||
super().reorder_state(indices) | |||
self.hidden = self._reorder_state(self.hidden, indices, dim=1) | |||
self.cell = self._reorder_state(self.cell, indices, dim=1) | |||
if self.input_feed is not None: | |||
self.input_feed = self._reorder_state(self.input_feed, indices, dim=0) | |||
class TransformerState(State): | |||
def __init__(self, encoder_output, encoder_mask, num_decoder_layer): | |||
""" | |||
与TransformerSeq2SeqDecoder对应的State, | |||
:param torch.FloatTensor encoder_output: bsz x encode_max_len x encoder_output_size, encoder的输出 | |||
:param torch.ByteTensor encoder_mask: bsz x encode_max_len 为1的地方需要attend | |||
:param int num_decoder_layer: decode有多少层 | |||
""" | |||
super().__init__(encoder_output, encoder_mask) | |||
self.encoder_key = [None] * num_decoder_layer # 每一个元素 bsz x encoder_max_len x key_dim | |||
self.encoder_value = [None] * num_decoder_layer # 每一个元素 bsz x encoder_max_len x value_dim | |||
self.decoder_prev_key = [None] * num_decoder_layer # 每一个元素 bsz x decode_length x key_dim | |||
self.decoder_prev_value = [None] * num_decoder_layer # 每一个元素 bsz x decode_length x key_dim | |||
def reorder_state(self, indices: torch.LongTensor): | |||
super().reorder_state(indices) | |||
self.encoder_key = self._reorder_state(self.encoder_key, indices) | |||
self.encoder_value = self._reorder_state(self.encoder_value, indices) | |||
self.decoder_prev_key = self._reorder_state(self.decoder_prev_key, indices) | |||
self.decoder_prev_value = self._reorder_state(self.decoder_prev_value, indices) | |||
@property | |||
def decode_length(self): | |||
if self.decoder_prev_key[0] is not None: | |||
return self.decoder_prev_key[0].size(1) | |||
return 0 | |||
@@ -4,8 +4,6 @@ r""" | |||
""" | |||
__all__ = [ | |||
# "BertModel", | |||
"ConvolutionCharEncoder", | |||
"LSTMCharEncoder", | |||
@@ -35,10 +33,14 @@ __all__ = [ | |||
"RobertaModel", | |||
"GPT2Model" | |||
"GPT2Model", | |||
"LSTMSeq2SeqEncoder", | |||
"TransformerSeq2SeqEncoder", | |||
"Seq2SeqEncoder" | |||
] | |||
from .attention import MultiHeadAttention, BiAttention, SelfAttention | |||
from fastNLP.modules.attention import MultiHeadAttention, BiAttention, SelfAttention | |||
from .bert import BertModel | |||
from .roberta import RobertaModel | |||
from .gpt2 import GPT2Model | |||
@@ -49,3 +51,4 @@ from .pooling import MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, KMaxPoo | |||
from .star_transformer import StarTransformer | |||
from .transformer import TransformerEncoder | |||
from .variational_rnn import VarRNN, VarLSTM, VarGRU | |||
from .seq2seq_encoder import LSTMSeq2SeqEncoder, TransformerSeq2SeqEncoder, Seq2SeqEncoder |
@@ -10,6 +10,7 @@ __all__ = [ | |||
import copy | |||
import json | |||
import math | |||
import os | |||
import torch | |||
from torch import nn | |||
@@ -20,7 +21,8 @@ from ...io.file_utils import _get_bert_dir | |||
from ...core import logger | |||
CONFIG_FILE = 'bert_config.json' | |||
CONFIG_FILE = 'config.json' | |||
WEIGHTS_NAME = 'pytorch_model.bin' | |||
BERT_KEY_RENAME_MAP_1 = { | |||
'gamma': 'weight', | |||
@@ -57,7 +59,8 @@ class BertConfig(object): | |||
max_position_embeddings=512, | |||
type_vocab_size=2, | |||
initializer_range=0.02, | |||
layer_norm_eps=1e-12): | |||
layer_norm_eps=1e-12, | |||
architectures='bert'): | |||
r"""Constructs BertConfig. | |||
Args: | |||
@@ -101,6 +104,7 @@ class BertConfig(object): | |||
self.type_vocab_size = type_vocab_size | |||
self.initializer_range = initializer_range | |||
self.layer_norm_eps = layer_norm_eps | |||
self.architectures = architectures | |||
else: | |||
raise ValueError("First argument must be either a vocabulary size (int)" | |||
"or the path to a pretrained model config file (str)") | |||
@@ -134,9 +138,13 @@ class BertConfig(object): | |||
def to_json_file(self, json_file_path): | |||
r""" Save this instance to a json file.""" | |||
if os.path.isdir(json_file_path): | |||
json_file_path = os.path.join(json_file_path, CONFIG_FILE) | |||
with open(json_file_path, "w", encoding='utf-8') as writer: | |||
writer.write(self.to_json_string()) | |||
def save_pretrained(self, save_directory): | |||
self.to_json_file(save_directory) | |||
def gelu(x): | |||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) | |||
@@ -149,21 +157,6 @@ def swish(x): | |||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} | |||
# class BertLayerNorm(nn.Module): | |||
# def __init__(self, hidden_size, eps=1e-12): | |||
# r"""Construct a layernorm module in the TF style (epsilon inside the square root). | |||
# """ | |||
# super(BertLayerNorm, self).__init__() | |||
# self.weight = nn.Parameter(torch.ones(hidden_size)) | |||
# self.bias = nn.Parameter(torch.zeros(hidden_size)) | |||
# self.variance_epsilon = eps | |||
# | |||
# def forward(self, x): | |||
# u = x.mean(-1, keepdim=True) | |||
# s = (x - u).pow(2).mean(-1, keepdim=True) | |||
# x = (x - u) / torch.sqrt(s + self.variance_epsilon) | |||
# return self.weight * x + self.bias | |||
BertLayerNorm = torch.nn.LayerNorm | |||
@@ -613,3 +606,24 @@ class BertModel(nn.Module): | |||
logger.info(f"Load pre-trained {model_type} parameters from file {weights_path}.") | |||
return model | |||
def save_pretrained(self, save_directory): | |||
""" 保存模型到某个folder | |||
""" | |||
assert os.path.isdir( | |||
save_directory | |||
), "Saving path should be a directory where the model and configuration can be saved" | |||
# Only save the model itself if we are using distributed training | |||
model_to_save = self.module if hasattr(self, "module") else self | |||
# Attach architecture to the config | |||
model_to_save.config.architectures = [model_to_save.__class__.__name__] | |||
# Save configuration file | |||
model_to_save.config.save_pretrained(save_directory) | |||
# If we save using the predefined names, we can load using `from_pretrained` | |||
output_model_file = os.path.join(save_directory, WEIGHTS_NAME) | |||
torch.save(model_to_save.state_dict(), output_model_file) | |||
logger.debug("Model weights saved in {}".format(output_model_file)) |
@@ -15,9 +15,8 @@ import math | |||
from torch.nn import CrossEntropyLoss | |||
from fastNLP.io.file_utils import _get_file_name_base_on_postfix | |||
from ..decoder.seq2seq_decoder import Decoder, Past | |||
from ..decoder.seq2seq_decoder import Seq2SeqDecoder, State | |||
from ..generator.seq2seq_generator import SequenceGenerator | |||
from typing import Tuple | |||
GELU_CONSTANT = math.sqrt(2 / math.pi) | |||
@@ -732,7 +731,7 @@ class GPT2PreTrainedModel(nn.Module): | |||
bos_token_id=bos_token_id, eos_token_id=eos_token_ids, | |||
repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||
pad_token_id=pad_token_id) | |||
results = generator.generate(input_ids, past=None) | |||
results = generator.generate(tokens=input_ids, state=GPT2State()) | |||
return results | |||
@@ -788,21 +787,13 @@ class GPT2Model(GPT2PreTrainedModel): | |||
for layer, heads in heads_to_prune.items(): | |||
self.h[layer].attn.prune_heads(heads) | |||
def forward( | |||
self, | |||
input_ids, | |||
past=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
output_attentions=True | |||
): | |||
def forward(self, input_ids, state=None, attention_mask=None, token_type_ids=None, position_ids=None, | |||
head_mask=None, output_attentions=True): | |||
""" | |||
:param torch.LongTensor input_ids: batch_size x max_len or batch_size x beam_size x 1 | |||
:param GPT2Past past: 之前的状态 | |||
:param torch.ByteTensor attention_mask: batch_size x (pre_len+past_len), 与input_ids与past的concat一样大。 | |||
:param GPT2State state: 之前的状态 | |||
:param torch.ByteTensor attention_mask: batch_size x (pre_len+past_len), 与input_ids与state的concat一样大。 | |||
为0的地方为padding。 | |||
:param torch.LongTensor token_type_ids: batch_size x max_len。 | |||
:param torch.LongTensor position_ids: 与input_ids对应的位置 | |||
@@ -818,11 +809,11 @@ class GPT2Model(GPT2PreTrainedModel): | |||
if position_ids is not None: | |||
position_ids = position_ids.view(-1, input_shape[-1]) | |||
if past is None or len(past)==0: | |||
if state is None or len(state)==0: | |||
past_length = 0 | |||
past = [None] * len(self.h) # len(self.h) 是layer的层数 | |||
state = [None] * len(self.h) # len(self.h) 是layer的层数 | |||
else: | |||
past_length = past[0][0].size(-2) | |||
past_length = state[0][0].size(-2) | |||
if position_ids is None: # 如果没有position id则生成 | |||
device = input_ids.device | |||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) | |||
@@ -880,7 +871,7 @@ class GPT2Model(GPT2PreTrainedModel): | |||
presents = () | |||
all_attentions = [] | |||
all_hidden_states = () | |||
for i, (block, layer_past) in enumerate(zip(self.h, past)): | |||
for i, (block, layer_past) in enumerate(zip(self.h, state)): | |||
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),) | |||
outputs = block( | |||
@@ -915,56 +906,63 @@ class GPT2Model(GPT2PreTrainedModel): | |||
return outputs # last hidden state, (presents), (all hidden_states), (attentions) | |||
class GPT2Past(Past): | |||
class GPT2State(State): | |||
def __init__(self): | |||
super().__init__() | |||
self.past = None # tuple [n_layer, 2 x batch_size x n_head x past_len x head_dim] | |||
super().__init__(None, None) | |||
self.state = None # tuple [n_layer, 2 x batch_size x n_head x past_len x head_dim] | |||
@property | |||
def num_samples(self): | |||
if self.past is not None: | |||
return self.past[0].size(1) | |||
if self.state is not None: | |||
return self.state[0].size(1) | |||
return None | |||
def reorder_past(self, indices): | |||
for i in range(len(self.past)): | |||
assert self.past[i] is not None | |||
self.past[i] = self.past[i].index_select(index=indices, dim=1) | |||
@property | |||
def decode_length(self): | |||
if self.state is None: | |||
return 0 | |||
return self.state[0].size(-2) | |||
def reorder_state(self, indices): | |||
if self.state: | |||
for i in range(len(self.state)): | |||
assert self.state[i] is not None | |||
self.state[i] = self.state[i].index_select(index=indices, dim=1) | |||
def __iter__(self): | |||
for p in self.past: | |||
for p in self.state: | |||
yield p | |||
def __getitem__(self, item): | |||
assert isinstance(item, int) | |||
return self.past[item] | |||
return self.state[item] | |||
def __len__(self): | |||
if self.past is not None: | |||
return len(self.past) | |||
if self.state is not None: | |||
return len(self.state) | |||
return 0 | |||
class _GPT2Decoder(Decoder): | |||
class _GPT2Decoder(Seq2SeqDecoder): | |||
""" | |||
用于wrap GPT2是的可以在SequenceGenerator中使用 | |||
""" | |||
def __init__(self, gpt_model): | |||
super().__init__() | |||
self.gpt_model = gpt_model | |||
def decode(self, tokens, past=None) -> Tuple[torch.Tensor, Past]: | |||
if past is None: | |||
past = GPT2Past() | |||
lm_logits, presents, _ = self.gpt_model(input_ids=tokens, | |||
past=past, | |||
def decode(self, tokens, state=None) -> torch.Tensor: | |||
if state is None: | |||
state = GPT2State() | |||
lm_logits, presents, _ = self.gpt_model(input_ids=tokens[:, state.decode_length:], | |||
state=state, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
output_attentions=False) | |||
past.past = list(presents) | |||
return lm_logits[:, -1], past | |||
def reorder_past(self, indices: torch.LongTensor, past: GPT2Past) -> GPT2Past: | |||
past.reorder_past(indices) | |||
return past | |||
state.state = list(presents) | |||
return lm_logits[:, -1] | |||
class GPT2LMHeadModel(GPT2PreTrainedModel): | |||
@@ -1008,21 +1006,12 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): | |||
def get_input_embeddings(self): | |||
return self.transformer.wte | |||
def forward( | |||
self, | |||
input_ids, | |||
past=None, | |||
attention_mask=None, | |||
token_type_ids=None, | |||
position_ids=None, | |||
head_mask=None, | |||
labels=None, | |||
output_attentions=False | |||
): | |||
def forward(self, input_ids, state=None, attention_mask=None, token_type_ids=None, position_ids=None, | |||
head_mask=None, labels=None, output_attentions=False): | |||
""" | |||
:param torch.LongTensor input_ids: batch_size x max_len or batch_size x beam_size x 1 | |||
:param tuple past: num_layers x 2 x batch_size x n_head x max_len' x head_dim. 可以将前一个时刻的presents作为输入 | |||
:param tuple state: num_layers x 2 x batch_size x n_head x max_len' x head_dim. 可以将前一个时刻的presents作为输入 | |||
:param torch.ByteTensor attention_mask: batch_size x max_len, 与input_ids一样大。为0的地方为padding。 | |||
:param torch.LongTensor token_type_ids: batch_size x max_len。 | |||
:param torch.LongTensor position_ids: 与input_ids对应的位置 | |||
@@ -1034,7 +1023,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): | |||
""" | |||
transformer_outputs = self.transformer( | |||
input_ids, | |||
past=past, | |||
state=state, | |||
attention_mask=attention_mask, | |||
token_type_ids=token_type_ids, | |||
position_ids=position_ids, | |||
@@ -0,0 +1,189 @@ | |||
import torch.nn as nn | |||
import torch | |||
from torch.nn import LayerNorm | |||
import torch.nn.functional as F | |||
from typing import Union, Tuple | |||
from ...core.utils import seq_len_to_mask | |||
import math | |||
from ...modules.encoder.lstm import LSTM | |||
from fastNLP.modules.attention import MultiHeadAttention | |||
from ...embeddings import StaticEmbedding | |||
from ...embeddings.utils import get_embeddings | |||
class Seq2SeqEncoder(nn.Module): | |||
""" | |||
所有Sequence2Sequence Encoder的基类。需要实现forward函数 | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def forward(self, tokens, seq_len): | |||
""" | |||
:param torch.LongTensor tokens: bsz x max_len, encoder的输入 | |||
:param torch.LongTensor seq_len: bsz | |||
:return: | |||
""" | |||
raise NotImplementedError | |||
class TransformerSeq2SeqEncoderLayer(nn.Module): | |||
def __init__(self, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, | |||
dropout: float = 0.1): | |||
""" | |||
Self-Attention的Layer, | |||
:param int d_model: input和output的输出维度 | |||
:param int n_head: 多少个head,每个head的维度为d_model/n_head | |||
:param int dim_ff: FFN的维度大小 | |||
:param float dropout: Self-attention和FFN的dropout大小,0表示不drop | |||
""" | |||
super(TransformerSeq2SeqEncoderLayer, self).__init__() | |||
self.d_model = d_model | |||
self.n_head = n_head | |||
self.dim_ff = dim_ff | |||
self.dropout = dropout | |||
self.self_attn = MultiHeadAttention(d_model, n_head, dropout) | |||
self.attn_layer_norm = LayerNorm(d_model) | |||
self.ffn_layer_norm = LayerNorm(d_model) | |||
self.ffn = nn.Sequential(nn.Linear(self.d_model, self.dim_ff), | |||
nn.ReLU(), | |||
nn.Dropout(dropout), | |||
nn.Linear(self.dim_ff, self.d_model), | |||
nn.Dropout(dropout)) | |||
def forward(self, x, mask): | |||
""" | |||
:param x: batch x src_seq x d_model | |||
:param mask: batch x src_seq,为0的地方为padding | |||
:return: | |||
""" | |||
# attention | |||
residual = x | |||
x = self.attn_layer_norm(x) | |||
x, _ = self.self_attn(query=x, | |||
key=x, | |||
value=x, | |||
key_mask=mask) | |||
x = F.dropout(x, p=self.dropout, training=self.training) | |||
x = residual + x | |||
# ffn | |||
residual = x | |||
x = self.ffn_layer_norm(x) | |||
x = self.ffn(x) | |||
x = residual + x | |||
return x | |||
class TransformerSeq2SeqEncoder(Seq2SeqEncoder): | |||
def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], pos_embed = None, | |||
num_layers = 6, d_model = 512, n_head = 8, dim_ff = 2048, dropout = 0.1): | |||
""" | |||
基于Transformer的Encoder | |||
:param embed: encoder输入token的embedding | |||
:param nn.Module pos_embed: position embedding | |||
:param int num_layers: 多少层的encoder | |||
:param int d_model: 输入输出的维度 | |||
:param int n_head: 多少个head | |||
:param int dim_ff: FFN中间的维度大小 | |||
:param float dropout: Attention和FFN的dropout大小 | |||
""" | |||
super(TransformerSeq2SeqEncoder, self).__init__() | |||
self.embed = get_embeddings(embed) | |||
self.embed_scale = math.sqrt(d_model) | |||
self.pos_embed = pos_embed | |||
self.num_layers = num_layers | |||
self.d_model = d_model | |||
self.n_head = n_head | |||
self.dim_ff = dim_ff | |||
self.dropout = dropout | |||
self.input_fc = nn.Linear(self.embed.embedding_dim, d_model) | |||
self.layer_stacks = nn.ModuleList([TransformerSeq2SeqEncoderLayer(d_model, n_head, dim_ff, dropout) | |||
for _ in range(num_layers)]) | |||
self.layer_norm = LayerNorm(d_model) | |||
def forward(self, tokens, seq_len): | |||
""" | |||
:param tokens: batch x max_len | |||
:param seq_len: [batch] | |||
:return: bsz x max_len x d_model, bsz x max_len(为0的地方为padding) | |||
""" | |||
x = self.embed(tokens) * self.embed_scale # batch, seq, dim | |||
batch_size, max_src_len, _ = x.size() | |||
device = x.device | |||
if self.pos_embed is not None: | |||
position = torch.arange(1, max_src_len + 1).unsqueeze(0).long().to(device) | |||
x += self.pos_embed(position) | |||
x = self.input_fc(x) | |||
x = F.dropout(x, p=self.dropout, training=self.training) | |||
encoder_mask = seq_len_to_mask(seq_len) | |||
encoder_mask = encoder_mask.to(device) | |||
for layer in self.layer_stacks: | |||
x = layer(x, encoder_mask) | |||
x = self.layer_norm(x) | |||
return x, encoder_mask | |||
class LSTMSeq2SeqEncoder(Seq2SeqEncoder): | |||
def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], num_layers = 3, | |||
hidden_size = 400, dropout = 0.3, bidirectional=True): | |||
""" | |||
LSTM的Encoder | |||
:param embed: encoder的token embed | |||
:param int num_layers: 多少层 | |||
:param int hidden_size: LSTM隐藏层、输出的大小 | |||
:param float dropout: LSTM层之间的Dropout是多少 | |||
:param bool bidirectional: 是否使用双向 | |||
""" | |||
super().__init__() | |||
self.embed = get_embeddings(embed) | |||
self.num_layers = num_layers | |||
self.dropout = dropout | |||
self.hidden_size = hidden_size | |||
self.bidirectional = bidirectional | |||
hidden_size = hidden_size//2 if bidirectional else hidden_size | |||
self.lstm = LSTM(input_size=embed.embedding_dim, hidden_size=hidden_size, bidirectional=bidirectional, | |||
batch_first=True, dropout=dropout if num_layers>1 else 0, num_layers=num_layers) | |||
def forward(self, tokens, seq_len): | |||
""" | |||
:param torch.LongTensor tokens: bsz x max_len | |||
:param torch.LongTensor seq_len: bsz | |||
:return: (output, (hidden, cell)), encoder_mask | |||
output: bsz x max_len x hidden_size, | |||
hidden,cell: batch_size x hidden_size, 最后一层的隐藏状态或cell状态 | |||
encoder_mask: bsz x max_len, 为0的地方是padding | |||
""" | |||
x = self.embed(tokens) | |||
device = x.device | |||
x, (final_hidden, final_cell) = self.lstm(x, seq_len) | |||
encoder_mask = seq_len_to_mask(seq_len).to(device) | |||
# x: batch,seq_len,dim; h/c: num_layers*2,batch,dim | |||
if self.bidirectional: | |||
final_hidden = self.concat_bidir(final_hidden) # 将双向的hidden state拼接起来,用于接下来的decoder的input | |||
final_cell = self.concat_bidir(final_cell) | |||
return (x, (final_hidden[-1], final_cell[-1])), encoder_mask # 为了配合Seq2SeqBaseModel的forward,这边需要分为两个return | |||
def concat_bidir(self, input): | |||
output = input.view(self.num_layers, 2, input.size(1), -1).transpose(1, 2) | |||
return output.reshape(self.num_layers, input.size(1), -1) |
@@ -5,7 +5,7 @@ __all__ = [ | |||
] | |||
from torch import nn | |||
from .attention import MultiHeadAttention | |||
from .seq2seq_encoder import TransformerSeq2SeqEncoderLayer | |||
class TransformerEncoder(nn.Module): | |||
@@ -13,66 +13,30 @@ class TransformerEncoder(nn.Module): | |||
transformer的encoder模块,不包含embedding层 | |||
""" | |||
def __init__(self, num_layers, d_model=512, n_head=8, dim_ff=2048, dropout=0.1): | |||
""" | |||
class SubLayer(nn.Module): | |||
def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): | |||
super(TransformerEncoder.SubLayer, self).__init__() | |||
self.atte = MultiHeadAttention(model_size, key_size, value_size, num_head, dropout) | |||
self.norm1 = nn.LayerNorm(model_size, eps=1e-6) | |||
self.ffn = nn.Sequential(nn.Linear(model_size, inner_size), | |||
nn.ReLU(), | |||
nn.Dropout(dropout), | |||
nn.Linear(inner_size, model_size)) | |||
self.norm2 = nn.LayerNorm(model_size, eps=1e-6) | |||
self.dropout = nn.Dropout(dropout) | |||
def forward(self, input, seq_mask=None, atte_mask_out=None): | |||
r""" | |||
:param input: [batch, seq_len, model_size] | |||
:param seq_mask: [batch, seq_len] | |||
:return: [batch, seq_len, model_size] | |||
""" | |||
if seq_mask is None: # 防止后续乘法时出错 | |||
seq_mask = 1 | |||
input = self.norm1(input) | |||
attention = self.atte(input, input, input, atte_mask_out) | |||
input = input + self.dropout(attention) | |||
attention *= seq_mask | |||
input = self.norm2(input) | |||
output = self.ffn(input) | |||
input = input + self.dropout(output) | |||
input *= seq_mask | |||
return input | |||
def __init__(self, num_layers, **kargs): | |||
r""" | |||
:param int num_layers: transformer的层数 | |||
:param int model_size: 输入维度的大小。同时也是输出维度的大小。 | |||
:param int inner_size: FFN层的hidden大小 | |||
:param int key_size: 每个head的维度大小。 | |||
:param int value_size: 每个head中value的维度。 | |||
:param int num_head: head的数量。 | |||
:param float dropout: dropout概率. Default: 0.1 | |||
:param int num_layers: 多少层Transformer | |||
:param int d_model: input和output的大小 | |||
:param int n_head: 多少个head | |||
:param int dim_ff: FFN中间hidden大小 | |||
:param float dropout: 多大概率drop attention和ffn中间的表示 | |||
""" | |||
super(TransformerEncoder, self).__init__() | |||
self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)]) | |||
self.norm = nn.LayerNorm(kargs['model_size'], eps=1e-6) | |||
self.layers = nn.ModuleList([TransformerSeq2SeqEncoderLayer(d_model = d_model, n_head = n_head, dim_ff = dim_ff, | |||
dropout = dropout) for _ in range(num_layers)]) | |||
self.norm = nn.LayerNorm(d_model, eps=1e-6) | |||
def forward(self, x, seq_mask=None): | |||
r""" | |||
:param x: [batch, seq_len, model_size] 输入序列 | |||
:param seq_mask: [batch, seq_len] 输入序列的padding mask, 若为 ``None`` , 生成全1向量. | |||
:param seq_mask: [batch, seq_len] 输入序列的padding mask, 若为 ``None`` , 生成全1向量. 为1的地方需要attend | |||
Default: ``None`` | |||
:return: [batch, seq_len, model_size] 输出序列 | |||
""" | |||
output = x | |||
if seq_mask is None: | |||
atte_mask_out = None | |||
else: | |||
atte_mask_out = (seq_mask.eq(False))[:, None, :] | |||
seq_mask = seq_mask[:, :, None] | |||
seq_mask = x.new_ones(x.size(0), x.size(1)).bool() | |||
for layer in self.layers: | |||
output = layer(output, seq_mask, atte_mask_out) | |||
output = layer(output, seq_mask) | |||
return self.norm(output) |
@@ -0,0 +1,9 @@ | |||
r""" | |||
""" | |||
__all__ = [ | |||
"SequenceGenerator" | |||
] | |||
from .seq2seq_generator import SequenceGenerator |
@@ -7,16 +7,35 @@ __all__ = [ | |||
] | |||
import torch | |||
from ..decoder.seq2seq_decoder import Decoder | |||
from ..decoder.seq2seq_decoder import Seq2SeqDecoder, State | |||
import torch.nn.functional as F | |||
from fastNLP.core.utils import _get_model_device | |||
from ...core.utils import _get_model_device | |||
from functools import partial | |||
class SequenceGenerator: | |||
def __init__(self, decoder: Decoder, max_length=20, num_beams=1, | |||
""" | |||
给定一个Seq2SeqDecoder,decode出句子 | |||
""" | |||
def __init__(self, decoder: Seq2SeqDecoder, max_length=20, num_beams=1, | |||
do_sample=True, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, | |||
repetition_penalty=1, length_penalty=1.0, pad_token_id=0): | |||
""" | |||
:param Seq2SeqDecoder decoder: Decoder对象 | |||
:param int max_length: 句子的最大长度 | |||
:param int num_beams: beam search的大小 | |||
:param bool do_sample: 是否通过采样的方式生成 | |||
:param float temperature: 只有在do_sample为True才有意义 | |||
:param int top_k: 只从top_k中采样 | |||
:param float top_p: 只从top_p的token中采样,nucles sample | |||
:param int,None bos_token_id: 句子开头的token id | |||
:param int,None eos_token_id: 句子结束的token id | |||
:param float repetition_penalty: 多大程度上惩罚重复的token | |||
:param float length_penalty: 对长度的惩罚,小于1鼓励长句,大于1鼓励短剧 | |||
:param int pad_token_id: 当某句话生成结束之后,之后生成的内容用pad_token_id补充 | |||
""" | |||
if do_sample: | |||
self.generate_func = partial(sample_generate, decoder=decoder, max_length=max_length, num_beams=num_beams, | |||
temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id, | |||
@@ -40,19 +59,19 @@ class SequenceGenerator: | |||
self.decoder = decoder | |||
@torch.no_grad() | |||
def generate(self, tokens=None, past=None): | |||
def generate(self, state, tokens=None): | |||
""" | |||
:param torch.LongTensor tokens: batch_size x length, 开始的token | |||
:param past: | |||
:return: | |||
:param State state: encoder结果的State, 是与Decoder配套是用的 | |||
:param torch.LongTensor,None tokens: batch_size x length, 开始的token | |||
:return: bsz x max_length' 生成的token序列。如果eos_token_id不为None, 每个sequence的结尾一定是eos_token_id | |||
""" | |||
# TODO 需要查看如果tokens长度不是1,decode的时候是否还能够直接decode? | |||
return self.generate_func(tokens=tokens, past=past) | |||
return self.generate_func(tokens=tokens, state=state) | |||
@torch.no_grad() | |||
def greedy_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, | |||
def greedy_generate(decoder, tokens=None, state=None, max_length=20, num_beams=1, | |||
bos_token_id=None, eos_token_id=None, pad_token_id=0, | |||
repetition_penalty=1, length_penalty=1.0): | |||
""" | |||
@@ -60,23 +79,23 @@ def greedy_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, | |||
:param Decoder decoder: Decoder对象 | |||
:param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 | |||
:param Past past: 应该包好encoder的一些输出。 | |||
:param State state: 应该包含encoder的一些输出。 | |||
:param int max_length: 生成句子的最大长度。 | |||
:param int num_beams: 使用多大的beam进行解码。 | |||
:param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。 | |||
:param int eos_token_id: 结束的token,如果为None,则一定会解码到max_length这么长。 | |||
:param int pad_token_id: | |||
:param int pad_token_id: pad的token id | |||
:param float repetition_penalty: 对重复出现的token多大的惩罚。 | |||
:param float length_penalty: 对每个token(除了eos)按照长度进行一定的惩罚。 | |||
:return: | |||
""" | |||
if num_beams == 1: | |||
token_ids = _no_beam_search_generate(decoder, tokens, past, max_length, temperature=1, top_k=50, top_p=1, | |||
token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, temperature=1, top_k=50, top_p=1, | |||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, | |||
repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||
pad_token_id=pad_token_id) | |||
else: | |||
token_ids = _beam_search_generate(decoder, tokens, past, max_length, num_beams=num_beams, | |||
token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, num_beams=num_beams, | |||
temperature=1, top_k=50, top_p=1, | |||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, | |||
repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||
@@ -86,7 +105,7 @@ def greedy_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, | |||
@torch.no_grad() | |||
def sample_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, temperature=1.0, top_k=50, | |||
def sample_generate(decoder, tokens=None, state=None, max_length=20, num_beams=1, temperature=1.0, top_k=50, | |||
top_p=1.0, bos_token_id=None, eos_token_id=None, pad_token_id=0, repetition_penalty=1.0, | |||
length_penalty=1.0): | |||
""" | |||
@@ -94,7 +113,7 @@ def sample_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, | |||
:param Decoder decoder: Decoder对象 | |||
:param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 | |||
:param Past past: 应该包好encoder的一些输出。 | |||
:param State state: 应该包含encoder的一些输出。 | |||
:param int max_length: 生成句子的最大长度。 | |||
:param int num_beam: 使用多大的beam进行解码。 | |||
:param float temperature: 采样时的退火大小 | |||
@@ -109,13 +128,13 @@ def sample_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, | |||
""" | |||
# 每个位置在生成的时候会sample生成 | |||
if num_beams == 1: | |||
token_ids = _no_beam_search_generate(decoder, tokens, past, max_length, temperature=temperature, | |||
token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, temperature=temperature, | |||
top_k=top_k, top_p=top_p, | |||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, | |||
repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||
pad_token_id=pad_token_id) | |||
else: | |||
token_ids = _beam_search_generate(decoder, tokens, past, max_length, num_beams=num_beams, | |||
token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, num_beams=num_beams, | |||
temperature=temperature, top_k=top_k, top_p=top_p, | |||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, | |||
repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||
@@ -123,40 +142,35 @@ def sample_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, | |||
return token_ids | |||
def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=20, temperature=1.0, top_k=50, | |||
def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_length=20, temperature=1.0, top_k=50, | |||
top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, | |||
repetition_penalty=1.0, length_penalty=1.0, pad_token_id=0): | |||
device = _get_model_device(decoder) | |||
if tokens is None: | |||
if bos_token_id is None: | |||
raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.") | |||
if past is None: | |||
raise RuntimeError("You have to specify either `past` or `tokens`.") | |||
batch_size = past.num_samples() | |||
batch_size = state.num_samples | |||
if batch_size is None: | |||
raise RuntimeError("Cannot infer the number of samples from `past`.") | |||
raise RuntimeError("Cannot infer the number of samples from `state`.") | |||
tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device) | |||
batch_size = tokens.size(0) | |||
if past is not None: | |||
assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match." | |||
if state.num_samples: | |||
assert state.num_samples == batch_size, "The number of samples in `tokens` and `state` should match." | |||
if eos_token_id is None: | |||
_eos_token_id = float('nan') | |||
_eos_token_id = -1 | |||
else: | |||
_eos_token_id = eos_token_id | |||
# for i in range(tokens.size(1)): | |||
# scores, past = decoder.decode_one(tokens[:, :i + 1], past) # batch_size x vocab_size, Past | |||
scores, past = decoder.decode(tokens, past) | |||
token_ids = tokens.clone() | |||
scores = decoder.decode(tokens=tokens, state=state) # 主要是为了update state | |||
next_tokens = scores.argmax(dim=-1, keepdim=True) | |||
token_ids = torch.cat([tokens, next_tokens], dim=1) | |||
cur_len = token_ids.size(1) | |||
dones = token_ids.new_zeros(batch_size).eq(1) | |||
# tokens = tokens[:, -1:] | |||
while cur_len < max_length: | |||
# scores, past = decoder.decode_one(tokens, past) # batch_size x vocab_size, Past | |||
scores, past = decoder.decode(tokens, past) # batch_size x vocab_size, Past | |||
scores = decoder.decode(tokens=token_ids, state=state) # batch_size x vocab_size | |||
if repetition_penalty != 1.0: | |||
token_scores = scores.gather(dim=1, index=token_ids) | |||
@@ -204,7 +218,7 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt | |||
return token_ids | |||
def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=20, num_beams=4, temperature=1.0, | |||
def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_length=20, num_beams=4, temperature=1.0, | |||
top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, | |||
repetition_penalty=1.0, length_penalty=None, pad_token_id=0) -> torch.LongTensor: | |||
# 进行beam search | |||
@@ -212,21 +226,20 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||
if tokens is None: | |||
if bos_token_id is None: | |||
raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.") | |||
if past is None: | |||
raise RuntimeError("You have to specify either `past` or `tokens`.") | |||
batch_size = past.num_samples() | |||
batch_size = state.num_samples | |||
if batch_size is None: | |||
raise RuntimeError("Cannot infer the number of samples from `past`.") | |||
raise RuntimeError("Cannot infer the number of samples from `state`.") | |||
tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device) | |||
batch_size = tokens.size(0) | |||
if past is not None: | |||
assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match." | |||
# for i in range(tokens.size(1) - 1): # 如果输入的长度较长,先decode | |||
# scores, past = decoder.decode_one(tokens[:, :i + 1], | |||
# past) # (batch_size, vocab_size), Past | |||
# scores, past = decoder.decode_one(tokens, past) # 这里要传入的是整个句子的长度 | |||
scores, past = decoder.decode(tokens, past) # 这里要传入的是整个句子的长度 | |||
if state.num_samples: | |||
assert state.num_samples == batch_size, "The number of samples in `tokens` and `state` should match." | |||
if eos_token_id is None: | |||
_eos_token_id = -1 | |||
else: | |||
_eos_token_id = eos_token_id | |||
scores = decoder.decode(tokens=tokens, state=state) # 这里要传入的是整个句子的长度 | |||
vocab_size = scores.size(1) | |||
assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size." | |||
@@ -240,15 +253,15 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||
# 得到(batch_size, num_beams), (batch_size, num_beams) | |||
next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True) | |||
# 根据index来做顺序的调转 | |||
indices = torch.arange(batch_size, dtype=torch.long).to(device) | |||
indices = indices.repeat_interleave(num_beams) | |||
decoder.reorder_past(indices, past) | |||
state.reorder_state(indices) | |||
tokens = tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length | |||
# 记录生成好的token (batch_size', cur_len) | |||
token_ids = torch.cat([tokens, next_tokens.view(-1, 1)], dim=-1) | |||
dones = [False] * batch_size | |||
tokens = next_tokens.view(-1, 1) | |||
beam_scores = next_scores.view(-1) # batch_size * num_beams | |||
@@ -262,8 +275,7 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||
batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids) | |||
while cur_len < max_length: | |||
# scores, past = decoder.decode_one(tokens, past) # batch_size * num_beams x vocab_size, Past | |||
scores, past = decoder.decode(tokens, past) | |||
scores = decoder.decode(token_ids, state) | |||
if repetition_penalty != 1.0: | |||
token_scores = scores.gather(dim=1, index=token_ids) | |||
lt_zero_mask = token_scores.lt(0).float() | |||
@@ -307,7 +319,7 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||
next_tokens = next_tokens.gather(dim=1, index=sorted_inds) | |||
from_which_beam = from_which_beam.gather(dim=1, index=sorted_inds) | |||
not_eos_mask = next_tokens.ne(eos_token_id) # 为1的地方不是eos | |||
not_eos_mask = next_tokens.ne(_eos_token_id) # 为1的地方不是eos | |||
keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams) # 为1的地方需要保留 | |||
keep_mask = not_eos_mask.__and__(keep_mask) # 为1的地方是需要进行下一步search的 | |||
@@ -316,18 +328,18 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||
_next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams) | |||
beam_scores = _next_scores.view(-1) | |||
# 更改past状态, 重组token_ids | |||
# 更改state状态, 重组token_ids | |||
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维 | |||
decoder.reorder_past(reorder_inds, past) | |||
state.reorder_state(reorder_inds) | |||
flag = True | |||
if cur_len + 1 == max_length: | |||
if cur_len+1 == max_length: | |||
eos_batch_idx = torch.arange(batch_size).to(next_tokens).repeat_interleave(repeats=num_beams, dim=0) | |||
eos_beam_ind = torch.arange(num_beams).to(token_ids).repeat(batch_size) # 表示的是indice | |||
eos_beam_idx = from_which_beam[:, :num_beams].reshape(-1) # 表示的是从哪个beam获取得到的 | |||
else: | |||
# 将每个batch中在num_beam内的序列添加到结束中, 为1的地方需要结束了 | |||
effective_eos_mask = next_tokens[:, :num_beams].eq(eos_token_id) # batch_size x num_beams | |||
effective_eos_mask = next_tokens[:, :num_beams].eq(_eos_token_id) # batch_size x num_beams | |||
if effective_eos_mask.sum().gt(0): | |||
eos_batch_idx, eos_beam_ind = effective_eos_mask.nonzero(as_tuple=True) | |||
# 是由于from_which_beam是 (batch_size, 2*num_beams)的,所以需要2*num_beams | |||
@@ -335,16 +347,17 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||
eos_beam_idx = from_which_beam.view(-1)[eos_beam_idx] # 获取真实的从哪个beam获取的eos | |||
else: | |||
flag = False | |||
# 重新组织token_ids的状态 | |||
tokens = _next_tokens | |||
token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), tokens], dim=-1) | |||
if flag: | |||
for batch_idx, beam_ind, beam_idx in zip(eos_batch_idx.tolist(), eos_beam_ind.tolist(), | |||
eos_beam_idx.tolist()): | |||
if not dones[batch_idx]: | |||
score = next_scores[batch_idx, beam_ind].item() | |||
hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score) | |||
# 重新组织token_ids的状态 | |||
tokens = _next_tokens | |||
token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), tokens], dim=-1) | |||
hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx, :cur_len+1].clone(), score) | |||
for batch_idx in range(batch_size): | |||
dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item()) | |||
@@ -360,15 +373,15 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||
for i, hypotheses in enumerate(hypos): | |||
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1] | |||
tgt_len[i] = len(best_hyp) + 1 # +1 for the <EOS> symbol | |||
tgt_len[i] = len(best_hyp) # +1 for the <EOS> symbol | |||
best.append(best_hyp) | |||
# generate target batch | |||
decoded = token_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id) | |||
for i, hypo in enumerate(best): | |||
decoded[i, :tgt_len[i] - 1] = hypo | |||
decoded[i, :tgt_len[i]] = hypo | |||
if eos_token_id is not None: | |||
decoded[i, tgt_len[i] - 1] = eos_token_id | |||
decoded[i, tgt_len[i] - 1] = _eos_token_id | |||
return decoded | |||
@@ -384,6 +384,9 @@ class BertTokenizer(object): | |||
index += 1 | |||
return vocab_file | |||
def save_pretrained(self, save_directory): | |||
self.save_vocabulary(save_directory) | |||
@classmethod | |||
def from_pretrained(cls, model_dir_or_name, *inputs, **kwargs): | |||
r""" | |||
@@ -377,6 +377,9 @@ class GPT2Tokenizer: | |||
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) | |||
return text | |||
def save_pretrained(self, save_directory): | |||
return self.save_vocabulary(save_directory) | |||
def save_vocabulary(self, save_directory): | |||
"""Save the tokenizer vocabulary and merge files to a directory.""" | |||
if not os.path.isdir(save_directory): | |||
@@ -32,7 +32,6 @@ from tools.PositionEmbedding import get_sinusoid_encoding_table | |||
from tools.logger import * | |||
from fastNLP.core.const import Const | |||
from fastNLP.modules.encoder.transformer import TransformerEncoder | |||
from transformer.Layers import EncoderLayer | |||
@@ -30,7 +30,7 @@ from .Encoder import Encoder | |||
from tools.PositionEmbedding import get_sinusoid_encoding_table | |||
from fastNLP.core.const import Const | |||
from fastNLP.modules.encoder.transformer import TransformerEncoder | |||
from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoderLayer | |||
class TransformerModel(nn.Module): | |||
def __init__(self, hps, vocab): | |||
@@ -68,7 +68,8 @@ class TransformerModel(nn.Module): | |||
get_sinusoid_encoding_table(hps.doc_max_timesteps + 1, self.hidden_size, padding_idx=0), freeze=True) | |||
self.layer_stack = nn.ModuleList([ | |||
TransformerEncoder.SubLayer(model_size=self.hidden_size, inner_size=self.d_inner, key_size=self.d_k, value_size=self.d_v,num_head=self.n_head, dropout=hps.atten_dropout_prob) | |||
TransformerSeq2SeqEncoderLayer(d_model = self.hidden_size, n_head = self.n_head, dim_ff = self.d_inner, | |||
dropout = hps.atten_dropout_prob) | |||
for _ in range(self.num_layers)]) | |||
self.wh = nn.Linear(self.hidden_size, 2) | |||
@@ -109,7 +110,7 @@ class TransformerModel(nn.Module): | |||
for enc_layer in self.layer_stack: | |||
# enc_output = [batch_size, N, hidden_size = n_head * d_v] | |||
# enc_slf_attn = [n_head * batch_size, N, N] | |||
enc_input = enc_layer(enc_input, seq_mask=self.non_pad_mask, atte_mask_out=self.slf_attn_mask) | |||
enc_input = enc_layer(enc_input, encoder_mask=self.slf_attn_mask) | |||
enc_input_list += [enc_input] | |||
self.dec_output_state = torch.cat(enc_input_list[-4:]) # [4, batch_size, N, hidden_state] | |||
@@ -265,7 +265,7 @@ def run_eval(model, loader, hps, best_loss, best_F, non_descent_cnt): | |||
label = Variable(label) | |||
input_len = Variable(input_len, requires_grad=False) | |||
model_outputs = model.forward(input,input_len) # [batch, N, 2] | |||
model_outputs = model.forward(input, input_len) # [batch, N, 2] | |||
outputs = model_outputs["p_sent"] | |||
prediction = model_outputs["prediction"] | |||
@@ -264,7 +264,7 @@ def run_eval(model, loader, hps, best_loss, best_F, non_descent_cnt): | |||
label = Variable(label) | |||
input_len = Variable(input_len, requires_grad=False) | |||
model_outputs = model.forward(input,input_len) # [batch, N, 2] | |||
model_outputs = model.forward(input, input_len) # [batch, N, 2] | |||
outputs = model_outputs[Const.OUTPUTS] | |||
prediction = model_outputs["prediction"] | |||
@@ -7,10 +7,12 @@ from transformer.Layers import EncoderLayer, DecoderLayer | |||
__author__ = "Yu-Hsiang Huang" | |||
def get_non_pad_mask(seq): | |||
assert seq.dim() == 2 | |||
return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1) | |||
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): | |||
''' Sinusoid position encoding table ''' | |||
@@ -31,6 +33,7 @@ def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): | |||
return torch.FloatTensor(sinusoid_table) | |||
def get_attn_key_pad_mask(seq_k, seq_q): | |||
''' For masking out the padding part of key sequence. ''' | |||
@@ -41,6 +44,7 @@ def get_attn_key_pad_mask(seq_k, seq_q): | |||
return padding_mask | |||
def get_subsequent_mask(seq): | |||
''' For masking out the subsequent info. ''' | |||
@@ -51,6 +55,7 @@ def get_subsequent_mask(seq): | |||
return subsequent_mask | |||
class Encoder(nn.Module): | |||
''' A encoder model with self attention mechanism. ''' | |||
@@ -98,6 +103,7 @@ class Encoder(nn.Module): | |||
return enc_output, enc_slf_attn_list | |||
return enc_output, | |||
class Decoder(nn.Module): | |||
''' A decoder model with self attention mechanism. ''' | |||
@@ -152,6 +158,7 @@ class Decoder(nn.Module): | |||
return dec_output, dec_slf_attn_list, dec_enc_attn_list | |||
return dec_output, | |||
class Transformer(nn.Module): | |||
''' A sequence to sequence model with attention mechanism. ''' | |||
@@ -181,8 +188,8 @@ class Transformer(nn.Module): | |||
nn.init.xavier_normal_(self.tgt_word_prj.weight) | |||
assert d_model == d_word_vec, \ | |||
'To facilitate the residual connections, \ | |||
the dimensions of all module outputs shall be the same.' | |||
'To facilitate the residual connections, \ | |||
the dimensions of all module outputs shall be the same.' | |||
if tgt_emb_prj_weight_sharing: | |||
# Share the weight matrix between target word embedding & the final logit dense layer | |||
@@ -194,7 +201,7 @@ class Transformer(nn.Module): | |||
if emb_src_tgt_weight_sharing: | |||
# Share the weight matrix between source & target word embeddings | |||
assert n_src_vocab == n_tgt_vocab, \ | |||
"To share word embedding table, the vocabulary size of src/tgt shall be the same." | |||
"To share word embedding table, the vocabulary size of src/tgt shall be the same." | |||
self.encoder.src_word_emb.weight = self.decoder.tgt_word_emb.weight | |||
def forward(self, src_seq, src_pos, tgt_seq, tgt_pos): | |||
@@ -1,7 +1,6 @@ | |||
import fastNLP | |||
import torch | |||
import math | |||
from fastNLP.modules.encoder.transformer import TransformerEncoder | |||
from fastNLP.modules.decoder.crf import ConditionalRandomField | |||
from fastNLP import Const | |||
import copy | |||
@@ -181,7 +180,6 @@ def make_CWS( | |||
freeze=True, | |||
): | |||
c = copy.deepcopy | |||
# encoder=TransformerEncoder(num_layers=N,model_size=d_model,inner_size=d_ff,key_size=d_model//h,value_size=d_model//h,num_head=h,dropout=dropout) | |||
encoder = transformer.make_encoder( | |||
N=N, d_model=d_model, h=h, dropout=dropout, d_ff=d_ff | |||
) | |||
@@ -1,9 +1,8 @@ | |||
import torch | |||
import torch.nn as nn | |||
from fastNLP.core.const import Const as C | |||
from fastNLP.modules.encoder.lstm import LSTM | |||
from fastNLP.embeddings.utils import get_embeddings | |||
from fastNLP.modules.encoder.attention import SelfAttention | |||
from fastNLP.modules.attention import SelfAttention | |||
from fastNLP.modules.decoder.mlp import MLP | |||
@@ -44,7 +44,7 @@ class WeightDrop(torch.nn.Module): | |||
def forward(self, *args): | |||
self._setweights() | |||
return self.module.forward(*args) | |||
return self.module.forward() | |||
if __name__ == '__main__': | |||
import torch | |||
@@ -40,8 +40,7 @@ class TestBertEmbedding(unittest.TestCase): | |||
result = embed(words) | |||
self.assertEqual(result.size(), (1, 4, 16)) | |||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1, | |||
only_use_pretrain_bpe=True) | |||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1) | |||
embed.eval() | |||
words = torch.LongTensor([[2, 3, 4, 0]]) | |||
result = embed(words) | |||
@@ -49,53 +48,30 @@ class TestBertEmbedding(unittest.TestCase): | |||
# 自动截断而不报错 | |||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1, | |||
only_use_pretrain_bpe=True, auto_truncate=True) | |||
auto_truncate=True) | |||
words = torch.LongTensor([[2, 3, 4, 1]*10, | |||
[2, 3]+[0]*38]) | |||
result = embed(words) | |||
self.assertEqual(result.size(), (2, 40, 16)) | |||
def test_bert_embedding_2(self): | |||
# 测试only_use_pretrain_vocab与truncate_embed是否正常工作 | |||
with open('test/data_for_tests/embedding/small_bert/vocab.txt', 'r', encoding='utf-8') as f: | |||
num_word = len(f.readlines()) | |||
Embedding = BertEmbedding | |||
vocab = Vocabulary().add_word_lst("this is a texta and [SEP] NotInBERT".split()) | |||
embed1 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | |||
only_use_pretrain_bpe=True, truncate_embed=True, min_freq=1) | |||
embed_bpe_vocab_size = len(vocab)-1 + 2 # 排除NotInBERT, 额外加##a, [CLS] | |||
self.assertEqual(embed_bpe_vocab_size, len(embed1.model.tokenzier.vocab)) | |||
embed2 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | |||
only_use_pretrain_bpe=True, truncate_embed=False, min_freq=1) | |||
embed_bpe_vocab_size = num_word # 排除NotInBERT | |||
self.assertEqual(embed_bpe_vocab_size, len(embed2.model.tokenzier.vocab)) | |||
embed3 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | |||
only_use_pretrain_bpe=False, truncate_embed=True, min_freq=1) | |||
embed_bpe_vocab_size = len(vocab)+2 # 新增##a, [CLS] | |||
self.assertEqual(embed_bpe_vocab_size, len(embed3.model.tokenzier.vocab)) | |||
embed4 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | |||
only_use_pretrain_bpe=False, truncate_embed=False, min_freq=1) | |||
embed_bpe_vocab_size = num_word+1 # 新增##a | |||
self.assertEqual(embed_bpe_vocab_size, len(embed4.model.tokenzier.vocab)) | |||
# 测试各种情况下以下tensor的值是相等的 | |||
embed1.eval() | |||
embed2.eval() | |||
embed3.eval() | |||
embed4.eval() | |||
tensor = torch.LongTensor([[vocab.to_index(w) for w in 'this is a texta and'.split()]]) | |||
t1 = embed1(tensor) | |||
t2 = embed2(tensor) | |||
t3 = embed3(tensor) | |||
t4 = embed4(tensor) | |||
self.assertEqual((t1-t2).sum(), 0) | |||
self.assertEqual((t1-t3).sum(), 0) | |||
self.assertEqual((t1-t4).sum(), 0) | |||
def test_save_load(self): | |||
bert_save_test = 'bert_save_test' | |||
try: | |||
os.makedirs(bert_save_test, exist_ok=True) | |||
vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInBERT".split()) | |||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1, | |||
auto_truncate=True) | |||
embed.save(bert_save_test) | |||
load_embed = BertEmbedding.load(bert_save_test) | |||
words = torch.randint(len(vocab), size=(2, 20)) | |||
embed.eval(), load_embed.eval() | |||
self.assertEqual((embed(words) - load_embed(words)).sum(), 0) | |||
finally: | |||
import shutil | |||
shutil.rmtree(bert_save_test) | |||
class TestBertWordPieceEncoder(unittest.TestCase): | |||
@@ -120,11 +96,30 @@ class TestBertWordPieceEncoder(unittest.TestCase): | |||
ds.set_input('words') | |||
words = torch.LongTensor(ds['words'].get([0, 1])) | |||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | |||
pool_method='first', include_cls_sep=True, pooled_cls=False) | |||
pool_method='first', include_cls_sep=True, pooled_cls=False, min_freq=1) | |||
embed.eval() | |||
words_res = embed(words) | |||
# 检查word piece什么的是正常work的 | |||
self.assertEqual((word_pieces_res[0, :5]-words_res[0, :5]).sum(), 0) | |||
self.assertEqual((word_pieces_res[0, 6:]-words_res[0, 5:]).sum(), 0) | |||
self.assertEqual((word_pieces_res[1, :3]-words_res[1, :3]).sum(), 0) | |||
self.assertEqual((word_pieces_res[1, :3]-words_res[1, :3]).sum(), 0) | |||
def test_save_load(self): | |||
bert_save_test = 'bert_save_test' | |||
try: | |||
os.makedirs(bert_save_test, exist_ok=True) | |||
embed = BertWordPieceEncoder(model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.0, | |||
layers='-2') | |||
ds = DataSet({'words': ["this is a test . [SEP]".split()]}) | |||
embed.index_datasets(ds, field_name='words') | |||
self.assertTrue(ds.has_field('word_pieces')) | |||
words = torch.LongTensor([[1, 2, 3, 4]]) | |||
embed.save(bert_save_test) | |||
load_embed = BertWordPieceEncoder.load(bert_save_test) | |||
embed.eval(), load_embed.eval() | |||
self.assertEqual((embed(words) - load_embed(words)).sum(), 0) | |||
finally: | |||
import shutil | |||
shutil.rmtree(bert_save_test) | |||
@@ -255,14 +255,17 @@ class TestGPT2WordPieceEncoder(unittest.TestCase): | |||
result = embed(torch.LongTensor([[1, 2, 3, 4]])) | |||
def test_generate(self): | |||
weight_path = 'test/data_for_tests/embedding/small_gpt2' | |||
# weight_path = 'test/data_for_tests/embedding/small_gpt2' | |||
weight_path = 'en' | |||
encoder = GPT2WordPieceEncoder(model_dir_or_name=weight_path, language_model=True) | |||
# 测试一下各项东西是否正常work | |||
print(encoder.generate_from_str('this', max_len=20, do_sample=False, num_beams=1, temperature=1, top_k=50, top_p=1.0, | |||
print(encoder.generate_from_str('This', max_len=20, do_sample=False, num_beams=1, temperature=1, top_k=50, top_p=1.0, | |||
repetition_penalty=1.0, length_penalty=1.0)) | |||
print(encoder.generate_from_str('This day', max_len=20, do_sample=False, num_beams=1, temperature=1, top_k=50, top_p=1.0, | |||
repetition_penalty=1.0, length_penalty=1.0)) | |||
print(encoder.generate_from_str('this', max_len=20, do_sample=True, num_beams=3, temperature=1, top_k=50, top_p=1.0, | |||
print(encoder.generate_from_str('This', max_len=20, do_sample=True, num_beams=3, temperature=1, top_k=50, top_p=1.0, | |||
repetition_penalty=1.0, length_penalty=1.0)) | |||
print(encoder.generate_from_str('this', max_len=20, do_sample=True, num_beams=3, temperature=2, top_k=20, top_p=2.0, | |||
print(encoder.generate_from_str('This', max_len=20, do_sample=True, num_beams=3, temperature=2, top_k=20, top_p=2.0, | |||
repetition_penalty=2.0, length_penalty=1.5)) |
@@ -47,7 +47,7 @@ class TestRobertWordPieceEncoder(unittest.TestCase): | |||
ds.set_input('words') | |||
words = torch.LongTensor(ds['words'].get([0, 1])) | |||
embed = RobertaEmbedding(vocab, model_dir_or_name=weight_path, | |||
pool_method='first', include_cls_sep=True, pooled_cls=False) | |||
pool_method='first', include_cls_sep=True, pooled_cls=False, min_freq=1) | |||
embed.eval() | |||
words_res = embed(words) | |||
@@ -183,6 +183,24 @@ class TestRobertWordPieceEncoder(unittest.TestCase): | |||
torch.save(model.state_dict(), 'test/data_for_tests/embedding/small_roberta/small_pytorch_model.bin') | |||
print(model(torch.LongTensor([[0,1,2,3]]))) | |||
def test_save_load(self): | |||
bert_save_test = 'roberta_save_test' | |||
try: | |||
os.makedirs(bert_save_test, exist_ok=True) | |||
embed = RobertaWordPieceEncoder(model_dir_or_name='test/data_for_tests/embedding/small_roberta', word_dropout=0.0, | |||
layers='-2') | |||
ds = DataSet({'words': ["this is a test . [SEP]".split()]}) | |||
embed.index_datasets(ds, field_name='words') | |||
self.assertTrue(ds.has_field('word_pieces')) | |||
words = torch.LongTensor([[1, 2, 3, 4]]) | |||
embed.save(bert_save_test) | |||
load_embed = RobertaWordPieceEncoder.load(bert_save_test) | |||
embed.eval(), load_embed.eval() | |||
self.assertEqual((embed(words) - load_embed(words)).sum(), 0) | |||
finally: | |||
import shutil | |||
shutil.rmtree(bert_save_test) | |||
class TestRobertaEmbedding(unittest.TestCase): | |||
def test_roberta_embedding_1(self): | |||
@@ -250,3 +268,20 @@ class TestRobertaEmbedding(unittest.TestCase): | |||
self.assertEqual((t1-t2).sum(), 0) | |||
self.assertEqual((t1-t3).sum(), 0) | |||
self.assertEqual((t1-t4).sum(), 0) | |||
def test_save_load(self): | |||
bert_save_test = 'roberta_save_test' | |||
try: | |||
os.makedirs(bert_save_test, exist_ok=True) | |||
vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInBERT".split()) | |||
embed = RobertaEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_roberta', | |||
word_dropout=0.1, | |||
auto_truncate=True) | |||
embed.save(bert_save_test) | |||
load_embed = RobertaEmbedding.load(bert_save_test) | |||
words = torch.randint(len(vocab), size=(2, 20)) | |||
embed.eval(), load_embed.eval() | |||
self.assertEqual((embed(words) - load_embed(words)).sum(), 0) | |||
finally: | |||
import shutil | |||
shutil.rmtree(bert_save_test) |
@@ -108,6 +108,56 @@ class TestLoad(unittest.TestCase): | |||
for v1i, v2i in zip(v1, v2): | |||
self.assertAlmostEqual(v1i, v2i, places=4) | |||
def test_save_load_static_embed(self): | |||
static_test_folder = 'static_save_test' | |||
try: | |||
# 测试包含no_create_entry | |||
os.makedirs(static_test_folder, exist_ok=True) | |||
vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A']) | |||
vocab.add_word_lst(['notinfile2', 'notinfile2'], no_create_entry=True) | |||
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||
'glove.6B.50d_test.txt') | |||
embed.save(static_test_folder) | |||
load_embed = StaticEmbedding.load(static_test_folder) | |||
words = torch.randint(len(vocab), size=(2, 20)) | |||
self.assertEqual((embed(words) - load_embed(words)).sum(), 0) | |||
# 测试不包含no_create_entry | |||
vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A']) | |||
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||
'glove.6B.50d_test.txt') | |||
embed.save(static_test_folder) | |||
load_embed = StaticEmbedding.load(static_test_folder) | |||
words = torch.randint(len(vocab), size=(2, 20)) | |||
self.assertEqual((embed(words) - load_embed(words)).sum(), 0) | |||
# 测试lower, min_freq | |||
vocab = Vocabulary().add_word_lst(['The', 'the', 'the', 'A', 'a', 'B']) | |||
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||
'glove.6B.50d_test.txt', min_freq=2, lower=True) | |||
embed.save(static_test_folder) | |||
load_embed = StaticEmbedding.load(static_test_folder) | |||
words = torch.randint(len(vocab), size=(2, 20)) | |||
self.assertEqual((embed(words) - load_embed(words)).sum(), 0) | |||
# 测试random的embedding | |||
vocab = Vocabulary().add_word_lst(['The', 'the', 'the', 'A', 'a', 'B']) | |||
vocab = vocab.add_word_lst(['b'], no_create_entry=True) | |||
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=4, min_freq=2, lower=True, | |||
normalize=True) | |||
embed.weight.data += 0.2 # 使得它不是normalize | |||
embed.save(static_test_folder) | |||
load_embed = StaticEmbedding.load(static_test_folder) | |||
words = torch.randint(len(vocab), size=(2, 20)) | |||
self.assertEqual((embed(words) - load_embed(words)).sum(), 0) | |||
finally: | |||
if os.path.isdir(static_test_folder): | |||
import shutil | |||
shutil.rmtree(static_test_folder) | |||
def read_static_embed(fp): | |||
""" | |||
@@ -30,11 +30,11 @@ class TestLoad(unittest.TestCase): | |||
'imdb': ('test/data_for_tests/io/imdb', IMDBLoader, (6, 6, 6), False), | |||
'ChnSentiCorp': ('test/data_for_tests/io/ChnSentiCorp', ChnSentiCorpLoader, (6, 6, 6), False), | |||
'THUCNews': ('test/data_for_tests/io/THUCNews', THUCNewsLoader, (9, 9, 9), False), | |||
'WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kLoader, (7, 6, 6), False), | |||
'WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kLoader, (6, 6, 7), False), | |||
} | |||
for k, v in data_set_dict.items(): | |||
path, loader, data_set, warns = v | |||
with self.subTest(loader=loader): | |||
with self.subTest(path=path): | |||
if warns: | |||
with self.assertWarns(Warning): | |||
data_bundle = loader().load(path) | |||
@@ -45,5 +45,6 @@ class TestLoad(unittest.TestCase): | |||
self.assertEqual(len(data_set), data_bundle.num_dataset) | |||
for x, y in zip(data_set, data_bundle.iter_datasets()): | |||
name, dataset = y | |||
self.assertEqual(x, len(dataset)) | |||
with self.subTest(split=name): | |||
self.assertEqual(x, len(dataset)) | |||
@@ -32,7 +32,7 @@ class TestMatchingLoad(unittest.TestCase): | |||
'Quora': ('test/data_for_tests/io/Quora', QuoraLoader, (2, 2, 2), False), | |||
'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusLoader, (5, 5, 5), False), | |||
'XNLI': ('test/data_for_tests/io/XNLI', CNXNLILoader, (6, 8, 6), False), | |||
'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCLoader, (5, 6, 6), False), | |||
'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCLoader, (6, 6, 5), False), | |||
} | |||
for k, v in data_set_dict.items(): | |||
path, loader, instance, warns = v | |||
@@ -46,5 +46,6 @@ class TestMatchingLoad(unittest.TestCase): | |||
self.assertEqual(len(instance), data_bundle.num_dataset) | |||
for x, y in zip(instance, data_bundle.iter_datasets()): | |||
name, dataset = y | |||
self.assertEqual(x, len(dataset)) | |||
with self.subTest(path=path, split=name): | |||
self.assertEqual(x, len(dataset)) | |||
@@ -70,7 +70,7 @@ class TestRunClassificationPipe(unittest.TestCase): | |||
} | |||
for k, v in data_set_dict.items(): | |||
path, pipe, data_set, vocab, warns = v | |||
with self.subTest(pipe=pipe): | |||
with self.subTest(path=path): | |||
if 'Chn' not in k: | |||
if warns: | |||
with self.assertWarns(Warning): | |||
@@ -39,7 +39,7 @@ class TestRunMatchingPipe(unittest.TestCase): | |||
'MNLI': ('test/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True), | |||
'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False), | |||
'XNLI': ('test/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 8, 6), (39, 3), False), | |||
'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (5, 6, 6), (36, 2), False), | |||
'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (6, 6, 5), (36, 2), False), | |||
} | |||
for k, v in data_set_dict.items(): | |||
path, pipe1, pipe2, data_set, vocab, warns = v | |||
@@ -58,7 +58,8 @@ class TestRunMatchingPipe(unittest.TestCase): | |||
print(data_bundle2) | |||
for x, y in zip(data_set, data_bundle1.iter_datasets()): | |||
name, dataset = y | |||
self.assertEqual(x, len(dataset)) | |||
with self.subTest(path=path, split=name): | |||
self.assertEqual(x, len(dataset)) | |||
self.assertEqual(len(data_set), data_bundle2.num_dataset) | |||
for x, y in zip(data_set, data_bundle2.iter_datasets()): | |||
name, dataset = y | |||
@@ -0,0 +1,76 @@ | |||
import unittest | |||
from fastNLP.models import SequenceGeneratorModel | |||
from fastNLP.models import LSTMSeq2SeqModel, TransformerSeq2SeqModel | |||
from fastNLP import Vocabulary, DataSet | |||
import torch | |||
from fastNLP.embeddings import StaticEmbedding | |||
from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric | |||
from fastNLP import Callback | |||
def prepare_env(): | |||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
vocab.add_word_lst("Another test !".split()) | |||
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5) | |||
src_words_idx = [[3, 1, 2], [1, 2]] | |||
# tgt_words_idx = [[1, 2, 3, 4], [2, 3]] | |||
src_seq_len = [3, 2] | |||
# tgt_seq_len = [4, 2] | |||
ds = DataSet({'src_tokens': src_words_idx, 'src_seq_len': src_seq_len, 'tgt_tokens': src_words_idx, | |||
'tgt_seq_len':src_seq_len}) | |||
ds.set_input('src_tokens', 'tgt_tokens', 'src_seq_len') | |||
ds.set_target('tgt_seq_len', 'tgt_tokens') | |||
return embed, ds | |||
class ExitCallback(Callback): | |||
def __init__(self): | |||
super().__init__() | |||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||
if eval_result['AccuracyMetric']['acc']==1: | |||
raise KeyboardInterrupt() | |||
class TestSeq2SeqGeneratorModel(unittest.TestCase): | |||
def test_run(self): | |||
# 检测是否能够使用SequenceGeneratorModel训练, 透传预测 | |||
embed, ds = prepare_env() | |||
model1 = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||
pos_embed='sin', max_position=20, num_layers=2, d_model=30, n_head=6, | |||
dim_ff=20, dropout=0.1, | |||
bind_encoder_decoder_embed=True, | |||
bind_decoder_input_output_embed=True) | |||
trainer = Trainer(ds, model1, optimizer=None, loss=CrossEntropyLoss(target='tgt_tokens', seq_len='tgt_seq_len'), | |||
batch_size=32, sampler=None, drop_last=False, update_every=1, | |||
num_workers=0, n_epochs=100, print_every=5, | |||
dev_data=ds, metrics=AccuracyMetric(target='tgt_tokens', seq_len='tgt_seq_len'), metric_key=None, | |||
validate_every=-1, save_path=None, use_tqdm=False, device=None, | |||
callbacks=ExitCallback(), check_code_level=0) | |||
res = trainer.train() | |||
self.assertEqual(res['best_eval']['AccuracyMetric']['acc'], 1) | |||
embed, ds = prepare_env() | |||
model2 = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||
num_layers=1, hidden_size=20, dropout=0.1, | |||
bind_encoder_decoder_embed=True, | |||
bind_decoder_input_output_embed=True, attention=True) | |||
optimizer = torch.optim.Adam(model2.parameters(), lr=0.01) | |||
trainer = Trainer(ds, model2, optimizer=optimizer, loss=CrossEntropyLoss(target='tgt_tokens', seq_len='tgt_seq_len'), | |||
batch_size=32, sampler=None, drop_last=False, update_every=1, | |||
num_workers=0, n_epochs=200, print_every=1, | |||
dev_data=ds, metrics=AccuracyMetric(target='tgt_tokens', seq_len='tgt_seq_len'), | |||
metric_key=None, | |||
validate_every=-1, save_path=None, use_tqdm=False, device=None, | |||
callbacks=ExitCallback(), check_code_level=0) | |||
res = trainer.train() | |||
self.assertEqual(res['best_eval']['AccuracyMetric']['acc'], 1) | |||
@@ -0,0 +1,114 @@ | |||
import unittest | |||
from fastNLP.models.seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel | |||
from fastNLP import Vocabulary | |||
from fastNLP.embeddings import StaticEmbedding | |||
import torch | |||
from torch import optim | |||
import torch.nn.functional as F | |||
from fastNLP import seq_len_to_mask | |||
def prepare_env(): | |||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
vocab.add_word_lst("Another test !".split()) | |||
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5) | |||
src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]]) | |||
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) | |||
src_seq_len = torch.LongTensor([3, 2]) | |||
tgt_seq_len = torch.LongTensor([4, 2]) | |||
return embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len | |||
def train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len): | |||
optimizer = optim.Adam(model.parameters(), lr=1e-2) | |||
mask = seq_len_to_mask(tgt_seq_len).eq(0) | |||
target = tgt_words_idx.masked_fill(mask, -100) | |||
for i in range(100): | |||
optimizer.zero_grad() | |||
pred = model(src_words_idx, tgt_words_idx, src_seq_len)['pred'] # bsz x max_len x vocab_size | |||
loss = F.cross_entropy(pred.transpose(1, 2), target) | |||
loss.backward() | |||
optimizer.step() | |||
right_count = pred.argmax(dim=-1).eq(target).masked_fill(mask, 1).sum() | |||
return right_count | |||
class TestTransformerSeq2SeqModel(unittest.TestCase): | |||
def test_run(self): | |||
# 测试能否跑通 | |||
embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() | |||
for pos_embed in ['learned', 'sin']: | |||
with self.subTest(pos_embed=pos_embed): | |||
model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||
pos_embed=pos_embed, max_position=20, num_layers=2, d_model=30, n_head=6, dim_ff=20, dropout=0.1, | |||
bind_encoder_decoder_embed=True, | |||
bind_decoder_input_output_embed=True) | |||
output = model(src_words_idx, tgt_words_idx, src_seq_len) | |||
self.assertEqual(output['pred'].size(), (2, 4, len(embed))) | |||
for bind_encoder_decoder_embed in [True, False]: | |||
tgt_embed = None | |||
for bind_decoder_input_output_embed in [True, False]: | |||
if bind_encoder_decoder_embed == False: | |||
tgt_embed = embed | |||
with self.subTest(bind_encoder_decoder_embed=bind_encoder_decoder_embed, | |||
bind_decoder_input_output_embed=bind_decoder_input_output_embed): | |||
model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=tgt_embed, | |||
pos_embed='sin', max_position=20, num_layers=2, | |||
d_model=30, n_head=6, dim_ff=20, dropout=0.1, | |||
bind_encoder_decoder_embed=bind_encoder_decoder_embed, | |||
bind_decoder_input_output_embed=bind_decoder_input_output_embed) | |||
output = model(src_words_idx, tgt_words_idx, src_seq_len) | |||
self.assertEqual(output['pred'].size(), (2, 4, len(embed))) | |||
def test_train(self): | |||
# 测试能否train到overfit | |||
embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() | |||
model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||
pos_embed='sin', max_position=20, num_layers=2, d_model=30, n_head=6, dim_ff=20, dropout=0.1, | |||
bind_encoder_decoder_embed=True, | |||
bind_decoder_input_output_embed=True) | |||
right_count = train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len) | |||
self.assertEqual(right_count, tgt_words_idx.nelement()) | |||
class TestLSTMSeq2SeqModel(unittest.TestCase): | |||
def test_run(self): | |||
# 测试能否跑通 | |||
embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() | |||
for bind_encoder_decoder_embed in [True, False]: | |||
tgt_embed = None | |||
for bind_decoder_input_output_embed in [True, False]: | |||
if bind_encoder_decoder_embed == False: | |||
tgt_embed = embed | |||
with self.subTest(bind_encoder_decoder_embed=bind_encoder_decoder_embed, | |||
bind_decoder_input_output_embed=bind_decoder_input_output_embed): | |||
model = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=tgt_embed, | |||
num_layers=2, hidden_size=20, dropout=0.1, | |||
bind_encoder_decoder_embed=bind_encoder_decoder_embed, | |||
bind_decoder_input_output_embed=bind_decoder_input_output_embed) | |||
output = model(src_words_idx, tgt_words_idx, src_seq_len) | |||
self.assertEqual(output['pred'].size(), (2, 4, len(embed))) | |||
def test_train(self): | |||
embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() | |||
model = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||
num_layers=1, hidden_size=20, dropout=0.1, | |||
bind_encoder_decoder_embed=True, | |||
bind_decoder_input_output_embed=True) | |||
right_count = train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len) | |||
self.assertEqual(right_count, tgt_words_idx.nelement()) | |||
@@ -0,0 +1,50 @@ | |||
import unittest | |||
import torch | |||
from fastNLP import Vocabulary | |||
from fastNLP.embeddings import StaticEmbedding | |||
from fastNLP.modules import TransformerSeq2SeqDecoder | |||
from fastNLP.modules import LSTMSeq2SeqDecoder | |||
from fastNLP import seq_len_to_mask | |||
class TestTransformerSeq2SeqDecoder(unittest.TestCase): | |||
def test_case(self): | |||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
vocab.add_word_lst("Another test !".split()) | |||
embed = StaticEmbedding(vocab, embedding_dim=10) | |||
encoder_output = torch.randn(2, 3, 10) | |||
src_seq_len = torch.LongTensor([3, 2]) | |||
encoder_mask = seq_len_to_mask(src_seq_len) | |||
for flag in [True, False]: | |||
with self.subTest(bind_decoder_input_output_embed=flag): | |||
decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed = None, | |||
d_model = 10, num_layers=2, n_head = 5, dim_ff = 20, dropout = 0.1, | |||
bind_decoder_input_output_embed = True) | |||
state = decoder.init_state(encoder_output, encoder_mask) | |||
output = decoder(tokens=torch.randint(0, len(vocab), size=(2, 4)), state=state) | |||
self.assertEqual(output.size(), (2, 4, len(vocab))) | |||
class TestLSTMDecoder(unittest.TestCase): | |||
def test_case(self): | |||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
vocab.add_word_lst("Another test !".split()) | |||
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=10) | |||
encoder_output = torch.randn(2, 3, 10) | |||
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) | |||
src_seq_len = torch.LongTensor([3, 2]) | |||
encoder_mask = seq_len_to_mask(src_seq_len) | |||
for flag in [True, False]: | |||
for attention in [True, False]: | |||
with self.subTest(bind_decoder_input_output_embed=flag, attention=attention): | |||
decoder = LSTMSeq2SeqDecoder(embed=embed, num_layers = 2, hidden_size = 10, | |||
dropout = 0.3, bind_decoder_input_output_embed=flag, attention=attention) | |||
state = decoder.init_state(encoder_output, encoder_mask) | |||
output = decoder(tgt_words_idx, state) | |||
self.assertEqual(tuple(output.size()), (2, 4, len(vocab))) |
@@ -0,0 +1,30 @@ | |||
import unittest | |||
import torch | |||
from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder | |||
from fastNLP import Vocabulary | |||
from fastNLP.embeddings import StaticEmbedding | |||
class TestTransformerSeq2SeqEncoder(unittest.TestCase): | |||
def test_case(self): | |||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
embed = StaticEmbedding(vocab, embedding_dim=5) | |||
encoder = TransformerSeq2SeqEncoder(embed, num_layers=2, d_model=10, n_head=2) | |||
words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0) | |||
seq_len = torch.LongTensor([3]) | |||
encoder_output, encoder_mask = encoder(words_idx, seq_len) | |||
self.assertEqual(encoder_output.size(), (1, 3, 10)) | |||
class TestBiLSTMEncoder(unittest.TestCase): | |||
def test_case(self): | |||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
embed = StaticEmbedding(vocab, embedding_dim=5) | |||
encoder = LSTMSeq2SeqEncoder(embed, hidden_size=5, num_layers=1) | |||
words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0) | |||
seq_len = torch.LongTensor([3]) | |||
encoder_output, encoder_mask = encoder(words_idx, seq_len) | |||
self.assertEqual(encoder_mask.size(), (1, 3)) |
@@ -0,0 +1 @@ | |||
@@ -0,0 +1,110 @@ | |||
import unittest | |||
import torch | |||
from fastNLP.modules.generator import SequenceGenerator | |||
from fastNLP.modules import TransformerSeq2SeqDecoder, LSTMSeq2SeqDecoder, Seq2SeqDecoder, State | |||
from fastNLP import Vocabulary | |||
from fastNLP.embeddings import StaticEmbedding | |||
from torch import nn | |||
from fastNLP import seq_len_to_mask | |||
def prepare_env(): | |||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
vocab.add_word_lst("Another test !".split()) | |||
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5) | |||
encoder_output = torch.randn(2, 3, 10) | |||
src_seq_len = torch.LongTensor([3, 2]) | |||
encoder_mask = seq_len_to_mask(src_seq_len) | |||
return embed, encoder_output, encoder_mask | |||
class TestSequenceGenerator(unittest.TestCase): | |||
def test_run(self): | |||
# 测试能否运行 (1) 初始化decoder,(2) decode一发 | |||
embed, encoder_output, encoder_mask = prepare_env() | |||
for do_sample in [True, False]: | |||
for num_beams in [1, 3, 5]: | |||
with self.subTest(do_sample=do_sample, num_beams=num_beams): | |||
decoder = LSTMSeq2SeqDecoder(embed=embed, num_layers=1, hidden_size=10, | |||
dropout=0.3, bind_decoder_input_output_embed=True, attention=True) | |||
state = decoder.init_state(encoder_output, encoder_mask) | |||
generator = SequenceGenerator(decoder=decoder, max_length=20, num_beams=num_beams, | |||
do_sample=do_sample, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=1, eos_token_id=None, | |||
repetition_penalty=1, length_penalty=1.0, pad_token_id=0) | |||
generator.generate(state=state, tokens=None) | |||
decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed=nn.Embedding(10, embed.embedding_dim), | |||
d_model=encoder_output.size(-1), num_layers=2, n_head=2, dim_ff=10, dropout=0.1, | |||
bind_decoder_input_output_embed=True) | |||
state = decoder.init_state(encoder_output, encoder_mask) | |||
generator = SequenceGenerator(decoder=decoder, max_length=5, num_beams=num_beams, | |||
do_sample=do_sample, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=1, eos_token_id=None, | |||
repetition_penalty=1, length_penalty=1.0, pad_token_id=0) | |||
generator.generate(state=state, tokens=None) | |||
# 测试一下其它值 | |||
decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed=nn.Embedding(10, embed.embedding_dim), | |||
d_model=encoder_output.size(-1), num_layers=2, n_head=2, dim_ff=10, | |||
dropout=0.1, | |||
bind_decoder_input_output_embed=True) | |||
state = decoder.init_state(encoder_output, encoder_mask) | |||
generator = SequenceGenerator(decoder=decoder, max_length=5, num_beams=num_beams, | |||
do_sample=do_sample, temperature=0.9, top_k=50, top_p=0.5, bos_token_id=1, | |||
eos_token_id=3, repetition_penalty=2, length_penalty=1.5, pad_token_id=0) | |||
generator.generate(state=state, tokens=None) | |||
def test_greedy_decode(self): | |||
# 测试能否正确的generate | |||
class GreedyDummyDecoder(Seq2SeqDecoder): | |||
def __init__(self, decoder_output): | |||
super().__init__() | |||
self.cur_length = 0 | |||
self.decoder_output = decoder_output | |||
def decode(self, tokens, state): | |||
self.cur_length += 1 | |||
scores = self.decoder_output[:, self.cur_length] | |||
return scores | |||
class DummyState(State): | |||
def __init__(self, decoder): | |||
super().__init__() | |||
self.decoder = decoder | |||
def reorder_state(self, indices: torch.LongTensor): | |||
self.decoder.decoder_output = self._reorder_state(self.decoder.decoder_output, indices, dim=0) | |||
# greedy | |||
for beam_search in [1, 3]: | |||
decoder_output = torch.randn(2, 10, 5) | |||
path = decoder_output.argmax(dim=-1) # 2 x 4 | |||
decoder = GreedyDummyDecoder(decoder_output) | |||
with self.subTest(beam_search=beam_search): | |||
generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search, | |||
do_sample=False, temperature=1, top_k=50, top_p=1, bos_token_id=1, | |||
eos_token_id=None, repetition_penalty=1, length_penalty=1, pad_token_id=0) | |||
decode_path = generator.generate(DummyState(decoder), tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True)) | |||
self.assertEqual(decode_path.eq(path).sum(), path.numel()) | |||
# greedy check eos_token_id | |||
for beam_search in [1, 3]: | |||
decoder_output = torch.randn(2, 10, 5) | |||
decoder_output[:, :7, 4].fill_(-100) | |||
decoder_output[0, 7, 4] = 1000 # 在第8个结束 | |||
decoder_output[1, 5, 4] = 1000 | |||
path = decoder_output.argmax(dim=-1) # 2 x 4 | |||
decoder = GreedyDummyDecoder(decoder_output) | |||
with self.subTest(beam_search=beam_search): | |||
generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search, | |||
do_sample=False, temperature=1, top_k=50, top_p=0.5, bos_token_id=1, | |||
eos_token_id=4, repetition_penalty=1, length_penalty=1, pad_token_id=0) | |||
decode_path = generator.generate(DummyState(decoder), | |||
tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True)) | |||
self.assertEqual(decode_path.size(1), 8) # 长度为8 | |||
self.assertEqual(decode_path[0].eq(path[0, :8]).sum(), 8) | |||
self.assertEqual(decode_path[1, :6].eq(path[1, :6]).sum(), 6) |