Browse Source

add doc comments

tags/v0.2.0
yunfan 6 years ago
parent
commit
e1e0661deb
3 changed files with 21 additions and 1 deletions
  1. +1
    -0
      fastNLP/core/fieldarray.py
  2. +1
    -0
      fastNLP/io/dataset_loader.py
  3. +19
    -1
      fastNLP/models/cnn_text_classification.py

+ 1
- 0
fastNLP/core/fieldarray.py View File

@@ -20,6 +20,7 @@ class FieldArray(object):
self.padding_val = padding_val self.padding_val = padding_val
self.is_target = is_target self.is_target = is_target
self.is_input = is_input self.is_input = is_input
# TODO: auto detect dtype
self.dtype = None self.dtype = None


def __repr__(self): def __repr__(self):


+ 1
- 0
fastNLP/io/dataset_loader.py View File

@@ -1,3 +1,4 @@
#TODO: need fix for current DataSet
import os import os


from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet


+ 19
- 1
fastNLP/models/cnn_text_classification.py View File

@@ -37,8 +37,9 @@ class CNNText(torch.nn.Module):


def forward(self, word_seq): def forward(self, word_seq):
""" """

:param word_seq: torch.LongTensor, [batch_size, seq_len] :param word_seq: torch.LongTensor, [batch_size, seq_len]
:return x: torch.LongTensor, [batch_size, num_classes]
:return output: dict of torch.LongTensor, [batch_size, num_classes]
""" """
x = self.embed(word_seq) # [N,L] -> [N,L,C] x = self.embed(word_seq) # [N,L] -> [N,L,C]
x = self.conv_pool(x) # [N,L,C] -> [N,C] x = self.conv_pool(x) # [N,L,C] -> [N,C]
@@ -47,14 +48,31 @@ class CNNText(torch.nn.Module):
return {'output':x} return {'output':x}


def predict(self, word_seq): def predict(self, word_seq):
"""

:param word_seq: torch.LongTensor, [batch_size, seq_len]
:return predict: dict of torch.LongTensor, [batch_size, seq_len]
"""
output = self(word_seq) output = self(word_seq)
_, predict = output['output'].max(dim=1) _, predict = output['output'].max(dim=1)
return {'predict': predict} return {'predict': predict}


def get_loss(self, output, label_seq): def get_loss(self, output, label_seq):
"""

:param output: output of forward(), [batch_size, seq_len]
:param label_seq: true label in DataSet, [batch_size, seq_len]
:return loss: torch.Tensor
"""
return self._loss(output, label_seq) return self._loss(output, label_seq)


def evaluate(self, predict, label_seq): def evaluate(self, predict, label_seq):
"""

:param predict: iterable predict tensors
:param label_seq: iterable true label tensors
:return accuracy: dict of float
"""
predict, label_seq = torch.stack(tuple(predict), dim=0), torch.stack(tuple(label_seq), dim=0) predict, label_seq = torch.stack(tuple(predict), dim=0), torch.stack(tuple(label_seq), dim=0)
predict, label_seq = predict.squeeze(), label_seq.squeeze() predict, label_seq = predict.squeeze(), label_seq.squeeze()
correct = (predict == label_seq).long().sum().item() correct = (predict == label_seq).long().sum().item()


Loading…
Cancel
Save