diff --git a/fastNLP/io/__init__.py b/fastNLP/io/__init__.py index cd0d3527..bf5c2c36 100644 --- a/fastNLP/io/__init__.py +++ b/fastNLP/io/__init__.py @@ -14,27 +14,56 @@ __all__ = [ 'EmbedLoader', - 'CSVLoader', - 'JsonLoader', - 'DataBundle', 'DataSetLoader', - 'ConllLoader', - 'Conll2003Loader', + 'YelpLoader', + 'YelpFullLoader', + 'YelpPolarityLoader', 'IMDBLoader', - 'MatchingLoader', - 'SNLILoader', - 'MNLILoader', - 'MTL16Loader', - 'PeopleDailyCorpusLoader', - 'QNLILoader', - 'QuoraLoader', - 'RTELoader', 'SSTLoader', 'SST2Loader', - 'YelpLoader', - + + 'ConllLoader', + 'Conll2003Loader', + 'Conll2003NERLoader', + 'OntoNotesNERLoader', + 'CTBLoader', + + 'Loader', + 'CSVLoader', + 'JsonLoader', + + 'CWSLoader', + + 'MNLILoader', + "QuoraLoader", + "SNLILoader", + "QNLILoader", + "RTELoader", + + "YelpFullPipe", + "YelpPolarityPipe", + "SSTPipe", + "SST2Pipe", + "IMDBPipe", + + "Conll2003NERPipe", + "OntoNotesNERPipe", + + "MatchingBertPipe", + "RTEBertPipe", + "SNLIBertPipe", + "QuoraBertPipe", + "QNLIBertPipe", + "MNLIBertPipe", + "MatchingPipe", + "RTEPipe", + "SNLIPipe", + "QuoraPipe", + "QNLIPipe", + "MNLIPipe", + 'ModelLoader', 'ModelSaver', ] @@ -44,4 +73,5 @@ from .base_loader import DataBundle, DataSetLoader from .dataset_loader import CSVLoader, JsonLoader from .model_io import ModelLoader, ModelSaver -from .data_loader import * +from .loader import * +from .pipe import * diff --git a/fastNLP/io/loader/__init__.py b/fastNLP/io/loader/__init__.py index 8e436532..4905a34f 100644 --- a/fastNLP/io/loader/__init__.py +++ b/fastNLP/io/loader/__init__.py @@ -1,30 +1,61 @@ - """ Loader用于读取数据,并将内容读取到 :class:`~fastNLP.DataSet` 或者 :class:`~fastNLP.io.DataBundle`中。所有的Loader都支持以下的 - 三个方法: __init__(),_load(), loads(). 其中__init__()用于申明读取参数,以及说明该Loader支持的数据格式,读取后Dataset中field - ; _load(path)方法传入一个文件路径读取单个文件,并返回DataSet; load(paths)用于读取文件夹下的文件,并返回DataBundle, load()方法 - 支持以下三种类型的参数 +三个方法: __init__(),_load(), loads(). 其中__init__()用于申明读取参数,以及说明该Loader支持的数据格式,读取后Dataset中field +; _load(path)方法传入一个文件路径读取单个文件,并返回DataSet; load(paths)用于读取文件夹下的文件,并返回DataBundle, load()方法 +支持以下三种类型的参数:: - Example:: - (0) 如果传入None,将尝试自动下载数据集并缓存。但不是所有的数据都可以直接下载。 - (1) 如果传入的是一个文件path,则返回的DataBundle包含一个名为train的DataSet可以通过data_bundle.datasets['train']获取 - (2) 传入的是一个文件夹目录,将读取的是这个文件夹下文件名中包含'train', 'test', 'dev'的文件,其它文件会被忽略。 - 假设某个目录下的文件为 - -train.txt - -dev.txt - -test.txt - -other.txt - Loader().load('/path/to/dir')读取,返回的data_bundle中可以用data_bundle.datasets['train'], data_bundle.datasets['dev'], - data_bundle.datasets['test']获取对应的DataSet,其中other.txt的内容会被忽略。 - 假设某个目录下的文件为 - -train.txt - -dev.txt - Loader().load('/path/to/dir')读取,返回的data_bundle中可以用data_bundle.datasets['train'], data_bundle.datasets['dev']获取 - 对应的DataSet。 - (3) 传入一个dict,key为dataset的名称,value是该dataset的文件路径。 - paths = {'train':'/path/to/train', 'dev': '/path/to/dev', 'test':'/path/to/test'} - Loader().load(paths) # 返回的data_bundle可以通过以下的方式获取相应的DataSet, data_bundle.datasets['train'], data_bundle.datasets['dev'], - data_bundle.datasets['test'] + (0) 如果传入None,将尝试自动下载数据集并缓存。但不是所有的数据都可以直接下载。 + (1) 如果传入的是一个文件path,则返回的DataBundle包含一个名为train的DataSet可以通过data_bundle.datasets['train']获取 + (2) 传入的是一个文件夹目录,将读取的是这个文件夹下文件名中包含'train', 'test', 'dev'的文件,其它文件会被忽略。 + 假设某个目录下的文件为 + -train.txt + -dev.txt + -test.txt + -other.txt + Loader().load('/path/to/dir')读取,返回的data_bundle中可以用data_bundle.datasets['train'], data_bundle.datasets['dev'], + data_bundle.datasets['test']获取对应的DataSet,其中other.txt的内容会被忽略。 + 假设某个目录下的文件为 + -train.txt + -dev.txt + Loader().load('/path/to/dir')读取,返回的data_bundle中可以用data_bundle.datasets['train'], data_bundle.datasets['dev']获取 + 对应的DataSet。 + (3) 传入一个dict,key为dataset的名称,value是该dataset的文件路径。 + paths = {'train':'/path/to/train', 'dev': '/path/to/dev', 'test':'/path/to/test'} + Loader().load(paths) # 返回的data_bundle可以通过以下的方式获取相应的DataSet, data_bundle.datasets['train'], data_bundle.datasets['dev'], + data_bundle.datasets['test'] """ +__all__ = [ + 'YelpLoader', + 'YelpFullLoader', + 'YelpPolarityLoader', + 'IMDBLoader', + 'SSTLoader', + 'SST2Loader', + + 'ConllLoader', + 'Conll2003Loader', + 'Conll2003NERLoader', + 'OntoNotesNERLoader', + 'CTBLoader', + + 'Loader', + 'CSVLoader', + 'JsonLoader', + + 'CWSLoader', + + 'MNLILoader', + "QuoraLoader", + "SNLILoader", + "QNLILoader", + "RTELoader" +] +from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader +from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader +from .csv import CSVLoader +from .cws import CWSLoader +from .json import JsonLoader +from .loader import Loader +from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader diff --git a/fastNLP/io/loader/matching.py b/fastNLP/io/loader/matching.py index eff98ba3..05f113c1 100644 --- a/fastNLP/io/loader/matching.py +++ b/fastNLP/io/loader/matching.py @@ -1,4 +1,3 @@ - import warnings from .loader import Loader from .json import JsonLoader @@ -9,12 +8,6 @@ from typing import Union, Dict from ...core import DataSet from ...core import Instance -__all__ = ['MNLILoader', - "QuoraLoader", - "SNLILoader", - "QNLILoader", - "RTELoader"] - class MNLILoader(Loader): """ diff --git a/fastNLP/io/pipe/__init__.py b/fastNLP/io/pipe/__init__.py index 0cf8d949..4cec3ad5 100644 --- a/fastNLP/io/pipe/__init__.py +++ b/fastNLP/io/pipe/__init__.py @@ -1,8 +1,34 @@ - - """ Pipe用于处理数据,所有的Pipe都包含一个process(DataBundle)方法,传入一个DataBundle对象, 在传入DataBundle上进行原位修改,并将其返回; - process_from_file(paths)传入的文件路径,返回一个DataBundle。process(DataBundle)或者process_from_file(paths)的返回DataBundle - 中的DataSet一般都包含原文与转换为index的输入,以及转换为index的target;除了DataSet之外,还会包含将field转为index时所建立的词表。 +process_from_file(paths)传入的文件路径,返回一个DataBundle。process(DataBundle)或者process_from_file(paths)的返回DataBundle +中的DataSet一般都包含原文与转换为index的输入,以及转换为index的target;除了DataSet之外,还会包含将field转为index时所建立的词表。 + +""" +__all__ = [ + "YelpFullPipe", + "YelpPolarityPipe", + "SSTPipe", + "SST2Pipe", + "IMDBPipe", + + "Conll2003NERPipe", + "OntoNotesNERPipe", + + "MatchingBertPipe", + "RTEBertPipe", + "SNLIBertPipe", + "QuoraBertPipe", + "QNLIBertPipe", + "MNLIBertPipe", + "MatchingPipe", + "RTEPipe", + "SNLIPipe", + "QuoraPipe", + "QNLIPipe", + "MNLIPipe", +] -""" \ No newline at end of file +from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe +from .conll import Conll2003NERPipe, OntoNotesNERPipe +from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \ + MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe diff --git a/fastNLP/io/pipe/classification.py b/fastNLP/io/pipe/classification.py index a64e5328..d370a28a 100644 --- a/fastNLP/io/pipe/classification.py +++ b/fastNLP/io/pipe/classification.py @@ -1,4 +1,3 @@ - from nltk import Tree from ..base_loader import DataBundle diff --git a/fastNLP/io/pipe/conll.py b/fastNLP/io/pipe/conll.py index 4f780614..e62d1a05 100644 --- a/fastNLP/io/pipe/conll.py +++ b/fastNLP/io/pipe/conll.py @@ -19,15 +19,16 @@ class _NERPipe(Pipe): :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 :param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为-100。 """ - def __init__(self, encoding_type:str='bio', lower:bool=False, target_pad_val=0): - if encoding_type == 'bio': + + def __init__(self, encoding_type: str = 'bio', lower: bool = False, target_pad_val=0): + if encoding_type == 'bio': self.convert_tag = iob2 else: self.convert_tag = iob2bioes self.lower = lower self.target_pad_val = int(target_pad_val) - def process(self, data_bundle:DataBundle)->DataBundle: + def process(self, data_bundle: DataBundle) -> DataBundle: """ 支持的DataSet的field为 @@ -146,4 +147,3 @@ class OntoNotesNERPipe(_NERPipe): def process_from_file(self, paths): data_bundle = OntoNotesNERLoader().load(paths) return self.process(data_bundle) - diff --git a/fastNLP/io/pipe/matching.py b/fastNLP/io/pipe/matching.py index 76a0eaf7..1e551f1d 100644 --- a/fastNLP/io/pipe/matching.py +++ b/fastNLP/io/pipe/matching.py @@ -6,6 +6,7 @@ from ...core import Const from ...core import Vocabulary from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader + class MatchingBertPipe(Pipe): """ Matching任务的Bert pipe,输出的DataSet将包含以下的field diff --git a/fastNLP/io/pipe/pipe.py b/fastNLP/io/pipe/pipe.py index 14c3866a..76cc00ec 100644 --- a/fastNLP/io/pipe/pipe.py +++ b/fastNLP/io/pipe/pipe.py @@ -1,9 +1,9 @@ - from .. import DataBundle + class Pipe: - def process(self, data_bundle:DataBundle)->DataBundle: + def process(self, data_bundle: DataBundle) -> DataBundle: raise NotImplementedError - def process_from_file(self, paths)->DataBundle: + def process_from_file(self, paths) -> DataBundle: raise NotImplementedError