@@ -482,3 +482,80 @@ class Vocabulary(object): | |||||
def __iter__(self): | def __iter__(self): | ||||
for word, index in self._word2idx.items(): | for word, index in self._word2idx.items(): | ||||
yield word, index | yield word, index | ||||
def save(self, filepath): | |||||
""" | |||||
:param str filepath: Vocabulary的储存路径 | |||||
:return: | |||||
""" | |||||
with open(filepath, 'w', encoding='utf-8') as f: | |||||
f.write(f'max_size\t{self.max_size}\n') | |||||
f.write(f'min_freq\t{self.min_freq}\n') | |||||
f.write(f'unknown\t{self.unknown}\n') | |||||
f.write(f'padding\t{self.padding}\n') | |||||
f.write(f'rebuild\t{self.rebuild}\n') | |||||
f.write('\n') | |||||
# idx: 如果idx为-2, 说明还没有进行build; 如果idx为-1,说明该词未编入 | |||||
# no_create_entry: 如果为1,说明该词是no_create_entry; 0 otherwise | |||||
# word \t count \t idx \t no_create_entry \n | |||||
idx = -2 | |||||
for word, count in self.word_count.items(): | |||||
if self._word2idx is not None: | |||||
idx = self._word2idx.get(word, -1) | |||||
is_no_create_entry = int(self._is_word_no_create_entry(word)) | |||||
f.write(f'{word}\t{count}\t{idx}\t{is_no_create_entry}\n') | |||||
@staticmethod | |||||
def load(filepath): | |||||
""" | |||||
:param str filepath: Vocabulary的读取路径 | |||||
:return: Vocabulary | |||||
""" | |||||
with open(filepath, 'r', encoding='utf-8') as f: | |||||
vocab = Vocabulary() | |||||
for line in f: | |||||
line = line.strip() | |||||
if line: | |||||
name, value = line.split() | |||||
if name == 'max_size': | |||||
vocab.max_size = int(value) if value!='None' else None | |||||
elif name == 'min_freq': | |||||
vocab.min_freq = int(value) if value!='None' else None | |||||
elif name in ('unknown', 'padding'): | |||||
setattr(vocab, name, value) | |||||
elif name == 'rebuild': | |||||
vocab.rebuild = True if value=='True' else False | |||||
else: | |||||
break | |||||
word_counter = {} | |||||
no_create_entry_counter = {} | |||||
word2idx = {} | |||||
for line in f: | |||||
line = line.strip() | |||||
if line: | |||||
parts = line.split() | |||||
word,count,idx,no_create_entry = parts[0], int(parts[1]), int(parts[2]), int(parts[3]) | |||||
if idx >= 0: | |||||
word2idx[word] = idx | |||||
word_counter[word] = count | |||||
if no_create_entry_counter: | |||||
no_create_entry_counter[word] = count | |||||
word_counter = Counter(word_counter) | |||||
no_create_entry_counter = Counter(no_create_entry_counter) | |||||
if len(word2idx)>0: | |||||
if vocab.padding: | |||||
word2idx[vocab.padding] = 0 | |||||
if vocab.unknown: | |||||
word2idx[vocab.unknown] = 1 if vocab.padding else 0 | |||||
idx2word = {value:key for key,value in word2idx.items()} | |||||
vocab.word_count = word_counter | |||||
vocab._no_create_word = no_create_entry_counter | |||||
if word2idx: | |||||
vocab._word2idx = word2idx | |||||
vocab._idx2word = idx2word | |||||
return vocab |
@@ -29,7 +29,17 @@ class BertEmbedding(ContextualEmbedding): | |||||
预训练的bert模型长度限制为512个token,而因为输入的word是未进行word piece分割的(word piece的分割有BertEmbedding在输入word | 预训练的bert模型长度限制为512个token,而因为输入的word是未进行word piece分割的(word piece的分割有BertEmbedding在输入word | ||||
时切分),在分割之后长度可能会超过最大长度限制。 | 时切分),在分割之后长度可能会超过最大长度限制。 | ||||
BertEmbedding可以支持自动下载权重,当前支持的模型有以下的几种(待补充): | |||||
BertEmbedding可以支持自动下载权重,当前支持的模型: | |||||
en: base-cased | |||||
en-large-cased-wwm: | |||||
en-large-cased: | |||||
en-large-uncased: | |||||
en-large-uncased-wwm | |||||
cn: 中文BERT wwm by HIT | |||||
cn-base: 中文BERT base-chinese | |||||
cn-wwm-ext: 中文BERT wwm by HIT with extra data pretrain. | |||||
multi-base-cased: multilingual cased | |||||
multi-base-uncased: multilingual uncased | |||||
Example:: | Example:: | ||||
@@ -70,6 +80,9 @@ class BertEmbedding(ContextualEmbedding): | |||||
""" | """ | ||||
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | ||||
if word_dropout>0: | |||||
assert vocab.unknown != None, "When word_drop>0, Vocabulary must contain the unknown token." | |||||
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | ||||
if 'cn' in model_dir_or_name.lower() and pool_method not in ('first', 'last'): | if 'cn' in model_dir_or_name.lower() and pool_method not in ('first', 'last'): | ||||
logger.warning("For Chinese bert, pooled_method should choose from 'first', 'last' in order to achieve" | logger.warning("For Chinese bert, pooled_method should choose from 'first', 'last' in order to achieve" | ||||
@@ -84,7 +97,8 @@ class BertEmbedding(ContextualEmbedding): | |||||
self.model = _WordBertModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | self.model = _WordBertModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | ||||
pool_method=pool_method, include_cls_sep=include_cls_sep, | pool_method=pool_method, include_cls_sep=include_cls_sep, | ||||
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=2) | pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=2) | ||||
self._sep_index = self.model._sep_index | |||||
self._cls_index = self.model._cls_index | |||||
self.requires_grad = requires_grad | self.requires_grad = requires_grad | ||||
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | ||||
@@ -117,21 +131,35 @@ class BertEmbedding(ContextualEmbedding): | |||||
""" | """ | ||||
if self.word_dropout > 0 and self.training: | if self.word_dropout > 0 and self.training: | ||||
with torch.no_grad(): | with torch.no_grad(): | ||||
if self._word_sep_index: # 不能drop sep | |||||
sep_mask = words.eq(self._word_sep_index) | |||||
not_sep_mask = words.ne(self._sep_index) | |||||
not_cls_mask = words.ne(self._cls_index) | |||||
if self._word_sep_index: | |||||
not_sep_mask = not_sep_mask.__and__(words.ne(self._word_sep_index)) | |||||
replaceable_mask = not_sep_mask.__and__(not_cls_mask) | |||||
mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) | mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) | ||||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | ||||
pad_mask = words.ne(0) | pad_mask = words.ne(0) | ||||
mask = pad_mask.__and__(mask) # pad的位置不为unk | |||||
mask = pad_mask.__and__(mask).__and__(replaceable_mask) # pad的位置不为unk | |||||
words = words.masked_fill(mask, self._word_unk_index) | words = words.masked_fill(mask, self._word_unk_index) | ||||
if self._word_sep_index: | |||||
words.masked_fill_(sep_mask, self._word_sep_index) | |||||
return words | return words | ||||
class BertWordPieceEncoder(nn.Module): | class BertWordPieceEncoder(nn.Module): | ||||
""" | """ | ||||
读取bert模型,读取之后调用index_dataset方法在dataset中生成word_pieces这一列。 | 读取bert模型,读取之后调用index_dataset方法在dataset中生成word_pieces这一列。 | ||||
BertWordPieceEncoder可以支持自动下载权重,当前支持的模型: | |||||
en: base-cased | |||||
en-large-cased-wwm: | |||||
en-large-cased: | |||||
en-large-uncased: | |||||
en-large-uncased-wwm | |||||
cn: 中文BERT wwm by HIT | |||||
cn-base: 中文BERT base-chinese | |||||
cn-wwm-ext: 中文BERT wwm by HIT with extra data pretrain. | |||||
multi-base-cased: multilingual cased | |||||
multi-base-uncased: multilingual uncased | |||||
""" | """ | ||||
def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False, | def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False, | ||||
@@ -149,6 +177,7 @@ class BertWordPieceEncoder(nn.Module): | |||||
self.model = _WordPieceBertModel(model_dir_or_name=model_dir_or_name, layers=layers, pooled_cls=pooled_cls) | self.model = _WordPieceBertModel(model_dir_or_name=model_dir_or_name, layers=layers, pooled_cls=pooled_cls) | ||||
self._sep_index = self.model._sep_index | self._sep_index = self.model._sep_index | ||||
self._cls_index = self.model._cls_index | |||||
self._wordpiece_pad_index = self.model._wordpiece_pad_index | self._wordpiece_pad_index = self.model._wordpiece_pad_index | ||||
self._wordpiece_unk_index = self.model._wordpiece_unknown_index | self._wordpiece_unk_index = self.model._wordpiece_unknown_index | ||||
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | ||||
@@ -212,15 +241,14 @@ class BertWordPieceEncoder(nn.Module): | |||||
""" | """ | ||||
if self.word_dropout > 0 and self.training: | if self.word_dropout > 0 and self.training: | ||||
with torch.no_grad(): | with torch.no_grad(): | ||||
if self._word_sep_index: # 不能drop sep | |||||
sep_mask = words.eq(self._wordpiece_unk_index) | |||||
not_sep_mask = words.ne(self._sep_index) | |||||
not_cls_mask = words.ne(self._cls_index) | |||||
replaceable_mask = not_sep_mask.__and__(not_cls_mask) | |||||
mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) | mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) | ||||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | ||||
pad_mask = words.ne(self._wordpiece_pad_index) | pad_mask = words.ne(self._wordpiece_pad_index) | ||||
mask = pad_mask.__and__(mask) # pad的位置不为unk | |||||
words = words.masked_fill(mask, self._word_unk_index) | |||||
if self._word_sep_index: | |||||
words.masked_fill_(sep_mask, self._wordpiece_unk_index) | |||||
mask = pad_mask.__and__(mask).__and__(replaceable_mask) # pad的位置不为unk | |||||
words = words.masked_fill(mask, self._wordpiece_unk_index) | |||||
return words | return words | ||||
@@ -24,7 +24,13 @@ from ..modules.encoder._elmo import ElmobiLm, ConvTokenEmbedder | |||||
class ElmoEmbedding(ContextualEmbedding): | class ElmoEmbedding(ContextualEmbedding): | ||||
""" | """ | ||||
使用ELMo的embedding。初始化之后,只需要传入words就可以得到对应的embedding。当前支持的使用名称初始化的模型有以下的这些(待补充) | |||||
使用ELMo的embedding。初始化之后,只需要传入words就可以得到对应的embedding。 | |||||
当前支持的使用名称初始化的模型: | |||||
en: 即en-medium hidden_size 1024; output_size 12 | |||||
en-medium: hidden_size 2048; output_size 256 | |||||
en-origial: hidden_size 4096; output_size 512 | |||||
en-original-5.5b: hidden_size 4096; output_size 512 | |||||
en-small: hidden_size 1024; output_size 128 | |||||
Example:: | Example:: | ||||
@@ -26,7 +26,24 @@ class StaticEmbedding(TokenEmbedding): | |||||
""" | """ | ||||
StaticEmbedding组件. 给定预训练embedding的名称或路径,根据vocab从embedding中抽取相应的数据(只会将出现在vocab中的词抽取出来, | StaticEmbedding组件. 给定预训练embedding的名称或路径,根据vocab从embedding中抽取相应的数据(只会将出现在vocab中的词抽取出来, | ||||
如果没有找到,则会随机初始化一个值(但如果该word是被标记为no_create_entry的话,则不会单独创建一个值,而是会被指向unk的index))。 | 如果没有找到,则会随机初始化一个值(但如果该word是被标记为no_create_entry的话,则不会单独创建一个值,而是会被指向unk的index))。 | ||||
当前支持自动下载的预训练vector有以下的几种(待补充); | |||||
当前支持自动下载的预训练vector有: | |||||
en: 实际为en-glove-840b-300d(常用) | |||||
en-glove-6b-50d: glove官方的50d向量 | |||||
en-glove-6b-100d: glove官方的100d向量 | |||||
en-glove-6b-200d: glove官方的200d向量 | |||||
en-glove-6b-300d: glove官方的300d向量 | |||||
en-glove-42b-300d: glove官方使用42B数据训练版本 | |||||
en-glove-840b-300d: | |||||
en-glove-twitter-27b-25d: | |||||
en-glove-twitter-27b-50d: | |||||
en-glove-twitter-27b-100d: | |||||
en-glove-twitter-27b-200d: | |||||
en-word2vec-300d: word2vec官方发布的300d向量 | |||||
en-fasttext-crawl: fasttext官方发布的300d英文预训练 | |||||
cn-char-fastnlp-100d: fastNLP训练的100d的character embedding | |||||
cn-bi-fastnlp-100d: fastNLP训练的100d的bigram embedding | |||||
cn-tri-fastnlp-100d: fastNLP训练的100d的trigram embedding | |||||
cn-fasttext: fasttext官方发布的300d中文预训练embedding | |||||
Example:: | Example:: | ||||
@@ -70,7 +70,7 @@ PRETRAIN_STATIC_FILES = { | |||||
'en-glove-twitter-27b-100d': 'glove.twitter.27B.100d.zip', | 'en-glove-twitter-27b-100d': 'glove.twitter.27B.100d.zip', | ||||
'en-glove-twitter-27b-200d': 'glove.twitter.27B.200d.zip', | 'en-glove-twitter-27b-200d': 'glove.twitter.27B.200d.zip', | ||||
'en-word2vec-300': "GoogleNews-vectors-negative300.txt.gz", | |||||
'en-word2vec-300d': "GoogleNews-vectors-negative300.txt.gz", | |||||
'en-fasttext-wiki': "wiki-news-300d-1M.vec.zip", | 'en-fasttext-wiki': "wiki-news-300d-1M.vec.zip", | ||||
'en-fasttext-crawl': "crawl-300d-2M.vec.zip", | 'en-fasttext-crawl': "crawl-300d-2M.vec.zip", | ||||
@@ -189,3 +189,32 @@ class TestOther(unittest.TestCase): | |||||
vocab.update(["hahahha", "hhh", "vvvv", "ass", "asss", "jfweiong", "eqgfeg", "feqfw"]) | vocab.update(["hahahha", "hhh", "vvvv", "ass", "asss", "jfweiong", "eqgfeg", "feqfw"]) | ||||
# this will print a warning | # this will print a warning | ||||
self.assertEqual(vocab.rebuild, True) | self.assertEqual(vocab.rebuild, True) | ||||
def test_save_and_load(self): | |||||
fp = 'vocab_save_test.txt' | |||||
try: | |||||
# check word2idx没变,no_create_entry正常 | |||||
words = list('abcdefaddfdkjfe') | |||||
no_create_entry = list('12342331') | |||||
unk = '[UNK]' | |||||
vocab = Vocabulary(unknown=unk, max_size=500) | |||||
vocab.add_word_lst(words) | |||||
vocab.add_word_lst(no_create_entry, no_create_entry=True) | |||||
vocab.save(fp) | |||||
new_vocab = Vocabulary.load(fp) | |||||
for word, index in vocab: | |||||
self.assertEqual(new_vocab.to_index(word), index) | |||||
for word in no_create_entry: | |||||
self.assertTrue(new_vocab._is_word_no_create_entry(word)) | |||||
for word in words: | |||||
self.assertFalse(new_vocab._is_word_no_create_entry(word)) | |||||
for idx in range(len(vocab)): | |||||
self.assertEqual(vocab.to_word(idx), new_vocab.to_word(idx)) | |||||
self.assertEqual(vocab.unknown, new_vocab.unknown) | |||||
except: | |||||
import os | |||||
if os.path.exists(fp): | |||||
os.remove(fp) |
@@ -46,3 +46,4 @@ class TestBertWordPieceEncoder(unittest.TestCase): | |||||
ds = DataSet({'words': ["this is a test . [SEP]".split()]}) | ds = DataSet({'words': ["this is a test . [SEP]".split()]}) | ||||
embed.index_datasets(ds, field_name='words') | embed.index_datasets(ds, field_name='words') | ||||
self.assertTrue(ds.has_field('word_pieces')) | self.assertTrue(ds.has_field('word_pieces')) | ||||
result = embed(torch.LongTensor([[1,2,3,4]])) |