@@ -1,12 +1,11 @@ | |||||
import _pickle | import _pickle | ||||
import os | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
import os | |||||
from fastNLP.action.action import Action | from fastNLP.action.action import Action | ||||
from fastNLP.action.action import RandomSampler, Batchifier | from fastNLP.action.action import RandomSampler, Batchifier | ||||
from fastNLP.modules.utils import seq_mask | |||||
class BaseTester(Action): | class BaseTester(Action): | ||||
@@ -148,18 +147,19 @@ class POSTester(BaseTester): | |||||
:param x: list of list, [batch_size, max_len] | :param x: list of list, [batch_size, max_len] | ||||
:return y: [batch_size, num_classes] | :return y: [batch_size, num_classes] | ||||
""" | """ | ||||
seq_len = [len(seq) for seq in x] | |||||
self.seq_len = [len(seq) for seq in x] | |||||
x = torch.Tensor(x).long() | x = torch.Tensor(x).long() | ||||
self.batch_size = x.size(0) | self.batch_size = x.size(0) | ||||
self.max_len = x.size(1) | self.max_len = x.size(1) | ||||
self.mask = seq_mask(seq_len, self.max_len) | |||||
# self.mask = seq_mask(seq_len, self.max_len) | |||||
y = network(x) | y = network(x) | ||||
return y | return y | ||||
def evaluate(self, predict, truth): | def evaluate(self, predict, truth): | ||||
truth = torch.Tensor(truth) | truth = torch.Tensor(truth) | ||||
loss, prediction = self.model.loss(predict, truth, self.mask, self.batch_size, self.max_len) | |||||
results = torch.Tensor(prediction[0][0]).view((-1,)) | |||||
loss = self.model.loss(predict, truth, self.seq_len) | |||||
prediction = self.model.prediction(predict, self.seq_len) | |||||
results = torch.Tensor(prediction).view(-1,) | |||||
accuracy = float(torch.sum(results == truth.view((-1,)))) / results.shape[0] | accuracy = float(torch.sum(results == truth.view((-1,)))) / results.shape[0] | ||||
return [loss.data, accuracy] | return [loss.data, accuracy] | ||||
@@ -10,7 +10,6 @@ import torch.nn as nn | |||||
from fastNLP.action.action import Action | from fastNLP.action.action import Action | ||||
from fastNLP.action.action import RandomSampler, Batchifier | from fastNLP.action.action import RandomSampler, Batchifier | ||||
from fastNLP.action.tester import POSTester | from fastNLP.action.tester import POSTester | ||||
from fastNLP.modules.utils import seq_mask | |||||
from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
@@ -289,13 +288,13 @@ class POSTrainer(BaseTrainer): | |||||
""" | """ | ||||
:param network: the PyTorch model | :param network: the PyTorch model | ||||
:param x: list of list, [batch_size, max_len] | :param x: list of list, [batch_size, max_len] | ||||
:return y: [batch_size, num_classes] | |||||
:return y: [batch_size, max_len, tag_size] | |||||
""" | """ | ||||
seq_len = [len(seq) for seq in x] | |||||
self.seq_len = [len(seq) for seq in x] | |||||
x = torch.Tensor(x).long() | x = torch.Tensor(x).long() | ||||
self.batch_size = x.size(0) | self.batch_size = x.size(0) | ||||
self.max_len = x.size(1) | self.max_len = x.size(1) | ||||
self.mask = seq_mask(seq_len, self.max_len) | |||||
# self.mask = seq_mask(seq_len, self.max_len) | |||||
y = network(x) | y = network(x) | ||||
return y | return y | ||||
@@ -318,7 +317,7 @@ class POSTrainer(BaseTrainer): | |||||
def get_loss(self, predict, truth): | def get_loss(self, predict, truth): | ||||
""" | """ | ||||
Compute loss given prediction and ground truth. | Compute loss given prediction and ground truth. | ||||
:param predict: prediction label vector, [batch_size, num_classes] | |||||
:param predict: prediction label vector, [batch_size, tag_size, tag_size] | |||||
:param truth: ground truth label vector, [batch_size, max_len] | :param truth: ground truth label vector, [batch_size, max_len] | ||||
:return: a scalar | :return: a scalar | ||||
""" | """ | ||||
@@ -328,7 +327,7 @@ class POSTrainer(BaseTrainer): | |||||
self.loss_func = self.model.loss | self.loss_func = self.model.loss | ||||
else: | else: | ||||
self.define_loss() | self.define_loss() | ||||
loss, prediction = self.loss_func(predict, truth, self.mask, self.batch_size, self.max_len) | |||||
loss = self.loss_func(predict, truth, self.seq_len) | |||||
# print("loss={:.2f}".format(loss.data)) | # print("loss={:.2f}".format(loss.data)) | ||||
return loss | return loss | ||||
@@ -1,9 +1,7 @@ | |||||
import torch | import torch | ||||
import torch.nn as nn | |||||
from torch.nn import functional as F | |||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
from fastNLP.modules.decoder.CRF import ContionalRandomField | |||||
from fastNLP.modules import decoder, encoder, utils | |||||
class SeqLabeling(BaseModel): | class SeqLabeling(BaseModel): | ||||
@@ -23,75 +21,61 @@ class SeqLabeling(BaseModel): | |||||
use_crf=True): | use_crf=True): | ||||
super(SeqLabeling, self).__init__() | super(SeqLabeling, self).__init__() | ||||
self.Emb = nn.Embedding(vocab_size, word_emb_dim) | |||||
if init_emb: | |||||
self.Emb.weight = nn.Parameter(init_emb) | |||||
self.num_classes = num_classes | |||||
self.input_dim = word_emb_dim | |||||
self.layers = rnn_num_layer | |||||
self.hidden_dim = hidden_dim | |||||
self.bi_direction = bi_direction | |||||
self.dropout = dropout | |||||
self.mode = rnn_mode | |||||
if self.mode == "lstm": | |||||
self.rnn = nn.LSTM(self.input_dim, self.hidden_dim, self.layers, batch_first=True, | |||||
bidirectional=self.bi_direction, dropout=self.dropout) | |||||
elif self.mode == "gru": | |||||
self.rnn = nn.GRU(self.input_dim, self.hidden_dim, self.layers, batch_first=True, | |||||
bidirectional=self.bi_direction, dropout=self.dropout) | |||||
elif self.mode == "rnn": | |||||
self.rnn = nn.RNN(self.input_dim, self.hidden_dim, self.layers, batch_first=True, | |||||
bidirectional=self.bi_direction, dropout=self.dropout) | |||||
else: | |||||
raise Exception | |||||
if bi_direction: | |||||
self.linear = nn.Linear(self.hidden_dim * 2, self.num_classes) | |||||
else: | |||||
self.linear = nn.Linear(self.hidden_dim, self.num_classes) | |||||
self.use_crf = use_crf | |||||
if self.use_crf: | |||||
self.crf = ContionalRandomField(num_classes) | |||||
self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim) | |||||
self.Rnn = encoder.lstm.Lstm(word_emb_dim, hidden_dim) | |||||
self.Linear = encoder.linear.Linear(hidden_dim, num_classes) | |||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | |||||
def forward(self, x): | def forward(self, x): | ||||
""" | """ | ||||
:param x: LongTensor, [batch_size, mex_len] | :param x: LongTensor, [batch_size, mex_len] | ||||
:return y: [batch_size, tag_size, tag_size] | :return y: [batch_size, tag_size, tag_size] | ||||
""" | """ | ||||
x = self.Emb(x) | |||||
x = self.Embedding(x) | |||||
# [batch_size, max_len, word_emb_dim] | # [batch_size, max_len, word_emb_dim] | ||||
x, hidden = self.rnn(x) | |||||
x = self.Rnn(x) | |||||
# [batch_size, max_len, hidden_size * direction] | # [batch_size, max_len, hidden_size * direction] | ||||
y = self.linear(x) | |||||
x = self.Linear(x) | |||||
# [batch_size, max_len, num_classes] | # [batch_size, max_len, num_classes] | ||||
return y | |||||
return x | |||||
def loss(self, x, y, mask, batch_size, max_len): | |||||
def loss(self, x, y, seq_length): | |||||
""" | """ | ||||
Negative log likelihood loss. | Negative log likelihood loss. | ||||
:param x: FloatTensor, [batch_size, tag_size, tag_size] | |||||
:param x: FloatTensor, [batch_size, max_len, tag_size] | |||||
:param y: LongTensor, [batch_size, max_len] | :param y: LongTensor, [batch_size, max_len] | ||||
:param mask: ByteTensor, [batch_size, max_len] | |||||
:param batch_size: int | |||||
:param max_len: int | |||||
:param seq_length: list of int. [batch_size] | |||||
:return loss: a scalar Tensor | :return loss: a scalar Tensor | ||||
prediction: list of tuple of (decode path(list), best score) | |||||
""" | """ | ||||
x = x.float() | x = x.float() | ||||
y = y.long() | y = y.long() | ||||
batch_size = x.size(0) | |||||
max_len = x.size(1) | |||||
mask = utils.seq_mask(seq_length, max_len) | |||||
mask = mask.byte().view(batch_size, max_len) | |||||
# mask = x.new(batch_size, max_len) | |||||
total_loss = self.Crf(x, y, mask) | |||||
return torch.mean(total_loss) | |||||
def prediction(self, x, seq_length): | |||||
""" | |||||
:param x: FloatTensor, [batch_size, tag_size, tag_size] | |||||
:param seq_length: int | |||||
:return prediction: list of tuple of (decode path(list), best score) | |||||
""" | |||||
x = x.float() | |||||
batch_size = x.size(0) | |||||
max_len = x.size(1) | |||||
mask = utils.seq_mask(seq_length, max_len) | |||||
mask = mask.byte() | mask = mask.byte() | ||||
# print(x.shape, y.shape, mask.shape) | |||||
if self.use_crf: | |||||
total_loss = self.crf(x, y, mask) | |||||
tag_seq = self.crf.viterbi_decode(x, mask) | |||||
else: | |||||
# error | |||||
loss_function = nn.NLLLoss(ignore_index=0, size_average=False) | |||||
x = x.view(batch_size * max_len, -1) | |||||
score = F.log_softmax(x) | |||||
total_loss = loss_function(score, y.view(batch_size * max_len)) | |||||
_, tag_seq = torch.max(score) | |||||
tag_seq = tag_seq.view(batch_size, max_len) | |||||
return torch.mean(total_loss), tag_seq | |||||
# mask = x.new(batch_size, max_len) | |||||
tag_seq = self.Crf.viterbi_decode(x, mask) | |||||
return tag_seq |
@@ -0,0 +1,11 @@ | |||||
from . import aggregation | |||||
from . import decoder | |||||
from . import encoder | |||||
from . import interaction | |||||
__version__ = '0.0.0' | |||||
__all__ = ['encoder', | |||||
'decoder', | |||||
'aggregation', | |||||
'interaction'] |
@@ -18,13 +18,13 @@ def seq_len_to_byte_mask(seq_lens): | |||||
return mask | return mask | ||||
class ContionalRandomField(nn.Module): | |||||
class ConditionalRandomField(nn.Module): | |||||
def __init__(self, tag_size, include_start_end_trans=True): | def __init__(self, tag_size, include_start_end_trans=True): | ||||
""" | """ | ||||
:param tag_size: int, num of tags | :param tag_size: int, num of tags | ||||
:param include_start_end_trans: bool, whether to include start/end tag | :param include_start_end_trans: bool, whether to include start/end tag | ||||
""" | """ | ||||
super(ContionalRandomField, self).__init__() | |||||
super(ConditionalRandomField, self).__init__() | |||||
self.include_start_end_trans = include_start_end_trans | self.include_start_end_trans = include_start_end_trans | ||||
self.tag_size = tag_size | self.tag_size = tag_size | ||||
@@ -47,7 +47,6 @@ class ContionalRandomField(nn.Module): | |||||
""" | """ | ||||
Computes the (batch_size,) denominator term for the log-likelihood, which is the | Computes the (batch_size,) denominator term for the log-likelihood, which is the | ||||
sum of the likelihoods across all possible state sequences. | sum of the likelihoods across all possible state sequences. | ||||
:param feats:FloatTensor, batch_size x max_len x tag_size | :param feats:FloatTensor, batch_size x max_len x tag_size | ||||
:param masks:ByteTensor, batch_size x max_len | :param masks:ByteTensor, batch_size x max_len | ||||
:return:FloatTensor, batch_size | :return:FloatTensor, batch_size | ||||
@@ -128,7 +127,7 @@ class ContionalRandomField(nn.Module): | |||||
return all_path_score - gold_path_score | return all_path_score - gold_path_score | ||||
def viterbi_decode(self, feats, masks): | |||||
def viterbi_decode(self, feats, masks, get_score=False): | |||||
""" | """ | ||||
Given a feats matrix, return best decode path and best score. | Given a feats matrix, return best decode path and best score. | ||||
:param feats: | :param feats: | ||||
@@ -147,28 +146,28 @@ class ContionalRandomField(nn.Module): | |||||
for t in range(self.tag_size): | for t in range(self.tag_size): | ||||
pre_scores = self.transition_m[:, t].view( | pre_scores = self.transition_m[:, t].view( | ||||
1, self.tag_size) + alpha | 1, self.tag_size) + alpha | ||||
max_scroe, indice = pre_scores.max(dim=1) | |||||
new_alpha[:, t] = max_scroe + feats[:, i, t] | |||||
paths[:, i - 1, t] = indice | |||||
alpha = new_alpha * \ | |||||
masks[:, i:i + 1].float() + alpha * \ | |||||
(1 - masks[:, i:i + 1].float()) | |||||
max_score, indices = pre_scores.max(dim=1) | |||||
new_alpha[:, t] = max_score + feats[:, i, t] | |||||
paths[:, i - 1, t] = indices | |||||
alpha = new_alpha * masks[:, i:i + 1].float() + alpha * (1 - masks[:, i:i + 1].float()) | |||||
if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
alpha += self.end_scores.view(1, -1) | alpha += self.end_scores.view(1, -1) | ||||
max_scroes, indice = alpha.max(dim=1) | |||||
indice = indice.cpu().numpy() | |||||
max_scores, indices = alpha.max(dim=1) | |||||
indices = indices.cpu().numpy() | |||||
final_paths = [] | final_paths = [] | ||||
paths = paths.cpu().numpy().astype(int) | paths = paths.cpu().numpy().astype(int) | ||||
seq_lens = masks.cumsum(dim=1, dtype=torch.long)[:, -1] | seq_lens = masks.cumsum(dim=1, dtype=torch.long)[:, -1] | ||||
for b in range(batch_size): | for b in range(batch_size): | ||||
path = [indice[b]] | |||||
path = [indices[b]] | |||||
for i in range(seq_lens[b] - 2, -1, -1): | for i in range(seq_lens[b] - 2, -1, -1): | ||||
index = paths[b, i, path[-1]] | index = paths[b, i, path[-1]] | ||||
path.append(index) | path.append(index) | ||||
final_paths.append(path[::-1]) | final_paths.append(path[::-1]) | ||||
return list(zip(final_paths, max_scroes.detach().cpu().numpy())) | |||||
if get_score: | |||||
return list(zip(final_paths, max_scores.detach().cpu().numpy())) | |||||
else: | |||||
return final_paths |
@@ -0,0 +1,3 @@ | |||||
from .CRF import ConditionalRandomField | |||||
__all__ = ["ConditionalRandomField"] |
@@ -0,0 +1,7 @@ | |||||
from .embedding import Embedding | |||||
from .linear import Linear | |||||
from .lstm import Lstm | |||||
__all__ = ["Lstm", | |||||
"Embedding", | |||||
"Linear"] |
@@ -1,10 +1,9 @@ | |||||
import torch.nn as nn | import torch.nn as nn | ||||
class Lookuptable(nn.Module): | |||||
class Embedding(nn.Module): | |||||
""" | """ | ||||
A simple lookup table | A simple lookup table | ||||
Args: | Args: | ||||
nums : the size of the lookup table | nums : the size of the lookup table | ||||
dims : the size of each vector | dims : the size of each vector | ||||
@@ -12,13 +11,14 @@ class Lookuptable(nn.Module): | |||||
sparse : If True, gradient matrix will be a sparse tensor. In this case, | sparse : If True, gradient matrix will be a sparse tensor. In this case, | ||||
only optim.SGD(cuda and cpu) and optim.Adagrad(cpu) can be used | only optim.SGD(cuda and cpu) and optim.Adagrad(cpu) can be used | ||||
""" | """ | ||||
def __init__(self, nums, dims, padding_idx=0, sparse=False): | |||||
super(Lookuptable, self).__init__() | |||||
self.embed = nn.Embedding(nums, dims, padding_idx, sparse=sparse) | |||||
def forward(self, x): | |||||
return self.embed(x) | |||||
def __init__(self, nums, dims, padding_idx=0, sparse=False, init_emb=None, dropout=0.0): | |||||
super(Embedding, self).__init__() | |||||
self.embed = nn.Embedding(nums, dims, padding_idx, sparse=sparse) | |||||
if init_emb: | |||||
self.embed.weight = nn.Parameter(init_emb) | |||||
self.dropout = nn.Dropout(dropout) | |||||
if __name__ == "__main__": | |||||
model = Lookuptable(10, 20) | |||||
def forward(self, x): | |||||
x = self.embed(x) | |||||
return self.dropout(x) |
@@ -0,0 +1,21 @@ | |||||
import torch.nn as nn | |||||
class Linear(nn.Module): | |||||
""" | |||||
Linear module | |||||
Args: | |||||
input_size : input size | |||||
hidden_size : hidden size | |||||
num_layers : number of hidden layers | |||||
dropout : dropout rate | |||||
bidirectional : If True, becomes a bidirectional RNN | |||||
""" | |||||
def __init__(self, input_size, output_size, bias=True): | |||||
super(Linear, self).__init__() | |||||
self.linear = nn.Linear(input_size, output_size, bias) | |||||
def forward(self, x): | |||||
x = self.linear(x) | |||||
return x |
@@ -4,7 +4,6 @@ import torch.nn as nn | |||||
class Lstm(nn.Module): | class Lstm(nn.Module): | ||||
""" | """ | ||||
LSTM module | LSTM module | ||||
Args: | Args: | ||||
input_size : input size | input_size : input size | ||||
hidden_size : hidden size | hidden_size : hidden size | ||||
@@ -12,11 +11,12 @@ class Lstm(nn.Module): | |||||
dropout : dropout rate | dropout : dropout rate | ||||
bidirectional : If True, becomes a bidirectional RNN | bidirectional : If True, becomes a bidirectional RNN | ||||
""" | """ | ||||
def __init__(self, input_size, hidden_size, num_layers, dropout, bidirectional): | |||||
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, bidirectional=False): | |||||
super(Lstm, self).__init__() | super(Lstm, self).__init__() | ||||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True, | self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True, | ||||
dropout=dropout, bidirectional=bidirectional) | dropout=dropout, bidirectional=bidirectional) | ||||
def forward(self, x): | def forward(self, x): | ||||
x, _ = self.lstm(x) | x, _ = self.lstm(x) | ||||
return x | return x |