@@ -1,12 +1,11 @@ | |||
import _pickle | |||
import os | |||
import numpy as np | |||
import torch | |||
import os | |||
from fastNLP.action.action import Action | |||
from fastNLP.action.action import RandomSampler, Batchifier | |||
from fastNLP.modules.utils import seq_mask | |||
class BaseTester(Action): | |||
@@ -148,18 +147,19 @@ class POSTester(BaseTester): | |||
:param x: list of list, [batch_size, max_len] | |||
:return y: [batch_size, num_classes] | |||
""" | |||
seq_len = [len(seq) for seq in x] | |||
self.seq_len = [len(seq) for seq in x] | |||
x = torch.Tensor(x).long() | |||
self.batch_size = x.size(0) | |||
self.max_len = x.size(1) | |||
self.mask = seq_mask(seq_len, self.max_len) | |||
# self.mask = seq_mask(seq_len, self.max_len) | |||
y = network(x) | |||
return y | |||
def evaluate(self, predict, truth): | |||
truth = torch.Tensor(truth) | |||
loss, prediction = self.model.loss(predict, truth, self.mask, self.batch_size, self.max_len) | |||
results = torch.Tensor(prediction[0][0]).view((-1,)) | |||
loss = self.model.loss(predict, truth, self.seq_len) | |||
prediction = self.model.prediction(predict, self.seq_len) | |||
results = torch.Tensor(prediction).view(-1,) | |||
accuracy = float(torch.sum(results == truth.view((-1,)))) / results.shape[0] | |||
return [loss.data, accuracy] | |||
@@ -10,7 +10,6 @@ import torch.nn as nn | |||
from fastNLP.action.action import Action | |||
from fastNLP.action.action import RandomSampler, Batchifier | |||
from fastNLP.action.tester import POSTester | |||
from fastNLP.modules.utils import seq_mask | |||
from fastNLP.saver.model_saver import ModelSaver | |||
@@ -289,13 +288,13 @@ class POSTrainer(BaseTrainer): | |||
""" | |||
:param network: the PyTorch model | |||
:param x: list of list, [batch_size, max_len] | |||
:return y: [batch_size, num_classes] | |||
:return y: [batch_size, max_len, tag_size] | |||
""" | |||
seq_len = [len(seq) for seq in x] | |||
self.seq_len = [len(seq) for seq in x] | |||
x = torch.Tensor(x).long() | |||
self.batch_size = x.size(0) | |||
self.max_len = x.size(1) | |||
self.mask = seq_mask(seq_len, self.max_len) | |||
# self.mask = seq_mask(seq_len, self.max_len) | |||
y = network(x) | |||
return y | |||
@@ -318,7 +317,7 @@ class POSTrainer(BaseTrainer): | |||
def get_loss(self, predict, truth): | |||
""" | |||
Compute loss given prediction and ground truth. | |||
:param predict: prediction label vector, [batch_size, num_classes] | |||
:param predict: prediction label vector, [batch_size, tag_size, tag_size] | |||
:param truth: ground truth label vector, [batch_size, max_len] | |||
:return: a scalar | |||
""" | |||
@@ -328,7 +327,7 @@ class POSTrainer(BaseTrainer): | |||
self.loss_func = self.model.loss | |||
else: | |||
self.define_loss() | |||
loss, prediction = self.loss_func(predict, truth, self.mask, self.batch_size, self.max_len) | |||
loss = self.loss_func(predict, truth, self.seq_len) | |||
# print("loss={:.2f}".format(loss.data)) | |||
return loss | |||
@@ -1,9 +1,7 @@ | |||
import torch | |||
import torch.nn as nn | |||
from torch.nn import functional as F | |||
from fastNLP.models.base_model import BaseModel | |||
from fastNLP.modules.decoder.CRF import ContionalRandomField | |||
from fastNLP.modules import decoder, encoder, utils | |||
class SeqLabeling(BaseModel): | |||
@@ -23,75 +21,61 @@ class SeqLabeling(BaseModel): | |||
use_crf=True): | |||
super(SeqLabeling, self).__init__() | |||
self.Emb = nn.Embedding(vocab_size, word_emb_dim) | |||
if init_emb: | |||
self.Emb.weight = nn.Parameter(init_emb) | |||
self.num_classes = num_classes | |||
self.input_dim = word_emb_dim | |||
self.layers = rnn_num_layer | |||
self.hidden_dim = hidden_dim | |||
self.bi_direction = bi_direction | |||
self.dropout = dropout | |||
self.mode = rnn_mode | |||
if self.mode == "lstm": | |||
self.rnn = nn.LSTM(self.input_dim, self.hidden_dim, self.layers, batch_first=True, | |||
bidirectional=self.bi_direction, dropout=self.dropout) | |||
elif self.mode == "gru": | |||
self.rnn = nn.GRU(self.input_dim, self.hidden_dim, self.layers, batch_first=True, | |||
bidirectional=self.bi_direction, dropout=self.dropout) | |||
elif self.mode == "rnn": | |||
self.rnn = nn.RNN(self.input_dim, self.hidden_dim, self.layers, batch_first=True, | |||
bidirectional=self.bi_direction, dropout=self.dropout) | |||
else: | |||
raise Exception | |||
if bi_direction: | |||
self.linear = nn.Linear(self.hidden_dim * 2, self.num_classes) | |||
else: | |||
self.linear = nn.Linear(self.hidden_dim, self.num_classes) | |||
self.use_crf = use_crf | |||
if self.use_crf: | |||
self.crf = ContionalRandomField(num_classes) | |||
self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim) | |||
self.Rnn = encoder.lstm.Lstm(word_emb_dim, hidden_dim) | |||
self.Linear = encoder.linear.Linear(hidden_dim, num_classes) | |||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | |||
def forward(self, x): | |||
""" | |||
:param x: LongTensor, [batch_size, mex_len] | |||
:return y: [batch_size, tag_size, tag_size] | |||
""" | |||
x = self.Emb(x) | |||
x = self.Embedding(x) | |||
# [batch_size, max_len, word_emb_dim] | |||
x, hidden = self.rnn(x) | |||
x = self.Rnn(x) | |||
# [batch_size, max_len, hidden_size * direction] | |||
y = self.linear(x) | |||
x = self.Linear(x) | |||
# [batch_size, max_len, num_classes] | |||
return y | |||
return x | |||
def loss(self, x, y, mask, batch_size, max_len): | |||
def loss(self, x, y, seq_length): | |||
""" | |||
Negative log likelihood loss. | |||
:param x: FloatTensor, [batch_size, tag_size, tag_size] | |||
:param x: FloatTensor, [batch_size, max_len, tag_size] | |||
:param y: LongTensor, [batch_size, max_len] | |||
:param mask: ByteTensor, [batch_size, max_len] | |||
:param batch_size: int | |||
:param max_len: int | |||
:param seq_length: list of int. [batch_size] | |||
:return loss: a scalar Tensor | |||
prediction: list of tuple of (decode path(list), best score) | |||
""" | |||
x = x.float() | |||
y = y.long() | |||
batch_size = x.size(0) | |||
max_len = x.size(1) | |||
mask = utils.seq_mask(seq_length, max_len) | |||
mask = mask.byte().view(batch_size, max_len) | |||
# mask = x.new(batch_size, max_len) | |||
total_loss = self.Crf(x, y, mask) | |||
return torch.mean(total_loss) | |||
def prediction(self, x, seq_length): | |||
""" | |||
:param x: FloatTensor, [batch_size, tag_size, tag_size] | |||
:param seq_length: int | |||
:return prediction: list of tuple of (decode path(list), best score) | |||
""" | |||
x = x.float() | |||
batch_size = x.size(0) | |||
max_len = x.size(1) | |||
mask = utils.seq_mask(seq_length, max_len) | |||
mask = mask.byte() | |||
# print(x.shape, y.shape, mask.shape) | |||
if self.use_crf: | |||
total_loss = self.crf(x, y, mask) | |||
tag_seq = self.crf.viterbi_decode(x, mask) | |||
else: | |||
# error | |||
loss_function = nn.NLLLoss(ignore_index=0, size_average=False) | |||
x = x.view(batch_size * max_len, -1) | |||
score = F.log_softmax(x) | |||
total_loss = loss_function(score, y.view(batch_size * max_len)) | |||
_, tag_seq = torch.max(score) | |||
tag_seq = tag_seq.view(batch_size, max_len) | |||
return torch.mean(total_loss), tag_seq | |||
# mask = x.new(batch_size, max_len) | |||
tag_seq = self.Crf.viterbi_decode(x, mask) | |||
return tag_seq |
@@ -0,0 +1,11 @@ | |||
from . import aggregation | |||
from . import decoder | |||
from . import encoder | |||
from . import interaction | |||
__version__ = '0.0.0' | |||
__all__ = ['encoder', | |||
'decoder', | |||
'aggregation', | |||
'interaction'] |
@@ -18,13 +18,13 @@ def seq_len_to_byte_mask(seq_lens): | |||
return mask | |||
class ContionalRandomField(nn.Module): | |||
class ConditionalRandomField(nn.Module): | |||
def __init__(self, tag_size, include_start_end_trans=True): | |||
""" | |||
:param tag_size: int, num of tags | |||
:param include_start_end_trans: bool, whether to include start/end tag | |||
""" | |||
super(ContionalRandomField, self).__init__() | |||
super(ConditionalRandomField, self).__init__() | |||
self.include_start_end_trans = include_start_end_trans | |||
self.tag_size = tag_size | |||
@@ -47,7 +47,6 @@ class ContionalRandomField(nn.Module): | |||
""" | |||
Computes the (batch_size,) denominator term for the log-likelihood, which is the | |||
sum of the likelihoods across all possible state sequences. | |||
:param feats:FloatTensor, batch_size x max_len x tag_size | |||
:param masks:ByteTensor, batch_size x max_len | |||
:return:FloatTensor, batch_size | |||
@@ -128,7 +127,7 @@ class ContionalRandomField(nn.Module): | |||
return all_path_score - gold_path_score | |||
def viterbi_decode(self, feats, masks): | |||
def viterbi_decode(self, feats, masks, get_score=False): | |||
""" | |||
Given a feats matrix, return best decode path and best score. | |||
:param feats: | |||
@@ -147,28 +146,28 @@ class ContionalRandomField(nn.Module): | |||
for t in range(self.tag_size): | |||
pre_scores = self.transition_m[:, t].view( | |||
1, self.tag_size) + alpha | |||
max_scroe, indice = pre_scores.max(dim=1) | |||
new_alpha[:, t] = max_scroe + feats[:, i, t] | |||
paths[:, i - 1, t] = indice | |||
alpha = new_alpha * \ | |||
masks[:, i:i + 1].float() + alpha * \ | |||
(1 - masks[:, i:i + 1].float()) | |||
max_score, indices = pre_scores.max(dim=1) | |||
new_alpha[:, t] = max_score + feats[:, i, t] | |||
paths[:, i - 1, t] = indices | |||
alpha = new_alpha * masks[:, i:i + 1].float() + alpha * (1 - masks[:, i:i + 1].float()) | |||
if self.include_start_end_trans: | |||
alpha += self.end_scores.view(1, -1) | |||
max_scroes, indice = alpha.max(dim=1) | |||
indice = indice.cpu().numpy() | |||
max_scores, indices = alpha.max(dim=1) | |||
indices = indices.cpu().numpy() | |||
final_paths = [] | |||
paths = paths.cpu().numpy().astype(int) | |||
seq_lens = masks.cumsum(dim=1, dtype=torch.long)[:, -1] | |||
for b in range(batch_size): | |||
path = [indice[b]] | |||
path = [indices[b]] | |||
for i in range(seq_lens[b] - 2, -1, -1): | |||
index = paths[b, i, path[-1]] | |||
path.append(index) | |||
final_paths.append(path[::-1]) | |||
return list(zip(final_paths, max_scroes.detach().cpu().numpy())) | |||
if get_score: | |||
return list(zip(final_paths, max_scores.detach().cpu().numpy())) | |||
else: | |||
return final_paths |
@@ -0,0 +1,3 @@ | |||
from .CRF import ConditionalRandomField | |||
__all__ = ["ConditionalRandomField"] |
@@ -0,0 +1,7 @@ | |||
from .embedding import Embedding | |||
from .linear import Linear | |||
from .lstm import Lstm | |||
__all__ = ["Lstm", | |||
"Embedding", | |||
"Linear"] |
@@ -1,10 +1,9 @@ | |||
import torch.nn as nn | |||
class Lookuptable(nn.Module): | |||
class Embedding(nn.Module): | |||
""" | |||
A simple lookup table | |||
Args: | |||
nums : the size of the lookup table | |||
dims : the size of each vector | |||
@@ -12,13 +11,14 @@ class Lookuptable(nn.Module): | |||
sparse : If True, gradient matrix will be a sparse tensor. In this case, | |||
only optim.SGD(cuda and cpu) and optim.Adagrad(cpu) can be used | |||
""" | |||
def __init__(self, nums, dims, padding_idx=0, sparse=False): | |||
super(Lookuptable, self).__init__() | |||
self.embed = nn.Embedding(nums, dims, padding_idx, sparse=sparse) | |||
def forward(self, x): | |||
return self.embed(x) | |||
def __init__(self, nums, dims, padding_idx=0, sparse=False, init_emb=None, dropout=0.0): | |||
super(Embedding, self).__init__() | |||
self.embed = nn.Embedding(nums, dims, padding_idx, sparse=sparse) | |||
if init_emb: | |||
self.embed.weight = nn.Parameter(init_emb) | |||
self.dropout = nn.Dropout(dropout) | |||
if __name__ == "__main__": | |||
model = Lookuptable(10, 20) | |||
def forward(self, x): | |||
x = self.embed(x) | |||
return self.dropout(x) |
@@ -0,0 +1,21 @@ | |||
import torch.nn as nn | |||
class Linear(nn.Module): | |||
""" | |||
Linear module | |||
Args: | |||
input_size : input size | |||
hidden_size : hidden size | |||
num_layers : number of hidden layers | |||
dropout : dropout rate | |||
bidirectional : If True, becomes a bidirectional RNN | |||
""" | |||
def __init__(self, input_size, output_size, bias=True): | |||
super(Linear, self).__init__() | |||
self.linear = nn.Linear(input_size, output_size, bias) | |||
def forward(self, x): | |||
x = self.linear(x) | |||
return x |
@@ -4,7 +4,6 @@ import torch.nn as nn | |||
class Lstm(nn.Module): | |||
""" | |||
LSTM module | |||
Args: | |||
input_size : input size | |||
hidden_size : hidden size | |||
@@ -12,11 +11,12 @@ class Lstm(nn.Module): | |||
dropout : dropout rate | |||
bidirectional : If True, becomes a bidirectional RNN | |||
""" | |||
def __init__(self, input_size, hidden_size, num_layers, dropout, bidirectional): | |||
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, bidirectional=False): | |||
super(Lstm, self).__init__() | |||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True, | |||
dropout=dropout, bidirectional=bidirectional) | |||
def forward(self, x): | |||
x, _ = self.lstm(x) | |||
return x |