Browse Source

fix bug in CRF comments; optimize PyTorch type conversion.

tags/v0.1.0
FengZiYjun 7 years ago
parent
commit
83c032df5d
3 changed files with 33 additions and 36 deletions
  1. +13
    -10
      fastNLP/action/trainer.py
  2. +18
    -24
      fastNLP/models/sequencce_modeling.py
  3. +2
    -2
      fastNLP/modules/CRF.py

+ 13
- 10
fastNLP/action/trainer.py View File

@@ -170,8 +170,8 @@ class BaseTrainer(Action):
[[word_21, word_22, word_23], [label_21. label_22]], # sample 2 [[word_21, word_22, word_23], [label_21. label_22]], # sample 2
... ...
] ]
:return batch_x: list. Each entry is a list of features of a sample.
batch_y: list. Each entry is a list of labels of a sample.
:return batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len]
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels]
""" """
if self.iterator is None: if self.iterator is None:
self.iterator = iter(Batchifier(RandomSampler(data), batch_size, drop_last=True)) self.iterator = iter(Batchifier(RandomSampler(data), batch_size, drop_last=True))
@@ -325,7 +325,6 @@ class POSTrainer(BaseTrainer):
self.num_classes = train_args.num_classes self.num_classes = train_args.num_classes
self.max_len = None self.max_len = None
self.mask = None self.mask = None
self.batch_x = None


def prepare_input(self, data_path): def prepare_input(self, data_path):
""" """
@@ -336,14 +335,18 @@ class POSTrainer(BaseTrainer):
return data_train, data_dev, 0, 1 return data_train, data_dev, 0, 1


def data_forward(self, network, x): def data_forward(self, network, x):
"""
:param network: the PyTorch model
:param x: list of list, [batch_size, max_len]
:return y: [batch_size, num_classes]
"""
seq_len = [len(seq) for seq in x] seq_len = [len(seq) for seq in x]
x = torch.LongTensor(x)
x = torch.Tensor(x).long()
self.batch_size = x.size(0) self.batch_size = x.size(0)
self.max_len = x.size(1) self.max_len = x.size(1)
self.mask = seq_mask(seq_len, self.max_len) self.mask = seq_mask(seq_len, self.max_len)
x = network(x)
self.batch_x = x
return x
y = network(x)
return y


def mode(self, test=False): def mode(self, test=False):
if test: if test:
@@ -357,8 +360,8 @@ class POSTrainer(BaseTrainer):
def get_loss(self, predict, truth): def get_loss(self, predict, truth):
""" """
Compute loss given prediction and ground truth. Compute loss given prediction and ground truth.
:param predict: prediction label vector
:param truth: ground truth label vector
:param predict: prediction label vector, [batch_size, num_classes]
:param truth: ground truth label vector, [batch_size, max_len]
:return: a scalar :return: a scalar
""" """
if self.loss_func is None: if self.loss_func is None:
@@ -366,7 +369,7 @@ class POSTrainer(BaseTrainer):
self.loss_func = self.model.loss self.loss_func = self.model.loss
else: else:
self.define_loss() self.define_loss()
return self.loss_func(self.batch_x, predict, self.mask, self.batch_size, self.max_len)
return self.loss_func(predict, truth, self.mask, self.batch_size, self.max_len)




if __name__ == "__name__": if __name__ == "__name__":


+ 18
- 24
fastNLP/models/sequencce_modeling.py View File

@@ -12,7 +12,7 @@ class SeqLabeling(BaseModel):
""" """


def __init__(self, hidden_dim, def __init__(self, hidden_dim,
rnn_num_layerd,
rnn_num_layer,
num_classes, num_classes,
vocab_size, vocab_size,
word_emb_dim=100, word_emb_dim=100,
@@ -29,7 +29,7 @@ class SeqLabeling(BaseModel):


self.num_classes = num_classes self.num_classes = num_classes
self.input_dim = word_emb_dim self.input_dim = word_emb_dim
self.layers = rnn_num_layerd
self.layers = rnn_num_layer
self.hidden_dim = hidden_dim self.hidden_dim = hidden_dim
self.bi_direction = bi_direction self.bi_direction = bi_direction
self.dropout = dropout self.dropout = dropout
@@ -55,32 +55,26 @@ class SeqLabeling(BaseModel):
self.crf = ContionalRandomField(num_classes) self.crf = ContionalRandomField(num_classes)


def forward(self, x): def forward(self, x):

x = self.embedding(x)
x, hidden = self.encode(x)
x = self.aggregate(x)
x = self.decode(x)
return x

def embedding(self, x):
return self.Emb(x)

def encode(self, x):
return self.rnn(x)

def aggregate(self, x):
return x

def decode(self, x):
x = self.linear(x)
return x
"""
:param x: LongTensor, [batch_size, mex_len]
:return y: [batch_size, tag_size, tag_size]
"""
x = self.Emb(x)
# [batch_size, max_len, word_emb_dim]
x, hidden = self.rnn(x)
# [batch_size, max_len, hidden_size * direction]
y = self.linear(x)
# [batch_size, max_len, num_classes]
return y


def loss(self, x, y, mask, batch_size, max_len): def loss(self, x, y, mask, batch_size, max_len):
""" """
Negative log likelihood loss. Negative log likelihood loss.
:param x:
:param y:
:param seq_len:
:param x: FloatTensor, [batch_size, tag_size, tag_size]
:param y: LongTensor, [batch_size, max_len]
:param mask: ByteTensor, [batch_size, max_len]
:param batch_size: int
:param max_len: int
:return loss: :return loss:
prediction: prediction:
""" """


+ 2
- 2
fastNLP/modules/CRF.py View File

@@ -82,7 +82,7 @@ class ContionalRandomField(nn.Module):
def _glod_score(self, feats, tags, masks): def _glod_score(self, feats, tags, masks):
""" """
Compute the score for the gold path. Compute the score for the gold path.
:param feats: FloatTensor, batch_size x tag_size x tag_size
:param feats: FloatTensor, batch_size x max_len x tag_size
:param tags: LongTensor, batch_size x max_len :param tags: LongTensor, batch_size x max_len
:param masks: ByteTensor, batch_size x max_len :param masks: ByteTensor, batch_size x max_len
:return:FloatTensor, batch_size :return:FloatTensor, batch_size
@@ -118,7 +118,7 @@ class ContionalRandomField(nn.Module):
def forward(self, feats, tags, masks): def forward(self, feats, tags, masks):
""" """
Calculate the neg log likelihood Calculate the neg log likelihood
:param feats:FloatTensor, batch_size x tag_size x tag_size
:param feats:FloatTensor, batch_size x max_len x tag_size
:param tags:LongTensor, batch_size x max_len :param tags:LongTensor, batch_size x max_len
:param masks:ByteTensor batch_size x max_len :param masks:ByteTensor batch_size x max_len
:return:FloatTensor, batch_size :return:FloatTensor, batch_size


Loading…
Cancel
Save