@@ -0,0 +1,46 @@ | |||||
class Const(): | |||||
"""fastNLP中field命名常量。 | |||||
具体列表:: | |||||
INPUT 模型的序列输入 words(复数words1, words2) | |||||
CHAR_INPUT 模型character输入 chars(复数chars1, chars2) | |||||
INPUT_LEN 序列长度 seq_len(复数seq_len1,seq_len2) | |||||
OUTPUT 模型输出 pred(复数pred1, pred2) | |||||
TARGET 真实目标 target(复数target1,target2) | |||||
""" | |||||
INPUT = 'words' | |||||
CHAR_INPUT = 'chars' | |||||
INPUT_LEN = 'seq_len' | |||||
OUTPUT = 'pred' | |||||
TARGET = 'target' | |||||
@staticmethod | |||||
def INPUTS(i): | |||||
"""得到第 i 个 ``INPUT`` 的命名""" | |||||
i = int(i) + 1 | |||||
return Const.INPUT + str(i) | |||||
@staticmethod | |||||
def CHAR_INPUTS(i): | |||||
"""得到第 i 个 ``CHAR_INPUT`` 的命名""" | |||||
i = int(i) + 1 | |||||
return Const.CHAR_INPUT + str(i) | |||||
@staticmethod | |||||
def INPUT_LENS(i): | |||||
"""得到第 i 个 ``INPUT_LEN`` 的命名""" | |||||
i = int(i) + 1 | |||||
return Const.INPUT_LEN + str(i) | |||||
@staticmethod | |||||
def OUTPUTS(i): | |||||
"""得到第 i 个 ``OUTPUT`` 的命名""" | |||||
i = int(i) + 1 | |||||
return Const.OUTPUT + str(i) | |||||
@staticmethod | |||||
def TARGETS(i): | |||||
"""得到第 i 个 ``TARGET`` 的命名""" | |||||
i = int(i) + 1 | |||||
return Const.TARGET + str(i) |
@@ -193,9 +193,9 @@ class ConllLoader(DataSetLoader): | |||||
:param headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexs`` 一一对应 | :param headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexs`` 一一对应 | ||||
:param indexs: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | :param indexs: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | ||||
:param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` | |||||
:param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False`` | |||||
""" | """ | ||||
def __init__(self, headers, indexs=None, dropna=True): | |||||
def __init__(self, headers, indexs=None, dropna=False): | |||||
super(ConllLoader, self).__init__() | super(ConllLoader, self).__init__() | ||||
if not isinstance(headers, (list, tuple)): | if not isinstance(headers, (list, tuple)): | ||||
raise TypeError('invalid headers: {}, should be list of strings'.format(headers)) | raise TypeError('invalid headers: {}, should be list of strings'.format(headers)) | ||||
@@ -314,7 +314,7 @@ class JsonLoader(DataSetLoader): | |||||
`value`也可为 ``None`` , 这时读入后的`field_name`与json对象对应属性同名 | `value`也可为 ``None`` , 这时读入后的`field_name`与json对象对应属性同名 | ||||
``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None`` | ``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None`` | ||||
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . | :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . | ||||
Default: ``True`` | |||||
Default: ``False`` | |||||
""" | """ | ||||
def __init__(self, fields=None, dropna=False): | def __init__(self, fields=None, dropna=False): | ||||
super(JsonLoader, self).__init__() | super(JsonLoader, self).__init__() | ||||
@@ -375,9 +375,9 @@ class CSVLoader(DataSetLoader): | |||||
若为 ``None`` ,则将读入文件的第一行视作 ``headers`` . Default: ``None`` | 若为 ``None`` ,则将读入文件的第一行视作 ``headers`` . Default: ``None`` | ||||
:param str sep: CSV文件中列与列之间的分隔符. Default: "," | :param str sep: CSV文件中列与列之间的分隔符. Default: "," | ||||
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . | :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . | ||||
Default: ``True`` | |||||
Default: ``False`` | |||||
""" | """ | ||||
def __init__(self, headers=None, sep=",", dropna=True): | |||||
def __init__(self, headers=None, sep=",", dropna=False): | |||||
self.headers = headers | self.headers = headers | ||||
self.sep = sep | self.sep = sep | ||||
self.dropna = dropna | self.dropna = dropna | ||||
@@ -1,8 +1,9 @@ | |||||
"""Star-Transformer 的 一个 Pytorch 实现. | """Star-Transformer 的 一个 Pytorch 实现. | ||||
""" | """ | ||||
from fastNLP.modules.encoder.star_transformer import StarTransformer | |||||
from fastNLP.core.utils import seq_lens_to_masks | |||||
from ..modules.encoder.star_transformer import StarTransformer | |||||
from ..core.utils import seq_lens_to_masks | |||||
from ..modules.utils import get_embeddings | from ..modules.utils import get_embeddings | ||||
from ..core.const import Const | |||||
import torch | import torch | ||||
from torch import nn | from torch import nn | ||||
@@ -139,7 +140,7 @@ class STSeqLabel(nn.Module): | |||||
nodes, _ = self.enc(words, mask) | nodes, _ = self.enc(words, mask) | ||||
output = self.cls(nodes) | output = self.cls(nodes) | ||||
output = output.transpose(1,2) # make hidden to be dim 1 | output = output.transpose(1,2) # make hidden to be dim 1 | ||||
return {'output': output} # [bsz, n_cls, seq_len] | |||||
return {Const.OUTPUT: output} # [bsz, n_cls, seq_len] | |||||
def predict(self, words, seq_len): | def predict(self, words, seq_len): | ||||
""" | """ | ||||
@@ -149,8 +150,8 @@ class STSeqLabel(nn.Module): | |||||
:return output: [batch, seq_len] 输出序列中每个元素的分类 | :return output: [batch, seq_len] 输出序列中每个元素的分类 | ||||
""" | """ | ||||
y = self.forward(words, seq_len) | y = self.forward(words, seq_len) | ||||
_, pred = y['output'].max(1) | |||||
return {'output': pred} | |||||
_, pred = y[Const.OUTPUT].max(1) | |||||
return {Const.OUTPUT: pred} | |||||
class STSeqCls(nn.Module): | class STSeqCls(nn.Module): | ||||
@@ -201,7 +202,7 @@ class STSeqCls(nn.Module): | |||||
nodes, relay = self.enc(words, mask) | nodes, relay = self.enc(words, mask) | ||||
y = 0.5 * (relay + nodes.max(1)[0]) | y = 0.5 * (relay + nodes.max(1)[0]) | ||||
output = self.cls(y) # [bsz, n_cls] | output = self.cls(y) # [bsz, n_cls] | ||||
return {'output': output} | |||||
return {Const.OUTPUT: output} | |||||
def predict(self, words, seq_len): | def predict(self, words, seq_len): | ||||
""" | """ | ||||
@@ -211,8 +212,8 @@ class STSeqCls(nn.Module): | |||||
:return output: [batch, num_cls] 输出序列的分类 | :return output: [batch, num_cls] 输出序列的分类 | ||||
""" | """ | ||||
y = self.forward(words, seq_len) | y = self.forward(words, seq_len) | ||||
_, pred = y['output'].max(1) | |||||
return {'output': pred} | |||||
_, pred = y[Const.OUTPUT].max(1) | |||||
return {Const.OUTPUT: pred} | |||||
class STNLICls(nn.Module): | class STNLICls(nn.Module): | ||||
@@ -269,7 +270,7 @@ class STNLICls(nn.Module): | |||||
y1 = enc(words1, mask1) | y1 = enc(words1, mask1) | ||||
y2 = enc(words2, mask2) | y2 = enc(words2, mask2) | ||||
output = self.cls(y1, y2) # [bsz, n_cls] | output = self.cls(y1, y2) # [bsz, n_cls] | ||||
return {'output': output} | |||||
return {Const.OUTPUT: output} | |||||
def predict(self, words1, words2, seq_len1, seq_len2): | def predict(self, words1, words2, seq_len1, seq_len2): | ||||
""" | """ | ||||
@@ -281,5 +282,5 @@ class STNLICls(nn.Module): | |||||
:return output: [batch, num_cls] 输出分类的概率 | :return output: [batch, num_cls] 输出分类的概率 | ||||
""" | """ | ||||
y = self.forward(words1, words2, seq_len1, seq_len2) | y = self.forward(words1, words2, seq_len1, seq_len2) | ||||
_, pred = y['output'].max(1) | |||||
return {'output': pred} | |||||
_, pred = y[Const.OUTPUT].max(1) | |||||
return {Const.OUTPUT: pred} |
@@ -3,7 +3,7 @@ import unittest | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \ | |||||
from fastNLP.core.callback import EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \ | |||||
LRFinder, \ | LRFinder, \ | ||||
TensorboardCallback | TensorboardCallback | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
@@ -1,7 +1,7 @@ | |||||
import unittest | import unittest | ||||
from fastNLP.io.dataset_loader import Conll2003Loader, PeopleDailyCorpusLoader, \ | from fastNLP.io.dataset_loader import Conll2003Loader, PeopleDailyCorpusLoader, \ | ||||
CSVLoader, SNLILoader | |||||
CSVLoader, SNLILoader, JsonLoader | |||||
class TestDatasetLoader(unittest.TestCase): | class TestDatasetLoader(unittest.TestCase): | ||||
@@ -24,3 +24,8 @@ class TestDatasetLoader(unittest.TestCase): | |||||
def test_SNLILoader(self): | def test_SNLILoader(self): | ||||
ds = SNLILoader().load('test/data_for_tests/sample_snli.jsonl') | ds = SNLILoader().load('test/data_for_tests/sample_snli.jsonl') | ||||
assert len(ds) == 3 | assert len(ds) == 3 | ||||
def test_JsonLoader(self): | |||||
ds = JsonLoader().load('test/data_for_tests/sample_snli.jsonl') | |||||
assert len(ds) == 3 | |||||