Browse Source

update for basemodel bug

tags/v1.0.0alpha
yh_cc 4 years ago
parent
commit
7ca8d045e7
10 changed files with 140 additions and 87 deletions
  1. +5
    -2
      fastNLP/embeddings/stack_embedding.py
  2. +1
    -1
      fastNLP/embeddings/static_embedding.py
  3. +13
    -0
      fastNLP/embeddings/utils.py
  4. +0
    -53
      fastNLP/models/base_model.py
  5. +6
    -3
      fastNLP/modules/encoder/bert.py
  6. +2
    -1
      fastNLP/modules/encoder/gpt2.py
  7. +1
    -1
      fastNLP/modules/encoder/roberta.py
  8. +72
    -3
      tests/core/test_trainer.py
  9. +27
    -23
      tests/embeddings/test_bert_embedding.py
  10. +13
    -0
      tests/embeddings/test_stack_embeddings.py

+ 5
- 2
fastNLP/embeddings/stack_embedding.py View File

@@ -13,6 +13,7 @@ import torch
from torch import nn as nn from torch import nn as nn


from .embedding import TokenEmbedding from .embedding import TokenEmbedding
from .utils import _check_vocab_has_same_index




class StackEmbedding(TokenEmbedding): class StackEmbedding(TokenEmbedding):
@@ -44,8 +45,9 @@ class StackEmbedding(TokenEmbedding):
vocabs.append(embed.get_word_vocab()) vocabs.append(embed.get_word_vocab())
_vocab = vocabs[0] _vocab = vocabs[0]
for vocab in vocabs[1:]: for vocab in vocabs[1:]:
assert vocab == _vocab, "All embeddings in StackEmbedding should use the same word vocabulary."
if _vocab!=vocab:
_check_vocab_has_same_index(_vocab, vocab)

super(StackEmbedding, self).__init__(_vocab, word_dropout=word_dropout, dropout=dropout) super(StackEmbedding, self).__init__(_vocab, word_dropout=word_dropout, dropout=dropout)
assert isinstance(embeds, list) assert isinstance(embeds, list)
for embed in embeds: for embed in embeds:
@@ -60,6 +62,7 @@ class StackEmbedding(TokenEmbedding):
:return: :return:
""" """
assert isinstance(embed, TokenEmbedding) assert isinstance(embed, TokenEmbedding)
_check_vocab_has_same_index(self.get_word_vocab(), embed.get_word_vocab())
self._embed_size += embed.embed_size self._embed_size += embed.embed_size
self.embeds.append(embed) self.embeds.append(embed)
return self return self


+ 1
- 1
fastNLP/embeddings/static_embedding.py View File

@@ -81,7 +81,7 @@ class StaticEmbedding(TokenEmbedding):
init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1, **kwargs): init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1, **kwargs):
r""" r"""
:param vocab: Vocabulary. 若该项为None则会读取所有的embedding。
:param Vocabulary vocab: 词表. StaticEmbedding只会加载包含在词表中的词的词向量,在预训练向量中没找到的使用随机初始化
:param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding文件夹(文件夹下应该只有一个 :param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding文件夹(文件夹下应该只有一个
以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。 以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。
如果输入为None则使用embedding_dim的维度随机初始化一个embedding。 如果输入为None则使用embedding_dim的维度随机初始化一个embedding。


+ 13
- 0
fastNLP/embeddings/utils.py View File

@@ -89,3 +89,16 @@ def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):


return torch.FloatTensor(sinusoid_table) return torch.FloatTensor(sinusoid_table)



def _check_vocab_has_same_index(vocab, other_vocab):
"""
检查两个vocabulary是否含有相同的word idx

:param Vocabulary vocab:
:param Vocabulary other_vocab:
:return:
"""
if other_vocab != vocab:
for word, word_ix in vocab:
other_word_idx = other_vocab.to_index(word)
assert other_word_idx == word_ix, f"Word {word} has different index in vocabs, {word_ix} Vs. {other_word_idx}."

+ 0
- 53
fastNLP/models/base_model.py View File

@@ -34,56 +34,3 @@ class NaiveClassifier(BaseModel):
def predict(self, x): def predict(self, x):
return {"predict": torch.sigmoid(self.mlp(x)) > 0.5} return {"predict": torch.sigmoid(self.mlp(x)) > 0.5}


class NaiveClassifier2(BaseModel):
r"""
一个简单的分类器例子,可用于各种测试
"""

def __init__(self, in_feature_dim, out_feature_dim):
super(NaiveClassifier2, self).__init__()
self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim])

def forward(self, x):
return {"predict": self.mlp(x)}

def predict(self, x):
return {"predict": torch.sigmoid(self.mlp(x)) > 0.5}


class NaiveClassifier3(BaseModel):
r"""
一个简单的分类器例子,可用于各种测试
"""

def __init__(self, in_feature_dim, out_feature_dim):
super(NaiveClassifier3, self).__init__()
self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim])

@torch.cuda.amp.autocast()
def forward(self, x):
return {"predict": self.mlp(x)}

@torch.cuda.amp.autocast()
def predict(self, x):
return {"predict": torch.sigmoid(self.mlp(x)) > 0.5}


class NaiveClassifier4(BaseModel):
r"""
一个简单的分类器例子,可用于各种测试
"""

def __init__(self, in_feature_dim, out_feature_dim):
super(NaiveClassifier4, self).__init__()
self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim])

def forward(self, x):
with torch.cuda.amp.autocast():
return {"predict": self.mlp(x)}


def predict(self, x):
with torch.cuda.amp.autocast():
return {"predict": torch.sigmoid(self.mlp(x)) > 0.5}

+ 6
- 3
fastNLP/modules/encoder/bert.py View File

@@ -477,7 +477,8 @@ class BertModel(nn.Module):
if isinstance(module, nn.Linear) and module.bias is not None: if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()


def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True,
position_ids=None):
""" """


:param torch.LongTensor input_ids: bsz x max_len的输入id :param torch.LongTensor input_ids: bsz x max_len的输入id
@@ -485,6 +486,7 @@ class BertModel(nn.Module):
:param attention_mask: 需要attend的为1,不需要为0 :param attention_mask: 需要attend的为1,不需要为0
:param bool output_all_encoded_layers: 是否输出所有层,默认输出token embedding(包含bpe, position以及type embedding) :param bool output_all_encoded_layers: 是否输出所有层,默认输出token embedding(包含bpe, position以及type embedding)
及每一层的hidden states。如果为False,只输出最后一层的结果 及每一层的hidden states。如果为False,只输出最后一层的结果
:param torch.LongTensor position_ids: bsz x max_len, position的id
:return: encode_layers: 如果output_all_encoded_layers为True,返回list(共num_layers+1个元素),每个元素为 :return: encode_layers: 如果output_all_encoded_layers为True,返回list(共num_layers+1个元素),每个元素为
bsz x max_len x hidden_size否则返回bsz x max_len x hidden_size的tensor; bsz x max_len x hidden_size否则返回bsz x max_len x hidden_size的tensor;
pooled_output: bsz x hidden_size为cls的表示,可以用于句子的分类 pooled_output: bsz x hidden_size为cls的表示,可以用于句子的分类
@@ -506,10 +508,11 @@ class BertModel(nn.Module):
# positions we want to attend and -10000.0 for masked positions. # positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
# this will case an issue when DataParallel: https://github.com/pytorch/pytorch/issues/40457#issuecomment-648396469
# extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0


embedding_output = self.embeddings(input_ids, token_type_ids)
embedding_output = self.embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
encoded_layers = self.encoder(embedding_output, encoded_layers = self.encoder(embedding_output,
extended_attention_mask, extended_attention_mask,
output_all_encoded_layers=output_all_encoded_layers) output_all_encoded_layers=output_all_encoded_layers)


+ 2
- 1
fastNLP/modules/encoder/gpt2.py View File

@@ -834,7 +834,8 @@ class GPT2Model(GPT2PreTrainedModel):
# positions we want to attend and -10000.0 for masked positions. # positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
# this will case an issue when DataParallel: https://github.com/pytorch/pytorch/issues/40457#issuecomment-648396469
# attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0 attention_mask = (1.0 - attention_mask) * -10000.0
# attention_mask = attention_mask.masked_fill(attention_mask.eq(0), -10000.0) # attention_mask = attention_mask.masked_fill(attention_mask.eq(0), -10000.0)




+ 1
- 1
fastNLP/modules/encoder/roberta.py View File

@@ -39,7 +39,7 @@ class RobertaEmbeddings(BertEmbeddings):
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
) )


def forward(self, input_ids, token_type_ids, words_embeddings=None):
def forward(self, input_ids, token_type_ids, words_embeddings=None, **kwargs):
position_ids = self.create_position_ids_from_input_ids(input_ids) position_ids = self.create_position_ids_from_input_ids(input_ids)


return super().forward( return super().forward(


+ 72
- 3
tests/core/test_trainer.py View File

@@ -14,8 +14,12 @@ from fastNLP import CrossEntropyLoss
from fastNLP import AccuracyMetric from fastNLP import AccuracyMetric
from fastNLP import SGD from fastNLP import SGD
from fastNLP import Trainer from fastNLP import Trainer
from fastNLP.models.base_model import NaiveClassifier, NaiveClassifier2, NaiveClassifier3, NaiveClassifier4
from fastNLP.models.base_model import NaiveClassifier
from fastNLP import TorchLoaderIter from fastNLP import TorchLoaderIter
from fastNLP.models import BaseModel
from fastNLP.modules import MLP
from pkg_resources import parse_version





def prepare_fake_dataset(): def prepare_fake_dataset():
@@ -577,6 +581,22 @@ class TrainerTestGround(unittest.TestCase):
""" """




class NaiveClassifier2(BaseModel):
r"""
一个简单的分类器例子,可用于各种测试
"""

def __init__(self, in_feature_dim, out_feature_dim):
super(NaiveClassifier2, self).__init__()
self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim])

def forward(self, x):
return {"predict": self.mlp(x)}

def predict(self, x):
return {"predict": torch.sigmoid(self.mlp(x)) > 0.5}


class Fp16TrainerTest(unittest.TestCase): class Fp16TrainerTest(unittest.TestCase):
def test_raise_error(self): def test_raise_error(self):
data_set = prepare_fake_dataset() data_set = prepare_fake_dataset()
@@ -605,7 +625,7 @@ class Fp16TrainerTest(unittest.TestCase):
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None, metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None,
use_tqdm=True, check_code_level=2, fp16=True, device=torch.device('cpu')) use_tqdm=True, check_code_level=2, fp16=True, device=torch.device('cpu'))


@unittest.skipIf(torch.cuda.is_available()==False, "Skip when no cuda device detch")
@unittest.skipIf(torch.cuda.is_available()==False or parse_version(torch.__version__) < parse_version('1.6'), "Skip when no cuda device detch")
def test_run_fp16(self): def test_run_fp16(self):
data_set = prepare_fake_dataset() data_set = prepare_fake_dataset()
data_set.set_input("x", flag=True) data_set.set_input("x", flag=True)
@@ -627,7 +647,7 @@ class Fp16TrainerTest(unittest.TestCase):
use_tqdm=True, check_code_level=2, fp16=True, device=0, test_use_fp16=False) use_tqdm=True, check_code_level=2, fp16=True, device=0, test_use_fp16=False)
trainer.train(load_best_model=False) trainer.train(load_best_model=False)


@unittest.skipIf(torch.cuda.device_count()<2, "Skip when lower than 1 gpus.")
@unittest.skipIf(torch.cuda.device_count()<2 or parse_version(torch.__version__) < parse_version('1.6'), "Skip when lower than 1 gpus.")
def test_run_data_parallel(self): def test_run_data_parallel(self):
data_set = prepare_fake_dataset() data_set = prepare_fake_dataset()
data_set.set_input("x", flag=True) data_set.set_input("x", flag=True)
@@ -635,6 +655,21 @@ class Fp16TrainerTest(unittest.TestCase):


train_set, dev_set = data_set.split(0.3) train_set, dev_set = data_set.split(0.3)


class NaiveClassifier2(BaseModel):
r"""
一个简单的分类器例子,可用于各种测试
"""

def __init__(self, in_feature_dim, out_feature_dim):
super(NaiveClassifier2, self).__init__()
self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim])

def forward(self, x):
return {"predict": self.mlp(x)}

def predict(self, x):
return {"predict": torch.sigmoid(self.mlp(x)) > 0.5}

model = NaiveClassifier2(2, 1) model = NaiveClassifier2(2, 1)
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCEWithLogits(pred="predict", target="y"), trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCEWithLogits(pred="predict", target="y"),
@@ -643,12 +678,46 @@ class Fp16TrainerTest(unittest.TestCase):
use_tqdm=True, check_code_level=2, fp16=True, device=[0, 1]) use_tqdm=True, check_code_level=2, fp16=True, device=[0, 1])


with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
class NaiveClassifier3(BaseModel):
r"""
一个简单的分类器例子,可用于各种测试
"""

def __init__(self, in_feature_dim, out_feature_dim):
super(NaiveClassifier3, self).__init__()
self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim])

@torch.cuda.amp.autocast()
def forward(self, x):
return {"predict": self.mlp(x)}

@torch.cuda.amp.autocast()
def predict(self, x):
return {"predict": torch.sigmoid(self.mlp(x)) > 0.5}

model = NaiveClassifier3(2, 1) model = NaiveClassifier3(2, 1)
trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCEWithLogits(pred="predict", target="y"), trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCEWithLogits(pred="predict", target="y"),
batch_size=32, n_epochs=10, print_every=50, dev_data=dev_set, batch_size=32, n_epochs=10, print_every=50, dev_data=dev_set,
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None, metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None,
use_tqdm=True, check_code_level=2, fp16=True, device=[0, 1], test_use_fp16=True) use_tqdm=True, check_code_level=2, fp16=True, device=[0, 1], test_use_fp16=True)


class NaiveClassifier4(BaseModel):
r"""
一个简单的分类器例子,可用于各种测试
"""

def __init__(self, in_feature_dim, out_feature_dim):
super(NaiveClassifier4, self).__init__()
self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim])

def forward(self, x):
with torch.cuda.amp.autocast():
return {"predict": self.mlp(x)}

def predict(self, x):
with torch.cuda.amp.autocast():
return {"predict": torch.sigmoid(self.mlp(x)) > 0.5}

model = NaiveClassifier4(2, 1) model = NaiveClassifier4(2, 1)
trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCEWithLogits(pred="predict", target="y"), trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCEWithLogits(pred="predict", target="y"),
batch_size=32, n_epochs=10, print_every=50, dev_data=dev_set, batch_size=32, n_epochs=10, print_every=50, dev_data=dev_set,


+ 27
- 23
tests/embeddings/test_bert_embedding.py View File

@@ -31,29 +31,33 @@ class TestDownload(unittest.TestCase):


class TestBertEmbedding(unittest.TestCase): class TestBertEmbedding(unittest.TestCase):
def test_bert_embedding_1(self): def test_bert_embedding_1(self):
vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInBERT".split())
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.1)
requires_grad = embed.requires_grad
embed.requires_grad = not requires_grad
embed.train()
words = torch.LongTensor([[2, 3, 4, 0]])
result = embed(words)
self.assertEqual(result.size(), (1, 4, 16))

embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.1)
embed.eval()
words = torch.LongTensor([[2, 3, 4, 0]])
result = embed(words)
self.assertEqual(result.size(), (1, 4, 16))

# 自动截断而不报错
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.1,
auto_truncate=True)

words = torch.LongTensor([[2, 3, 4, 1]*10,
[2, 3]+[0]*38])
result = embed(words)
self.assertEqual(result.size(), (2, 40, 16))
for pool_method in ['first', 'last', 'max', 'avg']:
with self.subTest(pool_method=pool_method):
vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInBERT".split())
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.1,
pool_method=pool_method)
requires_grad = embed.requires_grad
embed.requires_grad = not requires_grad
embed.train()
words = torch.LongTensor([[2, 3, 4, 0]])
result = embed(words)
self.assertEqual(result.size(), (1, 4, 16))

embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.1,
pool_method=pool_method)
embed.eval()
words = torch.LongTensor([[2, 3, 4, 0]])
result = embed(words)
self.assertEqual(result.size(), (1, 4, 16))

# 自动截断而不报错
embed = BertEmbedding(vocab, model_dir_or_name='tests/data_for_tests/embedding/small_bert', word_dropout=0.1,
auto_truncate=True, pool_method=pool_method)

words = torch.LongTensor([[2, 3, 4, 1]*10,
[2, 3]+[0]*38])
result = embed(words)
self.assertEqual(result.size(), (2, 40, 16))


def test_save_load(self): def test_save_load(self):
bert_save_test = 'bert_save_test' bert_save_test = 'bert_save_test'


+ 13
- 0
tests/embeddings/test_stack_embeddings.py View File

@@ -18,3 +18,16 @@ class TestCharEmbed(unittest.TestCase):
y = embed(x) y = embed(x)
self.assertEqual(tuple(y.size()), (2, 3, 130)) self.assertEqual(tuple(y.size()), (2, 3, 130))


def test_case_2(self):
# 测试只需要拥有一样的index就可以concat
ds = DataSet([Instance(words=['hello', 'world']), Instance(words=['hello', 'Jack'])])
vocab1 = Vocabulary().from_dataset(ds, field_name='words')
vocab2 = Vocabulary().from_dataset(ds, field_name='words')
self.assertEqual(len(vocab1), 5)
cnn_embed = CNNCharEmbedding(vocab1, embed_size=60)
lstm_embed = LSTMCharEmbedding(vocab2, embed_size=70)
embed = StackEmbedding([cnn_embed, lstm_embed])
x = torch.LongTensor([[2, 1, 0], [4, 3, 4]])
y = embed(x)
self.assertEqual(tuple(y.size()), (2, 3, 130))


Loading…
Cancel
Save