diff --git a/fastNLP/modules/encoder/_elmo.py b/fastNLP/modules/encoder/_elmo.py index 6b08edc8..b887c6b1 100644 --- a/fastNLP/modules/encoder/_elmo.py +++ b/fastNLP/modules/encoder/_elmo.py @@ -418,8 +418,6 @@ class ConvTokenEmbedder(nn.Module): self.output_dim = config['lstm']['projection_dim'] self._options = config - self.requires_grad = False - self._char_embedding_weights = char_emb_layer.weight.data char_cnn_options = self._options['char_cnn'] if char_cnn_options['activation'] == 'tanh': @@ -557,7 +555,6 @@ class _ElmoModel(nn.Module): def __init__(self, model_dir: str, vocab: Vocabulary = None, cache_word_reprs: bool = False): super(_ElmoModel, self).__init__() - # self.pkl_dict = {} self.model_dir = model_dir dir = os.walk(self.model_dir) config_file = None @@ -580,7 +577,6 @@ class _ElmoModel(nn.Module): config = json.load(open(os.path.join(model_dir, config_file), 'r')) self.weight_file = os.path.join(model_dir, weight_file) self.config = config - self.requires_grad = False OOV_TAG = '' PAD_TAG = '' @@ -616,12 +612,10 @@ class _ElmoModel(nn.Module): char_emb_layer = nn.Embedding(len(char_vocab) + 1, int(config['char_cnn']['embedding']['dim']), padding_idx=len(char_vocab)) - # 读入预训练权重 这里的elmo_model 是个dict 有char_embed的值以及char_cnn和 lstm 的 state_dict - elmo_pkl = open(os.path.join(self.model_dir, weight_file), "rb") - elmo_model = pickle.load(elmo_pkl) - elmo_pkl.close() + # 读入预训练权重 这里的elmo_model 包含char_cnn和 lstm 的 state_dict + elmo_model = torch.load(os.path.join(self.model_dir, weight_file), map_location='cpu') - self.char_embed_weights = elmo_model["char_embed"] + char_embed_weights = elmo_model["char_cnn"]['char_emb_layer.weight'] found_char_count = 0 for char, index in char_vocab: # 调整character embedding @@ -630,7 +624,7 @@ class _ElmoModel(nn.Module): found_char_count += 1 else: index_in_pre = char_lexicon[OOV_TAG] - char_emb_layer.weight.data[index] = self.char_embed_weights[index_in_pre] + char_emb_layer.weight.data[index] = char_embed_weights[index_in_pre] print(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.") # 生成words到chars的映射 @@ -659,7 +653,7 @@ class _ElmoModel(nn.Module): self.token_embedder = ConvTokenEmbedder( config, self.weight_file, None, char_emb_layer) - + elmo_model["char_cnn"]['char_emb_layer.weight'] = char_emb_layer.weight self.token_embedder.load_state_dict(elmo_model["char_cnn"]) self.output_dim = config['lstm']['projection_dim'] @@ -707,7 +701,7 @@ class _ElmoModel(nn.Module): expanded_words[:, 0].fill_(self.bos_index) expanded_words[torch.arange(batch_size).to(words), seq_len + 1] = self.eos_index seq_len = seq_len + 2 - zero_tensor = torch.zeros(expanded_words.shape).long() + zero_tensor = expanded_words.new_zeros(expanded_words.shape) mask = (expanded_words == zero_tensor).unsqueeze(-1) if hasattr(self, 'cached_word_embedding'): token_embedding = self.cached_word_embedding(expanded_words) diff --git a/fastNLP/modules/encoder/embedding.py b/fastNLP/modules/encoder/embedding.py index f7d840ad..050a423a 100644 --- a/fastNLP/modules/encoder/embedding.py +++ b/fastNLP/modules/encoder/embedding.py @@ -539,11 +539,11 @@ class ElmoEmbedding(ContextualEmbedding): self.model = _ElmoModel(model_dir, vocab, cache_word_reprs=cache_word_reprs) if layers=='mix': - self.layer_weights = nn.Parameter(torch.zeros(self.model.config['encoder']['n_layers']+1), + self.layer_weights = nn.Parameter(torch.zeros(self.model.config['lstm']['n_layers']+1), requires_grad=requires_grad) self.gamma = nn.Parameter(torch.ones(1), requires_grad=requires_grad) self._get_outputs = self._get_mixed_outputs - self._embed_size = self.model.config['encoder']['projection_dim'] * 2 + self._embed_size = self.model.config['lstm']['projection_dim'] * 2 else: layers = list(map(int, layers.split(','))) assert len(layers) > 0, "Must choose one output" @@ -551,7 +551,7 @@ class ElmoEmbedding(ContextualEmbedding): assert 0 <= layer <= 2, "Layer index should be in range [0, 2]." self.layers = layers self._get_outputs = self._get_layer_outputs - self._embed_size = len(self.layers) * self.model.config['encoder']['projection_dim'] * 2 + self._embed_size = len(self.layers) * self.model.config['lstm']['projection_dim'] * 2 self.requires_grad = requires_grad