From e1e0661debb8a649ebad7c2837dcd7d3d65a6151 Mon Sep 17 00:00:00 2001 From: yunfan Date: Tue, 27 Nov 2018 18:39:57 +0800 Subject: [PATCH] add doc comments --- fastNLP/core/fieldarray.py | 1 + fastNLP/io/dataset_loader.py | 1 + fastNLP/models/cnn_text_classification.py | 20 +++++++++++++++++++- 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/fastNLP/core/fieldarray.py b/fastNLP/core/fieldarray.py index 880d9d39..3a63f788 100644 --- a/fastNLP/core/fieldarray.py +++ b/fastNLP/core/fieldarray.py @@ -20,6 +20,7 @@ class FieldArray(object): self.padding_val = padding_val self.is_target = is_target self.is_input = is_input + # TODO: auto detect dtype self.dtype = None def __repr__(self): diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index 158a9e58..79cb30ad 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -1,3 +1,4 @@ +#TODO: need fix for current DataSet import os from fastNLP.core.dataset import DataSet diff --git a/fastNLP/models/cnn_text_classification.py b/fastNLP/models/cnn_text_classification.py index a4dcfef2..04b76fba 100644 --- a/fastNLP/models/cnn_text_classification.py +++ b/fastNLP/models/cnn_text_classification.py @@ -37,8 +37,9 @@ class CNNText(torch.nn.Module): def forward(self, word_seq): """ + :param word_seq: torch.LongTensor, [batch_size, seq_len] - :return x: torch.LongTensor, [batch_size, num_classes] + :return output: dict of torch.LongTensor, [batch_size, num_classes] """ x = self.embed(word_seq) # [N,L] -> [N,L,C] x = self.conv_pool(x) # [N,L,C] -> [N,C] @@ -47,14 +48,31 @@ class CNNText(torch.nn.Module): return {'output':x} def predict(self, word_seq): + """ + + :param word_seq: torch.LongTensor, [batch_size, seq_len] + :return predict: dict of torch.LongTensor, [batch_size, seq_len] + """ output = self(word_seq) _, predict = output['output'].max(dim=1) return {'predict': predict} def get_loss(self, output, label_seq): + """ + + :param output: output of forward(), [batch_size, seq_len] + :param label_seq: true label in DataSet, [batch_size, seq_len] + :return loss: torch.Tensor + """ return self._loss(output, label_seq) def evaluate(self, predict, label_seq): + """ + + :param predict: iterable predict tensors + :param label_seq: iterable true label tensors + :return accuracy: dict of float + """ predict, label_seq = torch.stack(tuple(predict), dim=0), torch.stack(tuple(label_seq), dim=0) predict, label_seq = predict.squeeze(), label_seq.squeeze() correct = (predict == label_seq).long().sum().item()