@@ -13,6 +13,7 @@ import torch | |||
from torch import nn as nn | |||
from .embedding import TokenEmbedding | |||
from .utils import _check_vocab_has_same_index | |||
class StackEmbedding(TokenEmbedding): | |||
@@ -44,8 +45,9 @@ class StackEmbedding(TokenEmbedding): | |||
vocabs.append(embed.get_word_vocab()) | |||
_vocab = vocabs[0] | |||
for vocab in vocabs[1:]: | |||
assert vocab == _vocab, "All embeddings in StackEmbedding should use the same word vocabulary." | |||
if _vocab!=vocab: | |||
_check_vocab_has_same_index(_vocab, vocab) | |||
super(StackEmbedding, self).__init__(_vocab, word_dropout=word_dropout, dropout=dropout) | |||
assert isinstance(embeds, list) | |||
for embed in embeds: | |||
@@ -60,6 +62,7 @@ class StackEmbedding(TokenEmbedding): | |||
:return: | |||
""" | |||
assert isinstance(embed, TokenEmbedding) | |||
_check_vocab_has_same_index(self.get_word_vocab(), embed.get_word_vocab()) | |||
self._embed_size += embed.embed_size | |||
self.embeds.append(embed) | |||
return self | |||
@@ -81,7 +81,7 @@ class StaticEmbedding(TokenEmbedding): | |||
init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1, **kwargs): | |||
r""" | |||
:param vocab: Vocabulary. 若该项为None则会读取所有的embedding。 | |||
:param Vocabulary vocab: 词表. StaticEmbedding只会加载包含在词表中的词的词向量,在预训练向量中没找到的使用随机初始化 | |||
:param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding文件夹(文件夹下应该只有一个 | |||
以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。 | |||
如果输入为None则使用embedding_dim的维度随机初始化一个embedding。 | |||
@@ -89,3 +89,16 @@ def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): | |||
return torch.FloatTensor(sinusoid_table) | |||
def _check_vocab_has_same_index(vocab, other_vocab): | |||
""" | |||
检查两个vocabulary是否含有相同的word idx | |||
:param Vocabulary vocab: | |||
:param Vocabulary other_vocab: | |||
:return: | |||
""" | |||
if other_vocab != vocab: | |||
for word, word_ix in vocab: | |||
other_word_idx = other_vocab.to_index(word) | |||
assert other_word_idx == word_ix, f"Word {word} has different index in vocabs, {word_ix} Vs. {other_word_idx}." |
@@ -34,56 +34,3 @@ class NaiveClassifier(BaseModel): | |||
def predict(self, x): | |||
return {"predict": torch.sigmoid(self.mlp(x)) > 0.5} | |||
class NaiveClassifier2(BaseModel): | |||
r""" | |||
一个简单的分类器例子,可用于各种测试 | |||
""" | |||
def __init__(self, in_feature_dim, out_feature_dim): | |||
super(NaiveClassifier2, self).__init__() | |||
self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim]) | |||
def forward(self, x): | |||
return {"predict": self.mlp(x)} | |||
def predict(self, x): | |||
return {"predict": torch.sigmoid(self.mlp(x)) > 0.5} | |||
class NaiveClassifier3(BaseModel): | |||
r""" | |||
一个简单的分类器例子,可用于各种测试 | |||
""" | |||
def __init__(self, in_feature_dim, out_feature_dim): | |||
super(NaiveClassifier3, self).__init__() | |||
self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim]) | |||
@torch.cuda.amp.autocast() | |||
def forward(self, x): | |||
return {"predict": self.mlp(x)} | |||
@torch.cuda.amp.autocast() | |||
def predict(self, x): | |||
return {"predict": torch.sigmoid(self.mlp(x)) > 0.5} | |||
class NaiveClassifier4(BaseModel): | |||
r""" | |||
一个简单的分类器例子,可用于各种测试 | |||
""" | |||
def __init__(self, in_feature_dim, out_feature_dim): | |||
super(NaiveClassifier4, self).__init__() | |||
self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim]) | |||
def forward(self, x): | |||
with torch.cuda.amp.autocast(): | |||
return {"predict": self.mlp(x)} | |||
def predict(self, x): | |||
with torch.cuda.amp.autocast(): | |||
return {"predict": torch.sigmoid(self.mlp(x)) > 0.5} |
@@ -477,7 +477,8 @@ class BertModel(nn.Module): | |||
if isinstance(module, nn.Linear) and module.bias is not None: | |||
module.bias.data.zero_() | |||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): | |||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, | |||
position_ids=None): | |||
""" | |||
:param torch.LongTensor input_ids: bsz x max_len的输入id | |||
@@ -485,6 +486,7 @@ class BertModel(nn.Module): | |||
:param attention_mask: 需要attend的为1,不需要为0 | |||
:param bool output_all_encoded_layers: 是否输出所有层,默认输出token embedding(包含bpe, position以及type embedding) | |||
及每一层的hidden states。如果为False,只输出最后一层的结果 | |||
:param torch.LongTensor position_ids: bsz x max_len, position的id | |||
:return: encode_layers: 如果output_all_encoded_layers为True,返回list(共num_layers+1个元素),每个元素为 | |||
bsz x max_len x hidden_size否则返回bsz x max_len x hidden_size的tensor; | |||
pooled_output: bsz x hidden_size为cls的表示,可以用于句子的分类 | |||
@@ -506,10 +508,11 @@ class BertModel(nn.Module): | |||
# positions we want to attend and -10000.0 for masked positions. | |||
# Since we are adding it to the raw scores before the softmax, this is | |||
# effectively the same as removing these entirely. | |||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility | |||
# this will case an issue when DataParallel: https://github.com/pytorch/pytorch/issues/40457#issuecomment-648396469 | |||
# extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility | |||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 | |||
embedding_output = self.embeddings(input_ids, token_type_ids) | |||
embedding_output = self.embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids) | |||
encoded_layers = self.encoder(embedding_output, | |||
extended_attention_mask, | |||
output_all_encoded_layers=output_all_encoded_layers) | |||
@@ -834,7 +834,8 @@ class GPT2Model(GPT2PreTrainedModel): | |||
# positions we want to attend and -10000.0 for masked positions. | |||
# Since we are adding it to the raw scores before the softmax, this is | |||
# effectively the same as removing these entirely. | |||
attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility | |||
# this will case an issue when DataParallel: https://github.com/pytorch/pytorch/issues/40457#issuecomment-648396469 | |||
# attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility | |||
attention_mask = (1.0 - attention_mask) * -10000.0 | |||
# attention_mask = attention_mask.masked_fill(attention_mask.eq(0), -10000.0) | |||
@@ -39,7 +39,7 @@ class RobertaEmbeddings(BertEmbeddings): | |||
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx | |||
) | |||
def forward(self, input_ids, token_type_ids, words_embeddings=None): | |||
def forward(self, input_ids, token_type_ids, words_embeddings=None, **kwargs): | |||
position_ids = self.create_position_ids_from_input_ids(input_ids) | |||
return super().forward( | |||
@@ -14,8 +14,12 @@ from fastNLP import CrossEntropyLoss | |||
from fastNLP import AccuracyMetric | |||
from fastNLP import SGD | |||
from fastNLP import Trainer | |||
from fastNLP.models.base_model import NaiveClassifier, NaiveClassifier2, NaiveClassifier3, NaiveClassifier4 | |||
from fastNLP.models.base_model import NaiveClassifier | |||
from fastNLP import TorchLoaderIter | |||
from fastNLP.models import BaseModel | |||
from fastNLP.modules import MLP | |||
from pkg_resources import parse_version | |||
def prepare_fake_dataset(): | |||
@@ -577,6 +581,22 @@ class TrainerTestGround(unittest.TestCase): | |||
""" | |||
class NaiveClassifier2(BaseModel): | |||
r""" | |||
一个简单的分类器例子,可用于各种测试 | |||
""" | |||
def __init__(self, in_feature_dim, out_feature_dim): | |||
super(NaiveClassifier2, self).__init__() | |||
self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim]) | |||
def forward(self, x): | |||
return {"predict": self.mlp(x)} | |||
def predict(self, x): | |||
return {"predict": torch.sigmoid(self.mlp(x)) > 0.5} | |||
class Fp16TrainerTest(unittest.TestCase): | |||
def test_raise_error(self): | |||
data_set = prepare_fake_dataset() | |||
@@ -605,7 +625,7 @@ class Fp16TrainerTest(unittest.TestCase): | |||
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None, | |||
use_tqdm=True, check_code_level=2, fp16=True, device=torch.device('cpu')) | |||
@unittest.skipIf(torch.cuda.is_available()==False, "Skip when no cuda device detch") | |||
@unittest.skipIf(torch.cuda.is_available()==False or parse_version(torch.__version__) < parse_version('1.6'), "Skip when no cuda device detch") | |||
def test_run_fp16(self): | |||
data_set = prepare_fake_dataset() | |||
data_set.set_input("x", flag=True) | |||
@@ -627,7 +647,7 @@ class Fp16TrainerTest(unittest.TestCase): | |||
use_tqdm=True, check_code_level=2, fp16=True, device=0, test_use_fp16=False) | |||
trainer.train(load_best_model=False) | |||
@unittest.skipIf(torch.cuda.device_count()<2, "Skip when lower than 1 gpus.") | |||
@unittest.skipIf(torch.cuda.device_count()<2 or parse_version(torch.__version__) < parse_version('1.6'), "Skip when lower than 1 gpus.") | |||
def test_run_data_parallel(self): | |||
data_set = prepare_fake_dataset() | |||
data_set.set_input("x", flag=True) | |||
@@ -635,6 +655,21 @@ class Fp16TrainerTest(unittest.TestCase): | |||
train_set, dev_set = data_set.split(0.3) | |||
class NaiveClassifier2(BaseModel): | |||
r""" | |||
一个简单的分类器例子,可用于各种测试 | |||
""" | |||
def __init__(self, in_feature_dim, out_feature_dim): | |||
super(NaiveClassifier2, self).__init__() | |||
self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim]) | |||
def forward(self, x): | |||
return {"predict": self.mlp(x)} | |||
def predict(self, x): | |||
return {"predict": torch.sigmoid(self.mlp(x)) > 0.5} | |||
model = NaiveClassifier2(2, 1) | |||
with self.assertRaises(RuntimeError): | |||
trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCEWithLogits(pred="predict", target="y"), | |||
@@ -643,12 +678,46 @@ class Fp16TrainerTest(unittest.TestCase): | |||
use_tqdm=True, check_code_level=2, fp16=True, device=[0, 1]) | |||
with self.assertRaises(RuntimeError): | |||
class NaiveClassifier3(BaseModel): | |||
r""" | |||
一个简单的分类器例子,可用于各种测试 | |||
""" | |||
def __init__(self, in_feature_dim, out_feature_dim): | |||
super(NaiveClassifier3, self).__init__() | |||
self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim]) | |||
@torch.cuda.amp.autocast() | |||
def forward(self, x): | |||
return {"predict": self.mlp(x)} | |||
@torch.cuda.amp.autocast() | |||
def predict(self, x): | |||
return {"predict": torch.sigmoid(self.mlp(x)) > 0.5} | |||
model = NaiveClassifier3(2, 1) | |||
trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCEWithLogits(pred="predict", target="y"), | |||
batch_size=32, n_epochs=10, print_every=50, dev_data=dev_set, | |||
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None, | |||
use_tqdm=True, check_code_level=2, fp16=True, device=[0, 1], test_use_fp16=True) | |||
class NaiveClassifier4(BaseModel): | |||
r""" | |||
一个简单的分类器例子,可用于各种测试 | |||
""" | |||
def __init__(self, in_feature_dim, out_feature_dim): | |||
super(NaiveClassifier4, self).__init__() | |||
self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim]) | |||
def forward(self, x): | |||
with torch.cuda.amp.autocast(): | |||
return {"predict": self.mlp(x)} | |||
def predict(self, x): | |||
with torch.cuda.amp.autocast(): | |||
return {"predict": torch.sigmoid(self.mlp(x)) > 0.5} | |||
model = NaiveClassifier4(2, 1) | |||
trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCEWithLogits(pred="predict", target="y"), | |||
batch_size=32, n_epochs=10, print_every=50, dev_data=dev_set, | |||
@@ -31,29 +31,33 @@ class TestDownload(unittest.TestCase): | |||
class TestBertEmbedding(unittest.TestCase): | |||
def test_bert_embedding_1(self): | |||
vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInBERT".split()) | |||
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.1) | |||
requires_grad = embed.requires_grad | |||
embed.requires_grad = not requires_grad | |||
embed.train() | |||
words = torch.LongTensor([[2, 3, 4, 0]]) | |||
result = embed(words) | |||
self.assertEqual(result.size(), (1, 4, 16)) | |||
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.1) | |||
embed.eval() | |||
words = torch.LongTensor([[2, 3, 4, 0]]) | |||
result = embed(words) | |||
self.assertEqual(result.size(), (1, 4, 16)) | |||
# 自动截断而不报错 | |||
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.1, | |||
auto_truncate=True) | |||
words = torch.LongTensor([[2, 3, 4, 1]*10, | |||
[2, 3]+[0]*38]) | |||
result = embed(words) | |||
self.assertEqual(result.size(), (2, 40, 16)) | |||
for pool_method in ['first', 'last', 'max', 'avg']: | |||
with self.subTest(pool_method=pool_method): | |||
vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInBERT".split()) | |||
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.1, | |||
pool_method=pool_method) | |||
requires_grad = embed.requires_grad | |||
embed.requires_grad = not requires_grad | |||
embed.train() | |||
words = torch.LongTensor([[2, 3, 4, 0]]) | |||
result = embed(words) | |||
self.assertEqual(result.size(), (1, 4, 16)) | |||
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.1, | |||
pool_method=pool_method) | |||
embed.eval() | |||
words = torch.LongTensor([[2, 3, 4, 0]]) | |||
result = embed(words) | |||
self.assertEqual(result.size(), (1, 4, 16)) | |||
# 自动截断而不报错 | |||
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.1, | |||
auto_truncate=True, pool_method=pool_method) | |||
words = torch.LongTensor([[2, 3, 4, 1]*10, | |||
[2, 3]+[0]*38]) | |||
result = embed(words) | |||
self.assertEqual(result.size(), (2, 40, 16)) | |||
def test_save_load(self): | |||
bert_save_test = 'bert_save_test' | |||
@@ -18,3 +18,16 @@ class TestCharEmbed(unittest.TestCase): | |||
y = embed(x) | |||
self.assertEqual(tuple(y.size()), (2, 3, 130)) | |||
def test_case_2(self): | |||
# 测试只需要拥有一样的index就可以concat | |||
ds = DataSet([Instance(words=['hello', 'world']), Instance(words=['hello', 'Jack'])]) | |||
vocab1 = Vocabulary().from_dataset(ds, field_name='words') | |||
vocab2 = Vocabulary().from_dataset(ds, field_name='words') | |||
self.assertEqual(len(vocab1), 5) | |||
cnn_embed = CNNCharEmbedding(vocab1, embed_size=60) | |||
lstm_embed = LSTMCharEmbedding(vocab2, embed_size=70) | |||
embed = StackEmbedding([cnn_embed, lstm_embed]) | |||
x = torch.LongTensor([[2, 1, 0], [4, 3, 4]]) | |||
y = embed(x) | |||
self.assertEqual(tuple(y.size()), (2, 3, 130)) | |||