|
|
@@ -180,11 +180,11 @@ class StaticEmbedding(TokenEmbedding): |
|
|
|
的名称。目前支持的embedding包括{`en` 或者 `en-glove-840b-300` : glove.840B.300d, `en-glove-6b-50` : glove.6B.50d, |
|
|
|
`en-word2vec-300` : GoogleNews-vectors-negative300}。第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。 |
|
|
|
:param requires_grad: 是否需要gradient. 默认为True |
|
|
|
:param init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。默认使用torch.nn.init.xavier_uniform_ |
|
|
|
。调用该方法时传入一个tensor对象。 |
|
|
|
|
|
|
|
:param init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。调用该方法时传入一个tensor对象。 |
|
|
|
:param normailize: 是否对vector进行normalize,使得每个vector的norm为1。 |
|
|
|
""" |
|
|
|
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', requires_grad: bool=True, init_method=None): |
|
|
|
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', requires_grad: bool=True, init_method=None, |
|
|
|
normalize=False): |
|
|
|
super(StaticEmbedding, self).__init__(vocab) |
|
|
|
|
|
|
|
# 优先定义需要下载的static embedding有哪些。这里估计需要自己搞一个server, |
|
|
@@ -202,7 +202,8 @@ class StaticEmbedding(TokenEmbedding): |
|
|
|
raise ValueError(f"Cannot recognize {model_dir_or_name}.") |
|
|
|
|
|
|
|
# 读取embedding |
|
|
|
embedding, hit_flags = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method) |
|
|
|
embedding, hit_flags = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method, |
|
|
|
normalize=normalize) |
|
|
|
self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1], |
|
|
|
padding_idx=vocab.padding_idx, |
|
|
|
max_norm=None, norm_type=2, scale_grad_by_freq=False, |
|
|
@@ -257,10 +258,7 @@ class StaticEmbedding(TokenEmbedding): |
|
|
|
assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary is supported." |
|
|
|
if not os.path.exists(embed_filepath): |
|
|
|
raise FileNotFoundError("`{}` does not exist.".format(embed_filepath)) |
|
|
|
if init_method is None: |
|
|
|
init_method = nn.init.xavier_uniform_ |
|
|
|
with open(embed_filepath, 'r', encoding='utf-8') as f: |
|
|
|
found_count = 0 |
|
|
|
line = f.readline().strip() |
|
|
|
parts = line.split() |
|
|
|
start_idx = 0 |
|
|
@@ -271,7 +269,8 @@ class StaticEmbedding(TokenEmbedding): |
|
|
|
dim = len(parts) - 1 |
|
|
|
f.seek(0) |
|
|
|
matrix = torch.zeros(len(vocab), dim) |
|
|
|
init_method(matrix) |
|
|
|
if init_method is not None: |
|
|
|
init_method(matrix) |
|
|
|
hit_flags = np.zeros(len(vocab), dtype=bool) |
|
|
|
for idx, line in enumerate(f, start_idx): |
|
|
|
try: |
|
|
@@ -286,7 +285,6 @@ class StaticEmbedding(TokenEmbedding): |
|
|
|
if word in vocab: |
|
|
|
index = vocab.to_index(word) |
|
|
|
matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim)) |
|
|
|
found_count += 1 |
|
|
|
hit_flags[index] = True |
|
|
|
except Exception as e: |
|
|
|
if error == 'ignore': |
|
|
@@ -294,7 +292,16 @@ class StaticEmbedding(TokenEmbedding): |
|
|
|
else: |
|
|
|
print("Error occurred at the {} line.".format(idx)) |
|
|
|
raise e |
|
|
|
found_count = sum(hit_flags) |
|
|
|
print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) |
|
|
|
if init_method is None: |
|
|
|
if len(vocab)-found_count>0 and found_count>0: # 有的没找到 |
|
|
|
found_vecs = matrix[torch.LongTensor(hit_flags.astype(int)).byte()] |
|
|
|
mean = found_vecs.mean(dim=0, keepdim=True) |
|
|
|
std = found_vecs.std(dim=0, keepdim=True) |
|
|
|
unfound_vec_num = np.sum(hit_flags==False) |
|
|
|
unfound_vecs = torch.randn(unfound_vec_num, dim)*std + mean |
|
|
|
matrix[torch.LongTensor(hit_flags.astype(int)).eq(0)] = unfound_vecs |
|
|
|
|
|
|
|
if normalize: |
|
|
|
matrix /= (torch.norm(matrix, dim=1, keepdim=True) + 1e-12) |
|
|
|