Browse Source

修改_elmo.py的权重加载

tags/v0.4.10
yh_cc 6 years ago
parent
commit
e867641023
1 changed files with 288 additions and 170 deletions
  1. +288
    -170
      fastNLP/modules/encoder/_elmo.py

+ 288
- 170
fastNLP/modules/encoder/_elmo.py View File

@@ -1,12 +1,13 @@

""" """
这个页面的代码大量参考了https://github.com/HIT-SCIR/ELMoForManyLangs/tree/master/elmoformanylangs
这个页面的代码大量参考了 allenNLP
""" """



from typing import Optional, Tuple, List, Callable from typing import Optional, Tuple, List, Callable


import os import os

import h5py
import numpy
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@@ -16,7 +17,6 @@ import json


from ..utils import get_dropout_mask from ..utils import get_dropout_mask
import codecs import codecs
from torch import autograd


class LstmCellWithProjection(torch.nn.Module): class LstmCellWithProjection(torch.nn.Module):
""" """
@@ -58,6 +58,7 @@ class LstmCellWithProjection(torch.nn.Module):
respectively. The first dimension is 1 in order to match the Pytorch respectively. The first dimension is 1 in order to match the Pytorch
API for returning stacked LSTM states. API for returning stacked LSTM states.
""" """

def __init__(self, def __init__(self,
input_size: int, input_size: int,
hidden_size: int, hidden_size: int,
@@ -129,13 +130,13 @@ class LstmCellWithProjection(torch.nn.Module):
# We have to use this '.data.new().fill_' pattern to create tensors with the correct # We have to use this '.data.new().fill_' pattern to create tensors with the correct
# type - forward has no knowledge of whether these are torch.Tensors or torch.cuda.Tensors. # type - forward has no knowledge of whether these are torch.Tensors or torch.cuda.Tensors.
output_accumulator = inputs.data.new(batch_size, output_accumulator = inputs.data.new(batch_size,
total_timesteps,
self.hidden_size).fill_(0)
total_timesteps,
self.hidden_size).fill_(0)
if initial_state is None: if initial_state is None:
full_batch_previous_memory = inputs.data.new(batch_size, full_batch_previous_memory = inputs.data.new(batch_size,
self.cell_size).fill_(0)
self.cell_size).fill_(0)
full_batch_previous_state = inputs.data.new(batch_size, full_batch_previous_state = inputs.data.new(batch_size,
self.hidden_size).fill_(0)
self.hidden_size).fill_(0)
else: else:
full_batch_previous_state = initial_state[0].squeeze(0) full_batch_previous_state = initial_state[0].squeeze(0)
full_batch_previous_memory = initial_state[1].squeeze(0) full_batch_previous_memory = initial_state[1].squeeze(0)
@@ -169,7 +170,7 @@ class LstmCellWithProjection(torch.nn.Module):
# Second conditional: Does the next shortest sequence beyond the current batch # Second conditional: Does the next shortest sequence beyond the current batch
# index require computation use this timestep? # index require computation use this timestep?
while current_length_index < (len(batch_lengths) - 1) and \ while current_length_index < (len(batch_lengths) - 1) and \
batch_lengths[current_length_index + 1] > index:
batch_lengths[current_length_index + 1] > index:
current_length_index += 1 current_length_index += 1


# Actually get the slices of the batch which we # Actually get the slices of the batch which we
@@ -256,7 +257,7 @@ class LstmbiLm(nn.Module):
inputs = inputs[sort_idx] inputs = inputs[sort_idx]
inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=self.batch_first) inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=self.batch_first)
output, hx = self.encoder(inputs, None) # -> [N,L,C] output, hx = self.encoder(inputs, None) # -> [N,L,C]
output, _ = nn.util.rnn.pad_packed_sequence(output, batch_first=self.batch_first)
output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=self.batch_first)
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)
output = output[unsort_idx] output = output[unsort_idx]
forward, backward = output.split(self.config['encoder']['dim'], 2) forward, backward = output.split(self.config['encoder']['dim'], 2)
@@ -316,13 +317,13 @@ class ElmobiLm(torch.nn.Module):
:param seq_len: batch_size :param seq_len: batch_size
:return: torch.FloatTensor. num_layers x batch_size x max_len x hidden_size :return: torch.FloatTensor. num_layers x batch_size x max_len x hidden_size
""" """
max_len = inputs.size(1)
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True)
inputs = inputs[sort_idx] inputs = inputs[sort_idx]
inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=True) inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=True)
output, _ = self._lstm_forward(inputs, None) output, _ = self._lstm_forward(inputs, None)
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)
output = output[:, unsort_idx] output = output[:, unsort_idx]

return output return output


def _lstm_forward(self, def _lstm_forward(self,
@@ -399,7 +400,7 @@ class ElmobiLm(torch.nn.Module):
torch.cat([forward_state[1], backward_state[1]], -1))) torch.cat([forward_state[1], backward_state[1]], -1)))


stacked_sequence_outputs: torch.FloatTensor = torch.stack(sequence_outputs) stacked_sequence_outputs: torch.FloatTensor = torch.stack(sequence_outputs)
# Stack the hidden state and memory for each layer into 2 tensors of shape
# Stack the hidden state and memory for each layer into 2 tensors of shape
# (num_layers, batch_size, hidden_size) and (num_layers, batch_size, cell_size) # (num_layers, batch_size, hidden_size) and (num_layers, batch_size, cell_size)
# respectively. # respectively.
final_hidden_states, final_memory_states = zip(*final_states) final_hidden_states, final_memory_states = zip(*final_states)
@@ -408,6 +409,66 @@ class ElmobiLm(torch.nn.Module):
torch.cat(final_memory_states, 0)) torch.cat(final_memory_states, 0))
return stacked_sequence_outputs, final_state_tuple return stacked_sequence_outputs, final_state_tuple


def load_weights(self, weight_file: str) -> None:
"""
Load the pre-trained weights from the file.
"""
requires_grad = False

with h5py.File(weight_file, 'r') as fin:
for i_layer, lstms in enumerate(
zip(self.forward_layers, self.backward_layers)
):
for j_direction, lstm in enumerate(lstms):
# lstm is an instance of LSTMCellWithProjection
cell_size = lstm.cell_size

dataset = fin['RNN_%s' % j_direction]['RNN']['MultiRNNCell']['Cell%s' % i_layer
]['LSTMCell']

# tensorflow packs together both W and U matrices into one matrix,
# but pytorch maintains individual matrices. In addition, tensorflow
# packs the gates as input, memory, forget, output but pytorch
# uses input, forget, memory, output. So we need to modify the weights.
tf_weights = numpy.transpose(dataset['W_0'][...])
torch_weights = tf_weights.copy()

# split the W from U matrices
input_size = lstm.input_size
input_weights = torch_weights[:, :input_size]
recurrent_weights = torch_weights[:, input_size:]
tf_input_weights = tf_weights[:, :input_size]
tf_recurrent_weights = tf_weights[:, input_size:]

# handle the different gate order convention
for torch_w, tf_w in [[input_weights, tf_input_weights],
[recurrent_weights, tf_recurrent_weights]]:
torch_w[(1 * cell_size):(2 * cell_size), :] = tf_w[(2 * cell_size):(3 * cell_size), :]
torch_w[(2 * cell_size):(3 * cell_size), :] = tf_w[(1 * cell_size):(2 * cell_size), :]

lstm.input_linearity.weight.data.copy_(torch.FloatTensor(input_weights))
lstm.state_linearity.weight.data.copy_(torch.FloatTensor(recurrent_weights))
lstm.input_linearity.weight.requires_grad = requires_grad
lstm.state_linearity.weight.requires_grad = requires_grad

# the bias weights
tf_bias = dataset['B'][...]
# tensorflow adds 1.0 to forget gate bias instead of modifying the
# parameters...
tf_bias[(2 * cell_size):(3 * cell_size)] += 1
torch_bias = tf_bias.copy()
torch_bias[(1 * cell_size):(2 * cell_size)
] = tf_bias[(2 * cell_size):(3 * cell_size)]
torch_bias[(2 * cell_size):(3 * cell_size)
] = tf_bias[(1 * cell_size):(2 * cell_size)]
lstm.state_linearity.bias.data.copy_(torch.FloatTensor(torch_bias))
lstm.state_linearity.bias.requires_grad = requires_grad

# the projection weights
proj_weights = numpy.transpose(dataset['W_P_0'][...])
lstm.state_projection.weight.data.copy_(torch.FloatTensor(proj_weights))
lstm.state_projection.weight.requires_grad = requires_grad



class LstmTokenEmbedder(nn.Module): class LstmTokenEmbedder(nn.Module):
def __init__(self, config, word_emb_layer, char_emb_layer): def __init__(self, config, word_emb_layer, char_emb_layer):
@@ -441,7 +502,7 @@ class LstmTokenEmbedder(nn.Module):
chars_emb = self.char_emb_layer(chars) chars_emb = self.char_emb_layer(chars)
# TODO 这里应该要考虑seq_len的问题 # TODO 这里应该要考虑seq_len的问题
_, (chars_outputs, __) = self.char_lstm(chars_emb) _, (chars_outputs, __) = self.char_lstm(chars_emb)
chars_outputs = chars_outputs.contiguous().view(-1, self.config['token_embedder']['char_dim'] * 2)
chars_outputs = chars_outputs.contiguous().view(-1, self.config['token_embedder']['embedding']['dim'] * 2)
embs.append(chars_outputs) embs.append(chars_outputs)


token_embedding = torch.cat(embs, dim=2) token_embedding = torch.cat(embs, dim=2)
@@ -450,79 +511,143 @@ class LstmTokenEmbedder(nn.Module):




class ConvTokenEmbedder(nn.Module): class ConvTokenEmbedder(nn.Module):
def __init__(self, config, word_emb_layer, char_emb_layer):
def __init__(self, config, weight_file, word_emb_layer, char_emb_layer, char_vocab):
super(ConvTokenEmbedder, self).__init__() super(ConvTokenEmbedder, self).__init__()
self.config = config
self.weight_file = weight_file
self.word_emb_layer = word_emb_layer self.word_emb_layer = word_emb_layer
self.char_emb_layer = char_emb_layer self.char_emb_layer = char_emb_layer


self.output_dim = config['encoder']['projection_dim'] self.output_dim = config['encoder']['projection_dim']
self.emb_dim = 0
if word_emb_layer is not None:
self.emb_dim += word_emb_layer.weight.size(1)

if char_emb_layer is not None:
self.convolutions = []
cnn_config = config['token_embedder']
filters = cnn_config['filters']
char_embed_dim = cnn_config['char_dim']

for i, (width, num) in enumerate(filters):
conv = torch.nn.Conv1d(
in_channels=char_embed_dim,
out_channels=num,
kernel_size=width,
bias=True
)
self.convolutions.append(conv)

self.convolutions = nn.ModuleList(self.convolutions)

self.n_filters = sum(f[1] for f in filters)
self.n_highway = cnn_config['n_highway']

self.highways = Highway(self.n_filters, self.n_highway, activation=torch.nn.functional.relu)
self.emb_dim += self.n_filters

self.projection = nn.Linear(self.emb_dim, self.output_dim, bias=True)
self._options = config
self.requires_grad = False
self._load_weights()
self._char_embedding_weights = char_emb_layer.weight.data

def _load_weights(self):
self._load_cnn_weights()
self._load_highway()
self._load_projection()

def _load_cnn_weights(self):
cnn_options = self._options['token_embedder']
filters = cnn_options['filters']
char_embed_dim = cnn_options['embedding']['dim']

convolutions = []
for i, (width, num) in enumerate(filters):
conv = torch.nn.Conv1d(
in_channels=char_embed_dim,
out_channels=num,
kernel_size=width,
bias=True
)
# load the weights
with h5py.File(self.weight_file, 'r') as fin:
weight = fin['CNN']['W_cnn_{}'.format(i)][...]
bias = fin['CNN']['b_cnn_{}'.format(i)][...]

w_reshaped = numpy.transpose(weight.squeeze(axis=0), axes=(2, 1, 0))
if w_reshaped.shape != tuple(conv.weight.data.shape):
raise ValueError("Invalid weight file")
conv.weight.data.copy_(torch.FloatTensor(w_reshaped))
conv.bias.data.copy_(torch.FloatTensor(bias))

conv.weight.requires_grad = self.requires_grad
conv.bias.requires_grad = self.requires_grad

convolutions.append(conv)
self.add_module('char_conv_{}'.format(i), conv)

self._convolutions = convolutions

def _load_highway(self):
# the highway layers have same dimensionality as the number of cnn filters
cnn_options = self._options['token_embedder']
filters = cnn_options['filters']
n_filters = sum(f[1] for f in filters)
n_highway = cnn_options['n_highway']

# create the layers, and load the weights
self._highways = Highway(n_filters, n_highway, activation=torch.nn.functional.relu)
for k in range(n_highway):
# The AllenNLP highway is one matrix multplication with concatenation of
# transform and carry weights.
with h5py.File(self.weight_file, 'r') as fin:
# The weights are transposed due to multiplication order assumptions in tf
# vs pytorch (tf.matmul(X, W) vs pytorch.matmul(W, X))
w_transform = numpy.transpose(fin['CNN_high_{}'.format(k)]['W_transform'][...])
# -1.0 since AllenNLP is g * x + (1 - g) * f(x) but tf is (1 - g) * x + g * f(x)
w_carry = -1.0 * numpy.transpose(fin['CNN_high_{}'.format(k)]['W_carry'][...])
weight = numpy.concatenate([w_transform, w_carry], axis=0)
self._highways._layers[k].weight.data.copy_(torch.FloatTensor(weight))
self._highways._layers[k].weight.requires_grad = self.requires_grad

b_transform = fin['CNN_high_{}'.format(k)]['b_transform'][...]
b_carry = -1.0 * fin['CNN_high_{}'.format(k)]['b_carry'][...]
bias = numpy.concatenate([b_transform, b_carry], axis=0)
self._highways._layers[k].bias.data.copy_(torch.FloatTensor(bias))
self._highways._layers[k].bias.requires_grad = self.requires_grad

def _load_projection(self):
cnn_options = self._options['token_embedder']
filters = cnn_options['filters']
n_filters = sum(f[1] for f in filters)

self._projection = torch.nn.Linear(n_filters, self.output_dim, bias=True)
with h5py.File(self.weight_file, 'r') as fin:
weight = fin['CNN_proj']['W_proj'][...]
bias = fin['CNN_proj']['b_proj'][...]
self._projection.weight.data.copy_(torch.FloatTensor(numpy.transpose(weight)))
self._projection.bias.data.copy_(torch.FloatTensor(bias))

self._projection.weight.requires_grad = self.requires_grad
self._projection.bias.requires_grad = self.requires_grad


def forward(self, words, chars): def forward(self, words, chars):
embs = []
if self.word_emb_layer is not None:
if hasattr(self, 'words_to_words'):
words = self.words_to_words[words]
word_emb = self.word_emb_layer(words)
embs.append(word_emb)
"""
:param words:
:param chars: Tensor Shape ``(batch_size, sequence_length, 50)``:
:return Tensor Shape ``(batch_size, sequence_length + 2, embedding_dim)`` :
"""
# the character id embedding
# (batch_size * sequence_length, max_chars_per_token, embed_dim)
# character_embedding = torch.nn.functional.embedding(
# chars.view(-1, max_chars_per_token),
# self._char_embedding_weights
# )
batch_size, sequence_length, max_char_len = chars.size()
character_embedding = self.char_emb_layer(chars).reshape(batch_size*sequence_length, max_char_len, -1)
# run convolutions
cnn_options = self._options['token_embedder']
if cnn_options['activation'] == 'tanh':
activation = torch.tanh
elif cnn_options['activation'] == 'relu':
activation = torch.nn.functional.relu
else:
raise Exception("Unknown activation")


if self.char_emb_layer is not None:
batch_size, seq_len, _ = chars.size()
chars = chars.view(batch_size * seq_len, -1)
character_embedding = self.char_emb_layer(chars)
character_embedding = torch.transpose(character_embedding, 1, 2)

cnn_config = self.config['token_embedder']
if cnn_config['activation'] == 'tanh':
activation = torch.nn.functional.tanh
elif cnn_config['activation'] == 'relu':
activation = torch.nn.functional.relu
else:
raise Exception("Unknown activation")
# (batch_size * sequence_length, embed_dim, max_chars_per_token)
character_embedding = torch.transpose(character_embedding, 1, 2)
convs = []
for i in range(len(self._convolutions)):
conv = getattr(self, 'char_conv_{}'.format(i))
convolved = conv(character_embedding)
# (batch_size * sequence_length, n_filters for this width)
convolved, _ = torch.max(convolved, dim=-1)
convolved = activation(convolved)
convs.append(convolved)


convs = []
for i in range(len(self.convolutions)):
convolved = self.convolutions[i](character_embedding)
# (batch_size * sequence_length, n_filters for this width)
convolved, _ = torch.max(convolved, dim=-1)
convolved = activation(convolved)
convs.append(convolved)
char_emb = torch.cat(convs, dim=-1)
char_emb = self.highways(char_emb)
# (batch_size * sequence_length, n_filters)
token_embedding = torch.cat(convs, dim=-1)


embs.append(char_emb.view(batch_size, -1, self.n_filters))
# apply the highway layers (batch_size * sequence_length, n_filters)
token_embedding = self._highways(token_embedding)


token_embedding = torch.cat(embs, dim=2)
# final projection (batch_size * sequence_length, embedding_dim)
token_embedding = self._projection(token_embedding)


return self.projection(token_embedding)
# reshape to (batch_size, sequence_length+2, embedding_dim)
return token_embedding.view(batch_size, sequence_length, -1)




class Highway(torch.nn.Module): class Highway(torch.nn.Module):
@@ -543,6 +668,7 @@ class Highway(torch.nn.Module):
activation : ``Callable[[torch.Tensor], torch.Tensor]``, optional (default=``torch.nn.functional.relu``) activation : ``Callable[[torch.Tensor], torch.Tensor]``, optional (default=``torch.nn.functional.relu``)
The non-linearity to use in the highway layers. The non-linearity to use in the highway layers.
""" """

def __init__(self, def __init__(self,
input_dim: int, input_dim: int,
num_layers: int = 1, num_layers: int = 1,
@@ -573,6 +699,7 @@ class Highway(torch.nn.Module):
current_input = gate * linear_part + (1 - gate) * nonlinear_part current_input = gate * linear_part + (1 - gate) * nonlinear_part
return current_input return current_input



class _ElmoModel(nn.Module): class _ElmoModel(nn.Module):
""" """
该Module是ElmoEmbedding中进行所有的heavy lifting的地方。做的工作,包括 该Module是ElmoEmbedding中进行所有的heavy lifting的地方。做的工作,包括
@@ -582,11 +709,32 @@ class _ElmoModel(nn.Module):
(4) 设计一个保存token的embedding,允许缓存word的表示。 (4) 设计一个保存token的embedding,允许缓存word的表示。


""" """
def __init__(self, model_dir:str, vocab:Vocabulary=None, cache_word_reprs:bool=False):

def __init__(self, model_dir: str, vocab: Vocabulary = None, cache_word_reprs: bool = False):
super(_ElmoModel, self).__init__() super(_ElmoModel, self).__init__()
config = json.load(open(os.path.join(model_dir, 'structure_config.json'), 'r'))


dir = os.walk(model_dir)
config_file = None
weight_file = None
config_count = 0
weight_count = 0
for path, dir_list, file_list in dir:
for file_name in file_list:
if file_name.__contains__(".json"):
config_file = file_name
config_count += 1
elif file_name.__contains__(".hdf5"):
weight_file = file_name
weight_count += 1
if config_count > 1 or weight_count > 1:
raise Exception(f"Multiple config files(*.json) or weight files(*.hdf5) detected in {model_dir}.")
elif config_count == 0 or weight_count == 0:
raise Exception(f"No config file or weight file found in {model_dir}")

config = json.load(open(os.path.join(model_dir, config_file), 'r'))
self.weight_file = os.path.join(model_dir, weight_file)
self.config = config self.config = config
self.requires_grad = False


OOV_TAG = '<oov>' OOV_TAG = '<oov>'
PAD_TAG = '<pad>' PAD_TAG = '<pad>'
@@ -595,48 +743,8 @@ class _ElmoModel(nn.Module):
BOW_TAG = '<bow>' BOW_TAG = '<bow>'
EOW_TAG = '<eow>' EOW_TAG = '<eow>'


# 将加载embedding放到这里
token_embedder_states = torch.load(os.path.join(model_dir, 'token_embedder.pkl'), map_location='cpu')

# For the model trained with word form word encoder.
if config['token_embedder']['word_dim'] > 0:
word_lexicon = {}
with codecs.open(os.path.join(model_dir, 'word.dic'), 'r', encoding='utf-8') as fpi:
for line in fpi:
tokens = line.strip().split('\t')
if len(tokens) == 1:
tokens.insert(0, '\u3000')
token, i = tokens
word_lexicon[token] = int(i)
# 做一些sanity check
for special_word in [PAD_TAG, OOV_TAG, BOS_TAG, EOS_TAG]:
assert special_word in word_lexicon, f"{special_word} not found in word.dic."
# 根据vocab调整word_embedding
pre_word_embedding = token_embedder_states.pop('word_emb_layer.embedding.weight')
word_emb_layer = nn.Embedding(len(vocab)+2, config['token_embedder']['word_dim']) #多增加两个是为了<bos>与<eos>
found_word_count = 0
for word, index in vocab:
if index == vocab.unknown_idx: # 因为fastNLP的unknow是<unk> 而在这里是<oov>所以ugly强制适配一下
index_in_pre = word_lexicon[OOV_TAG]
found_word_count += 1
elif index == vocab.padding_idx: # 需要pad对齐
index_in_pre = word_lexicon[PAD_TAG]
found_word_count += 1
elif word in word_lexicon:
index_in_pre = word_lexicon[word]
found_word_count += 1
else:
index_in_pre = word_lexicon[OOV_TAG]
word_emb_layer.weight.data[index] = pre_word_embedding[index_in_pre]
print(f"{found_word_count} out of {len(vocab)} words were found in pretrained elmo embedding.")
word_emb_layer.weight.data[-1] = pre_word_embedding[word_lexicon[EOS_TAG]]
word_emb_layer.weight.data[-2] = pre_word_embedding[word_lexicon[BOS_TAG]]
self.word_vocab = vocab
else:
word_emb_layer = None

# For the model trained with character-based word encoder. # For the model trained with character-based word encoder.
if config['token_embedder']['char_dim'] > 0:
if config['token_embedder']['embedding']['dim'] > 0:
char_lexicon = {} char_lexicon = {}
with codecs.open(os.path.join(model_dir, 'char.dic'), 'r', encoding='utf-8') as fpi: with codecs.open(os.path.join(model_dir, 'char.dic'), 'r', encoding='utf-8') as fpi:
for line in fpi: for line in fpi:
@@ -645,22 +753,26 @@ class _ElmoModel(nn.Module):
tokens.insert(0, '\u3000') tokens.insert(0, '\u3000')
token, i = tokens token, i = tokens
char_lexicon[token] = int(i) char_lexicon[token] = int(i)

# 做一些sanity check # 做一些sanity check
for special_word in [PAD_TAG, OOV_TAG, BOW_TAG, EOW_TAG]: for special_word in [PAD_TAG, OOV_TAG, BOW_TAG, EOW_TAG]:
assert special_word in char_lexicon, f"{special_word} not found in char.dic." assert special_word in char_lexicon, f"{special_word} not found in char.dic."

# 从vocab中构建char_vocab # 从vocab中构建char_vocab
char_vocab = Vocabulary(unknown=OOV_TAG, padding=PAD_TAG) char_vocab = Vocabulary(unknown=OOV_TAG, padding=PAD_TAG)
# 需要保证<bow>与<eow>在里面 # 需要保证<bow>与<eow>在里面
char_vocab.add_word(BOW_TAG)
char_vocab.add_word(EOW_TAG)
char_vocab.add_word_lst([BOW_TAG, EOW_TAG, BOS_TAG, EOS_TAG])
for word, index in vocab: for word, index in vocab:
char_vocab.add_word_lst(list(word)) char_vocab.add_word_lst(list(word))
# 保证<eos>, <bos>也在
char_vocab.add_word_lst(list(BOS_TAG))
char_vocab.add_word_lst(list(EOS_TAG))
# 根据char_lexicon调整
char_emb_layer = nn.Embedding(len(char_vocab), int(config['token_embedder']['char_dim']))
pre_char_embedding = token_embedder_states.pop('char_emb_layer.embedding.weight')

self.bos_index, self.eos_index, self._pad_index = len(vocab), len(vocab)+1, vocab.padding_idx
# 根据char_lexicon调整, 多设置一位,是预留给word padding的(该位置的char表示为全0表示)
char_emb_layer = nn.Embedding(len(char_vocab)+1, int(config['token_embedder']['embedding']['dim']),
padding_idx=len(char_vocab))
with h5py.File(self.weight_file, 'r') as fin:
char_embed_weights = fin['char_embed'][...]
char_embed_weights = torch.from_numpy(char_embed_weights)
found_char_count = 0 found_char_count = 0
for char, index in char_vocab: # 调整character embedding for char, index in char_vocab: # 调整character embedding
if char in char_lexicon: if char in char_lexicon:
@@ -668,79 +780,84 @@ class _ElmoModel(nn.Module):
found_char_count += 1 found_char_count += 1
else: else:
index_in_pre = char_lexicon[OOV_TAG] index_in_pre = char_lexicon[OOV_TAG]
char_emb_layer.weight.data[index] = pre_char_embedding[index_in_pre]
char_emb_layer.weight.data[index] = char_embed_weights[index_in_pre]

print(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.") print(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.")
# 生成words到chars的映射 # 生成words到chars的映射
if config['token_embedder']['name'].lower() == 'cnn': if config['token_embedder']['name'].lower() == 'cnn':
max_chars = config['token_embedder']['max_characters_per_token'] max_chars = config['token_embedder']['max_characters_per_token']
elif config['token_embedder']['name'].lower() == 'lstm': elif config['token_embedder']['name'].lower() == 'lstm':
max_chars = max(map(lambda x: len(x[0]), vocab)) + 2 # 需要补充两个<bow>与<eow>
max_chars = max(map(lambda x: len(x[0]), vocab)) + 2 # 需要补充两个<bow>与<eow>
else: else:
raise ValueError('Unknown token_embedder: {0}'.format(config['token_embedder']['name'])) raise ValueError('Unknown token_embedder: {0}'.format(config['token_embedder']['name']))
# 增加<bos>, <eos>所以加2.
self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab)+2, max_chars), self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab)+2, max_chars),
fill_value=char_vocab.to_index(PAD_TAG), dtype=torch.long),
fill_value=len(char_vocab),
dtype=torch.long),
requires_grad=False) requires_grad=False)
for word, index in vocab:
if len(word)+2>max_chars:
word = word[:max_chars-2]
if index==vocab.padding_idx: # 如果是pad的话,需要和给定的对齐
word = PAD_TAG
elif index==vocab.unknown_idx:
word = OOV_TAG
char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(c) for c in word] + [char_vocab.to_index(EOW_TAG)]
char_ids += [char_vocab.to_index(PAD_TAG)]*(max_chars-len(char_ids))
for word, index in list(iter(vocab)) + [(BOS_TAG, len(vocab)), (EOS_TAG, len(vocab)+1)]:
if len(word) + 2 > max_chars:
word = word[:max_chars - 2]
if index == self._pad_index:
continue
elif word == BOS_TAG or word == EOS_TAG:
char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(word)] + [
char_vocab.to_index(EOW_TAG)]
char_ids += [char_vocab.to_index(PAD_TAG)] * (max_chars - len(char_ids))
else:
char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(c) for c in word] + [
char_vocab.to_index(EOW_TAG)]
char_ids += [char_vocab.to_index(PAD_TAG)] * (max_chars - len(char_ids))
self.words_to_chars_embedding[index] = torch.LongTensor(char_ids) self.words_to_chars_embedding[index] = torch.LongTensor(char_ids)
for index, word in enumerate([BOS_TAG, EOS_TAG]): # 加上<eos>, <bos>
if len(word)+2>max_chars:
word = word[:max_chars-2]
char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(c) for c in word] + [char_vocab.to_index(EOW_TAG)]
char_ids += [char_vocab.to_index(PAD_TAG)]*(max_chars-len(char_ids))
self.words_to_chars_embedding[index+len(vocab)] = torch.LongTensor(char_ids)

self.char_vocab = char_vocab self.char_vocab = char_vocab
else: else:
char_emb_layer = None char_emb_layer = None


if config['token_embedder']['name'].lower() == 'cnn': if config['token_embedder']['name'].lower() == 'cnn':
self.token_embedder = ConvTokenEmbedder( self.token_embedder = ConvTokenEmbedder(
config, word_emb_layer, char_emb_layer)
config, self.weight_file, None, char_emb_layer, self.char_vocab)
elif config['token_embedder']['name'].lower() == 'lstm': elif config['token_embedder']['name'].lower() == 'lstm':
self.token_embedder = LstmTokenEmbedder( self.token_embedder = LstmTokenEmbedder(
config, word_emb_layer, char_emb_layer)
self.token_embedder.load_state_dict(token_embedder_states, strict=False)
if config['token_embedder']['word_dim'] > 0 and vocab._no_create_word_length > 0: # 需要映射,使得来自于dev, test的idx指向unk
words_to_words = nn.Parameter(torch.arange(len(vocab)+2).long(), requires_grad=False)
config, None, char_emb_layer)

if config['token_embedder']['word_dim'] > 0 \
and vocab._no_create_word_length > 0: # 需要映射,使得来自于dev, test的idx指向unk
words_to_words = nn.Parameter(torch.arange(len(vocab) + 2).long(), requires_grad=False)
for word, idx in vocab: for word, idx in vocab:
if vocab._is_word_no_create_entry(word): if vocab._is_word_no_create_entry(word):
words_to_words[idx] = vocab.unknown_idx words_to_words[idx] = vocab.unknown_idx
setattr(self.token_embedder, 'words_to_words', words_to_words) setattr(self.token_embedder, 'words_to_words', words_to_words)
self.output_dim = config['encoder']['projection_dim'] self.output_dim = config['encoder']['projection_dim']


# 暂时只考虑 elmo
if config['encoder']['name'].lower() == 'elmo': if config['encoder']['name'].lower() == 'elmo':
self.encoder = ElmobiLm(config) self.encoder = ElmobiLm(config)
elif config['encoder']['name'].lower() == 'lstm': elif config['encoder']['name'].lower() == 'lstm':
self.encoder = LstmbiLm(config) self.encoder = LstmbiLm(config)
self.encoder.load_state_dict(torch.load(os.path.join(model_dir, 'encoder.pkl'),
map_location='cpu'))


self.bos_index = len(vocab)
self.eos_index = len(vocab) + 1
self._pad_index = vocab.padding_idx
self.encoder.load_weights(self.weight_file)


if cache_word_reprs: if cache_word_reprs:
if config['token_embedder']['char_dim']>0: # 只有在使用了chars的情况下有用
if config['token_embedder']['embedding']['dim'] > 0: # 只有在使用了chars的情况下有用
print("Start to generate cache word representations.") print("Start to generate cache word representations.")
batch_size = 320 batch_size = 320
num_batches = self.words_to_chars_embedding.size(0)//batch_size + \
int(self.words_to_chars_embedding.size(0)%batch_size!=0)
self.cached_word_embedding = nn.Embedding(self.words_to_chars_embedding.size(0),
# bos eos
word_size = self.words_to_chars_embedding.size(0)
num_batches = word_size // batch_size + \
int(word_size % batch_size != 0)

self.cached_word_embedding = nn.Embedding(word_size,
config['encoder']['projection_dim']) config['encoder']['projection_dim'])
with torch.no_grad(): with torch.no_grad():
for i in range(num_batches): for i in range(num_batches):
words = torch.arange(i*batch_size, min((i+1)*batch_size, self.words_to_chars_embedding.size(0))).long()
words = torch.arange(i * batch_size,
min((i + 1) * batch_size, word_size)).long()
chars = self.words_to_chars_embedding[words].unsqueeze(1) # batch_size x 1 x max_chars chars = self.words_to_chars_embedding[words].unsqueeze(1) # batch_size x 1 x max_chars
word_reprs = self.token_embedder(words.unsqueeze(1), chars).detach() # batch_size x 1 x config['encoder']['projection_dim']
word_reprs = self.token_embedder(words.unsqueeze(1),
chars).detach() # batch_size x 1 x config['encoder']['projection_dim']
self.cached_word_embedding.weight.data[words] = word_reprs.squeeze(1) self.cached_word_embedding.weight.data[words] = word_reprs.squeeze(1)

print("Finish generating cached word representations. Going to delete the character encoder.") print("Finish generating cached word representations. Going to delete the character encoder.")
del self.token_embedder, self.words_to_chars_embedding del self.token_embedder, self.words_to_chars_embedding
else: else:
@@ -758,7 +875,7 @@ class _ElmoModel(nn.Module):
seq_len = words.ne(self._pad_index).sum(dim=-1) seq_len = words.ne(self._pad_index).sum(dim=-1)
expanded_words[:, 1:-1] = words expanded_words[:, 1:-1] = words
expanded_words[:, 0].fill_(self.bos_index) expanded_words[:, 0].fill_(self.bos_index)
expanded_words[torch.arange(batch_size).to(words), seq_len+1] = self.eos_index
expanded_words[torch.arange(batch_size).to(words), seq_len + 1] = self.eos_index
seq_len = seq_len + 2 seq_len = seq_len + 2
if hasattr(self, 'cached_word_embedding'): if hasattr(self, 'cached_word_embedding'):
token_embedding = self.cached_word_embedding(expanded_words) token_embedding = self.cached_word_embedding(expanded_words)
@@ -767,16 +884,18 @@ class _ElmoModel(nn.Module):
chars = self.words_to_chars_embedding[expanded_words] chars = self.words_to_chars_embedding[expanded_words]
else: else:
chars = None chars = None
token_embedding = self.token_embedder(expanded_words, chars)
token_embedding = self.token_embedder(expanded_words, chars) # batch_size x max_len x embed_dim

if self.config['encoder']['name'] == 'elmo': if self.config['encoder']['name'] == 'elmo':
encoder_output = self.encoder(token_embedding, seq_len) encoder_output = self.encoder(token_embedding, seq_len)
if encoder_output.size(2) < max_len+2:
dummy_tensor = encoder_output.new_zeros(encoder_output.size(0), batch_size,
max_len + 2 - encoder_output.size(2), encoder_output.size(-1))
encoder_output = torch.cat([encoder_output, dummy_tensor], 2)
sz = encoder_output.size() # 2, batch_size, max_len, hidden_size
token_embedding = torch.cat([token_embedding, token_embedding], dim=2).view(1, sz[1], sz[2], sz[3])
encoder_output = torch.cat([token_embedding, encoder_output], dim=0)
if encoder_output.size(2) < max_len + 2:
num_layers, _, output_len, hidden_size = encoder_output.size()
dummy_tensor = encoder_output.new_zeros(num_layers, batch_size,
max_len + 2 - output_len, hidden_size)
encoder_output = torch.cat((encoder_output, dummy_tensor), 2)
sz = encoder_output.size() # 2, batch_size, max_len, hidden_size
token_embedding = torch.cat((token_embedding, token_embedding), dim=2).view(1, sz[1], sz[2], sz[3])
encoder_output = torch.cat((token_embedding, encoder_output), dim=0)
elif self.config['encoder']['name'] == 'lstm': elif self.config['encoder']['name'] == 'lstm':
encoder_output = self.encoder(token_embedding, seq_len) encoder_output = self.encoder(token_embedding, seq_len)
else: else:
@@ -784,5 +903,4 @@ class _ElmoModel(nn.Module):


# 删除<eos>, <bos>. 这里没有精确地删除,但应该也不会影响最后的结果了。 # 删除<eos>, <bos>. 这里没有精确地删除,但应该也不会影响最后的结果了。
encoder_output = encoder_output[:, :, 1:-1] encoder_output = encoder_output[:, :, 1:-1]

return encoder_output return encoder_output

Loading…
Cancel
Save