@@ -0,0 +1,46 @@ | |||
class Const(): | |||
"""fastNLP中field命名常量。 | |||
具体列表:: | |||
INPUT 模型的序列输入 words(复数words1, words2) | |||
CHAR_INPUT 模型character输入 chars(复数chars1, chars2) | |||
INPUT_LEN 序列长度 seq_len(复数seq_len1,seq_len2) | |||
OUTPUT 模型输出 pred(复数pred1, pred2) | |||
TARGET 真实目标 target(复数target1,target2) | |||
""" | |||
INPUT = 'words' | |||
CHAR_INPUT = 'chars' | |||
INPUT_LEN = 'seq_len' | |||
OUTPUT = 'pred' | |||
TARGET = 'target' | |||
@staticmethod | |||
def INPUTS(i): | |||
"""得到第 i 个 ``INPUT`` 的命名""" | |||
i = int(i) + 1 | |||
return Const.INPUT + str(i) | |||
@staticmethod | |||
def CHAR_INPUTS(i): | |||
"""得到第 i 个 ``CHAR_INPUT`` 的命名""" | |||
i = int(i) + 1 | |||
return Const.CHAR_INPUT + str(i) | |||
@staticmethod | |||
def INPUT_LENS(i): | |||
"""得到第 i 个 ``INPUT_LEN`` 的命名""" | |||
i = int(i) + 1 | |||
return Const.INPUT_LEN + str(i) | |||
@staticmethod | |||
def OUTPUTS(i): | |||
"""得到第 i 个 ``OUTPUT`` 的命名""" | |||
i = int(i) + 1 | |||
return Const.OUTPUT + str(i) | |||
@staticmethod | |||
def TARGETS(i): | |||
"""得到第 i 个 ``TARGET`` 的命名""" | |||
i = int(i) + 1 | |||
return Const.TARGET + str(i) |
@@ -193,9 +193,9 @@ class ConllLoader(DataSetLoader): | |||
:param headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexs`` 一一对应 | |||
:param indexs: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | |||
:param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` | |||
:param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False`` | |||
""" | |||
def __init__(self, headers, indexs=None, dropna=True): | |||
def __init__(self, headers, indexs=None, dropna=False): | |||
super(ConllLoader, self).__init__() | |||
if not isinstance(headers, (list, tuple)): | |||
raise TypeError('invalid headers: {}, should be list of strings'.format(headers)) | |||
@@ -314,7 +314,7 @@ class JsonLoader(DataSetLoader): | |||
`value`也可为 ``None`` , 这时读入后的`field_name`与json对象对应属性同名 | |||
``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None`` | |||
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . | |||
Default: ``True`` | |||
Default: ``False`` | |||
""" | |||
def __init__(self, fields=None, dropna=False): | |||
super(JsonLoader, self).__init__() | |||
@@ -375,9 +375,9 @@ class CSVLoader(DataSetLoader): | |||
若为 ``None`` ,则将读入文件的第一行视作 ``headers`` . Default: ``None`` | |||
:param str sep: CSV文件中列与列之间的分隔符. Default: "," | |||
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . | |||
Default: ``True`` | |||
Default: ``False`` | |||
""" | |||
def __init__(self, headers=None, sep=",", dropna=True): | |||
def __init__(self, headers=None, sep=",", dropna=False): | |||
self.headers = headers | |||
self.sep = sep | |||
self.dropna = dropna | |||
@@ -1,8 +1,9 @@ | |||
"""Star-Transformer 的 一个 Pytorch 实现. | |||
""" | |||
from fastNLP.modules.encoder.star_transformer import StarTransformer | |||
from fastNLP.core.utils import seq_lens_to_masks | |||
from ..modules.encoder.star_transformer import StarTransformer | |||
from ..core.utils import seq_lens_to_masks | |||
from ..modules.utils import get_embeddings | |||
from ..core.const import Const | |||
import torch | |||
from torch import nn | |||
@@ -139,7 +140,7 @@ class STSeqLabel(nn.Module): | |||
nodes, _ = self.enc(words, mask) | |||
output = self.cls(nodes) | |||
output = output.transpose(1,2) # make hidden to be dim 1 | |||
return {'output': output} # [bsz, n_cls, seq_len] | |||
return {Const.OUTPUT: output} # [bsz, n_cls, seq_len] | |||
def predict(self, words, seq_len): | |||
""" | |||
@@ -149,8 +150,8 @@ class STSeqLabel(nn.Module): | |||
:return output: [batch, seq_len] 输出序列中每个元素的分类 | |||
""" | |||
y = self.forward(words, seq_len) | |||
_, pred = y['output'].max(1) | |||
return {'output': pred} | |||
_, pred = y[Const.OUTPUT].max(1) | |||
return {Const.OUTPUT: pred} | |||
class STSeqCls(nn.Module): | |||
@@ -201,7 +202,7 @@ class STSeqCls(nn.Module): | |||
nodes, relay = self.enc(words, mask) | |||
y = 0.5 * (relay + nodes.max(1)[0]) | |||
output = self.cls(y) # [bsz, n_cls] | |||
return {'output': output} | |||
return {Const.OUTPUT: output} | |||
def predict(self, words, seq_len): | |||
""" | |||
@@ -211,8 +212,8 @@ class STSeqCls(nn.Module): | |||
:return output: [batch, num_cls] 输出序列的分类 | |||
""" | |||
y = self.forward(words, seq_len) | |||
_, pred = y['output'].max(1) | |||
return {'output': pred} | |||
_, pred = y[Const.OUTPUT].max(1) | |||
return {Const.OUTPUT: pred} | |||
class STNLICls(nn.Module): | |||
@@ -269,7 +270,7 @@ class STNLICls(nn.Module): | |||
y1 = enc(words1, mask1) | |||
y2 = enc(words2, mask2) | |||
output = self.cls(y1, y2) # [bsz, n_cls] | |||
return {'output': output} | |||
return {Const.OUTPUT: output} | |||
def predict(self, words1, words2, seq_len1, seq_len2): | |||
""" | |||
@@ -281,5 +282,5 @@ class STNLICls(nn.Module): | |||
:return output: [batch, num_cls] 输出分类的概率 | |||
""" | |||
y = self.forward(words1, words2, seq_len1, seq_len2) | |||
_, pred = y['output'].max(1) | |||
return {'output': pred} | |||
_, pred = y[Const.OUTPUT].max(1) | |||
return {Const.OUTPUT: pred} |
@@ -3,7 +3,7 @@ import unittest | |||
import numpy as np | |||
import torch | |||
from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \ | |||
from fastNLP.core.callback import EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \ | |||
LRFinder, \ | |||
TensorboardCallback | |||
from fastNLP.core.dataset import DataSet | |||
@@ -1,7 +1,7 @@ | |||
import unittest | |||
from fastNLP.io.dataset_loader import Conll2003Loader, PeopleDailyCorpusLoader, \ | |||
CSVLoader, SNLILoader | |||
CSVLoader, SNLILoader, JsonLoader | |||
class TestDatasetLoader(unittest.TestCase): | |||
@@ -24,3 +24,8 @@ class TestDatasetLoader(unittest.TestCase): | |||
def test_SNLILoader(self): | |||
ds = SNLILoader().load('test/data_for_tests/sample_snli.jsonl') | |||
assert len(ds) == 3 | |||
def test_JsonLoader(self): | |||
ds = JsonLoader().load('test/data_for_tests/sample_snli.jsonl') | |||
assert len(ds) == 3 | |||