Browse Source

修复elmo的bug

tags/v0.4.10
yh_cc 5 years ago
parent
commit
9e81eae0ad
2 changed files with 9 additions and 15 deletions
  1. +6
    -12
      fastNLP/modules/encoder/_elmo.py
  2. +3
    -3
      fastNLP/modules/encoder/embedding.py

+ 6
- 12
fastNLP/modules/encoder/_elmo.py View File

@@ -418,8 +418,6 @@ class ConvTokenEmbedder(nn.Module):


self.output_dim = config['lstm']['projection_dim'] self.output_dim = config['lstm']['projection_dim']
self._options = config self._options = config
self.requires_grad = False
self._char_embedding_weights = char_emb_layer.weight.data


char_cnn_options = self._options['char_cnn'] char_cnn_options = self._options['char_cnn']
if char_cnn_options['activation'] == 'tanh': if char_cnn_options['activation'] == 'tanh':
@@ -557,7 +555,6 @@ class _ElmoModel(nn.Module):


def __init__(self, model_dir: str, vocab: Vocabulary = None, cache_word_reprs: bool = False): def __init__(self, model_dir: str, vocab: Vocabulary = None, cache_word_reprs: bool = False):
super(_ElmoModel, self).__init__() super(_ElmoModel, self).__init__()
# self.pkl_dict = {}
self.model_dir = model_dir self.model_dir = model_dir
dir = os.walk(self.model_dir) dir = os.walk(self.model_dir)
config_file = None config_file = None
@@ -580,7 +577,6 @@ class _ElmoModel(nn.Module):
config = json.load(open(os.path.join(model_dir, config_file), 'r')) config = json.load(open(os.path.join(model_dir, config_file), 'r'))
self.weight_file = os.path.join(model_dir, weight_file) self.weight_file = os.path.join(model_dir, weight_file)
self.config = config self.config = config
self.requires_grad = False


OOV_TAG = '<oov>' OOV_TAG = '<oov>'
PAD_TAG = '<pad>' PAD_TAG = '<pad>'
@@ -616,12 +612,10 @@ class _ElmoModel(nn.Module):
char_emb_layer = nn.Embedding(len(char_vocab) + 1, int(config['char_cnn']['embedding']['dim']), char_emb_layer = nn.Embedding(len(char_vocab) + 1, int(config['char_cnn']['embedding']['dim']),
padding_idx=len(char_vocab)) padding_idx=len(char_vocab))


# 读入预训练权重 这里的elmo_model 是个dict 有char_embed的值以及char_cnn和 lstm 的 state_dict
elmo_pkl = open(os.path.join(self.model_dir, weight_file), "rb")
elmo_model = pickle.load(elmo_pkl)
elmo_pkl.close()
# 读入预训练权重 这里的elmo_model 包含char_cnn和 lstm 的 state_dict
elmo_model = torch.load(os.path.join(self.model_dir, weight_file), map_location='cpu')


self.char_embed_weights = elmo_model["char_embed"]
char_embed_weights = elmo_model["char_cnn"]['char_emb_layer.weight']


found_char_count = 0 found_char_count = 0
for char, index in char_vocab: # 调整character embedding for char, index in char_vocab: # 调整character embedding
@@ -630,7 +624,7 @@ class _ElmoModel(nn.Module):
found_char_count += 1 found_char_count += 1
else: else:
index_in_pre = char_lexicon[OOV_TAG] index_in_pre = char_lexicon[OOV_TAG]
char_emb_layer.weight.data[index] = self.char_embed_weights[index_in_pre]
char_emb_layer.weight.data[index] = char_embed_weights[index_in_pre]


print(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.") print(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.")
# 生成words到chars的映射 # 生成words到chars的映射
@@ -659,7 +653,7 @@ class _ElmoModel(nn.Module):


self.token_embedder = ConvTokenEmbedder( self.token_embedder = ConvTokenEmbedder(
config, self.weight_file, None, char_emb_layer) config, self.weight_file, None, char_emb_layer)
elmo_model["char_cnn"]['char_emb_layer.weight'] = char_emb_layer.weight
self.token_embedder.load_state_dict(elmo_model["char_cnn"]) self.token_embedder.load_state_dict(elmo_model["char_cnn"])


self.output_dim = config['lstm']['projection_dim'] self.output_dim = config['lstm']['projection_dim']
@@ -707,7 +701,7 @@ class _ElmoModel(nn.Module):
expanded_words[:, 0].fill_(self.bos_index) expanded_words[:, 0].fill_(self.bos_index)
expanded_words[torch.arange(batch_size).to(words), seq_len + 1] = self.eos_index expanded_words[torch.arange(batch_size).to(words), seq_len + 1] = self.eos_index
seq_len = seq_len + 2 seq_len = seq_len + 2
zero_tensor = torch.zeros(expanded_words.shape).long()
zero_tensor = expanded_words.new_zeros(expanded_words.shape)
mask = (expanded_words == zero_tensor).unsqueeze(-1) mask = (expanded_words == zero_tensor).unsqueeze(-1)
if hasattr(self, 'cached_word_embedding'): if hasattr(self, 'cached_word_embedding'):
token_embedding = self.cached_word_embedding(expanded_words) token_embedding = self.cached_word_embedding(expanded_words)


+ 3
- 3
fastNLP/modules/encoder/embedding.py View File

@@ -539,11 +539,11 @@ class ElmoEmbedding(ContextualEmbedding):
self.model = _ElmoModel(model_dir, vocab, cache_word_reprs=cache_word_reprs) self.model = _ElmoModel(model_dir, vocab, cache_word_reprs=cache_word_reprs)


if layers=='mix': if layers=='mix':
self.layer_weights = nn.Parameter(torch.zeros(self.model.config['encoder']['n_layers']+1),
self.layer_weights = nn.Parameter(torch.zeros(self.model.config['lstm']['n_layers']+1),
requires_grad=requires_grad) requires_grad=requires_grad)
self.gamma = nn.Parameter(torch.ones(1), requires_grad=requires_grad) self.gamma = nn.Parameter(torch.ones(1), requires_grad=requires_grad)
self._get_outputs = self._get_mixed_outputs self._get_outputs = self._get_mixed_outputs
self._embed_size = self.model.config['encoder']['projection_dim'] * 2
self._embed_size = self.model.config['lstm']['projection_dim'] * 2
else: else:
layers = list(map(int, layers.split(','))) layers = list(map(int, layers.split(',')))
assert len(layers) > 0, "Must choose one output" assert len(layers) > 0, "Must choose one output"
@@ -551,7 +551,7 @@ class ElmoEmbedding(ContextualEmbedding):
assert 0 <= layer <= 2, "Layer index should be in range [0, 2]." assert 0 <= layer <= 2, "Layer index should be in range [0, 2]."
self.layers = layers self.layers = layers
self._get_outputs = self._get_layer_outputs self._get_outputs = self._get_layer_outputs
self._embed_size = len(self.layers) * self.model.config['encoder']['projection_dim'] * 2
self._embed_size = len(self.layers) * self.model.config['lstm']['projection_dim'] * 2


self.requires_grad = requires_grad self.requires_grad = requires_grad




Loading…
Cancel
Save