|
|
@@ -10,6 +10,8 @@ from ..core.vocabulary import Vocabulary |
|
|
|
from ..io.file_utils import PRETRAIN_STATIC_FILES, _get_embedding_url, cached_path |
|
|
|
from .embedding import TokenEmbedding |
|
|
|
from ..modules.utils import _get_file_name_base_on_postfix |
|
|
|
from copy import deepcopy |
|
|
|
from collections import defaultdict |
|
|
|
|
|
|
|
class StaticEmbedding(TokenEmbedding): |
|
|
|
""" |
|
|
@@ -46,12 +48,13 @@ class StaticEmbedding(TokenEmbedding): |
|
|
|
:param callable init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。调用该方法时传入一个tensor对 |
|
|
|
:param bool lower: 是否将vocab中的词语小写后再和预训练的词表进行匹配。如果你的词表中包含大写的词语,或者就是需要单独 |
|
|
|
为大写的词语开辟一个vector表示,则将lower设置为False。 |
|
|
|
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 |
|
|
|
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 |
|
|
|
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 |
|
|
|
:param bool normalize: 是否对vector进行normalize,使得每个vector的norm为1。 |
|
|
|
:param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。 |
|
|
|
""" |
|
|
|
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', embedding_dim=100, requires_grad: bool=True, |
|
|
|
init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False): |
|
|
|
init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1): |
|
|
|
super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) |
|
|
|
|
|
|
|
# 得到cache_path |
|
|
@@ -70,6 +73,28 @@ class StaticEmbedding(TokenEmbedding): |
|
|
|
else: |
|
|
|
raise ValueError(f"Cannot recognize {model_dir_or_name}.") |
|
|
|
|
|
|
|
# 缩小vocab |
|
|
|
truncate_vocab = (vocab.min_freq is None and min_freq>1) or (vocab.min_freq and vocab.min_freq<min_freq) |
|
|
|
if truncate_vocab: |
|
|
|
truncated_vocab = deepcopy(vocab) |
|
|
|
truncated_vocab.min_freq = min_freq |
|
|
|
truncated_vocab.word2idx = None |
|
|
|
if lower: # 如果有lower,将大小写的的freq需要同时考虑到 |
|
|
|
lowered_word_count = defaultdict(int) |
|
|
|
for word, count in truncated_vocab.word_count.items(): |
|
|
|
lowered_word_count[word.lower()] += count |
|
|
|
for word in truncated_vocab.word_count.keys(): |
|
|
|
word_count = truncated_vocab.word_count[word] |
|
|
|
if lowered_word_count[word.lower()]>=min_freq and word_count<min_freq: |
|
|
|
truncated_vocab.add_word_lst([word]*(min_freq-word_count), |
|
|
|
no_create_entry=truncated_vocab._is_word_no_create_entry(word)) |
|
|
|
truncated_vocab.build_vocab() |
|
|
|
truncated_words_to_words = torch.arange(len(vocab)).long() |
|
|
|
for word, index in vocab: |
|
|
|
truncated_words_to_words[index] = truncated_vocab.to_index(word) |
|
|
|
print(f"{len(vocab) - len(truncated_vocab)} out of {len(vocab)} words have frequency less than {min_freq}.") |
|
|
|
vocab = truncated_vocab |
|
|
|
|
|
|
|
# 读取embedding |
|
|
|
if lower: |
|
|
|
lowered_vocab = Vocabulary(padding=vocab.padding, unknown=vocab.unknown) |
|
|
@@ -84,9 +109,6 @@ class StaticEmbedding(TokenEmbedding): |
|
|
|
embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method) |
|
|
|
else: |
|
|
|
embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method) |
|
|
|
# 需要适配一下 |
|
|
|
if not hasattr(self, 'words_to_words'): |
|
|
|
self.words_to_words = torch.arange(len(lowered_vocab)).long() |
|
|
|
if lowered_vocab.unknown: |
|
|
|
unknown_idx = lowered_vocab.unknown_idx |
|
|
|
else: |
|
|
@@ -108,6 +130,14 @@ class StaticEmbedding(TokenEmbedding): |
|
|
|
embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method) |
|
|
|
if normalize: |
|
|
|
embedding /= (torch.norm(embedding, dim=1, keepdim=True) + 1e-12) |
|
|
|
|
|
|
|
if truncate_vocab: |
|
|
|
for i in range(len(truncated_words_to_words)): |
|
|
|
index_in_truncated_vocab = truncated_words_to_words[i] |
|
|
|
truncated_words_to_words[i] = self.words_to_words[index_in_truncated_vocab] |
|
|
|
del self.words_to_words |
|
|
|
self.words_to_words = nn.Parameter(truncated_words_to_words, requires_grad=False) |
|
|
|
|
|
|
|
self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1], |
|
|
|
padding_idx=vocab.padding_idx, |
|
|
|
max_norm=None, norm_type=2, scale_grad_by_freq=False, |
|
|
@@ -184,6 +214,10 @@ class StaticEmbedding(TokenEmbedding): |
|
|
|
dim = len(parts) - 1 |
|
|
|
f.seek(0) |
|
|
|
matrix = {} |
|
|
|
if vocab.padding: |
|
|
|
matrix[vocab.padding_idx] = torch.zeros(dim) |
|
|
|
if vocab.unknown: |
|
|
|
matrix[vocab.unknown_idx] = torch.zeros(dim) |
|
|
|
found_count = 0 |
|
|
|
for idx, line in enumerate(f, start_idx): |
|
|
|
try: |
|
|
@@ -208,35 +242,25 @@ class StaticEmbedding(TokenEmbedding): |
|
|
|
print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) |
|
|
|
for word, index in vocab: |
|
|
|
if index not in matrix and not vocab._is_word_no_create_entry(word): |
|
|
|
if vocab.padding_idx == index: |
|
|
|
matrix[index] = torch.zeros(dim) |
|
|
|
elif vocab.unknown_idx in matrix: # 如果有unkonwn,用unknown初始化 |
|
|
|
if vocab.unknown_idx in matrix: # 如果有unkonwn,用unknown初始化 |
|
|
|
matrix[index] = matrix[vocab.unknown_idx] |
|
|
|
else: |
|
|
|
matrix[index] = None |
|
|
|
# matrix中代表是需要建立entry的词 |
|
|
|
vectors = self._randomly_init_embed(len(matrix), dim, init_method) |
|
|
|
|
|
|
|
vectors = self._randomly_init_embed(len(vocab), dim, init_method) |
|
|
|
|
|
|
|
if vocab._no_create_word_length>0: |
|
|
|
if vocab.unknown is None: # 创建一个专门的unknown |
|
|
|
unknown_idx = len(matrix) |
|
|
|
vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous() |
|
|
|
else: |
|
|
|
unknown_idx = vocab.unknown_idx |
|
|
|
words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(), |
|
|
|
requires_grad=False) |
|
|
|
for word, index in vocab: |
|
|
|
vec = matrix.get(index, None) |
|
|
|
if vec is not None: |
|
|
|
vectors[index] = vec |
|
|
|
words_to_words[index] = index |
|
|
|
else: |
|
|
|
vectors[index] = vectors[unknown_idx] |
|
|
|
self.words_to_words = words_to_words |
|
|
|
if vocab.unknown is None: # 创建一个专门的unknown |
|
|
|
unknown_idx = len(matrix) |
|
|
|
vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous() |
|
|
|
else: |
|
|
|
for index, vec in matrix.items(): |
|
|
|
if vec is not None: |
|
|
|
vectors[index] = vec |
|
|
|
unknown_idx = vocab.unknown_idx |
|
|
|
self.words_to_words = nn.Parameter(torch.full((len(vocab), ), fill_value=unknown_idx).long(), |
|
|
|
requires_grad=False) |
|
|
|
|
|
|
|
for index, (index_in_vocab, vec) in enumerate(matrix.items()): |
|
|
|
if vec is not None: |
|
|
|
vectors[index] = vec |
|
|
|
self.words_to_words[index_in_vocab] = index |
|
|
|
|
|
|
|
return vectors |
|
|
|
|
|
|
|