From 8acb0f4ad6f09dca5f655cd6830f85a852365036 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Fri, 6 Mar 2020 23:43:08 +0800 Subject: [PATCH] =?UTF-8?q?1.=E6=96=B0=E5=A2=9EVocabulary=E7=9A=84save?= =?UTF-8?q?=E4=B8=8Eload=E5=8A=9F=E8=83=BD;=202.=E4=BF=AE=E5=A4=8DBERTEmbe?= =?UTF-8?q?dding=20drop=5Fword=E7=9A=84bug;=203.=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=E5=90=84=E7=A7=8DEmbedding=E7=9A=84=E4=BD=BF=E7=94=A8=E6=B3=A8?= =?UTF-8?q?=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/vocabulary.py | 77 ++++++++++++++++++++++++++ fastNLP/embeddings/bert_embedding.py | 54 +++++++++++++----- fastNLP/embeddings/elmo_embedding.py | 8 ++- fastNLP/embeddings/static_embedding.py | 19 ++++++- fastNLP/io/file_utils.py | 2 +- test/core/test_vocabulary.py | 29 ++++++++++ test/embeddings/test_bert_embedding.py | 1 + 7 files changed, 174 insertions(+), 16 deletions(-) diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 3456061f..012cd493 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -482,3 +482,80 @@ class Vocabulary(object): def __iter__(self): for word, index in self._word2idx.items(): yield word, index + + def save(self, filepath): + """ + + :param str filepath: Vocabulary的储存路径 + :return: + """ + with open(filepath, 'w', encoding='utf-8') as f: + f.write(f'max_size\t{self.max_size}\n') + f.write(f'min_freq\t{self.min_freq}\n') + f.write(f'unknown\t{self.unknown}\n') + f.write(f'padding\t{self.padding}\n') + f.write(f'rebuild\t{self.rebuild}\n') + f.write('\n') + # idx: 如果idx为-2, 说明还没有进行build; 如果idx为-1,说明该词未编入 + # no_create_entry: 如果为1,说明该词是no_create_entry; 0 otherwise + # word \t count \t idx \t no_create_entry \n + idx = -2 + for word, count in self.word_count.items(): + if self._word2idx is not None: + idx = self._word2idx.get(word, -1) + is_no_create_entry = int(self._is_word_no_create_entry(word)) + f.write(f'{word}\t{count}\t{idx}\t{is_no_create_entry}\n') + + @staticmethod + def load(filepath): + """ + + :param str filepath: Vocabulary的读取路径 + :return: Vocabulary + """ + with open(filepath, 'r', encoding='utf-8') as f: + vocab = Vocabulary() + for line in f: + line = line.strip() + if line: + name, value = line.split() + if name == 'max_size': + vocab.max_size = int(value) if value!='None' else None + elif name == 'min_freq': + vocab.min_freq = int(value) if value!='None' else None + elif name in ('unknown', 'padding'): + setattr(vocab, name, value) + elif name == 'rebuild': + vocab.rebuild = True if value=='True' else False + else: + break + word_counter = {} + no_create_entry_counter = {} + word2idx = {} + for line in f: + line = line.strip() + if line: + parts = line.split() + word,count,idx,no_create_entry = parts[0], int(parts[1]), int(parts[2]), int(parts[3]) + if idx >= 0: + word2idx[word] = idx + word_counter[word] = count + if no_create_entry_counter: + no_create_entry_counter[word] = count + + word_counter = Counter(word_counter) + no_create_entry_counter = Counter(no_create_entry_counter) + if len(word2idx)>0: + if vocab.padding: + word2idx[vocab.padding] = 0 + if vocab.unknown: + word2idx[vocab.unknown] = 1 if vocab.padding else 0 + idx2word = {value:key for key,value in word2idx.items()} + + vocab.word_count = word_counter + vocab._no_create_word = no_create_entry_counter + if word2idx: + vocab._word2idx = word2idx + vocab._idx2word = idx2word + + return vocab diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index 44824dc0..c81a4463 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -29,7 +29,17 @@ class BertEmbedding(ContextualEmbedding): 预训练的bert模型长度限制为512个token,而因为输入的word是未进行word piece分割的(word piece的分割有BertEmbedding在输入word 时切分),在分割之后长度可能会超过最大长度限制。 - BertEmbedding可以支持自动下载权重,当前支持的模型有以下的几种(待补充): + BertEmbedding可以支持自动下载权重,当前支持的模型: + en: base-cased + en-large-cased-wwm: + en-large-cased: + en-large-uncased: + en-large-uncased-wwm + cn: 中文BERT wwm by HIT + cn-base: 中文BERT base-chinese + cn-wwm-ext: 中文BERT wwm by HIT with extra data pretrain. + multi-base-cased: multilingual cased + multi-base-uncased: multilingual uncased Example:: @@ -70,6 +80,9 @@ class BertEmbedding(ContextualEmbedding): """ super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) + if word_dropout>0: + assert vocab.unknown != None, "When word_drop>0, Vocabulary must contain the unknown token." + if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: if 'cn' in model_dir_or_name.lower() and pool_method not in ('first', 'last'): logger.warning("For Chinese bert, pooled_method should choose from 'first', 'last' in order to achieve" @@ -84,7 +97,8 @@ class BertEmbedding(ContextualEmbedding): self.model = _WordBertModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, pool_method=pool_method, include_cls_sep=include_cls_sep, pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=2) - + self._sep_index = self.model._sep_index + self._cls_index = self.model._cls_index self.requires_grad = requires_grad self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size @@ -117,21 +131,35 @@ class BertEmbedding(ContextualEmbedding): """ if self.word_dropout > 0 and self.training: with torch.no_grad(): - if self._word_sep_index: # 不能drop sep - sep_mask = words.eq(self._word_sep_index) + not_sep_mask = words.ne(self._sep_index) + not_cls_mask = words.ne(self._cls_index) + if self._word_sep_index: + not_sep_mask = not_sep_mask.__and__(words.ne(self._word_sep_index)) + replaceable_mask = not_sep_mask.__and__(not_cls_mask) mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 pad_mask = words.ne(0) - mask = pad_mask.__and__(mask) # pad的位置不为unk + mask = pad_mask.__and__(mask).__and__(replaceable_mask) # pad的位置不为unk words = words.masked_fill(mask, self._word_unk_index) - if self._word_sep_index: - words.masked_fill_(sep_mask, self._word_sep_index) return words class BertWordPieceEncoder(nn.Module): """ 读取bert模型,读取之后调用index_dataset方法在dataset中生成word_pieces这一列。 + + BertWordPieceEncoder可以支持自动下载权重,当前支持的模型: + en: base-cased + en-large-cased-wwm: + en-large-cased: + en-large-uncased: + en-large-uncased-wwm + cn: 中文BERT wwm by HIT + cn-base: 中文BERT base-chinese + cn-wwm-ext: 中文BERT wwm by HIT with extra data pretrain. + multi-base-cased: multilingual cased + multi-base-uncased: multilingual uncased + """ def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False, @@ -149,6 +177,7 @@ class BertWordPieceEncoder(nn.Module): self.model = _WordPieceBertModel(model_dir_or_name=model_dir_or_name, layers=layers, pooled_cls=pooled_cls) self._sep_index = self.model._sep_index + self._cls_index = self.model._cls_index self._wordpiece_pad_index = self.model._wordpiece_pad_index self._wordpiece_unk_index = self.model._wordpiece_unknown_index self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size @@ -212,15 +241,14 @@ class BertWordPieceEncoder(nn.Module): """ if self.word_dropout > 0 and self.training: with torch.no_grad(): - if self._word_sep_index: # 不能drop sep - sep_mask = words.eq(self._wordpiece_unk_index) + not_sep_mask = words.ne(self._sep_index) + not_cls_mask = words.ne(self._cls_index) + replaceable_mask = not_sep_mask.__and__(not_cls_mask) mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 pad_mask = words.ne(self._wordpiece_pad_index) - mask = pad_mask.__and__(mask) # pad的位置不为unk - words = words.masked_fill(mask, self._word_unk_index) - if self._word_sep_index: - words.masked_fill_(sep_mask, self._wordpiece_unk_index) + mask = pad_mask.__and__(mask).__and__(replaceable_mask) # pad的位置不为unk + words = words.masked_fill(mask, self._wordpiece_unk_index) return words diff --git a/fastNLP/embeddings/elmo_embedding.py b/fastNLP/embeddings/elmo_embedding.py index f2d643f7..ce077ebf 100644 --- a/fastNLP/embeddings/elmo_embedding.py +++ b/fastNLP/embeddings/elmo_embedding.py @@ -24,7 +24,13 @@ from ..modules.encoder._elmo import ElmobiLm, ConvTokenEmbedder class ElmoEmbedding(ContextualEmbedding): """ - 使用ELMo的embedding。初始化之后,只需要传入words就可以得到对应的embedding。当前支持的使用名称初始化的模型有以下的这些(待补充) + 使用ELMo的embedding。初始化之后,只需要传入words就可以得到对应的embedding。 + 当前支持的使用名称初始化的模型: + en: 即en-medium hidden_size 1024; output_size 12 + en-medium: hidden_size 2048; output_size 256 + en-origial: hidden_size 4096; output_size 512 + en-original-5.5b: hidden_size 4096; output_size 512 + en-small: hidden_size 1024; output_size 128 Example:: diff --git a/fastNLP/embeddings/static_embedding.py b/fastNLP/embeddings/static_embedding.py index 42d30bac..6d32431d 100644 --- a/fastNLP/embeddings/static_embedding.py +++ b/fastNLP/embeddings/static_embedding.py @@ -26,7 +26,24 @@ class StaticEmbedding(TokenEmbedding): """ StaticEmbedding组件. 给定预训练embedding的名称或路径,根据vocab从embedding中抽取相应的数据(只会将出现在vocab中的词抽取出来, 如果没有找到,则会随机初始化一个值(但如果该word是被标记为no_create_entry的话,则不会单独创建一个值,而是会被指向unk的index))。 - 当前支持自动下载的预训练vector有以下的几种(待补充); + 当前支持自动下载的预训练vector有: + en: 实际为en-glove-840b-300d(常用) + en-glove-6b-50d: glove官方的50d向量 + en-glove-6b-100d: glove官方的100d向量 + en-glove-6b-200d: glove官方的200d向量 + en-glove-6b-300d: glove官方的300d向量 + en-glove-42b-300d: glove官方使用42B数据训练版本 + en-glove-840b-300d: + en-glove-twitter-27b-25d: + en-glove-twitter-27b-50d: + en-glove-twitter-27b-100d: + en-glove-twitter-27b-200d: + en-word2vec-300d: word2vec官方发布的300d向量 + en-fasttext-crawl: fasttext官方发布的300d英文预训练 + cn-char-fastnlp-100d: fastNLP训练的100d的character embedding + cn-bi-fastnlp-100d: fastNLP训练的100d的bigram embedding + cn-tri-fastnlp-100d: fastNLP训练的100d的trigram embedding + cn-fasttext: fasttext官方发布的300d中文预训练embedding Example:: diff --git a/fastNLP/io/file_utils.py b/fastNLP/io/file_utils.py index 5195ec74..4f29de43 100644 --- a/fastNLP/io/file_utils.py +++ b/fastNLP/io/file_utils.py @@ -70,7 +70,7 @@ PRETRAIN_STATIC_FILES = { 'en-glove-twitter-27b-100d': 'glove.twitter.27B.100d.zip', 'en-glove-twitter-27b-200d': 'glove.twitter.27B.200d.zip', - 'en-word2vec-300': "GoogleNews-vectors-negative300.txt.gz", + 'en-word2vec-300d': "GoogleNews-vectors-negative300.txt.gz", 'en-fasttext-wiki': "wiki-news-300d-1M.vec.zip", 'en-fasttext-crawl': "crawl-300d-2M.vec.zip", diff --git a/test/core/test_vocabulary.py b/test/core/test_vocabulary.py index 5d8d4269..81a01092 100644 --- a/test/core/test_vocabulary.py +++ b/test/core/test_vocabulary.py @@ -189,3 +189,32 @@ class TestOther(unittest.TestCase): vocab.update(["hahahha", "hhh", "vvvv", "ass", "asss", "jfweiong", "eqgfeg", "feqfw"]) # this will print a warning self.assertEqual(vocab.rebuild, True) + + def test_save_and_load(self): + fp = 'vocab_save_test.txt' + try: + # check word2idx没变,no_create_entry正常 + words = list('abcdefaddfdkjfe') + no_create_entry = list('12342331') + unk = '[UNK]' + vocab = Vocabulary(unknown=unk, max_size=500) + + vocab.add_word_lst(words) + vocab.add_word_lst(no_create_entry, no_create_entry=True) + vocab.save(fp) + + new_vocab = Vocabulary.load(fp) + + for word, index in vocab: + self.assertEqual(new_vocab.to_index(word), index) + for word in no_create_entry: + self.assertTrue(new_vocab._is_word_no_create_entry(word)) + for word in words: + self.assertFalse(new_vocab._is_word_no_create_entry(word)) + for idx in range(len(vocab)): + self.assertEqual(vocab.to_word(idx), new_vocab.to_word(idx)) + self.assertEqual(vocab.unknown, new_vocab.unknown) + except: + import os + if os.path.exists(fp): + os.remove(fp) diff --git a/test/embeddings/test_bert_embedding.py b/test/embeddings/test_bert_embedding.py index 2a8550c3..cc35a4e4 100644 --- a/test/embeddings/test_bert_embedding.py +++ b/test/embeddings/test_bert_embedding.py @@ -46,3 +46,4 @@ class TestBertWordPieceEncoder(unittest.TestCase): ds = DataSet({'words': ["this is a test . [SEP]".split()]}) embed.index_datasets(ds, field_name='words') self.assertTrue(ds.has_field('word_pieces')) + result = embed(torch.LongTensor([[1,2,3,4]]))