Browse Source

optimize model definition of seq modeling; test_POS_pipeline ready to run

tags/v0.1.0
FengZiYjun 7 years ago
parent
commit
1c0eac4c82
10 changed files with 121 additions and 97 deletions
  1. +6
    -6
      fastNLP/action/tester.py
  2. +5
    -6
      fastNLP/action/trainer.py
  3. +41
    -57
      fastNLP/models/sequence_modeling.py
  4. +11
    -0
      fastNLP/modules/__init__.py
  5. +14
    -15
      fastNLP/modules/decoder/CRF.py
  6. +3
    -0
      fastNLP/modules/decoder/__init__.py
  7. +7
    -0
      fastNLP/modules/encoder/__init__.py
  8. +10
    -10
      fastNLP/modules/encoder/embedding.py
  9. +21
    -0
      fastNLP/modules/encoder/linear.py
  10. +3
    -3
      fastNLP/modules/encoder/lstm.py

+ 6
- 6
fastNLP/action/tester.py View File

@@ -1,12 +1,11 @@
import _pickle
import os

import numpy as np
import torch
import os

from fastNLP.action.action import Action
from fastNLP.action.action import RandomSampler, Batchifier
from fastNLP.modules.utils import seq_mask


class BaseTester(Action):
@@ -148,18 +147,19 @@ class POSTester(BaseTester):
:param x: list of list, [batch_size, max_len]
:return y: [batch_size, num_classes]
"""
seq_len = [len(seq) for seq in x]
self.seq_len = [len(seq) for seq in x]
x = torch.Tensor(x).long()
self.batch_size = x.size(0)
self.max_len = x.size(1)
self.mask = seq_mask(seq_len, self.max_len)
# self.mask = seq_mask(seq_len, self.max_len)
y = network(x)
return y

def evaluate(self, predict, truth):
truth = torch.Tensor(truth)
loss, prediction = self.model.loss(predict, truth, self.mask, self.batch_size, self.max_len)
results = torch.Tensor(prediction[0][0]).view((-1,))
loss = self.model.loss(predict, truth, self.seq_len)
prediction = self.model.prediction(predict, self.seq_len)
results = torch.Tensor(prediction).view(-1,)
accuracy = float(torch.sum(results == truth.view((-1,)))) / results.shape[0]
return [loss.data, accuracy]



+ 5
- 6
fastNLP/action/trainer.py View File

@@ -10,7 +10,6 @@ import torch.nn as nn
from fastNLP.action.action import Action
from fastNLP.action.action import RandomSampler, Batchifier
from fastNLP.action.tester import POSTester
from fastNLP.modules.utils import seq_mask
from fastNLP.saver.model_saver import ModelSaver


@@ -289,13 +288,13 @@ class POSTrainer(BaseTrainer):
"""
:param network: the PyTorch model
:param x: list of list, [batch_size, max_len]
:return y: [batch_size, num_classes]
:return y: [batch_size, max_len, tag_size]
"""
seq_len = [len(seq) for seq in x]
self.seq_len = [len(seq) for seq in x]
x = torch.Tensor(x).long()
self.batch_size = x.size(0)
self.max_len = x.size(1)
self.mask = seq_mask(seq_len, self.max_len)
# self.mask = seq_mask(seq_len, self.max_len)
y = network(x)
return y

@@ -318,7 +317,7 @@ class POSTrainer(BaseTrainer):
def get_loss(self, predict, truth):
"""
Compute loss given prediction and ground truth.
:param predict: prediction label vector, [batch_size, num_classes]
:param predict: prediction label vector, [batch_size, tag_size, tag_size]
:param truth: ground truth label vector, [batch_size, max_len]
:return: a scalar
"""
@@ -328,7 +327,7 @@ class POSTrainer(BaseTrainer):
self.loss_func = self.model.loss
else:
self.define_loss()
loss, prediction = self.loss_func(predict, truth, self.mask, self.batch_size, self.max_len)
loss = self.loss_func(predict, truth, self.seq_len)
# print("loss={:.2f}".format(loss.data))
return loss



+ 41
- 57
fastNLP/models/sequence_modeling.py View File

@@ -1,9 +1,7 @@
import torch
import torch.nn as nn
from torch.nn import functional as F

from fastNLP.models.base_model import BaseModel
from fastNLP.modules.decoder.CRF import ContionalRandomField
from fastNLP.modules import decoder, encoder, utils


class SeqLabeling(BaseModel):
@@ -23,75 +21,61 @@ class SeqLabeling(BaseModel):
use_crf=True):
super(SeqLabeling, self).__init__()

self.Emb = nn.Embedding(vocab_size, word_emb_dim)
if init_emb:
self.Emb.weight = nn.Parameter(init_emb)

self.num_classes = num_classes
self.input_dim = word_emb_dim
self.layers = rnn_num_layer
self.hidden_dim = hidden_dim
self.bi_direction = bi_direction
self.dropout = dropout
self.mode = rnn_mode

if self.mode == "lstm":
self.rnn = nn.LSTM(self.input_dim, self.hidden_dim, self.layers, batch_first=True,
bidirectional=self.bi_direction, dropout=self.dropout)
elif self.mode == "gru":
self.rnn = nn.GRU(self.input_dim, self.hidden_dim, self.layers, batch_first=True,
bidirectional=self.bi_direction, dropout=self.dropout)
elif self.mode == "rnn":
self.rnn = nn.RNN(self.input_dim, self.hidden_dim, self.layers, batch_first=True,
bidirectional=self.bi_direction, dropout=self.dropout)
else:
raise Exception
if bi_direction:
self.linear = nn.Linear(self.hidden_dim * 2, self.num_classes)
else:
self.linear = nn.Linear(self.hidden_dim, self.num_classes)
self.use_crf = use_crf
if self.use_crf:
self.crf = ContionalRandomField(num_classes)
self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim)
self.Rnn = encoder.lstm.Lstm(word_emb_dim, hidden_dim)
self.Linear = encoder.linear.Linear(hidden_dim, num_classes)
self.Crf = decoder.CRF.ConditionalRandomField(num_classes)

def forward(self, x):
"""
:param x: LongTensor, [batch_size, mex_len]
:return y: [batch_size, tag_size, tag_size]
"""
x = self.Emb(x)
x = self.Embedding(x)
# [batch_size, max_len, word_emb_dim]
x, hidden = self.rnn(x)
x = self.Rnn(x)
# [batch_size, max_len, hidden_size * direction]
y = self.linear(x)
x = self.Linear(x)
# [batch_size, max_len, num_classes]
return y
return x

def loss(self, x, y, mask, batch_size, max_len):
def loss(self, x, y, seq_length):
"""
Negative log likelihood loss.
:param x: FloatTensor, [batch_size, tag_size, tag_size]
:param x: FloatTensor, [batch_size, max_len, tag_size]
:param y: LongTensor, [batch_size, max_len]
:param mask: ByteTensor, [batch_size, max_len]
:param batch_size: int
:param max_len: int
:param seq_length: list of int. [batch_size]
:return loss: a scalar Tensor
prediction: list of tuple of (decode path(list), best score)

"""
x = x.float()
y = y.long()

batch_size = x.size(0)
max_len = x.size(1)

mask = utils.seq_mask(seq_length, max_len)
mask = mask.byte().view(batch_size, max_len)
# mask = x.new(batch_size, max_len)

total_loss = self.Crf(x, y, mask)

return torch.mean(total_loss)

def prediction(self, x, seq_length):
"""
:param x: FloatTensor, [batch_size, tag_size, tag_size]
:param seq_length: int
:return prediction: list of tuple of (decode path(list), best score)
"""
x = x.float()
batch_size = x.size(0)
max_len = x.size(1)

mask = utils.seq_mask(seq_length, max_len)
mask = mask.byte()
# print(x.shape, y.shape, mask.shape)

if self.use_crf:
total_loss = self.crf(x, y, mask)
tag_seq = self.crf.viterbi_decode(x, mask)
else:
# error
loss_function = nn.NLLLoss(ignore_index=0, size_average=False)
x = x.view(batch_size * max_len, -1)
score = F.log_softmax(x)
total_loss = loss_function(score, y.view(batch_size * max_len))
_, tag_seq = torch.max(score)
tag_seq = tag_seq.view(batch_size, max_len)
return torch.mean(total_loss), tag_seq
# mask = x.new(batch_size, max_len)

tag_seq = self.Crf.viterbi_decode(x, mask)

return tag_seq

+ 11
- 0
fastNLP/modules/__init__.py View File

@@ -0,0 +1,11 @@
from . import aggregation
from . import decoder
from . import encoder
from . import interaction

__version__ = '0.0.0'

__all__ = ['encoder',
'decoder',
'aggregation',
'interaction']

+ 14
- 15
fastNLP/modules/decoder/CRF.py View File

@@ -18,13 +18,13 @@ def seq_len_to_byte_mask(seq_lens):
return mask


class ContionalRandomField(nn.Module):
class ConditionalRandomField(nn.Module):
def __init__(self, tag_size, include_start_end_trans=True):
"""
:param tag_size: int, num of tags
:param include_start_end_trans: bool, whether to include start/end tag
"""
super(ContionalRandomField, self).__init__()
super(ConditionalRandomField, self).__init__()

self.include_start_end_trans = include_start_end_trans
self.tag_size = tag_size
@@ -47,7 +47,6 @@ class ContionalRandomField(nn.Module):
"""
Computes the (batch_size,) denominator term for the log-likelihood, which is the
sum of the likelihoods across all possible state sequences.

:param feats:FloatTensor, batch_size x max_len x tag_size
:param masks:ByteTensor, batch_size x max_len
:return:FloatTensor, batch_size
@@ -128,7 +127,7 @@ class ContionalRandomField(nn.Module):

return all_path_score - gold_path_score

def viterbi_decode(self, feats, masks):
def viterbi_decode(self, feats, masks, get_score=False):
"""
Given a feats matrix, return best decode path and best score.
:param feats:
@@ -147,28 +146,28 @@ class ContionalRandomField(nn.Module):
for t in range(self.tag_size):
pre_scores = self.transition_m[:, t].view(
1, self.tag_size) + alpha
max_scroe, indice = pre_scores.max(dim=1)
new_alpha[:, t] = max_scroe + feats[:, i, t]
paths[:, i - 1, t] = indice
alpha = new_alpha * \
masks[:, i:i + 1].float() + alpha * \
(1 - masks[:, i:i + 1].float())
max_score, indices = pre_scores.max(dim=1)
new_alpha[:, t] = max_score + feats[:, i, t]
paths[:, i - 1, t] = indices
alpha = new_alpha * masks[:, i:i + 1].float() + alpha * (1 - masks[:, i:i + 1].float())

if self.include_start_end_trans:
alpha += self.end_scores.view(1, -1)

max_scroes, indice = alpha.max(dim=1)
indice = indice.cpu().numpy()
max_scores, indices = alpha.max(dim=1)
indices = indices.cpu().numpy()
final_paths = []
paths = paths.cpu().numpy().astype(int)

seq_lens = masks.cumsum(dim=1, dtype=torch.long)[:, -1]

for b in range(batch_size):
path = [indice[b]]
path = [indices[b]]
for i in range(seq_lens[b] - 2, -1, -1):
index = paths[b, i, path[-1]]
path.append(index)
final_paths.append(path[::-1])

return list(zip(final_paths, max_scroes.detach().cpu().numpy()))
if get_score:
return list(zip(final_paths, max_scores.detach().cpu().numpy()))
else:
return final_paths

+ 3
- 0
fastNLP/modules/decoder/__init__.py View File

@@ -0,0 +1,3 @@
from .CRF import ConditionalRandomField

__all__ = ["ConditionalRandomField"]

+ 7
- 0
fastNLP/modules/encoder/__init__.py View File

@@ -0,0 +1,7 @@
from .embedding import Embedding
from .linear import Linear
from .lstm import Lstm

__all__ = ["Lstm",
"Embedding",
"Linear"]

+ 10
- 10
fastNLP/modules/encoder/embedding.py View File

@@ -1,10 +1,9 @@
import torch.nn as nn


class Lookuptable(nn.Module):
class Embedding(nn.Module):
"""
A simple lookup table

Args:
nums : the size of the lookup table
dims : the size of each vector
@@ -12,13 +11,14 @@ class Lookuptable(nn.Module):
sparse : If True, gradient matrix will be a sparse tensor. In this case,
only optim.SGD(cuda and cpu) and optim.Adagrad(cpu) can be used
"""
def __init__(self, nums, dims, padding_idx=0, sparse=False):
super(Lookuptable, self).__init__()
self.embed = nn.Embedding(nums, dims, padding_idx, sparse=sparse)
def forward(self, x):
return self.embed(x)

def __init__(self, nums, dims, padding_idx=0, sparse=False, init_emb=None, dropout=0.0):
super(Embedding, self).__init__()
self.embed = nn.Embedding(nums, dims, padding_idx, sparse=sparse)
if init_emb:
self.embed.weight = nn.Parameter(init_emb)
self.dropout = nn.Dropout(dropout)

if __name__ == "__main__":
model = Lookuptable(10, 20)
def forward(self, x):
x = self.embed(x)
return self.dropout(x)

+ 21
- 0
fastNLP/modules/encoder/linear.py View File

@@ -0,0 +1,21 @@
import torch.nn as nn


class Linear(nn.Module):
"""
Linear module
Args:
input_size : input size
hidden_size : hidden size
num_layers : number of hidden layers
dropout : dropout rate
bidirectional : If True, becomes a bidirectional RNN
"""

def __init__(self, input_size, output_size, bias=True):
super(Linear, self).__init__()
self.linear = nn.Linear(input_size, output_size, bias)

def forward(self, x):
x = self.linear(x)
return x

+ 3
- 3
fastNLP/modules/encoder/lstm.py View File

@@ -4,7 +4,6 @@ import torch.nn as nn
class Lstm(nn.Module):
"""
LSTM module

Args:
input_size : input size
hidden_size : hidden size
@@ -12,11 +11,12 @@ class Lstm(nn.Module):
dropout : dropout rate
bidirectional : If True, becomes a bidirectional RNN
"""
def __init__(self, input_size, hidden_size, num_layers, dropout, bidirectional):

def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, bidirectional=False):
super(Lstm, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True,
dropout=dropout, bidirectional=bidirectional)
def forward(self, x):
x, _ = self.lstm(x)
return x

Loading…
Cancel
Save