|
@@ -1,12 +1,13 @@ |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
这个页面的代码大量参考了https://github.com/HIT-SCIR/ELMoForManyLangs/tree/master/elmoformanylangs |
|
|
|
|
|
|
|
|
这个页面的代码大量参考了 allenNLP |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Tuple, List, Callable |
|
|
from typing import Optional, Tuple, List, Callable |
|
|
|
|
|
|
|
|
import os |
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
import h5py |
|
|
|
|
|
import numpy |
|
|
import torch |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torch.nn.functional as F |
|
@@ -16,7 +17,6 @@ import json |
|
|
|
|
|
|
|
|
from ..utils import get_dropout_mask |
|
|
from ..utils import get_dropout_mask |
|
|
import codecs |
|
|
import codecs |
|
|
from torch import autograd |
|
|
|
|
|
|
|
|
|
|
|
class LstmCellWithProjection(torch.nn.Module): |
|
|
class LstmCellWithProjection(torch.nn.Module): |
|
|
""" |
|
|
""" |
|
@@ -58,6 +58,7 @@ class LstmCellWithProjection(torch.nn.Module): |
|
|
respectively. The first dimension is 1 in order to match the Pytorch |
|
|
respectively. The first dimension is 1 in order to match the Pytorch |
|
|
API for returning stacked LSTM states. |
|
|
API for returning stacked LSTM states. |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, |
|
|
def __init__(self, |
|
|
input_size: int, |
|
|
input_size: int, |
|
|
hidden_size: int, |
|
|
hidden_size: int, |
|
@@ -129,13 +130,13 @@ class LstmCellWithProjection(torch.nn.Module): |
|
|
# We have to use this '.data.new().fill_' pattern to create tensors with the correct |
|
|
# We have to use this '.data.new().fill_' pattern to create tensors with the correct |
|
|
# type - forward has no knowledge of whether these are torch.Tensors or torch.cuda.Tensors. |
|
|
# type - forward has no knowledge of whether these are torch.Tensors or torch.cuda.Tensors. |
|
|
output_accumulator = inputs.data.new(batch_size, |
|
|
output_accumulator = inputs.data.new(batch_size, |
|
|
total_timesteps, |
|
|
|
|
|
self.hidden_size).fill_(0) |
|
|
|
|
|
|
|
|
total_timesteps, |
|
|
|
|
|
self.hidden_size).fill_(0) |
|
|
if initial_state is None: |
|
|
if initial_state is None: |
|
|
full_batch_previous_memory = inputs.data.new(batch_size, |
|
|
full_batch_previous_memory = inputs.data.new(batch_size, |
|
|
self.cell_size).fill_(0) |
|
|
|
|
|
|
|
|
self.cell_size).fill_(0) |
|
|
full_batch_previous_state = inputs.data.new(batch_size, |
|
|
full_batch_previous_state = inputs.data.new(batch_size, |
|
|
self.hidden_size).fill_(0) |
|
|
|
|
|
|
|
|
self.hidden_size).fill_(0) |
|
|
else: |
|
|
else: |
|
|
full_batch_previous_state = initial_state[0].squeeze(0) |
|
|
full_batch_previous_state = initial_state[0].squeeze(0) |
|
|
full_batch_previous_memory = initial_state[1].squeeze(0) |
|
|
full_batch_previous_memory = initial_state[1].squeeze(0) |
|
@@ -169,7 +170,7 @@ class LstmCellWithProjection(torch.nn.Module): |
|
|
# Second conditional: Does the next shortest sequence beyond the current batch |
|
|
# Second conditional: Does the next shortest sequence beyond the current batch |
|
|
# index require computation use this timestep? |
|
|
# index require computation use this timestep? |
|
|
while current_length_index < (len(batch_lengths) - 1) and \ |
|
|
while current_length_index < (len(batch_lengths) - 1) and \ |
|
|
batch_lengths[current_length_index + 1] > index: |
|
|
|
|
|
|
|
|
batch_lengths[current_length_index + 1] > index: |
|
|
current_length_index += 1 |
|
|
current_length_index += 1 |
|
|
|
|
|
|
|
|
# Actually get the slices of the batch which we |
|
|
# Actually get the slices of the batch which we |
|
@@ -256,7 +257,7 @@ class LstmbiLm(nn.Module): |
|
|
inputs = inputs[sort_idx] |
|
|
inputs = inputs[sort_idx] |
|
|
inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=self.batch_first) |
|
|
inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=self.batch_first) |
|
|
output, hx = self.encoder(inputs, None) # -> [N,L,C] |
|
|
output, hx = self.encoder(inputs, None) # -> [N,L,C] |
|
|
output, _ = nn.util.rnn.pad_packed_sequence(output, batch_first=self.batch_first) |
|
|
|
|
|
|
|
|
output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=self.batch_first) |
|
|
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) |
|
|
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) |
|
|
output = output[unsort_idx] |
|
|
output = output[unsort_idx] |
|
|
forward, backward = output.split(self.config['encoder']['dim'], 2) |
|
|
forward, backward = output.split(self.config['encoder']['dim'], 2) |
|
@@ -316,13 +317,13 @@ class ElmobiLm(torch.nn.Module): |
|
|
:param seq_len: batch_size |
|
|
:param seq_len: batch_size |
|
|
:return: torch.FloatTensor. num_layers x batch_size x max_len x hidden_size |
|
|
:return: torch.FloatTensor. num_layers x batch_size x max_len x hidden_size |
|
|
""" |
|
|
""" |
|
|
|
|
|
max_len = inputs.size(1) |
|
|
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) |
|
|
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) |
|
|
inputs = inputs[sort_idx] |
|
|
inputs = inputs[sort_idx] |
|
|
inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=True) |
|
|
inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=True) |
|
|
output, _ = self._lstm_forward(inputs, None) |
|
|
output, _ = self._lstm_forward(inputs, None) |
|
|
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) |
|
|
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) |
|
|
output = output[:, unsort_idx] |
|
|
output = output[:, unsort_idx] |
|
|
|
|
|
|
|
|
return output |
|
|
return output |
|
|
|
|
|
|
|
|
def _lstm_forward(self, |
|
|
def _lstm_forward(self, |
|
@@ -399,7 +400,7 @@ class ElmobiLm(torch.nn.Module): |
|
|
torch.cat([forward_state[1], backward_state[1]], -1))) |
|
|
torch.cat([forward_state[1], backward_state[1]], -1))) |
|
|
|
|
|
|
|
|
stacked_sequence_outputs: torch.FloatTensor = torch.stack(sequence_outputs) |
|
|
stacked_sequence_outputs: torch.FloatTensor = torch.stack(sequence_outputs) |
|
|
# Stack the hidden state and memory for each layer into 2 tensors of shape |
|
|
|
|
|
|
|
|
# Stack the hidden state and memory for each layer in。to 2 tensors of shape |
|
|
# (num_layers, batch_size, hidden_size) and (num_layers, batch_size, cell_size) |
|
|
# (num_layers, batch_size, hidden_size) and (num_layers, batch_size, cell_size) |
|
|
# respectively. |
|
|
# respectively. |
|
|
final_hidden_states, final_memory_states = zip(*final_states) |
|
|
final_hidden_states, final_memory_states = zip(*final_states) |
|
@@ -408,6 +409,66 @@ class ElmobiLm(torch.nn.Module): |
|
|
torch.cat(final_memory_states, 0)) |
|
|
torch.cat(final_memory_states, 0)) |
|
|
return stacked_sequence_outputs, final_state_tuple |
|
|
return stacked_sequence_outputs, final_state_tuple |
|
|
|
|
|
|
|
|
|
|
|
def load_weights(self, weight_file: str) -> None: |
|
|
|
|
|
""" |
|
|
|
|
|
Load the pre-trained weights from the file. |
|
|
|
|
|
""" |
|
|
|
|
|
requires_grad = False |
|
|
|
|
|
|
|
|
|
|
|
with h5py.File(weight_file, 'r') as fin: |
|
|
|
|
|
for i_layer, lstms in enumerate( |
|
|
|
|
|
zip(self.forward_layers, self.backward_layers) |
|
|
|
|
|
): |
|
|
|
|
|
for j_direction, lstm in enumerate(lstms): |
|
|
|
|
|
# lstm is an instance of LSTMCellWithProjection |
|
|
|
|
|
cell_size = lstm.cell_size |
|
|
|
|
|
|
|
|
|
|
|
dataset = fin['RNN_%s' % j_direction]['RNN']['MultiRNNCell']['Cell%s' % i_layer |
|
|
|
|
|
]['LSTMCell'] |
|
|
|
|
|
|
|
|
|
|
|
# tensorflow packs together both W and U matrices into one matrix, |
|
|
|
|
|
# but pytorch maintains individual matrices. In addition, tensorflow |
|
|
|
|
|
# packs the gates as input, memory, forget, output but pytorch |
|
|
|
|
|
# uses input, forget, memory, output. So we need to modify the weights. |
|
|
|
|
|
tf_weights = numpy.transpose(dataset['W_0'][...]) |
|
|
|
|
|
torch_weights = tf_weights.copy() |
|
|
|
|
|
|
|
|
|
|
|
# split the W from U matrices |
|
|
|
|
|
input_size = lstm.input_size |
|
|
|
|
|
input_weights = torch_weights[:, :input_size] |
|
|
|
|
|
recurrent_weights = torch_weights[:, input_size:] |
|
|
|
|
|
tf_input_weights = tf_weights[:, :input_size] |
|
|
|
|
|
tf_recurrent_weights = tf_weights[:, input_size:] |
|
|
|
|
|
|
|
|
|
|
|
# handle the different gate order convention |
|
|
|
|
|
for torch_w, tf_w in [[input_weights, tf_input_weights], |
|
|
|
|
|
[recurrent_weights, tf_recurrent_weights]]: |
|
|
|
|
|
torch_w[(1 * cell_size):(2 * cell_size), :] = tf_w[(2 * cell_size):(3 * cell_size), :] |
|
|
|
|
|
torch_w[(2 * cell_size):(3 * cell_size), :] = tf_w[(1 * cell_size):(2 * cell_size), :] |
|
|
|
|
|
|
|
|
|
|
|
lstm.input_linearity.weight.data.copy_(torch.FloatTensor(input_weights)) |
|
|
|
|
|
lstm.state_linearity.weight.data.copy_(torch.FloatTensor(recurrent_weights)) |
|
|
|
|
|
lstm.input_linearity.weight.requires_grad = requires_grad |
|
|
|
|
|
lstm.state_linearity.weight.requires_grad = requires_grad |
|
|
|
|
|
|
|
|
|
|
|
# the bias weights |
|
|
|
|
|
tf_bias = dataset['B'][...] |
|
|
|
|
|
# tensorflow adds 1.0 to forget gate bias instead of modifying the |
|
|
|
|
|
# parameters... |
|
|
|
|
|
tf_bias[(2 * cell_size):(3 * cell_size)] += 1 |
|
|
|
|
|
torch_bias = tf_bias.copy() |
|
|
|
|
|
torch_bias[(1 * cell_size):(2 * cell_size) |
|
|
|
|
|
] = tf_bias[(2 * cell_size):(3 * cell_size)] |
|
|
|
|
|
torch_bias[(2 * cell_size):(3 * cell_size) |
|
|
|
|
|
] = tf_bias[(1 * cell_size):(2 * cell_size)] |
|
|
|
|
|
lstm.state_linearity.bias.data.copy_(torch.FloatTensor(torch_bias)) |
|
|
|
|
|
lstm.state_linearity.bias.requires_grad = requires_grad |
|
|
|
|
|
|
|
|
|
|
|
# the projection weights |
|
|
|
|
|
proj_weights = numpy.transpose(dataset['W_P_0'][...]) |
|
|
|
|
|
lstm.state_projection.weight.data.copy_(torch.FloatTensor(proj_weights)) |
|
|
|
|
|
lstm.state_projection.weight.requires_grad = requires_grad |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LstmTokenEmbedder(nn.Module): |
|
|
class LstmTokenEmbedder(nn.Module): |
|
|
def __init__(self, config, word_emb_layer, char_emb_layer): |
|
|
def __init__(self, config, word_emb_layer, char_emb_layer): |
|
@@ -441,7 +502,7 @@ class LstmTokenEmbedder(nn.Module): |
|
|
chars_emb = self.char_emb_layer(chars) |
|
|
chars_emb = self.char_emb_layer(chars) |
|
|
# TODO 这里应该要考虑seq_len的问题 |
|
|
# TODO 这里应该要考虑seq_len的问题 |
|
|
_, (chars_outputs, __) = self.char_lstm(chars_emb) |
|
|
_, (chars_outputs, __) = self.char_lstm(chars_emb) |
|
|
chars_outputs = chars_outputs.contiguous().view(-1, self.config['token_embedder']['char_dim'] * 2) |
|
|
|
|
|
|
|
|
chars_outputs = chars_outputs.contiguous().view(-1, self.config['token_embedder']['embedding']['dim'] * 2) |
|
|
embs.append(chars_outputs) |
|
|
embs.append(chars_outputs) |
|
|
|
|
|
|
|
|
token_embedding = torch.cat(embs, dim=2) |
|
|
token_embedding = torch.cat(embs, dim=2) |
|
@@ -450,79 +511,143 @@ class LstmTokenEmbedder(nn.Module): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConvTokenEmbedder(nn.Module): |
|
|
class ConvTokenEmbedder(nn.Module): |
|
|
def __init__(self, config, word_emb_layer, char_emb_layer): |
|
|
|
|
|
|
|
|
def __init__(self, config, weight_file, word_emb_layer, char_emb_layer, char_vocab): |
|
|
super(ConvTokenEmbedder, self).__init__() |
|
|
super(ConvTokenEmbedder, self).__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.weight_file = weight_file |
|
|
self.word_emb_layer = word_emb_layer |
|
|
self.word_emb_layer = word_emb_layer |
|
|
self.char_emb_layer = char_emb_layer |
|
|
self.char_emb_layer = char_emb_layer |
|
|
|
|
|
|
|
|
self.output_dim = config['encoder']['projection_dim'] |
|
|
self.output_dim = config['encoder']['projection_dim'] |
|
|
self.emb_dim = 0 |
|
|
|
|
|
if word_emb_layer is not None: |
|
|
|
|
|
self.emb_dim += word_emb_layer.weight.size(1) |
|
|
|
|
|
|
|
|
|
|
|
if char_emb_layer is not None: |
|
|
|
|
|
self.convolutions = [] |
|
|
|
|
|
cnn_config = config['token_embedder'] |
|
|
|
|
|
filters = cnn_config['filters'] |
|
|
|
|
|
char_embed_dim = cnn_config['char_dim'] |
|
|
|
|
|
|
|
|
|
|
|
for i, (width, num) in enumerate(filters): |
|
|
|
|
|
conv = torch.nn.Conv1d( |
|
|
|
|
|
in_channels=char_embed_dim, |
|
|
|
|
|
out_channels=num, |
|
|
|
|
|
kernel_size=width, |
|
|
|
|
|
bias=True |
|
|
|
|
|
) |
|
|
|
|
|
self.convolutions.append(conv) |
|
|
|
|
|
|
|
|
|
|
|
self.convolutions = nn.ModuleList(self.convolutions) |
|
|
|
|
|
|
|
|
|
|
|
self.n_filters = sum(f[1] for f in filters) |
|
|
|
|
|
self.n_highway = cnn_config['n_highway'] |
|
|
|
|
|
|
|
|
|
|
|
self.highways = Highway(self.n_filters, self.n_highway, activation=torch.nn.functional.relu) |
|
|
|
|
|
self.emb_dim += self.n_filters |
|
|
|
|
|
|
|
|
|
|
|
self.projection = nn.Linear(self.emb_dim, self.output_dim, bias=True) |
|
|
|
|
|
|
|
|
self._options = config |
|
|
|
|
|
self.requires_grad = False |
|
|
|
|
|
self._load_weights() |
|
|
|
|
|
self._char_embedding_weights = char_emb_layer.weight.data |
|
|
|
|
|
|
|
|
|
|
|
def _load_weights(self): |
|
|
|
|
|
self._load_cnn_weights() |
|
|
|
|
|
self._load_highway() |
|
|
|
|
|
self._load_projection() |
|
|
|
|
|
|
|
|
|
|
|
def _load_cnn_weights(self): |
|
|
|
|
|
cnn_options = self._options['token_embedder'] |
|
|
|
|
|
filters = cnn_options['filters'] |
|
|
|
|
|
char_embed_dim = cnn_options['embedding']['dim'] |
|
|
|
|
|
|
|
|
|
|
|
convolutions = [] |
|
|
|
|
|
for i, (width, num) in enumerate(filters): |
|
|
|
|
|
conv = torch.nn.Conv1d( |
|
|
|
|
|
in_channels=char_embed_dim, |
|
|
|
|
|
out_channels=num, |
|
|
|
|
|
kernel_size=width, |
|
|
|
|
|
bias=True |
|
|
|
|
|
) |
|
|
|
|
|
# load the weights |
|
|
|
|
|
with h5py.File(self.weight_file, 'r') as fin: |
|
|
|
|
|
weight = fin['CNN']['W_cnn_{}'.format(i)][...] |
|
|
|
|
|
bias = fin['CNN']['b_cnn_{}'.format(i)][...] |
|
|
|
|
|
|
|
|
|
|
|
w_reshaped = numpy.transpose(weight.squeeze(axis=0), axes=(2, 1, 0)) |
|
|
|
|
|
if w_reshaped.shape != tuple(conv.weight.data.shape): |
|
|
|
|
|
raise ValueError("Invalid weight file") |
|
|
|
|
|
conv.weight.data.copy_(torch.FloatTensor(w_reshaped)) |
|
|
|
|
|
conv.bias.data.copy_(torch.FloatTensor(bias)) |
|
|
|
|
|
|
|
|
|
|
|
conv.weight.requires_grad = self.requires_grad |
|
|
|
|
|
conv.bias.requires_grad = self.requires_grad |
|
|
|
|
|
|
|
|
|
|
|
convolutions.append(conv) |
|
|
|
|
|
self.add_module('char_conv_{}'.format(i), conv) |
|
|
|
|
|
|
|
|
|
|
|
self._convolutions = convolutions |
|
|
|
|
|
|
|
|
|
|
|
def _load_highway(self): |
|
|
|
|
|
# the highway layers have same dimensionality as the number of cnn filters |
|
|
|
|
|
cnn_options = self._options['token_embedder'] |
|
|
|
|
|
filters = cnn_options['filters'] |
|
|
|
|
|
n_filters = sum(f[1] for f in filters) |
|
|
|
|
|
n_highway = cnn_options['n_highway'] |
|
|
|
|
|
|
|
|
|
|
|
# create the layers, and load the weights |
|
|
|
|
|
self._highways = Highway(n_filters, n_highway, activation=torch.nn.functional.relu) |
|
|
|
|
|
for k in range(n_highway): |
|
|
|
|
|
# The AllenNLP highway is one matrix multplication with concatenation of |
|
|
|
|
|
# transform and carry weights. |
|
|
|
|
|
with h5py.File(self.weight_file, 'r') as fin: |
|
|
|
|
|
# The weights are transposed due to multiplication order assumptions in tf |
|
|
|
|
|
# vs pytorch (tf.matmul(X, W) vs pytorch.matmul(W, X)) |
|
|
|
|
|
w_transform = numpy.transpose(fin['CNN_high_{}'.format(k)]['W_transform'][...]) |
|
|
|
|
|
# -1.0 since AllenNLP is g * x + (1 - g) * f(x) but tf is (1 - g) * x + g * f(x) |
|
|
|
|
|
w_carry = -1.0 * numpy.transpose(fin['CNN_high_{}'.format(k)]['W_carry'][...]) |
|
|
|
|
|
weight = numpy.concatenate([w_transform, w_carry], axis=0) |
|
|
|
|
|
self._highways._layers[k].weight.data.copy_(torch.FloatTensor(weight)) |
|
|
|
|
|
self._highways._layers[k].weight.requires_grad = self.requires_grad |
|
|
|
|
|
|
|
|
|
|
|
b_transform = fin['CNN_high_{}'.format(k)]['b_transform'][...] |
|
|
|
|
|
b_carry = -1.0 * fin['CNN_high_{}'.format(k)]['b_carry'][...] |
|
|
|
|
|
bias = numpy.concatenate([b_transform, b_carry], axis=0) |
|
|
|
|
|
self._highways._layers[k].bias.data.copy_(torch.FloatTensor(bias)) |
|
|
|
|
|
self._highways._layers[k].bias.requires_grad = self.requires_grad |
|
|
|
|
|
|
|
|
|
|
|
def _load_projection(self): |
|
|
|
|
|
cnn_options = self._options['token_embedder'] |
|
|
|
|
|
filters = cnn_options['filters'] |
|
|
|
|
|
n_filters = sum(f[1] for f in filters) |
|
|
|
|
|
|
|
|
|
|
|
self._projection = torch.nn.Linear(n_filters, self.output_dim, bias=True) |
|
|
|
|
|
with h5py.File(self.weight_file, 'r') as fin: |
|
|
|
|
|
weight = fin['CNN_proj']['W_proj'][...] |
|
|
|
|
|
bias = fin['CNN_proj']['b_proj'][...] |
|
|
|
|
|
self._projection.weight.data.copy_(torch.FloatTensor(numpy.transpose(weight))) |
|
|
|
|
|
self._projection.bias.data.copy_(torch.FloatTensor(bias)) |
|
|
|
|
|
|
|
|
|
|
|
self._projection.weight.requires_grad = self.requires_grad |
|
|
|
|
|
self._projection.bias.requires_grad = self.requires_grad |
|
|
|
|
|
|
|
|
def forward(self, words, chars): |
|
|
def forward(self, words, chars): |
|
|
embs = [] |
|
|
|
|
|
if self.word_emb_layer is not None: |
|
|
|
|
|
if hasattr(self, 'words_to_words'): |
|
|
|
|
|
words = self.words_to_words[words] |
|
|
|
|
|
word_emb = self.word_emb_layer(words) |
|
|
|
|
|
embs.append(word_emb) |
|
|
|
|
|
|
|
|
""" |
|
|
|
|
|
:param words: |
|
|
|
|
|
:param chars: Tensor Shape ``(batch_size, sequence_length, 50)``: |
|
|
|
|
|
:return Tensor Shape ``(batch_size, sequence_length + 2, embedding_dim)`` : |
|
|
|
|
|
""" |
|
|
|
|
|
# the character id embedding |
|
|
|
|
|
# (batch_size * sequence_length, max_chars_per_token, embed_dim) |
|
|
|
|
|
# character_embedding = torch.nn.functional.embedding( |
|
|
|
|
|
# chars.view(-1, max_chars_per_token), |
|
|
|
|
|
# self._char_embedding_weights |
|
|
|
|
|
# ) |
|
|
|
|
|
batch_size, sequence_length, max_char_len = chars.size() |
|
|
|
|
|
character_embedding = self.char_emb_layer(chars).reshape(batch_size*sequence_length, max_char_len, -1) |
|
|
|
|
|
# run convolutions |
|
|
|
|
|
cnn_options = self._options['token_embedder'] |
|
|
|
|
|
if cnn_options['activation'] == 'tanh': |
|
|
|
|
|
activation = torch.tanh |
|
|
|
|
|
elif cnn_options['activation'] == 'relu': |
|
|
|
|
|
activation = torch.nn.functional.relu |
|
|
|
|
|
else: |
|
|
|
|
|
raise Exception("Unknown activation") |
|
|
|
|
|
|
|
|
if self.char_emb_layer is not None: |
|
|
|
|
|
batch_size, seq_len, _ = chars.size() |
|
|
|
|
|
chars = chars.view(batch_size * seq_len, -1) |
|
|
|
|
|
character_embedding = self.char_emb_layer(chars) |
|
|
|
|
|
character_embedding = torch.transpose(character_embedding, 1, 2) |
|
|
|
|
|
|
|
|
|
|
|
cnn_config = self.config['token_embedder'] |
|
|
|
|
|
if cnn_config['activation'] == 'tanh': |
|
|
|
|
|
activation = torch.nn.functional.tanh |
|
|
|
|
|
elif cnn_config['activation'] == 'relu': |
|
|
|
|
|
activation = torch.nn.functional.relu |
|
|
|
|
|
else: |
|
|
|
|
|
raise Exception("Unknown activation") |
|
|
|
|
|
|
|
|
# (batch_size * sequence_length, embed_dim, max_chars_per_token) |
|
|
|
|
|
character_embedding = torch.transpose(character_embedding, 1, 2) |
|
|
|
|
|
convs = [] |
|
|
|
|
|
for i in range(len(self._convolutions)): |
|
|
|
|
|
conv = getattr(self, 'char_conv_{}'.format(i)) |
|
|
|
|
|
convolved = conv(character_embedding) |
|
|
|
|
|
# (batch_size * sequence_length, n_filters for this width) |
|
|
|
|
|
convolved, _ = torch.max(convolved, dim=-1) |
|
|
|
|
|
convolved = activation(convolved) |
|
|
|
|
|
convs.append(convolved) |
|
|
|
|
|
|
|
|
convs = [] |
|
|
|
|
|
for i in range(len(self.convolutions)): |
|
|
|
|
|
convolved = self.convolutions[i](character_embedding) |
|
|
|
|
|
# (batch_size * sequence_length, n_filters for this width) |
|
|
|
|
|
convolved, _ = torch.max(convolved, dim=-1) |
|
|
|
|
|
convolved = activation(convolved) |
|
|
|
|
|
convs.append(convolved) |
|
|
|
|
|
char_emb = torch.cat(convs, dim=-1) |
|
|
|
|
|
char_emb = self.highways(char_emb) |
|
|
|
|
|
|
|
|
# (batch_size * sequence_length, n_filters) |
|
|
|
|
|
token_embedding = torch.cat(convs, dim=-1) |
|
|
|
|
|
|
|
|
embs.append(char_emb.view(batch_size, -1, self.n_filters)) |
|
|
|
|
|
|
|
|
# apply the highway layers (batch_size * sequence_length, n_filters) |
|
|
|
|
|
token_embedding = self._highways(token_embedding) |
|
|
|
|
|
|
|
|
token_embedding = torch.cat(embs, dim=2) |
|
|
|
|
|
|
|
|
# final projection (batch_size * sequence_length, embedding_dim) |
|
|
|
|
|
token_embedding = self._projection(token_embedding) |
|
|
|
|
|
|
|
|
return self.projection(token_embedding) |
|
|
|
|
|
|
|
|
# reshape to (batch_size, sequence_length+2, embedding_dim) |
|
|
|
|
|
return token_embedding.view(batch_size, sequence_length, -1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Highway(torch.nn.Module): |
|
|
class Highway(torch.nn.Module): |
|
@@ -543,6 +668,7 @@ class Highway(torch.nn.Module): |
|
|
activation : ``Callable[[torch.Tensor], torch.Tensor]``, optional (default=``torch.nn.functional.relu``) |
|
|
activation : ``Callable[[torch.Tensor], torch.Tensor]``, optional (default=``torch.nn.functional.relu``) |
|
|
The non-linearity to use in the highway layers. |
|
|
The non-linearity to use in the highway layers. |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, |
|
|
def __init__(self, |
|
|
input_dim: int, |
|
|
input_dim: int, |
|
|
num_layers: int = 1, |
|
|
num_layers: int = 1, |
|
@@ -573,6 +699,7 @@ class Highway(torch.nn.Module): |
|
|
current_input = gate * linear_part + (1 - gate) * nonlinear_part |
|
|
current_input = gate * linear_part + (1 - gate) * nonlinear_part |
|
|
return current_input |
|
|
return current_input |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _ElmoModel(nn.Module): |
|
|
class _ElmoModel(nn.Module): |
|
|
""" |
|
|
""" |
|
|
该Module是ElmoEmbedding中进行所有的heavy lifting的地方。做的工作,包括 |
|
|
该Module是ElmoEmbedding中进行所有的heavy lifting的地方。做的工作,包括 |
|
@@ -582,11 +709,32 @@ class _ElmoModel(nn.Module): |
|
|
(4) 设计一个保存token的embedding,允许缓存word的表示。 |
|
|
(4) 设计一个保存token的embedding,允许缓存word的表示。 |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
def __init__(self, model_dir:str, vocab:Vocabulary=None, cache_word_reprs:bool=False): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, model_dir: str, vocab: Vocabulary = None, cache_word_reprs: bool = False): |
|
|
super(_ElmoModel, self).__init__() |
|
|
super(_ElmoModel, self).__init__() |
|
|
config = json.load(open(os.path.join(model_dir, 'structure_config.json'), 'r')) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dir = os.walk(model_dir) |
|
|
|
|
|
config_file = None |
|
|
|
|
|
weight_file = None |
|
|
|
|
|
config_count = 0 |
|
|
|
|
|
weight_count = 0 |
|
|
|
|
|
for path, dir_list, file_list in dir: |
|
|
|
|
|
for file_name in file_list: |
|
|
|
|
|
if file_name.__contains__(".json"): |
|
|
|
|
|
config_file = file_name |
|
|
|
|
|
config_count += 1 |
|
|
|
|
|
elif file_name.__contains__(".hdf5"): |
|
|
|
|
|
weight_file = file_name |
|
|
|
|
|
weight_count += 1 |
|
|
|
|
|
if config_count > 1 or weight_count > 1: |
|
|
|
|
|
raise Exception(f"Multiple config files(*.json) or weight files(*.hdf5) detected in {model_dir}.") |
|
|
|
|
|
elif config_count == 0 or weight_count == 0: |
|
|
|
|
|
raise Exception(f"No config file or weight file found in {model_dir}") |
|
|
|
|
|
|
|
|
|
|
|
config = json.load(open(os.path.join(model_dir, config_file), 'r')) |
|
|
|
|
|
self.weight_file = os.path.join(model_dir, weight_file) |
|
|
self.config = config |
|
|
self.config = config |
|
|
|
|
|
self.requires_grad = False |
|
|
|
|
|
|
|
|
OOV_TAG = '<oov>' |
|
|
OOV_TAG = '<oov>' |
|
|
PAD_TAG = '<pad>' |
|
|
PAD_TAG = '<pad>' |
|
@@ -595,48 +743,8 @@ class _ElmoModel(nn.Module): |
|
|
BOW_TAG = '<bow>' |
|
|
BOW_TAG = '<bow>' |
|
|
EOW_TAG = '<eow>' |
|
|
EOW_TAG = '<eow>' |
|
|
|
|
|
|
|
|
# 将加载embedding放到这里 |
|
|
|
|
|
token_embedder_states = torch.load(os.path.join(model_dir, 'token_embedder.pkl'), map_location='cpu') |
|
|
|
|
|
|
|
|
|
|
|
# For the model trained with word form word encoder. |
|
|
|
|
|
if config['token_embedder']['word_dim'] > 0: |
|
|
|
|
|
word_lexicon = {} |
|
|
|
|
|
with codecs.open(os.path.join(model_dir, 'word.dic'), 'r', encoding='utf-8') as fpi: |
|
|
|
|
|
for line in fpi: |
|
|
|
|
|
tokens = line.strip().split('\t') |
|
|
|
|
|
if len(tokens) == 1: |
|
|
|
|
|
tokens.insert(0, '\u3000') |
|
|
|
|
|
token, i = tokens |
|
|
|
|
|
word_lexicon[token] = int(i) |
|
|
|
|
|
# 做一些sanity check |
|
|
|
|
|
for special_word in [PAD_TAG, OOV_TAG, BOS_TAG, EOS_TAG]: |
|
|
|
|
|
assert special_word in word_lexicon, f"{special_word} not found in word.dic." |
|
|
|
|
|
# 根据vocab调整word_embedding |
|
|
|
|
|
pre_word_embedding = token_embedder_states.pop('word_emb_layer.embedding.weight') |
|
|
|
|
|
word_emb_layer = nn.Embedding(len(vocab)+2, config['token_embedder']['word_dim']) #多增加两个是为了<bos>与<eos> |
|
|
|
|
|
found_word_count = 0 |
|
|
|
|
|
for word, index in vocab: |
|
|
|
|
|
if index == vocab.unknown_idx: # 因为fastNLP的unknow是<unk> 而在这里是<oov>所以ugly强制适配一下 |
|
|
|
|
|
index_in_pre = word_lexicon[OOV_TAG] |
|
|
|
|
|
found_word_count += 1 |
|
|
|
|
|
elif index == vocab.padding_idx: # 需要pad对齐 |
|
|
|
|
|
index_in_pre = word_lexicon[PAD_TAG] |
|
|
|
|
|
found_word_count += 1 |
|
|
|
|
|
elif word in word_lexicon: |
|
|
|
|
|
index_in_pre = word_lexicon[word] |
|
|
|
|
|
found_word_count += 1 |
|
|
|
|
|
else: |
|
|
|
|
|
index_in_pre = word_lexicon[OOV_TAG] |
|
|
|
|
|
word_emb_layer.weight.data[index] = pre_word_embedding[index_in_pre] |
|
|
|
|
|
print(f"{found_word_count} out of {len(vocab)} words were found in pretrained elmo embedding.") |
|
|
|
|
|
word_emb_layer.weight.data[-1] = pre_word_embedding[word_lexicon[EOS_TAG]] |
|
|
|
|
|
word_emb_layer.weight.data[-2] = pre_word_embedding[word_lexicon[BOS_TAG]] |
|
|
|
|
|
self.word_vocab = vocab |
|
|
|
|
|
else: |
|
|
|
|
|
word_emb_layer = None |
|
|
|
|
|
|
|
|
|
|
|
# For the model trained with character-based word encoder. |
|
|
# For the model trained with character-based word encoder. |
|
|
if config['token_embedder']['char_dim'] > 0: |
|
|
|
|
|
|
|
|
if config['token_embedder']['embedding']['dim'] > 0: |
|
|
char_lexicon = {} |
|
|
char_lexicon = {} |
|
|
with codecs.open(os.path.join(model_dir, 'char.dic'), 'r', encoding='utf-8') as fpi: |
|
|
with codecs.open(os.path.join(model_dir, 'char.dic'), 'r', encoding='utf-8') as fpi: |
|
|
for line in fpi: |
|
|
for line in fpi: |
|
@@ -645,22 +753,26 @@ class _ElmoModel(nn.Module): |
|
|
tokens.insert(0, '\u3000') |
|
|
tokens.insert(0, '\u3000') |
|
|
token, i = tokens |
|
|
token, i = tokens |
|
|
char_lexicon[token] = int(i) |
|
|
char_lexicon[token] = int(i) |
|
|
|
|
|
|
|
|
# 做一些sanity check |
|
|
# 做一些sanity check |
|
|
for special_word in [PAD_TAG, OOV_TAG, BOW_TAG, EOW_TAG]: |
|
|
for special_word in [PAD_TAG, OOV_TAG, BOW_TAG, EOW_TAG]: |
|
|
assert special_word in char_lexicon, f"{special_word} not found in char.dic." |
|
|
assert special_word in char_lexicon, f"{special_word} not found in char.dic." |
|
|
|
|
|
|
|
|
# 从vocab中构建char_vocab |
|
|
# 从vocab中构建char_vocab |
|
|
char_vocab = Vocabulary(unknown=OOV_TAG, padding=PAD_TAG) |
|
|
char_vocab = Vocabulary(unknown=OOV_TAG, padding=PAD_TAG) |
|
|
# 需要保证<bow>与<eow>在里面 |
|
|
# 需要保证<bow>与<eow>在里面 |
|
|
char_vocab.add_word(BOW_TAG) |
|
|
|
|
|
char_vocab.add_word(EOW_TAG) |
|
|
|
|
|
|
|
|
char_vocab.add_word_lst([BOW_TAG, EOW_TAG, BOS_TAG, EOS_TAG]) |
|
|
|
|
|
|
|
|
for word, index in vocab: |
|
|
for word, index in vocab: |
|
|
char_vocab.add_word_lst(list(word)) |
|
|
char_vocab.add_word_lst(list(word)) |
|
|
# 保证<eos>, <bos>也在 |
|
|
|
|
|
char_vocab.add_word_lst(list(BOS_TAG)) |
|
|
|
|
|
char_vocab.add_word_lst(list(EOS_TAG)) |
|
|
|
|
|
# 根据char_lexicon调整 |
|
|
|
|
|
char_emb_layer = nn.Embedding(len(char_vocab), int(config['token_embedder']['char_dim'])) |
|
|
|
|
|
pre_char_embedding = token_embedder_states.pop('char_emb_layer.embedding.weight') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.bos_index, self.eos_index, self._pad_index = len(vocab), len(vocab)+1, vocab.padding_idx |
|
|
|
|
|
# 根据char_lexicon调整, 多设置一位,是预留给word padding的(该位置的char表示为全0表示) |
|
|
|
|
|
char_emb_layer = nn.Embedding(len(char_vocab)+1, int(config['token_embedder']['embedding']['dim']), |
|
|
|
|
|
padding_idx=len(char_vocab)) |
|
|
|
|
|
with h5py.File(self.weight_file, 'r') as fin: |
|
|
|
|
|
char_embed_weights = fin['char_embed'][...] |
|
|
|
|
|
char_embed_weights = torch.from_numpy(char_embed_weights) |
|
|
found_char_count = 0 |
|
|
found_char_count = 0 |
|
|
for char, index in char_vocab: # 调整character embedding |
|
|
for char, index in char_vocab: # 调整character embedding |
|
|
if char in char_lexicon: |
|
|
if char in char_lexicon: |
|
@@ -668,79 +780,84 @@ class _ElmoModel(nn.Module): |
|
|
found_char_count += 1 |
|
|
found_char_count += 1 |
|
|
else: |
|
|
else: |
|
|
index_in_pre = char_lexicon[OOV_TAG] |
|
|
index_in_pre = char_lexicon[OOV_TAG] |
|
|
char_emb_layer.weight.data[index] = pre_char_embedding[index_in_pre] |
|
|
|
|
|
|
|
|
char_emb_layer.weight.data[index] = char_embed_weights[index_in_pre] |
|
|
|
|
|
|
|
|
print(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.") |
|
|
print(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.") |
|
|
# 生成words到chars的映射 |
|
|
# 生成words到chars的映射 |
|
|
if config['token_embedder']['name'].lower() == 'cnn': |
|
|
if config['token_embedder']['name'].lower() == 'cnn': |
|
|
max_chars = config['token_embedder']['max_characters_per_token'] |
|
|
max_chars = config['token_embedder']['max_characters_per_token'] |
|
|
elif config['token_embedder']['name'].lower() == 'lstm': |
|
|
elif config['token_embedder']['name'].lower() == 'lstm': |
|
|
max_chars = max(map(lambda x: len(x[0]), vocab)) + 2 # 需要补充两个<bow>与<eow> |
|
|
|
|
|
|
|
|
max_chars = max(map(lambda x: len(x[0]), vocab)) + 2 # 需要补充两个<bow>与<eow> |
|
|
else: |
|
|
else: |
|
|
raise ValueError('Unknown token_embedder: {0}'.format(config['token_embedder']['name'])) |
|
|
raise ValueError('Unknown token_embedder: {0}'.format(config['token_embedder']['name'])) |
|
|
# 增加<bos>, <eos>所以加2. |
|
|
|
|
|
|
|
|
|
|
|
self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab)+2, max_chars), |
|
|
self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab)+2, max_chars), |
|
|
fill_value=char_vocab.to_index(PAD_TAG), dtype=torch.long), |
|
|
|
|
|
|
|
|
fill_value=len(char_vocab), |
|
|
|
|
|
dtype=torch.long), |
|
|
requires_grad=False) |
|
|
requires_grad=False) |
|
|
for word, index in vocab: |
|
|
|
|
|
if len(word)+2>max_chars: |
|
|
|
|
|
word = word[:max_chars-2] |
|
|
|
|
|
if index==vocab.padding_idx: # 如果是pad的话,需要和给定的对齐 |
|
|
|
|
|
word = PAD_TAG |
|
|
|
|
|
elif index==vocab.unknown_idx: |
|
|
|
|
|
word = OOV_TAG |
|
|
|
|
|
char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(c) for c in word] + [char_vocab.to_index(EOW_TAG)] |
|
|
|
|
|
char_ids += [char_vocab.to_index(PAD_TAG)]*(max_chars-len(char_ids)) |
|
|
|
|
|
|
|
|
for word, index in list(iter(vocab)) + [(BOS_TAG, len(vocab)), (EOS_TAG, len(vocab)+1)]: |
|
|
|
|
|
if len(word) + 2 > max_chars: |
|
|
|
|
|
word = word[:max_chars - 2] |
|
|
|
|
|
if index == self._pad_index: |
|
|
|
|
|
continue |
|
|
|
|
|
elif word == BOS_TAG or word == EOS_TAG: |
|
|
|
|
|
char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(word)] + [ |
|
|
|
|
|
char_vocab.to_index(EOW_TAG)] |
|
|
|
|
|
char_ids += [char_vocab.to_index(PAD_TAG)] * (max_chars - len(char_ids)) |
|
|
|
|
|
else: |
|
|
|
|
|
char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(c) for c in word] + [ |
|
|
|
|
|
char_vocab.to_index(EOW_TAG)] |
|
|
|
|
|
char_ids += [char_vocab.to_index(PAD_TAG)] * (max_chars - len(char_ids)) |
|
|
self.words_to_chars_embedding[index] = torch.LongTensor(char_ids) |
|
|
self.words_to_chars_embedding[index] = torch.LongTensor(char_ids) |
|
|
for index, word in enumerate([BOS_TAG, EOS_TAG]): # 加上<eos>, <bos> |
|
|
|
|
|
if len(word)+2>max_chars: |
|
|
|
|
|
word = word[:max_chars-2] |
|
|
|
|
|
char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(c) for c in word] + [char_vocab.to_index(EOW_TAG)] |
|
|
|
|
|
char_ids += [char_vocab.to_index(PAD_TAG)]*(max_chars-len(char_ids)) |
|
|
|
|
|
self.words_to_chars_embedding[index+len(vocab)] = torch.LongTensor(char_ids) |
|
|
|
|
|
|
|
|
|
|
|
self.char_vocab = char_vocab |
|
|
self.char_vocab = char_vocab |
|
|
else: |
|
|
else: |
|
|
char_emb_layer = None |
|
|
char_emb_layer = None |
|
|
|
|
|
|
|
|
if config['token_embedder']['name'].lower() == 'cnn': |
|
|
if config['token_embedder']['name'].lower() == 'cnn': |
|
|
self.token_embedder = ConvTokenEmbedder( |
|
|
self.token_embedder = ConvTokenEmbedder( |
|
|
config, word_emb_layer, char_emb_layer) |
|
|
|
|
|
|
|
|
config, self.weight_file, None, char_emb_layer, self.char_vocab) |
|
|
elif config['token_embedder']['name'].lower() == 'lstm': |
|
|
elif config['token_embedder']['name'].lower() == 'lstm': |
|
|
self.token_embedder = LstmTokenEmbedder( |
|
|
self.token_embedder = LstmTokenEmbedder( |
|
|
config, word_emb_layer, char_emb_layer) |
|
|
|
|
|
self.token_embedder.load_state_dict(token_embedder_states, strict=False) |
|
|
|
|
|
if config['token_embedder']['word_dim'] > 0 and vocab._no_create_word_length > 0: # 需要映射,使得来自于dev, test的idx指向unk |
|
|
|
|
|
words_to_words = nn.Parameter(torch.arange(len(vocab)+2).long(), requires_grad=False) |
|
|
|
|
|
|
|
|
config, None, char_emb_layer) |
|
|
|
|
|
|
|
|
|
|
|
if config['token_embedder']['word_dim'] > 0 \ |
|
|
|
|
|
and vocab._no_create_word_length > 0: # 需要映射,使得来自于dev, test的idx指向unk |
|
|
|
|
|
words_to_words = nn.Parameter(torch.arange(len(vocab) + 2).long(), requires_grad=False) |
|
|
for word, idx in vocab: |
|
|
for word, idx in vocab: |
|
|
if vocab._is_word_no_create_entry(word): |
|
|
if vocab._is_word_no_create_entry(word): |
|
|
words_to_words[idx] = vocab.unknown_idx |
|
|
words_to_words[idx] = vocab.unknown_idx |
|
|
setattr(self.token_embedder, 'words_to_words', words_to_words) |
|
|
setattr(self.token_embedder, 'words_to_words', words_to_words) |
|
|
self.output_dim = config['encoder']['projection_dim'] |
|
|
self.output_dim = config['encoder']['projection_dim'] |
|
|
|
|
|
|
|
|
|
|
|
# 暂时只考虑 elmo |
|
|
if config['encoder']['name'].lower() == 'elmo': |
|
|
if config['encoder']['name'].lower() == 'elmo': |
|
|
self.encoder = ElmobiLm(config) |
|
|
self.encoder = ElmobiLm(config) |
|
|
elif config['encoder']['name'].lower() == 'lstm': |
|
|
elif config['encoder']['name'].lower() == 'lstm': |
|
|
self.encoder = LstmbiLm(config) |
|
|
self.encoder = LstmbiLm(config) |
|
|
self.encoder.load_state_dict(torch.load(os.path.join(model_dir, 'encoder.pkl'), |
|
|
|
|
|
map_location='cpu')) |
|
|
|
|
|
|
|
|
|
|
|
self.bos_index = len(vocab) |
|
|
|
|
|
self.eos_index = len(vocab) + 1 |
|
|
|
|
|
self._pad_index = vocab.padding_idx |
|
|
|
|
|
|
|
|
self.encoder.load_weights(self.weight_file) |
|
|
|
|
|
|
|
|
if cache_word_reprs: |
|
|
if cache_word_reprs: |
|
|
if config['token_embedder']['char_dim']>0: # 只有在使用了chars的情况下有用 |
|
|
|
|
|
|
|
|
if config['token_embedder']['embedding']['dim'] > 0: # 只有在使用了chars的情况下有用 |
|
|
print("Start to generate cache word representations.") |
|
|
print("Start to generate cache word representations.") |
|
|
batch_size = 320 |
|
|
batch_size = 320 |
|
|
num_batches = self.words_to_chars_embedding.size(0)//batch_size + \ |
|
|
|
|
|
int(self.words_to_chars_embedding.size(0)%batch_size!=0) |
|
|
|
|
|
self.cached_word_embedding = nn.Embedding(self.words_to_chars_embedding.size(0), |
|
|
|
|
|
|
|
|
# bos eos |
|
|
|
|
|
word_size = self.words_to_chars_embedding.size(0) |
|
|
|
|
|
num_batches = word_size // batch_size + \ |
|
|
|
|
|
int(word_size % batch_size != 0) |
|
|
|
|
|
|
|
|
|
|
|
self.cached_word_embedding = nn.Embedding(word_size, |
|
|
config['encoder']['projection_dim']) |
|
|
config['encoder']['projection_dim']) |
|
|
with torch.no_grad(): |
|
|
with torch.no_grad(): |
|
|
for i in range(num_batches): |
|
|
for i in range(num_batches): |
|
|
words = torch.arange(i*batch_size, min((i+1)*batch_size, self.words_to_chars_embedding.size(0))).long() |
|
|
|
|
|
|
|
|
words = torch.arange(i * batch_size, |
|
|
|
|
|
min((i + 1) * batch_size, word_size)).long() |
|
|
chars = self.words_to_chars_embedding[words].unsqueeze(1) # batch_size x 1 x max_chars |
|
|
chars = self.words_to_chars_embedding[words].unsqueeze(1) # batch_size x 1 x max_chars |
|
|
word_reprs = self.token_embedder(words.unsqueeze(1), chars).detach() # batch_size x 1 x config['encoder']['projection_dim'] |
|
|
|
|
|
|
|
|
word_reprs = self.token_embedder(words.unsqueeze(1), |
|
|
|
|
|
chars).detach() # batch_size x 1 x config['encoder']['projection_dim'] |
|
|
self.cached_word_embedding.weight.data[words] = word_reprs.squeeze(1) |
|
|
self.cached_word_embedding.weight.data[words] = word_reprs.squeeze(1) |
|
|
|
|
|
|
|
|
print("Finish generating cached word representations. Going to delete the character encoder.") |
|
|
print("Finish generating cached word representations. Going to delete the character encoder.") |
|
|
del self.token_embedder, self.words_to_chars_embedding |
|
|
del self.token_embedder, self.words_to_chars_embedding |
|
|
else: |
|
|
else: |
|
@@ -758,7 +875,7 @@ class _ElmoModel(nn.Module): |
|
|
seq_len = words.ne(self._pad_index).sum(dim=-1) |
|
|
seq_len = words.ne(self._pad_index).sum(dim=-1) |
|
|
expanded_words[:, 1:-1] = words |
|
|
expanded_words[:, 1:-1] = words |
|
|
expanded_words[:, 0].fill_(self.bos_index) |
|
|
expanded_words[:, 0].fill_(self.bos_index) |
|
|
expanded_words[torch.arange(batch_size).to(words), seq_len+1] = self.eos_index |
|
|
|
|
|
|
|
|
expanded_words[torch.arange(batch_size).to(words), seq_len + 1] = self.eos_index |
|
|
seq_len = seq_len + 2 |
|
|
seq_len = seq_len + 2 |
|
|
if hasattr(self, 'cached_word_embedding'): |
|
|
if hasattr(self, 'cached_word_embedding'): |
|
|
token_embedding = self.cached_word_embedding(expanded_words) |
|
|
token_embedding = self.cached_word_embedding(expanded_words) |
|
@@ -767,16 +884,18 @@ class _ElmoModel(nn.Module): |
|
|
chars = self.words_to_chars_embedding[expanded_words] |
|
|
chars = self.words_to_chars_embedding[expanded_words] |
|
|
else: |
|
|
else: |
|
|
chars = None |
|
|
chars = None |
|
|
token_embedding = self.token_embedder(expanded_words, chars) |
|
|
|
|
|
|
|
|
token_embedding = self.token_embedder(expanded_words, chars) # batch_size x max_len x embed_dim |
|
|
|
|
|
|
|
|
if self.config['encoder']['name'] == 'elmo': |
|
|
if self.config['encoder']['name'] == 'elmo': |
|
|
encoder_output = self.encoder(token_embedding, seq_len) |
|
|
encoder_output = self.encoder(token_embedding, seq_len) |
|
|
if encoder_output.size(2) < max_len+2: |
|
|
|
|
|
dummy_tensor = encoder_output.new_zeros(encoder_output.size(0), batch_size, |
|
|
|
|
|
max_len + 2 - encoder_output.size(2), encoder_output.size(-1)) |
|
|
|
|
|
encoder_output = torch.cat([encoder_output, dummy_tensor], 2) |
|
|
|
|
|
sz = encoder_output.size() # 2, batch_size, max_len, hidden_size |
|
|
|
|
|
token_embedding = torch.cat([token_embedding, token_embedding], dim=2).view(1, sz[1], sz[2], sz[3]) |
|
|
|
|
|
encoder_output = torch.cat([token_embedding, encoder_output], dim=0) |
|
|
|
|
|
|
|
|
if encoder_output.size(2) < max_len + 2: |
|
|
|
|
|
num_layers, _, output_len, hidden_size = encoder_output.size() |
|
|
|
|
|
dummy_tensor = encoder_output.new_zeros(num_layers, batch_size, |
|
|
|
|
|
max_len + 2 - output_len, hidden_size) |
|
|
|
|
|
encoder_output = torch.cat((encoder_output, dummy_tensor), 2) |
|
|
|
|
|
sz = encoder_output.size() # 2, batch_size, max_len, hidden_size |
|
|
|
|
|
token_embedding = torch.cat((token_embedding, token_embedding), dim=2).view(1, sz[1], sz[2], sz[3]) |
|
|
|
|
|
encoder_output = torch.cat((token_embedding, encoder_output), dim=0) |
|
|
elif self.config['encoder']['name'] == 'lstm': |
|
|
elif self.config['encoder']['name'] == 'lstm': |
|
|
encoder_output = self.encoder(token_embedding, seq_len) |
|
|
encoder_output = self.encoder(token_embedding, seq_len) |
|
|
else: |
|
|
else: |
|
@@ -784,5 +903,4 @@ class _ElmoModel(nn.Module): |
|
|
|
|
|
|
|
|
# 删除<eos>, <bos>. 这里没有精确地删除,但应该也不会影响最后的结果了。 |
|
|
# 删除<eos>, <bos>. 这里没有精确地删除,但应该也不会影响最后的结果了。 |
|
|
encoder_output = encoder_output[:, :, 1:-1] |
|
|
encoder_output = encoder_output[:, :, 1:-1] |
|
|
|
|
|
|
|
|
return encoder_output |
|
|
return encoder_output |