diff --git a/fastNLP/modules/encoder/_elmo.py b/fastNLP/modules/encoder/_elmo.py index 4ebee819..a49634bb 100644 --- a/fastNLP/modules/encoder/_elmo.py +++ b/fastNLP/modules/encoder/_elmo.py @@ -1,12 +1,13 @@ - """ -这个页面的代码大量参考了https://github.com/HIT-SCIR/ELMoForManyLangs/tree/master/elmoformanylangs +这个页面的代码大量参考了 allenNLP """ - from typing import Optional, Tuple, List, Callable import os + +import h5py +import numpy import torch import torch.nn as nn import torch.nn.functional as F @@ -16,7 +17,6 @@ import json from ..utils import get_dropout_mask import codecs -from torch import autograd class LstmCellWithProjection(torch.nn.Module): """ @@ -58,6 +58,7 @@ class LstmCellWithProjection(torch.nn.Module): respectively. The first dimension is 1 in order to match the Pytorch API for returning stacked LSTM states. """ + def __init__(self, input_size: int, hidden_size: int, @@ -129,13 +130,13 @@ class LstmCellWithProjection(torch.nn.Module): # We have to use this '.data.new().fill_' pattern to create tensors with the correct # type - forward has no knowledge of whether these are torch.Tensors or torch.cuda.Tensors. output_accumulator = inputs.data.new(batch_size, - total_timesteps, - self.hidden_size).fill_(0) + total_timesteps, + self.hidden_size).fill_(0) if initial_state is None: full_batch_previous_memory = inputs.data.new(batch_size, - self.cell_size).fill_(0) + self.cell_size).fill_(0) full_batch_previous_state = inputs.data.new(batch_size, - self.hidden_size).fill_(0) + self.hidden_size).fill_(0) else: full_batch_previous_state = initial_state[0].squeeze(0) full_batch_previous_memory = initial_state[1].squeeze(0) @@ -169,7 +170,7 @@ class LstmCellWithProjection(torch.nn.Module): # Second conditional: Does the next shortest sequence beyond the current batch # index require computation use this timestep? while current_length_index < (len(batch_lengths) - 1) and \ - batch_lengths[current_length_index + 1] > index: + batch_lengths[current_length_index + 1] > index: current_length_index += 1 # Actually get the slices of the batch which we @@ -256,7 +257,7 @@ class LstmbiLm(nn.Module): inputs = inputs[sort_idx] inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=self.batch_first) output, hx = self.encoder(inputs, None) # -> [N,L,C] - output, _ = nn.util.rnn.pad_packed_sequence(output, batch_first=self.batch_first) + output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=self.batch_first) _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) output = output[unsort_idx] forward, backward = output.split(self.config['encoder']['dim'], 2) @@ -316,13 +317,13 @@ class ElmobiLm(torch.nn.Module): :param seq_len: batch_size :return: torch.FloatTensor. num_layers x batch_size x max_len x hidden_size """ + max_len = inputs.size(1) sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) inputs = inputs[sort_idx] inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=True) output, _ = self._lstm_forward(inputs, None) _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) output = output[:, unsort_idx] - return output def _lstm_forward(self, @@ -399,7 +400,7 @@ class ElmobiLm(torch.nn.Module): torch.cat([forward_state[1], backward_state[1]], -1))) stacked_sequence_outputs: torch.FloatTensor = torch.stack(sequence_outputs) - # Stack the hidden state and memory for each layer into 2 tensors of shape + # Stack the hidden state and memory for each layer in。to 2 tensors of shape # (num_layers, batch_size, hidden_size) and (num_layers, batch_size, cell_size) # respectively. final_hidden_states, final_memory_states = zip(*final_states) @@ -408,6 +409,66 @@ class ElmobiLm(torch.nn.Module): torch.cat(final_memory_states, 0)) return stacked_sequence_outputs, final_state_tuple + def load_weights(self, weight_file: str) -> None: + """ + Load the pre-trained weights from the file. + """ + requires_grad = False + + with h5py.File(weight_file, 'r') as fin: + for i_layer, lstms in enumerate( + zip(self.forward_layers, self.backward_layers) + ): + for j_direction, lstm in enumerate(lstms): + # lstm is an instance of LSTMCellWithProjection + cell_size = lstm.cell_size + + dataset = fin['RNN_%s' % j_direction]['RNN']['MultiRNNCell']['Cell%s' % i_layer + ]['LSTMCell'] + + # tensorflow packs together both W and U matrices into one matrix, + # but pytorch maintains individual matrices. In addition, tensorflow + # packs the gates as input, memory, forget, output but pytorch + # uses input, forget, memory, output. So we need to modify the weights. + tf_weights = numpy.transpose(dataset['W_0'][...]) + torch_weights = tf_weights.copy() + + # split the W from U matrices + input_size = lstm.input_size + input_weights = torch_weights[:, :input_size] + recurrent_weights = torch_weights[:, input_size:] + tf_input_weights = tf_weights[:, :input_size] + tf_recurrent_weights = tf_weights[:, input_size:] + + # handle the different gate order convention + for torch_w, tf_w in [[input_weights, tf_input_weights], + [recurrent_weights, tf_recurrent_weights]]: + torch_w[(1 * cell_size):(2 * cell_size), :] = tf_w[(2 * cell_size):(3 * cell_size), :] + torch_w[(2 * cell_size):(3 * cell_size), :] = tf_w[(1 * cell_size):(2 * cell_size), :] + + lstm.input_linearity.weight.data.copy_(torch.FloatTensor(input_weights)) + lstm.state_linearity.weight.data.copy_(torch.FloatTensor(recurrent_weights)) + lstm.input_linearity.weight.requires_grad = requires_grad + lstm.state_linearity.weight.requires_grad = requires_grad + + # the bias weights + tf_bias = dataset['B'][...] + # tensorflow adds 1.0 to forget gate bias instead of modifying the + # parameters... + tf_bias[(2 * cell_size):(3 * cell_size)] += 1 + torch_bias = tf_bias.copy() + torch_bias[(1 * cell_size):(2 * cell_size) + ] = tf_bias[(2 * cell_size):(3 * cell_size)] + torch_bias[(2 * cell_size):(3 * cell_size) + ] = tf_bias[(1 * cell_size):(2 * cell_size)] + lstm.state_linearity.bias.data.copy_(torch.FloatTensor(torch_bias)) + lstm.state_linearity.bias.requires_grad = requires_grad + + # the projection weights + proj_weights = numpy.transpose(dataset['W_P_0'][...]) + lstm.state_projection.weight.data.copy_(torch.FloatTensor(proj_weights)) + lstm.state_projection.weight.requires_grad = requires_grad + class LstmTokenEmbedder(nn.Module): def __init__(self, config, word_emb_layer, char_emb_layer): @@ -441,7 +502,7 @@ class LstmTokenEmbedder(nn.Module): chars_emb = self.char_emb_layer(chars) # TODO 这里应该要考虑seq_len的问题 _, (chars_outputs, __) = self.char_lstm(chars_emb) - chars_outputs = chars_outputs.contiguous().view(-1, self.config['token_embedder']['char_dim'] * 2) + chars_outputs = chars_outputs.contiguous().view(-1, self.config['token_embedder']['embedding']['dim'] * 2) embs.append(chars_outputs) token_embedding = torch.cat(embs, dim=2) @@ -450,79 +511,143 @@ class LstmTokenEmbedder(nn.Module): class ConvTokenEmbedder(nn.Module): - def __init__(self, config, word_emb_layer, char_emb_layer): + def __init__(self, config, weight_file, word_emb_layer, char_emb_layer, char_vocab): super(ConvTokenEmbedder, self).__init__() - self.config = config + self.weight_file = weight_file self.word_emb_layer = word_emb_layer self.char_emb_layer = char_emb_layer self.output_dim = config['encoder']['projection_dim'] - self.emb_dim = 0 - if word_emb_layer is not None: - self.emb_dim += word_emb_layer.weight.size(1) - - if char_emb_layer is not None: - self.convolutions = [] - cnn_config = config['token_embedder'] - filters = cnn_config['filters'] - char_embed_dim = cnn_config['char_dim'] - - for i, (width, num) in enumerate(filters): - conv = torch.nn.Conv1d( - in_channels=char_embed_dim, - out_channels=num, - kernel_size=width, - bias=True - ) - self.convolutions.append(conv) - - self.convolutions = nn.ModuleList(self.convolutions) - - self.n_filters = sum(f[1] for f in filters) - self.n_highway = cnn_config['n_highway'] - - self.highways = Highway(self.n_filters, self.n_highway, activation=torch.nn.functional.relu) - self.emb_dim += self.n_filters - - self.projection = nn.Linear(self.emb_dim, self.output_dim, bias=True) + self._options = config + self.requires_grad = False + self._load_weights() + self._char_embedding_weights = char_emb_layer.weight.data + + def _load_weights(self): + self._load_cnn_weights() + self._load_highway() + self._load_projection() + + def _load_cnn_weights(self): + cnn_options = self._options['token_embedder'] + filters = cnn_options['filters'] + char_embed_dim = cnn_options['embedding']['dim'] + + convolutions = [] + for i, (width, num) in enumerate(filters): + conv = torch.nn.Conv1d( + in_channels=char_embed_dim, + out_channels=num, + kernel_size=width, + bias=True + ) + # load the weights + with h5py.File(self.weight_file, 'r') as fin: + weight = fin['CNN']['W_cnn_{}'.format(i)][...] + bias = fin['CNN']['b_cnn_{}'.format(i)][...] + + w_reshaped = numpy.transpose(weight.squeeze(axis=0), axes=(2, 1, 0)) + if w_reshaped.shape != tuple(conv.weight.data.shape): + raise ValueError("Invalid weight file") + conv.weight.data.copy_(torch.FloatTensor(w_reshaped)) + conv.bias.data.copy_(torch.FloatTensor(bias)) + + conv.weight.requires_grad = self.requires_grad + conv.bias.requires_grad = self.requires_grad + + convolutions.append(conv) + self.add_module('char_conv_{}'.format(i), conv) + + self._convolutions = convolutions + + def _load_highway(self): + # the highway layers have same dimensionality as the number of cnn filters + cnn_options = self._options['token_embedder'] + filters = cnn_options['filters'] + n_filters = sum(f[1] for f in filters) + n_highway = cnn_options['n_highway'] + + # create the layers, and load the weights + self._highways = Highway(n_filters, n_highway, activation=torch.nn.functional.relu) + for k in range(n_highway): + # The AllenNLP highway is one matrix multplication with concatenation of + # transform and carry weights. + with h5py.File(self.weight_file, 'r') as fin: + # The weights are transposed due to multiplication order assumptions in tf + # vs pytorch (tf.matmul(X, W) vs pytorch.matmul(W, X)) + w_transform = numpy.transpose(fin['CNN_high_{}'.format(k)]['W_transform'][...]) + # -1.0 since AllenNLP is g * x + (1 - g) * f(x) but tf is (1 - g) * x + g * f(x) + w_carry = -1.0 * numpy.transpose(fin['CNN_high_{}'.format(k)]['W_carry'][...]) + weight = numpy.concatenate([w_transform, w_carry], axis=0) + self._highways._layers[k].weight.data.copy_(torch.FloatTensor(weight)) + self._highways._layers[k].weight.requires_grad = self.requires_grad + + b_transform = fin['CNN_high_{}'.format(k)]['b_transform'][...] + b_carry = -1.0 * fin['CNN_high_{}'.format(k)]['b_carry'][...] + bias = numpy.concatenate([b_transform, b_carry], axis=0) + self._highways._layers[k].bias.data.copy_(torch.FloatTensor(bias)) + self._highways._layers[k].bias.requires_grad = self.requires_grad + + def _load_projection(self): + cnn_options = self._options['token_embedder'] + filters = cnn_options['filters'] + n_filters = sum(f[1] for f in filters) + + self._projection = torch.nn.Linear(n_filters, self.output_dim, bias=True) + with h5py.File(self.weight_file, 'r') as fin: + weight = fin['CNN_proj']['W_proj'][...] + bias = fin['CNN_proj']['b_proj'][...] + self._projection.weight.data.copy_(torch.FloatTensor(numpy.transpose(weight))) + self._projection.bias.data.copy_(torch.FloatTensor(bias)) + + self._projection.weight.requires_grad = self.requires_grad + self._projection.bias.requires_grad = self.requires_grad def forward(self, words, chars): - embs = [] - if self.word_emb_layer is not None: - if hasattr(self, 'words_to_words'): - words = self.words_to_words[words] - word_emb = self.word_emb_layer(words) - embs.append(word_emb) + """ + :param words: + :param chars: Tensor Shape ``(batch_size, sequence_length, 50)``: + :return Tensor Shape ``(batch_size, sequence_length + 2, embedding_dim)`` : + """ + # the character id embedding + # (batch_size * sequence_length, max_chars_per_token, embed_dim) + # character_embedding = torch.nn.functional.embedding( + # chars.view(-1, max_chars_per_token), + # self._char_embedding_weights + # ) + batch_size, sequence_length, max_char_len = chars.size() + character_embedding = self.char_emb_layer(chars).reshape(batch_size*sequence_length, max_char_len, -1) + # run convolutions + cnn_options = self._options['token_embedder'] + if cnn_options['activation'] == 'tanh': + activation = torch.tanh + elif cnn_options['activation'] == 'relu': + activation = torch.nn.functional.relu + else: + raise Exception("Unknown activation") - if self.char_emb_layer is not None: - batch_size, seq_len, _ = chars.size() - chars = chars.view(batch_size * seq_len, -1) - character_embedding = self.char_emb_layer(chars) - character_embedding = torch.transpose(character_embedding, 1, 2) - - cnn_config = self.config['token_embedder'] - if cnn_config['activation'] == 'tanh': - activation = torch.nn.functional.tanh - elif cnn_config['activation'] == 'relu': - activation = torch.nn.functional.relu - else: - raise Exception("Unknown activation") + # (batch_size * sequence_length, embed_dim, max_chars_per_token) + character_embedding = torch.transpose(character_embedding, 1, 2) + convs = [] + for i in range(len(self._convolutions)): + conv = getattr(self, 'char_conv_{}'.format(i)) + convolved = conv(character_embedding) + # (batch_size * sequence_length, n_filters for this width) + convolved, _ = torch.max(convolved, dim=-1) + convolved = activation(convolved) + convs.append(convolved) - convs = [] - for i in range(len(self.convolutions)): - convolved = self.convolutions[i](character_embedding) - # (batch_size * sequence_length, n_filters for this width) - convolved, _ = torch.max(convolved, dim=-1) - convolved = activation(convolved) - convs.append(convolved) - char_emb = torch.cat(convs, dim=-1) - char_emb = self.highways(char_emb) + # (batch_size * sequence_length, n_filters) + token_embedding = torch.cat(convs, dim=-1) - embs.append(char_emb.view(batch_size, -1, self.n_filters)) + # apply the highway layers (batch_size * sequence_length, n_filters) + token_embedding = self._highways(token_embedding) - token_embedding = torch.cat(embs, dim=2) + # final projection (batch_size * sequence_length, embedding_dim) + token_embedding = self._projection(token_embedding) - return self.projection(token_embedding) + # reshape to (batch_size, sequence_length+2, embedding_dim) + return token_embedding.view(batch_size, sequence_length, -1) class Highway(torch.nn.Module): @@ -543,6 +668,7 @@ class Highway(torch.nn.Module): activation : ``Callable[[torch.Tensor], torch.Tensor]``, optional (default=``torch.nn.functional.relu``) The non-linearity to use in the highway layers. """ + def __init__(self, input_dim: int, num_layers: int = 1, @@ -573,6 +699,7 @@ class Highway(torch.nn.Module): current_input = gate * linear_part + (1 - gate) * nonlinear_part return current_input + class _ElmoModel(nn.Module): """ 该Module是ElmoEmbedding中进行所有的heavy lifting的地方。做的工作,包括 @@ -582,11 +709,32 @@ class _ElmoModel(nn.Module): (4) 设计一个保存token的embedding,允许缓存word的表示。 """ - def __init__(self, model_dir:str, vocab:Vocabulary=None, cache_word_reprs:bool=False): + + def __init__(self, model_dir: str, vocab: Vocabulary = None, cache_word_reprs: bool = False): super(_ElmoModel, self).__init__() - config = json.load(open(os.path.join(model_dir, 'structure_config.json'), 'r')) + dir = os.walk(model_dir) + config_file = None + weight_file = None + config_count = 0 + weight_count = 0 + for path, dir_list, file_list in dir: + for file_name in file_list: + if file_name.__contains__(".json"): + config_file = file_name + config_count += 1 + elif file_name.__contains__(".hdf5"): + weight_file = file_name + weight_count += 1 + if config_count > 1 or weight_count > 1: + raise Exception(f"Multiple config files(*.json) or weight files(*.hdf5) detected in {model_dir}.") + elif config_count == 0 or weight_count == 0: + raise Exception(f"No config file or weight file found in {model_dir}") + + config = json.load(open(os.path.join(model_dir, config_file), 'r')) + self.weight_file = os.path.join(model_dir, weight_file) self.config = config + self.requires_grad = False OOV_TAG = '' PAD_TAG = '' @@ -595,48 +743,8 @@ class _ElmoModel(nn.Module): BOW_TAG = '' EOW_TAG = '' - # 将加载embedding放到这里 - token_embedder_states = torch.load(os.path.join(model_dir, 'token_embedder.pkl'), map_location='cpu') - - # For the model trained with word form word encoder. - if config['token_embedder']['word_dim'] > 0: - word_lexicon = {} - with codecs.open(os.path.join(model_dir, 'word.dic'), 'r', encoding='utf-8') as fpi: - for line in fpi: - tokens = line.strip().split('\t') - if len(tokens) == 1: - tokens.insert(0, '\u3000') - token, i = tokens - word_lexicon[token] = int(i) - # 做一些sanity check - for special_word in [PAD_TAG, OOV_TAG, BOS_TAG, EOS_TAG]: - assert special_word in word_lexicon, f"{special_word} not found in word.dic." - # 根据vocab调整word_embedding - pre_word_embedding = token_embedder_states.pop('word_emb_layer.embedding.weight') - word_emb_layer = nn.Embedding(len(vocab)+2, config['token_embedder']['word_dim']) #多增加两个是为了 - found_word_count = 0 - for word, index in vocab: - if index == vocab.unknown_idx: # 因为fastNLP的unknow是 而在这里是所以ugly强制适配一下 - index_in_pre = word_lexicon[OOV_TAG] - found_word_count += 1 - elif index == vocab.padding_idx: # 需要pad对齐 - index_in_pre = word_lexicon[PAD_TAG] - found_word_count += 1 - elif word in word_lexicon: - index_in_pre = word_lexicon[word] - found_word_count += 1 - else: - index_in_pre = word_lexicon[OOV_TAG] - word_emb_layer.weight.data[index] = pre_word_embedding[index_in_pre] - print(f"{found_word_count} out of {len(vocab)} words were found in pretrained elmo embedding.") - word_emb_layer.weight.data[-1] = pre_word_embedding[word_lexicon[EOS_TAG]] - word_emb_layer.weight.data[-2] = pre_word_embedding[word_lexicon[BOS_TAG]] - self.word_vocab = vocab - else: - word_emb_layer = None - # For the model trained with character-based word encoder. - if config['token_embedder']['char_dim'] > 0: + if config['token_embedder']['embedding']['dim'] > 0: char_lexicon = {} with codecs.open(os.path.join(model_dir, 'char.dic'), 'r', encoding='utf-8') as fpi: for line in fpi: @@ -645,22 +753,26 @@ class _ElmoModel(nn.Module): tokens.insert(0, '\u3000') token, i = tokens char_lexicon[token] = int(i) + # 做一些sanity check for special_word in [PAD_TAG, OOV_TAG, BOW_TAG, EOW_TAG]: assert special_word in char_lexicon, f"{special_word} not found in char.dic." + # 从vocab中构建char_vocab char_vocab = Vocabulary(unknown=OOV_TAG, padding=PAD_TAG) # 需要保证在里面 - char_vocab.add_word(BOW_TAG) - char_vocab.add_word(EOW_TAG) + char_vocab.add_word_lst([BOW_TAG, EOW_TAG, BOS_TAG, EOS_TAG]) + for word, index in vocab: char_vocab.add_word_lst(list(word)) - # 保证, 也在 - char_vocab.add_word_lst(list(BOS_TAG)) - char_vocab.add_word_lst(list(EOS_TAG)) - # 根据char_lexicon调整 - char_emb_layer = nn.Embedding(len(char_vocab), int(config['token_embedder']['char_dim'])) - pre_char_embedding = token_embedder_states.pop('char_emb_layer.embedding.weight') + + self.bos_index, self.eos_index, self._pad_index = len(vocab), len(vocab)+1, vocab.padding_idx + # 根据char_lexicon调整, 多设置一位,是预留给word padding的(该位置的char表示为全0表示) + char_emb_layer = nn.Embedding(len(char_vocab)+1, int(config['token_embedder']['embedding']['dim']), + padding_idx=len(char_vocab)) + with h5py.File(self.weight_file, 'r') as fin: + char_embed_weights = fin['char_embed'][...] + char_embed_weights = torch.from_numpy(char_embed_weights) found_char_count = 0 for char, index in char_vocab: # 调整character embedding if char in char_lexicon: @@ -668,79 +780,84 @@ class _ElmoModel(nn.Module): found_char_count += 1 else: index_in_pre = char_lexicon[OOV_TAG] - char_emb_layer.weight.data[index] = pre_char_embedding[index_in_pre] + char_emb_layer.weight.data[index] = char_embed_weights[index_in_pre] + print(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.") # 生成words到chars的映射 if config['token_embedder']['name'].lower() == 'cnn': max_chars = config['token_embedder']['max_characters_per_token'] elif config['token_embedder']['name'].lower() == 'lstm': - max_chars = max(map(lambda x: len(x[0]), vocab)) + 2 # 需要补充两个 + max_chars = max(map(lambda x: len(x[0]), vocab)) + 2 # 需要补充两个 else: raise ValueError('Unknown token_embedder: {0}'.format(config['token_embedder']['name'])) - # 增加, 所以加2. + self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab)+2, max_chars), - fill_value=char_vocab.to_index(PAD_TAG), dtype=torch.long), + fill_value=len(char_vocab), + dtype=torch.long), requires_grad=False) - for word, index in vocab: - if len(word)+2>max_chars: - word = word[:max_chars-2] - if index==vocab.padding_idx: # 如果是pad的话,需要和给定的对齐 - word = PAD_TAG - elif index==vocab.unknown_idx: - word = OOV_TAG - char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(c) for c in word] + [char_vocab.to_index(EOW_TAG)] - char_ids += [char_vocab.to_index(PAD_TAG)]*(max_chars-len(char_ids)) + for word, index in list(iter(vocab)) + [(BOS_TAG, len(vocab)), (EOS_TAG, len(vocab)+1)]: + if len(word) + 2 > max_chars: + word = word[:max_chars - 2] + if index == self._pad_index: + continue + elif word == BOS_TAG or word == EOS_TAG: + char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(word)] + [ + char_vocab.to_index(EOW_TAG)] + char_ids += [char_vocab.to_index(PAD_TAG)] * (max_chars - len(char_ids)) + else: + char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(c) for c in word] + [ + char_vocab.to_index(EOW_TAG)] + char_ids += [char_vocab.to_index(PAD_TAG)] * (max_chars - len(char_ids)) self.words_to_chars_embedding[index] = torch.LongTensor(char_ids) - for index, word in enumerate([BOS_TAG, EOS_TAG]): # 加上, - if len(word)+2>max_chars: - word = word[:max_chars-2] - char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(c) for c in word] + [char_vocab.to_index(EOW_TAG)] - char_ids += [char_vocab.to_index(PAD_TAG)]*(max_chars-len(char_ids)) - self.words_to_chars_embedding[index+len(vocab)] = torch.LongTensor(char_ids) + self.char_vocab = char_vocab else: char_emb_layer = None if config['token_embedder']['name'].lower() == 'cnn': self.token_embedder = ConvTokenEmbedder( - config, word_emb_layer, char_emb_layer) + config, self.weight_file, None, char_emb_layer, self.char_vocab) elif config['token_embedder']['name'].lower() == 'lstm': self.token_embedder = LstmTokenEmbedder( - config, word_emb_layer, char_emb_layer) - self.token_embedder.load_state_dict(token_embedder_states, strict=False) - if config['token_embedder']['word_dim'] > 0 and vocab._no_create_word_length > 0: # 需要映射,使得来自于dev, test的idx指向unk - words_to_words = nn.Parameter(torch.arange(len(vocab)+2).long(), requires_grad=False) + config, None, char_emb_layer) + + if config['token_embedder']['word_dim'] > 0 \ + and vocab._no_create_word_length > 0: # 需要映射,使得来自于dev, test的idx指向unk + words_to_words = nn.Parameter(torch.arange(len(vocab) + 2).long(), requires_grad=False) for word, idx in vocab: if vocab._is_word_no_create_entry(word): words_to_words[idx] = vocab.unknown_idx setattr(self.token_embedder, 'words_to_words', words_to_words) self.output_dim = config['encoder']['projection_dim'] + # 暂时只考虑 elmo if config['encoder']['name'].lower() == 'elmo': self.encoder = ElmobiLm(config) elif config['encoder']['name'].lower() == 'lstm': self.encoder = LstmbiLm(config) - self.encoder.load_state_dict(torch.load(os.path.join(model_dir, 'encoder.pkl'), - map_location='cpu')) - self.bos_index = len(vocab) - self.eos_index = len(vocab) + 1 - self._pad_index = vocab.padding_idx + self.encoder.load_weights(self.weight_file) if cache_word_reprs: - if config['token_embedder']['char_dim']>0: # 只有在使用了chars的情况下有用 + if config['token_embedder']['embedding']['dim'] > 0: # 只有在使用了chars的情况下有用 print("Start to generate cache word representations.") batch_size = 320 - num_batches = self.words_to_chars_embedding.size(0)//batch_size + \ - int(self.words_to_chars_embedding.size(0)%batch_size!=0) - self.cached_word_embedding = nn.Embedding(self.words_to_chars_embedding.size(0), + # bos eos + word_size = self.words_to_chars_embedding.size(0) + num_batches = word_size // batch_size + \ + int(word_size % batch_size != 0) + + self.cached_word_embedding = nn.Embedding(word_size, config['encoder']['projection_dim']) with torch.no_grad(): for i in range(num_batches): - words = torch.arange(i*batch_size, min((i+1)*batch_size, self.words_to_chars_embedding.size(0))).long() + words = torch.arange(i * batch_size, + min((i + 1) * batch_size, word_size)).long() chars = self.words_to_chars_embedding[words].unsqueeze(1) # batch_size x 1 x max_chars - word_reprs = self.token_embedder(words.unsqueeze(1), chars).detach() # batch_size x 1 x config['encoder']['projection_dim'] + word_reprs = self.token_embedder(words.unsqueeze(1), + chars).detach() # batch_size x 1 x config['encoder']['projection_dim'] self.cached_word_embedding.weight.data[words] = word_reprs.squeeze(1) + print("Finish generating cached word representations. Going to delete the character encoder.") del self.token_embedder, self.words_to_chars_embedding else: @@ -758,7 +875,7 @@ class _ElmoModel(nn.Module): seq_len = words.ne(self._pad_index).sum(dim=-1) expanded_words[:, 1:-1] = words expanded_words[:, 0].fill_(self.bos_index) - expanded_words[torch.arange(batch_size).to(words), seq_len+1] = self.eos_index + expanded_words[torch.arange(batch_size).to(words), seq_len + 1] = self.eos_index seq_len = seq_len + 2 if hasattr(self, 'cached_word_embedding'): token_embedding = self.cached_word_embedding(expanded_words) @@ -767,16 +884,18 @@ class _ElmoModel(nn.Module): chars = self.words_to_chars_embedding[expanded_words] else: chars = None - token_embedding = self.token_embedder(expanded_words, chars) + token_embedding = self.token_embedder(expanded_words, chars) # batch_size x max_len x embed_dim + if self.config['encoder']['name'] == 'elmo': encoder_output = self.encoder(token_embedding, seq_len) - if encoder_output.size(2) < max_len+2: - dummy_tensor = encoder_output.new_zeros(encoder_output.size(0), batch_size, - max_len + 2 - encoder_output.size(2), encoder_output.size(-1)) - encoder_output = torch.cat([encoder_output, dummy_tensor], 2) - sz = encoder_output.size() # 2, batch_size, max_len, hidden_size - token_embedding = torch.cat([token_embedding, token_embedding], dim=2).view(1, sz[1], sz[2], sz[3]) - encoder_output = torch.cat([token_embedding, encoder_output], dim=0) + if encoder_output.size(2) < max_len + 2: + num_layers, _, output_len, hidden_size = encoder_output.size() + dummy_tensor = encoder_output.new_zeros(num_layers, batch_size, + max_len + 2 - output_len, hidden_size) + encoder_output = torch.cat((encoder_output, dummy_tensor), 2) + sz = encoder_output.size() # 2, batch_size, max_len, hidden_size + token_embedding = torch.cat((token_embedding, token_embedding), dim=2).view(1, sz[1], sz[2], sz[3]) + encoder_output = torch.cat((token_embedding, encoder_output), dim=0) elif self.config['encoder']['name'] == 'lstm': encoder_output = self.encoder(token_embedding, seq_len) else: @@ -784,5 +903,4 @@ class _ElmoModel(nn.Module): # 删除, . 这里没有精确地删除,但应该也不会影响最后的结果了。 encoder_output = encoder_output[:, :, 1:-1] - return encoder_output