Browse Source

optimize model definition of seq modeling; test_POS_pipeline ready to run

tags/v0.1.0
FengZiYjun 7 years ago
parent
commit
1c0eac4c82
10 changed files with 121 additions and 97 deletions
  1. +6
    -6
      fastNLP/action/tester.py
  2. +5
    -6
      fastNLP/action/trainer.py
  3. +41
    -57
      fastNLP/models/sequence_modeling.py
  4. +11
    -0
      fastNLP/modules/__init__.py
  5. +14
    -15
      fastNLP/modules/decoder/CRF.py
  6. +3
    -0
      fastNLP/modules/decoder/__init__.py
  7. +7
    -0
      fastNLP/modules/encoder/__init__.py
  8. +10
    -10
      fastNLP/modules/encoder/embedding.py
  9. +21
    -0
      fastNLP/modules/encoder/linear.py
  10. +3
    -3
      fastNLP/modules/encoder/lstm.py

+ 6
- 6
fastNLP/action/tester.py View File

@@ -1,12 +1,11 @@
import _pickle import _pickle
import os


import numpy as np import numpy as np
import torch import torch
import os


from fastNLP.action.action import Action from fastNLP.action.action import Action
from fastNLP.action.action import RandomSampler, Batchifier from fastNLP.action.action import RandomSampler, Batchifier
from fastNLP.modules.utils import seq_mask




class BaseTester(Action): class BaseTester(Action):
@@ -148,18 +147,19 @@ class POSTester(BaseTester):
:param x: list of list, [batch_size, max_len] :param x: list of list, [batch_size, max_len]
:return y: [batch_size, num_classes] :return y: [batch_size, num_classes]
""" """
seq_len = [len(seq) for seq in x]
self.seq_len = [len(seq) for seq in x]
x = torch.Tensor(x).long() x = torch.Tensor(x).long()
self.batch_size = x.size(0) self.batch_size = x.size(0)
self.max_len = x.size(1) self.max_len = x.size(1)
self.mask = seq_mask(seq_len, self.max_len)
# self.mask = seq_mask(seq_len, self.max_len)
y = network(x) y = network(x)
return y return y


def evaluate(self, predict, truth): def evaluate(self, predict, truth):
truth = torch.Tensor(truth) truth = torch.Tensor(truth)
loss, prediction = self.model.loss(predict, truth, self.mask, self.batch_size, self.max_len)
results = torch.Tensor(prediction[0][0]).view((-1,))
loss = self.model.loss(predict, truth, self.seq_len)
prediction = self.model.prediction(predict, self.seq_len)
results = torch.Tensor(prediction).view(-1,)
accuracy = float(torch.sum(results == truth.view((-1,)))) / results.shape[0] accuracy = float(torch.sum(results == truth.view((-1,)))) / results.shape[0]
return [loss.data, accuracy] return [loss.data, accuracy]




+ 5
- 6
fastNLP/action/trainer.py View File

@@ -10,7 +10,6 @@ import torch.nn as nn
from fastNLP.action.action import Action from fastNLP.action.action import Action
from fastNLP.action.action import RandomSampler, Batchifier from fastNLP.action.action import RandomSampler, Batchifier
from fastNLP.action.tester import POSTester from fastNLP.action.tester import POSTester
from fastNLP.modules.utils import seq_mask
from fastNLP.saver.model_saver import ModelSaver from fastNLP.saver.model_saver import ModelSaver




@@ -289,13 +288,13 @@ class POSTrainer(BaseTrainer):
""" """
:param network: the PyTorch model :param network: the PyTorch model
:param x: list of list, [batch_size, max_len] :param x: list of list, [batch_size, max_len]
:return y: [batch_size, num_classes]
:return y: [batch_size, max_len, tag_size]
""" """
seq_len = [len(seq) for seq in x]
self.seq_len = [len(seq) for seq in x]
x = torch.Tensor(x).long() x = torch.Tensor(x).long()
self.batch_size = x.size(0) self.batch_size = x.size(0)
self.max_len = x.size(1) self.max_len = x.size(1)
self.mask = seq_mask(seq_len, self.max_len)
# self.mask = seq_mask(seq_len, self.max_len)
y = network(x) y = network(x)
return y return y


@@ -318,7 +317,7 @@ class POSTrainer(BaseTrainer):
def get_loss(self, predict, truth): def get_loss(self, predict, truth):
""" """
Compute loss given prediction and ground truth. Compute loss given prediction and ground truth.
:param predict: prediction label vector, [batch_size, num_classes]
:param predict: prediction label vector, [batch_size, tag_size, tag_size]
:param truth: ground truth label vector, [batch_size, max_len] :param truth: ground truth label vector, [batch_size, max_len]
:return: a scalar :return: a scalar
""" """
@@ -328,7 +327,7 @@ class POSTrainer(BaseTrainer):
self.loss_func = self.model.loss self.loss_func = self.model.loss
else: else:
self.define_loss() self.define_loss()
loss, prediction = self.loss_func(predict, truth, self.mask, self.batch_size, self.max_len)
loss = self.loss_func(predict, truth, self.seq_len)
# print("loss={:.2f}".format(loss.data)) # print("loss={:.2f}".format(loss.data))
return loss return loss




+ 41
- 57
fastNLP/models/sequence_modeling.py View File

@@ -1,9 +1,7 @@
import torch import torch
import torch.nn as nn
from torch.nn import functional as F


from fastNLP.models.base_model import BaseModel from fastNLP.models.base_model import BaseModel
from fastNLP.modules.decoder.CRF import ContionalRandomField
from fastNLP.modules import decoder, encoder, utils




class SeqLabeling(BaseModel): class SeqLabeling(BaseModel):
@@ -23,75 +21,61 @@ class SeqLabeling(BaseModel):
use_crf=True): use_crf=True):
super(SeqLabeling, self).__init__() super(SeqLabeling, self).__init__()


self.Emb = nn.Embedding(vocab_size, word_emb_dim)
if init_emb:
self.Emb.weight = nn.Parameter(init_emb)

self.num_classes = num_classes
self.input_dim = word_emb_dim
self.layers = rnn_num_layer
self.hidden_dim = hidden_dim
self.bi_direction = bi_direction
self.dropout = dropout
self.mode = rnn_mode

if self.mode == "lstm":
self.rnn = nn.LSTM(self.input_dim, self.hidden_dim, self.layers, batch_first=True,
bidirectional=self.bi_direction, dropout=self.dropout)
elif self.mode == "gru":
self.rnn = nn.GRU(self.input_dim, self.hidden_dim, self.layers, batch_first=True,
bidirectional=self.bi_direction, dropout=self.dropout)
elif self.mode == "rnn":
self.rnn = nn.RNN(self.input_dim, self.hidden_dim, self.layers, batch_first=True,
bidirectional=self.bi_direction, dropout=self.dropout)
else:
raise Exception
if bi_direction:
self.linear = nn.Linear(self.hidden_dim * 2, self.num_classes)
else:
self.linear = nn.Linear(self.hidden_dim, self.num_classes)
self.use_crf = use_crf
if self.use_crf:
self.crf = ContionalRandomField(num_classes)
self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim)
self.Rnn = encoder.lstm.Lstm(word_emb_dim, hidden_dim)
self.Linear = encoder.linear.Linear(hidden_dim, num_classes)
self.Crf = decoder.CRF.ConditionalRandomField(num_classes)


def forward(self, x): def forward(self, x):
""" """
:param x: LongTensor, [batch_size, mex_len] :param x: LongTensor, [batch_size, mex_len]
:return y: [batch_size, tag_size, tag_size] :return y: [batch_size, tag_size, tag_size]
""" """
x = self.Emb(x)
x = self.Embedding(x)
# [batch_size, max_len, word_emb_dim] # [batch_size, max_len, word_emb_dim]
x, hidden = self.rnn(x)
x = self.Rnn(x)
# [batch_size, max_len, hidden_size * direction] # [batch_size, max_len, hidden_size * direction]
y = self.linear(x)
x = self.Linear(x)
# [batch_size, max_len, num_classes] # [batch_size, max_len, num_classes]
return y
return x


def loss(self, x, y, mask, batch_size, max_len):
def loss(self, x, y, seq_length):
""" """
Negative log likelihood loss. Negative log likelihood loss.
:param x: FloatTensor, [batch_size, tag_size, tag_size]
:param x: FloatTensor, [batch_size, max_len, tag_size]
:param y: LongTensor, [batch_size, max_len] :param y: LongTensor, [batch_size, max_len]
:param mask: ByteTensor, [batch_size, max_len]
:param batch_size: int
:param max_len: int
:param seq_length: list of int. [batch_size]
:return loss: a scalar Tensor :return loss: a scalar Tensor
prediction: list of tuple of (decode path(list), best score)

""" """
x = x.float() x = x.float()
y = y.long() y = y.long()

batch_size = x.size(0)
max_len = x.size(1)

mask = utils.seq_mask(seq_length, max_len)
mask = mask.byte().view(batch_size, max_len)
# mask = x.new(batch_size, max_len)

total_loss = self.Crf(x, y, mask)

return torch.mean(total_loss)

def prediction(self, x, seq_length):
"""
:param x: FloatTensor, [batch_size, tag_size, tag_size]
:param seq_length: int
:return prediction: list of tuple of (decode path(list), best score)
"""
x = x.float()
batch_size = x.size(0)
max_len = x.size(1)

mask = utils.seq_mask(seq_length, max_len)
mask = mask.byte() mask = mask.byte()
# print(x.shape, y.shape, mask.shape)

if self.use_crf:
total_loss = self.crf(x, y, mask)
tag_seq = self.crf.viterbi_decode(x, mask)
else:
# error
loss_function = nn.NLLLoss(ignore_index=0, size_average=False)
x = x.view(batch_size * max_len, -1)
score = F.log_softmax(x)
total_loss = loss_function(score, y.view(batch_size * max_len))
_, tag_seq = torch.max(score)
tag_seq = tag_seq.view(batch_size, max_len)
return torch.mean(total_loss), tag_seq
# mask = x.new(batch_size, max_len)

tag_seq = self.Crf.viterbi_decode(x, mask)

return tag_seq

+ 11
- 0
fastNLP/modules/__init__.py View File

@@ -0,0 +1,11 @@
from . import aggregation
from . import decoder
from . import encoder
from . import interaction

__version__ = '0.0.0'

__all__ = ['encoder',
'decoder',
'aggregation',
'interaction']

+ 14
- 15
fastNLP/modules/decoder/CRF.py View File

@@ -18,13 +18,13 @@ def seq_len_to_byte_mask(seq_lens):
return mask return mask




class ContionalRandomField(nn.Module):
class ConditionalRandomField(nn.Module):
def __init__(self, tag_size, include_start_end_trans=True): def __init__(self, tag_size, include_start_end_trans=True):
""" """
:param tag_size: int, num of tags :param tag_size: int, num of tags
:param include_start_end_trans: bool, whether to include start/end tag :param include_start_end_trans: bool, whether to include start/end tag
""" """
super(ContionalRandomField, self).__init__()
super(ConditionalRandomField, self).__init__()


self.include_start_end_trans = include_start_end_trans self.include_start_end_trans = include_start_end_trans
self.tag_size = tag_size self.tag_size = tag_size
@@ -47,7 +47,6 @@ class ContionalRandomField(nn.Module):
""" """
Computes the (batch_size,) denominator term for the log-likelihood, which is the Computes the (batch_size,) denominator term for the log-likelihood, which is the
sum of the likelihoods across all possible state sequences. sum of the likelihoods across all possible state sequences.

:param feats:FloatTensor, batch_size x max_len x tag_size :param feats:FloatTensor, batch_size x max_len x tag_size
:param masks:ByteTensor, batch_size x max_len :param masks:ByteTensor, batch_size x max_len
:return:FloatTensor, batch_size :return:FloatTensor, batch_size
@@ -128,7 +127,7 @@ class ContionalRandomField(nn.Module):


return all_path_score - gold_path_score return all_path_score - gold_path_score


def viterbi_decode(self, feats, masks):
def viterbi_decode(self, feats, masks, get_score=False):
""" """
Given a feats matrix, return best decode path and best score. Given a feats matrix, return best decode path and best score.
:param feats: :param feats:
@@ -147,28 +146,28 @@ class ContionalRandomField(nn.Module):
for t in range(self.tag_size): for t in range(self.tag_size):
pre_scores = self.transition_m[:, t].view( pre_scores = self.transition_m[:, t].view(
1, self.tag_size) + alpha 1, self.tag_size) + alpha
max_scroe, indice = pre_scores.max(dim=1)
new_alpha[:, t] = max_scroe + feats[:, i, t]
paths[:, i - 1, t] = indice
alpha = new_alpha * \
masks[:, i:i + 1].float() + alpha * \
(1 - masks[:, i:i + 1].float())
max_score, indices = pre_scores.max(dim=1)
new_alpha[:, t] = max_score + feats[:, i, t]
paths[:, i - 1, t] = indices
alpha = new_alpha * masks[:, i:i + 1].float() + alpha * (1 - masks[:, i:i + 1].float())


if self.include_start_end_trans: if self.include_start_end_trans:
alpha += self.end_scores.view(1, -1) alpha += self.end_scores.view(1, -1)


max_scroes, indice = alpha.max(dim=1)
indice = indice.cpu().numpy()
max_scores, indices = alpha.max(dim=1)
indices = indices.cpu().numpy()
final_paths = [] final_paths = []
paths = paths.cpu().numpy().astype(int) paths = paths.cpu().numpy().astype(int)


seq_lens = masks.cumsum(dim=1, dtype=torch.long)[:, -1] seq_lens = masks.cumsum(dim=1, dtype=torch.long)[:, -1]


for b in range(batch_size): for b in range(batch_size):
path = [indice[b]]
path = [indices[b]]
for i in range(seq_lens[b] - 2, -1, -1): for i in range(seq_lens[b] - 2, -1, -1):
index = paths[b, i, path[-1]] index = paths[b, i, path[-1]]
path.append(index) path.append(index)
final_paths.append(path[::-1]) final_paths.append(path[::-1])

return list(zip(final_paths, max_scroes.detach().cpu().numpy()))
if get_score:
return list(zip(final_paths, max_scores.detach().cpu().numpy()))
else:
return final_paths

+ 3
- 0
fastNLP/modules/decoder/__init__.py View File

@@ -0,0 +1,3 @@
from .CRF import ConditionalRandomField

__all__ = ["ConditionalRandomField"]

+ 7
- 0
fastNLP/modules/encoder/__init__.py View File

@@ -0,0 +1,7 @@
from .embedding import Embedding
from .linear import Linear
from .lstm import Lstm

__all__ = ["Lstm",
"Embedding",
"Linear"]

+ 10
- 10
fastNLP/modules/encoder/embedding.py View File

@@ -1,10 +1,9 @@
import torch.nn as nn import torch.nn as nn




class Lookuptable(nn.Module):
class Embedding(nn.Module):
""" """
A simple lookup table A simple lookup table

Args: Args:
nums : the size of the lookup table nums : the size of the lookup table
dims : the size of each vector dims : the size of each vector
@@ -12,13 +11,14 @@ class Lookuptable(nn.Module):
sparse : If True, gradient matrix will be a sparse tensor. In this case, sparse : If True, gradient matrix will be a sparse tensor. In this case,
only optim.SGD(cuda and cpu) and optim.Adagrad(cpu) can be used only optim.SGD(cuda and cpu) and optim.Adagrad(cpu) can be used
""" """
def __init__(self, nums, dims, padding_idx=0, sparse=False):
super(Lookuptable, self).__init__()
self.embed = nn.Embedding(nums, dims, padding_idx, sparse=sparse)
def forward(self, x):
return self.embed(x)


def __init__(self, nums, dims, padding_idx=0, sparse=False, init_emb=None, dropout=0.0):
super(Embedding, self).__init__()
self.embed = nn.Embedding(nums, dims, padding_idx, sparse=sparse)
if init_emb:
self.embed.weight = nn.Parameter(init_emb)
self.dropout = nn.Dropout(dropout)


if __name__ == "__main__":
model = Lookuptable(10, 20)
def forward(self, x):
x = self.embed(x)
return self.dropout(x)

+ 21
- 0
fastNLP/modules/encoder/linear.py View File

@@ -0,0 +1,21 @@
import torch.nn as nn


class Linear(nn.Module):
"""
Linear module
Args:
input_size : input size
hidden_size : hidden size
num_layers : number of hidden layers
dropout : dropout rate
bidirectional : If True, becomes a bidirectional RNN
"""

def __init__(self, input_size, output_size, bias=True):
super(Linear, self).__init__()
self.linear = nn.Linear(input_size, output_size, bias)

def forward(self, x):
x = self.linear(x)
return x

+ 3
- 3
fastNLP/modules/encoder/lstm.py View File

@@ -4,7 +4,6 @@ import torch.nn as nn
class Lstm(nn.Module): class Lstm(nn.Module):
""" """
LSTM module LSTM module

Args: Args:
input_size : input size input_size : input size
hidden_size : hidden size hidden_size : hidden size
@@ -12,11 +11,12 @@ class Lstm(nn.Module):
dropout : dropout rate dropout : dropout rate
bidirectional : If True, becomes a bidirectional RNN bidirectional : If True, becomes a bidirectional RNN
""" """
def __init__(self, input_size, hidden_size, num_layers, dropout, bidirectional):

def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, bidirectional=False):
super(Lstm, self).__init__() super(Lstm, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True, self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True,
dropout=dropout, bidirectional=bidirectional) dropout=dropout, bidirectional=bidirectional)
def forward(self, x): def forward(self, x):
x, _ = self.lstm(x) x, _ = self.lstm(x)
return x return x

Loading…
Cancel
Save