diff --git a/fastNLP/core/const.py b/fastNLP/core/const.py new file mode 100644 index 00000000..56447395 --- /dev/null +++ b/fastNLP/core/const.py @@ -0,0 +1,46 @@ +class Const(): + """fastNLP中field命名常量。 + 具体列表:: + + INPUT 模型的序列输入 words(复数words1, words2) + CHAR_INPUT 模型character输入 chars(复数chars1, chars2) + INPUT_LEN 序列长度 seq_len(复数seq_len1,seq_len2) + OUTPUT 模型输出 pred(复数pred1, pred2) + TARGET 真实目标 target(复数target1,target2) + + """ + INPUT = 'words' + CHAR_INPUT = 'chars' + INPUT_LEN = 'seq_len' + OUTPUT = 'pred' + TARGET = 'target' + + @staticmethod + def INPUTS(i): + """得到第 i 个 ``INPUT`` 的命名""" + i = int(i) + 1 + return Const.INPUT + str(i) + + @staticmethod + def CHAR_INPUTS(i): + """得到第 i 个 ``CHAR_INPUT`` 的命名""" + i = int(i) + 1 + return Const.CHAR_INPUT + str(i) + + @staticmethod + def INPUT_LENS(i): + """得到第 i 个 ``INPUT_LEN`` 的命名""" + i = int(i) + 1 + return Const.INPUT_LEN + str(i) + + @staticmethod + def OUTPUTS(i): + """得到第 i 个 ``OUTPUT`` 的命名""" + i = int(i) + 1 + return Const.OUTPUT + str(i) + + @staticmethod + def TARGETS(i): + """得到第 i 个 ``TARGET`` 的命名""" + i = int(i) + 1 + return Const.TARGET + str(i) diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index a9babce5..0fcdcf29 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -193,9 +193,9 @@ class ConllLoader(DataSetLoader): :param headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexs`` 一一对应 :param indexs: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` - :param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` + :param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False`` """ - def __init__(self, headers, indexs=None, dropna=True): + def __init__(self, headers, indexs=None, dropna=False): super(ConllLoader, self).__init__() if not isinstance(headers, (list, tuple)): raise TypeError('invalid headers: {}, should be list of strings'.format(headers)) @@ -314,7 +314,7 @@ class JsonLoader(DataSetLoader): `value`也可为 ``None`` , 这时读入后的`field_name`与json对象对应属性同名 ``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None`` :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . - Default: ``True`` + Default: ``False`` """ def __init__(self, fields=None, dropna=False): super(JsonLoader, self).__init__() @@ -375,9 +375,9 @@ class CSVLoader(DataSetLoader): 若为 ``None`` ,则将读入文件的第一行视作 ``headers`` . Default: ``None`` :param str sep: CSV文件中列与列之间的分隔符. Default: "," :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . - Default: ``True`` + Default: ``False`` """ - def __init__(self, headers=None, sep=",", dropna=True): + def __init__(self, headers=None, sep=",", dropna=False): self.headers = headers self.sep = sep self.dropna = dropna diff --git a/fastNLP/models/star_transformer.py b/fastNLP/models/star_transformer.py index e4fbeb28..f0c6f33f 100644 --- a/fastNLP/models/star_transformer.py +++ b/fastNLP/models/star_transformer.py @@ -1,8 +1,9 @@ """Star-Transformer 的 一个 Pytorch 实现. """ -from fastNLP.modules.encoder.star_transformer import StarTransformer -from fastNLP.core.utils import seq_lens_to_masks +from ..modules.encoder.star_transformer import StarTransformer +from ..core.utils import seq_lens_to_masks from ..modules.utils import get_embeddings +from ..core.const import Const import torch from torch import nn @@ -139,7 +140,7 @@ class STSeqLabel(nn.Module): nodes, _ = self.enc(words, mask) output = self.cls(nodes) output = output.transpose(1,2) # make hidden to be dim 1 - return {'output': output} # [bsz, n_cls, seq_len] + return {Const.OUTPUT: output} # [bsz, n_cls, seq_len] def predict(self, words, seq_len): """ @@ -149,8 +150,8 @@ class STSeqLabel(nn.Module): :return output: [batch, seq_len] 输出序列中每个元素的分类 """ y = self.forward(words, seq_len) - _, pred = y['output'].max(1) - return {'output': pred} + _, pred = y[Const.OUTPUT].max(1) + return {Const.OUTPUT: pred} class STSeqCls(nn.Module): @@ -201,7 +202,7 @@ class STSeqCls(nn.Module): nodes, relay = self.enc(words, mask) y = 0.5 * (relay + nodes.max(1)[0]) output = self.cls(y) # [bsz, n_cls] - return {'output': output} + return {Const.OUTPUT: output} def predict(self, words, seq_len): """ @@ -211,8 +212,8 @@ class STSeqCls(nn.Module): :return output: [batch, num_cls] 输出序列的分类 """ y = self.forward(words, seq_len) - _, pred = y['output'].max(1) - return {'output': pred} + _, pred = y[Const.OUTPUT].max(1) + return {Const.OUTPUT: pred} class STNLICls(nn.Module): @@ -269,7 +270,7 @@ class STNLICls(nn.Module): y1 = enc(words1, mask1) y2 = enc(words2, mask2) output = self.cls(y1, y2) # [bsz, n_cls] - return {'output': output} + return {Const.OUTPUT: output} def predict(self, words1, words2, seq_len1, seq_len2): """ @@ -281,5 +282,5 @@ class STNLICls(nn.Module): :return output: [batch, num_cls] 输出分类的概率 """ y = self.forward(words1, words2, seq_len1, seq_len2) - _, pred = y['output'].max(1) - return {'output': pred} + _, pred = y[Const.OUTPUT].max(1) + return {Const.OUTPUT: pred} diff --git a/test/core/test_callbacks.py b/test/core/test_callbacks.py index 3329e7a1..cf3e2fff 100644 --- a/test/core/test_callbacks.py +++ b/test/core/test_callbacks.py @@ -3,7 +3,7 @@ import unittest import numpy as np import torch -from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \ +from fastNLP.core.callback import EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \ LRFinder, \ TensorboardCallback from fastNLP.core.dataset import DataSet diff --git a/test/io/test_dataset_loader.py b/test/io/test_dataset_loader.py index 97379a7d..2e367567 100644 --- a/test/io/test_dataset_loader.py +++ b/test/io/test_dataset_loader.py @@ -1,7 +1,7 @@ import unittest from fastNLP.io.dataset_loader import Conll2003Loader, PeopleDailyCorpusLoader, \ - CSVLoader, SNLILoader + CSVLoader, SNLILoader, JsonLoader class TestDatasetLoader(unittest.TestCase): @@ -24,3 +24,8 @@ class TestDatasetLoader(unittest.TestCase): def test_SNLILoader(self): ds = SNLILoader().load('test/data_for_tests/sample_snli.jsonl') assert len(ds) == 3 + + def test_JsonLoader(self): + ds = JsonLoader().load('test/data_for_tests/sample_snli.jsonl') + assert len(ds) == 3 +