From 2c00c1ae5aab51feae7dfec5331e58d82c2f46af Mon Sep 17 00:00:00 2001 From: yh Date: Tue, 9 Jul 2019 17:32:22 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=A0=E9=99=A4elmo=E5=AF=B9h5py=E7=9A=84?= =?UTF-8?q?=E4=BE=9D=E8=B5=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/modules/encoder/_elmo.py | 460 ++++++++++--------------------- requirements.txt | 1 - 2 files changed, 144 insertions(+), 317 deletions(-) diff --git a/fastNLP/modules/encoder/_elmo.py b/fastNLP/modules/encoder/_elmo.py index a49634bb..6b08edc8 100644 --- a/fastNLP/modules/encoder/_elmo.py +++ b/fastNLP/modules/encoder/_elmo.py @@ -6,14 +6,13 @@ from typing import Optional, Tuple, List, Callable import os -import h5py -import numpy import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence from ...core.vocabulary import Vocabulary import json +import pickle from ..utils import get_dropout_mask import codecs @@ -244,13 +243,13 @@ class LstmbiLm(nn.Module): def __init__(self, config): super(LstmbiLm, self).__init__() self.config = config - self.encoder = nn.LSTM(self.config['encoder']['projection_dim'], - self.config['encoder']['dim'], - num_layers=self.config['encoder']['n_layers'], + self.encoder = nn.LSTM(self.config['lstm']['projection_dim'], + self.config['lstm']['dim'], + num_layers=self.config['lstm']['n_layers'], bidirectional=True, batch_first=True, dropout=self.config['dropout']) - self.projection = nn.Linear(self.config['encoder']['dim'], self.config['encoder']['projection_dim'], bias=True) + self.projection = nn.Linear(self.config['lstm']['dim'], self.config['lstm']['projection_dim'], bias=True) def forward(self, inputs, seq_len): sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) @@ -260,7 +259,7 @@ class LstmbiLm(nn.Module): output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=self.batch_first) _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) output = output[unsort_idx] - forward, backward = output.split(self.config['encoder']['dim'], 2) + forward, backward = output.split(self.config['lstm']['dim'], 2) return torch.cat([self.projection(forward), self.projection(backward)], dim=2) @@ -268,13 +267,13 @@ class ElmobiLm(torch.nn.Module): def __init__(self, config): super(ElmobiLm, self).__init__() self.config = config - input_size = config['encoder']['projection_dim'] - hidden_size = config['encoder']['projection_dim'] - cell_size = config['encoder']['dim'] - num_layers = config['encoder']['n_layers'] - memory_cell_clip_value = config['encoder']['cell_clip'] - state_projection_clip_value = config['encoder']['proj_clip'] - recurrent_dropout_probability = config['dropout'] + input_size = config['lstm']['projection_dim'] + hidden_size = config['lstm']['projection_dim'] + cell_size = config['lstm']['dim'] + num_layers = config['lstm']['n_layers'] + memory_cell_clip_value = config['lstm']['cell_clip'] + state_projection_clip_value = config['lstm']['proj_clip'] + recurrent_dropout_probability = 0.0 self.input_size = input_size self.hidden_size = hidden_size @@ -409,199 +408,52 @@ class ElmobiLm(torch.nn.Module): torch.cat(final_memory_states, 0)) return stacked_sequence_outputs, final_state_tuple - def load_weights(self, weight_file: str) -> None: - """ - Load the pre-trained weights from the file. - """ - requires_grad = False - - with h5py.File(weight_file, 'r') as fin: - for i_layer, lstms in enumerate( - zip(self.forward_layers, self.backward_layers) - ): - for j_direction, lstm in enumerate(lstms): - # lstm is an instance of LSTMCellWithProjection - cell_size = lstm.cell_size - - dataset = fin['RNN_%s' % j_direction]['RNN']['MultiRNNCell']['Cell%s' % i_layer - ]['LSTMCell'] - - # tensorflow packs together both W and U matrices into one matrix, - # but pytorch maintains individual matrices. In addition, tensorflow - # packs the gates as input, memory, forget, output but pytorch - # uses input, forget, memory, output. So we need to modify the weights. - tf_weights = numpy.transpose(dataset['W_0'][...]) - torch_weights = tf_weights.copy() - - # split the W from U matrices - input_size = lstm.input_size - input_weights = torch_weights[:, :input_size] - recurrent_weights = torch_weights[:, input_size:] - tf_input_weights = tf_weights[:, :input_size] - tf_recurrent_weights = tf_weights[:, input_size:] - - # handle the different gate order convention - for torch_w, tf_w in [[input_weights, tf_input_weights], - [recurrent_weights, tf_recurrent_weights]]: - torch_w[(1 * cell_size):(2 * cell_size), :] = tf_w[(2 * cell_size):(3 * cell_size), :] - torch_w[(2 * cell_size):(3 * cell_size), :] = tf_w[(1 * cell_size):(2 * cell_size), :] - - lstm.input_linearity.weight.data.copy_(torch.FloatTensor(input_weights)) - lstm.state_linearity.weight.data.copy_(torch.FloatTensor(recurrent_weights)) - lstm.input_linearity.weight.requires_grad = requires_grad - lstm.state_linearity.weight.requires_grad = requires_grad - - # the bias weights - tf_bias = dataset['B'][...] - # tensorflow adds 1.0 to forget gate bias instead of modifying the - # parameters... - tf_bias[(2 * cell_size):(3 * cell_size)] += 1 - torch_bias = tf_bias.copy() - torch_bias[(1 * cell_size):(2 * cell_size) - ] = tf_bias[(2 * cell_size):(3 * cell_size)] - torch_bias[(2 * cell_size):(3 * cell_size) - ] = tf_bias[(1 * cell_size):(2 * cell_size)] - lstm.state_linearity.bias.data.copy_(torch.FloatTensor(torch_bias)) - lstm.state_linearity.bias.requires_grad = requires_grad - - # the projection weights - proj_weights = numpy.transpose(dataset['W_P_0'][...]) - lstm.state_projection.weight.data.copy_(torch.FloatTensor(proj_weights)) - lstm.state_projection.weight.requires_grad = requires_grad - - -class LstmTokenEmbedder(nn.Module): - def __init__(self, config, word_emb_layer, char_emb_layer): - super(LstmTokenEmbedder, self).__init__() - self.config = config - self.word_emb_layer = word_emb_layer - self.char_emb_layer = char_emb_layer - self.output_dim = config['encoder']['projection_dim'] - emb_dim = 0 - if word_emb_layer is not None: - emb_dim += word_emb_layer.n_d - - if char_emb_layer is not None: - emb_dim += char_emb_layer.n_d * 2 - self.char_lstm = nn.LSTM(char_emb_layer.n_d, char_emb_layer.n_d, num_layers=1, bidirectional=True, - batch_first=True, dropout=config['dropout']) - - self.projection = nn.Linear(emb_dim, self.output_dim, bias=True) - - def forward(self, words, chars): - embs = [] - if self.word_emb_layer is not None: - if hasattr(self, 'words_to_words'): - words = self.words_to_words[words] - word_emb = self.word_emb_layer(words) - embs.append(word_emb) - - if self.char_emb_layer is not None: - batch_size, seq_len, _ = chars.shape - chars = chars.view(batch_size * seq_len, -1) - chars_emb = self.char_emb_layer(chars) - # TODO 这里应该要考虑seq_len的问题 - _, (chars_outputs, __) = self.char_lstm(chars_emb) - chars_outputs = chars_outputs.contiguous().view(-1, self.config['token_embedder']['embedding']['dim'] * 2) - embs.append(chars_outputs) - - token_embedding = torch.cat(embs, dim=2) - - return self.projection(token_embedding) - class ConvTokenEmbedder(nn.Module): - def __init__(self, config, weight_file, word_emb_layer, char_emb_layer, char_vocab): + def __init__(self, config, weight_file, word_emb_layer, char_emb_layer): super(ConvTokenEmbedder, self).__init__() self.weight_file = weight_file self.word_emb_layer = word_emb_layer self.char_emb_layer = char_emb_layer - self.output_dim = config['encoder']['projection_dim'] + self.output_dim = config['lstm']['projection_dim'] self._options = config self.requires_grad = False - self._load_weights() self._char_embedding_weights = char_emb_layer.weight.data - def _load_weights(self): - self._load_cnn_weights() - self._load_highway() - self._load_projection() - - def _load_cnn_weights(self): - cnn_options = self._options['token_embedder'] - filters = cnn_options['filters'] - char_embed_dim = cnn_options['embedding']['dim'] - - convolutions = [] - for i, (width, num) in enumerate(filters): - conv = torch.nn.Conv1d( - in_channels=char_embed_dim, - out_channels=num, - kernel_size=width, - bias=True - ) - # load the weights - with h5py.File(self.weight_file, 'r') as fin: - weight = fin['CNN']['W_cnn_{}'.format(i)][...] - bias = fin['CNN']['b_cnn_{}'.format(i)][...] - - w_reshaped = numpy.transpose(weight.squeeze(axis=0), axes=(2, 1, 0)) - if w_reshaped.shape != tuple(conv.weight.data.shape): - raise ValueError("Invalid weight file") - conv.weight.data.copy_(torch.FloatTensor(w_reshaped)) - conv.bias.data.copy_(torch.FloatTensor(bias)) - - conv.weight.requires_grad = self.requires_grad - conv.bias.requires_grad = self.requires_grad - - convolutions.append(conv) - self.add_module('char_conv_{}'.format(i), conv) - - self._convolutions = convolutions - - def _load_highway(self): - # the highway layers have same dimensionality as the number of cnn filters - cnn_options = self._options['token_embedder'] - filters = cnn_options['filters'] - n_filters = sum(f[1] for f in filters) - n_highway = cnn_options['n_highway'] - - # create the layers, and load the weights - self._highways = Highway(n_filters, n_highway, activation=torch.nn.functional.relu) - for k in range(n_highway): - # The AllenNLP highway is one matrix multplication with concatenation of - # transform and carry weights. - with h5py.File(self.weight_file, 'r') as fin: - # The weights are transposed due to multiplication order assumptions in tf - # vs pytorch (tf.matmul(X, W) vs pytorch.matmul(W, X)) - w_transform = numpy.transpose(fin['CNN_high_{}'.format(k)]['W_transform'][...]) - # -1.0 since AllenNLP is g * x + (1 - g) * f(x) but tf is (1 - g) * x + g * f(x) - w_carry = -1.0 * numpy.transpose(fin['CNN_high_{}'.format(k)]['W_carry'][...]) - weight = numpy.concatenate([w_transform, w_carry], axis=0) - self._highways._layers[k].weight.data.copy_(torch.FloatTensor(weight)) - self._highways._layers[k].weight.requires_grad = self.requires_grad - - b_transform = fin['CNN_high_{}'.format(k)]['b_transform'][...] - b_carry = -1.0 * fin['CNN_high_{}'.format(k)]['b_carry'][...] - bias = numpy.concatenate([b_transform, b_carry], axis=0) - self._highways._layers[k].bias.data.copy_(torch.FloatTensor(bias)) - self._highways._layers[k].bias.requires_grad = self.requires_grad - - def _load_projection(self): - cnn_options = self._options['token_embedder'] - filters = cnn_options['filters'] - n_filters = sum(f[1] for f in filters) - - self._projection = torch.nn.Linear(n_filters, self.output_dim, bias=True) - with h5py.File(self.weight_file, 'r') as fin: - weight = fin['CNN_proj']['W_proj'][...] - bias = fin['CNN_proj']['b_proj'][...] - self._projection.weight.data.copy_(torch.FloatTensor(numpy.transpose(weight))) - self._projection.bias.data.copy_(torch.FloatTensor(bias)) - - self._projection.weight.requires_grad = self.requires_grad - self._projection.bias.requires_grad = self.requires_grad + char_cnn_options = self._options['char_cnn'] + if char_cnn_options['activation'] == 'tanh': + self.activation = torch.tanh + elif char_cnn_options['activation'] == 'relu': + self.activation = torch.nn.functional.relu + else: + raise Exception("Unknown activation") + + if char_emb_layer is not None: + self.char_conv = [] + cnn_config = config['char_cnn'] + filters = cnn_config['filters'] + char_embed_dim = cnn_config['embedding']['dim'] + convolutions = [] + + for i, (width, num) in enumerate(filters): + conv = torch.nn.Conv1d( + in_channels=char_embed_dim, + out_channels=num, + kernel_size=width, + bias=True + ) + convolutions.append(conv) + self.add_module('char_conv_{}'.format(i), conv) + + self._convolutions = convolutions + + n_filters = sum(f[1] for f in filters) + n_highway = cnn_config['n_highway'] + + self._highways = Highway(n_filters, n_highway, activation=torch.nn.functional.relu) + + self._projection = torch.nn.Linear(n_filters, self.output_dim, bias=True) def forward(self, words, chars): """ @@ -616,15 +468,8 @@ class ConvTokenEmbedder(nn.Module): # self._char_embedding_weights # ) batch_size, sequence_length, max_char_len = chars.size() - character_embedding = self.char_emb_layer(chars).reshape(batch_size*sequence_length, max_char_len, -1) + character_embedding = self.char_emb_layer(chars).reshape(batch_size * sequence_length, max_char_len, -1) # run convolutions - cnn_options = self._options['token_embedder'] - if cnn_options['activation'] == 'tanh': - activation = torch.tanh - elif cnn_options['activation'] == 'relu': - activation = torch.nn.functional.relu - else: - raise Exception("Unknown activation") # (batch_size * sequence_length, embed_dim, max_chars_per_token) character_embedding = torch.transpose(character_embedding, 1, 2) @@ -634,7 +479,7 @@ class ConvTokenEmbedder(nn.Module): convolved = conv(character_embedding) # (batch_size * sequence_length, n_filters for this width) convolved, _ = torch.max(convolved, dim=-1) - convolved = activation(convolved) + convolved = self.activation(convolved) convs.append(convolved) # (batch_size * sequence_length, n_filters) @@ -712,8 +557,9 @@ class _ElmoModel(nn.Module): def __init__(self, model_dir: str, vocab: Vocabulary = None, cache_word_reprs: bool = False): super(_ElmoModel, self).__init__() - - dir = os.walk(model_dir) + # self.pkl_dict = {} + self.model_dir = model_dir + dir = os.walk(self.model_dir) config_file = None weight_file = None config_count = 0 @@ -723,7 +569,7 @@ class _ElmoModel(nn.Module): if file_name.__contains__(".json"): config_file = file_name config_count += 1 - elif file_name.__contains__(".hdf5"): + elif file_name.__contains__(".pkl"): weight_file = file_name weight_count += 1 if config_count > 1 or weight_count > 1: @@ -744,102 +590,86 @@ class _ElmoModel(nn.Module): EOW_TAG = '' # For the model trained with character-based word encoder. - if config['token_embedder']['embedding']['dim'] > 0: - char_lexicon = {} - with codecs.open(os.path.join(model_dir, 'char.dic'), 'r', encoding='utf-8') as fpi: - for line in fpi: - tokens = line.strip().split('\t') - if len(tokens) == 1: - tokens.insert(0, '\u3000') - token, i = tokens - char_lexicon[token] = int(i) - - # 做一些sanity check - for special_word in [PAD_TAG, OOV_TAG, BOW_TAG, EOW_TAG]: - assert special_word in char_lexicon, f"{special_word} not found in char.dic." - - # 从vocab中构建char_vocab - char_vocab = Vocabulary(unknown=OOV_TAG, padding=PAD_TAG) - # 需要保证在里面 - char_vocab.add_word_lst([BOW_TAG, EOW_TAG, BOS_TAG, EOS_TAG]) - - for word, index in vocab: - char_vocab.add_word_lst(list(word)) - - self.bos_index, self.eos_index, self._pad_index = len(vocab), len(vocab)+1, vocab.padding_idx - # 根据char_lexicon调整, 多设置一位,是预留给word padding的(该位置的char表示为全0表示) - char_emb_layer = nn.Embedding(len(char_vocab)+1, int(config['token_embedder']['embedding']['dim']), - padding_idx=len(char_vocab)) - with h5py.File(self.weight_file, 'r') as fin: - char_embed_weights = fin['char_embed'][...] - char_embed_weights = torch.from_numpy(char_embed_weights) - found_char_count = 0 - for char, index in char_vocab: # 调整character embedding - if char in char_lexicon: - index_in_pre = char_lexicon.get(char) - found_char_count += 1 - else: - index_in_pre = char_lexicon[OOV_TAG] - char_emb_layer.weight.data[index] = char_embed_weights[index_in_pre] - - print(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.") - # 生成words到chars的映射 - if config['token_embedder']['name'].lower() == 'cnn': - max_chars = config['token_embedder']['max_characters_per_token'] - elif config['token_embedder']['name'].lower() == 'lstm': - max_chars = max(map(lambda x: len(x[0]), vocab)) + 2 # 需要补充两个 + char_lexicon = {} + with codecs.open(os.path.join(model_dir, 'char.dic'), 'r', encoding='utf-8') as fpi: + for line in fpi: + tokens = line.strip().split('\t') + if len(tokens) == 1: + tokens.insert(0, '\u3000') + token, i = tokens + char_lexicon[token] = int(i) + + # 做一些sanity check + for special_word in [PAD_TAG, OOV_TAG, BOW_TAG, EOW_TAG]: + assert special_word in char_lexicon, f"{special_word} not found in char.dic." + + # 从vocab中构建char_vocab + char_vocab = Vocabulary(unknown=OOV_TAG, padding=PAD_TAG) + # 需要保证在里面 + char_vocab.add_word_lst([BOW_TAG, EOW_TAG, BOS_TAG, EOS_TAG]) + + for word, index in vocab: + char_vocab.add_word_lst(list(word)) + + self.bos_index, self.eos_index, self._pad_index = len(vocab), len(vocab) + 1, vocab.padding_idx + # 根据char_lexicon调整, 多设置一位,是预留给word padding的(该位置的char表示为全0表示) + char_emb_layer = nn.Embedding(len(char_vocab) + 1, int(config['char_cnn']['embedding']['dim']), + padding_idx=len(char_vocab)) + + # 读入预训练权重 这里的elmo_model 是个dict 有char_embed的值以及char_cnn和 lstm 的 state_dict + elmo_pkl = open(os.path.join(self.model_dir, weight_file), "rb") + elmo_model = pickle.load(elmo_pkl) + elmo_pkl.close() + + self.char_embed_weights = elmo_model["char_embed"] + + found_char_count = 0 + for char, index in char_vocab: # 调整character embedding + if char in char_lexicon: + index_in_pre = char_lexicon.get(char) + found_char_count += 1 else: - raise ValueError('Unknown token_embedder: {0}'.format(config['token_embedder']['name'])) - - self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab)+2, max_chars), - fill_value=len(char_vocab), - dtype=torch.long), - requires_grad=False) - for word, index in list(iter(vocab)) + [(BOS_TAG, len(vocab)), (EOS_TAG, len(vocab)+1)]: - if len(word) + 2 > max_chars: - word = word[:max_chars - 2] - if index == self._pad_index: - continue - elif word == BOS_TAG or word == EOS_TAG: - char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(word)] + [ - char_vocab.to_index(EOW_TAG)] - char_ids += [char_vocab.to_index(PAD_TAG)] * (max_chars - len(char_ids)) - else: - char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(c) for c in word] + [ - char_vocab.to_index(EOW_TAG)] - char_ids += [char_vocab.to_index(PAD_TAG)] * (max_chars - len(char_ids)) - self.words_to_chars_embedding[index] = torch.LongTensor(char_ids) - - self.char_vocab = char_vocab - else: - char_emb_layer = None - - if config['token_embedder']['name'].lower() == 'cnn': - self.token_embedder = ConvTokenEmbedder( - config, self.weight_file, None, char_emb_layer, self.char_vocab) - elif config['token_embedder']['name'].lower() == 'lstm': - self.token_embedder = LstmTokenEmbedder( - config, None, char_emb_layer) - - if config['token_embedder']['word_dim'] > 0 \ - and vocab._no_create_word_length > 0: # 需要映射,使得来自于dev, test的idx指向unk - words_to_words = nn.Parameter(torch.arange(len(vocab) + 2).long(), requires_grad=False) - for word, idx in vocab: - if vocab._is_word_no_create_entry(word): - words_to_words[idx] = vocab.unknown_idx - setattr(self.token_embedder, 'words_to_words', words_to_words) - self.output_dim = config['encoder']['projection_dim'] - - # 暂时只考虑 elmo - if config['encoder']['name'].lower() == 'elmo': - self.encoder = ElmobiLm(config) - elif config['encoder']['name'].lower() == 'lstm': - self.encoder = LstmbiLm(config) - - self.encoder.load_weights(self.weight_file) + index_in_pre = char_lexicon[OOV_TAG] + char_emb_layer.weight.data[index] = self.char_embed_weights[index_in_pre] + + print(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.") + # 生成words到chars的映射 + max_chars = config['char_cnn']['max_characters_per_token'] + + self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab) + 2, max_chars), + fill_value=len(char_vocab), + dtype=torch.long), + requires_grad=False) + for word, index in list(iter(vocab)) + [(BOS_TAG, len(vocab)), (EOS_TAG, len(vocab) + 1)]: + if len(word) + 2 > max_chars: + word = word[:max_chars - 2] + if index == self._pad_index: + continue + elif word == BOS_TAG or word == EOS_TAG: + char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(word)] + [ + char_vocab.to_index(EOW_TAG)] + char_ids += [char_vocab.to_index(PAD_TAG)] * (max_chars - len(char_ids)) + else: + char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(c) for c in word] + [ + char_vocab.to_index(EOW_TAG)] + char_ids += [char_vocab.to_index(PAD_TAG)] * (max_chars - len(char_ids)) + self.words_to_chars_embedding[index] = torch.LongTensor(char_ids) + + self.char_vocab = char_vocab + + self.token_embedder = ConvTokenEmbedder( + config, self.weight_file, None, char_emb_layer) + + self.token_embedder.load_state_dict(elmo_model["char_cnn"]) + + self.output_dim = config['lstm']['projection_dim'] + + # lstm encoder + self.encoder = ElmobiLm(config) + self.encoder.load_state_dict(elmo_model["lstm"]) if cache_word_reprs: - if config['token_embedder']['embedding']['dim'] > 0: # 只有在使用了chars的情况下有用 + if config['char_cnn']['embedding']['dim'] > 0: # 只有在使用了chars的情况下有用 print("Start to generate cache word representations.") batch_size = 320 # bos eos @@ -848,7 +678,7 @@ class _ElmoModel(nn.Module): int(word_size % batch_size != 0) self.cached_word_embedding = nn.Embedding(word_size, - config['encoder']['projection_dim']) + config['lstm']['projection_dim']) with torch.no_grad(): for i in range(num_batches): words = torch.arange(i * batch_size, @@ -877,6 +707,8 @@ class _ElmoModel(nn.Module): expanded_words[:, 0].fill_(self.bos_index) expanded_words[torch.arange(batch_size).to(words), seq_len + 1] = self.eos_index seq_len = seq_len + 2 + zero_tensor = torch.zeros(expanded_words.shape).long() + mask = (expanded_words == zero_tensor).unsqueeze(-1) if hasattr(self, 'cached_word_embedding'): token_embedding = self.cached_word_embedding(expanded_words) else: @@ -886,20 +718,16 @@ class _ElmoModel(nn.Module): chars = None token_embedding = self.token_embedder(expanded_words, chars) # batch_size x max_len x embed_dim - if self.config['encoder']['name'] == 'elmo': - encoder_output = self.encoder(token_embedding, seq_len) - if encoder_output.size(2) < max_len + 2: - num_layers, _, output_len, hidden_size = encoder_output.size() - dummy_tensor = encoder_output.new_zeros(num_layers, batch_size, - max_len + 2 - output_len, hidden_size) - encoder_output = torch.cat((encoder_output, dummy_tensor), 2) - sz = encoder_output.size() # 2, batch_size, max_len, hidden_size - token_embedding = torch.cat((token_embedding, token_embedding), dim=2).view(1, sz[1], sz[2], sz[3]) - encoder_output = torch.cat((token_embedding, encoder_output), dim=0) - elif self.config['encoder']['name'] == 'lstm': - encoder_output = self.encoder(token_embedding, seq_len) - else: - raise ValueError('Unknown encoder: {0}'.format(self.config['encoder']['name'])) + encoder_output = self.encoder(token_embedding, seq_len) + if encoder_output.size(2) < max_len + 2: + num_layers, _, output_len, hidden_size = encoder_output.size() + dummy_tensor = encoder_output.new_zeros(num_layers, batch_size, + max_len + 2 - output_len, hidden_size) + encoder_output = torch.cat((encoder_output, dummy_tensor), 2) + sz = encoder_output.size() # 2, batch_size, max_len, hidden_size + token_embedding = token_embedding.masked_fill(mask, 0) + token_embedding = torch.cat((token_embedding, token_embedding), dim=2).view(1, sz[1], sz[2], sz[3]) + encoder_output = torch.cat((token_embedding, encoder_output), dim=0) # 删除, . 这里没有精确地删除,但应该也不会影响最后的结果了。 encoder_output = encoder_output[:, :, 1:-1] diff --git a/requirements.txt b/requirements.txt index 90b67f2c..f71e2223 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,3 @@ tqdm>=4.28.1 nltk>=3.4.1 requests spacy -h5py