From 83c032df5d661e0860695d80c37296480555b833 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Tue, 10 Jul 2018 18:51:42 +0800 Subject: [PATCH] fix bug in CRF comments; optimize PyTorch type conversion. --- fastNLP/action/trainer.py | 23 ++++++++------- fastNLP/models/sequencce_modeling.py | 42 ++++++++++++---------------- fastNLP/modules/CRF.py | 4 +-- 3 files changed, 33 insertions(+), 36 deletions(-) diff --git a/fastNLP/action/trainer.py b/fastNLP/action/trainer.py index 94a704f9..1f22ef28 100644 --- a/fastNLP/action/trainer.py +++ b/fastNLP/action/trainer.py @@ -170,8 +170,8 @@ class BaseTrainer(Action): [[word_21, word_22, word_23], [label_21. label_22]], # sample 2 ... ] - :return batch_x: list. Each entry is a list of features of a sample. - batch_y: list. Each entry is a list of labels of a sample. + :return batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] + batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] """ if self.iterator is None: self.iterator = iter(Batchifier(RandomSampler(data), batch_size, drop_last=True)) @@ -325,7 +325,6 @@ class POSTrainer(BaseTrainer): self.num_classes = train_args.num_classes self.max_len = None self.mask = None - self.batch_x = None def prepare_input(self, data_path): """ @@ -336,14 +335,18 @@ class POSTrainer(BaseTrainer): return data_train, data_dev, 0, 1 def data_forward(self, network, x): + """ + :param network: the PyTorch model + :param x: list of list, [batch_size, max_len] + :return y: [batch_size, num_classes] + """ seq_len = [len(seq) for seq in x] - x = torch.LongTensor(x) + x = torch.Tensor(x).long() self.batch_size = x.size(0) self.max_len = x.size(1) self.mask = seq_mask(seq_len, self.max_len) - x = network(x) - self.batch_x = x - return x + y = network(x) + return y def mode(self, test=False): if test: @@ -357,8 +360,8 @@ class POSTrainer(BaseTrainer): def get_loss(self, predict, truth): """ Compute loss given prediction and ground truth. - :param predict: prediction label vector - :param truth: ground truth label vector + :param predict: prediction label vector, [batch_size, num_classes] + :param truth: ground truth label vector, [batch_size, max_len] :return: a scalar """ if self.loss_func is None: @@ -366,7 +369,7 @@ class POSTrainer(BaseTrainer): self.loss_func = self.model.loss else: self.define_loss() - return self.loss_func(self.batch_x, predict, self.mask, self.batch_size, self.max_len) + return self.loss_func(predict, truth, self.mask, self.batch_size, self.max_len) if __name__ == "__name__": diff --git a/fastNLP/models/sequencce_modeling.py b/fastNLP/models/sequencce_modeling.py index ba96d4b6..96f09f80 100644 --- a/fastNLP/models/sequencce_modeling.py +++ b/fastNLP/models/sequencce_modeling.py @@ -12,7 +12,7 @@ class SeqLabeling(BaseModel): """ def __init__(self, hidden_dim, - rnn_num_layerd, + rnn_num_layer, num_classes, vocab_size, word_emb_dim=100, @@ -29,7 +29,7 @@ class SeqLabeling(BaseModel): self.num_classes = num_classes self.input_dim = word_emb_dim - self.layers = rnn_num_layerd + self.layers = rnn_num_layer self.hidden_dim = hidden_dim self.bi_direction = bi_direction self.dropout = dropout @@ -55,32 +55,26 @@ class SeqLabeling(BaseModel): self.crf = ContionalRandomField(num_classes) def forward(self, x): - - x = self.embedding(x) - x, hidden = self.encode(x) - x = self.aggregate(x) - x = self.decode(x) - return x - - def embedding(self, x): - return self.Emb(x) - - def encode(self, x): - return self.rnn(x) - - def aggregate(self, x): - return x - - def decode(self, x): - x = self.linear(x) - return x + """ + :param x: LongTensor, [batch_size, mex_len] + :return y: [batch_size, tag_size, tag_size] + """ + x = self.Emb(x) + # [batch_size, max_len, word_emb_dim] + x, hidden = self.rnn(x) + # [batch_size, max_len, hidden_size * direction] + y = self.linear(x) + # [batch_size, max_len, num_classes] + return y def loss(self, x, y, mask, batch_size, max_len): """ Negative log likelihood loss. - :param x: - :param y: - :param seq_len: + :param x: FloatTensor, [batch_size, tag_size, tag_size] + :param y: LongTensor, [batch_size, max_len] + :param mask: ByteTensor, [batch_size, max_len] + :param batch_size: int + :param max_len: int :return loss: prediction: """ diff --git a/fastNLP/modules/CRF.py b/fastNLP/modules/CRF.py index 6361b93d..96c84dca 100644 --- a/fastNLP/modules/CRF.py +++ b/fastNLP/modules/CRF.py @@ -82,7 +82,7 @@ class ContionalRandomField(nn.Module): def _glod_score(self, feats, tags, masks): """ Compute the score for the gold path. - :param feats: FloatTensor, batch_size x tag_size x tag_size + :param feats: FloatTensor, batch_size x max_len x tag_size :param tags: LongTensor, batch_size x max_len :param masks: ByteTensor, batch_size x max_len :return:FloatTensor, batch_size @@ -118,7 +118,7 @@ class ContionalRandomField(nn.Module): def forward(self, feats, tags, masks): """ Calculate the neg log likelihood - :param feats:FloatTensor, batch_size x tag_size x tag_size + :param feats:FloatTensor, batch_size x max_len x tag_size :param tags:LongTensor, batch_size x max_len :param masks:ByteTensor batch_size x max_len :return:FloatTensor, batch_size