@@ -16,6 +16,7 @@ from ._logger import logger | |||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .utils import Option | from .utils import Option | ||||
from .utils import _is_iterable | from .utils import _is_iterable | ||||
import io | |||||
class VocabularyOption(Option): | class VocabularyOption(Option): | ||||
@@ -487,76 +488,99 @@ class Vocabulary(object): | |||||
def save(self, filepath): | def save(self, filepath): | ||||
r""" | r""" | ||||
:param str filepath: Vocabulary的储存路径 | |||||
:param str,io.StringIO filepath: Vocabulary的储存路径 | |||||
:return: | :return: | ||||
""" | """ | ||||
with open(filepath, 'w', encoding='utf-8') as f: | |||||
f.write(f'max_size\t{self.max_size}\n') | |||||
f.write(f'min_freq\t{self.min_freq}\n') | |||||
f.write(f'unknown\t{self.unknown}\n') | |||||
f.write(f'padding\t{self.padding}\n') | |||||
f.write(f'rebuild\t{self.rebuild}\n') | |||||
f.write('\n') | |||||
# idx: 如果idx为-2, 说明还没有进行build; 如果idx为-1,说明该词未编入 | |||||
# no_create_entry: 如果为1,说明该词是no_create_entry; 0 otherwise | |||||
# word \t count \t idx \t no_create_entry \n | |||||
idx = -2 | |||||
for word, count in self.word_count.items(): | |||||
if self._word2idx is not None: | |||||
idx = self._word2idx.get(word, -1) | |||||
is_no_create_entry = int(self._is_word_no_create_entry(word)) | |||||
f.write(f'{word}\t{count}\t{idx}\t{is_no_create_entry}\n') | |||||
if isinstance(filepath, io.IOBase): | |||||
assert filepath.writable() | |||||
f = filepath | |||||
elif isinstance(filepath, str): | |||||
try: | |||||
f = open(filepath, 'w', encoding='utf-8') | |||||
except Exception as e: | |||||
raise e | |||||
else: | |||||
raise TypeError("Illegal `filepath`.") | |||||
f.write(f'max_size\t{self.max_size}\n') | |||||
f.write(f'min_freq\t{self.min_freq}\n') | |||||
f.write(f'unknown\t{self.unknown}\n') | |||||
f.write(f'padding\t{self.padding}\n') | |||||
f.write(f'rebuild\t{self.rebuild}\n') | |||||
f.write('\n') | |||||
# idx: 如果idx为-2, 说明还没有进行build; 如果idx为-1,说明该词未编入 | |||||
# no_create_entry: 如果为1,说明该词是no_create_entry; 0 otherwise | |||||
# word \t count \t idx \t no_create_entry \n | |||||
idx = -2 | |||||
for word, count in self.word_count.items(): | |||||
if self._word2idx is not None: | |||||
idx = self._word2idx.get(word, -1) | |||||
is_no_create_entry = int(self._is_word_no_create_entry(word)) | |||||
f.write(f'{word}\t{count}\t{idx}\t{is_no_create_entry}\n') | |||||
if isinstance(filepath, str): # 如果是file的话就关闭 | |||||
f.close() | |||||
@staticmethod | @staticmethod | ||||
def load(filepath): | def load(filepath): | ||||
r""" | r""" | ||||
:param str filepath: Vocabulary的读取路径 | |||||
:param str,io.StringIO filepath: Vocabulary的读取路径 | |||||
:return: Vocabulary | :return: Vocabulary | ||||
""" | """ | ||||
with open(filepath, 'r', encoding='utf-8') as f: | |||||
vocab = Vocabulary() | |||||
for line in f: | |||||
line = line.strip() | |||||
if line: | |||||
name, value = line.split() | |||||
if name in ('max_size', 'min_freq'): | |||||
value = int(value) if value!='None' else None | |||||
setattr(vocab, name, value) | |||||
elif name in ('unknown', 'padding'): | |||||
value = value if value!='None' else None | |||||
setattr(vocab, name, value) | |||||
elif name == 'rebuild': | |||||
vocab.rebuild = True if value=='True' else False | |||||
else: | |||||
break | |||||
word_counter = {} | |||||
no_create_entry_counter = {} | |||||
word2idx = {} | |||||
for line in f: | |||||
line = line.strip() | |||||
if line: | |||||
parts = line.split('\t') | |||||
word,count,idx,no_create_entry = parts[0], int(parts[1]), int(parts[2]), int(parts[3]) | |||||
if idx >= 0: | |||||
word2idx[word] = idx | |||||
word_counter[word] = count | |||||
if no_create_entry: | |||||
no_create_entry_counter[word] = count | |||||
word_counter = Counter(word_counter) | |||||
no_create_entry_counter = Counter(no_create_entry_counter) | |||||
if len(word2idx)>0: | |||||
if vocab.padding: | |||||
word2idx[vocab.padding] = 0 | |||||
if vocab.unknown: | |||||
word2idx[vocab.unknown] = 1 if vocab.padding else 0 | |||||
idx2word = {value:key for key,value in word2idx.items()} | |||||
vocab.word_count = word_counter | |||||
vocab._no_create_word = no_create_entry_counter | |||||
if word2idx: | |||||
vocab._word2idx = word2idx | |||||
vocab._idx2word = idx2word | |||||
if isinstance(filepath, io.IOBase): | |||||
assert filepath.writable() | |||||
f = filepath | |||||
elif isinstance(filepath, str): | |||||
try: | |||||
f = open(filepath, 'r', encoding='utf-8') | |||||
except Exception as e: | |||||
raise e | |||||
else: | |||||
raise TypeError("Illegal `filepath`.") | |||||
vocab = Vocabulary() | |||||
for line in f: | |||||
line = line.strip() | |||||
if line: | |||||
name, value = line.split() | |||||
if name in ('max_size', 'min_freq'): | |||||
value = int(value) if value!='None' else None | |||||
setattr(vocab, name, value) | |||||
elif name in ('unknown', 'padding'): | |||||
value = value if value!='None' else None | |||||
setattr(vocab, name, value) | |||||
elif name == 'rebuild': | |||||
vocab.rebuild = True if value=='True' else False | |||||
else: | |||||
break | |||||
word_counter = {} | |||||
no_create_entry_counter = {} | |||||
word2idx = {} | |||||
for line in f: | |||||
line = line.strip() | |||||
if line: | |||||
parts = line.split('\t') | |||||
word,count,idx,no_create_entry = parts[0], int(parts[1]), int(parts[2]), int(parts[3]) | |||||
if idx >= 0: | |||||
word2idx[word] = idx | |||||
word_counter[word] = count | |||||
if no_create_entry: | |||||
no_create_entry_counter[word] = count | |||||
word_counter = Counter(word_counter) | |||||
no_create_entry_counter = Counter(no_create_entry_counter) | |||||
if len(word2idx)>0: | |||||
if vocab.padding: | |||||
word2idx[vocab.padding] = 0 | |||||
if vocab.unknown: | |||||
word2idx[vocab.unknown] = 1 if vocab.padding else 0 | |||||
idx2word = {value:key for key,value in word2idx.items()} | |||||
vocab.word_count = word_counter | |||||
vocab._no_create_word = no_create_entry_counter | |||||
if word2idx: | |||||
vocab._word2idx = word2idx | |||||
vocab._idx2word = idx2word | |||||
if isinstance(filepath, str): # 如果是file的话就关闭 | |||||
f.close() | |||||
return vocab | return vocab |
@@ -22,8 +22,9 @@ __all__ = [ | |||||
"StackEmbedding", | "StackEmbedding", | ||||
"LSTMCharEmbedding", | "LSTMCharEmbedding", | ||||
"CNNCharEmbedding", | "CNNCharEmbedding", | ||||
"get_embeddings", | |||||
"get_embeddings", | |||||
"get_sinusoid_encoding_table" | |||||
] | ] | ||||
from .embedding import Embedding, TokenEmbedding | from .embedding import Embedding, TokenEmbedding | ||||
@@ -34,7 +35,7 @@ from .roberta_embedding import RobertaEmbedding, RobertaWordPieceEncoder | |||||
from .gpt2_embedding import GPT2WordPieceEncoder, GPT2Embedding | from .gpt2_embedding import GPT2WordPieceEncoder, GPT2Embedding | ||||
from .char_embedding import CNNCharEmbedding, LSTMCharEmbedding | from .char_embedding import CNNCharEmbedding, LSTMCharEmbedding | ||||
from .stack_embedding import StackEmbedding | from .stack_embedding import StackEmbedding | ||||
from .utils import get_embeddings | |||||
from .utils import get_embeddings, get_sinusoid_encoding_table | |||||
import sys | import sys | ||||
from ..doc_utils import doc_process | from ..doc_utils import doc_process |
@@ -8,11 +8,11 @@ __all__ = [ | |||||
"BertWordPieceEncoder" | "BertWordPieceEncoder" | ||||
] | ] | ||||
import collections | |||||
import os | |||||
import warnings | import warnings | ||||
from itertools import chain | from itertools import chain | ||||
from functools import partial | from functools import partial | ||||
import json | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from torch import nn | from torch import nn | ||||
@@ -24,6 +24,13 @@ from ..io.file_utils import PRETRAINED_BERT_MODEL_DIR | |||||
from ..modules.encoder.bert import BertModel | from ..modules.encoder.bert import BertModel | ||||
from ..modules.tokenizer import BertTokenizer | from ..modules.tokenizer import BertTokenizer | ||||
# TODO 需要重新修改,使得encoder可以直接读取embedding的权重 | |||||
VOCAB_NAME = 'vocab.txt' | |||||
BERT_EMBED_HYPER = 'bert_hyper.json' | |||||
BERT_EMBED_FOLDER = 'bert' | |||||
BERT_ENCODER_HYPER = 'bert_hyper.json' | |||||
BERT_ENCODER_FOLDER = 'bert' | |||||
class BertEmbedding(ContextualEmbedding): | class BertEmbedding(ContextualEmbedding): | ||||
r""" | r""" | ||||
@@ -82,10 +89,7 @@ class BertEmbedding(ContextualEmbedding): | |||||
word pieces后的内容,并将第512个word piece置为[SEP]。超过长度的部分的encode结果直接全部置零。一般仅有只使用[CLS] | word pieces后的内容,并将第512个word piece置为[SEP]。超过长度的部分的encode结果直接全部置零。一般仅有只使用[CLS] | ||||
来进行分类的任务将auto_truncate置为True。 | 来进行分类的任务将auto_truncate置为True。 | ||||
:param kwargs: | :param kwargs: | ||||
bool only_use_pretrain_bpe: 仅使用出现在pretrain词表中的bpe,如果该词没法tokenize则使用unk。如果embedding不需要更新 | |||||
建议设置为True。 | |||||
int min_freq: 仅在only_use_pretrain_bpe为False有效,大于等于该次数的词会被新加入BERT的BPE词表中 | |||||
bool truncate_embed: 是否仅保留用到的bpe(这样会减内存占用和加快速度) | |||||
int min_freq: 小于该次数的词会被unk代替 | |||||
""" | """ | ||||
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | ||||
@@ -106,14 +110,11 @@ class BertEmbedding(ContextualEmbedding): | |||||
if '[CLS]' in vocab: | if '[CLS]' in vocab: | ||||
self._word_cls_index = vocab['CLS'] | self._word_cls_index = vocab['CLS'] | ||||
only_use_pretrain_bpe = kwargs.get('only_use_pretrain_bpe', False) | |||||
truncate_embed = kwargs.get('truncate_embed', True) | |||||
min_freq = kwargs.get('min_freq', 2) | min_freq = kwargs.get('min_freq', 2) | ||||
self._min_freq = min_freq | |||||
self.model = _BertWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | self.model = _BertWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | ||||
pool_method=pool_method, include_cls_sep=include_cls_sep, | pool_method=pool_method, include_cls_sep=include_cls_sep, | ||||
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq, | |||||
only_use_pretrain_bpe=only_use_pretrain_bpe, truncate_embed=truncate_embed) | |||||
pooled_cls=pooled_cls, min_freq=min_freq, auto_truncate=auto_truncate) | |||||
self.requires_grad = requires_grad | self.requires_grad = requires_grad | ||||
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | ||||
@@ -160,6 +161,57 @@ class BertEmbedding(ContextualEmbedding): | |||||
words = words.masked_fill(mask, self._word_unk_index) | words = words.masked_fill(mask, self._word_unk_index) | ||||
return words | return words | ||||
def save(self, folder): | |||||
""" | |||||
将embedding保存到folder这个目录下,将会保存三个文件vocab.txt, bert_embed_hyper.txt, bert_embed/, 其中bert_embed下包含 | |||||
config.json,pytorch_model.bin,vocab.txt三个文件(该folder下的数据也可以直接被BERTModel读取) | |||||
:param str folder: | |||||
:return: | |||||
""" | |||||
os.makedirs(folder, exist_ok=True) | |||||
self.get_word_vocab().save(os.path.join(folder, VOCAB_NAME)) | |||||
hyper = {} | |||||
hyper['min_freq'] = self._min_freq | |||||
hyper['layers'] = ','.join(map(str, self.model.layers)) | |||||
hyper['pool_method'] = self.model.pool_method | |||||
hyper['dropout'] = self.dropout_layer.p | |||||
hyper['word_dropout'] = self.word_dropout | |||||
hyper['include_cls_sep'] = self.model.include_cls_sep | |||||
hyper['pooled_cls'] = self.model.pooled_cls | |||||
hyper['auto_truncate'] = self.model.auto_truncate | |||||
hyper['requires_grad'] = bool(self.requires_grad) | |||||
with open(os.path.join(folder, BERT_EMBED_HYPER), 'w', encoding='utf-8') as f: | |||||
json.dump(hyper, f, indent=2) | |||||
os.makedirs(os.path.join(folder, BERT_EMBED_FOLDER), exist_ok=True) | |||||
self.model.save(os.path.join(folder, BERT_EMBED_FOLDER)) | |||||
logger.debug(f"BERTEmbedding has been saved in {folder}") | |||||
@classmethod | |||||
def load(cls, folder): | |||||
""" | |||||
给定一个folder, 需要包含以下三个内容vocab.txt, bert_embed_hyper.txt, bert_embed/ | |||||
:param str folder: | |||||
:return: | |||||
""" | |||||
for name in [VOCAB_NAME, BERT_EMBED_FOLDER, BERT_EMBED_HYPER]: | |||||
assert os.path.exists(os.path.join(folder, name)), f"{name} not found in {folder}." | |||||
vocab = Vocabulary.load(os.path.join(folder, VOCAB_NAME)) | |||||
with open(os.path.join(folder, BERT_EMBED_HYPER), 'r', encoding='utf-8') as f: | |||||
hyper = json.load(f) | |||||
model_dir_or_name = os.path.join(os.path.join(folder, BERT_EMBED_FOLDER)) | |||||
bert_embed = cls(vocab=vocab, model_dir_or_name=model_dir_or_name, **hyper) | |||||
return bert_embed | |||||
class BertWordPieceEncoder(nn.Module): | class BertWordPieceEncoder(nn.Module): | ||||
r""" | r""" | ||||
@@ -180,7 +232,7 @@ class BertWordPieceEncoder(nn.Module): | |||||
""" | """ | ||||
def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False, | def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False, | ||||
word_dropout=0, dropout=0, requires_grad: bool = True): | |||||
word_dropout=0, dropout=0, requires_grad: bool = True, **kwargs): | |||||
r""" | r""" | ||||
:param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased`` | :param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased`` | ||||
@@ -270,11 +322,53 @@ class BertWordPieceEncoder(nn.Module): | |||||
words = words.masked_fill(mask, self._wordpiece_unk_index) | words = words.masked_fill(mask, self._wordpiece_unk_index) | ||||
return words | return words | ||||
def save(self, folder): | |||||
""" | |||||
会在folder下创建两个文件bert_encoder_hyper.json与bert_encoder/, bert_encoder下包含三个文件config.json, | |||||
pytorch_model.bin,vocab.txt三个文件(该folder下的数据也可以直接被BERTModel读取) | |||||
:param str folder: | |||||
:return: | |||||
""" | |||||
os.makedirs(folder, exist_ok=True) | |||||
hyper = {} | |||||
hyper['layers'] = ','.join(map(str, self.model.layers)) | |||||
hyper['dropout'] = self.dropout_layer.p | |||||
hyper['word_dropout'] = self.word_dropout | |||||
hyper['pooled_cls'] = self.model.pooled_cls | |||||
hyper['requires_grad'] = bool(self.requires_grad) | |||||
with open(os.path.join(folder, BERT_ENCODER_HYPER), 'w', encoding='utf-8') as f: | |||||
json.dump(hyper, f, indent=2) | |||||
os.makedirs(os.path.join(folder, BERT_ENCODER_FOLDER), exist_ok=True) | |||||
self.model.save(os.path.join(folder, BERT_ENCODER_FOLDER)) | |||||
logger.debug(f"BertWordPieceEncoder has been saved in {folder}") | |||||
@classmethod | |||||
def load(cls, folder): | |||||
""" | |||||
会在folder下创建两个文件bert_encoder_hyper.json与bert_encoder/, bert_encoder下包含三个文件 | |||||
:param folder: | |||||
:return: | |||||
""" | |||||
for name in [BERT_ENCODER_HYPER, BERT_ENCODER_FOLDER]: | |||||
assert os.path.exists(os.path.join(folder, name)), f"{name} not found in {folder}." | |||||
with open(os.path.join(folder, BERT_ENCODER_HYPER), 'r', encoding='utf-8') as f: | |||||
hyper = json.load(f) | |||||
model_dir_or_name = os.path.join(os.path.join(folder, BERT_ENCODER_FOLDER)) | |||||
bert_encoder = cls(model_dir_or_name=model_dir_or_name, **hyper) | |||||
return bert_encoder | |||||
class _BertWordModel(nn.Module): | class _BertWordModel(nn.Module): | ||||
def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', | def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', | ||||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2, | |||||
only_use_pretrain_bpe=False, truncate_embed=True): | |||||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2): | |||||
super().__init__() | super().__init__() | ||||
self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name) | self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name) | ||||
@@ -303,73 +397,8 @@ class _BertWordModel(nn.Module): | |||||
self.auto_truncate = auto_truncate | self.auto_truncate = auto_truncate | ||||
# 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP] | # 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP] | ||||
logger.info("Start to generate word pieces for word.") | |||||
self._has_sep_in_vocab = '[SEP]' in vocab # 用来判断传入的数据是否需要生成token_ids | self._has_sep_in_vocab = '[SEP]' in vocab # 用来判断传入的数据是否需要生成token_ids | ||||
# 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值 | |||||
word_piece_dict = {'[CLS]': 1, '[SEP]': 1} # 用到的word_piece以及新增的 | |||||
new_add_to_bpe_vocab = 0 | |||||
unsegment_count = 0 | |||||
if '[sep]' in vocab: | |||||
warnings.warn("Lower cased [sep] detected, it cannot be correctly recognized as [SEP] by BertEmbedding.") | |||||
if "[CLS]" in vocab: | |||||
warnings.warn("[CLS] detected in your vocabulary. BertEmbedding will add [CLS] and [SEP] to the begin " | |||||
"and end of the input automatically, make sure you don't add [CLS] and [SEP] at the begin" | |||||
" and end.") | |||||
for word, index in vocab: | |||||
if index == vocab.padding_idx: # pad是个特殊的符号 | |||||
word = '[PAD]' | |||||
elif index == vocab.unknown_idx: | |||||
word = '[UNK]' | |||||
_words = self.tokenzier.basic_tokenizer._tokenize_chinese_chars(word).split() | |||||
word_pieces = [] | |||||
for w in _words: | |||||
word_pieces.extend(self.tokenzier.wordpiece_tokenizer.tokenize(w)) | |||||
if len(word_pieces) == 1: | |||||
if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到 | |||||
if index != vocab.unknown_idx and word_pieces[0] == '[UNK]': # 说明这个词不在原始的word里面 | |||||
if vocab.word_count[word] >= min_freq and not vocab._is_word_no_create_entry( | |||||
word) and not only_use_pretrain_bpe: # 出现次数大于这个次数才新增 | |||||
word_piece_dict[word] = 1 # 新增一个值 | |||||
new_add_to_bpe_vocab += 1 | |||||
unsegment_count += 1 | |||||
continue | |||||
for word_piece in word_pieces: | |||||
word_piece_dict[word_piece] = 1 | |||||
original_embed = self.encoder.embeddings.word_embeddings.weight.data | |||||
# 特殊词汇要特殊处理 | |||||
if not truncate_embed:# 如果不删除的话需要将已有的加上 | |||||
word_piece_dict.update(self.tokenzier.vocab) | |||||
embed = nn.Embedding(len(word_piece_dict), original_embed.size(1)) # 新的embed | |||||
new_word_piece_vocab = collections.OrderedDict() | |||||
for index, token in enumerate(['[PAD]', '[UNK]']): | |||||
index = word_piece_dict.pop(token, None) | |||||
if index is not None: | |||||
new_word_piece_vocab[token] = len(new_word_piece_vocab) | |||||
embed.weight.data[new_word_piece_vocab[token]] = original_embed[self.tokenzier.vocab[token]] | |||||
for token in word_piece_dict.keys(): | |||||
if token not in new_word_piece_vocab: | |||||
new_word_piece_vocab[token] = len(new_word_piece_vocab) | |||||
index = new_word_piece_vocab[token] | |||||
if token in self.tokenzier.vocab: | |||||
embed.weight.data[index] = original_embed[self.tokenzier.vocab[token]] | |||||
else: | |||||
embed.weight.data[index] = original_embed[self.tokenzier.vocab['[UNK]']] | |||||
self.tokenzier._reinit_on_new_vocab(new_word_piece_vocab) | |||||
self.encoder.embeddings.word_embeddings = embed | |||||
self.encoder.config.vocab_size = len(new_word_piece_vocab) | |||||
if unsegment_count>0: | |||||
if only_use_pretrain_bpe or new_add_to_bpe_vocab==0: | |||||
logger.info(f"{unsegment_count} words are unsegmented.") | |||||
else: | |||||
logger.info(f"{unsegment_count} words are unsegmented. Among them, {new_add_to_bpe_vocab} added to the BPE vocab.") | |||||
word_to_wordpieces = [] | word_to_wordpieces = [] | ||||
word_pieces_lengths = [] | word_pieces_lengths = [] | ||||
for word, index in vocab: | for word, index in vocab: | ||||
@@ -377,6 +406,8 @@ class _BertWordModel(nn.Module): | |||||
word = '[PAD]' | word = '[PAD]' | ||||
elif index == vocab.unknown_idx: | elif index == vocab.unknown_idx: | ||||
word = '[UNK]' | word = '[UNK]' | ||||
elif vocab.word_count[word]<min_freq: | |||||
word = '[UNK]' | |||||
word_pieces = self.tokenzier.wordpiece_tokenizer.tokenize(word) | word_pieces = self.tokenzier.wordpiece_tokenizer.tokenize(word) | ||||
word_pieces = self.tokenzier.convert_tokens_to_ids(word_pieces) | word_pieces = self.tokenzier.convert_tokens_to_ids(word_pieces) | ||||
word_to_wordpieces.append(word_pieces) | word_to_wordpieces.append(word_pieces) | ||||
@@ -504,6 +535,16 @@ class _BertWordModel(nn.Module): | |||||
# 3. 最终的embedding结果 | # 3. 最终的embedding结果 | ||||
return outputs | return outputs | ||||
def save(self, folder): | |||||
""" | |||||
给定一个folder保存pytorch_model.bin, config.json, vocab.txt | |||||
:param str folder: | |||||
:return: | |||||
""" | |||||
self.tokenzier.save_pretrained(folder) | |||||
self.encoder.save_pretrained(folder) | |||||
class _BertWordPieceModel(nn.Module): | class _BertWordPieceModel(nn.Module): | ||||
r""" | r""" | ||||
@@ -580,4 +621,14 @@ class _BertWordPieceModel(nn.Module): | |||||
if l in (len(bert_outputs)-1, -1) and self.pooled_cls: | if l in (len(bert_outputs)-1, -1) and self.pooled_cls: | ||||
bert_output[:, 0] = pooled_cls | bert_output[:, 0] = pooled_cls | ||||
outputs[l_index] = bert_output | outputs[l_index] = bert_output | ||||
return outputs | |||||
return outputs | |||||
def save(self, folder): | |||||
""" | |||||
给定一个folder保存pytorch_model.bin, config.json, vocab.txt | |||||
:param folder: | |||||
:return: | |||||
""" | |||||
self.tokenzier.save_pretrained(folder) | |||||
self.encoder.save_pretrained(folder) |
@@ -78,7 +78,7 @@ class Embedding(nn.Module): | |||||
if isinstance(self.embed, nn.Embedding): | if isinstance(self.embed, nn.Embedding): | ||||
return self.embed.weight.size(0) | return self.embed.weight.size(0) | ||||
else: | else: | ||||
return self.embed.num_embedding | |||||
return self.embed.num_embeddings | |||||
def __len__(self): | def __len__(self): | ||||
return len(self.embed) | return len(self.embed) | ||||
@@ -188,7 +188,7 @@ class TokenEmbedding(nn.Module): | |||||
return self._embed_size | return self._embed_size | ||||
@property | @property | ||||
def num_embedding(self) -> int: | |||||
def num_embeddings(self) -> int: | |||||
r""" | r""" | ||||
这个值可能会大于实际的embedding矩阵的大小。 | 这个值可能会大于实际的embedding矩阵的大小。 | ||||
:return: | :return: | ||||
@@ -205,7 +205,7 @@ class TokenEmbedding(nn.Module): | |||||
@property | @property | ||||
def size(self): | def size(self): | ||||
return torch.Size(self.num_embedding, self._embed_size) | |||||
return torch.Size(self.num_embeddings, self._embed_size) | |||||
@abstractmethod | @abstractmethod | ||||
def forward(self, words): | def forward(self, words): | ||||
@@ -10,8 +10,8 @@ __all__ = [ | |||||
from functools import partial | from functools import partial | ||||
import collections | |||||
import warnings | |||||
import os | |||||
import json | |||||
from itertools import chain | from itertools import chain | ||||
import numpy as np | import numpy as np | ||||
@@ -24,6 +24,13 @@ from ..modules.encoder.roberta import RobertaModel | |||||
from ..modules.tokenizer import RobertaTokenizer | from ..modules.tokenizer import RobertaTokenizer | ||||
VOCAB_NAME = 'vocab.txt' | |||||
ROBERTA_EMBED_HYPER = 'roberta_hyper.json' | |||||
ROBERTA_ENCODER_HYPER = 'roberta_hyper.json' | |||||
ROBERTA_EMBED_FOLDER = 'roberta' | |||||
ROBERTA_ENCODER_FOLDER = 'roberta' | |||||
class RobertaEmbedding(ContextualEmbedding): | class RobertaEmbedding(ContextualEmbedding): | ||||
r""" | r""" | ||||
使用RoBERTa对words进行编码的Embedding。建议将输入的words长度限制在430以内,而不要使用512(根据预训练模型参数,可能有变化)。这是由于 | 使用RoBERTa对words进行编码的Embedding。建议将输入的words长度限制在430以内,而不要使用512(根据预训练模型参数,可能有变化)。这是由于 | ||||
@@ -71,10 +78,7 @@ class RobertaEmbedding(ContextualEmbedding): | |||||
word pieces后的内容,并将第512个word piece置为</s>。超过长度的部分的encode结果直接全部置零。一般仅有只使用<s> | word pieces后的内容,并将第512个word piece置为</s>。超过长度的部分的encode结果直接全部置零。一般仅有只使用<s> | ||||
来进行分类的任务将auto_truncate置为True。 | 来进行分类的任务将auto_truncate置为True。 | ||||
:param kwargs: | :param kwargs: | ||||
bool only_use_pretrain_bpe: 仅使用出现在pretrain词表中的bpe,如果该词没法tokenize则使用unk。如果embedding不需要更新 | |||||
建议设置为True。 | |||||
int min_freq: 仅在only_use_pretrain_bpe为False有效,大于等于该次数的词会被新加入BERT的BPE词表中 | |||||
bool truncate_embed: 是否仅保留用到的bpe(这样会减内存占用和加快速度) | |||||
int min_freq: 小于该次数的词会被unk代替 | |||||
""" | """ | ||||
super().__init__(vocab, word_dropout=word_dropout, dropout=dropout) | super().__init__(vocab, word_dropout=word_dropout, dropout=dropout) | ||||
@@ -89,14 +93,12 @@ class RobertaEmbedding(ContextualEmbedding): | |||||
if '<s>' in vocab: | if '<s>' in vocab: | ||||
self._word_cls_index = vocab['<s>'] | self._word_cls_index = vocab['<s>'] | ||||
only_use_pretrain_bpe = kwargs.get('only_use_pretrain_bpe', False) | |||||
truncate_embed = kwargs.get('truncate_embed', True) | |||||
min_freq = kwargs.get('min_freq', 2) | min_freq = kwargs.get('min_freq', 2) | ||||
self._min_freq = min_freq | |||||
self.model = _RobertaWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | self.model = _RobertaWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | ||||
pool_method=pool_method, include_cls_sep=include_cls_sep, | pool_method=pool_method, include_cls_sep=include_cls_sep, | ||||
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq, | |||||
only_use_pretrain_bpe=only_use_pretrain_bpe, truncate_embed=truncate_embed) | |||||
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=min_freq) | |||||
self.requires_grad = requires_grad | self.requires_grad = requires_grad | ||||
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | ||||
@@ -142,11 +144,56 @@ class RobertaEmbedding(ContextualEmbedding): | |||||
words = words.masked_fill(mask, self._word_unk_index) | words = words.masked_fill(mask, self._word_unk_index) | ||||
return words | return words | ||||
def save(self, folder): | |||||
""" | |||||
将roberta embedding保存到folder,保存之后包含三个文件vocab.txt, roberta_embed_hyper.txt, roberta_embed/, | |||||
:param str folder: 保存地址 | |||||
:return: | |||||
""" | |||||
os.makedirs(folder, exist_ok=True) | |||||
self.get_word_vocab().save(os.path.join(folder, VOCAB_NAME)) | |||||
hyper = {} | |||||
hyper['min_freq'] = self._min_freq | |||||
hyper['layers'] = ','.join(map(str, self.model.layers)) | |||||
hyper['pool_method'] = self.model.pool_method | |||||
hyper['dropout'] = self.dropout_layer.p | |||||
hyper['word_dropout'] = self.word_dropout | |||||
hyper['include_cls_sep'] = self.model.include_cls_sep | |||||
hyper['pooled_cls'] = self.model.pooled_cls | |||||
hyper['auto_truncate'] = self.model.auto_truncate | |||||
hyper['requires_grad'] = bool(self.requires_grad) | |||||
with open(os.path.join(folder, ROBERTA_EMBED_HYPER), 'w', encoding='utf-8') as f: | |||||
json.dump(hyper, f, indent=2) | |||||
os.makedirs(os.path.join(folder, ROBERTA_EMBED_FOLDER), exist_ok=True) | |||||
self.model.save(os.path.join(folder, ROBERTA_EMBED_FOLDER)) | |||||
@classmethod | |||||
def load(cls, folder): | |||||
""" | |||||
从folder中读取数据初始化RobertaEmbedding | |||||
:param folder: | |||||
:return: | |||||
""" | |||||
for name in [VOCAB_NAME, ROBERTA_EMBED_HYPER, ROBERTA_EMBED_FOLDER]: | |||||
assert os.path.exists(os.path.join(folder, name)), f"{name} not found in {folder}." | |||||
vocab = Vocabulary.load(os.path.join(folder, VOCAB_NAME)) | |||||
with open(os.path.join(folder, ROBERTA_EMBED_HYPER), 'r', encoding='utf-8') as f: | |||||
hyper = json.load(f) | |||||
model_name_or_path = os.path.join(folder, ROBERTA_EMBED_FOLDER) | |||||
roberta = cls(vocab=vocab, model_dir_or_name=model_name_or_path, **hyper) | |||||
return roberta | |||||
class _RobertaWordModel(nn.Module): | class _RobertaWordModel(nn.Module): | ||||
def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', | def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', | ||||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2, | |||||
only_use_pretrain_bpe=False, truncate_embed=True): | |||||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2): | |||||
super().__init__() | super().__init__() | ||||
self.tokenzier = RobertaTokenizer.from_pretrained(model_dir_or_name) | self.tokenzier = RobertaTokenizer.from_pretrained(model_dir_or_name) | ||||
@@ -177,72 +224,6 @@ class _RobertaWordModel(nn.Module): | |||||
self.pooled_cls = pooled_cls | self.pooled_cls = pooled_cls | ||||
self.auto_truncate = auto_truncate | self.auto_truncate = auto_truncate | ||||
# 将所有vocab中word的wordpiece计算出来, 需要额外考虑<s>和</s> | |||||
logger.info("Start to generate word pieces for word.") | |||||
# 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值 | |||||
word_piece_dict = {'<s>': 1, '</s>': 1} # 用到的word_piece以及新增的 | |||||
found_count = 0 | |||||
new_add_to_bpe_vocab = 0 | |||||
unsegment_count = 0 | |||||
if "<s>" in vocab: | |||||
warnings.warn("<s> detected in your vocabulary. RobertaEmbedding will add <s> and </s> to the begin " | |||||
"and end of the input automatically, make sure you don't add <s> and </s> at the begin" | |||||
" and end.") | |||||
for word, index in vocab: | |||||
if index == vocab.padding_idx: # pad是个特殊的符号 | |||||
word = '<pad>' | |||||
elif index == vocab.unknown_idx: | |||||
word = '<unk>' | |||||
# _words = self.tokenzier.basic_tokenizer._tokenize_chinese_chars(word).split() # 这里暂时不考虑中文内容 | |||||
word_pieces = [] | |||||
# 如果这个word不是在句子开头 | |||||
word_pieces.extend(self.tokenzier.tokenize(word, add_prefix_space=True)) | |||||
if len(word_pieces) == 1: | |||||
if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到 | |||||
if index != vocab.unknown_idx and word_pieces[0] == '<unk>': # 说明这个词不在原始的word里面 | |||||
if vocab.word_count[word] >= min_freq and not vocab._is_word_no_create_entry( | |||||
word) and not only_use_pretrain_bpe: # 出现次数大于这个次数才新增 | |||||
word_piece_dict[word] = 1 # 新增一个值 | |||||
new_add_to_bpe_vocab += 1 | |||||
unsegment_count += 1 | |||||
continue | |||||
found_count += 1 | |||||
for word_piece in word_pieces: | |||||
word_piece_dict[word_piece] = 1 | |||||
# 如果这个word是在句子开头 | |||||
original_embed = self.encoder.embeddings.word_embeddings.weight.data | |||||
# 特殊词汇要特殊处理 | |||||
if not truncate_embed: # 如果不删除的话需要将已有的加上 | |||||
word_piece_dict.update(self.tokenzier.encoder) | |||||
embed = nn.Embedding(len(word_piece_dict), original_embed.size(1)) # 新的embed | |||||
new_word_piece_vocab = collections.OrderedDict() | |||||
for index, token in enumerate(['<s>', '<pad>', '</s>', '<unk>']): | |||||
index = word_piece_dict.pop(token, None) | |||||
if index is not None: | |||||
new_word_piece_vocab[token] = len(new_word_piece_vocab) | |||||
embed.weight.data[new_word_piece_vocab[token]] = original_embed[self.tokenzier.encoder[token]] | |||||
for token in word_piece_dict.keys(): | |||||
if token not in new_word_piece_vocab: | |||||
new_word_piece_vocab[token] = len(new_word_piece_vocab) | |||||
index = new_word_piece_vocab[token] | |||||
if token in self.tokenzier.encoder: | |||||
embed.weight.data[index] = original_embed[self.tokenzier.encoder[token]] | |||||
else: | |||||
embed.weight.data[index] = original_embed[self.tokenzier.encoder['<unk>']] | |||||
self.tokenzier._reinit_on_new_vocab(new_word_piece_vocab) | |||||
self.encoder.embeddings.word_embeddings = embed | |||||
self.encoder.config.vocab_size = len(new_word_piece_vocab) | |||||
if unsegment_count>0: | |||||
if only_use_pretrain_bpe or new_add_to_bpe_vocab==0: | |||||
logger.info(f"{unsegment_count} words are unsegmented.") | |||||
else: | |||||
logger.info(f"{unsegment_count} words are unsegmented. Among them, {new_add_to_bpe_vocab} added to the BPE vocab.") | |||||
word_to_wordpieces = [] | word_to_wordpieces = [] | ||||
word_pieces_lengths = [] | word_pieces_lengths = [] | ||||
for word, index in vocab: | for word, index in vocab: | ||||
@@ -250,6 +231,8 @@ class _RobertaWordModel(nn.Module): | |||||
word = '<pad>' | word = '<pad>' | ||||
elif index == vocab.unknown_idx: | elif index == vocab.unknown_idx: | ||||
word = '<unk>' | word = '<unk>' | ||||
elif vocab.word_count[word]<min_freq: | |||||
word = '<unk>' | |||||
word_pieces = self.tokenzier.tokenize(word) | word_pieces = self.tokenzier.tokenize(word) | ||||
word_pieces = self.tokenzier.convert_tokens_to_ids(word_pieces) | word_pieces = self.tokenzier.convert_tokens_to_ids(word_pieces) | ||||
word_to_wordpieces.append(word_pieces) | word_to_wordpieces.append(word_pieces) | ||||
@@ -368,6 +351,10 @@ class _RobertaWordModel(nn.Module): | |||||
# 3. 最终的embedding结果 | # 3. 最终的embedding结果 | ||||
return outputs | return outputs | ||||
def save(self, folder): | |||||
self.tokenzier.save_pretrained(folder) | |||||
self.encoder.save_pretrained(folder) | |||||
class RobertaWordPieceEncoder(nn.Module): | class RobertaWordPieceEncoder(nn.Module): | ||||
r""" | r""" | ||||
@@ -380,7 +367,7 @@ class RobertaWordPieceEncoder(nn.Module): | |||||
""" | """ | ||||
def __init__(self, model_dir_or_name: str = 'en', layers: str = '-1', pooled_cls: bool = False, | def __init__(self, model_dir_or_name: str = 'en', layers: str = '-1', pooled_cls: bool = False, | ||||
word_dropout=0, dropout=0, requires_grad: bool = True): | |||||
word_dropout=0, dropout=0, requires_grad: bool = True, **kwargs): | |||||
r""" | r""" | ||||
:param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased`` | :param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased`` | ||||
@@ -462,6 +449,36 @@ class RobertaWordPieceEncoder(nn.Module): | |||||
words = words.masked_fill(mask, self._wordpiece_unk_index) | words = words.masked_fill(mask, self._wordpiece_unk_index) | ||||
return words | return words | ||||
def save(self, folder): | |||||
os.makedirs(folder, exist_ok=True) | |||||
hyper = {} | |||||
hyper['layers'] = ','.join(map(str, self.model.layers)) | |||||
hyper['dropout'] = self.dropout_layer.p | |||||
hyper['word_dropout'] = self.word_dropout | |||||
hyper['pooled_cls'] = self.model.pooled_cls | |||||
hyper['requires_grad'] = bool(self.requires_grad) | |||||
with open(os.path.join(folder, ROBERTA_ENCODER_HYPER), 'w', encoding='utf-8') as f: | |||||
json.dump(hyper, f, indent=2) | |||||
os.makedirs(os.path.join(folder, ROBERTA_ENCODER_FOLDER), exist_ok=True) | |||||
self.model.save(os.path.join(folder, ROBERTA_ENCODER_FOLDER)) | |||||
logger.debug(f"BertWordPieceEncoder has been saved in {folder}") | |||||
@classmethod | |||||
def load(cls, folder): | |||||
for name in [ROBERTA_ENCODER_HYPER, ROBERTA_ENCODER_FOLDER]: | |||||
assert os.path.exists(os.path.join(folder, name)), f"{name} not found in {folder}." | |||||
with open(os.path.join(folder, ROBERTA_ENCODER_HYPER), 'r', encoding='utf-8') as f: | |||||
hyper = json.load(f) | |||||
model_dir_or_name = os.path.join(os.path.join(folder, ROBERTA_ENCODER_FOLDER)) | |||||
bert_encoder = cls(model_dir_or_name=model_dir_or_name, **hyper) | |||||
return bert_encoder | |||||
class _WordPieceRobertaModel(nn.Module): | class _WordPieceRobertaModel(nn.Module): | ||||
def __init__(self, model_dir_or_name: str, layers: str = '-1', pooled_cls: bool=False): | def __init__(self, model_dir_or_name: str, layers: str = '-1', pooled_cls: bool=False): | ||||
@@ -535,4 +552,8 @@ class _WordPieceRobertaModel(nn.Module): | |||||
if l in (len(roberta_output)-1, -1) and self.pooled_cls: | if l in (len(roberta_output)-1, -1) and self.pooled_cls: | ||||
roberta_output[:, 0] = pooled_cls | roberta_output[:, 0] = pooled_cls | ||||
outputs[l_index] = roberta_output | outputs[l_index] = roberta_output | ||||
return outputs | |||||
return outputs | |||||
def save(self, folder): | |||||
self.tokenzier.save_pretrained(folder) | |||||
self.encoder.save_pretrained(folder) |
@@ -10,6 +10,8 @@ import os | |||||
import warnings | import warnings | ||||
from collections import defaultdict | from collections import defaultdict | ||||
from copy import deepcopy | from copy import deepcopy | ||||
import json | |||||
from typing import Union | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
@@ -19,7 +21,12 @@ from .embedding import TokenEmbedding | |||||
from ..core import logger | from ..core import logger | ||||
from ..core.vocabulary import Vocabulary | from ..core.vocabulary import Vocabulary | ||||
from ..io.file_utils import PRETRAIN_STATIC_FILES, _get_embedding_url, cached_path | from ..io.file_utils import PRETRAIN_STATIC_FILES, _get_embedding_url, cached_path | ||||
from fastNLP.io.file_utils import _get_file_name_base_on_postfix | |||||
from ..io.file_utils import _get_file_name_base_on_postfix | |||||
VOCAB_FILENAME = 'vocab.txt' | |||||
STATIC_HYPER_FILENAME = 'static_hyper.json' | |||||
STATIC_EMBED_FILENAME = 'static.txt' | |||||
class StaticEmbedding(TokenEmbedding): | class StaticEmbedding(TokenEmbedding): | ||||
@@ -70,7 +77,7 @@ class StaticEmbedding(TokenEmbedding): | |||||
""" | """ | ||||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', embedding_dim=-1, requires_grad: bool = True, | |||||
def __init__(self, vocab: Vocabulary, model_dir_or_name: Union[str, None] = 'en', embedding_dim=-1, requires_grad: bool = True, | |||||
init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1, **kwargs): | init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1, **kwargs): | ||||
r""" | r""" | ||||
@@ -95,8 +102,8 @@ class StaticEmbedding(TokenEmbedding): | |||||
""" | """ | ||||
super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | ||||
if embedding_dim > 0: | if embedding_dim > 0: | ||||
if model_dir_or_name is not None: | |||||
warnings.warn(f"StaticEmbedding will ignore `model_dir_or_name`, and randomly initialize embedding with" | |||||
if model_dir_or_name: | |||||
logger.info(f"StaticEmbedding will ignore `model_dir_or_name`, and randomly initialize embedding with" | |||||
f" dimension {embedding_dim}. If you want to use pre-trained embedding, " | f" dimension {embedding_dim}. If you want to use pre-trained embedding, " | ||||
f"set `embedding_dim` to 0.") | f"set `embedding_dim` to 0.") | ||||
model_dir_or_name = None | model_dir_or_name = None | ||||
@@ -116,7 +123,9 @@ class StaticEmbedding(TokenEmbedding): | |||||
model_path = _get_file_name_base_on_postfix(os.path.abspath(os.path.expanduser(model_dir_or_name)), '.txt') | model_path = _get_file_name_base_on_postfix(os.path.abspath(os.path.expanduser(model_dir_or_name)), '.txt') | ||||
else: | else: | ||||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | raise ValueError(f"Cannot recognize {model_dir_or_name}.") | ||||
kwargs['min_freq'] = min_freq | |||||
kwargs['lower'] = lower | |||||
# 根据min_freq缩小vocab | # 根据min_freq缩小vocab | ||||
truncate_vocab = (vocab.min_freq is None and min_freq > 1) or (vocab.min_freq and vocab.min_freq < min_freq) | truncate_vocab = (vocab.min_freq is None and min_freq > 1) or (vocab.min_freq and vocab.min_freq < min_freq) | ||||
if truncate_vocab: | if truncate_vocab: | ||||
@@ -143,7 +152,7 @@ class StaticEmbedding(TokenEmbedding): | |||||
truncated_words_to_words = torch.arange(len(vocab)).long() | truncated_words_to_words = torch.arange(len(vocab)).long() | ||||
for word, index in vocab: | for word, index in vocab: | ||||
truncated_words_to_words[index] = truncated_vocab.to_index(word) | truncated_words_to_words[index] = truncated_vocab.to_index(word) | ||||
logger.info(f"{len(vocab) - len(truncated_vocab)} out of {len(vocab)} words have frequency less than {min_freq}.") | |||||
logger.info(f"{len(vocab) - len(truncated_vocab)} words have frequency less than {min_freq}.") | |||||
vocab = truncated_vocab | vocab = truncated_vocab | ||||
self.only_use_pretrain_word = kwargs.get('only_use_pretrain_word', False) | self.only_use_pretrain_word = kwargs.get('only_use_pretrain_word', False) | ||||
@@ -198,6 +207,7 @@ class StaticEmbedding(TokenEmbedding): | |||||
sparse=False, _weight=embedding) | sparse=False, _weight=embedding) | ||||
self._embed_size = self.embedding.weight.size(1) | self._embed_size = self.embedding.weight.size(1) | ||||
self.requires_grad = requires_grad | self.requires_grad = requires_grad | ||||
self.kwargs = kwargs | |||||
@property | @property | ||||
def weight(self): | def weight(self): | ||||
@@ -321,3 +331,71 @@ class StaticEmbedding(TokenEmbedding): | |||||
words = self.embedding(words) | words = self.embedding(words) | ||||
words = self.dropout(words) | words = self.dropout(words) | ||||
return words | return words | ||||
def save(self, folder): | |||||
""" | |||||
将embedding存储到folder下,之后可以通过使用load方法读取 | |||||
:param str folder: 会在该folder下生成三个文件, vocab.txt, static_embed_hyper.txt, static_embed_hyper.json. | |||||
其中vocab.txt可以用Vocabulary通过load读取; embedding.txt按照word2vec的方式存储,以空格的方式隔开元素, | |||||
第一行只有两个元素,剩下的行首先是word然后是各个维度的值; static_embed_hyper.json是StaticEmbedding的超参数 | |||||
:return: | |||||
""" | |||||
os.makedirs(folder, exist_ok=True) | |||||
vocab = self.get_word_vocab() | |||||
vocab_fp = os.path.join(folder, VOCAB_FILENAME) | |||||
vocab.save(vocab_fp) | |||||
kwargs = self.kwargs.copy() | |||||
kwargs['dropout'] = self.dropout_layer.p | |||||
kwargs['word_dropout'] = self.word_dropout | |||||
kwargs['requires_grad'] = self.requires_grad | |||||
kwargs['only_norm_found_vector'] = False | |||||
kwargs['only_use_pretrain_word'] = True | |||||
with open(os.path.join(folder, STATIC_HYPER_FILENAME), 'w', encoding='utf-8') as f: | |||||
json.dump(kwargs, f, indent=2) | |||||
with open(os.path.join(folder, STATIC_EMBED_FILENAME), 'w', encoding='utf-8') as f: | |||||
f.write('{}\n'.format(' '*30)) # 留白之后再来填写 | |||||
word_count = 0 | |||||
saved_word = {} | |||||
valid_word_count = 0 | |||||
for i in range(len(self.words_to_words)): | |||||
word = vocab.to_word(i) | |||||
if not vocab._is_word_no_create_entry(word): | |||||
word_count += 1 | |||||
if kwargs['lower']: | |||||
word = word.lower() | |||||
if word in saved_word: | |||||
continue | |||||
saved_word[word] = 1 | |||||
vec_i = self.words_to_words[i] | |||||
if vec_i==vocab.unknown_idx and i!=vocab.unknown_idx: | |||||
continue | |||||
vec = self.embedding.weight.data[vec_i].tolist() | |||||
vec_str = ' '.join(map(str, vec)) | |||||
f.write(f'{word} {vec_str}\n') | |||||
valid_word_count += 1 | |||||
f.seek(0) | |||||
f.write('{} {}'.format(valid_word_count, self.embedding_dim)) | |||||
logger.debug(f"StaticEmbedding has been saved to {folder}.") | |||||
@classmethod | |||||
def load(cls, folder): | |||||
""" | |||||
:param str folder: 该folder下应该有以下三个文件vocab.txt, static_embed.txt, static_hyper.json | |||||
:return: | |||||
""" | |||||
for name in [VOCAB_FILENAME, STATIC_EMBED_FILENAME, STATIC_HYPER_FILENAME]: | |||||
assert os.path.exists(os.path.join(folder, name)), f"{name} not found in {folder}." | |||||
vocab = Vocabulary.load(os.path.join(folder, VOCAB_FILENAME)) | |||||
with open(os.path.join(folder, STATIC_HYPER_FILENAME), 'r', encoding='utf-8') as f: | |||||
hyper = json.load(f) | |||||
logger.info(f"Load StaticEmbedding from {folder}.") | |||||
embed = cls(vocab=vocab, model_dir_or_name=os.path.join(folder, STATIC_EMBED_FILENAME), **hyper) | |||||
return embed | |||||
@@ -9,7 +9,8 @@ from torch import nn as nn | |||||
from ..core.vocabulary import Vocabulary | from ..core.vocabulary import Vocabulary | ||||
__all__ = [ | __all__ = [ | ||||
'get_embeddings' | |||||
'get_embeddings', | |||||
'get_sinusoid_encoding_table' | |||||
] | ] | ||||
@@ -31,7 +32,7 @@ def _construct_char_vocab_from_vocab(vocab: Vocabulary, min_freq: int = 1, inclu | |||||
return char_vocab | return char_vocab | ||||
def get_embeddings(init_embed): | |||||
def get_embeddings(init_embed, padding_idx=None): | |||||
r""" | r""" | ||||
根据输入的init_embed返回Embedding对象。如果输入是tuple, 则随机初始化一个nn.Embedding; 如果输入是numpy.ndarray, 则按照ndarray | 根据输入的init_embed返回Embedding对象。如果输入是tuple, 则随机初始化一个nn.Embedding; 如果输入是numpy.ndarray, 则按照ndarray | ||||
的值将nn.Embedding初始化; 如果输入是torch.Tensor, 则按该值初始化nn.Embedding; 如果输入是fastNLP中的embedding将不做处理 | 的值将nn.Embedding初始化; 如果输入是torch.Tensor, 则按该值初始化nn.Embedding; 如果输入是fastNLP中的embedding将不做处理 | ||||
@@ -40,11 +41,12 @@ def get_embeddings(init_embed): | |||||
:param init_embed: 可以是 tuple:(num_embedings, embedding_dim), 即embedding的大小和每个词的维度;也可以传入 | :param init_embed: 可以是 tuple:(num_embedings, embedding_dim), 即embedding的大小和每个词的维度;也可以传入 | ||||
nn.Embedding 对象, 此时就以传入的对象作为embedding; 传入np.ndarray也行,将使用传入的ndarray作为作为Embedding初始化; | nn.Embedding 对象, 此时就以传入的对象作为embedding; 传入np.ndarray也行,将使用传入的ndarray作为作为Embedding初始化; | ||||
传入torch.Tensor, 将使用传入的值作为Embedding初始化。 | 传入torch.Tensor, 将使用传入的值作为Embedding初始化。 | ||||
:param padding_idx: 当传入tuple时,padding_idx有效 | |||||
:return nn.Embedding: embeddings | :return nn.Embedding: embeddings | ||||
""" | """ | ||||
if isinstance(init_embed, tuple): | if isinstance(init_embed, tuple): | ||||
res = nn.Embedding( | res = nn.Embedding( | ||||
num_embeddings=init_embed[0], embedding_dim=init_embed[1]) | |||||
num_embeddings=init_embed[0], embedding_dim=init_embed[1], padding_idx=padding_idx) | |||||
nn.init.uniform_(res.weight.data, a=-np.sqrt(3 / res.weight.data.size(1)), | nn.init.uniform_(res.weight.data, a=-np.sqrt(3 / res.weight.data.size(1)), | ||||
b=np.sqrt(3 / res.weight.data.size(1))) | b=np.sqrt(3 / res.weight.data.size(1))) | ||||
elif isinstance(init_embed, nn.Module): | elif isinstance(init_embed, nn.Module): | ||||
@@ -58,3 +60,32 @@ def get_embeddings(init_embed): | |||||
raise TypeError( | raise TypeError( | ||||
'invalid init_embed type: {}'.format((type(init_embed)))) | 'invalid init_embed type: {}'.format((type(init_embed)))) | ||||
return res | return res | ||||
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): | |||||
""" | |||||
sinusoid的embedding,其中position的表示中,偶数维(0,2,4,...)是sin, 奇数(1,3,5...)是cos | |||||
:param int n_position: 一共多少个position | |||||
:param int d_hid: 多少维度,需要为偶数 | |||||
:param padding_idx: | |||||
:return: torch.FloatTensor, shape为n_position x d_hid | |||||
""" | |||||
def cal_angle(position, hid_idx): | |||||
return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) | |||||
def get_posi_angle_vec(position): | |||||
return [cal_angle(position, hid_j) for hid_j in range(d_hid)] | |||||
sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) | |||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i | |||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 | |||||
if padding_idx is not None: | |||||
# zero vector for padding dimension | |||||
sinusoid_table[padding_idx] = 0. | |||||
return torch.FloatTensor(sinusoid_table) | |||||
@@ -9,18 +9,18 @@ fastNLP 在 :mod:`~fastNLP.models` 模块中内置了如 :class:`~fastNLP.models | |||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
"CNNText", | "CNNText", | ||||
"SeqLabeling", | "SeqLabeling", | ||||
"AdvSeqLabel", | "AdvSeqLabel", | ||||
"BiLSTMCRF", | "BiLSTMCRF", | ||||
"ESIM", | "ESIM", | ||||
"StarTransEnc", | "StarTransEnc", | ||||
"STSeqLabel", | "STSeqLabel", | ||||
"STNLICls", | "STNLICls", | ||||
"STSeqCls", | "STSeqCls", | ||||
"BiaffineParser", | "BiaffineParser", | ||||
"GraphParser", | "GraphParser", | ||||
@@ -28,7 +28,13 @@ __all__ = [ | |||||
"BertForSentenceMatching", | "BertForSentenceMatching", | ||||
"BertForMultipleChoice", | "BertForMultipleChoice", | ||||
"BertForTokenClassification", | "BertForTokenClassification", | ||||
"BertForQuestionAnswering" | |||||
"BertForQuestionAnswering", | |||||
"TransformerSeq2SeqModel", | |||||
"LSTMSeq2SeqModel", | |||||
"Seq2SeqModel", | |||||
'SequenceGeneratorModel' | |||||
] | ] | ||||
from .base_model import BaseModel | from .base_model import BaseModel | ||||
@@ -39,7 +45,9 @@ from .cnn_text_classification import CNNText | |||||
from .sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF | from .sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF | ||||
from .snli import ESIM | from .snli import ESIM | ||||
from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel | from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel | ||||
from .seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel, Seq2SeqModel | |||||
from .seq2seq_generator import SequenceGeneratorModel | |||||
import sys | import sys | ||||
from ..doc_utils import doc_process | from ..doc_utils import doc_process | ||||
doc_process(sys.modules[__name__]) | |||||
doc_process(sys.modules[__name__]) |
@@ -39,7 +39,7 @@ from torch import nn | |||||
from .base_model import BaseModel | from .base_model import BaseModel | ||||
from ..core._logger import logger | from ..core._logger import logger | ||||
from ..core.const import Const | from ..core.const import Const | ||||
from ..embeddings import BertEmbedding | |||||
from ..embeddings.bert_embedding import BertEmbedding | |||||
class BertForSequenceClassification(BaseModel): | class BertForSequenceClassification(BaseModel): | ||||
@@ -314,13 +314,8 @@ class BiaffineParser(GraphParser): | |||||
raise ValueError('unsupported rnn_out_size: {} for transformer'.format(rnn_out_size)) | raise ValueError('unsupported rnn_out_size: {} for transformer'.format(rnn_out_size)) | ||||
self.position_emb = nn.Embedding(num_embeddings=self.max_len, | self.position_emb = nn.Embedding(num_embeddings=self.max_len, | ||||
embedding_dim=rnn_out_size, ) | embedding_dim=rnn_out_size, ) | ||||
self.encoder = TransformerEncoder(num_layers=rnn_layers, | |||||
model_size=rnn_out_size, | |||||
inner_size=1024, | |||||
key_size=d_k, | |||||
value_size=d_v, | |||||
num_head=n_head, | |||||
dropout=dropout, ) | |||||
self.encoder = TransformerEncoder( num_layers=rnn_layers, d_model=rnn_out_size, | |||||
n_head=n_head, dim_ff=1024, dropout=dropout) | |||||
else: | else: | ||||
raise ValueError('unsupported encoder type: {}'.format(encoder)) | raise ValueError('unsupported encoder type: {}'.format(encoder)) | ||||
@@ -0,0 +1,62 @@ | |||||
import torch | |||||
from torch import nn | |||||
from .seq2seq_model import Seq2SeqModel | |||||
from ..modules.generator.seq2seq_generator import SequenceGenerator | |||||
class SequenceGeneratorModel(nn.Module): | |||||
""" | |||||
用于封装Seq2SeqModel使其可以做生成任务 | |||||
""" | |||||
def __init__(self, seq2seq_model: Seq2SeqModel, bos_token_id, eos_token_id=None, max_length=30, num_beams=1, | |||||
do_sample=True, temperature=1.0, top_k=50, top_p=1.0, | |||||
repetition_penalty=1, length_penalty=1.0, pad_token_id=0): | |||||
""" | |||||
:param Seq2SeqModel seq2seq_model: 序列到序列模型 | |||||
:param int,None bos_token_id: 句子开头的token id | |||||
:param int,None eos_token_id: 句子结束的token id | |||||
:param int max_length: 句子的最大长度 | |||||
:param int num_beams: beam search的大小 | |||||
:param bool do_sample: 是否通过采样的方式生成 | |||||
:param float temperature: 只有在do_sample为True才有意义 | |||||
:param int top_k: 只从top_k中采样 | |||||
:param float top_p: 只从top_p的token中采样,nucles sample | |||||
:param float repetition_penalty: 多大程度上惩罚重复的token | |||||
:param float length_penalty: 对长度的惩罚,小于1鼓励长句,大于1鼓励短剧 | |||||
:param int pad_token_id: 当某句话生成结束之后,之后生成的内容用pad_token_id补充 | |||||
""" | |||||
super().__init__() | |||||
self.seq2seq_model = seq2seq_model | |||||
self.generator = SequenceGenerator(seq2seq_model.decoder, max_length=max_length, num_beams=num_beams, | |||||
do_sample=do_sample, temperature=temperature, top_k=top_k, top_p=top_p, | |||||
bos_token_id=bos_token_id, | |||||
eos_token_id=eos_token_id, | |||||
repetition_penalty=repetition_penalty, length_penalty=length_penalty, | |||||
pad_token_id=pad_token_id) | |||||
def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None): | |||||
""" | |||||
透传调用seq2seq_model的forward | |||||
:param torch.LongTensor src_tokens: bsz x max_len | |||||
:param torch.LongTensor tgt_tokens: bsz x max_len' | |||||
:param torch.LongTensor src_seq_len: bsz | |||||
:param torch.LongTensor tgt_seq_len: bsz | |||||
:return: | |||||
""" | |||||
return self.seq2seq_model(src_tokens, tgt_tokens, src_seq_len, tgt_seq_len) | |||||
def predict(self, src_tokens, src_seq_len=None): | |||||
""" | |||||
给定source的内容,输出generate的内容 | |||||
:param torch.LongTensor src_tokens: bsz x max_len | |||||
:param torch.LongTensor src_seq_len: bsz | |||||
:return: | |||||
""" | |||||
state = self.seq2seq_model.prepare_state(src_tokens, src_seq_len) | |||||
result = self.generator.generate(state) | |||||
return {'pred': result} |
@@ -0,0 +1,176 @@ | |||||
r""" | |||||
主要包含组成Sequence-to-Sequence的model | |||||
""" | |||||
import torch | |||||
from torch import nn | |||||
from ..embeddings import get_embeddings | |||||
from ..embeddings.utils import get_sinusoid_encoding_table | |||||
from ..modules.decoder.seq2seq_decoder import Seq2SeqDecoder, TransformerSeq2SeqDecoder, LSTMSeq2SeqDecoder | |||||
from ..modules.encoder.seq2seq_encoder import Seq2SeqEncoder, TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder | |||||
class Seq2SeqModel(nn.Module): | |||||
def __init__(self, encoder: Seq2SeqEncoder, decoder: Seq2SeqDecoder): | |||||
""" | |||||
可以用于在Trainer中训练的Seq2Seq模型。正常情况下,继承了该函数之后,只需要实现classmethod build_model即可。 | |||||
:param encoder: Encoder | |||||
:param decoder: Decoder | |||||
""" | |||||
super().__init__() | |||||
self.encoder = encoder | |||||
self.decoder = decoder | |||||
def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None): | |||||
""" | |||||
:param torch.LongTensor src_tokens: source的token | |||||
:param torch.LongTensor tgt_tokens: target的token | |||||
:param torch.LongTensor src_seq_len: src的长度 | |||||
:param torch.LongTensor tgt_seq_len: target的长度,默认用不上 | |||||
:return: {'pred': torch.Tensor}, 其中pred的shape为bsz x max_len x vocab_size | |||||
""" | |||||
state = self.prepare_state(src_tokens, src_seq_len) | |||||
decoder_output = self.decoder(tgt_tokens, state) | |||||
if isinstance(decoder_output, torch.Tensor): | |||||
return {'pred': decoder_output} | |||||
elif isinstance(decoder_output, (tuple, list)): | |||||
return {'pred': decoder_output[0]} | |||||
else: | |||||
raise TypeError(f"Unsupported return type from Decoder:{type(self.decoder)}") | |||||
def prepare_state(self, src_tokens, src_seq_len=None): | |||||
""" | |||||
调用encoder获取state,会把encoder的encoder_output, encoder_mask直接传入到decoder.init_state中初始化一个state | |||||
:param src_tokens: | |||||
:param src_seq_len: | |||||
:return: | |||||
""" | |||||
encoder_output, encoder_mask = self.encoder(src_tokens, src_seq_len) | |||||
state = self.decoder.init_state(encoder_output, encoder_mask) | |||||
return state | |||||
@classmethod | |||||
def build_model(cls, *args, **kwargs): | |||||
""" | |||||
需要实现本方法来进行Seq2SeqModel的初始化 | |||||
:return: | |||||
""" | |||||
raise NotImplemented | |||||
class TransformerSeq2SeqModel(Seq2SeqModel): | |||||
""" | |||||
Encoder为TransformerSeq2SeqEncoder, decoder为TransformerSeq2SeqDecoder,通过build_model方法初始化 | |||||
""" | |||||
def __init__(self, encoder, decoder): | |||||
super().__init__(encoder, decoder) | |||||
@classmethod | |||||
def build_model(cls, src_embed, tgt_embed=None, | |||||
pos_embed='sin', max_position=1024, num_layers=6, d_model=512, n_head=8, dim_ff=2048, dropout=0.1, | |||||
bind_encoder_decoder_embed=False, | |||||
bind_decoder_input_output_embed=True): | |||||
""" | |||||
初始化一个TransformerSeq2SeqModel | |||||
:param nn.Module, StaticEmbedding, Tuple[int, int] src_embed: source的embedding | |||||
:param nn.Module, StaticEmbedding, Tuple[int, int] tgt_embed: target的embedding,如果bind_encoder_decoder_embed为 | |||||
True,则不要输入该值 | |||||
:param str pos_embed: 支持sin, learned两种 | |||||
:param int max_position: 最大支持长度 | |||||
:param int num_layers: encoder和decoder的层数 | |||||
:param int d_model: encoder和decoder输入输出的大小 | |||||
:param int n_head: encoder和decoder的head的数量 | |||||
:param int dim_ff: encoder和decoder中FFN中间映射的维度 | |||||
:param float dropout: Attention和FFN dropout的大小 | |||||
:param bool bind_encoder_decoder_embed: 是否对encoder和decoder使用相同的embedding | |||||
:param bool bind_decoder_input_output_embed: decoder的输出embedding是否与其输入embedding是一样的权重 | |||||
:return: TransformerSeq2SeqModel | |||||
""" | |||||
if bind_encoder_decoder_embed and tgt_embed is not None: | |||||
raise RuntimeError("If you set `bind_encoder_decoder_embed=True`, please do not provide `tgt_embed`.") | |||||
src_embed = get_embeddings(src_embed) | |||||
if bind_encoder_decoder_embed: | |||||
tgt_embed = src_embed | |||||
else: | |||||
assert tgt_embed is not None, "You need to pass `tgt_embed` when `bind_encoder_decoder_embed=False`" | |||||
tgt_embed = get_embeddings(tgt_embed) | |||||
if pos_embed == 'sin': | |||||
encoder_pos_embed = nn.Embedding.from_pretrained( | |||||
get_sinusoid_encoding_table(max_position + 1, src_embed.embedding_dim, padding_idx=0), | |||||
freeze=True) # 这里规定0是padding | |||||
deocder_pos_embed = nn.Embedding.from_pretrained( | |||||
get_sinusoid_encoding_table(max_position + 1, tgt_embed.embedding_dim, padding_idx=0), | |||||
freeze=True) # 这里规定0是padding | |||||
elif pos_embed == 'learned': | |||||
encoder_pos_embed = get_embeddings((max_position + 1, src_embed.embedding_dim), padding_idx=0) | |||||
deocder_pos_embed = get_embeddings((max_position + 1, src_embed.embedding_dim), padding_idx=1) | |||||
else: | |||||
raise ValueError("pos_embed only supports sin or learned.") | |||||
encoder = TransformerSeq2SeqEncoder(embed=src_embed, pos_embed=encoder_pos_embed, | |||||
num_layers=num_layers, d_model=d_model, n_head=n_head, dim_ff=dim_ff, | |||||
dropout=dropout) | |||||
decoder = TransformerSeq2SeqDecoder(embed=tgt_embed, pos_embed=deocder_pos_embed, | |||||
d_model=d_model, num_layers=num_layers, n_head=n_head, dim_ff=dim_ff, | |||||
dropout=dropout, | |||||
bind_decoder_input_output_embed=bind_decoder_input_output_embed) | |||||
return cls(encoder, decoder) | |||||
class LSTMSeq2SeqModel(Seq2SeqModel): | |||||
""" | |||||
使用LSTMSeq2SeqEncoder和LSTMSeq2SeqDecoder的model | |||||
""" | |||||
def __init__(self, encoder, decoder): | |||||
super().__init__(encoder, decoder) | |||||
@classmethod | |||||
def build_model(cls, src_embed, tgt_embed=None, | |||||
num_layers = 3, hidden_size = 400, dropout = 0.3, bidirectional=True, | |||||
attention=True, bind_encoder_decoder_embed=False, | |||||
bind_decoder_input_output_embed=True): | |||||
""" | |||||
:param nn.Module, StaticEmbedding, Tuple[int, int] src_embed: source的embedding | |||||
:param nn.Module, StaticEmbedding, Tuple[int, int] tgt_embed: target的embedding,如果bind_encoder_decoder_embed为 | |||||
True,则不要输入该值 | |||||
:param int num_layers: Encoder和Decoder的层数 | |||||
:param int hidden_size: encoder和decoder的隐藏层大小 | |||||
:param float dropout: 每层之间的Dropout的大小 | |||||
:param bool bidirectional: encoder是否使用双向LSTM | |||||
:param bool attention: decoder是否使用attention attend encoder在所有时刻的状态 | |||||
:param bool bind_encoder_decoder_embed: 是否对encoder和decoder使用相同的embedding | |||||
:param bool bind_decoder_input_output_embed: decoder的输出embedding是否与其输入embedding是一样的权重 | |||||
:return: LSTMSeq2SeqModel | |||||
""" | |||||
if bind_encoder_decoder_embed and tgt_embed is not None: | |||||
raise RuntimeError("If you set `bind_encoder_decoder_embed=True`, please do not provide `tgt_embed`.") | |||||
src_embed = get_embeddings(src_embed) | |||||
if bind_encoder_decoder_embed: | |||||
tgt_embed = src_embed | |||||
else: | |||||
assert tgt_embed is not None, "You need to pass `tgt_embed` when `bind_encoder_decoder_embed=False`" | |||||
tgt_embed = get_embeddings(tgt_embed) | |||||
encoder = LSTMSeq2SeqEncoder(embed=src_embed, num_layers = num_layers, | |||||
hidden_size = hidden_size, dropout = dropout, bidirectional=bidirectional) | |||||
decoder = LSTMSeq2SeqDecoder(embed=tgt_embed, num_layers = num_layers, hidden_size = hidden_size, | |||||
dropout = dropout, bind_decoder_input_output_embed = bind_decoder_input_output_embed, | |||||
attention=attention) | |||||
return cls(encoder, decoder) |
@@ -14,9 +14,9 @@ import torch.nn.functional as F | |||||
from .base_model import BaseModel | from .base_model import BaseModel | ||||
from ..core.const import Const as C | from ..core.const import Const as C | ||||
from ..core.utils import seq_len_to_mask | from ..core.utils import seq_len_to_mask | ||||
from ..embeddings import get_embeddings | |||||
from ..modules import ConditionalRandomField | |||||
from ..modules import LSTM | |||||
from ..embeddings.utils import get_embeddings | |||||
from ..modules.decoder import ConditionalRandomField | |||||
from ..modules.encoder import LSTM | |||||
from ..modules import decoder, encoder | from ..modules import decoder, encoder | ||||
from ..modules.decoder.crf import allowed_transitions | from ..modules.decoder.crf import allowed_transitions | ||||
@@ -58,7 +58,21 @@ __all__ = [ | |||||
"RobertaModel", | "RobertaModel", | ||||
"GPT2Model", | "GPT2Model", | ||||
"GPT2Tokenizer" | |||||
"GPT2Tokenizer", | |||||
"TransformerSeq2SeqEncoder", | |||||
"LSTMSeq2SeqEncoder", | |||||
"Seq2SeqEncoder", | |||||
"TransformerSeq2SeqDecoder", | |||||
"LSTMSeq2SeqDecoder", | |||||
"Seq2SeqDecoder", | |||||
"TransformerState", | |||||
"LSTMState", | |||||
"State", | |||||
"SequenceGenerator" | |||||
] | ] | ||||
import sys | import sys | ||||
@@ -68,6 +82,7 @@ from . import encoder | |||||
from .decoder import * | from .decoder import * | ||||
from .dropout import TimestepDropout | from .dropout import TimestepDropout | ||||
from .encoder import * | from .encoder import * | ||||
from .generator import * | |||||
from .utils import summary | from .utils import summary | ||||
from ..doc_utils import doc_process | from ..doc_utils import doc_process | ||||
from .tokenizer import * | from .tokenizer import * | ||||
@@ -12,7 +12,8 @@ import torch | |||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from torch import nn | from torch import nn | ||||
from fastNLP.modules.utils import initial_parameter | |||||
from .utils import initial_parameter | |||||
from .decoder.seq2seq_state import TransformerState | |||||
class DotAttention(nn.Module): | class DotAttention(nn.Module): | ||||
@@ -45,64 +46,153 @@ class DotAttention(nn.Module): | |||||
class MultiHeadAttention(nn.Module): | class MultiHeadAttention(nn.Module): | ||||
r""" | |||||
Transformer当中的MultiHeadAttention | |||||
""" | """ | ||||
Attention is all you need中提到的多头注意力 | |||||
def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1): | |||||
r""" | |||||
:param input_size: int, 输入维度的大小。同时也是输出维度的大小。 | |||||
:param key_size: int, 每个head的维度大小。 | |||||
:param value_size: int,每个head中value的维度。 | |||||
:param num_head: int,head的数量。 | |||||
:param dropout: float。 | |||||
""" | |||||
""" | |||||
def __init__(self, d_model: int = 512, n_head: int = 8, dropout: float = 0.0, layer_idx: int = None): | |||||
super(MultiHeadAttention, self).__init__() | super(MultiHeadAttention, self).__init__() | ||||
self.input_size = input_size | |||||
self.key_size = key_size | |||||
self.value_size = value_size | |||||
self.num_head = num_head | |||||
in_size = key_size * num_head | |||||
self.q_in = nn.Linear(input_size, in_size) | |||||
self.k_in = nn.Linear(input_size, in_size) | |||||
self.v_in = nn.Linear(input_size, in_size) | |||||
self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=dropout) | |||||
self.out = nn.Linear(value_size * num_head, input_size) | |||||
self.d_model = d_model | |||||
self.n_head = n_head | |||||
self.dropout = dropout | |||||
self.head_dim = d_model // n_head | |||||
self.layer_idx = layer_idx | |||||
assert d_model % n_head == 0, "d_model should be divisible by n_head" | |||||
self.scaling = self.head_dim ** -0.5 | |||||
self.q_proj = nn.Linear(d_model, d_model) | |||||
self.k_proj = nn.Linear(d_model, d_model) | |||||
self.v_proj = nn.Linear(d_model, d_model) | |||||
self.out_proj = nn.Linear(d_model, d_model) | |||||
self.reset_parameters() | self.reset_parameters() | ||||
def forward(self, query, key, value, key_mask=None, attn_mask=None, state=None): | |||||
""" | |||||
:param query: batch x seq x dim | |||||
:param key: batch x seq x dim | |||||
:param value: batch x seq x dim | |||||
:param key_mask: batch x seq 用于指示哪些key不要attend到;注意到mask为1的地方是要attend到的 | |||||
:param attn_mask: seq x seq, 用于mask掉attention map。 主要是用在训练时decoder端的self attention,下三角为1 | |||||
:param state: 过去的信息,在inference的时候会用到,比如encoder output、decoder的prev kv。这样可以减少计算。 | |||||
:return: | |||||
""" | |||||
assert key.size() == value.size() | |||||
if state is not None: | |||||
assert self.layer_idx is not None | |||||
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr() | |||||
q = self.q_proj(query) # batch x seq x dim | |||||
q *= self.scaling | |||||
k = v = None | |||||
prev_k = prev_v = None | |||||
# 从state中取kv | |||||
if isinstance(state, TransformerState): # 说明此时在inference阶段 | |||||
if qkv_same: # 此时在decoder self attention | |||||
prev_k = state.decoder_prev_key[self.layer_idx] | |||||
prev_v = state.decoder_prev_value[self.layer_idx] | |||||
else: # 此时在decoder-encoder attention,直接将保存下来的key装载起来即可 | |||||
k = state.encoder_key[self.layer_idx] | |||||
v = state.encoder_value[self.layer_idx] | |||||
if k is None: | |||||
k = self.k_proj(key) | |||||
v = self.v_proj(value) | |||||
if prev_k is not None: | |||||
k = torch.cat((prev_k, k), dim=1) | |||||
v = torch.cat((prev_v, v), dim=1) | |||||
# 更新state | |||||
if isinstance(state, TransformerState): | |||||
if qkv_same: | |||||
state.decoder_prev_key[self.layer_idx] = k | |||||
state.decoder_prev_value[self.layer_idx] = v | |||||
else: | |||||
state.encoder_key[self.layer_idx] = k | |||||
state.encoder_value[self.layer_idx] = v | |||||
# 开始计算attention | |||||
batch_size, q_len, d_model = query.size() | |||||
k_len, v_len = k.size(1), v.size(1) | |||||
q = q.reshape(batch_size, q_len, self.n_head, self.head_dim) | |||||
k = k.reshape(batch_size, k_len, self.n_head, self.head_dim) | |||||
v = v.reshape(batch_size, v_len, self.n_head, self.head_dim) | |||||
attn_weights = torch.einsum('bqnh,bknh->bqkn', q, k) # bs,q_len,k_len,n_head | |||||
if key_mask is not None: | |||||
_key_mask = ~key_mask[:, None, :, None].bool() # batch,1,k_len,1 | |||||
attn_weights = attn_weights.masked_fill(_key_mask, -float('inf')) | |||||
if attn_mask is not None: | |||||
_attn_mask = attn_mask[None, :, :, None].eq(0) # 1,q_len,k_len,n_head | |||||
attn_weights = attn_weights.masked_fill(_attn_mask, -float('inf')) | |||||
attn_weights = F.softmax(attn_weights, dim=2) | |||||
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) | |||||
output = torch.einsum('bqkn,bknh->bqnh', attn_weights, v) # batch,q_len,n_head,head_dim | |||||
output = output.reshape(batch_size, q_len, -1) | |||||
output = self.out_proj(output) # batch,q_len,dim | |||||
return output, attn_weights | |||||
def reset_parameters(self): | def reset_parameters(self): | ||||
sqrt = math.sqrt | |||||
nn.init.normal_(self.q_in.weight, mean=0, std=sqrt(1.0 / self.input_size)) | |||||
nn.init.normal_(self.k_in.weight, mean=0, std=sqrt(1.0 / self.input_size)) | |||||
nn.init.normal_(self.v_in.weight, mean=0, std=sqrt(1.0 / self.input_size)) | |||||
nn.init.normal_(self.out.weight, mean=0, std=sqrt(1.0 / self.input_size)) | |||||
nn.init.xavier_uniform_(self.q_proj.weight) | |||||
nn.init.xavier_uniform_(self.k_proj.weight) | |||||
nn.init.xavier_uniform_(self.v_proj.weight) | |||||
nn.init.xavier_uniform_(self.out_proj.weight) | |||||
def set_layer_idx(self, layer_idx): | |||||
self.layer_idx = layer_idx | |||||
def forward(self, Q, K, V, atte_mask_out=None): | |||||
r""" | |||||
:param Q: [batch, seq_len_q, model_size] | |||||
:param K: [batch, seq_len_k, model_size] | |||||
:param V: [batch, seq_len_k, model_size] | |||||
:param seq_mask: [batch, seq_len] | |||||
class AttentionLayer(nn.Module): | |||||
def __init__(selfu, input_size, key_dim, value_dim, bias=False): | |||||
""" | """ | ||||
batch, sq, _ = Q.size() | |||||
sk = K.size(1) | |||||
d_k, d_v, n_head = self.key_size, self.value_size, self.num_head | |||||
# input linear | |||||
q = self.q_in(Q).view(batch, sq, n_head, d_k).transpose(1, 2) | |||||
k = self.k_in(K).view(batch, sk, n_head, d_k).transpose(1, 2) | |||||
v = self.v_in(V).view(batch, sk, n_head, d_v).transpose(1, 2) | |||||
if atte_mask_out is not None: | |||||
atte_mask_out = atte_mask_out[:,None,:,:] # [bsz,1,1,len] | |||||
atte = self.attention(q, k, v, atte_mask_out).view(batch, n_head, sq, d_v) | |||||
# concat all heads, do output linear | |||||
atte = atte.transpose(1, 2).contiguous().view(batch, sq, -1) | |||||
output = self.out(atte) | |||||
return output | |||||
可用于LSTM2LSTM的序列到序列模型的decode过程中,该attention是在decode过程中根据上一个step的hidden计算对encoder结果的attention | |||||
:param int input_size: 输入的大小 | |||||
:param int key_dim: 一般就是encoder_output输出的维度 | |||||
:param int value_dim: 输出的大小维度, 一般就是decoder hidden的大小 | |||||
:param bias: | |||||
""" | |||||
super().__init__() | |||||
selfu.input_proj = nn.Linear(input_size, key_dim, bias=bias) | |||||
selfu.output_proj = nn.Linear(input_size + key_dim, value_dim, bias=bias) | |||||
def forward(self, input, encode_outputs, encode_mask): | |||||
""" | |||||
:param input: batch_size x input_size | |||||
:param encode_outputs: batch_size x max_len x key_dim | |||||
:param encode_mask: batch_size x max_len, 为0的地方为padding | |||||
:return: hidden: batch_size x value_dim, scores: batch_size x max_len, normalized过的 | |||||
""" | |||||
# x: bsz x encode_hidden_size | |||||
x = self.input_proj(input) | |||||
# compute attention | |||||
attn_scores = torch.matmul(encode_outputs, x.unsqueeze(-1)).squeeze(-1) # b x max_len | |||||
# don't attend over padding | |||||
if encode_mask is not None: | |||||
attn_scores = attn_scores.float().masked_fill_( | |||||
encode_mask.eq(0), | |||||
float('-inf') | |||||
).type_as(attn_scores) # FP16 support: cast to float and back | |||||
attn_scores = F.softmax(attn_scores, dim=-1) # srclen x bsz | |||||
# sum weighted sources | |||||
x = torch.matmul(attn_scores.unsqueeze(1), encode_outputs).squeeze(1) # b x encode_hidden_size | |||||
x = torch.tanh(self.output_proj(torch.cat((x, input), dim=1))) | |||||
return x, attn_scores | |||||
def _masked_softmax(tensor, mask): | def _masked_softmax(tensor, mask): |
@@ -6,10 +6,20 @@ __all__ = [ | |||||
"MLP", | "MLP", | ||||
"ConditionalRandomField", | "ConditionalRandomField", | ||||
"viterbi_decode", | "viterbi_decode", | ||||
"allowed_transitions" | |||||
"allowed_transitions", | |||||
"LSTMState", | |||||
"TransformerState", | |||||
"State", | |||||
"TransformerSeq2SeqDecoder", | |||||
"LSTMSeq2SeqDecoder", | |||||
"Seq2SeqDecoder" | |||||
] | ] | ||||
from .crf import ConditionalRandomField | from .crf import ConditionalRandomField | ||||
from .crf import allowed_transitions | from .crf import allowed_transitions | ||||
from .mlp import MLP | from .mlp import MLP | ||||
from .utils import viterbi_decode | from .utils import viterbi_decode | ||||
from .seq2seq_decoder import Seq2SeqDecoder, LSTMSeq2SeqDecoder, TransformerSeq2SeqDecoder | |||||
from .seq2seq_state import State, LSTMState, TransformerState |
@@ -1,109 +1,413 @@ | |||||
# coding=utf-8 | |||||
__all__ = [ | |||||
"TransformerPast", | |||||
"Past", | |||||
"Decoder" | |||||
] | |||||
from typing import Union, Tuple | |||||
import math | |||||
import torch | import torch | ||||
from torch import nn | from torch import nn | ||||
import abc | |||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from ..attention import AttentionLayer, MultiHeadAttention | |||||
from ...embeddings import StaticEmbedding | from ...embeddings import StaticEmbedding | ||||
import numpy as np | |||||
from typing import Union, Tuple | |||||
from ...embeddings.utils import get_embeddings | from ...embeddings.utils import get_embeddings | ||||
from torch.nn import LayerNorm | |||||
import math | |||||
from .seq2seq_state import State, LSTMState, TransformerState | |||||
class Seq2SeqDecoder(nn.Module): | |||||
""" | |||||
Sequence-to-Sequence Decoder的基类。一定需要实现forward函数,剩下的函数根据需要实现。每个Seq2SeqDecoder都应该有相应的State对象 | |||||
用来承载该Decoder所需要的Encoder输出、Decoder需要记录的历史信息(例如LSTM的hidden信息)。 | |||||
class Past: | |||||
""" | |||||
def __init__(self): | def __init__(self): | ||||
pass | |||||
super().__init__() | |||||
def forward(self, tokens, state, **kwargs): | |||||
""" | |||||
@abc.abstractmethod | |||||
def num_samples(self): | |||||
pass | |||||
:param torch.LongTensor tokens: bsz x max_len | |||||
:param State state: state包含了encoder的输出以及decode之前的内容 | |||||
:return: 返回值可以为bsz x max_len x vocab_size的Tensor,也可以是一个list,但是第一个元素必须是词的预测分布 | |||||
""" | |||||
raise NotImplemented | |||||
@abc.abstractmethod | |||||
def reorder_past(self, indices: torch.LongTensor): | |||||
def reorder_states(self, indices, states): | |||||
""" | """ | ||||
根据indices中的index,将past的中状态置为正确的顺序。inplace改变 | |||||
根据indices重新排列states中的状态,在beam search进行生成时,会用到该函数。 | |||||
:param torch.LongTensor indices: | :param torch.LongTensor indices: | ||||
:param Past past: | |||||
:param State states: | |||||
:return: | :return: | ||||
""" | """ | ||||
raise NotImplemented | |||||
assert isinstance(states, State), f"`states` should be of type State instead of {type(states)}" | |||||
states.reorder_state(indices) | |||||
def init_state(self, encoder_output, encoder_mask): | |||||
""" | |||||
初始化一个state对象,用来记录了encoder的输出以及decode已经完成的部分。 | |||||
:param Union[torch.Tensor, list, tuple] encoder_output: 如果不为None,内部元素需要为torch.Tensor, 默认其中第一维是batch | |||||
维度 | |||||
:param Union[torch.Tensor, list, tuple] encoder_mask: 如果部位None,内部元素需要torch.Tensor, 默认其中第一维是batch | |||||
维度 | |||||
:param kwargs: | |||||
:return: State, 返回一个State对象,记录了encoder的输出 | |||||
""" | |||||
state = State(encoder_output, encoder_mask) | |||||
return state | |||||
class TransformerPast(Past): | |||||
def __init__(self, encoder_outputs: torch.Tensor = None, encoder_mask: torch.Tensor = None, | |||||
num_decoder_layer: int = 6): | |||||
def decode(self, tokens, state): | |||||
""" | """ | ||||
根据states中的内容,以及tokens中的内容进行之后的生成。 | |||||
:param encoder_outputs: (batch,src_seq_len,dim) | |||||
:param encoder_mask: (batch,src_seq_len) | |||||
:param encoder_key: list of (batch, src_seq_len, dim) | |||||
:param encoder_value: | |||||
:param decoder_prev_key: | |||||
:param decoder_prev_value: | |||||
:param torch.LongTensor tokens: bsz x max_len, 上一个时刻的token输出。 | |||||
:param State state: 记录了encoder输出与decoder过去状态 | |||||
:return: torch.FloatTensor: bsz x vocab_size, 输出的是下一个时刻的分布 | |||||
""" | """ | ||||
outputs = self(state=state, tokens=tokens) | |||||
if isinstance(outputs, torch.Tensor): | |||||
return outputs[:, -1] | |||||
else: | |||||
raise RuntimeError("Unrecognized output from the `forward()` function. Please override the `decode()` function.") | |||||
class TiedEmbedding(nn.Module): | |||||
""" | |||||
用于将weight和原始weight绑定 | |||||
""" | |||||
def __init__(self, weight): | |||||
super().__init__() | super().__init__() | ||||
self.encoder_outputs = encoder_outputs | |||||
self.encoder_mask = encoder_mask | |||||
self.encoder_key = [None] * num_decoder_layer | |||||
self.encoder_value = [None] * num_decoder_layer | |||||
self.decoder_prev_key = [None] * num_decoder_layer | |||||
self.decoder_prev_value = [None] * num_decoder_layer | |||||
def num_samples(self): | |||||
if self.encoder_outputs is not None: | |||||
return self.encoder_outputs.size(0) | |||||
return None | |||||
def _reorder_state(self, state, indices): | |||||
if type(state) == torch.Tensor: | |||||
state = state.index_select(index=indices, dim=0) | |||||
elif type(state) == list: | |||||
for i in range(len(state)): | |||||
assert state[i] is not None | |||||
state[i] = state[i].index_select(index=indices, dim=0) | |||||
self.weight = weight # vocab_size x embed_size | |||||
def forward(self, x): | |||||
""" | |||||
:param torch.FloatTensor x: bsz x * x embed_size | |||||
:return: torch.FloatTensor bsz x * x vocab_size | |||||
""" | |||||
return torch.matmul(x, self.weight.t()) | |||||
def get_binded_decoder_output_embed(embed): | |||||
""" | |||||
给定一个embedding,输出对应的绑定的embedding,输出对象为TiedEmbedding | |||||
:param embed: | |||||
:return: | |||||
""" | |||||
if isinstance(embed, StaticEmbedding): | |||||
for idx, map2idx in enumerate(embed.words_to_words): | |||||
assert idx == map2idx, "Invalid StaticEmbedding for Decoder, please check:(1) whether the vocabulary " \ | |||||
"include `no_create_entry=True` word; (2) StaticEmbedding should not initialize with " \ | |||||
"`lower=True` or `min_freq!=1`." | |||||
elif not isinstance(embed, nn.Embedding): | |||||
raise TypeError("Only nn.Embedding or StaticEmbedding is allowed for binding.") | |||||
return TiedEmbedding(embed.weight) | |||||
class LSTMSeq2SeqDecoder(Seq2SeqDecoder): | |||||
def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], num_layers = 3, hidden_size = 300, | |||||
dropout = 0.3, bind_decoder_input_output_embed = True, attention=True): | |||||
""" | |||||
LSTM的Decoder | |||||
:param nn.Module,tuple embed: decoder输入的embedding. | |||||
:param int num_layers: 多少层LSTM | |||||
:param int hidden_size: 隐藏层大小, 该值也被认为是encoder的输出维度大小 | |||||
:param dropout: Dropout的大小 | |||||
:param bool bind_decoder_input_output_embed: 是否将输出层和输入层的词向量绑定在一起(即为同一个),若embed为StaticEmbedding, | |||||
则StaticEmbedding的vocab不能包含no_create_entry的token,同时StaticEmbedding初始化时lower为False, min_freq=1. | |||||
:param bool attention: 是否使用attention | |||||
""" | |||||
super().__init__() | |||||
self.embed = get_embeddings(init_embed=embed) | |||||
self.embed_dim = embed.embedding_dim | |||||
if bind_decoder_input_output_embed: | |||||
self.output_layer = get_binded_decoder_output_embed(self.embed) | |||||
else: # 不需要bind | |||||
self.output_embed = get_embeddings((self.embed.num_embeddings, self.embed.embedding_dim)) | |||||
self.output_layer = TiedEmbedding(self.output_embed.weight) | |||||
self.hidden_size = hidden_size | |||||
self.num_layers = num_layers | |||||
self.lstm = nn.LSTM(input_size=self.embed_dim + hidden_size, hidden_size=hidden_size, num_layers=num_layers, | |||||
batch_first=True, bidirectional=False, dropout=dropout if num_layers>1 else 0) | |||||
self.attention_layer = AttentionLayer(hidden_size, hidden_size, hidden_size) if attention else None | |||||
self.output_proj = nn.Linear(hidden_size, self.embed_dim) | |||||
self.dropout_layer = nn.Dropout(dropout) | |||||
def forward(self, tokens, state, return_attention=False): | |||||
""" | |||||
:param torch.LongTensor tokens: batch x max_len | |||||
:param LSTMState state: 保存encoder输出和decode状态的State对象 | |||||
:param bool return_attention: 是否返回attention的的score | |||||
:return: bsz x max_len x vocab_size; 如果return_attention=True, 还会返回bsz x max_len x encode_length | |||||
""" | |||||
src_output = state.encoder_output | |||||
encoder_mask = state.encoder_mask | |||||
assert tokens.size(1)>state.decode_length, "The state does not match the tokens." | |||||
tokens = tokens[:, state.decode_length:] | |||||
x = self.embed(tokens) | |||||
attn_weights = [] if self.attention_layer is not None else None # 保存attention weight, batch,tgt_seq,src_seq | |||||
input_feed = state.input_feed | |||||
decoder_out = [] | |||||
cur_hidden = state.hidden | |||||
cur_cell = state.cell | |||||
# 开始计算 | |||||
for i in range(tokens.size(1)): | |||||
input = torch.cat( | |||||
(x[:, i:i + 1, :], | |||||
input_feed[:, None, :] | |||||
), | |||||
dim=2 | |||||
) # batch,1,2*dim | |||||
_, (cur_hidden, cur_cell) = self.lstm(input, hx=(cur_hidden, cur_cell)) # hidden/cell保持原来的size | |||||
if self.attention_layer is not None: | |||||
input_feed, attn_weight = self.attention_layer(cur_hidden[-1], src_output, encoder_mask) | |||||
attn_weights.append(attn_weight) | |||||
else: | |||||
input_feed = cur_hidden[-1] | |||||
state.input_feed = input_feed # batch, hidden | |||||
state.hidden = cur_hidden | |||||
state.cell = cur_cell | |||||
state.decode_length += 1 | |||||
decoder_out.append(input_feed) | |||||
decoder_out = torch.stack(decoder_out, dim=1) # batch,seq_len,hidden | |||||
decoder_out = self.dropout_layer(decoder_out) | |||||
if attn_weights is not None: | |||||
attn_weights = torch.cat(attn_weights, dim=1) # batch, tgt_len, src_len | |||||
decoder_out = self.output_proj(decoder_out) | |||||
feats = self.output_layer(decoder_out) | |||||
if return_attention: | |||||
return feats, attn_weights | |||||
return feats | |||||
def init_state(self, encoder_output, encoder_mask) -> LSTMState: | |||||
""" | |||||
:param encoder_output: 输入可以有两种情况(1) 输入为一个tuple,包含三个内容(encoder_output, (hidden, cell)),其中encoder_output: | |||||
bsz x max_len x hidden_size, hidden: bsz x hidden_size, cell:bsz x hidden_size,一般使用LSTMEncoder的最后一层的 | |||||
hidden state和cell state来赋值这两个值 | |||||
(2) 只有encoder_output: bsz x max_len x hidden_size, 这种情况下hidden和cell使用0初始化 | |||||
:param torch.ByteTensor encoder_mask: bsz x max_len, 为0的位置是padding, 用来指示source中哪些不需要attend | |||||
:return: | |||||
""" | |||||
if not isinstance(encoder_output, torch.Tensor): | |||||
encoder_output, (hidden, cell) = encoder_output | |||||
else: | else: | ||||
raise ValueError('State does not support other format') | |||||
hidden = cell = None | |||||
assert encoder_output.ndim==3 | |||||
assert encoder_mask.size()==encoder_output.size()[:2] | |||||
assert encoder_output.size(-1)==self.hidden_size, "The dimension of encoder outputs should be the same with " \ | |||||
"the hidden_size." | |||||
t = [hidden, cell] | |||||
for idx in range(2): | |||||
v = t[idx] | |||||
if v is None: | |||||
v = encoder_output.new_zeros(self.num_layers, encoder_output.size(0), self.hidden_size) | |||||
else: | |||||
assert v.dim()==2 | |||||
assert v.size(-1)==self.hidden_size | |||||
v = v[None].repeat(self.num_layers, 1, 1) # num_layers x bsz x hidden_size | |||||
t[idx] = v | |||||
state = LSTMState(encoder_output, encoder_mask, t[0], t[1]) | |||||
return state | return state | ||||
def reorder_past(self, indices: torch.LongTensor): | |||||
self.encoder_outputs = self._reorder_state(self.encoder_outputs, indices) | |||||
self.encoder_mask = self._reorder_state(self.encoder_mask, indices) | |||||
self.encoder_key = self._reorder_state(self.encoder_key, indices) | |||||
self.encoder_value = self._reorder_state(self.encoder_value, indices) | |||||
self.decoder_prev_key = self._reorder_state(self.decoder_prev_key, indices) | |||||
self.decoder_prev_value = self._reorder_state(self.decoder_prev_value, indices) | |||||
return self | |||||
class TransformerSeq2SeqDecoderLayer(nn.Module): | |||||
def __init__(self, d_model = 512, n_head = 8, dim_ff = 2048, dropout = 0.1, layer_idx = None): | |||||
""" | |||||
class Decoder(nn.Module): | |||||
def __init__(self): | |||||
:param int d_model: 输入、输出的维度 | |||||
:param int n_head: 多少个head,需要能被d_model整除 | |||||
:param int dim_ff: | |||||
:param float dropout: | |||||
:param int layer_idx: layer的编号 | |||||
""" | |||||
super().__init__() | super().__init__() | ||||
self.d_model = d_model | |||||
self.n_head = n_head | |||||
self.dim_ff = dim_ff | |||||
self.dropout = dropout | |||||
self.layer_idx = layer_idx # 记录layer的层索引,以方便获取state的信息 | |||||
self.self_attn = MultiHeadAttention(d_model, n_head, dropout, layer_idx) | |||||
self.self_attn_layer_norm = nn.LayerNorm(d_model) | |||||
self.encoder_attn = MultiHeadAttention(d_model, n_head, dropout, layer_idx) | |||||
self.encoder_attn_layer_norm = nn.LayerNorm(d_model) | |||||
self.ffn = nn.Sequential(nn.Linear(self.d_model, self.dim_ff), | |||||
nn.ReLU(), | |||||
nn.Dropout(dropout), | |||||
nn.Linear(self.dim_ff, self.d_model), | |||||
nn.Dropout(dropout)) | |||||
self.final_layer_norm = nn.LayerNorm(self.d_model) | |||||
@abc.abstractmethod | |||||
def decode(self, *args, **kwargs) -> Tuple[torch.Tensor, Past]: | |||||
def forward(self, x, encoder_output, encoder_mask=None, self_attn_mask=None, state=None): | |||||
""" | """ | ||||
当模型进行解码时,使用这个函数。返回一个batch_size x vocab_size的结果与更新的Past状态。需要考虑一种特殊情况,即tokens长度不是1,即给定了 | |||||
解码句子开头的情况,这种情况需要查看Past中是否正确计算了decode的状态。 | |||||
:return: tensor:batch_size x vocab_size, past: Past | |||||
:param x: (batch, seq_len, dim), decoder端的输入 | |||||
:param encoder_output: (batch,src_seq_len,dim) | |||||
:param encoder_mask: batch,src_seq_len | |||||
:param self_attn_mask: seq_len, seq_len,下三角的mask矩阵,只在训练时传入 | |||||
:param TransformerState state: 只在inference阶段传入 | |||||
:return: | |||||
""" | """ | ||||
raise NotImplemented | |||||
@abc.abstractmethod | |||||
def reorder_past(self, indices: torch.LongTensor, past: Past): | |||||
# self attention part | |||||
residual = x | |||||
x = self.self_attn_layer_norm(x) | |||||
x, _ = self.self_attn(query=x, | |||||
key=x, | |||||
value=x, | |||||
attn_mask=self_attn_mask, | |||||
state=state) | |||||
x = F.dropout(x, p=self.dropout, training=self.training) | |||||
x = residual + x | |||||
# encoder attention part | |||||
residual = x | |||||
x = self.encoder_attn_layer_norm(x) | |||||
x, attn_weight = self.encoder_attn(query=x, | |||||
key=encoder_output, | |||||
value=encoder_output, | |||||
key_mask=encoder_mask, | |||||
state=state) | |||||
x = F.dropout(x, p=self.dropout, training=self.training) | |||||
x = residual + x | |||||
# ffn | |||||
residual = x | |||||
x = self.final_layer_norm(x) | |||||
x = self.ffn(x) | |||||
x = residual + x | |||||
return x, attn_weight | |||||
class TransformerSeq2SeqDecoder(Seq2SeqDecoder): | |||||
def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], pos_embed: nn.Module = None, | |||||
d_model = 512, num_layers=6, n_head = 8, dim_ff = 2048, dropout = 0.1, | |||||
bind_decoder_input_output_embed = True): | |||||
""" | """ | ||||
根据indices中的index,将past的中状态置为正确的顺序。inplace改变 | |||||
:param torch.LongTensor indices: | |||||
:param Past past: | |||||
:return: | |||||
:param embed: 输入token的embedding | |||||
:param nn.Module pos_embed: 位置embedding | |||||
:param int d_model: 输出、输出的大小 | |||||
:param int num_layers: 多少层 | |||||
:param int n_head: 多少个head | |||||
:param int dim_ff: FFN 的中间大小 | |||||
:param float dropout: Self-Attention和FFN中的dropout的大小 | |||||
:param bool bind_decoder_input_output_embed: 是否将输出层和输入层的词向量绑定在一起(即为同一个),若embed为StaticEmbedding, | |||||
则StaticEmbedding的vocab不能包含no_create_entry的token,同时StaticEmbedding初始化时lower为False, min_freq=1. | |||||
""" | |||||
super().__init__() | |||||
self.embed = get_embeddings(embed) | |||||
self.pos_embed = pos_embed | |||||
if bind_decoder_input_output_embed: | |||||
self.output_layer = get_binded_decoder_output_embed(self.embed) | |||||
else: # 不需要bind | |||||
self.output_embed = get_embeddings((self.embed.num_embeddings, self.embed.embedding_dim)) | |||||
self.output_layer = TiedEmbedding(self.output_embed.weight) | |||||
self.num_layers = num_layers | |||||
self.d_model = d_model | |||||
self.n_head = n_head | |||||
self.dim_ff = dim_ff | |||||
self.dropout = dropout | |||||
self.input_fc = nn.Linear(self.embed.embedding_dim, d_model) | |||||
self.layer_stacks = nn.ModuleList([TransformerSeq2SeqDecoderLayer(d_model, n_head, dim_ff, dropout, layer_idx) | |||||
for layer_idx in range(num_layers)]) | |||||
self.embed_scale = math.sqrt(d_model) | |||||
self.layer_norm = nn.LayerNorm(d_model) | |||||
self.output_fc = nn.Linear(self.d_model, self.embed.embedding_dim) | |||||
def forward(self, tokens, state, return_attention=False): | |||||
""" | """ | ||||
raise NotImplemented | |||||
:param torch.LongTensor tokens: batch x tgt_len,decode的词 | |||||
:param TransformerState state: 用于记录encoder的输出以及decode状态的对象,可以通过init_state()获取 | |||||
:param bool return_attention: 是否返回对encoder结果的attention score | |||||
:return: bsz x max_len x vocab_size; 如果return_attention=True, 还会返回bsz x max_len x encode_length | |||||
""" | |||||
encoder_output = state.encoder_output | |||||
encoder_mask = state.encoder_mask | |||||
assert state.decode_length<tokens.size(1), "The decoded tokens in State should be less than tokens." | |||||
tokens = tokens[:, state.decode_length:] | |||||
device = tokens.device | |||||
x = self.embed_scale * self.embed(tokens) | |||||
if self.pos_embed is not None: | |||||
position = torch.arange(state.decode_length, state.decode_length+tokens.size(1)).long().to(device)[None] | |||||
x += self.pos_embed(position) | |||||
x = self.input_fc(x) | |||||
x = F.dropout(x, p=self.dropout, training=self.training) | |||||
batch_size, max_tgt_len = tokens.size() | |||||
if max_tgt_len>1: | |||||
triangle_mask = self._get_triangle_mask(tokens) | |||||
else: | |||||
triangle_mask = None | |||||
for layer in self.layer_stacks: | |||||
x, attn_weight = layer(x=x, | |||||
encoder_output=encoder_output, | |||||
encoder_mask=encoder_mask, | |||||
self_attn_mask=triangle_mask, | |||||
state=state | |||||
) | |||||
x = self.layer_norm(x) # batch, tgt_len, dim | |||||
x = self.output_fc(x) | |||||
feats = self.output_layer(x) | |||||
if return_attention: | |||||
return feats, attn_weight | |||||
return feats | |||||
def init_state(self, encoder_output, encoder_mask): | |||||
""" | |||||
初始化一个TransformerState用于forward | |||||
:param torch.FloatTensor encoder_output: bsz x max_len x d_model, encoder的输出 | |||||
:param torch.ByteTensor encoder_mask: bsz x max_len, 为1的位置需要attend。 | |||||
:return: TransformerState | |||||
""" | |||||
if isinstance(encoder_output, torch.Tensor): | |||||
encoder_output = encoder_output | |||||
elif isinstance(encoder_output, (list, tuple)): | |||||
encoder_output = encoder_output[0] # 防止是LSTMEncoder的输出结果 | |||||
else: | |||||
raise TypeError("Unsupported `encoder_output` for TransformerSeq2SeqDecoder") | |||||
state = TransformerState(encoder_output, encoder_mask, num_decoder_layer=self.num_layers) | |||||
return state | |||||
@staticmethod | |||||
def _get_triangle_mask(tokens): | |||||
tensor = tokens.new_ones(tokens.size(1), tokens.size(1)) | |||||
return torch.tril(tensor).byte() | |||||
@@ -0,0 +1,145 @@ | |||||
r""" | |||||
每个Decoder都有对应的State用来记录encoder的输出以及Decode的历史记录 | |||||
""" | |||||
__all__ = [ | |||||
'State', | |||||
"LSTMState", | |||||
"TransformerState" | |||||
] | |||||
from typing import Union | |||||
import torch | |||||
class State: | |||||
def __init__(self, encoder_output=None, encoder_mask=None, **kwargs): | |||||
""" | |||||
每个Decoder都有对应的State对象用来承载encoder的输出以及当前时刻之前的decode状态。 | |||||
:param Union[torch.Tensor, list, tuple] encoder_output: 如果不为None,内部元素需要为torch.Tensor, 默认其中第一维是batch | |||||
维度 | |||||
:param Union[torch.Tensor, list, tuple] encoder_mask: 如果部位None,内部元素需要torch.Tensor, 默认其中第一维是batch | |||||
维度 | |||||
:param kwargs: | |||||
""" | |||||
self.encoder_output = encoder_output | |||||
self.encoder_mask = encoder_mask | |||||
self._decode_length = 0 | |||||
@property | |||||
def num_samples(self): | |||||
""" | |||||
返回的State中包含的是多少个sample的encoder状态,主要用于Generate的时候确定batch的大小。 | |||||
:return: | |||||
""" | |||||
if self.encoder_output is not None: | |||||
return self.encoder_output.size(0) | |||||
else: | |||||
return None | |||||
@property | |||||
def decode_length(self): | |||||
""" | |||||
当前Decode到哪个token了,decoder只会从decode_length之后的token开始decode, 为0说明还没开始decode。 | |||||
:return: | |||||
""" | |||||
return self._decode_length | |||||
@decode_length.setter | |||||
def decode_length(self, value): | |||||
self._decode_length = value | |||||
def _reorder_state(self, state: Union[torch.Tensor, list, tuple], indices: torch.LongTensor, dim: int = 0): | |||||
if isinstance(state, torch.Tensor): | |||||
state = state.index_select(index=indices, dim=dim) | |||||
elif isinstance(state, list): | |||||
for i in range(len(state)): | |||||
assert state[i] is not None | |||||
state[i] = self._reorder_state(state[i], indices, dim) | |||||
elif isinstance(state, tuple): | |||||
tmp_list = [] | |||||
for i in range(len(state)): | |||||
assert state[i] is not None | |||||
tmp_list.append(self._reorder_state(state[i], indices, dim)) | |||||
state = tuple(tmp_list) | |||||
else: | |||||
raise TypeError(f"Cannot reorder data of type:{type(state)}") | |||||
return state | |||||
def reorder_state(self, indices: torch.LongTensor): | |||||
if self.encoder_mask is not None: | |||||
self.encoder_mask = self._reorder_state(self.encoder_mask, indices) | |||||
if self.encoder_output is not None: | |||||
self.encoder_output = self._reorder_state(self.encoder_output, indices) | |||||
class LSTMState(State): | |||||
def __init__(self, encoder_output, encoder_mask, hidden, cell): | |||||
""" | |||||
LSTMDecoder对应的State,保存encoder的输出以及LSTM解码过程中的一些中间状态 | |||||
:param torch.FloatTensor encoder_output: bsz x src_seq_len x encode_output_size,encoder的输出 | |||||
:param torch.BoolTensor encoder_mask: bsz x src_seq_len, 为0的地方是padding | |||||
:param torch.FloatTensor hidden: num_layers x bsz x hidden_size, 上个时刻的hidden状态 | |||||
:param torch.FloatTensor cell: num_layers x bsz x hidden_size, 上个时刻的cell状态 | |||||
""" | |||||
super().__init__(encoder_output, encoder_mask) | |||||
self.hidden = hidden | |||||
self.cell = cell | |||||
self._input_feed = hidden[0] # 默认是上一个时刻的输出 | |||||
@property | |||||
def input_feed(self): | |||||
""" | |||||
LSTMDecoder中每个时刻的输入会把上个token的embedding和input_feed拼接起来输入到下个时刻,在LSTMDecoder不使用attention时, | |||||
input_feed即上个时刻的hidden state, 否则是attention layer的输出。 | |||||
:return: torch.FloatTensor, bsz x hidden_size | |||||
""" | |||||
return self._input_feed | |||||
@input_feed.setter | |||||
def input_feed(self, value): | |||||
self._input_feed = value | |||||
def reorder_state(self, indices: torch.LongTensor): | |||||
super().reorder_state(indices) | |||||
self.hidden = self._reorder_state(self.hidden, indices, dim=1) | |||||
self.cell = self._reorder_state(self.cell, indices, dim=1) | |||||
if self.input_feed is not None: | |||||
self.input_feed = self._reorder_state(self.input_feed, indices, dim=0) | |||||
class TransformerState(State): | |||||
def __init__(self, encoder_output, encoder_mask, num_decoder_layer): | |||||
""" | |||||
与TransformerSeq2SeqDecoder对应的State, | |||||
:param torch.FloatTensor encoder_output: bsz x encode_max_len x encoder_output_size, encoder的输出 | |||||
:param torch.ByteTensor encoder_mask: bsz x encode_max_len 为1的地方需要attend | |||||
:param int num_decoder_layer: decode有多少层 | |||||
""" | |||||
super().__init__(encoder_output, encoder_mask) | |||||
self.encoder_key = [None] * num_decoder_layer # 每一个元素 bsz x encoder_max_len x key_dim | |||||
self.encoder_value = [None] * num_decoder_layer # 每一个元素 bsz x encoder_max_len x value_dim | |||||
self.decoder_prev_key = [None] * num_decoder_layer # 每一个元素 bsz x decode_length x key_dim | |||||
self.decoder_prev_value = [None] * num_decoder_layer # 每一个元素 bsz x decode_length x key_dim | |||||
def reorder_state(self, indices: torch.LongTensor): | |||||
super().reorder_state(indices) | |||||
self.encoder_key = self._reorder_state(self.encoder_key, indices) | |||||
self.encoder_value = self._reorder_state(self.encoder_value, indices) | |||||
self.decoder_prev_key = self._reorder_state(self.decoder_prev_key, indices) | |||||
self.decoder_prev_value = self._reorder_state(self.decoder_prev_value, indices) | |||||
@property | |||||
def decode_length(self): | |||||
if self.decoder_prev_key[0] is not None: | |||||
return self.decoder_prev_key[0].size(1) | |||||
return 0 | |||||
@@ -4,8 +4,6 @@ r""" | |||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
# "BertModel", | |||||
"ConvolutionCharEncoder", | "ConvolutionCharEncoder", | ||||
"LSTMCharEncoder", | "LSTMCharEncoder", | ||||
@@ -35,10 +33,14 @@ __all__ = [ | |||||
"RobertaModel", | "RobertaModel", | ||||
"GPT2Model" | |||||
"GPT2Model", | |||||
"LSTMSeq2SeqEncoder", | |||||
"TransformerSeq2SeqEncoder", | |||||
"Seq2SeqEncoder" | |||||
] | ] | ||||
from .attention import MultiHeadAttention, BiAttention, SelfAttention | |||||
from fastNLP.modules.attention import MultiHeadAttention, BiAttention, SelfAttention | |||||
from .bert import BertModel | from .bert import BertModel | ||||
from .roberta import RobertaModel | from .roberta import RobertaModel | ||||
from .gpt2 import GPT2Model | from .gpt2 import GPT2Model | ||||
@@ -49,3 +51,4 @@ from .pooling import MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, KMaxPoo | |||||
from .star_transformer import StarTransformer | from .star_transformer import StarTransformer | ||||
from .transformer import TransformerEncoder | from .transformer import TransformerEncoder | ||||
from .variational_rnn import VarRNN, VarLSTM, VarGRU | from .variational_rnn import VarRNN, VarLSTM, VarGRU | ||||
from .seq2seq_encoder import LSTMSeq2SeqEncoder, TransformerSeq2SeqEncoder, Seq2SeqEncoder |
@@ -10,6 +10,7 @@ __all__ = [ | |||||
import copy | import copy | ||||
import json | import json | ||||
import math | import math | ||||
import os | |||||
import torch | import torch | ||||
from torch import nn | from torch import nn | ||||
@@ -20,7 +21,8 @@ from ...io.file_utils import _get_bert_dir | |||||
from ...core import logger | from ...core import logger | ||||
CONFIG_FILE = 'bert_config.json' | |||||
CONFIG_FILE = 'config.json' | |||||
WEIGHTS_NAME = 'pytorch_model.bin' | |||||
BERT_KEY_RENAME_MAP_1 = { | BERT_KEY_RENAME_MAP_1 = { | ||||
'gamma': 'weight', | 'gamma': 'weight', | ||||
@@ -57,7 +59,8 @@ class BertConfig(object): | |||||
max_position_embeddings=512, | max_position_embeddings=512, | ||||
type_vocab_size=2, | type_vocab_size=2, | ||||
initializer_range=0.02, | initializer_range=0.02, | ||||
layer_norm_eps=1e-12): | |||||
layer_norm_eps=1e-12, | |||||
architectures='bert'): | |||||
r"""Constructs BertConfig. | r"""Constructs BertConfig. | ||||
Args: | Args: | ||||
@@ -101,6 +104,7 @@ class BertConfig(object): | |||||
self.type_vocab_size = type_vocab_size | self.type_vocab_size = type_vocab_size | ||||
self.initializer_range = initializer_range | self.initializer_range = initializer_range | ||||
self.layer_norm_eps = layer_norm_eps | self.layer_norm_eps = layer_norm_eps | ||||
self.architectures = architectures | |||||
else: | else: | ||||
raise ValueError("First argument must be either a vocabulary size (int)" | raise ValueError("First argument must be either a vocabulary size (int)" | ||||
"or the path to a pretrained model config file (str)") | "or the path to a pretrained model config file (str)") | ||||
@@ -134,9 +138,13 @@ class BertConfig(object): | |||||
def to_json_file(self, json_file_path): | def to_json_file(self, json_file_path): | ||||
r""" Save this instance to a json file.""" | r""" Save this instance to a json file.""" | ||||
if os.path.isdir(json_file_path): | |||||
json_file_path = os.path.join(json_file_path, CONFIG_FILE) | |||||
with open(json_file_path, "w", encoding='utf-8') as writer: | with open(json_file_path, "w", encoding='utf-8') as writer: | ||||
writer.write(self.to_json_string()) | writer.write(self.to_json_string()) | ||||
def save_pretrained(self, save_directory): | |||||
self.to_json_file(save_directory) | |||||
def gelu(x): | def gelu(x): | ||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) | ||||
@@ -149,21 +157,6 @@ def swish(x): | |||||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} | ||||
# class BertLayerNorm(nn.Module): | |||||
# def __init__(self, hidden_size, eps=1e-12): | |||||
# r"""Construct a layernorm module in the TF style (epsilon inside the square root). | |||||
# """ | |||||
# super(BertLayerNorm, self).__init__() | |||||
# self.weight = nn.Parameter(torch.ones(hidden_size)) | |||||
# self.bias = nn.Parameter(torch.zeros(hidden_size)) | |||||
# self.variance_epsilon = eps | |||||
# | |||||
# def forward(self, x): | |||||
# u = x.mean(-1, keepdim=True) | |||||
# s = (x - u).pow(2).mean(-1, keepdim=True) | |||||
# x = (x - u) / torch.sqrt(s + self.variance_epsilon) | |||||
# return self.weight * x + self.bias | |||||
BertLayerNorm = torch.nn.LayerNorm | BertLayerNorm = torch.nn.LayerNorm | ||||
@@ -613,3 +606,24 @@ class BertModel(nn.Module): | |||||
logger.info(f"Load pre-trained {model_type} parameters from file {weights_path}.") | logger.info(f"Load pre-trained {model_type} parameters from file {weights_path}.") | ||||
return model | return model | ||||
def save_pretrained(self, save_directory): | |||||
""" 保存模型到某个folder | |||||
""" | |||||
assert os.path.isdir( | |||||
save_directory | |||||
), "Saving path should be a directory where the model and configuration can be saved" | |||||
# Only save the model itself if we are using distributed training | |||||
model_to_save = self.module if hasattr(self, "module") else self | |||||
# Attach architecture to the config | |||||
model_to_save.config.architectures = [model_to_save.__class__.__name__] | |||||
# Save configuration file | |||||
model_to_save.config.save_pretrained(save_directory) | |||||
# If we save using the predefined names, we can load using `from_pretrained` | |||||
output_model_file = os.path.join(save_directory, WEIGHTS_NAME) | |||||
torch.save(model_to_save.state_dict(), output_model_file) | |||||
logger.debug("Model weights saved in {}".format(output_model_file)) |
@@ -15,9 +15,8 @@ import math | |||||
from torch.nn import CrossEntropyLoss | from torch.nn import CrossEntropyLoss | ||||
from fastNLP.io.file_utils import _get_file_name_base_on_postfix | from fastNLP.io.file_utils import _get_file_name_base_on_postfix | ||||
from ..decoder.seq2seq_decoder import Decoder, Past | |||||
from ..decoder.seq2seq_decoder import Seq2SeqDecoder, State | |||||
from ..generator.seq2seq_generator import SequenceGenerator | from ..generator.seq2seq_generator import SequenceGenerator | ||||
from typing import Tuple | |||||
GELU_CONSTANT = math.sqrt(2 / math.pi) | GELU_CONSTANT = math.sqrt(2 / math.pi) | ||||
@@ -732,7 +731,7 @@ class GPT2PreTrainedModel(nn.Module): | |||||
bos_token_id=bos_token_id, eos_token_id=eos_token_ids, | bos_token_id=bos_token_id, eos_token_id=eos_token_ids, | ||||
repetition_penalty=repetition_penalty, length_penalty=length_penalty, | repetition_penalty=repetition_penalty, length_penalty=length_penalty, | ||||
pad_token_id=pad_token_id) | pad_token_id=pad_token_id) | ||||
results = generator.generate(input_ids, past=None) | |||||
results = generator.generate(tokens=input_ids, state=GPT2State()) | |||||
return results | return results | ||||
@@ -788,21 +787,13 @@ class GPT2Model(GPT2PreTrainedModel): | |||||
for layer, heads in heads_to_prune.items(): | for layer, heads in heads_to_prune.items(): | ||||
self.h[layer].attn.prune_heads(heads) | self.h[layer].attn.prune_heads(heads) | ||||
def forward( | |||||
self, | |||||
input_ids, | |||||
past=None, | |||||
attention_mask=None, | |||||
token_type_ids=None, | |||||
position_ids=None, | |||||
head_mask=None, | |||||
output_attentions=True | |||||
): | |||||
def forward(self, input_ids, state=None, attention_mask=None, token_type_ids=None, position_ids=None, | |||||
head_mask=None, output_attentions=True): | |||||
""" | """ | ||||
:param torch.LongTensor input_ids: batch_size x max_len or batch_size x beam_size x 1 | :param torch.LongTensor input_ids: batch_size x max_len or batch_size x beam_size x 1 | ||||
:param GPT2Past past: 之前的状态 | |||||
:param torch.ByteTensor attention_mask: batch_size x (pre_len+past_len), 与input_ids与past的concat一样大。 | |||||
:param GPT2State state: 之前的状态 | |||||
:param torch.ByteTensor attention_mask: batch_size x (pre_len+past_len), 与input_ids与state的concat一样大。 | |||||
为0的地方为padding。 | 为0的地方为padding。 | ||||
:param torch.LongTensor token_type_ids: batch_size x max_len。 | :param torch.LongTensor token_type_ids: batch_size x max_len。 | ||||
:param torch.LongTensor position_ids: 与input_ids对应的位置 | :param torch.LongTensor position_ids: 与input_ids对应的位置 | ||||
@@ -818,11 +809,11 @@ class GPT2Model(GPT2PreTrainedModel): | |||||
if position_ids is not None: | if position_ids is not None: | ||||
position_ids = position_ids.view(-1, input_shape[-1]) | position_ids = position_ids.view(-1, input_shape[-1]) | ||||
if past is None or len(past)==0: | |||||
if state is None or len(state)==0: | |||||
past_length = 0 | past_length = 0 | ||||
past = [None] * len(self.h) # len(self.h) 是layer的层数 | |||||
state = [None] * len(self.h) # len(self.h) 是layer的层数 | |||||
else: | else: | ||||
past_length = past[0][0].size(-2) | |||||
past_length = state[0][0].size(-2) | |||||
if position_ids is None: # 如果没有position id则生成 | if position_ids is None: # 如果没有position id则生成 | ||||
device = input_ids.device | device = input_ids.device | ||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) | position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) | ||||
@@ -880,7 +871,7 @@ class GPT2Model(GPT2PreTrainedModel): | |||||
presents = () | presents = () | ||||
all_attentions = [] | all_attentions = [] | ||||
all_hidden_states = () | all_hidden_states = () | ||||
for i, (block, layer_past) in enumerate(zip(self.h, past)): | |||||
for i, (block, layer_past) in enumerate(zip(self.h, state)): | |||||
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),) | all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),) | ||||
outputs = block( | outputs = block( | ||||
@@ -915,56 +906,63 @@ class GPT2Model(GPT2PreTrainedModel): | |||||
return outputs # last hidden state, (presents), (all hidden_states), (attentions) | return outputs # last hidden state, (presents), (all hidden_states), (attentions) | ||||
class GPT2Past(Past): | |||||
class GPT2State(State): | |||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | |||||
self.past = None # tuple [n_layer, 2 x batch_size x n_head x past_len x head_dim] | |||||
super().__init__(None, None) | |||||
self.state = None # tuple [n_layer, 2 x batch_size x n_head x past_len x head_dim] | |||||
@property | |||||
def num_samples(self): | def num_samples(self): | ||||
if self.past is not None: | |||||
return self.past[0].size(1) | |||||
if self.state is not None: | |||||
return self.state[0].size(1) | |||||
return None | return None | ||||
def reorder_past(self, indices): | |||||
for i in range(len(self.past)): | |||||
assert self.past[i] is not None | |||||
self.past[i] = self.past[i].index_select(index=indices, dim=1) | |||||
@property | |||||
def decode_length(self): | |||||
if self.state is None: | |||||
return 0 | |||||
return self.state[0].size(-2) | |||||
def reorder_state(self, indices): | |||||
if self.state: | |||||
for i in range(len(self.state)): | |||||
assert self.state[i] is not None | |||||
self.state[i] = self.state[i].index_select(index=indices, dim=1) | |||||
def __iter__(self): | def __iter__(self): | ||||
for p in self.past: | |||||
for p in self.state: | |||||
yield p | yield p | ||||
def __getitem__(self, item): | def __getitem__(self, item): | ||||
assert isinstance(item, int) | assert isinstance(item, int) | ||||
return self.past[item] | |||||
return self.state[item] | |||||
def __len__(self): | def __len__(self): | ||||
if self.past is not None: | |||||
return len(self.past) | |||||
if self.state is not None: | |||||
return len(self.state) | |||||
return 0 | return 0 | ||||
class _GPT2Decoder(Decoder): | |||||
class _GPT2Decoder(Seq2SeqDecoder): | |||||
""" | |||||
用于wrap GPT2是的可以在SequenceGenerator中使用 | |||||
""" | |||||
def __init__(self, gpt_model): | def __init__(self, gpt_model): | ||||
super().__init__() | super().__init__() | ||||
self.gpt_model = gpt_model | self.gpt_model = gpt_model | ||||
def decode(self, tokens, past=None) -> Tuple[torch.Tensor, Past]: | |||||
if past is None: | |||||
past = GPT2Past() | |||||
lm_logits, presents, _ = self.gpt_model(input_ids=tokens, | |||||
past=past, | |||||
def decode(self, tokens, state=None) -> torch.Tensor: | |||||
if state is None: | |||||
state = GPT2State() | |||||
lm_logits, presents, _ = self.gpt_model(input_ids=tokens[:, state.decode_length:], | |||||
state=state, | |||||
attention_mask=None, | attention_mask=None, | ||||
token_type_ids=None, | token_type_ids=None, | ||||
position_ids=None, | position_ids=None, | ||||
head_mask=None, | head_mask=None, | ||||
output_attentions=False) | output_attentions=False) | ||||
past.past = list(presents) | |||||
return lm_logits[:, -1], past | |||||
def reorder_past(self, indices: torch.LongTensor, past: GPT2Past) -> GPT2Past: | |||||
past.reorder_past(indices) | |||||
return past | |||||
state.state = list(presents) | |||||
return lm_logits[:, -1] | |||||
class GPT2LMHeadModel(GPT2PreTrainedModel): | class GPT2LMHeadModel(GPT2PreTrainedModel): | ||||
@@ -1008,21 +1006,12 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): | |||||
def get_input_embeddings(self): | def get_input_embeddings(self): | ||||
return self.transformer.wte | return self.transformer.wte | ||||
def forward( | |||||
self, | |||||
input_ids, | |||||
past=None, | |||||
attention_mask=None, | |||||
token_type_ids=None, | |||||
position_ids=None, | |||||
head_mask=None, | |||||
labels=None, | |||||
output_attentions=False | |||||
): | |||||
def forward(self, input_ids, state=None, attention_mask=None, token_type_ids=None, position_ids=None, | |||||
head_mask=None, labels=None, output_attentions=False): | |||||
""" | """ | ||||
:param torch.LongTensor input_ids: batch_size x max_len or batch_size x beam_size x 1 | :param torch.LongTensor input_ids: batch_size x max_len or batch_size x beam_size x 1 | ||||
:param tuple past: num_layers x 2 x batch_size x n_head x max_len' x head_dim. 可以将前一个时刻的presents作为输入 | |||||
:param tuple state: num_layers x 2 x batch_size x n_head x max_len' x head_dim. 可以将前一个时刻的presents作为输入 | |||||
:param torch.ByteTensor attention_mask: batch_size x max_len, 与input_ids一样大。为0的地方为padding。 | :param torch.ByteTensor attention_mask: batch_size x max_len, 与input_ids一样大。为0的地方为padding。 | ||||
:param torch.LongTensor token_type_ids: batch_size x max_len。 | :param torch.LongTensor token_type_ids: batch_size x max_len。 | ||||
:param torch.LongTensor position_ids: 与input_ids对应的位置 | :param torch.LongTensor position_ids: 与input_ids对应的位置 | ||||
@@ -1034,7 +1023,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): | |||||
""" | """ | ||||
transformer_outputs = self.transformer( | transformer_outputs = self.transformer( | ||||
input_ids, | input_ids, | ||||
past=past, | |||||
state=state, | |||||
attention_mask=attention_mask, | attention_mask=attention_mask, | ||||
token_type_ids=token_type_ids, | token_type_ids=token_type_ids, | ||||
position_ids=position_ids, | position_ids=position_ids, | ||||
@@ -0,0 +1,189 @@ | |||||
import torch.nn as nn | |||||
import torch | |||||
from torch.nn import LayerNorm | |||||
import torch.nn.functional as F | |||||
from typing import Union, Tuple | |||||
from ...core.utils import seq_len_to_mask | |||||
import math | |||||
from ...modules.encoder.lstm import LSTM | |||||
from fastNLP.modules.attention import MultiHeadAttention | |||||
from ...embeddings import StaticEmbedding | |||||
from ...embeddings.utils import get_embeddings | |||||
class Seq2SeqEncoder(nn.Module): | |||||
""" | |||||
所有Sequence2Sequence Encoder的基类。需要实现forward函数 | |||||
""" | |||||
def __init__(self): | |||||
super().__init__() | |||||
def forward(self, tokens, seq_len): | |||||
""" | |||||
:param torch.LongTensor tokens: bsz x max_len, encoder的输入 | |||||
:param torch.LongTensor seq_len: bsz | |||||
:return: | |||||
""" | |||||
raise NotImplementedError | |||||
class TransformerSeq2SeqEncoderLayer(nn.Module): | |||||
def __init__(self, d_model: int = 512, n_head: int = 8, dim_ff: int = 2048, | |||||
dropout: float = 0.1): | |||||
""" | |||||
Self-Attention的Layer, | |||||
:param int d_model: input和output的输出维度 | |||||
:param int n_head: 多少个head,每个head的维度为d_model/n_head | |||||
:param int dim_ff: FFN的维度大小 | |||||
:param float dropout: Self-attention和FFN的dropout大小,0表示不drop | |||||
""" | |||||
super(TransformerSeq2SeqEncoderLayer, self).__init__() | |||||
self.d_model = d_model | |||||
self.n_head = n_head | |||||
self.dim_ff = dim_ff | |||||
self.dropout = dropout | |||||
self.self_attn = MultiHeadAttention(d_model, n_head, dropout) | |||||
self.attn_layer_norm = LayerNorm(d_model) | |||||
self.ffn_layer_norm = LayerNorm(d_model) | |||||
self.ffn = nn.Sequential(nn.Linear(self.d_model, self.dim_ff), | |||||
nn.ReLU(), | |||||
nn.Dropout(dropout), | |||||
nn.Linear(self.dim_ff, self.d_model), | |||||
nn.Dropout(dropout)) | |||||
def forward(self, x, mask): | |||||
""" | |||||
:param x: batch x src_seq x d_model | |||||
:param mask: batch x src_seq,为0的地方为padding | |||||
:return: | |||||
""" | |||||
# attention | |||||
residual = x | |||||
x = self.attn_layer_norm(x) | |||||
x, _ = self.self_attn(query=x, | |||||
key=x, | |||||
value=x, | |||||
key_mask=mask) | |||||
x = F.dropout(x, p=self.dropout, training=self.training) | |||||
x = residual + x | |||||
# ffn | |||||
residual = x | |||||
x = self.ffn_layer_norm(x) | |||||
x = self.ffn(x) | |||||
x = residual + x | |||||
return x | |||||
class TransformerSeq2SeqEncoder(Seq2SeqEncoder): | |||||
def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], pos_embed = None, | |||||
num_layers = 6, d_model = 512, n_head = 8, dim_ff = 2048, dropout = 0.1): | |||||
""" | |||||
基于Transformer的Encoder | |||||
:param embed: encoder输入token的embedding | |||||
:param nn.Module pos_embed: position embedding | |||||
:param int num_layers: 多少层的encoder | |||||
:param int d_model: 输入输出的维度 | |||||
:param int n_head: 多少个head | |||||
:param int dim_ff: FFN中间的维度大小 | |||||
:param float dropout: Attention和FFN的dropout大小 | |||||
""" | |||||
super(TransformerSeq2SeqEncoder, self).__init__() | |||||
self.embed = get_embeddings(embed) | |||||
self.embed_scale = math.sqrt(d_model) | |||||
self.pos_embed = pos_embed | |||||
self.num_layers = num_layers | |||||
self.d_model = d_model | |||||
self.n_head = n_head | |||||
self.dim_ff = dim_ff | |||||
self.dropout = dropout | |||||
self.input_fc = nn.Linear(self.embed.embedding_dim, d_model) | |||||
self.layer_stacks = nn.ModuleList([TransformerSeq2SeqEncoderLayer(d_model, n_head, dim_ff, dropout) | |||||
for _ in range(num_layers)]) | |||||
self.layer_norm = LayerNorm(d_model) | |||||
def forward(self, tokens, seq_len): | |||||
""" | |||||
:param tokens: batch x max_len | |||||
:param seq_len: [batch] | |||||
:return: bsz x max_len x d_model, bsz x max_len(为0的地方为padding) | |||||
""" | |||||
x = self.embed(tokens) * self.embed_scale # batch, seq, dim | |||||
batch_size, max_src_len, _ = x.size() | |||||
device = x.device | |||||
if self.pos_embed is not None: | |||||
position = torch.arange(1, max_src_len + 1).unsqueeze(0).long().to(device) | |||||
x += self.pos_embed(position) | |||||
x = self.input_fc(x) | |||||
x = F.dropout(x, p=self.dropout, training=self.training) | |||||
encoder_mask = seq_len_to_mask(seq_len) | |||||
encoder_mask = encoder_mask.to(device) | |||||
for layer in self.layer_stacks: | |||||
x = layer(x, encoder_mask) | |||||
x = self.layer_norm(x) | |||||
return x, encoder_mask | |||||
class LSTMSeq2SeqEncoder(Seq2SeqEncoder): | |||||
def __init__(self, embed: Union[nn.Module, StaticEmbedding, Tuple[int, int]], num_layers = 3, | |||||
hidden_size = 400, dropout = 0.3, bidirectional=True): | |||||
""" | |||||
LSTM的Encoder | |||||
:param embed: encoder的token embed | |||||
:param int num_layers: 多少层 | |||||
:param int hidden_size: LSTM隐藏层、输出的大小 | |||||
:param float dropout: LSTM层之间的Dropout是多少 | |||||
:param bool bidirectional: 是否使用双向 | |||||
""" | |||||
super().__init__() | |||||
self.embed = get_embeddings(embed) | |||||
self.num_layers = num_layers | |||||
self.dropout = dropout | |||||
self.hidden_size = hidden_size | |||||
self.bidirectional = bidirectional | |||||
hidden_size = hidden_size//2 if bidirectional else hidden_size | |||||
self.lstm = LSTM(input_size=embed.embedding_dim, hidden_size=hidden_size, bidirectional=bidirectional, | |||||
batch_first=True, dropout=dropout if num_layers>1 else 0, num_layers=num_layers) | |||||
def forward(self, tokens, seq_len): | |||||
""" | |||||
:param torch.LongTensor tokens: bsz x max_len | |||||
:param torch.LongTensor seq_len: bsz | |||||
:return: (output, (hidden, cell)), encoder_mask | |||||
output: bsz x max_len x hidden_size, | |||||
hidden,cell: batch_size x hidden_size, 最后一层的隐藏状态或cell状态 | |||||
encoder_mask: bsz x max_len, 为0的地方是padding | |||||
""" | |||||
x = self.embed(tokens) | |||||
device = x.device | |||||
x, (final_hidden, final_cell) = self.lstm(x, seq_len) | |||||
encoder_mask = seq_len_to_mask(seq_len).to(device) | |||||
# x: batch,seq_len,dim; h/c: num_layers*2,batch,dim | |||||
if self.bidirectional: | |||||
final_hidden = self.concat_bidir(final_hidden) # 将双向的hidden state拼接起来,用于接下来的decoder的input | |||||
final_cell = self.concat_bidir(final_cell) | |||||
return (x, (final_hidden[-1], final_cell[-1])), encoder_mask # 为了配合Seq2SeqBaseModel的forward,这边需要分为两个return | |||||
def concat_bidir(self, input): | |||||
output = input.view(self.num_layers, 2, input.size(1), -1).transpose(1, 2) | |||||
return output.reshape(self.num_layers, input.size(1), -1) |
@@ -5,7 +5,7 @@ __all__ = [ | |||||
] | ] | ||||
from torch import nn | from torch import nn | ||||
from .attention import MultiHeadAttention | |||||
from .seq2seq_encoder import TransformerSeq2SeqEncoderLayer | |||||
class TransformerEncoder(nn.Module): | class TransformerEncoder(nn.Module): | ||||
@@ -13,66 +13,30 @@ class TransformerEncoder(nn.Module): | |||||
transformer的encoder模块,不包含embedding层 | transformer的encoder模块,不包含embedding层 | ||||
""" | """ | ||||
def __init__(self, num_layers, d_model=512, n_head=8, dim_ff=2048, dropout=0.1): | |||||
""" | |||||
class SubLayer(nn.Module): | |||||
def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): | |||||
super(TransformerEncoder.SubLayer, self).__init__() | |||||
self.atte = MultiHeadAttention(model_size, key_size, value_size, num_head, dropout) | |||||
self.norm1 = nn.LayerNorm(model_size, eps=1e-6) | |||||
self.ffn = nn.Sequential(nn.Linear(model_size, inner_size), | |||||
nn.ReLU(), | |||||
nn.Dropout(dropout), | |||||
nn.Linear(inner_size, model_size)) | |||||
self.norm2 = nn.LayerNorm(model_size, eps=1e-6) | |||||
self.dropout = nn.Dropout(dropout) | |||||
def forward(self, input, seq_mask=None, atte_mask_out=None): | |||||
r""" | |||||
:param input: [batch, seq_len, model_size] | |||||
:param seq_mask: [batch, seq_len] | |||||
:return: [batch, seq_len, model_size] | |||||
""" | |||||
if seq_mask is None: # 防止后续乘法时出错 | |||||
seq_mask = 1 | |||||
input = self.norm1(input) | |||||
attention = self.atte(input, input, input, atte_mask_out) | |||||
input = input + self.dropout(attention) | |||||
attention *= seq_mask | |||||
input = self.norm2(input) | |||||
output = self.ffn(input) | |||||
input = input + self.dropout(output) | |||||
input *= seq_mask | |||||
return input | |||||
def __init__(self, num_layers, **kargs): | |||||
r""" | |||||
:param int num_layers: transformer的层数 | |||||
:param int model_size: 输入维度的大小。同时也是输出维度的大小。 | |||||
:param int inner_size: FFN层的hidden大小 | |||||
:param int key_size: 每个head的维度大小。 | |||||
:param int value_size: 每个head中value的维度。 | |||||
:param int num_head: head的数量。 | |||||
:param float dropout: dropout概率. Default: 0.1 | |||||
:param int num_layers: 多少层Transformer | |||||
:param int d_model: input和output的大小 | |||||
:param int n_head: 多少个head | |||||
:param int dim_ff: FFN中间hidden大小 | |||||
:param float dropout: 多大概率drop attention和ffn中间的表示 | |||||
""" | """ | ||||
super(TransformerEncoder, self).__init__() | super(TransformerEncoder, self).__init__() | ||||
self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)]) | |||||
self.norm = nn.LayerNorm(kargs['model_size'], eps=1e-6) | |||||
self.layers = nn.ModuleList([TransformerSeq2SeqEncoderLayer(d_model = d_model, n_head = n_head, dim_ff = dim_ff, | |||||
dropout = dropout) for _ in range(num_layers)]) | |||||
self.norm = nn.LayerNorm(d_model, eps=1e-6) | |||||
def forward(self, x, seq_mask=None): | def forward(self, x, seq_mask=None): | ||||
r""" | r""" | ||||
:param x: [batch, seq_len, model_size] 输入序列 | :param x: [batch, seq_len, model_size] 输入序列 | ||||
:param seq_mask: [batch, seq_len] 输入序列的padding mask, 若为 ``None`` , 生成全1向量. | |||||
:param seq_mask: [batch, seq_len] 输入序列的padding mask, 若为 ``None`` , 生成全1向量. 为1的地方需要attend | |||||
Default: ``None`` | Default: ``None`` | ||||
:return: [batch, seq_len, model_size] 输出序列 | :return: [batch, seq_len, model_size] 输出序列 | ||||
""" | """ | ||||
output = x | output = x | ||||
if seq_mask is None: | if seq_mask is None: | ||||
atte_mask_out = None | |||||
else: | |||||
atte_mask_out = (seq_mask.eq(False))[:, None, :] | |||||
seq_mask = seq_mask[:, :, None] | |||||
seq_mask = x.new_ones(x.size(0), x.size(1)).bool() | |||||
for layer in self.layers: | for layer in self.layers: | ||||
output = layer(output, seq_mask, atte_mask_out) | |||||
output = layer(output, seq_mask) | |||||
return self.norm(output) | return self.norm(output) |
@@ -0,0 +1,9 @@ | |||||
r""" | |||||
""" | |||||
__all__ = [ | |||||
"SequenceGenerator" | |||||
] | |||||
from .seq2seq_generator import SequenceGenerator |
@@ -7,16 +7,35 @@ __all__ = [ | |||||
] | ] | ||||
import torch | import torch | ||||
from ..decoder.seq2seq_decoder import Decoder | |||||
from ..decoder.seq2seq_decoder import Seq2SeqDecoder, State | |||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from fastNLP.core.utils import _get_model_device | |||||
from ...core.utils import _get_model_device | |||||
from functools import partial | from functools import partial | ||||
class SequenceGenerator: | class SequenceGenerator: | ||||
def __init__(self, decoder: Decoder, max_length=20, num_beams=1, | |||||
""" | |||||
给定一个Seq2SeqDecoder,decode出句子 | |||||
""" | |||||
def __init__(self, decoder: Seq2SeqDecoder, max_length=20, num_beams=1, | |||||
do_sample=True, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, | do_sample=True, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, | ||||
repetition_penalty=1, length_penalty=1.0, pad_token_id=0): | repetition_penalty=1, length_penalty=1.0, pad_token_id=0): | ||||
""" | |||||
:param Seq2SeqDecoder decoder: Decoder对象 | |||||
:param int max_length: 句子的最大长度 | |||||
:param int num_beams: beam search的大小 | |||||
:param bool do_sample: 是否通过采样的方式生成 | |||||
:param float temperature: 只有在do_sample为True才有意义 | |||||
:param int top_k: 只从top_k中采样 | |||||
:param float top_p: 只从top_p的token中采样,nucles sample | |||||
:param int,None bos_token_id: 句子开头的token id | |||||
:param int,None eos_token_id: 句子结束的token id | |||||
:param float repetition_penalty: 多大程度上惩罚重复的token | |||||
:param float length_penalty: 对长度的惩罚,小于1鼓励长句,大于1鼓励短剧 | |||||
:param int pad_token_id: 当某句话生成结束之后,之后生成的内容用pad_token_id补充 | |||||
""" | |||||
if do_sample: | if do_sample: | ||||
self.generate_func = partial(sample_generate, decoder=decoder, max_length=max_length, num_beams=num_beams, | self.generate_func = partial(sample_generate, decoder=decoder, max_length=max_length, num_beams=num_beams, | ||||
temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id, | temperature=temperature, top_k=top_k, top_p=top_p, bos_token_id=bos_token_id, | ||||
@@ -40,19 +59,19 @@ class SequenceGenerator: | |||||
self.decoder = decoder | self.decoder = decoder | ||||
@torch.no_grad() | @torch.no_grad() | ||||
def generate(self, tokens=None, past=None): | |||||
def generate(self, state, tokens=None): | |||||
""" | """ | ||||
:param torch.LongTensor tokens: batch_size x length, 开始的token | |||||
:param past: | |||||
:return: | |||||
:param State state: encoder结果的State, 是与Decoder配套是用的 | |||||
:param torch.LongTensor,None tokens: batch_size x length, 开始的token | |||||
:return: bsz x max_length' 生成的token序列。如果eos_token_id不为None, 每个sequence的结尾一定是eos_token_id | |||||
""" | """ | ||||
# TODO 需要查看如果tokens长度不是1,decode的时候是否还能够直接decode? | |||||
return self.generate_func(tokens=tokens, past=past) | |||||
return self.generate_func(tokens=tokens, state=state) | |||||
@torch.no_grad() | @torch.no_grad() | ||||
def greedy_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, | |||||
def greedy_generate(decoder, tokens=None, state=None, max_length=20, num_beams=1, | |||||
bos_token_id=None, eos_token_id=None, pad_token_id=0, | bos_token_id=None, eos_token_id=None, pad_token_id=0, | ||||
repetition_penalty=1, length_penalty=1.0): | repetition_penalty=1, length_penalty=1.0): | ||||
""" | """ | ||||
@@ -60,23 +79,23 @@ def greedy_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, | |||||
:param Decoder decoder: Decoder对象 | :param Decoder decoder: Decoder对象 | ||||
:param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 | :param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 | ||||
:param Past past: 应该包好encoder的一些输出。 | |||||
:param State state: 应该包含encoder的一些输出。 | |||||
:param int max_length: 生成句子的最大长度。 | :param int max_length: 生成句子的最大长度。 | ||||
:param int num_beams: 使用多大的beam进行解码。 | :param int num_beams: 使用多大的beam进行解码。 | ||||
:param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。 | :param int bos_token_id: 如果tokens传入为None,则使用bos_token_id开始往后解码。 | ||||
:param int eos_token_id: 结束的token,如果为None,则一定会解码到max_length这么长。 | :param int eos_token_id: 结束的token,如果为None,则一定会解码到max_length这么长。 | ||||
:param int pad_token_id: | |||||
:param int pad_token_id: pad的token id | |||||
:param float repetition_penalty: 对重复出现的token多大的惩罚。 | :param float repetition_penalty: 对重复出现的token多大的惩罚。 | ||||
:param float length_penalty: 对每个token(除了eos)按照长度进行一定的惩罚。 | :param float length_penalty: 对每个token(除了eos)按照长度进行一定的惩罚。 | ||||
:return: | :return: | ||||
""" | """ | ||||
if num_beams == 1: | if num_beams == 1: | ||||
token_ids = _no_beam_search_generate(decoder, tokens, past, max_length, temperature=1, top_k=50, top_p=1, | |||||
token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, temperature=1, top_k=50, top_p=1, | |||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, | bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, | ||||
repetition_penalty=repetition_penalty, length_penalty=length_penalty, | repetition_penalty=repetition_penalty, length_penalty=length_penalty, | ||||
pad_token_id=pad_token_id) | pad_token_id=pad_token_id) | ||||
else: | else: | ||||
token_ids = _beam_search_generate(decoder, tokens, past, max_length, num_beams=num_beams, | |||||
token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, num_beams=num_beams, | |||||
temperature=1, top_k=50, top_p=1, | temperature=1, top_k=50, top_p=1, | ||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, | bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, | ||||
repetition_penalty=repetition_penalty, length_penalty=length_penalty, | repetition_penalty=repetition_penalty, length_penalty=length_penalty, | ||||
@@ -86,7 +105,7 @@ def greedy_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, | |||||
@torch.no_grad() | @torch.no_grad() | ||||
def sample_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, temperature=1.0, top_k=50, | |||||
def sample_generate(decoder, tokens=None, state=None, max_length=20, num_beams=1, temperature=1.0, top_k=50, | |||||
top_p=1.0, bos_token_id=None, eos_token_id=None, pad_token_id=0, repetition_penalty=1.0, | top_p=1.0, bos_token_id=None, eos_token_id=None, pad_token_id=0, repetition_penalty=1.0, | ||||
length_penalty=1.0): | length_penalty=1.0): | ||||
""" | """ | ||||
@@ -94,7 +113,7 @@ def sample_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, | |||||
:param Decoder decoder: Decoder对象 | :param Decoder decoder: Decoder对象 | ||||
:param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 | :param torch.LongTensor tokens: batch_size x len, decode的输入值,如果为None,则自动从bos_token_id开始生成 | ||||
:param Past past: 应该包好encoder的一些输出。 | |||||
:param State state: 应该包含encoder的一些输出。 | |||||
:param int max_length: 生成句子的最大长度。 | :param int max_length: 生成句子的最大长度。 | ||||
:param int num_beam: 使用多大的beam进行解码。 | :param int num_beam: 使用多大的beam进行解码。 | ||||
:param float temperature: 采样时的退火大小 | :param float temperature: 采样时的退火大小 | ||||
@@ -109,13 +128,13 @@ def sample_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, | |||||
""" | """ | ||||
# 每个位置在生成的时候会sample生成 | # 每个位置在生成的时候会sample生成 | ||||
if num_beams == 1: | if num_beams == 1: | ||||
token_ids = _no_beam_search_generate(decoder, tokens, past, max_length, temperature=temperature, | |||||
token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, temperature=temperature, | |||||
top_k=top_k, top_p=top_p, | top_k=top_k, top_p=top_p, | ||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, | bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, | ||||
repetition_penalty=repetition_penalty, length_penalty=length_penalty, | repetition_penalty=repetition_penalty, length_penalty=length_penalty, | ||||
pad_token_id=pad_token_id) | pad_token_id=pad_token_id) | ||||
else: | else: | ||||
token_ids = _beam_search_generate(decoder, tokens, past, max_length, num_beams=num_beams, | |||||
token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, num_beams=num_beams, | |||||
temperature=temperature, top_k=top_k, top_p=top_p, | temperature=temperature, top_k=top_k, top_p=top_p, | ||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, | bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=True, | ||||
repetition_penalty=repetition_penalty, length_penalty=length_penalty, | repetition_penalty=repetition_penalty, length_penalty=length_penalty, | ||||
@@ -123,40 +142,35 @@ def sample_generate(decoder, tokens=None, past=None, max_length=20, num_beams=1, | |||||
return token_ids | return token_ids | ||||
def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=20, temperature=1.0, top_k=50, | |||||
def _no_beam_search_generate(decoder: Seq2SeqDecoder, state, tokens=None, max_length=20, temperature=1.0, top_k=50, | |||||
top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, | top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, | ||||
repetition_penalty=1.0, length_penalty=1.0, pad_token_id=0): | repetition_penalty=1.0, length_penalty=1.0, pad_token_id=0): | ||||
device = _get_model_device(decoder) | device = _get_model_device(decoder) | ||||
if tokens is None: | if tokens is None: | ||||
if bos_token_id is None: | if bos_token_id is None: | ||||
raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.") | raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.") | ||||
if past is None: | |||||
raise RuntimeError("You have to specify either `past` or `tokens`.") | |||||
batch_size = past.num_samples() | |||||
batch_size = state.num_samples | |||||
if batch_size is None: | if batch_size is None: | ||||
raise RuntimeError("Cannot infer the number of samples from `past`.") | |||||
raise RuntimeError("Cannot infer the number of samples from `state`.") | |||||
tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device) | tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device) | ||||
batch_size = tokens.size(0) | batch_size = tokens.size(0) | ||||
if past is not None: | |||||
assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match." | |||||
if state.num_samples: | |||||
assert state.num_samples == batch_size, "The number of samples in `tokens` and `state` should match." | |||||
if eos_token_id is None: | if eos_token_id is None: | ||||
_eos_token_id = float('nan') | |||||
_eos_token_id = -1 | |||||
else: | else: | ||||
_eos_token_id = eos_token_id | _eos_token_id = eos_token_id | ||||
# for i in range(tokens.size(1)): | |||||
# scores, past = decoder.decode_one(tokens[:, :i + 1], past) # batch_size x vocab_size, Past | |||||
scores, past = decoder.decode(tokens, past) | |||||
token_ids = tokens.clone() | |||||
scores = decoder.decode(tokens=tokens, state=state) # 主要是为了update state | |||||
next_tokens = scores.argmax(dim=-1, keepdim=True) | |||||
token_ids = torch.cat([tokens, next_tokens], dim=1) | |||||
cur_len = token_ids.size(1) | cur_len = token_ids.size(1) | ||||
dones = token_ids.new_zeros(batch_size).eq(1) | dones = token_ids.new_zeros(batch_size).eq(1) | ||||
# tokens = tokens[:, -1:] | # tokens = tokens[:, -1:] | ||||
while cur_len < max_length: | while cur_len < max_length: | ||||
# scores, past = decoder.decode_one(tokens, past) # batch_size x vocab_size, Past | |||||
scores, past = decoder.decode(tokens, past) # batch_size x vocab_size, Past | |||||
scores = decoder.decode(tokens=token_ids, state=state) # batch_size x vocab_size | |||||
if repetition_penalty != 1.0: | if repetition_penalty != 1.0: | ||||
token_scores = scores.gather(dim=1, index=token_ids) | token_scores = scores.gather(dim=1, index=token_ids) | ||||
@@ -204,7 +218,7 @@ def _no_beam_search_generate(decoder: Decoder, tokens=None, past=None, max_lengt | |||||
return token_ids | return token_ids | ||||
def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=20, num_beams=4, temperature=1.0, | |||||
def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_length=20, num_beams=4, temperature=1.0, | |||||
top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, | top_k=50, top_p=1.0, bos_token_id=None, eos_token_id=None, do_sample=True, | ||||
repetition_penalty=1.0, length_penalty=None, pad_token_id=0) -> torch.LongTensor: | repetition_penalty=1.0, length_penalty=None, pad_token_id=0) -> torch.LongTensor: | ||||
# 进行beam search | # 进行beam search | ||||
@@ -212,21 +226,20 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||||
if tokens is None: | if tokens is None: | ||||
if bos_token_id is None: | if bos_token_id is None: | ||||
raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.") | raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.") | ||||
if past is None: | |||||
raise RuntimeError("You have to specify either `past` or `tokens`.") | |||||
batch_size = past.num_samples() | |||||
batch_size = state.num_samples | |||||
if batch_size is None: | if batch_size is None: | ||||
raise RuntimeError("Cannot infer the number of samples from `past`.") | |||||
raise RuntimeError("Cannot infer the number of samples from `state`.") | |||||
tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device) | tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device) | ||||
batch_size = tokens.size(0) | batch_size = tokens.size(0) | ||||
if past is not None: | |||||
assert past.num_samples() == batch_size, "The number of samples in `tokens` and `past` should match." | |||||
# for i in range(tokens.size(1) - 1): # 如果输入的长度较长,先decode | |||||
# scores, past = decoder.decode_one(tokens[:, :i + 1], | |||||
# past) # (batch_size, vocab_size), Past | |||||
# scores, past = decoder.decode_one(tokens, past) # 这里要传入的是整个句子的长度 | |||||
scores, past = decoder.decode(tokens, past) # 这里要传入的是整个句子的长度 | |||||
if state.num_samples: | |||||
assert state.num_samples == batch_size, "The number of samples in `tokens` and `state` should match." | |||||
if eos_token_id is None: | |||||
_eos_token_id = -1 | |||||
else: | |||||
_eos_token_id = eos_token_id | |||||
scores = decoder.decode(tokens=tokens, state=state) # 这里要传入的是整个句子的长度 | |||||
vocab_size = scores.size(1) | vocab_size = scores.size(1) | ||||
assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size." | assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size." | ||||
@@ -240,15 +253,15 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||||
# 得到(batch_size, num_beams), (batch_size, num_beams) | # 得到(batch_size, num_beams), (batch_size, num_beams) | ||||
next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True) | next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True) | ||||
# 根据index来做顺序的调转 | |||||
indices = torch.arange(batch_size, dtype=torch.long).to(device) | indices = torch.arange(batch_size, dtype=torch.long).to(device) | ||||
indices = indices.repeat_interleave(num_beams) | indices = indices.repeat_interleave(num_beams) | ||||
decoder.reorder_past(indices, past) | |||||
state.reorder_state(indices) | |||||
tokens = tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length | tokens = tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length | ||||
# 记录生成好的token (batch_size', cur_len) | # 记录生成好的token (batch_size', cur_len) | ||||
token_ids = torch.cat([tokens, next_tokens.view(-1, 1)], dim=-1) | token_ids = torch.cat([tokens, next_tokens.view(-1, 1)], dim=-1) | ||||
dones = [False] * batch_size | dones = [False] * batch_size | ||||
tokens = next_tokens.view(-1, 1) | |||||
beam_scores = next_scores.view(-1) # batch_size * num_beams | beam_scores = next_scores.view(-1) # batch_size * num_beams | ||||
@@ -262,8 +275,7 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||||
batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids) | batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids) | ||||
while cur_len < max_length: | while cur_len < max_length: | ||||
# scores, past = decoder.decode_one(tokens, past) # batch_size * num_beams x vocab_size, Past | |||||
scores, past = decoder.decode(tokens, past) | |||||
scores = decoder.decode(token_ids, state) | |||||
if repetition_penalty != 1.0: | if repetition_penalty != 1.0: | ||||
token_scores = scores.gather(dim=1, index=token_ids) | token_scores = scores.gather(dim=1, index=token_ids) | ||||
lt_zero_mask = token_scores.lt(0).float() | lt_zero_mask = token_scores.lt(0).float() | ||||
@@ -307,7 +319,7 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||||
next_tokens = next_tokens.gather(dim=1, index=sorted_inds) | next_tokens = next_tokens.gather(dim=1, index=sorted_inds) | ||||
from_which_beam = from_which_beam.gather(dim=1, index=sorted_inds) | from_which_beam = from_which_beam.gather(dim=1, index=sorted_inds) | ||||
not_eos_mask = next_tokens.ne(eos_token_id) # 为1的地方不是eos | |||||
not_eos_mask = next_tokens.ne(_eos_token_id) # 为1的地方不是eos | |||||
keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams) # 为1的地方需要保留 | keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams) # 为1的地方需要保留 | ||||
keep_mask = not_eos_mask.__and__(keep_mask) # 为1的地方是需要进行下一步search的 | keep_mask = not_eos_mask.__and__(keep_mask) # 为1的地方是需要进行下一步search的 | ||||
@@ -316,18 +328,18 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||||
_next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams) | _next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams) | ||||
beam_scores = _next_scores.view(-1) | beam_scores = _next_scores.view(-1) | ||||
# 更改past状态, 重组token_ids | |||||
# 更改state状态, 重组token_ids | |||||
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维 | reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten成一维 | ||||
decoder.reorder_past(reorder_inds, past) | |||||
state.reorder_state(reorder_inds) | |||||
flag = True | flag = True | ||||
if cur_len + 1 == max_length: | |||||
if cur_len+1 == max_length: | |||||
eos_batch_idx = torch.arange(batch_size).to(next_tokens).repeat_interleave(repeats=num_beams, dim=0) | eos_batch_idx = torch.arange(batch_size).to(next_tokens).repeat_interleave(repeats=num_beams, dim=0) | ||||
eos_beam_ind = torch.arange(num_beams).to(token_ids).repeat(batch_size) # 表示的是indice | eos_beam_ind = torch.arange(num_beams).to(token_ids).repeat(batch_size) # 表示的是indice | ||||
eos_beam_idx = from_which_beam[:, :num_beams].reshape(-1) # 表示的是从哪个beam获取得到的 | eos_beam_idx = from_which_beam[:, :num_beams].reshape(-1) # 表示的是从哪个beam获取得到的 | ||||
else: | else: | ||||
# 将每个batch中在num_beam内的序列添加到结束中, 为1的地方需要结束了 | # 将每个batch中在num_beam内的序列添加到结束中, 为1的地方需要结束了 | ||||
effective_eos_mask = next_tokens[:, :num_beams].eq(eos_token_id) # batch_size x num_beams | |||||
effective_eos_mask = next_tokens[:, :num_beams].eq(_eos_token_id) # batch_size x num_beams | |||||
if effective_eos_mask.sum().gt(0): | if effective_eos_mask.sum().gt(0): | ||||
eos_batch_idx, eos_beam_ind = effective_eos_mask.nonzero(as_tuple=True) | eos_batch_idx, eos_beam_ind = effective_eos_mask.nonzero(as_tuple=True) | ||||
# 是由于from_which_beam是 (batch_size, 2*num_beams)的,所以需要2*num_beams | # 是由于from_which_beam是 (batch_size, 2*num_beams)的,所以需要2*num_beams | ||||
@@ -335,16 +347,17 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||||
eos_beam_idx = from_which_beam.view(-1)[eos_beam_idx] # 获取真实的从哪个beam获取的eos | eos_beam_idx = from_which_beam.view(-1)[eos_beam_idx] # 获取真实的从哪个beam获取的eos | ||||
else: | else: | ||||
flag = False | flag = False | ||||
# 重新组织token_ids的状态 | |||||
tokens = _next_tokens | |||||
token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), tokens], dim=-1) | |||||
if flag: | if flag: | ||||
for batch_idx, beam_ind, beam_idx in zip(eos_batch_idx.tolist(), eos_beam_ind.tolist(), | for batch_idx, beam_ind, beam_idx in zip(eos_batch_idx.tolist(), eos_beam_ind.tolist(), | ||||
eos_beam_idx.tolist()): | eos_beam_idx.tolist()): | ||||
if not dones[batch_idx]: | if not dones[batch_idx]: | ||||
score = next_scores[batch_idx, beam_ind].item() | score = next_scores[batch_idx, beam_ind].item() | ||||
hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score) | |||||
# 重新组织token_ids的状态 | |||||
tokens = _next_tokens | |||||
token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), tokens], dim=-1) | |||||
hypos[batch_idx].add(token_ids[batch_idx * num_beams + beam_idx, :cur_len+1].clone(), score) | |||||
for batch_idx in range(batch_size): | for batch_idx in range(batch_size): | ||||
dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item()) | dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item()) | ||||
@@ -360,15 +373,15 @@ def _beam_search_generate(decoder: Decoder, tokens=None, past=None, max_length=2 | |||||
for i, hypotheses in enumerate(hypos): | for i, hypotheses in enumerate(hypos): | ||||
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1] | best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1] | ||||
tgt_len[i] = len(best_hyp) + 1 # +1 for the <EOS> symbol | |||||
tgt_len[i] = len(best_hyp) # +1 for the <EOS> symbol | |||||
best.append(best_hyp) | best.append(best_hyp) | ||||
# generate target batch | # generate target batch | ||||
decoded = token_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id) | decoded = token_ids.new(batch_size, tgt_len.max().item()).fill_(pad_token_id) | ||||
for i, hypo in enumerate(best): | for i, hypo in enumerate(best): | ||||
decoded[i, :tgt_len[i] - 1] = hypo | |||||
decoded[i, :tgt_len[i]] = hypo | |||||
if eos_token_id is not None: | if eos_token_id is not None: | ||||
decoded[i, tgt_len[i] - 1] = eos_token_id | |||||
decoded[i, tgt_len[i] - 1] = _eos_token_id | |||||
return decoded | return decoded | ||||
@@ -384,6 +384,9 @@ class BertTokenizer(object): | |||||
index += 1 | index += 1 | ||||
return vocab_file | return vocab_file | ||||
def save_pretrained(self, save_directory): | |||||
self.save_vocabulary(save_directory) | |||||
@classmethod | @classmethod | ||||
def from_pretrained(cls, model_dir_or_name, *inputs, **kwargs): | def from_pretrained(cls, model_dir_or_name, *inputs, **kwargs): | ||||
r""" | r""" | ||||
@@ -377,6 +377,9 @@ class GPT2Tokenizer: | |||||
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) | text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) | ||||
return text | return text | ||||
def save_pretrained(self, save_directory): | |||||
return self.save_vocabulary(save_directory) | |||||
def save_vocabulary(self, save_directory): | def save_vocabulary(self, save_directory): | ||||
"""Save the tokenizer vocabulary and merge files to a directory.""" | """Save the tokenizer vocabulary and merge files to a directory.""" | ||||
if not os.path.isdir(save_directory): | if not os.path.isdir(save_directory): | ||||
@@ -32,7 +32,6 @@ from tools.PositionEmbedding import get_sinusoid_encoding_table | |||||
from tools.logger import * | from tools.logger import * | ||||
from fastNLP.core.const import Const | from fastNLP.core.const import Const | ||||
from fastNLP.modules.encoder.transformer import TransformerEncoder | |||||
from transformer.Layers import EncoderLayer | from transformer.Layers import EncoderLayer | ||||
@@ -30,7 +30,7 @@ from .Encoder import Encoder | |||||
from tools.PositionEmbedding import get_sinusoid_encoding_table | from tools.PositionEmbedding import get_sinusoid_encoding_table | ||||
from fastNLP.core.const import Const | from fastNLP.core.const import Const | ||||
from fastNLP.modules.encoder.transformer import TransformerEncoder | |||||
from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoderLayer | |||||
class TransformerModel(nn.Module): | class TransformerModel(nn.Module): | ||||
def __init__(self, hps, vocab): | def __init__(self, hps, vocab): | ||||
@@ -68,7 +68,8 @@ class TransformerModel(nn.Module): | |||||
get_sinusoid_encoding_table(hps.doc_max_timesteps + 1, self.hidden_size, padding_idx=0), freeze=True) | get_sinusoid_encoding_table(hps.doc_max_timesteps + 1, self.hidden_size, padding_idx=0), freeze=True) | ||||
self.layer_stack = nn.ModuleList([ | self.layer_stack = nn.ModuleList([ | ||||
TransformerEncoder.SubLayer(model_size=self.hidden_size, inner_size=self.d_inner, key_size=self.d_k, value_size=self.d_v,num_head=self.n_head, dropout=hps.atten_dropout_prob) | |||||
TransformerSeq2SeqEncoderLayer(d_model = self.hidden_size, n_head = self.n_head, dim_ff = self.d_inner, | |||||
dropout = hps.atten_dropout_prob) | |||||
for _ in range(self.num_layers)]) | for _ in range(self.num_layers)]) | ||||
self.wh = nn.Linear(self.hidden_size, 2) | self.wh = nn.Linear(self.hidden_size, 2) | ||||
@@ -109,7 +110,7 @@ class TransformerModel(nn.Module): | |||||
for enc_layer in self.layer_stack: | for enc_layer in self.layer_stack: | ||||
# enc_output = [batch_size, N, hidden_size = n_head * d_v] | # enc_output = [batch_size, N, hidden_size = n_head * d_v] | ||||
# enc_slf_attn = [n_head * batch_size, N, N] | # enc_slf_attn = [n_head * batch_size, N, N] | ||||
enc_input = enc_layer(enc_input, seq_mask=self.non_pad_mask, atte_mask_out=self.slf_attn_mask) | |||||
enc_input = enc_layer(enc_input, encoder_mask=self.slf_attn_mask) | |||||
enc_input_list += [enc_input] | enc_input_list += [enc_input] | ||||
self.dec_output_state = torch.cat(enc_input_list[-4:]) # [4, batch_size, N, hidden_state] | self.dec_output_state = torch.cat(enc_input_list[-4:]) # [4, batch_size, N, hidden_state] | ||||
@@ -265,7 +265,7 @@ def run_eval(model, loader, hps, best_loss, best_F, non_descent_cnt): | |||||
label = Variable(label) | label = Variable(label) | ||||
input_len = Variable(input_len, requires_grad=False) | input_len = Variable(input_len, requires_grad=False) | ||||
model_outputs = model.forward(input,input_len) # [batch, N, 2] | |||||
model_outputs = model.forward(input, input_len) # [batch, N, 2] | |||||
outputs = model_outputs["p_sent"] | outputs = model_outputs["p_sent"] | ||||
prediction = model_outputs["prediction"] | prediction = model_outputs["prediction"] | ||||
@@ -264,7 +264,7 @@ def run_eval(model, loader, hps, best_loss, best_F, non_descent_cnt): | |||||
label = Variable(label) | label = Variable(label) | ||||
input_len = Variable(input_len, requires_grad=False) | input_len = Variable(input_len, requires_grad=False) | ||||
model_outputs = model.forward(input,input_len) # [batch, N, 2] | |||||
model_outputs = model.forward(input, input_len) # [batch, N, 2] | |||||
outputs = model_outputs[Const.OUTPUTS] | outputs = model_outputs[Const.OUTPUTS] | ||||
prediction = model_outputs["prediction"] | prediction = model_outputs["prediction"] | ||||
@@ -7,10 +7,12 @@ from transformer.Layers import EncoderLayer, DecoderLayer | |||||
__author__ = "Yu-Hsiang Huang" | __author__ = "Yu-Hsiang Huang" | ||||
def get_non_pad_mask(seq): | def get_non_pad_mask(seq): | ||||
assert seq.dim() == 2 | assert seq.dim() == 2 | ||||
return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1) | return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1) | ||||
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): | ||||
''' Sinusoid position encoding table ''' | ''' Sinusoid position encoding table ''' | ||||
@@ -31,6 +33,7 @@ def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): | |||||
return torch.FloatTensor(sinusoid_table) | return torch.FloatTensor(sinusoid_table) | ||||
def get_attn_key_pad_mask(seq_k, seq_q): | def get_attn_key_pad_mask(seq_k, seq_q): | ||||
''' For masking out the padding part of key sequence. ''' | ''' For masking out the padding part of key sequence. ''' | ||||
@@ -41,6 +44,7 @@ def get_attn_key_pad_mask(seq_k, seq_q): | |||||
return padding_mask | return padding_mask | ||||
def get_subsequent_mask(seq): | def get_subsequent_mask(seq): | ||||
''' For masking out the subsequent info. ''' | ''' For masking out the subsequent info. ''' | ||||
@@ -51,6 +55,7 @@ def get_subsequent_mask(seq): | |||||
return subsequent_mask | return subsequent_mask | ||||
class Encoder(nn.Module): | class Encoder(nn.Module): | ||||
''' A encoder model with self attention mechanism. ''' | ''' A encoder model with self attention mechanism. ''' | ||||
@@ -98,6 +103,7 @@ class Encoder(nn.Module): | |||||
return enc_output, enc_slf_attn_list | return enc_output, enc_slf_attn_list | ||||
return enc_output, | return enc_output, | ||||
class Decoder(nn.Module): | class Decoder(nn.Module): | ||||
''' A decoder model with self attention mechanism. ''' | ''' A decoder model with self attention mechanism. ''' | ||||
@@ -152,6 +158,7 @@ class Decoder(nn.Module): | |||||
return dec_output, dec_slf_attn_list, dec_enc_attn_list | return dec_output, dec_slf_attn_list, dec_enc_attn_list | ||||
return dec_output, | return dec_output, | ||||
class Transformer(nn.Module): | class Transformer(nn.Module): | ||||
''' A sequence to sequence model with attention mechanism. ''' | ''' A sequence to sequence model with attention mechanism. ''' | ||||
@@ -181,8 +188,8 @@ class Transformer(nn.Module): | |||||
nn.init.xavier_normal_(self.tgt_word_prj.weight) | nn.init.xavier_normal_(self.tgt_word_prj.weight) | ||||
assert d_model == d_word_vec, \ | assert d_model == d_word_vec, \ | ||||
'To facilitate the residual connections, \ | |||||
the dimensions of all module outputs shall be the same.' | |||||
'To facilitate the residual connections, \ | |||||
the dimensions of all module outputs shall be the same.' | |||||
if tgt_emb_prj_weight_sharing: | if tgt_emb_prj_weight_sharing: | ||||
# Share the weight matrix between target word embedding & the final logit dense layer | # Share the weight matrix between target word embedding & the final logit dense layer | ||||
@@ -194,7 +201,7 @@ class Transformer(nn.Module): | |||||
if emb_src_tgt_weight_sharing: | if emb_src_tgt_weight_sharing: | ||||
# Share the weight matrix between source & target word embeddings | # Share the weight matrix between source & target word embeddings | ||||
assert n_src_vocab == n_tgt_vocab, \ | assert n_src_vocab == n_tgt_vocab, \ | ||||
"To share word embedding table, the vocabulary size of src/tgt shall be the same." | |||||
"To share word embedding table, the vocabulary size of src/tgt shall be the same." | |||||
self.encoder.src_word_emb.weight = self.decoder.tgt_word_emb.weight | self.encoder.src_word_emb.weight = self.decoder.tgt_word_emb.weight | ||||
def forward(self, src_seq, src_pos, tgt_seq, tgt_pos): | def forward(self, src_seq, src_pos, tgt_seq, tgt_pos): | ||||
@@ -1,7 +1,6 @@ | |||||
import fastNLP | import fastNLP | ||||
import torch | import torch | ||||
import math | import math | ||||
from fastNLP.modules.encoder.transformer import TransformerEncoder | |||||
from fastNLP.modules.decoder.crf import ConditionalRandomField | from fastNLP.modules.decoder.crf import ConditionalRandomField | ||||
from fastNLP import Const | from fastNLP import Const | ||||
import copy | import copy | ||||
@@ -181,7 +180,6 @@ def make_CWS( | |||||
freeze=True, | freeze=True, | ||||
): | ): | ||||
c = copy.deepcopy | c = copy.deepcopy | ||||
# encoder=TransformerEncoder(num_layers=N,model_size=d_model,inner_size=d_ff,key_size=d_model//h,value_size=d_model//h,num_head=h,dropout=dropout) | |||||
encoder = transformer.make_encoder( | encoder = transformer.make_encoder( | ||||
N=N, d_model=d_model, h=h, dropout=dropout, d_ff=d_ff | N=N, d_model=d_model, h=h, dropout=dropout, d_ff=d_ff | ||||
) | ) | ||||
@@ -1,9 +1,8 @@ | |||||
import torch | |||||
import torch.nn as nn | import torch.nn as nn | ||||
from fastNLP.core.const import Const as C | from fastNLP.core.const import Const as C | ||||
from fastNLP.modules.encoder.lstm import LSTM | from fastNLP.modules.encoder.lstm import LSTM | ||||
from fastNLP.embeddings.utils import get_embeddings | from fastNLP.embeddings.utils import get_embeddings | ||||
from fastNLP.modules.encoder.attention import SelfAttention | |||||
from fastNLP.modules.attention import SelfAttention | |||||
from fastNLP.modules.decoder.mlp import MLP | from fastNLP.modules.decoder.mlp import MLP | ||||
@@ -44,7 +44,7 @@ class WeightDrop(torch.nn.Module): | |||||
def forward(self, *args): | def forward(self, *args): | ||||
self._setweights() | self._setweights() | ||||
return self.module.forward(*args) | |||||
return self.module.forward() | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
import torch | import torch | ||||
@@ -40,8 +40,7 @@ class TestBertEmbedding(unittest.TestCase): | |||||
result = embed(words) | result = embed(words) | ||||
self.assertEqual(result.size(), (1, 4, 16)) | self.assertEqual(result.size(), (1, 4, 16)) | ||||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1, | |||||
only_use_pretrain_bpe=True) | |||||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1) | |||||
embed.eval() | embed.eval() | ||||
words = torch.LongTensor([[2, 3, 4, 0]]) | words = torch.LongTensor([[2, 3, 4, 0]]) | ||||
result = embed(words) | result = embed(words) | ||||
@@ -49,53 +48,30 @@ class TestBertEmbedding(unittest.TestCase): | |||||
# 自动截断而不报错 | # 自动截断而不报错 | ||||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1, | embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1, | ||||
only_use_pretrain_bpe=True, auto_truncate=True) | |||||
auto_truncate=True) | |||||
words = torch.LongTensor([[2, 3, 4, 1]*10, | words = torch.LongTensor([[2, 3, 4, 1]*10, | ||||
[2, 3]+[0]*38]) | [2, 3]+[0]*38]) | ||||
result = embed(words) | result = embed(words) | ||||
self.assertEqual(result.size(), (2, 40, 16)) | self.assertEqual(result.size(), (2, 40, 16)) | ||||
def test_bert_embedding_2(self): | |||||
# 测试only_use_pretrain_vocab与truncate_embed是否正常工作 | |||||
with open('test/data_for_tests/embedding/small_bert/vocab.txt', 'r', encoding='utf-8') as f: | |||||
num_word = len(f.readlines()) | |||||
Embedding = BertEmbedding | |||||
vocab = Vocabulary().add_word_lst("this is a texta and [SEP] NotInBERT".split()) | |||||
embed1 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | |||||
only_use_pretrain_bpe=True, truncate_embed=True, min_freq=1) | |||||
embed_bpe_vocab_size = len(vocab)-1 + 2 # 排除NotInBERT, 额外加##a, [CLS] | |||||
self.assertEqual(embed_bpe_vocab_size, len(embed1.model.tokenzier.vocab)) | |||||
embed2 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | |||||
only_use_pretrain_bpe=True, truncate_embed=False, min_freq=1) | |||||
embed_bpe_vocab_size = num_word # 排除NotInBERT | |||||
self.assertEqual(embed_bpe_vocab_size, len(embed2.model.tokenzier.vocab)) | |||||
embed3 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | |||||
only_use_pretrain_bpe=False, truncate_embed=True, min_freq=1) | |||||
embed_bpe_vocab_size = len(vocab)+2 # 新增##a, [CLS] | |||||
self.assertEqual(embed_bpe_vocab_size, len(embed3.model.tokenzier.vocab)) | |||||
embed4 = Embedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | |||||
only_use_pretrain_bpe=False, truncate_embed=False, min_freq=1) | |||||
embed_bpe_vocab_size = num_word+1 # 新增##a | |||||
self.assertEqual(embed_bpe_vocab_size, len(embed4.model.tokenzier.vocab)) | |||||
# 测试各种情况下以下tensor的值是相等的 | |||||
embed1.eval() | |||||
embed2.eval() | |||||
embed3.eval() | |||||
embed4.eval() | |||||
tensor = torch.LongTensor([[vocab.to_index(w) for w in 'this is a texta and'.split()]]) | |||||
t1 = embed1(tensor) | |||||
t2 = embed2(tensor) | |||||
t3 = embed3(tensor) | |||||
t4 = embed4(tensor) | |||||
self.assertEqual((t1-t2).sum(), 0) | |||||
self.assertEqual((t1-t3).sum(), 0) | |||||
self.assertEqual((t1-t4).sum(), 0) | |||||
def test_save_load(self): | |||||
bert_save_test = 'bert_save_test' | |||||
try: | |||||
os.makedirs(bert_save_test, exist_ok=True) | |||||
vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInBERT".split()) | |||||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1, | |||||
auto_truncate=True) | |||||
embed.save(bert_save_test) | |||||
load_embed = BertEmbedding.load(bert_save_test) | |||||
words = torch.randint(len(vocab), size=(2, 20)) | |||||
embed.eval(), load_embed.eval() | |||||
self.assertEqual((embed(words) - load_embed(words)).sum(), 0) | |||||
finally: | |||||
import shutil | |||||
shutil.rmtree(bert_save_test) | |||||
class TestBertWordPieceEncoder(unittest.TestCase): | class TestBertWordPieceEncoder(unittest.TestCase): | ||||
@@ -120,11 +96,30 @@ class TestBertWordPieceEncoder(unittest.TestCase): | |||||
ds.set_input('words') | ds.set_input('words') | ||||
words = torch.LongTensor(ds['words'].get([0, 1])) | words = torch.LongTensor(ds['words'].get([0, 1])) | ||||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', | ||||
pool_method='first', include_cls_sep=True, pooled_cls=False) | |||||
pool_method='first', include_cls_sep=True, pooled_cls=False, min_freq=1) | |||||
embed.eval() | embed.eval() | ||||
words_res = embed(words) | words_res = embed(words) | ||||
# 检查word piece什么的是正常work的 | # 检查word piece什么的是正常work的 | ||||
self.assertEqual((word_pieces_res[0, :5]-words_res[0, :5]).sum(), 0) | self.assertEqual((word_pieces_res[0, :5]-words_res[0, :5]).sum(), 0) | ||||
self.assertEqual((word_pieces_res[0, 6:]-words_res[0, 5:]).sum(), 0) | self.assertEqual((word_pieces_res[0, 6:]-words_res[0, 5:]).sum(), 0) | ||||
self.assertEqual((word_pieces_res[1, :3]-words_res[1, :3]).sum(), 0) | |||||
self.assertEqual((word_pieces_res[1, :3]-words_res[1, :3]).sum(), 0) | |||||
def test_save_load(self): | |||||
bert_save_test = 'bert_save_test' | |||||
try: | |||||
os.makedirs(bert_save_test, exist_ok=True) | |||||
embed = BertWordPieceEncoder(model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.0, | |||||
layers='-2') | |||||
ds = DataSet({'words': ["this is a test . [SEP]".split()]}) | |||||
embed.index_datasets(ds, field_name='words') | |||||
self.assertTrue(ds.has_field('word_pieces')) | |||||
words = torch.LongTensor([[1, 2, 3, 4]]) | |||||
embed.save(bert_save_test) | |||||
load_embed = BertWordPieceEncoder.load(bert_save_test) | |||||
embed.eval(), load_embed.eval() | |||||
self.assertEqual((embed(words) - load_embed(words)).sum(), 0) | |||||
finally: | |||||
import shutil | |||||
shutil.rmtree(bert_save_test) | |||||
@@ -255,14 +255,17 @@ class TestGPT2WordPieceEncoder(unittest.TestCase): | |||||
result = embed(torch.LongTensor([[1, 2, 3, 4]])) | result = embed(torch.LongTensor([[1, 2, 3, 4]])) | ||||
def test_generate(self): | def test_generate(self): | ||||
weight_path = 'test/data_for_tests/embedding/small_gpt2' | |||||
# weight_path = 'test/data_for_tests/embedding/small_gpt2' | |||||
weight_path = 'en' | |||||
encoder = GPT2WordPieceEncoder(model_dir_or_name=weight_path, language_model=True) | encoder = GPT2WordPieceEncoder(model_dir_or_name=weight_path, language_model=True) | ||||
# 测试一下各项东西是否正常work | # 测试一下各项东西是否正常work | ||||
print(encoder.generate_from_str('this', max_len=20, do_sample=False, num_beams=1, temperature=1, top_k=50, top_p=1.0, | |||||
print(encoder.generate_from_str('This', max_len=20, do_sample=False, num_beams=1, temperature=1, top_k=50, top_p=1.0, | |||||
repetition_penalty=1.0, length_penalty=1.0)) | |||||
print(encoder.generate_from_str('This day', max_len=20, do_sample=False, num_beams=1, temperature=1, top_k=50, top_p=1.0, | |||||
repetition_penalty=1.0, length_penalty=1.0)) | repetition_penalty=1.0, length_penalty=1.0)) | ||||
print(encoder.generate_from_str('this', max_len=20, do_sample=True, num_beams=3, temperature=1, top_k=50, top_p=1.0, | |||||
print(encoder.generate_from_str('This', max_len=20, do_sample=True, num_beams=3, temperature=1, top_k=50, top_p=1.0, | |||||
repetition_penalty=1.0, length_penalty=1.0)) | repetition_penalty=1.0, length_penalty=1.0)) | ||||
print(encoder.generate_from_str('this', max_len=20, do_sample=True, num_beams=3, temperature=2, top_k=20, top_p=2.0, | |||||
print(encoder.generate_from_str('This', max_len=20, do_sample=True, num_beams=3, temperature=2, top_k=20, top_p=2.0, | |||||
repetition_penalty=2.0, length_penalty=1.5)) | repetition_penalty=2.0, length_penalty=1.5)) |
@@ -47,7 +47,7 @@ class TestRobertWordPieceEncoder(unittest.TestCase): | |||||
ds.set_input('words') | ds.set_input('words') | ||||
words = torch.LongTensor(ds['words'].get([0, 1])) | words = torch.LongTensor(ds['words'].get([0, 1])) | ||||
embed = RobertaEmbedding(vocab, model_dir_or_name=weight_path, | embed = RobertaEmbedding(vocab, model_dir_or_name=weight_path, | ||||
pool_method='first', include_cls_sep=True, pooled_cls=False) | |||||
pool_method='first', include_cls_sep=True, pooled_cls=False, min_freq=1) | |||||
embed.eval() | embed.eval() | ||||
words_res = embed(words) | words_res = embed(words) | ||||
@@ -183,6 +183,24 @@ class TestRobertWordPieceEncoder(unittest.TestCase): | |||||
torch.save(model.state_dict(), 'test/data_for_tests/embedding/small_roberta/small_pytorch_model.bin') | torch.save(model.state_dict(), 'test/data_for_tests/embedding/small_roberta/small_pytorch_model.bin') | ||||
print(model(torch.LongTensor([[0,1,2,3]]))) | print(model(torch.LongTensor([[0,1,2,3]]))) | ||||
def test_save_load(self): | |||||
bert_save_test = 'roberta_save_test' | |||||
try: | |||||
os.makedirs(bert_save_test, exist_ok=True) | |||||
embed = RobertaWordPieceEncoder(model_dir_or_name='test/data_for_tests/embedding/small_roberta', word_dropout=0.0, | |||||
layers='-2') | |||||
ds = DataSet({'words': ["this is a test . [SEP]".split()]}) | |||||
embed.index_datasets(ds, field_name='words') | |||||
self.assertTrue(ds.has_field('word_pieces')) | |||||
words = torch.LongTensor([[1, 2, 3, 4]]) | |||||
embed.save(bert_save_test) | |||||
load_embed = RobertaWordPieceEncoder.load(bert_save_test) | |||||
embed.eval(), load_embed.eval() | |||||
self.assertEqual((embed(words) - load_embed(words)).sum(), 0) | |||||
finally: | |||||
import shutil | |||||
shutil.rmtree(bert_save_test) | |||||
class TestRobertaEmbedding(unittest.TestCase): | class TestRobertaEmbedding(unittest.TestCase): | ||||
def test_roberta_embedding_1(self): | def test_roberta_embedding_1(self): | ||||
@@ -250,3 +268,20 @@ class TestRobertaEmbedding(unittest.TestCase): | |||||
self.assertEqual((t1-t2).sum(), 0) | self.assertEqual((t1-t2).sum(), 0) | ||||
self.assertEqual((t1-t3).sum(), 0) | self.assertEqual((t1-t3).sum(), 0) | ||||
self.assertEqual((t1-t4).sum(), 0) | self.assertEqual((t1-t4).sum(), 0) | ||||
def test_save_load(self): | |||||
bert_save_test = 'roberta_save_test' | |||||
try: | |||||
os.makedirs(bert_save_test, exist_ok=True) | |||||
vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInBERT".split()) | |||||
embed = RobertaEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_roberta', | |||||
word_dropout=0.1, | |||||
auto_truncate=True) | |||||
embed.save(bert_save_test) | |||||
load_embed = RobertaEmbedding.load(bert_save_test) | |||||
words = torch.randint(len(vocab), size=(2, 20)) | |||||
embed.eval(), load_embed.eval() | |||||
self.assertEqual((embed(words) - load_embed(words)).sum(), 0) | |||||
finally: | |||||
import shutil | |||||
shutil.rmtree(bert_save_test) |
@@ -108,6 +108,56 @@ class TestLoad(unittest.TestCase): | |||||
for v1i, v2i in zip(v1, v2): | for v1i, v2i in zip(v1, v2): | ||||
self.assertAlmostEqual(v1i, v2i, places=4) | self.assertAlmostEqual(v1i, v2i, places=4) | ||||
def test_save_load_static_embed(self): | |||||
static_test_folder = 'static_save_test' | |||||
try: | |||||
# 测试包含no_create_entry | |||||
os.makedirs(static_test_folder, exist_ok=True) | |||||
vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A']) | |||||
vocab.add_word_lst(['notinfile2', 'notinfile2'], no_create_entry=True) | |||||
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||||
'glove.6B.50d_test.txt') | |||||
embed.save(static_test_folder) | |||||
load_embed = StaticEmbedding.load(static_test_folder) | |||||
words = torch.randint(len(vocab), size=(2, 20)) | |||||
self.assertEqual((embed(words) - load_embed(words)).sum(), 0) | |||||
# 测试不包含no_create_entry | |||||
vocab = Vocabulary().add_word_lst(['The', 'a', 'notinfile1', 'A']) | |||||
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||||
'glove.6B.50d_test.txt') | |||||
embed.save(static_test_folder) | |||||
load_embed = StaticEmbedding.load(static_test_folder) | |||||
words = torch.randint(len(vocab), size=(2, 20)) | |||||
self.assertEqual((embed(words) - load_embed(words)).sum(), 0) | |||||
# 测试lower, min_freq | |||||
vocab = Vocabulary().add_word_lst(['The', 'the', 'the', 'A', 'a', 'B']) | |||||
embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_static_embedding/' | |||||
'glove.6B.50d_test.txt', min_freq=2, lower=True) | |||||
embed.save(static_test_folder) | |||||
load_embed = StaticEmbedding.load(static_test_folder) | |||||
words = torch.randint(len(vocab), size=(2, 20)) | |||||
self.assertEqual((embed(words) - load_embed(words)).sum(), 0) | |||||
# 测试random的embedding | |||||
vocab = Vocabulary().add_word_lst(['The', 'the', 'the', 'A', 'a', 'B']) | |||||
vocab = vocab.add_word_lst(['b'], no_create_entry=True) | |||||
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=4, min_freq=2, lower=True, | |||||
normalize=True) | |||||
embed.weight.data += 0.2 # 使得它不是normalize | |||||
embed.save(static_test_folder) | |||||
load_embed = StaticEmbedding.load(static_test_folder) | |||||
words = torch.randint(len(vocab), size=(2, 20)) | |||||
self.assertEqual((embed(words) - load_embed(words)).sum(), 0) | |||||
finally: | |||||
if os.path.isdir(static_test_folder): | |||||
import shutil | |||||
shutil.rmtree(static_test_folder) | |||||
def read_static_embed(fp): | def read_static_embed(fp): | ||||
""" | """ | ||||
@@ -30,11 +30,11 @@ class TestLoad(unittest.TestCase): | |||||
'imdb': ('test/data_for_tests/io/imdb', IMDBLoader, (6, 6, 6), False), | 'imdb': ('test/data_for_tests/io/imdb', IMDBLoader, (6, 6, 6), False), | ||||
'ChnSentiCorp': ('test/data_for_tests/io/ChnSentiCorp', ChnSentiCorpLoader, (6, 6, 6), False), | 'ChnSentiCorp': ('test/data_for_tests/io/ChnSentiCorp', ChnSentiCorpLoader, (6, 6, 6), False), | ||||
'THUCNews': ('test/data_for_tests/io/THUCNews', THUCNewsLoader, (9, 9, 9), False), | 'THUCNews': ('test/data_for_tests/io/THUCNews', THUCNewsLoader, (9, 9, 9), False), | ||||
'WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kLoader, (7, 6, 6), False), | |||||
'WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kLoader, (6, 6, 7), False), | |||||
} | } | ||||
for k, v in data_set_dict.items(): | for k, v in data_set_dict.items(): | ||||
path, loader, data_set, warns = v | path, loader, data_set, warns = v | ||||
with self.subTest(loader=loader): | |||||
with self.subTest(path=path): | |||||
if warns: | if warns: | ||||
with self.assertWarns(Warning): | with self.assertWarns(Warning): | ||||
data_bundle = loader().load(path) | data_bundle = loader().load(path) | ||||
@@ -45,5 +45,6 @@ class TestLoad(unittest.TestCase): | |||||
self.assertEqual(len(data_set), data_bundle.num_dataset) | self.assertEqual(len(data_set), data_bundle.num_dataset) | ||||
for x, y in zip(data_set, data_bundle.iter_datasets()): | for x, y in zip(data_set, data_bundle.iter_datasets()): | ||||
name, dataset = y | name, dataset = y | ||||
self.assertEqual(x, len(dataset)) | |||||
with self.subTest(split=name): | |||||
self.assertEqual(x, len(dataset)) | |||||
@@ -32,7 +32,7 @@ class TestMatchingLoad(unittest.TestCase): | |||||
'Quora': ('test/data_for_tests/io/Quora', QuoraLoader, (2, 2, 2), False), | 'Quora': ('test/data_for_tests/io/Quora', QuoraLoader, (2, 2, 2), False), | ||||
'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusLoader, (5, 5, 5), False), | 'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusLoader, (5, 5, 5), False), | ||||
'XNLI': ('test/data_for_tests/io/XNLI', CNXNLILoader, (6, 8, 6), False), | 'XNLI': ('test/data_for_tests/io/XNLI', CNXNLILoader, (6, 8, 6), False), | ||||
'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCLoader, (5, 6, 6), False), | |||||
'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCLoader, (6, 6, 5), False), | |||||
} | } | ||||
for k, v in data_set_dict.items(): | for k, v in data_set_dict.items(): | ||||
path, loader, instance, warns = v | path, loader, instance, warns = v | ||||
@@ -46,5 +46,6 @@ class TestMatchingLoad(unittest.TestCase): | |||||
self.assertEqual(len(instance), data_bundle.num_dataset) | self.assertEqual(len(instance), data_bundle.num_dataset) | ||||
for x, y in zip(instance, data_bundle.iter_datasets()): | for x, y in zip(instance, data_bundle.iter_datasets()): | ||||
name, dataset = y | name, dataset = y | ||||
self.assertEqual(x, len(dataset)) | |||||
with self.subTest(path=path, split=name): | |||||
self.assertEqual(x, len(dataset)) | |||||
@@ -70,7 +70,7 @@ class TestRunClassificationPipe(unittest.TestCase): | |||||
} | } | ||||
for k, v in data_set_dict.items(): | for k, v in data_set_dict.items(): | ||||
path, pipe, data_set, vocab, warns = v | path, pipe, data_set, vocab, warns = v | ||||
with self.subTest(pipe=pipe): | |||||
with self.subTest(path=path): | |||||
if 'Chn' not in k: | if 'Chn' not in k: | ||||
if warns: | if warns: | ||||
with self.assertWarns(Warning): | with self.assertWarns(Warning): | ||||
@@ -39,7 +39,7 @@ class TestRunMatchingPipe(unittest.TestCase): | |||||
'MNLI': ('test/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True), | 'MNLI': ('test/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True), | ||||
'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False), | 'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False), | ||||
'XNLI': ('test/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 8, 6), (39, 3), False), | 'XNLI': ('test/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 8, 6), (39, 3), False), | ||||
'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (5, 6, 6), (36, 2), False), | |||||
'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (6, 6, 5), (36, 2), False), | |||||
} | } | ||||
for k, v in data_set_dict.items(): | for k, v in data_set_dict.items(): | ||||
path, pipe1, pipe2, data_set, vocab, warns = v | path, pipe1, pipe2, data_set, vocab, warns = v | ||||
@@ -58,7 +58,8 @@ class TestRunMatchingPipe(unittest.TestCase): | |||||
print(data_bundle2) | print(data_bundle2) | ||||
for x, y in zip(data_set, data_bundle1.iter_datasets()): | for x, y in zip(data_set, data_bundle1.iter_datasets()): | ||||
name, dataset = y | name, dataset = y | ||||
self.assertEqual(x, len(dataset)) | |||||
with self.subTest(path=path, split=name): | |||||
self.assertEqual(x, len(dataset)) | |||||
self.assertEqual(len(data_set), data_bundle2.num_dataset) | self.assertEqual(len(data_set), data_bundle2.num_dataset) | ||||
for x, y in zip(data_set, data_bundle2.iter_datasets()): | for x, y in zip(data_set, data_bundle2.iter_datasets()): | ||||
name, dataset = y | name, dataset = y | ||||
@@ -0,0 +1,76 @@ | |||||
import unittest | |||||
from fastNLP.models import SequenceGeneratorModel | |||||
from fastNLP.models import LSTMSeq2SeqModel, TransformerSeq2SeqModel | |||||
from fastNLP import Vocabulary, DataSet | |||||
import torch | |||||
from fastNLP.embeddings import StaticEmbedding | |||||
from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric | |||||
from fastNLP import Callback | |||||
def prepare_env(): | |||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||||
vocab.add_word_lst("Another test !".split()) | |||||
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5) | |||||
src_words_idx = [[3, 1, 2], [1, 2]] | |||||
# tgt_words_idx = [[1, 2, 3, 4], [2, 3]] | |||||
src_seq_len = [3, 2] | |||||
# tgt_seq_len = [4, 2] | |||||
ds = DataSet({'src_tokens': src_words_idx, 'src_seq_len': src_seq_len, 'tgt_tokens': src_words_idx, | |||||
'tgt_seq_len':src_seq_len}) | |||||
ds.set_input('src_tokens', 'tgt_tokens', 'src_seq_len') | |||||
ds.set_target('tgt_seq_len', 'tgt_tokens') | |||||
return embed, ds | |||||
class ExitCallback(Callback): | |||||
def __init__(self): | |||||
super().__init__() | |||||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||||
if eval_result['AccuracyMetric']['acc']==1: | |||||
raise KeyboardInterrupt() | |||||
class TestSeq2SeqGeneratorModel(unittest.TestCase): | |||||
def test_run(self): | |||||
# 检测是否能够使用SequenceGeneratorModel训练, 透传预测 | |||||
embed, ds = prepare_env() | |||||
model1 = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||||
pos_embed='sin', max_position=20, num_layers=2, d_model=30, n_head=6, | |||||
dim_ff=20, dropout=0.1, | |||||
bind_encoder_decoder_embed=True, | |||||
bind_decoder_input_output_embed=True) | |||||
trainer = Trainer(ds, model1, optimizer=None, loss=CrossEntropyLoss(target='tgt_tokens', seq_len='tgt_seq_len'), | |||||
batch_size=32, sampler=None, drop_last=False, update_every=1, | |||||
num_workers=0, n_epochs=100, print_every=5, | |||||
dev_data=ds, metrics=AccuracyMetric(target='tgt_tokens', seq_len='tgt_seq_len'), metric_key=None, | |||||
validate_every=-1, save_path=None, use_tqdm=False, device=None, | |||||
callbacks=ExitCallback(), check_code_level=0) | |||||
res = trainer.train() | |||||
self.assertEqual(res['best_eval']['AccuracyMetric']['acc'], 1) | |||||
embed, ds = prepare_env() | |||||
model2 = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||||
num_layers=1, hidden_size=20, dropout=0.1, | |||||
bind_encoder_decoder_embed=True, | |||||
bind_decoder_input_output_embed=True, attention=True) | |||||
optimizer = torch.optim.Adam(model2.parameters(), lr=0.01) | |||||
trainer = Trainer(ds, model2, optimizer=optimizer, loss=CrossEntropyLoss(target='tgt_tokens', seq_len='tgt_seq_len'), | |||||
batch_size=32, sampler=None, drop_last=False, update_every=1, | |||||
num_workers=0, n_epochs=200, print_every=1, | |||||
dev_data=ds, metrics=AccuracyMetric(target='tgt_tokens', seq_len='tgt_seq_len'), | |||||
metric_key=None, | |||||
validate_every=-1, save_path=None, use_tqdm=False, device=None, | |||||
callbacks=ExitCallback(), check_code_level=0) | |||||
res = trainer.train() | |||||
self.assertEqual(res['best_eval']['AccuracyMetric']['acc'], 1) | |||||
@@ -0,0 +1,114 @@ | |||||
import unittest | |||||
from fastNLP.models.seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel | |||||
from fastNLP import Vocabulary | |||||
from fastNLP.embeddings import StaticEmbedding | |||||
import torch | |||||
from torch import optim | |||||
import torch.nn.functional as F | |||||
from fastNLP import seq_len_to_mask | |||||
def prepare_env(): | |||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||||
vocab.add_word_lst("Another test !".split()) | |||||
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5) | |||||
src_words_idx = torch.LongTensor([[3, 1, 2], [1, 2, 0]]) | |||||
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) | |||||
src_seq_len = torch.LongTensor([3, 2]) | |||||
tgt_seq_len = torch.LongTensor([4, 2]) | |||||
return embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len | |||||
def train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len): | |||||
optimizer = optim.Adam(model.parameters(), lr=1e-2) | |||||
mask = seq_len_to_mask(tgt_seq_len).eq(0) | |||||
target = tgt_words_idx.masked_fill(mask, -100) | |||||
for i in range(100): | |||||
optimizer.zero_grad() | |||||
pred = model(src_words_idx, tgt_words_idx, src_seq_len)['pred'] # bsz x max_len x vocab_size | |||||
loss = F.cross_entropy(pred.transpose(1, 2), target) | |||||
loss.backward() | |||||
optimizer.step() | |||||
right_count = pred.argmax(dim=-1).eq(target).masked_fill(mask, 1).sum() | |||||
return right_count | |||||
class TestTransformerSeq2SeqModel(unittest.TestCase): | |||||
def test_run(self): | |||||
# 测试能否跑通 | |||||
embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() | |||||
for pos_embed in ['learned', 'sin']: | |||||
with self.subTest(pos_embed=pos_embed): | |||||
model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||||
pos_embed=pos_embed, max_position=20, num_layers=2, d_model=30, n_head=6, dim_ff=20, dropout=0.1, | |||||
bind_encoder_decoder_embed=True, | |||||
bind_decoder_input_output_embed=True) | |||||
output = model(src_words_idx, tgt_words_idx, src_seq_len) | |||||
self.assertEqual(output['pred'].size(), (2, 4, len(embed))) | |||||
for bind_encoder_decoder_embed in [True, False]: | |||||
tgt_embed = None | |||||
for bind_decoder_input_output_embed in [True, False]: | |||||
if bind_encoder_decoder_embed == False: | |||||
tgt_embed = embed | |||||
with self.subTest(bind_encoder_decoder_embed=bind_encoder_decoder_embed, | |||||
bind_decoder_input_output_embed=bind_decoder_input_output_embed): | |||||
model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=tgt_embed, | |||||
pos_embed='sin', max_position=20, num_layers=2, | |||||
d_model=30, n_head=6, dim_ff=20, dropout=0.1, | |||||
bind_encoder_decoder_embed=bind_encoder_decoder_embed, | |||||
bind_decoder_input_output_embed=bind_decoder_input_output_embed) | |||||
output = model(src_words_idx, tgt_words_idx, src_seq_len) | |||||
self.assertEqual(output['pred'].size(), (2, 4, len(embed))) | |||||
def test_train(self): | |||||
# 测试能否train到overfit | |||||
embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() | |||||
model = TransformerSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||||
pos_embed='sin', max_position=20, num_layers=2, d_model=30, n_head=6, dim_ff=20, dropout=0.1, | |||||
bind_encoder_decoder_embed=True, | |||||
bind_decoder_input_output_embed=True) | |||||
right_count = train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len) | |||||
self.assertEqual(right_count, tgt_words_idx.nelement()) | |||||
class TestLSTMSeq2SeqModel(unittest.TestCase): | |||||
def test_run(self): | |||||
# 测试能否跑通 | |||||
embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() | |||||
for bind_encoder_decoder_embed in [True, False]: | |||||
tgt_embed = None | |||||
for bind_decoder_input_output_embed in [True, False]: | |||||
if bind_encoder_decoder_embed == False: | |||||
tgt_embed = embed | |||||
with self.subTest(bind_encoder_decoder_embed=bind_encoder_decoder_embed, | |||||
bind_decoder_input_output_embed=bind_decoder_input_output_embed): | |||||
model = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=tgt_embed, | |||||
num_layers=2, hidden_size=20, dropout=0.1, | |||||
bind_encoder_decoder_embed=bind_encoder_decoder_embed, | |||||
bind_decoder_input_output_embed=bind_decoder_input_output_embed) | |||||
output = model(src_words_idx, tgt_words_idx, src_seq_len) | |||||
self.assertEqual(output['pred'].size(), (2, 4, len(embed))) | |||||
def test_train(self): | |||||
embed, src_words_idx, tgt_words_idx, src_seq_len, tgt_seq_len = prepare_env() | |||||
model = LSTMSeq2SeqModel.build_model(src_embed=embed, tgt_embed=None, | |||||
num_layers=1, hidden_size=20, dropout=0.1, | |||||
bind_encoder_decoder_embed=True, | |||||
bind_decoder_input_output_embed=True) | |||||
right_count = train_model(model, src_words_idx, tgt_words_idx, tgt_seq_len, src_seq_len) | |||||
self.assertEqual(right_count, tgt_words_idx.nelement()) | |||||
@@ -0,0 +1,50 @@ | |||||
import unittest | |||||
import torch | |||||
from fastNLP import Vocabulary | |||||
from fastNLP.embeddings import StaticEmbedding | |||||
from fastNLP.modules import TransformerSeq2SeqDecoder | |||||
from fastNLP.modules import LSTMSeq2SeqDecoder | |||||
from fastNLP import seq_len_to_mask | |||||
class TestTransformerSeq2SeqDecoder(unittest.TestCase): | |||||
def test_case(self): | |||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||||
vocab.add_word_lst("Another test !".split()) | |||||
embed = StaticEmbedding(vocab, embedding_dim=10) | |||||
encoder_output = torch.randn(2, 3, 10) | |||||
src_seq_len = torch.LongTensor([3, 2]) | |||||
encoder_mask = seq_len_to_mask(src_seq_len) | |||||
for flag in [True, False]: | |||||
with self.subTest(bind_decoder_input_output_embed=flag): | |||||
decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed = None, | |||||
d_model = 10, num_layers=2, n_head = 5, dim_ff = 20, dropout = 0.1, | |||||
bind_decoder_input_output_embed = True) | |||||
state = decoder.init_state(encoder_output, encoder_mask) | |||||
output = decoder(tokens=torch.randint(0, len(vocab), size=(2, 4)), state=state) | |||||
self.assertEqual(output.size(), (2, 4, len(vocab))) | |||||
class TestLSTMDecoder(unittest.TestCase): | |||||
def test_case(self): | |||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||||
vocab.add_word_lst("Another test !".split()) | |||||
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=10) | |||||
encoder_output = torch.randn(2, 3, 10) | |||||
tgt_words_idx = torch.LongTensor([[1, 2, 3, 4], [2, 3, 0, 0]]) | |||||
src_seq_len = torch.LongTensor([3, 2]) | |||||
encoder_mask = seq_len_to_mask(src_seq_len) | |||||
for flag in [True, False]: | |||||
for attention in [True, False]: | |||||
with self.subTest(bind_decoder_input_output_embed=flag, attention=attention): | |||||
decoder = LSTMSeq2SeqDecoder(embed=embed, num_layers = 2, hidden_size = 10, | |||||
dropout = 0.3, bind_decoder_input_output_embed=flag, attention=attention) | |||||
state = decoder.init_state(encoder_output, encoder_mask) | |||||
output = decoder(tgt_words_idx, state) | |||||
self.assertEqual(tuple(output.size()), (2, 4, len(vocab))) |
@@ -0,0 +1,30 @@ | |||||
import unittest | |||||
import torch | |||||
from fastNLP.modules.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder | |||||
from fastNLP import Vocabulary | |||||
from fastNLP.embeddings import StaticEmbedding | |||||
class TestTransformerSeq2SeqEncoder(unittest.TestCase): | |||||
def test_case(self): | |||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||||
embed = StaticEmbedding(vocab, embedding_dim=5) | |||||
encoder = TransformerSeq2SeqEncoder(embed, num_layers=2, d_model=10, n_head=2) | |||||
words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0) | |||||
seq_len = torch.LongTensor([3]) | |||||
encoder_output, encoder_mask = encoder(words_idx, seq_len) | |||||
self.assertEqual(encoder_output.size(), (1, 3, 10)) | |||||
class TestBiLSTMEncoder(unittest.TestCase): | |||||
def test_case(self): | |||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||||
embed = StaticEmbedding(vocab, embedding_dim=5) | |||||
encoder = LSTMSeq2SeqEncoder(embed, hidden_size=5, num_layers=1) | |||||
words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0) | |||||
seq_len = torch.LongTensor([3]) | |||||
encoder_output, encoder_mask = encoder(words_idx, seq_len) | |||||
self.assertEqual(encoder_mask.size(), (1, 3)) |
@@ -0,0 +1 @@ | |||||
@@ -0,0 +1,110 @@ | |||||
import unittest | |||||
import torch | |||||
from fastNLP.modules.generator import SequenceGenerator | |||||
from fastNLP.modules import TransformerSeq2SeqDecoder, LSTMSeq2SeqDecoder, Seq2SeqDecoder, State | |||||
from fastNLP import Vocabulary | |||||
from fastNLP.embeddings import StaticEmbedding | |||||
from torch import nn | |||||
from fastNLP import seq_len_to_mask | |||||
def prepare_env(): | |||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||||
vocab.add_word_lst("Another test !".split()) | |||||
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5) | |||||
encoder_output = torch.randn(2, 3, 10) | |||||
src_seq_len = torch.LongTensor([3, 2]) | |||||
encoder_mask = seq_len_to_mask(src_seq_len) | |||||
return embed, encoder_output, encoder_mask | |||||
class TestSequenceGenerator(unittest.TestCase): | |||||
def test_run(self): | |||||
# 测试能否运行 (1) 初始化decoder,(2) decode一发 | |||||
embed, encoder_output, encoder_mask = prepare_env() | |||||
for do_sample in [True, False]: | |||||
for num_beams in [1, 3, 5]: | |||||
with self.subTest(do_sample=do_sample, num_beams=num_beams): | |||||
decoder = LSTMSeq2SeqDecoder(embed=embed, num_layers=1, hidden_size=10, | |||||
dropout=0.3, bind_decoder_input_output_embed=True, attention=True) | |||||
state = decoder.init_state(encoder_output, encoder_mask) | |||||
generator = SequenceGenerator(decoder=decoder, max_length=20, num_beams=num_beams, | |||||
do_sample=do_sample, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=1, eos_token_id=None, | |||||
repetition_penalty=1, length_penalty=1.0, pad_token_id=0) | |||||
generator.generate(state=state, tokens=None) | |||||
decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed=nn.Embedding(10, embed.embedding_dim), | |||||
d_model=encoder_output.size(-1), num_layers=2, n_head=2, dim_ff=10, dropout=0.1, | |||||
bind_decoder_input_output_embed=True) | |||||
state = decoder.init_state(encoder_output, encoder_mask) | |||||
generator = SequenceGenerator(decoder=decoder, max_length=5, num_beams=num_beams, | |||||
do_sample=do_sample, temperature=1.0, top_k=50, top_p=1.0, bos_token_id=1, eos_token_id=None, | |||||
repetition_penalty=1, length_penalty=1.0, pad_token_id=0) | |||||
generator.generate(state=state, tokens=None) | |||||
# 测试一下其它值 | |||||
decoder = TransformerSeq2SeqDecoder(embed=embed, pos_embed=nn.Embedding(10, embed.embedding_dim), | |||||
d_model=encoder_output.size(-1), num_layers=2, n_head=2, dim_ff=10, | |||||
dropout=0.1, | |||||
bind_decoder_input_output_embed=True) | |||||
state = decoder.init_state(encoder_output, encoder_mask) | |||||
generator = SequenceGenerator(decoder=decoder, max_length=5, num_beams=num_beams, | |||||
do_sample=do_sample, temperature=0.9, top_k=50, top_p=0.5, bos_token_id=1, | |||||
eos_token_id=3, repetition_penalty=2, length_penalty=1.5, pad_token_id=0) | |||||
generator.generate(state=state, tokens=None) | |||||
def test_greedy_decode(self): | |||||
# 测试能否正确的generate | |||||
class GreedyDummyDecoder(Seq2SeqDecoder): | |||||
def __init__(self, decoder_output): | |||||
super().__init__() | |||||
self.cur_length = 0 | |||||
self.decoder_output = decoder_output | |||||
def decode(self, tokens, state): | |||||
self.cur_length += 1 | |||||
scores = self.decoder_output[:, self.cur_length] | |||||
return scores | |||||
class DummyState(State): | |||||
def __init__(self, decoder): | |||||
super().__init__() | |||||
self.decoder = decoder | |||||
def reorder_state(self, indices: torch.LongTensor): | |||||
self.decoder.decoder_output = self._reorder_state(self.decoder.decoder_output, indices, dim=0) | |||||
# greedy | |||||
for beam_search in [1, 3]: | |||||
decoder_output = torch.randn(2, 10, 5) | |||||
path = decoder_output.argmax(dim=-1) # 2 x 4 | |||||
decoder = GreedyDummyDecoder(decoder_output) | |||||
with self.subTest(beam_search=beam_search): | |||||
generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search, | |||||
do_sample=False, temperature=1, top_k=50, top_p=1, bos_token_id=1, | |||||
eos_token_id=None, repetition_penalty=1, length_penalty=1, pad_token_id=0) | |||||
decode_path = generator.generate(DummyState(decoder), tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True)) | |||||
self.assertEqual(decode_path.eq(path).sum(), path.numel()) | |||||
# greedy check eos_token_id | |||||
for beam_search in [1, 3]: | |||||
decoder_output = torch.randn(2, 10, 5) | |||||
decoder_output[:, :7, 4].fill_(-100) | |||||
decoder_output[0, 7, 4] = 1000 # 在第8个结束 | |||||
decoder_output[1, 5, 4] = 1000 | |||||
path = decoder_output.argmax(dim=-1) # 2 x 4 | |||||
decoder = GreedyDummyDecoder(decoder_output) | |||||
with self.subTest(beam_search=beam_search): | |||||
generator = SequenceGenerator(decoder=decoder, max_length=decoder_output.size(1), num_beams=beam_search, | |||||
do_sample=False, temperature=1, top_k=50, top_p=0.5, bos_token_id=1, | |||||
eos_token_id=4, repetition_penalty=1, length_penalty=1, pad_token_id=0) | |||||
decode_path = generator.generate(DummyState(decoder), | |||||
tokens=decoder_output[:, 0].argmax(dim=-1, keepdim=True)) | |||||
self.assertEqual(decode_path.size(1), 8) # 长度为8 | |||||
self.assertEqual(decode_path[0].eq(path[0, :8]).sum(), 8) | |||||
self.assertEqual(decode_path[1, :6].eq(path[1, :6]).sum(), 6) |