diff --git a/docs/Makefile b/docs/Makefile
index 2b4de2d8..b9f1cf95 100644
--- a/docs/Makefile
+++ b/docs/Makefile
@@ -14,7 +14,7 @@ help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
apidoc:
- $(SPHINXAPIDOC) -efM -o source ../$(SPHINXPROJ)
+ $(SPHINXAPIDOC) -efM -o source ../$(SPHINXPROJ) && python3 format.py
server:
cd build/html && python -m http.server
diff --git a/docs/format.py b/docs/format.py
new file mode 100644
index 00000000..7cc341c2
--- /dev/null
+++ b/docs/format.py
@@ -0,0 +1,65 @@
+import os
+
+
+def shorten(file, to_delete, cut=False):
+ if file.endswith("index.rst") or file.endswith("conf.py"):
+ return
+ res = []
+ with open(file, "r") as fin:
+ lines = fin.readlines()
+ for line in lines:
+ if cut and line.rstrip() == "Submodules":
+ break
+ else:
+ res.append(line.rstrip())
+ for i, line in enumerate(res):
+ if line.endswith(" package"):
+ res[i] = res[i][:-len(" package")]
+ res[i + 1] = res[i + 1][:-len(" package")]
+ elif line.endswith(" module"):
+ res[i] = res[i][:-len(" module")]
+ res[i + 1] = res[i + 1][:-len(" module")]
+ else:
+ for name in to_delete:
+ if line.endswith(name):
+ res[i] = "del"
+
+ with open(file, "w") as fout:
+ for line in res:
+ if line != "del":
+ print(line, file=fout)
+
+
+def clear(path='./source/'):
+ files = os.listdir(path)
+ to_delete = [
+ "fastNLP.core.dist_trainer",
+ "fastNLP.core.predictor",
+
+ "fastNLP.io.file_reader",
+ "fastNLP.io.config_io",
+
+ "fastNLP.embeddings.contextual_embedding",
+
+ "fastNLP.modules.dropout",
+ "fastNLP.models.base_model",
+ "fastNLP.models.bert",
+ "fastNLP.models.enas_utils",
+ "fastNLP.models.enas_controller",
+ "fastNLP.models.enas_model",
+ "fastNLP.models.enas_trainer",
+ ]
+ for file in files:
+ if not os.path.isdir(path + file):
+ res = file.split('.')
+ if len(res) > 4:
+ to_delete.append(file[:-4])
+ elif len(res) == 4:
+ shorten(path + file, to_delete, True)
+ else:
+ shorten(path + file, to_delete)
+ for file in to_delete:
+ os.remove(path + file + ".rst")
+
+
+clear()
diff --git a/docs/source/fastNLP.core.rst b/docs/source/fastNLP.core.rst
index cacc6622..08d161b7 100644
--- a/docs/source/fastNLP.core.rst
+++ b/docs/source/fastNLP.core.rst
@@ -6,11 +6,10 @@ fastNLP.core
:undoc-members:
:show-inheritance:
-子模块
+Submodules
----------
.. toctree::
- :maxdepth: 1
fastNLP.core.batch
fastNLP.core.callback
diff --git a/docs/source/fastNLP.embeddings.rst b/docs/source/fastNLP.embeddings.rst
index 6b168906..6872e91d 100644
--- a/docs/source/fastNLP.embeddings.rst
+++ b/docs/source/fastNLP.embeddings.rst
@@ -6,11 +6,10 @@ fastNLP.embeddings
:undoc-members:
:show-inheritance:
-子模块
+Submodules
----------
.. toctree::
- :maxdepth: 1
fastNLP.embeddings.bert_embedding
fastNLP.embeddings.char_embedding
diff --git a/docs/source/fastNLP.io.data_loader.rst b/docs/source/fastNLP.io.data_loader.rst
index 8f990102..0b4f5d0b 100644
--- a/docs/source/fastNLP.io.data_loader.rst
+++ b/docs/source/fastNLP.io.data_loader.rst
@@ -1,7 +1,8 @@
fastNLP.io.data\_loader
-==========================
+=======================
.. automodule:: fastNLP.io.data_loader
:members:
:undoc-members:
- :show-inheritance:
\ No newline at end of file
+ :show-inheritance:
+
diff --git a/docs/source/fastNLP.io.file_utils.rst b/docs/source/fastNLP.io.file_utils.rst
new file mode 100644
index 00000000..944550d7
--- /dev/null
+++ b/docs/source/fastNLP.io.file_utils.rst
@@ -0,0 +1,7 @@
+fastNLP.io.file\_utils
+======================
+
+.. automodule:: fastNLP.io.file_utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.io.loader.rst b/docs/source/fastNLP.io.loader.rst
new file mode 100644
index 00000000..bbdc1d7a
--- /dev/null
+++ b/docs/source/fastNLP.io.loader.rst
@@ -0,0 +1,8 @@
+fastNLP.io.loader
+=================
+
+.. automodule:: fastNLP.io.loader
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
diff --git a/docs/source/fastNLP.io.pipe.rst b/docs/source/fastNLP.io.pipe.rst
new file mode 100644
index 00000000..bf126585
--- /dev/null
+++ b/docs/source/fastNLP.io.pipe.rst
@@ -0,0 +1,8 @@
+fastNLP.io.pipe
+===============
+
+.. automodule:: fastNLP.io.pipe
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
diff --git a/docs/source/fastNLP.io.rst b/docs/source/fastNLP.io.rst
index a97ed67d..0a006709 100644
--- a/docs/source/fastNLP.io.rst
+++ b/docs/source/fastNLP.io.rst
@@ -6,14 +6,23 @@ fastNLP.io
:undoc-members:
:show-inheritance:
-子模块
+Subpackages
+-----------
+
+.. toctree::
+
+ fastNLP.io.data_loader
+ fastNLP.io.loader
+ fastNLP.io.pipe
+
+Submodules
----------
.. toctree::
- :maxdepth: 1
fastNLP.io.base_loader
- fastNLP.io.embed_loader
fastNLP.io.dataset_loader
- fastNLP.io.data_loader
+ fastNLP.io.embed_loader
+ fastNLP.io.file_utils
fastNLP.io.model_io
+ fastNLP.io.utils
diff --git a/docs/source/fastNLP.io.utils.rst b/docs/source/fastNLP.io.utils.rst
new file mode 100644
index 00000000..0b3f3938
--- /dev/null
+++ b/docs/source/fastNLP.io.utils.rst
@@ -0,0 +1,7 @@
+fastNLP.io.utils
+================
+
+.. automodule:: fastNLP.io.utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.models.rst b/docs/source/fastNLP.models.rst
index 2ea546e2..36875b85 100644
--- a/docs/source/fastNLP.models.rst
+++ b/docs/source/fastNLP.models.rst
@@ -6,11 +6,10 @@ fastNLP.models
:undoc-members:
:show-inheritance:
-子模块
+Submodules
----------
.. toctree::
- :maxdepth: 1
fastNLP.models.biaffine_parser
fastNLP.models.cnn_text_classification
diff --git a/docs/source/fastNLP.modules.encoder.rst b/docs/source/fastNLP.modules.encoder.rst
index 0562f12d..e60f9fa4 100644
--- a/docs/source/fastNLP.modules.encoder.rst
+++ b/docs/source/fastNLP.modules.encoder.rst
@@ -5,3 +5,4 @@ fastNLP.modules.encoder
:members:
:undoc-members:
:show-inheritance:
+
diff --git a/docs/source/fastNLP.modules.rst b/docs/source/fastNLP.modules.rst
index 646ef2d3..06494b53 100644
--- a/docs/source/fastNLP.modules.rst
+++ b/docs/source/fastNLP.modules.rst
@@ -6,12 +6,17 @@ fastNLP.modules
:undoc-members:
:show-inheritance:
-子模块
+Subpackages
-----------
.. toctree::
- :titlesonly:
- :maxdepth: 1
fastNLP.modules.decoder
- fastNLP.modules.encoder
\ No newline at end of file
+ fastNLP.modules.encoder
+
+Submodules
+----------
+
+.. toctree::
+
+ fastNLP.modules.utils
diff --git a/docs/source/fastNLP.modules.utils.rst b/docs/source/fastNLP.modules.utils.rst
new file mode 100644
index 00000000..c0219435
--- /dev/null
+++ b/docs/source/fastNLP.modules.utils.rst
@@ -0,0 +1,7 @@
+fastNLP.modules.utils
+=====================
+
+.. automodule:: fastNLP.modules.utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/fastNLP.rst b/docs/source/fastNLP.rst
index 0057a184..e3ba429d 100644
--- a/docs/source/fastNLP.rst
+++ b/docs/source/fastNLP.rst
@@ -1,16 +1,15 @@
-API 文档
-===============
+fastNLP
+=======
.. automodule:: fastNLP
:members:
:undoc-members:
:show-inheritance:
-内部模块
+Subpackages
-----------
.. toctree::
- :maxdepth: 1
fastNLP.core
fastNLP.embeddings
diff --git a/docs/source/modules.rst b/docs/source/modules.rst
index 9ca3c7f3..e9a92cb7 100644
--- a/docs/source/modules.rst
+++ b/docs/source/modules.rst
@@ -2,7 +2,6 @@ fastNLP
=======
.. toctree::
- :titlesonly:
:maxdepth: 4
fastNLP
diff --git a/fastNLP/core/dist_trainer.py b/fastNLP/core/dist_trainer.py
index 4a423933..00db6361 100644
--- a/fastNLP/core/dist_trainer.py
+++ b/fastNLP/core/dist_trainer.py
@@ -1,3 +1,6 @@
+"""
+正在开发中的分布式训练代码
+"""
import torch
import torch.cuda
import torch.optim
@@ -41,7 +44,8 @@ def get_local_rank():
class DistTrainer():
- """Distributed Trainer that support distributed and mixed precision training
+ """
+ Distributed Trainer that support distributed and mixed precision training
"""
def __init__(self, train_data, model, optimizer=None, loss=None,
callbacks_all=None, callbacks_master=None,
diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py
index ea5e84ac..723cd2d5 100644
--- a/fastNLP/embeddings/bert_embedding.py
+++ b/fastNLP/embeddings/bert_embedding.py
@@ -176,9 +176,9 @@ class BertWordPieceEncoder(nn.Module):
def index_datasets(self, *datasets, field_name, add_cls_sep=True):
"""
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了
- bert的pad value。
+ bert的pad value。
- :param DataSet datasets: DataSet对象
+ :param ~fastNLP.DataSet datasets: DataSet对象
:param str field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。
:param bool add_cls_sep: 如果首尾不是[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP]。
:return:
diff --git a/fastNLP/embeddings/contextual_embedding.py b/fastNLP/embeddings/contextual_embedding.py
index 1831af4e..152b0ab9 100644
--- a/fastNLP/embeddings/contextual_embedding.py
+++ b/fastNLP/embeddings/contextual_embedding.py
@@ -1,4 +1,3 @@
-
from abc import abstractmethod
import torch
@@ -9,6 +8,10 @@ from ..core.sampler import SequentialSampler
from ..core.utils import _move_model_to_device, _get_model_device
from .embedding import TokenEmbedding
+__all__ = [
+ "ContextualEmbedding"
+]
+
class ContextualEmbedding(TokenEmbedding):
def __init__(self, vocab: Vocabulary, word_dropout:float=0.0, dropout:float=0.0):
diff --git a/fastNLP/io/__init__.py b/fastNLP/io/__init__.py
index cd0d3527..bf5c2c36 100644
--- a/fastNLP/io/__init__.py
+++ b/fastNLP/io/__init__.py
@@ -14,27 +14,56 @@
__all__ = [
'EmbedLoader',
- 'CSVLoader',
- 'JsonLoader',
-
'DataBundle',
'DataSetLoader',
- 'ConllLoader',
- 'Conll2003Loader',
+ 'YelpLoader',
+ 'YelpFullLoader',
+ 'YelpPolarityLoader',
'IMDBLoader',
- 'MatchingLoader',
- 'SNLILoader',
- 'MNLILoader',
- 'MTL16Loader',
- 'PeopleDailyCorpusLoader',
- 'QNLILoader',
- 'QuoraLoader',
- 'RTELoader',
'SSTLoader',
'SST2Loader',
- 'YelpLoader',
-
+
+ 'ConllLoader',
+ 'Conll2003Loader',
+ 'Conll2003NERLoader',
+ 'OntoNotesNERLoader',
+ 'CTBLoader',
+
+ 'Loader',
+ 'CSVLoader',
+ 'JsonLoader',
+
+ 'CWSLoader',
+
+ 'MNLILoader',
+ "QuoraLoader",
+ "SNLILoader",
+ "QNLILoader",
+ "RTELoader",
+
+ "YelpFullPipe",
+ "YelpPolarityPipe",
+ "SSTPipe",
+ "SST2Pipe",
+ "IMDBPipe",
+
+ "Conll2003NERPipe",
+ "OntoNotesNERPipe",
+
+ "MatchingBertPipe",
+ "RTEBertPipe",
+ "SNLIBertPipe",
+ "QuoraBertPipe",
+ "QNLIBertPipe",
+ "MNLIBertPipe",
+ "MatchingPipe",
+ "RTEPipe",
+ "SNLIPipe",
+ "QuoraPipe",
+ "QNLIPipe",
+ "MNLIPipe",
+
'ModelLoader',
'ModelSaver',
]
@@ -44,4 +73,5 @@ from .base_loader import DataBundle, DataSetLoader
from .dataset_loader import CSVLoader, JsonLoader
from .model_io import ModelLoader, ModelSaver
-from .data_loader import *
+from .loader import *
+from .pipe import *
diff --git a/fastNLP/io/base_loader.py b/fastNLP/io/base_loader.py
index 01232627..429a8406 100644
--- a/fastNLP/io/base_loader.py
+++ b/fastNLP/io/base_loader.py
@@ -128,7 +128,7 @@ class DataBundle:
"""
向DataBunlde中增加vocab
- :param Vocabulary vocab: 词表
+ :param ~fastNLP.Vocabulary vocab: 词表
:param str field_name: 这个vocab对应的field名称
:return:
"""
@@ -138,7 +138,7 @@ class DataBundle:
def set_dataset(self, dataset, name):
"""
- :param DataSet dataset: 传递给DataBundle的DataSet
+ :param ~fastNLP.DataSet dataset: 传递给DataBundle的DataSet
:param str name: dataset的名称
:return:
"""
diff --git a/fastNLP/io/config_io.py b/fastNLP/io/config_io.py
index 4acdbb96..ac349080 100644
--- a/fastNLP/io/config_io.py
+++ b/fastNLP/io/config_io.py
@@ -1,7 +1,9 @@
"""
用于读入和处理和保存 config 文件
- .. todo::
+
+.. todo::
这个模块中的类可能被抛弃?
+
"""
__all__ = [
"ConfigLoader",
diff --git a/fastNLP/io/file_utils.py b/fastNLP/io/file_utils.py
index b465ed9b..43fe2ab1 100644
--- a/fastNLP/io/file_utils.py
+++ b/fastNLP/io/file_utils.py
@@ -84,6 +84,7 @@ def cached_path(url_or_filename:str, cache_dir:str=None, name=None) -> Path:
给定一个url,尝试通过url中的解析出来的文件名字filename到{cache_dir}/{name}/{filename}下寻找这个文件,
(1)如果cache_dir=None, 则cache_dir=~/.fastNLP/; 否则cache_dir=cache_dir
(2)如果name=None, 则没有中间的{name}这一层结构;否者中间结构就为{name}
+
如果有该文件,就直接返回路径
如果没有该文件,则尝试用传入的url下载
@@ -126,8 +127,10 @@ def get_filepath(filepath):
如果filepath为文件夹,
如果内含多个文件, 返回filepath
如果只有一个文件, 返回filepath + filename
+
如果filepath为文件
返回filepath
+
:param str filepath: 路径
:return:
"""
@@ -237,7 +240,8 @@ def split_filename_suffix(filepath):
def get_from_cache(url: str, cache_dir: Path = None) -> Path:
"""
尝试在cache_dir中寻找url定义的资源; 如果没有找到; 则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。会将下载的
- 文件解压,将解压后的文件全部放在cache_dir文件夹中。
+ 文件解压,将解压后的文件全部放在cache_dir文件夹中。
+
如果从url中下载的资源解压后有多个文件,则返回目录的路径; 如果只有一个资源文件,则返回具体的路径。
"""
cache_dir.mkdir(parents=True, exist_ok=True)
diff --git a/fastNLP/io/loader/__init__.py b/fastNLP/io/loader/__init__.py
index 8e436532..8c0d391c 100644
--- a/fastNLP/io/loader/__init__.py
+++ b/fastNLP/io/loader/__init__.py
@@ -1,30 +1,61 @@
-
"""
-Loader用于读取数据,并将内容读取到 :class:`~fastNLP.DataSet` 或者 :class:`~fastNLP.io.DataBundle`中。所有的Loader都支持以下的
- 三个方法: __init__(),_load(), loads(). 其中__init__()用于申明读取参数,以及说明该Loader支持的数据格式,读取后Dataset中field
- ; _load(path)方法传入一个文件路径读取单个文件,并返回DataSet; load(paths)用于读取文件夹下的文件,并返回DataBundle, load()方法
- 支持以下三种类型的参数
+Loader用于读取数据,并将内容读取到 :class:`~fastNLP.DataSet` 或者 :class:`~fastNLP.io.DataBundle` 中。所有的Loader都支持以下的
+三个方法: __init__(),_load(), loads(). 其中__init__()用于申明读取参数,以及说明该Loader支持的数据格式,读取后Dataset中field
+; _load(path)方法传入一个文件路径读取单个文件,并返回DataSet; load(paths)用于读取文件夹下的文件,并返回DataBundle, load()方法
+支持以下三种类型的参数::
- Example::
- (0) 如果传入None,将尝试自动下载数据集并缓存。但不是所有的数据都可以直接下载。
- (1) 如果传入的是一个文件path,则返回的DataBundle包含一个名为train的DataSet可以通过data_bundle.datasets['train']获取
- (2) 传入的是一个文件夹目录,将读取的是这个文件夹下文件名中包含'train', 'test', 'dev'的文件,其它文件会被忽略。
- 假设某个目录下的文件为
- -train.txt
- -dev.txt
- -test.txt
- -other.txt
- Loader().load('/path/to/dir')读取,返回的data_bundle中可以用data_bundle.datasets['train'], data_bundle.datasets['dev'],
- data_bundle.datasets['test']获取对应的DataSet,其中other.txt的内容会被忽略。
- 假设某个目录下的文件为
- -train.txt
- -dev.txt
- Loader().load('/path/to/dir')读取,返回的data_bundle中可以用data_bundle.datasets['train'], data_bundle.datasets['dev']获取
- 对应的DataSet。
- (3) 传入一个dict,key为dataset的名称,value是该dataset的文件路径。
- paths = {'train':'/path/to/train', 'dev': '/path/to/dev', 'test':'/path/to/test'}
- Loader().load(paths) # 返回的data_bundle可以通过以下的方式获取相应的DataSet, data_bundle.datasets['train'], data_bundle.datasets['dev'],
- data_bundle.datasets['test']
+ (0) 如果传入None,将尝试自动下载数据集并缓存。但不是所有的数据都可以直接下载。
+ (1) 如果传入的是一个文件path,则返回的DataBundle包含一个名为train的DataSet可以通过data_bundle.datasets['train']获取
+ (2) 传入的是一个文件夹目录,将读取的是这个文件夹下文件名中包含'train', 'test', 'dev'的文件,其它文件会被忽略。
+ 假设某个目录下的文件为
+ -train.txt
+ -dev.txt
+ -test.txt
+ -other.txt
+ Loader().load('/path/to/dir')读取,返回的data_bundle中可以用data_bundle.datasets['train'], data_bundle.datasets['dev'],
+ data_bundle.datasets['test']获取对应的DataSet,其中other.txt的内容会被忽略。
+ 假设某个目录下的文件为
+ -train.txt
+ -dev.txt
+ Loader().load('/path/to/dir')读取,返回的data_bundle中可以用data_bundle.datasets['train'], data_bundle.datasets['dev']获取
+ 对应的DataSet。
+ (3) 传入一个dict,key为dataset的名称,value是该dataset的文件路径。
+ paths = {'train':'/path/to/train', 'dev': '/path/to/dev', 'test':'/path/to/test'}
+ Loader().load(paths) # 返回的data_bundle可以通过以下的方式获取相应的DataSet, data_bundle.datasets['train'], data_bundle.datasets['dev'],
+ data_bundle.datasets['test']
"""
+__all__ = [
+ 'YelpLoader',
+ 'YelpFullLoader',
+ 'YelpPolarityLoader',
+ 'IMDBLoader',
+ 'SSTLoader',
+ 'SST2Loader',
+
+ 'ConllLoader',
+ 'Conll2003Loader',
+ 'Conll2003NERLoader',
+ 'OntoNotesNERLoader',
+ 'CTBLoader',
+
+ 'Loader',
+ 'CSVLoader',
+ 'JsonLoader',
+
+ 'CWSLoader',
+
+ 'MNLILoader',
+ "QuoraLoader",
+ "SNLILoader",
+ "QNLILoader",
+ "RTELoader"
+]
+from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader
+from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader
+from .csv import CSVLoader
+from .cws import CWSLoader
+from .json import JsonLoader
+from .loader import Loader
+from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader
diff --git a/fastNLP/io/loader/conll.py b/fastNLP/io/loader/conll.py
index 43790c15..b2c89ecc 100644
--- a/fastNLP/io/loader/conll.py
+++ b/fastNLP/io/loader/conll.py
@@ -1,12 +1,12 @@
from typing import Dict, Union
from .loader import Loader
-from ... import DataSet
+from ...core.dataset import DataSet
from ..file_reader import _read_conll
-from ... import Instance
+from ...core.instance import Instance
from .. import DataBundle
from ..utils import check_loader_paths
-from ... import Const
+from ...core.const import Const
class ConllLoader(Loader):
diff --git a/fastNLP/io/loader/cws.py b/fastNLP/io/loader/cws.py
index 46c07f28..3af28116 100644
--- a/fastNLP/io/loader/cws.py
+++ b/fastNLP/io/loader/cws.py
@@ -1,6 +1,6 @@
-
from .loader import Loader
-from ...core import DataSet, Instance
+from ...core.dataset import DataSet
+from ...core.instance import Instance
class CWSLoader(Loader):
diff --git a/fastNLP/io/loader/loader.py b/fastNLP/io/loader/loader.py
index 4cf5bcf3..c59de29f 100644
--- a/fastNLP/io/loader/loader.py
+++ b/fastNLP/io/loader/loader.py
@@ -1,4 +1,4 @@
-from ... import DataSet
+from ...core.dataset import DataSet
from .. import DataBundle
from ..utils import check_loader_paths
from typing import Union, Dict
diff --git a/fastNLP/io/loader/matching.py b/fastNLP/io/loader/matching.py
index eff98ba3..58fa0d6f 100644
--- a/fastNLP/io/loader/matching.py
+++ b/fastNLP/io/loader/matching.py
@@ -1,19 +1,12 @@
-
import warnings
from .loader import Loader
from .json import JsonLoader
-from ...core import Const
+from ...core.const import Const
from .. import DataBundle
import os
from typing import Union, Dict
-from ...core import DataSet
-from ...core import Instance
-
-__all__ = ['MNLILoader',
- "QuoraLoader",
- "SNLILoader",
- "QNLILoader",
- "RTELoader"]
+from ...core.dataset import DataSet
+from ...core.instance import Instance
class MNLILoader(Loader):
diff --git a/fastNLP/io/pipe/__init__.py b/fastNLP/io/pipe/__init__.py
index 0cf8d949..4cec3ad5 100644
--- a/fastNLP/io/pipe/__init__.py
+++ b/fastNLP/io/pipe/__init__.py
@@ -1,8 +1,34 @@
-
-
"""
Pipe用于处理数据,所有的Pipe都包含一个process(DataBundle)方法,传入一个DataBundle对象, 在传入DataBundle上进行原位修改,并将其返回;
- process_from_file(paths)传入的文件路径,返回一个DataBundle。process(DataBundle)或者process_from_file(paths)的返回DataBundle
- 中的DataSet一般都包含原文与转换为index的输入,以及转换为index的target;除了DataSet之外,还会包含将field转为index时所建立的词表。
+process_from_file(paths)传入的文件路径,返回一个DataBundle。process(DataBundle)或者process_from_file(paths)的返回DataBundle
+中的DataSet一般都包含原文与转换为index的输入,以及转换为index的target;除了DataSet之外,还会包含将field转为index时所建立的词表。
+
+"""
+__all__ = [
+ "YelpFullPipe",
+ "YelpPolarityPipe",
+ "SSTPipe",
+ "SST2Pipe",
+ "IMDBPipe",
+
+ "Conll2003NERPipe",
+ "OntoNotesNERPipe",
+
+ "MatchingBertPipe",
+ "RTEBertPipe",
+ "SNLIBertPipe",
+ "QuoraBertPipe",
+ "QNLIBertPipe",
+ "MNLIBertPipe",
+ "MatchingPipe",
+ "RTEPipe",
+ "SNLIPipe",
+ "QuoraPipe",
+ "QNLIPipe",
+ "MNLIPipe",
+]
-"""
\ No newline at end of file
+from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe
+from .conll import Conll2003NERPipe, OntoNotesNERPipe
+from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \
+ MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe
diff --git a/fastNLP/io/pipe/classification.py b/fastNLP/io/pipe/classification.py
index a64e5328..429b6552 100644
--- a/fastNLP/io/pipe/classification.py
+++ b/fastNLP/io/pipe/classification.py
@@ -1,17 +1,17 @@
-
from nltk import Tree
from ..base_loader import DataBundle
from ...core.vocabulary import Vocabulary
from ...core.const import Const
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader
-from ...core import DataSet, Instance
+from ...core.dataset import DataSet
+from ...core.instance import Instance
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance
from .pipe import Pipe
import re
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+')
-from ...core import cache_results
+from ...core.utils import cache_results
class _CLSPipe(Pipe):
"""
@@ -257,7 +257,7 @@ class SSTPipe(_CLSPipe):
"(4 (4 (2 Offers) (3 (3 (2 that) (3 (3 rare)..."
"..."
- :param DataBundle data_bundle: 需要处理的DataBundle对象
+ :param ~fastNLP.io.DataBundle data_bundle: 需要处理的DataBundle对象
:return:
"""
# 先取出subtree
@@ -407,7 +407,7 @@ class IMDBPipe(_CLSPipe):
:param DataBunlde data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和target两个field,且raw_words列应该为str,
target列应该为str。
- :return:DataBundle
+ :return: DataBundle
"""
# 替换
def replace_br(raw_words):
diff --git a/fastNLP/io/pipe/conll.py b/fastNLP/io/pipe/conll.py
index 4f780614..b9007344 100644
--- a/fastNLP/io/pipe/conll.py
+++ b/fastNLP/io/pipe/conll.py
@@ -1,7 +1,7 @@
from .pipe import Pipe
from .. import DataBundle
from .utils import iob2, iob2bioes
-from ... import Const
+from ...core.const import Const
from ..loader.conll import Conll2003NERLoader, OntoNotesNERLoader
from .utils import _indexize, _add_words_field
@@ -19,15 +19,16 @@ class _NERPipe(Pipe):
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。
:param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为-100。
"""
- def __init__(self, encoding_type:str='bio', lower:bool=False, target_pad_val=0):
- if encoding_type == 'bio':
+
+ def __init__(self, encoding_type: str = 'bio', lower: bool = False, target_pad_val=0):
+ if encoding_type == 'bio':
self.convert_tag = iob2
else:
self.convert_tag = iob2bioes
self.lower = lower
self.target_pad_val = int(target_pad_val)
- def process(self, data_bundle:DataBundle)->DataBundle:
+ def process(self, data_bundle: DataBundle) -> DataBundle:
"""
支持的DataSet的field为
@@ -146,4 +147,3 @@ class OntoNotesNERPipe(_NERPipe):
def process_from_file(self, paths):
data_bundle = OntoNotesNERLoader().load(paths)
return self.process(data_bundle)
-
diff --git a/fastNLP/io/pipe/matching.py b/fastNLP/io/pipe/matching.py
index 76a0eaf7..93e854b1 100644
--- a/fastNLP/io/pipe/matching.py
+++ b/fastNLP/io/pipe/matching.py
@@ -2,10 +2,11 @@ import math
from .pipe import Pipe
from .utils import get_tokenizer
-from ...core import Const
-from ...core import Vocabulary
+from ...core.const import Const
+from ...core.vocabulary import Vocabulary
from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader
+
class MatchingBertPipe(Pipe):
"""
Matching任务的Bert pipe,输出的DataSet将包含以下的field
diff --git a/fastNLP/io/pipe/pipe.py b/fastNLP/io/pipe/pipe.py
index 14c3866a..76cc00ec 100644
--- a/fastNLP/io/pipe/pipe.py
+++ b/fastNLP/io/pipe/pipe.py
@@ -1,9 +1,9 @@
-
from .. import DataBundle
+
class Pipe:
- def process(self, data_bundle:DataBundle)->DataBundle:
+ def process(self, data_bundle: DataBundle) -> DataBundle:
raise NotImplementedError
- def process_from_file(self, paths)->DataBundle:
+ def process_from_file(self, paths) -> DataBundle:
raise NotImplementedError
diff --git a/fastNLP/io/pipe/utils.py b/fastNLP/io/pipe/utils.py
index 59bee96e..5e9ff8dc 100644
--- a/fastNLP/io/pipe/utils.py
+++ b/fastNLP/io/pipe/utils.py
@@ -1,6 +1,6 @@
from typing import List
-from ...core import Vocabulary
-from ...core import Const
+from ...core.vocabulary import Vocabulary
+from ...core.const import Const
def iob2(tags:List[str])->List[str]:
"""
diff --git a/fastNLP/io/utils.py b/fastNLP/io/utils.py
index a4ca2954..76b32b0a 100644
--- a/fastNLP/io/utils.py
+++ b/fastNLP/io/utils.py
@@ -6,12 +6,14 @@ from pathlib import Path
def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
"""
- 检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果
- {
- 'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。
- 'test': 'xxx' # 可能有,也可能没有
- ...
- }
+ 检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果::
+
+ {
+ 'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。
+ 'test': 'xxx' # 可能有,也可能没有
+ ...
+ }
+
如果paths为不合法的,将直接进行raise相应的错误. 如果paths内不包含train也会报错。
:param str paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名
diff --git a/fastNLP/modules/utils.py b/fastNLP/modules/utils.py
index 21608c5d..ead75711 100644
--- a/fastNLP/modules/utils.py
+++ b/fastNLP/modules/utils.py
@@ -112,7 +112,7 @@ def get_dropout_mask(drop_p: float, tensor: torch.Tensor):
根据tensor的形状,生成一个mask
:param drop_p: float, 以多大的概率置为0。
- :param tensor:torch.Tensor
+ :param tensor: torch.Tensor
:return: torch.FloatTensor. 与tensor一样的shape
"""
mask_x = torch.ones_like(tensor)
diff --git a/reproduction/seqence_labelling/chinese_ner/data/ChineseNER.py b/reproduction/seqence_labelling/chinese_ner/data/ChineseNER.py
index cec5ab76..0d292bdc 100644
--- a/reproduction/seqence_labelling/chinese_ner/data/ChineseNER.py
+++ b/reproduction/seqence_labelling/chinese_ner/data/ChineseNER.py
@@ -51,7 +51,7 @@ class ChineseNERLoader(DataSetLoader):
:param paths:
:param bool, bigrams: 是否包含生成bigram feature, [a, b, c, d] -> [ab, bc, cd, d]
:param bool, trigrams: 是否包含trigram feature,[a, b, c, d] -> [abc, bcd, cd, d]
- :return: DataBundle
+ :return: ~fastNLP.io.DataBundle
包含以下的fields
raw_chars: List[str]
chars: List[int]