@@ -482,3 +482,80 @@ class Vocabulary(object): | |||
def __iter__(self): | |||
for word, index in self._word2idx.items(): | |||
yield word, index | |||
def save(self, filepath): | |||
""" | |||
:param str filepath: Vocabulary的储存路径 | |||
:return: | |||
""" | |||
with open(filepath, 'w', encoding='utf-8') as f: | |||
f.write(f'max_size\t{self.max_size}\n') | |||
f.write(f'min_freq\t{self.min_freq}\n') | |||
f.write(f'unknown\t{self.unknown}\n') | |||
f.write(f'padding\t{self.padding}\n') | |||
f.write(f'rebuild\t{self.rebuild}\n') | |||
f.write('\n') | |||
# idx: 如果idx为-2, 说明还没有进行build; 如果idx为-1,说明该词未编入 | |||
# no_create_entry: 如果为1,说明该词是no_create_entry; 0 otherwise | |||
# word \t count \t idx \t no_create_entry \n | |||
idx = -2 | |||
for word, count in self.word_count.items(): | |||
if self._word2idx is not None: | |||
idx = self._word2idx.get(word, -1) | |||
is_no_create_entry = int(self._is_word_no_create_entry(word)) | |||
f.write(f'{word}\t{count}\t{idx}\t{is_no_create_entry}\n') | |||
@staticmethod | |||
def load(filepath): | |||
""" | |||
:param str filepath: Vocabulary的读取路径 | |||
:return: Vocabulary | |||
""" | |||
with open(filepath, 'r', encoding='utf-8') as f: | |||
vocab = Vocabulary() | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
name, value = line.split() | |||
if name == 'max_size': | |||
vocab.max_size = int(value) if value!='None' else None | |||
elif name == 'min_freq': | |||
vocab.min_freq = int(value) if value!='None' else None | |||
elif name in ('unknown', 'padding'): | |||
setattr(vocab, name, value) | |||
elif name == 'rebuild': | |||
vocab.rebuild = True if value=='True' else False | |||
else: | |||
break | |||
word_counter = {} | |||
no_create_entry_counter = {} | |||
word2idx = {} | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
parts = line.split() | |||
word,count,idx,no_create_entry = parts[0], int(parts[1]), int(parts[2]), int(parts[3]) | |||
if idx >= 0: | |||
word2idx[word] = idx | |||
word_counter[word] = count | |||
if no_create_entry_counter: | |||
no_create_entry_counter[word] = count | |||
word_counter = Counter(word_counter) | |||
no_create_entry_counter = Counter(no_create_entry_counter) | |||
if len(word2idx)>0: | |||
if vocab.padding: | |||
word2idx[vocab.padding] = 0 | |||
if vocab.unknown: | |||
word2idx[vocab.unknown] = 1 if vocab.padding else 0 | |||
idx2word = {value:key for key,value in word2idx.items()} | |||
vocab.word_count = word_counter | |||
vocab._no_create_word = no_create_entry_counter | |||
if word2idx: | |||
vocab._word2idx = word2idx | |||
vocab._idx2word = idx2word | |||
return vocab |
@@ -29,7 +29,17 @@ class BertEmbedding(ContextualEmbedding): | |||
预训练的bert模型长度限制为512个token,而因为输入的word是未进行word piece分割的(word piece的分割有BertEmbedding在输入word | |||
时切分),在分割之后长度可能会超过最大长度限制。 | |||
BertEmbedding可以支持自动下载权重,当前支持的模型有以下的几种(待补充): | |||
BertEmbedding可以支持自动下载权重,当前支持的模型: | |||
en: base-cased | |||
en-large-cased-wwm: | |||
en-large-cased: | |||
en-large-uncased: | |||
en-large-uncased-wwm | |||
cn: 中文BERT wwm by HIT | |||
cn-base: 中文BERT base-chinese | |||
cn-wwm-ext: 中文BERT wwm by HIT with extra data pretrain. | |||
multi-base-cased: multilingual cased | |||
multi-base-uncased: multilingual uncased | |||
Example:: | |||
@@ -70,6 +80,9 @@ class BertEmbedding(ContextualEmbedding): | |||
""" | |||
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||
if word_dropout>0: | |||
assert vocab.unknown != None, "When word_drop>0, Vocabulary must contain the unknown token." | |||
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | |||
if 'cn' in model_dir_or_name.lower() and pool_method not in ('first', 'last'): | |||
logger.warning("For Chinese bert, pooled_method should choose from 'first', 'last' in order to achieve" | |||
@@ -84,7 +97,8 @@ class BertEmbedding(ContextualEmbedding): | |||
self.model = _WordBertModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | |||
pool_method=pool_method, include_cls_sep=include_cls_sep, | |||
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=2) | |||
self._sep_index = self.model._sep_index | |||
self._cls_index = self.model._cls_index | |||
self.requires_grad = requires_grad | |||
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | |||
@@ -117,21 +131,35 @@ class BertEmbedding(ContextualEmbedding): | |||
""" | |||
if self.word_dropout > 0 and self.training: | |||
with torch.no_grad(): | |||
if self._word_sep_index: # 不能drop sep | |||
sep_mask = words.eq(self._word_sep_index) | |||
not_sep_mask = words.ne(self._sep_index) | |||
not_cls_mask = words.ne(self._cls_index) | |||
if self._word_sep_index: | |||
not_sep_mask = not_sep_mask.__and__(words.ne(self._word_sep_index)) | |||
replaceable_mask = not_sep_mask.__and__(not_cls_mask) | |||
mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) | |||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | |||
pad_mask = words.ne(0) | |||
mask = pad_mask.__and__(mask) # pad的位置不为unk | |||
mask = pad_mask.__and__(mask).__and__(replaceable_mask) # pad的位置不为unk | |||
words = words.masked_fill(mask, self._word_unk_index) | |||
if self._word_sep_index: | |||
words.masked_fill_(sep_mask, self._word_sep_index) | |||
return words | |||
class BertWordPieceEncoder(nn.Module): | |||
""" | |||
读取bert模型,读取之后调用index_dataset方法在dataset中生成word_pieces这一列。 | |||
BertWordPieceEncoder可以支持自动下载权重,当前支持的模型: | |||
en: base-cased | |||
en-large-cased-wwm: | |||
en-large-cased: | |||
en-large-uncased: | |||
en-large-uncased-wwm | |||
cn: 中文BERT wwm by HIT | |||
cn-base: 中文BERT base-chinese | |||
cn-wwm-ext: 中文BERT wwm by HIT with extra data pretrain. | |||
multi-base-cased: multilingual cased | |||
multi-base-uncased: multilingual uncased | |||
""" | |||
def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False, | |||
@@ -149,6 +177,7 @@ class BertWordPieceEncoder(nn.Module): | |||
self.model = _WordPieceBertModel(model_dir_or_name=model_dir_or_name, layers=layers, pooled_cls=pooled_cls) | |||
self._sep_index = self.model._sep_index | |||
self._cls_index = self.model._cls_index | |||
self._wordpiece_pad_index = self.model._wordpiece_pad_index | |||
self._wordpiece_unk_index = self.model._wordpiece_unknown_index | |||
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | |||
@@ -212,15 +241,14 @@ class BertWordPieceEncoder(nn.Module): | |||
""" | |||
if self.word_dropout > 0 and self.training: | |||
with torch.no_grad(): | |||
if self._word_sep_index: # 不能drop sep | |||
sep_mask = words.eq(self._wordpiece_unk_index) | |||
not_sep_mask = words.ne(self._sep_index) | |||
not_cls_mask = words.ne(self._cls_index) | |||
replaceable_mask = not_sep_mask.__and__(not_cls_mask) | |||
mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) | |||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | |||
pad_mask = words.ne(self._wordpiece_pad_index) | |||
mask = pad_mask.__and__(mask) # pad的位置不为unk | |||
words = words.masked_fill(mask, self._word_unk_index) | |||
if self._word_sep_index: | |||
words.masked_fill_(sep_mask, self._wordpiece_unk_index) | |||
mask = pad_mask.__and__(mask).__and__(replaceable_mask) # pad的位置不为unk | |||
words = words.masked_fill(mask, self._wordpiece_unk_index) | |||
return words | |||
@@ -24,7 +24,13 @@ from ..modules.encoder._elmo import ElmobiLm, ConvTokenEmbedder | |||
class ElmoEmbedding(ContextualEmbedding): | |||
""" | |||
使用ELMo的embedding。初始化之后,只需要传入words就可以得到对应的embedding。当前支持的使用名称初始化的模型有以下的这些(待补充) | |||
使用ELMo的embedding。初始化之后,只需要传入words就可以得到对应的embedding。 | |||
当前支持的使用名称初始化的模型: | |||
en: 即en-medium hidden_size 1024; output_size 12 | |||
en-medium: hidden_size 2048; output_size 256 | |||
en-origial: hidden_size 4096; output_size 512 | |||
en-original-5.5b: hidden_size 4096; output_size 512 | |||
en-small: hidden_size 1024; output_size 128 | |||
Example:: | |||
@@ -26,7 +26,24 @@ class StaticEmbedding(TokenEmbedding): | |||
""" | |||
StaticEmbedding组件. 给定预训练embedding的名称或路径,根据vocab从embedding中抽取相应的数据(只会将出现在vocab中的词抽取出来, | |||
如果没有找到,则会随机初始化一个值(但如果该word是被标记为no_create_entry的话,则不会单独创建一个值,而是会被指向unk的index))。 | |||
当前支持自动下载的预训练vector有以下的几种(待补充); | |||
当前支持自动下载的预训练vector有: | |||
en: 实际为en-glove-840b-300d(常用) | |||
en-glove-6b-50d: glove官方的50d向量 | |||
en-glove-6b-100d: glove官方的100d向量 | |||
en-glove-6b-200d: glove官方的200d向量 | |||
en-glove-6b-300d: glove官方的300d向量 | |||
en-glove-42b-300d: glove官方使用42B数据训练版本 | |||
en-glove-840b-300d: | |||
en-glove-twitter-27b-25d: | |||
en-glove-twitter-27b-50d: | |||
en-glove-twitter-27b-100d: | |||
en-glove-twitter-27b-200d: | |||
en-word2vec-300d: word2vec官方发布的300d向量 | |||
en-fasttext-crawl: fasttext官方发布的300d英文预训练 | |||
cn-char-fastnlp-100d: fastNLP训练的100d的character embedding | |||
cn-bi-fastnlp-100d: fastNLP训练的100d的bigram embedding | |||
cn-tri-fastnlp-100d: fastNLP训练的100d的trigram embedding | |||
cn-fasttext: fasttext官方发布的300d中文预训练embedding | |||
Example:: | |||
@@ -70,7 +70,7 @@ PRETRAIN_STATIC_FILES = { | |||
'en-glove-twitter-27b-100d': 'glove.twitter.27B.100d.zip', | |||
'en-glove-twitter-27b-200d': 'glove.twitter.27B.200d.zip', | |||
'en-word2vec-300': "GoogleNews-vectors-negative300.txt.gz", | |||
'en-word2vec-300d': "GoogleNews-vectors-negative300.txt.gz", | |||
'en-fasttext-wiki': "wiki-news-300d-1M.vec.zip", | |||
'en-fasttext-crawl': "crawl-300d-2M.vec.zip", | |||
@@ -189,3 +189,32 @@ class TestOther(unittest.TestCase): | |||
vocab.update(["hahahha", "hhh", "vvvv", "ass", "asss", "jfweiong", "eqgfeg", "feqfw"]) | |||
# this will print a warning | |||
self.assertEqual(vocab.rebuild, True) | |||
def test_save_and_load(self): | |||
fp = 'vocab_save_test.txt' | |||
try: | |||
# check word2idx没变,no_create_entry正常 | |||
words = list('abcdefaddfdkjfe') | |||
no_create_entry = list('12342331') | |||
unk = '[UNK]' | |||
vocab = Vocabulary(unknown=unk, max_size=500) | |||
vocab.add_word_lst(words) | |||
vocab.add_word_lst(no_create_entry, no_create_entry=True) | |||
vocab.save(fp) | |||
new_vocab = Vocabulary.load(fp) | |||
for word, index in vocab: | |||
self.assertEqual(new_vocab.to_index(word), index) | |||
for word in no_create_entry: | |||
self.assertTrue(new_vocab._is_word_no_create_entry(word)) | |||
for word in words: | |||
self.assertFalse(new_vocab._is_word_no_create_entry(word)) | |||
for idx in range(len(vocab)): | |||
self.assertEqual(vocab.to_word(idx), new_vocab.to_word(idx)) | |||
self.assertEqual(vocab.unknown, new_vocab.unknown) | |||
except: | |||
import os | |||
if os.path.exists(fp): | |||
os.remove(fp) |
@@ -46,3 +46,4 @@ class TestBertWordPieceEncoder(unittest.TestCase): | |||
ds = DataSet({'words': ["this is a test . [SEP]".split()]}) | |||
embed.index_datasets(ds, field_name='words') | |||
self.assertTrue(ds.has_field('word_pieces')) | |||
result = embed(torch.LongTensor([[1,2,3,4]])) |