Browse Source

删除elmo对h5py的依赖

tags/v0.4.10
yh 5 years ago
parent
commit
2c00c1ae5a
2 changed files with 144 additions and 317 deletions
  1. +144
    -316
      fastNLP/modules/encoder/_elmo.py
  2. +0
    -1
      requirements.txt

+ 144
- 316
fastNLP/modules/encoder/_elmo.py View File

@@ -6,14 +6,13 @@ from typing import Optional, Tuple, List, Callable

import os

import h5py
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence
from ...core.vocabulary import Vocabulary
import json
import pickle

from ..utils import get_dropout_mask
import codecs
@@ -244,13 +243,13 @@ class LstmbiLm(nn.Module):
def __init__(self, config):
super(LstmbiLm, self).__init__()
self.config = config
self.encoder = nn.LSTM(self.config['encoder']['projection_dim'],
self.config['encoder']['dim'],
num_layers=self.config['encoder']['n_layers'],
self.encoder = nn.LSTM(self.config['lstm']['projection_dim'],
self.config['lstm']['dim'],
num_layers=self.config['lstm']['n_layers'],
bidirectional=True,
batch_first=True,
dropout=self.config['dropout'])
self.projection = nn.Linear(self.config['encoder']['dim'], self.config['encoder']['projection_dim'], bias=True)
self.projection = nn.Linear(self.config['lstm']['dim'], self.config['lstm']['projection_dim'], bias=True)

def forward(self, inputs, seq_len):
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True)
@@ -260,7 +259,7 @@ class LstmbiLm(nn.Module):
output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=self.batch_first)
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)
output = output[unsort_idx]
forward, backward = output.split(self.config['encoder']['dim'], 2)
forward, backward = output.split(self.config['lstm']['dim'], 2)
return torch.cat([self.projection(forward), self.projection(backward)], dim=2)


@@ -268,13 +267,13 @@ class ElmobiLm(torch.nn.Module):
def __init__(self, config):
super(ElmobiLm, self).__init__()
self.config = config
input_size = config['encoder']['projection_dim']
hidden_size = config['encoder']['projection_dim']
cell_size = config['encoder']['dim']
num_layers = config['encoder']['n_layers']
memory_cell_clip_value = config['encoder']['cell_clip']
state_projection_clip_value = config['encoder']['proj_clip']
recurrent_dropout_probability = config['dropout']
input_size = config['lstm']['projection_dim']
hidden_size = config['lstm']['projection_dim']
cell_size = config['lstm']['dim']
num_layers = config['lstm']['n_layers']
memory_cell_clip_value = config['lstm']['cell_clip']
state_projection_clip_value = config['lstm']['proj_clip']
recurrent_dropout_probability = 0.0

self.input_size = input_size
self.hidden_size = hidden_size
@@ -409,199 +408,52 @@ class ElmobiLm(torch.nn.Module):
torch.cat(final_memory_states, 0))
return stacked_sequence_outputs, final_state_tuple

def load_weights(self, weight_file: str) -> None:
"""
Load the pre-trained weights from the file.
"""
requires_grad = False

with h5py.File(weight_file, 'r') as fin:
for i_layer, lstms in enumerate(
zip(self.forward_layers, self.backward_layers)
):
for j_direction, lstm in enumerate(lstms):
# lstm is an instance of LSTMCellWithProjection
cell_size = lstm.cell_size

dataset = fin['RNN_%s' % j_direction]['RNN']['MultiRNNCell']['Cell%s' % i_layer
]['LSTMCell']

# tensorflow packs together both W and U matrices into one matrix,
# but pytorch maintains individual matrices. In addition, tensorflow
# packs the gates as input, memory, forget, output but pytorch
# uses input, forget, memory, output. So we need to modify the weights.
tf_weights = numpy.transpose(dataset['W_0'][...])
torch_weights = tf_weights.copy()

# split the W from U matrices
input_size = lstm.input_size
input_weights = torch_weights[:, :input_size]
recurrent_weights = torch_weights[:, input_size:]
tf_input_weights = tf_weights[:, :input_size]
tf_recurrent_weights = tf_weights[:, input_size:]

# handle the different gate order convention
for torch_w, tf_w in [[input_weights, tf_input_weights],
[recurrent_weights, tf_recurrent_weights]]:
torch_w[(1 * cell_size):(2 * cell_size), :] = tf_w[(2 * cell_size):(3 * cell_size), :]
torch_w[(2 * cell_size):(3 * cell_size), :] = tf_w[(1 * cell_size):(2 * cell_size), :]

lstm.input_linearity.weight.data.copy_(torch.FloatTensor(input_weights))
lstm.state_linearity.weight.data.copy_(torch.FloatTensor(recurrent_weights))
lstm.input_linearity.weight.requires_grad = requires_grad
lstm.state_linearity.weight.requires_grad = requires_grad

# the bias weights
tf_bias = dataset['B'][...]
# tensorflow adds 1.0 to forget gate bias instead of modifying the
# parameters...
tf_bias[(2 * cell_size):(3 * cell_size)] += 1
torch_bias = tf_bias.copy()
torch_bias[(1 * cell_size):(2 * cell_size)
] = tf_bias[(2 * cell_size):(3 * cell_size)]
torch_bias[(2 * cell_size):(3 * cell_size)
] = tf_bias[(1 * cell_size):(2 * cell_size)]
lstm.state_linearity.bias.data.copy_(torch.FloatTensor(torch_bias))
lstm.state_linearity.bias.requires_grad = requires_grad

# the projection weights
proj_weights = numpy.transpose(dataset['W_P_0'][...])
lstm.state_projection.weight.data.copy_(torch.FloatTensor(proj_weights))
lstm.state_projection.weight.requires_grad = requires_grad


class LstmTokenEmbedder(nn.Module):
def __init__(self, config, word_emb_layer, char_emb_layer):
super(LstmTokenEmbedder, self).__init__()
self.config = config
self.word_emb_layer = word_emb_layer
self.char_emb_layer = char_emb_layer
self.output_dim = config['encoder']['projection_dim']
emb_dim = 0
if word_emb_layer is not None:
emb_dim += word_emb_layer.n_d

if char_emb_layer is not None:
emb_dim += char_emb_layer.n_d * 2
self.char_lstm = nn.LSTM(char_emb_layer.n_d, char_emb_layer.n_d, num_layers=1, bidirectional=True,
batch_first=True, dropout=config['dropout'])

self.projection = nn.Linear(emb_dim, self.output_dim, bias=True)

def forward(self, words, chars):
embs = []
if self.word_emb_layer is not None:
if hasattr(self, 'words_to_words'):
words = self.words_to_words[words]
word_emb = self.word_emb_layer(words)
embs.append(word_emb)

if self.char_emb_layer is not None:
batch_size, seq_len, _ = chars.shape
chars = chars.view(batch_size * seq_len, -1)
chars_emb = self.char_emb_layer(chars)
# TODO 这里应该要考虑seq_len的问题
_, (chars_outputs, __) = self.char_lstm(chars_emb)
chars_outputs = chars_outputs.contiguous().view(-1, self.config['token_embedder']['embedding']['dim'] * 2)
embs.append(chars_outputs)

token_embedding = torch.cat(embs, dim=2)

return self.projection(token_embedding)


class ConvTokenEmbedder(nn.Module):
def __init__(self, config, weight_file, word_emb_layer, char_emb_layer, char_vocab):
def __init__(self, config, weight_file, word_emb_layer, char_emb_layer):
super(ConvTokenEmbedder, self).__init__()
self.weight_file = weight_file
self.word_emb_layer = word_emb_layer
self.char_emb_layer = char_emb_layer

self.output_dim = config['encoder']['projection_dim']
self.output_dim = config['lstm']['projection_dim']
self._options = config
self.requires_grad = False
self._load_weights()
self._char_embedding_weights = char_emb_layer.weight.data

def _load_weights(self):
self._load_cnn_weights()
self._load_highway()
self._load_projection()

def _load_cnn_weights(self):
cnn_options = self._options['token_embedder']
filters = cnn_options['filters']
char_embed_dim = cnn_options['embedding']['dim']

convolutions = []
for i, (width, num) in enumerate(filters):
conv = torch.nn.Conv1d(
in_channels=char_embed_dim,
out_channels=num,
kernel_size=width,
bias=True
)
# load the weights
with h5py.File(self.weight_file, 'r') as fin:
weight = fin['CNN']['W_cnn_{}'.format(i)][...]
bias = fin['CNN']['b_cnn_{}'.format(i)][...]

w_reshaped = numpy.transpose(weight.squeeze(axis=0), axes=(2, 1, 0))
if w_reshaped.shape != tuple(conv.weight.data.shape):
raise ValueError("Invalid weight file")
conv.weight.data.copy_(torch.FloatTensor(w_reshaped))
conv.bias.data.copy_(torch.FloatTensor(bias))

conv.weight.requires_grad = self.requires_grad
conv.bias.requires_grad = self.requires_grad

convolutions.append(conv)
self.add_module('char_conv_{}'.format(i), conv)

self._convolutions = convolutions

def _load_highway(self):
# the highway layers have same dimensionality as the number of cnn filters
cnn_options = self._options['token_embedder']
filters = cnn_options['filters']
n_filters = sum(f[1] for f in filters)
n_highway = cnn_options['n_highway']

# create the layers, and load the weights
self._highways = Highway(n_filters, n_highway, activation=torch.nn.functional.relu)
for k in range(n_highway):
# The AllenNLP highway is one matrix multplication with concatenation of
# transform and carry weights.
with h5py.File(self.weight_file, 'r') as fin:
# The weights are transposed due to multiplication order assumptions in tf
# vs pytorch (tf.matmul(X, W) vs pytorch.matmul(W, X))
w_transform = numpy.transpose(fin['CNN_high_{}'.format(k)]['W_transform'][...])
# -1.0 since AllenNLP is g * x + (1 - g) * f(x) but tf is (1 - g) * x + g * f(x)
w_carry = -1.0 * numpy.transpose(fin['CNN_high_{}'.format(k)]['W_carry'][...])
weight = numpy.concatenate([w_transform, w_carry], axis=0)
self._highways._layers[k].weight.data.copy_(torch.FloatTensor(weight))
self._highways._layers[k].weight.requires_grad = self.requires_grad

b_transform = fin['CNN_high_{}'.format(k)]['b_transform'][...]
b_carry = -1.0 * fin['CNN_high_{}'.format(k)]['b_carry'][...]
bias = numpy.concatenate([b_transform, b_carry], axis=0)
self._highways._layers[k].bias.data.copy_(torch.FloatTensor(bias))
self._highways._layers[k].bias.requires_grad = self.requires_grad

def _load_projection(self):
cnn_options = self._options['token_embedder']
filters = cnn_options['filters']
n_filters = sum(f[1] for f in filters)

self._projection = torch.nn.Linear(n_filters, self.output_dim, bias=True)
with h5py.File(self.weight_file, 'r') as fin:
weight = fin['CNN_proj']['W_proj'][...]
bias = fin['CNN_proj']['b_proj'][...]
self._projection.weight.data.copy_(torch.FloatTensor(numpy.transpose(weight)))
self._projection.bias.data.copy_(torch.FloatTensor(bias))

self._projection.weight.requires_grad = self.requires_grad
self._projection.bias.requires_grad = self.requires_grad
char_cnn_options = self._options['char_cnn']
if char_cnn_options['activation'] == 'tanh':
self.activation = torch.tanh
elif char_cnn_options['activation'] == 'relu':
self.activation = torch.nn.functional.relu
else:
raise Exception("Unknown activation")

if char_emb_layer is not None:
self.char_conv = []
cnn_config = config['char_cnn']
filters = cnn_config['filters']
char_embed_dim = cnn_config['embedding']['dim']
convolutions = []

for i, (width, num) in enumerate(filters):
conv = torch.nn.Conv1d(
in_channels=char_embed_dim,
out_channels=num,
kernel_size=width,
bias=True
)
convolutions.append(conv)
self.add_module('char_conv_{}'.format(i), conv)

self._convolutions = convolutions

n_filters = sum(f[1] for f in filters)
n_highway = cnn_config['n_highway']

self._highways = Highway(n_filters, n_highway, activation=torch.nn.functional.relu)

self._projection = torch.nn.Linear(n_filters, self.output_dim, bias=True)

def forward(self, words, chars):
"""
@@ -616,15 +468,8 @@ class ConvTokenEmbedder(nn.Module):
# self._char_embedding_weights
# )
batch_size, sequence_length, max_char_len = chars.size()
character_embedding = self.char_emb_layer(chars).reshape(batch_size*sequence_length, max_char_len, -1)
character_embedding = self.char_emb_layer(chars).reshape(batch_size * sequence_length, max_char_len, -1)
# run convolutions
cnn_options = self._options['token_embedder']
if cnn_options['activation'] == 'tanh':
activation = torch.tanh
elif cnn_options['activation'] == 'relu':
activation = torch.nn.functional.relu
else:
raise Exception("Unknown activation")

# (batch_size * sequence_length, embed_dim, max_chars_per_token)
character_embedding = torch.transpose(character_embedding, 1, 2)
@@ -634,7 +479,7 @@ class ConvTokenEmbedder(nn.Module):
convolved = conv(character_embedding)
# (batch_size * sequence_length, n_filters for this width)
convolved, _ = torch.max(convolved, dim=-1)
convolved = activation(convolved)
convolved = self.activation(convolved)
convs.append(convolved)

# (batch_size * sequence_length, n_filters)
@@ -712,8 +557,9 @@ class _ElmoModel(nn.Module):

def __init__(self, model_dir: str, vocab: Vocabulary = None, cache_word_reprs: bool = False):
super(_ElmoModel, self).__init__()

dir = os.walk(model_dir)
# self.pkl_dict = {}
self.model_dir = model_dir
dir = os.walk(self.model_dir)
config_file = None
weight_file = None
config_count = 0
@@ -723,7 +569,7 @@ class _ElmoModel(nn.Module):
if file_name.__contains__(".json"):
config_file = file_name
config_count += 1
elif file_name.__contains__(".hdf5"):
elif file_name.__contains__(".pkl"):
weight_file = file_name
weight_count += 1
if config_count > 1 or weight_count > 1:
@@ -744,102 +590,86 @@ class _ElmoModel(nn.Module):
EOW_TAG = '<eow>'

# For the model trained with character-based word encoder.
if config['token_embedder']['embedding']['dim'] > 0:
char_lexicon = {}
with codecs.open(os.path.join(model_dir, 'char.dic'), 'r', encoding='utf-8') as fpi:
for line in fpi:
tokens = line.strip().split('\t')
if len(tokens) == 1:
tokens.insert(0, '\u3000')
token, i = tokens
char_lexicon[token] = int(i)

# 做一些sanity check
for special_word in [PAD_TAG, OOV_TAG, BOW_TAG, EOW_TAG]:
assert special_word in char_lexicon, f"{special_word} not found in char.dic."

# 从vocab中构建char_vocab
char_vocab = Vocabulary(unknown=OOV_TAG, padding=PAD_TAG)
# 需要保证<bow>与<eow>在里面
char_vocab.add_word_lst([BOW_TAG, EOW_TAG, BOS_TAG, EOS_TAG])

for word, index in vocab:
char_vocab.add_word_lst(list(word))

self.bos_index, self.eos_index, self._pad_index = len(vocab), len(vocab)+1, vocab.padding_idx
# 根据char_lexicon调整, 多设置一位,是预留给word padding的(该位置的char表示为全0表示)
char_emb_layer = nn.Embedding(len(char_vocab)+1, int(config['token_embedder']['embedding']['dim']),
padding_idx=len(char_vocab))
with h5py.File(self.weight_file, 'r') as fin:
char_embed_weights = fin['char_embed'][...]
char_embed_weights = torch.from_numpy(char_embed_weights)
found_char_count = 0
for char, index in char_vocab: # 调整character embedding
if char in char_lexicon:
index_in_pre = char_lexicon.get(char)
found_char_count += 1
else:
index_in_pre = char_lexicon[OOV_TAG]
char_emb_layer.weight.data[index] = char_embed_weights[index_in_pre]

print(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.")
# 生成words到chars的映射
if config['token_embedder']['name'].lower() == 'cnn':
max_chars = config['token_embedder']['max_characters_per_token']
elif config['token_embedder']['name'].lower() == 'lstm':
max_chars = max(map(lambda x: len(x[0]), vocab)) + 2 # 需要补充两个<bow>与<eow>
char_lexicon = {}
with codecs.open(os.path.join(model_dir, 'char.dic'), 'r', encoding='utf-8') as fpi:
for line in fpi:
tokens = line.strip().split('\t')
if len(tokens) == 1:
tokens.insert(0, '\u3000')
token, i = tokens
char_lexicon[token] = int(i)

# 做一些sanity check
for special_word in [PAD_TAG, OOV_TAG, BOW_TAG, EOW_TAG]:
assert special_word in char_lexicon, f"{special_word} not found in char.dic."

# 从vocab中构建char_vocab
char_vocab = Vocabulary(unknown=OOV_TAG, padding=PAD_TAG)
# 需要保证<bow>与<eow>在里面
char_vocab.add_word_lst([BOW_TAG, EOW_TAG, BOS_TAG, EOS_TAG])

for word, index in vocab:
char_vocab.add_word_lst(list(word))

self.bos_index, self.eos_index, self._pad_index = len(vocab), len(vocab) + 1, vocab.padding_idx
# 根据char_lexicon调整, 多设置一位,是预留给word padding的(该位置的char表示为全0表示)
char_emb_layer = nn.Embedding(len(char_vocab) + 1, int(config['char_cnn']['embedding']['dim']),
padding_idx=len(char_vocab))

# 读入预训练权重 这里的elmo_model 是个dict 有char_embed的值以及char_cnn和 lstm 的 state_dict
elmo_pkl = open(os.path.join(self.model_dir, weight_file), "rb")
elmo_model = pickle.load(elmo_pkl)
elmo_pkl.close()

self.char_embed_weights = elmo_model["char_embed"]

found_char_count = 0
for char, index in char_vocab: # 调整character embedding
if char in char_lexicon:
index_in_pre = char_lexicon.get(char)
found_char_count += 1
else:
raise ValueError('Unknown token_embedder: {0}'.format(config['token_embedder']['name']))

self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab)+2, max_chars),
fill_value=len(char_vocab),
dtype=torch.long),
requires_grad=False)
for word, index in list(iter(vocab)) + [(BOS_TAG, len(vocab)), (EOS_TAG, len(vocab)+1)]:
if len(word) + 2 > max_chars:
word = word[:max_chars - 2]
if index == self._pad_index:
continue
elif word == BOS_TAG or word == EOS_TAG:
char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(word)] + [
char_vocab.to_index(EOW_TAG)]
char_ids += [char_vocab.to_index(PAD_TAG)] * (max_chars - len(char_ids))
else:
char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(c) for c in word] + [
char_vocab.to_index(EOW_TAG)]
char_ids += [char_vocab.to_index(PAD_TAG)] * (max_chars - len(char_ids))
self.words_to_chars_embedding[index] = torch.LongTensor(char_ids)

self.char_vocab = char_vocab
else:
char_emb_layer = None

if config['token_embedder']['name'].lower() == 'cnn':
self.token_embedder = ConvTokenEmbedder(
config, self.weight_file, None, char_emb_layer, self.char_vocab)
elif config['token_embedder']['name'].lower() == 'lstm':
self.token_embedder = LstmTokenEmbedder(
config, None, char_emb_layer)

if config['token_embedder']['word_dim'] > 0 \
and vocab._no_create_word_length > 0: # 需要映射,使得来自于dev, test的idx指向unk
words_to_words = nn.Parameter(torch.arange(len(vocab) + 2).long(), requires_grad=False)
for word, idx in vocab:
if vocab._is_word_no_create_entry(word):
words_to_words[idx] = vocab.unknown_idx
setattr(self.token_embedder, 'words_to_words', words_to_words)
self.output_dim = config['encoder']['projection_dim']

# 暂时只考虑 elmo
if config['encoder']['name'].lower() == 'elmo':
self.encoder = ElmobiLm(config)
elif config['encoder']['name'].lower() == 'lstm':
self.encoder = LstmbiLm(config)

self.encoder.load_weights(self.weight_file)
index_in_pre = char_lexicon[OOV_TAG]
char_emb_layer.weight.data[index] = self.char_embed_weights[index_in_pre]

print(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.")
# 生成words到chars的映射
max_chars = config['char_cnn']['max_characters_per_token']

self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab) + 2, max_chars),
fill_value=len(char_vocab),
dtype=torch.long),
requires_grad=False)
for word, index in list(iter(vocab)) + [(BOS_TAG, len(vocab)), (EOS_TAG, len(vocab) + 1)]:
if len(word) + 2 > max_chars:
word = word[:max_chars - 2]
if index == self._pad_index:
continue
elif word == BOS_TAG or word == EOS_TAG:
char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(word)] + [
char_vocab.to_index(EOW_TAG)]
char_ids += [char_vocab.to_index(PAD_TAG)] * (max_chars - len(char_ids))
else:
char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(c) for c in word] + [
char_vocab.to_index(EOW_TAG)]
char_ids += [char_vocab.to_index(PAD_TAG)] * (max_chars - len(char_ids))
self.words_to_chars_embedding[index] = torch.LongTensor(char_ids)

self.char_vocab = char_vocab

self.token_embedder = ConvTokenEmbedder(
config, self.weight_file, None, char_emb_layer)

self.token_embedder.load_state_dict(elmo_model["char_cnn"])

self.output_dim = config['lstm']['projection_dim']

# lstm encoder
self.encoder = ElmobiLm(config)
self.encoder.load_state_dict(elmo_model["lstm"])

if cache_word_reprs:
if config['token_embedder']['embedding']['dim'] > 0: # 只有在使用了chars的情况下有用
if config['char_cnn']['embedding']['dim'] > 0: # 只有在使用了chars的情况下有用
print("Start to generate cache word representations.")
batch_size = 320
# bos eos
@@ -848,7 +678,7 @@ class _ElmoModel(nn.Module):
int(word_size % batch_size != 0)

self.cached_word_embedding = nn.Embedding(word_size,
config['encoder']['projection_dim'])
config['lstm']['projection_dim'])
with torch.no_grad():
for i in range(num_batches):
words = torch.arange(i * batch_size,
@@ -877,6 +707,8 @@ class _ElmoModel(nn.Module):
expanded_words[:, 0].fill_(self.bos_index)
expanded_words[torch.arange(batch_size).to(words), seq_len + 1] = self.eos_index
seq_len = seq_len + 2
zero_tensor = torch.zeros(expanded_words.shape).long()
mask = (expanded_words == zero_tensor).unsqueeze(-1)
if hasattr(self, 'cached_word_embedding'):
token_embedding = self.cached_word_embedding(expanded_words)
else:
@@ -886,20 +718,16 @@ class _ElmoModel(nn.Module):
chars = None
token_embedding = self.token_embedder(expanded_words, chars) # batch_size x max_len x embed_dim

if self.config['encoder']['name'] == 'elmo':
encoder_output = self.encoder(token_embedding, seq_len)
if encoder_output.size(2) < max_len + 2:
num_layers, _, output_len, hidden_size = encoder_output.size()
dummy_tensor = encoder_output.new_zeros(num_layers, batch_size,
max_len + 2 - output_len, hidden_size)
encoder_output = torch.cat((encoder_output, dummy_tensor), 2)
sz = encoder_output.size() # 2, batch_size, max_len, hidden_size
token_embedding = torch.cat((token_embedding, token_embedding), dim=2).view(1, sz[1], sz[2], sz[3])
encoder_output = torch.cat((token_embedding, encoder_output), dim=0)
elif self.config['encoder']['name'] == 'lstm':
encoder_output = self.encoder(token_embedding, seq_len)
else:
raise ValueError('Unknown encoder: {0}'.format(self.config['encoder']['name']))
encoder_output = self.encoder(token_embedding, seq_len)
if encoder_output.size(2) < max_len + 2:
num_layers, _, output_len, hidden_size = encoder_output.size()
dummy_tensor = encoder_output.new_zeros(num_layers, batch_size,
max_len + 2 - output_len, hidden_size)
encoder_output = torch.cat((encoder_output, dummy_tensor), 2)
sz = encoder_output.size() # 2, batch_size, max_len, hidden_size
token_embedding = token_embedding.masked_fill(mask, 0)
token_embedding = torch.cat((token_embedding, token_embedding), dim=2).view(1, sz[1], sz[2], sz[3])
encoder_output = torch.cat((token_embedding, encoder_output), dim=0)

# 删除<eos>, <bos>. 这里没有精确地删除,但应该也不会影响最后的结果了。
encoder_output = encoder_output[:, :, 1:-1]


+ 0
- 1
requirements.txt View File

@@ -4,4 +4,3 @@ tqdm>=4.28.1
nltk>=3.4.1
requests
spacy
h5py

Loading…
Cancel
Save