|
@@ -418,8 +418,6 @@ class ConvTokenEmbedder(nn.Module): |
|
|
|
|
|
|
|
|
self.output_dim = config['lstm']['projection_dim'] |
|
|
self.output_dim = config['lstm']['projection_dim'] |
|
|
self._options = config |
|
|
self._options = config |
|
|
self.requires_grad = False |
|
|
|
|
|
self._char_embedding_weights = char_emb_layer.weight.data |
|
|
|
|
|
|
|
|
|
|
|
char_cnn_options = self._options['char_cnn'] |
|
|
char_cnn_options = self._options['char_cnn'] |
|
|
if char_cnn_options['activation'] == 'tanh': |
|
|
if char_cnn_options['activation'] == 'tanh': |
|
@@ -557,7 +555,6 @@ class _ElmoModel(nn.Module): |
|
|
|
|
|
|
|
|
def __init__(self, model_dir: str, vocab: Vocabulary = None, cache_word_reprs: bool = False): |
|
|
def __init__(self, model_dir: str, vocab: Vocabulary = None, cache_word_reprs: bool = False): |
|
|
super(_ElmoModel, self).__init__() |
|
|
super(_ElmoModel, self).__init__() |
|
|
# self.pkl_dict = {} |
|
|
|
|
|
self.model_dir = model_dir |
|
|
self.model_dir = model_dir |
|
|
dir = os.walk(self.model_dir) |
|
|
dir = os.walk(self.model_dir) |
|
|
config_file = None |
|
|
config_file = None |
|
@@ -580,7 +577,6 @@ class _ElmoModel(nn.Module): |
|
|
config = json.load(open(os.path.join(model_dir, config_file), 'r')) |
|
|
config = json.load(open(os.path.join(model_dir, config_file), 'r')) |
|
|
self.weight_file = os.path.join(model_dir, weight_file) |
|
|
self.weight_file = os.path.join(model_dir, weight_file) |
|
|
self.config = config |
|
|
self.config = config |
|
|
self.requires_grad = False |
|
|
|
|
|
|
|
|
|
|
|
OOV_TAG = '<oov>' |
|
|
OOV_TAG = '<oov>' |
|
|
PAD_TAG = '<pad>' |
|
|
PAD_TAG = '<pad>' |
|
@@ -616,12 +612,10 @@ class _ElmoModel(nn.Module): |
|
|
char_emb_layer = nn.Embedding(len(char_vocab) + 1, int(config['char_cnn']['embedding']['dim']), |
|
|
char_emb_layer = nn.Embedding(len(char_vocab) + 1, int(config['char_cnn']['embedding']['dim']), |
|
|
padding_idx=len(char_vocab)) |
|
|
padding_idx=len(char_vocab)) |
|
|
|
|
|
|
|
|
# 读入预训练权重 这里的elmo_model 是个dict 有char_embed的值以及char_cnn和 lstm 的 state_dict |
|
|
|
|
|
elmo_pkl = open(os.path.join(self.model_dir, weight_file), "rb") |
|
|
|
|
|
elmo_model = pickle.load(elmo_pkl) |
|
|
|
|
|
elmo_pkl.close() |
|
|
|
|
|
|
|
|
# 读入预训练权重 这里的elmo_model 包含char_cnn和 lstm 的 state_dict |
|
|
|
|
|
elmo_model = torch.load(os.path.join(self.model_dir, weight_file), map_location='cpu') |
|
|
|
|
|
|
|
|
self.char_embed_weights = elmo_model["char_embed"] |
|
|
|
|
|
|
|
|
char_embed_weights = elmo_model["char_cnn"]['char_emb_layer.weight'] |
|
|
|
|
|
|
|
|
found_char_count = 0 |
|
|
found_char_count = 0 |
|
|
for char, index in char_vocab: # 调整character embedding |
|
|
for char, index in char_vocab: # 调整character embedding |
|
@@ -630,7 +624,7 @@ class _ElmoModel(nn.Module): |
|
|
found_char_count += 1 |
|
|
found_char_count += 1 |
|
|
else: |
|
|
else: |
|
|
index_in_pre = char_lexicon[OOV_TAG] |
|
|
index_in_pre = char_lexicon[OOV_TAG] |
|
|
char_emb_layer.weight.data[index] = self.char_embed_weights[index_in_pre] |
|
|
|
|
|
|
|
|
char_emb_layer.weight.data[index] = char_embed_weights[index_in_pre] |
|
|
|
|
|
|
|
|
print(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.") |
|
|
print(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.") |
|
|
# 生成words到chars的映射 |
|
|
# 生成words到chars的映射 |
|
@@ -659,7 +653,7 @@ class _ElmoModel(nn.Module): |
|
|
|
|
|
|
|
|
self.token_embedder = ConvTokenEmbedder( |
|
|
self.token_embedder = ConvTokenEmbedder( |
|
|
config, self.weight_file, None, char_emb_layer) |
|
|
config, self.weight_file, None, char_emb_layer) |
|
|
|
|
|
|
|
|
|
|
|
elmo_model["char_cnn"]['char_emb_layer.weight'] = char_emb_layer.weight |
|
|
self.token_embedder.load_state_dict(elmo_model["char_cnn"]) |
|
|
self.token_embedder.load_state_dict(elmo_model["char_cnn"]) |
|
|
|
|
|
|
|
|
self.output_dim = config['lstm']['projection_dim'] |
|
|
self.output_dim = config['lstm']['projection_dim'] |
|
@@ -707,7 +701,7 @@ class _ElmoModel(nn.Module): |
|
|
expanded_words[:, 0].fill_(self.bos_index) |
|
|
expanded_words[:, 0].fill_(self.bos_index) |
|
|
expanded_words[torch.arange(batch_size).to(words), seq_len + 1] = self.eos_index |
|
|
expanded_words[torch.arange(batch_size).to(words), seq_len + 1] = self.eos_index |
|
|
seq_len = seq_len + 2 |
|
|
seq_len = seq_len + 2 |
|
|
zero_tensor = torch.zeros(expanded_words.shape).long() |
|
|
|
|
|
|
|
|
zero_tensor = expanded_words.new_zeros(expanded_words.shape) |
|
|
mask = (expanded_words == zero_tensor).unsqueeze(-1) |
|
|
mask = (expanded_words == zero_tensor).unsqueeze(-1) |
|
|
if hasattr(self, 'cached_word_embedding'): |
|
|
if hasattr(self, 'cached_word_embedding'): |
|
|
token_embedding = self.cached_word_embedding(expanded_words) |
|
|
token_embedding = self.cached_word_embedding(expanded_words) |
|
|