Browse Source

修改staticEmbedding的初始化方式,显示通过这种初始化在esmi上的snli更容易达到88的test acc

tags/v0.4.10
yh 6 years ago
parent
commit
40c4d216d1
1 changed files with 17 additions and 10 deletions
  1. +17
    -10
      fastNLP/modules/encoder/embedding.py

+ 17
- 10
fastNLP/modules/encoder/embedding.py View File

@@ -180,11 +180,11 @@ class StaticEmbedding(TokenEmbedding):
的名称。目前支持的embedding包括{`en` 或者 `en-glove-840b-300` : glove.840B.300d, `en-glove-6b-50` : glove.6B.50d,
`en-word2vec-300` : GoogleNews-vectors-negative300}。第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。
:param requires_grad: 是否需要gradient. 默认为True
:param init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。默认使用torch.nn.init.xavier_uniform_
。调用该方法时传入一个tensor对象。

:param init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。调用该方法时传入一个tensor对象。
:param normailize: 是否对vector进行normalize,使得每个vector的norm为1。
"""
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', requires_grad: bool=True, init_method=None):
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', requires_grad: bool=True, init_method=None,
normalize=False):
super(StaticEmbedding, self).__init__(vocab)

# 优先定义需要下载的static embedding有哪些。这里估计需要自己搞一个server,
@@ -202,7 +202,8 @@ class StaticEmbedding(TokenEmbedding):
raise ValueError(f"Cannot recognize {model_dir_or_name}.")

# 读取embedding
embedding, hit_flags = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method)
embedding, hit_flags = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method,
normalize=normalize)
self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1],
padding_idx=vocab.padding_idx,
max_norm=None, norm_type=2, scale_grad_by_freq=False,
@@ -257,10 +258,7 @@ class StaticEmbedding(TokenEmbedding):
assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary is supported."
if not os.path.exists(embed_filepath):
raise FileNotFoundError("`{}` does not exist.".format(embed_filepath))
if init_method is None:
init_method = nn.init.xavier_uniform_
with open(embed_filepath, 'r', encoding='utf-8') as f:
found_count = 0
line = f.readline().strip()
parts = line.split()
start_idx = 0
@@ -271,7 +269,8 @@ class StaticEmbedding(TokenEmbedding):
dim = len(parts) - 1
f.seek(0)
matrix = torch.zeros(len(vocab), dim)
init_method(matrix)
if init_method is not None:
init_method(matrix)
hit_flags = np.zeros(len(vocab), dtype=bool)
for idx, line in enumerate(f, start_idx):
try:
@@ -286,7 +285,6 @@ class StaticEmbedding(TokenEmbedding):
if word in vocab:
index = vocab.to_index(word)
matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim))
found_count += 1
hit_flags[index] = True
except Exception as e:
if error == 'ignore':
@@ -294,7 +292,16 @@ class StaticEmbedding(TokenEmbedding):
else:
print("Error occurred at the {} line.".format(idx))
raise e
found_count = sum(hit_flags)
print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab)))
if init_method is None:
if len(vocab)-found_count>0 and found_count>0: # 有的没找到
found_vecs = matrix[torch.LongTensor(hit_flags.astype(int)).byte()]
mean = found_vecs.mean(dim=0, keepdim=True)
std = found_vecs.std(dim=0, keepdim=True)
unfound_vec_num = np.sum(hit_flags==False)
unfound_vecs = torch.randn(unfound_vec_num, dim)*std + mean
matrix[torch.LongTensor(hit_flags.astype(int)).eq(0)] = unfound_vecs

if normalize:
matrix /= (torch.norm(matrix, dim=1, keepdim=True) + 1e-12)


Loading…
Cancel
Save