diff --git a/fastNLP/modules/encoder/embedding.py b/fastNLP/modules/encoder/embedding.py index a58668da..c48cb806 100644 --- a/fastNLP/modules/encoder/embedding.py +++ b/fastNLP/modules/encoder/embedding.py @@ -180,11 +180,11 @@ class StaticEmbedding(TokenEmbedding): 的名称。目前支持的embedding包括{`en` 或者 `en-glove-840b-300` : glove.840B.300d, `en-glove-6b-50` : glove.6B.50d, `en-word2vec-300` : GoogleNews-vectors-negative300}。第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。 :param requires_grad: 是否需要gradient. 默认为True - :param init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。默认使用torch.nn.init.xavier_uniform_ - 。调用该方法时传入一个tensor对象。 - + :param init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。调用该方法时传入一个tensor对象。 + :param normailize: 是否对vector进行normalize,使得每个vector的norm为1。 """ - def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', requires_grad: bool=True, init_method=None): + def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', requires_grad: bool=True, init_method=None, + normalize=False): super(StaticEmbedding, self).__init__(vocab) # 优先定义需要下载的static embedding有哪些。这里估计需要自己搞一个server, @@ -202,7 +202,8 @@ class StaticEmbedding(TokenEmbedding): raise ValueError(f"Cannot recognize {model_dir_or_name}.") # 读取embedding - embedding, hit_flags = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method) + embedding, hit_flags = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method, + normalize=normalize) self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1], padding_idx=vocab.padding_idx, max_norm=None, norm_type=2, scale_grad_by_freq=False, @@ -257,10 +258,7 @@ class StaticEmbedding(TokenEmbedding): assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary is supported." if not os.path.exists(embed_filepath): raise FileNotFoundError("`{}` does not exist.".format(embed_filepath)) - if init_method is None: - init_method = nn.init.xavier_uniform_ with open(embed_filepath, 'r', encoding='utf-8') as f: - found_count = 0 line = f.readline().strip() parts = line.split() start_idx = 0 @@ -271,7 +269,8 @@ class StaticEmbedding(TokenEmbedding): dim = len(parts) - 1 f.seek(0) matrix = torch.zeros(len(vocab), dim) - init_method(matrix) + if init_method is not None: + init_method(matrix) hit_flags = np.zeros(len(vocab), dtype=bool) for idx, line in enumerate(f, start_idx): try: @@ -286,7 +285,6 @@ class StaticEmbedding(TokenEmbedding): if word in vocab: index = vocab.to_index(word) matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim)) - found_count += 1 hit_flags[index] = True except Exception as e: if error == 'ignore': @@ -294,7 +292,16 @@ class StaticEmbedding(TokenEmbedding): else: print("Error occurred at the {} line.".format(idx)) raise e + found_count = sum(hit_flags) print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) + if init_method is None: + if len(vocab)-found_count>0 and found_count>0: # 有的没找到 + found_vecs = matrix[torch.LongTensor(hit_flags.astype(int)).byte()] + mean = found_vecs.mean(dim=0, keepdim=True) + std = found_vecs.std(dim=0, keepdim=True) + unfound_vec_num = np.sum(hit_flags==False) + unfound_vecs = torch.randn(unfound_vec_num, dim)*std + mean + matrix[torch.LongTensor(hit_flags.astype(int)).eq(0)] = unfound_vecs if normalize: matrix /= (torch.norm(matrix, dim=1, keepdim=True) + 1e-12)