@@ -14,7 +14,7 @@ help: | |||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) | ||||
apidoc: | apidoc: | ||||
$(SPHINXAPIDOC) -efM -o source ../$(SPHINXPROJ) | |||||
$(SPHINXAPIDOC) -efM -o source ../$(SPHINXPROJ) && python3 format.py | |||||
server: | server: | ||||
cd build/html && python -m http.server | cd build/html && python -m http.server | ||||
@@ -0,0 +1,65 @@ | |||||
import os | |||||
def shorten(file, to_delete, cut=False): | |||||
if file.endswith("index.rst") or file.endswith("conf.py"): | |||||
return | |||||
res = [] | |||||
with open(file, "r") as fin: | |||||
lines = fin.readlines() | |||||
for line in lines: | |||||
if cut and line.rstrip() == "Submodules": | |||||
break | |||||
else: | |||||
res.append(line.rstrip()) | |||||
for i, line in enumerate(res): | |||||
if line.endswith(" package"): | |||||
res[i] = res[i][:-len(" package")] | |||||
res[i + 1] = res[i + 1][:-len(" package")] | |||||
elif line.endswith(" module"): | |||||
res[i] = res[i][:-len(" module")] | |||||
res[i + 1] = res[i + 1][:-len(" module")] | |||||
else: | |||||
for name in to_delete: | |||||
if line.endswith(name): | |||||
res[i] = "del" | |||||
with open(file, "w") as fout: | |||||
for line in res: | |||||
if line != "del": | |||||
print(line, file=fout) | |||||
def clear(path='./source/'): | |||||
files = os.listdir(path) | |||||
to_delete = [ | |||||
"fastNLP.core.dist_trainer", | |||||
"fastNLP.core.predictor", | |||||
"fastNLP.io.file_reader", | |||||
"fastNLP.io.config_io", | |||||
"fastNLP.embeddings.contextual_embedding", | |||||
"fastNLP.modules.dropout", | |||||
"fastNLP.models.base_model", | |||||
"fastNLP.models.bert", | |||||
"fastNLP.models.enas_utils", | |||||
"fastNLP.models.enas_controller", | |||||
"fastNLP.models.enas_model", | |||||
"fastNLP.models.enas_trainer", | |||||
] | |||||
for file in files: | |||||
if not os.path.isdir(path + file): | |||||
res = file.split('.') | |||||
if len(res) > 4: | |||||
to_delete.append(file[:-4]) | |||||
elif len(res) == 4: | |||||
shorten(path + file, to_delete, True) | |||||
else: | |||||
shorten(path + file, to_delete) | |||||
for file in to_delete: | |||||
os.remove(path + file + ".rst") | |||||
clear() |
@@ -6,11 +6,10 @@ fastNLP.core | |||||
:undoc-members: | :undoc-members: | ||||
:show-inheritance: | :show-inheritance: | ||||
子模块 | |||||
Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 1 | |||||
fastNLP.core.batch | fastNLP.core.batch | ||||
fastNLP.core.callback | fastNLP.core.callback | ||||
@@ -6,11 +6,10 @@ fastNLP.embeddings | |||||
:undoc-members: | :undoc-members: | ||||
:show-inheritance: | :show-inheritance: | ||||
子模块 | |||||
Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 1 | |||||
fastNLP.embeddings.bert_embedding | fastNLP.embeddings.bert_embedding | ||||
fastNLP.embeddings.char_embedding | fastNLP.embeddings.char_embedding | ||||
@@ -1,7 +1,8 @@ | |||||
fastNLP.io.data\_loader | fastNLP.io.data\_loader | ||||
========================== | |||||
======================= | |||||
.. automodule:: fastNLP.io.data_loader | .. automodule:: fastNLP.io.data_loader | ||||
:members: | :members: | ||||
:undoc-members: | :undoc-members: | ||||
:show-inheritance: | |||||
:show-inheritance: | |||||
@@ -0,0 +1,7 @@ | |||||
fastNLP.io.file\_utils | |||||
====================== | |||||
.. automodule:: fastNLP.io.file_utils | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: |
@@ -0,0 +1,8 @@ | |||||
fastNLP.io.loader | |||||
================= | |||||
.. automodule:: fastNLP.io.loader | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: | |||||
@@ -0,0 +1,8 @@ | |||||
fastNLP.io.pipe | |||||
=============== | |||||
.. automodule:: fastNLP.io.pipe | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: | |||||
@@ -6,14 +6,23 @@ fastNLP.io | |||||
:undoc-members: | :undoc-members: | ||||
:show-inheritance: | :show-inheritance: | ||||
子模块 | |||||
Subpackages | |||||
----------- | |||||
.. toctree:: | |||||
fastNLP.io.data_loader | |||||
fastNLP.io.loader | |||||
fastNLP.io.pipe | |||||
Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 1 | |||||
fastNLP.io.base_loader | fastNLP.io.base_loader | ||||
fastNLP.io.embed_loader | |||||
fastNLP.io.dataset_loader | fastNLP.io.dataset_loader | ||||
fastNLP.io.data_loader | |||||
fastNLP.io.embed_loader | |||||
fastNLP.io.file_utils | |||||
fastNLP.io.model_io | fastNLP.io.model_io | ||||
fastNLP.io.utils |
@@ -0,0 +1,7 @@ | |||||
fastNLP.io.utils | |||||
================ | |||||
.. automodule:: fastNLP.io.utils | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: |
@@ -6,11 +6,10 @@ fastNLP.models | |||||
:undoc-members: | :undoc-members: | ||||
:show-inheritance: | :show-inheritance: | ||||
子模块 | |||||
Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 1 | |||||
fastNLP.models.biaffine_parser | fastNLP.models.biaffine_parser | ||||
fastNLP.models.cnn_text_classification | fastNLP.models.cnn_text_classification | ||||
@@ -5,3 +5,4 @@ fastNLP.modules.encoder | |||||
:members: | :members: | ||||
:undoc-members: | :undoc-members: | ||||
:show-inheritance: | :show-inheritance: | ||||
@@ -6,12 +6,17 @@ fastNLP.modules | |||||
:undoc-members: | :undoc-members: | ||||
:show-inheritance: | :show-inheritance: | ||||
子模块 | |||||
Subpackages | |||||
----------- | ----------- | ||||
.. toctree:: | .. toctree:: | ||||
:titlesonly: | |||||
:maxdepth: 1 | |||||
fastNLP.modules.decoder | fastNLP.modules.decoder | ||||
fastNLP.modules.encoder | |||||
fastNLP.modules.encoder | |||||
Submodules | |||||
---------- | |||||
.. toctree:: | |||||
fastNLP.modules.utils |
@@ -0,0 +1,7 @@ | |||||
fastNLP.modules.utils | |||||
===================== | |||||
.. automodule:: fastNLP.modules.utils | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: |
@@ -1,16 +1,15 @@ | |||||
API 文档 | |||||
=============== | |||||
fastNLP | |||||
======= | |||||
.. automodule:: fastNLP | .. automodule:: fastNLP | ||||
:members: | :members: | ||||
:undoc-members: | :undoc-members: | ||||
:show-inheritance: | :show-inheritance: | ||||
内部模块 | |||||
Subpackages | |||||
----------- | ----------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 1 | |||||
fastNLP.core | fastNLP.core | ||||
fastNLP.embeddings | fastNLP.embeddings | ||||
@@ -2,7 +2,6 @@ fastNLP | |||||
======= | ======= | ||||
.. toctree:: | .. toctree:: | ||||
:titlesonly: | |||||
:maxdepth: 4 | :maxdepth: 4 | ||||
fastNLP | fastNLP |
@@ -1,3 +1,6 @@ | |||||
""" | |||||
正在开发中的分布式训练代码 | |||||
""" | |||||
import torch | import torch | ||||
import torch.cuda | import torch.cuda | ||||
import torch.optim | import torch.optim | ||||
@@ -41,7 +44,8 @@ def get_local_rank(): | |||||
class DistTrainer(): | class DistTrainer(): | ||||
"""Distributed Trainer that support distributed and mixed precision training | |||||
""" | |||||
Distributed Trainer that support distributed and mixed precision training | |||||
""" | """ | ||||
def __init__(self, train_data, model, optimizer=None, loss=None, | def __init__(self, train_data, model, optimizer=None, loss=None, | ||||
callbacks_all=None, callbacks_master=None, | callbacks_all=None, callbacks_master=None, | ||||
@@ -176,9 +176,9 @@ class BertWordPieceEncoder(nn.Module): | |||||
def index_datasets(self, *datasets, field_name, add_cls_sep=True): | def index_datasets(self, *datasets, field_name, add_cls_sep=True): | ||||
""" | """ | ||||
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了 | 使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了 | ||||
bert的pad value。 | |||||
bert的pad value。 | |||||
:param DataSet datasets: DataSet对象 | |||||
:param ~fastNLP.DataSet datasets: DataSet对象 | |||||
:param str field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。 | :param str field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。 | ||||
:param bool add_cls_sep: 如果首尾不是[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP]。 | :param bool add_cls_sep: 如果首尾不是[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP]。 | ||||
:return: | :return: | ||||
@@ -1,4 +1,3 @@ | |||||
from abc import abstractmethod | from abc import abstractmethod | ||||
import torch | import torch | ||||
@@ -9,6 +8,10 @@ from ..core.sampler import SequentialSampler | |||||
from ..core.utils import _move_model_to_device, _get_model_device | from ..core.utils import _move_model_to_device, _get_model_device | ||||
from .embedding import TokenEmbedding | from .embedding import TokenEmbedding | ||||
__all__ = [ | |||||
"ContextualEmbedding" | |||||
] | |||||
class ContextualEmbedding(TokenEmbedding): | class ContextualEmbedding(TokenEmbedding): | ||||
def __init__(self, vocab: Vocabulary, word_dropout:float=0.0, dropout:float=0.0): | def __init__(self, vocab: Vocabulary, word_dropout:float=0.0, dropout:float=0.0): | ||||
@@ -14,27 +14,56 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'EmbedLoader', | 'EmbedLoader', | ||||
'CSVLoader', | |||||
'JsonLoader', | |||||
'DataBundle', | 'DataBundle', | ||||
'DataSetLoader', | 'DataSetLoader', | ||||
'ConllLoader', | |||||
'Conll2003Loader', | |||||
'YelpLoader', | |||||
'YelpFullLoader', | |||||
'YelpPolarityLoader', | |||||
'IMDBLoader', | 'IMDBLoader', | ||||
'MatchingLoader', | |||||
'SNLILoader', | |||||
'MNLILoader', | |||||
'MTL16Loader', | |||||
'PeopleDailyCorpusLoader', | |||||
'QNLILoader', | |||||
'QuoraLoader', | |||||
'RTELoader', | |||||
'SSTLoader', | 'SSTLoader', | ||||
'SST2Loader', | 'SST2Loader', | ||||
'YelpLoader', | |||||
'ConllLoader', | |||||
'Conll2003Loader', | |||||
'Conll2003NERLoader', | |||||
'OntoNotesNERLoader', | |||||
'CTBLoader', | |||||
'Loader', | |||||
'CSVLoader', | |||||
'JsonLoader', | |||||
'CWSLoader', | |||||
'MNLILoader', | |||||
"QuoraLoader", | |||||
"SNLILoader", | |||||
"QNLILoader", | |||||
"RTELoader", | |||||
"YelpFullPipe", | |||||
"YelpPolarityPipe", | |||||
"SSTPipe", | |||||
"SST2Pipe", | |||||
"IMDBPipe", | |||||
"Conll2003NERPipe", | |||||
"OntoNotesNERPipe", | |||||
"MatchingBertPipe", | |||||
"RTEBertPipe", | |||||
"SNLIBertPipe", | |||||
"QuoraBertPipe", | |||||
"QNLIBertPipe", | |||||
"MNLIBertPipe", | |||||
"MatchingPipe", | |||||
"RTEPipe", | |||||
"SNLIPipe", | |||||
"QuoraPipe", | |||||
"QNLIPipe", | |||||
"MNLIPipe", | |||||
'ModelLoader', | 'ModelLoader', | ||||
'ModelSaver', | 'ModelSaver', | ||||
] | ] | ||||
@@ -44,4 +73,5 @@ from .base_loader import DataBundle, DataSetLoader | |||||
from .dataset_loader import CSVLoader, JsonLoader | from .dataset_loader import CSVLoader, JsonLoader | ||||
from .model_io import ModelLoader, ModelSaver | from .model_io import ModelLoader, ModelSaver | ||||
from .data_loader import * | |||||
from .loader import * | |||||
from .pipe import * |
@@ -128,7 +128,7 @@ class DataBundle: | |||||
""" | """ | ||||
向DataBunlde中增加vocab | 向DataBunlde中增加vocab | ||||
:param Vocabulary vocab: 词表 | |||||
:param ~fastNLP.Vocabulary vocab: 词表 | |||||
:param str field_name: 这个vocab对应的field名称 | :param str field_name: 这个vocab对应的field名称 | ||||
:return: | :return: | ||||
""" | """ | ||||
@@ -138,7 +138,7 @@ class DataBundle: | |||||
def set_dataset(self, dataset, name): | def set_dataset(self, dataset, name): | ||||
""" | """ | ||||
:param DataSet dataset: 传递给DataBundle的DataSet | |||||
:param ~fastNLP.DataSet dataset: 传递给DataBundle的DataSet | |||||
:param str name: dataset的名称 | :param str name: dataset的名称 | ||||
:return: | :return: | ||||
""" | """ | ||||
@@ -1,7 +1,9 @@ | |||||
""" | """ | ||||
用于读入和处理和保存 config 文件 | 用于读入和处理和保存 config 文件 | ||||
.. todo:: | |||||
.. todo:: | |||||
这个模块中的类可能被抛弃? | 这个模块中的类可能被抛弃? | ||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
"ConfigLoader", | "ConfigLoader", | ||||
@@ -84,6 +84,7 @@ def cached_path(url_or_filename:str, cache_dir:str=None, name=None) -> Path: | |||||
给定一个url,尝试通过url中的解析出来的文件名字filename到{cache_dir}/{name}/{filename}下寻找这个文件, | 给定一个url,尝试通过url中的解析出来的文件名字filename到{cache_dir}/{name}/{filename}下寻找这个文件, | ||||
(1)如果cache_dir=None, 则cache_dir=~/.fastNLP/; 否则cache_dir=cache_dir | (1)如果cache_dir=None, 则cache_dir=~/.fastNLP/; 否则cache_dir=cache_dir | ||||
(2)如果name=None, 则没有中间的{name}这一层结构;否者中间结构就为{name} | (2)如果name=None, 则没有中间的{name}这一层结构;否者中间结构就为{name} | ||||
如果有该文件,就直接返回路径 | 如果有该文件,就直接返回路径 | ||||
如果没有该文件,则尝试用传入的url下载 | 如果没有该文件,则尝试用传入的url下载 | ||||
@@ -126,8 +127,10 @@ def get_filepath(filepath): | |||||
如果filepath为文件夹, | 如果filepath为文件夹, | ||||
如果内含多个文件, 返回filepath | 如果内含多个文件, 返回filepath | ||||
如果只有一个文件, 返回filepath + filename | 如果只有一个文件, 返回filepath + filename | ||||
如果filepath为文件 | 如果filepath为文件 | ||||
返回filepath | 返回filepath | ||||
:param str filepath: 路径 | :param str filepath: 路径 | ||||
:return: | :return: | ||||
""" | """ | ||||
@@ -237,7 +240,8 @@ def split_filename_suffix(filepath): | |||||
def get_from_cache(url: str, cache_dir: Path = None) -> Path: | def get_from_cache(url: str, cache_dir: Path = None) -> Path: | ||||
""" | """ | ||||
尝试在cache_dir中寻找url定义的资源; 如果没有找到; 则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。会将下载的 | 尝试在cache_dir中寻找url定义的资源; 如果没有找到; 则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。会将下载的 | ||||
文件解压,将解压后的文件全部放在cache_dir文件夹中。 | |||||
文件解压,将解压后的文件全部放在cache_dir文件夹中。 | |||||
如果从url中下载的资源解压后有多个文件,则返回目录的路径; 如果只有一个资源文件,则返回具体的路径。 | 如果从url中下载的资源解压后有多个文件,则返回目录的路径; 如果只有一个资源文件,则返回具体的路径。 | ||||
""" | """ | ||||
cache_dir.mkdir(parents=True, exist_ok=True) | cache_dir.mkdir(parents=True, exist_ok=True) | ||||
@@ -1,30 +1,61 @@ | |||||
""" | """ | ||||
Loader用于读取数据,并将内容读取到 :class:`~fastNLP.DataSet` 或者 :class:`~fastNLP.io.DataBundle`中。所有的Loader都支持以下的 | |||||
三个方法: __init__(),_load(), loads(). 其中__init__()用于申明读取参数,以及说明该Loader支持的数据格式,读取后Dataset中field | |||||
; _load(path)方法传入一个文件路径读取单个文件,并返回DataSet; load(paths)用于读取文件夹下的文件,并返回DataBundle, load()方法 | |||||
支持以下三种类型的参数 | |||||
Loader用于读取数据,并将内容读取到 :class:`~fastNLP.DataSet` 或者 :class:`~fastNLP.io.DataBundle` 中。所有的Loader都支持以下的 | |||||
三个方法: __init__(),_load(), loads(). 其中__init__()用于申明读取参数,以及说明该Loader支持的数据格式,读取后Dataset中field | |||||
; _load(path)方法传入一个文件路径读取单个文件,并返回DataSet; load(paths)用于读取文件夹下的文件,并返回DataBundle, load()方法 | |||||
支持以下三种类型的参数:: | |||||
Example:: | |||||
(0) 如果传入None,将尝试自动下载数据集并缓存。但不是所有的数据都可以直接下载。 | |||||
(1) 如果传入的是一个文件path,则返回的DataBundle包含一个名为train的DataSet可以通过data_bundle.datasets['train']获取 | |||||
(2) 传入的是一个文件夹目录,将读取的是这个文件夹下文件名中包含'train', 'test', 'dev'的文件,其它文件会被忽略。 | |||||
假设某个目录下的文件为 | |||||
-train.txt | |||||
-dev.txt | |||||
-test.txt | |||||
-other.txt | |||||
Loader().load('/path/to/dir')读取,返回的data_bundle中可以用data_bundle.datasets['train'], data_bundle.datasets['dev'], | |||||
data_bundle.datasets['test']获取对应的DataSet,其中other.txt的内容会被忽略。 | |||||
假设某个目录下的文件为 | |||||
-train.txt | |||||
-dev.txt | |||||
Loader().load('/path/to/dir')读取,返回的data_bundle中可以用data_bundle.datasets['train'], data_bundle.datasets['dev']获取 | |||||
对应的DataSet。 | |||||
(3) 传入一个dict,key为dataset的名称,value是该dataset的文件路径。 | |||||
paths = {'train':'/path/to/train', 'dev': '/path/to/dev', 'test':'/path/to/test'} | |||||
Loader().load(paths) # 返回的data_bundle可以通过以下的方式获取相应的DataSet, data_bundle.datasets['train'], data_bundle.datasets['dev'], | |||||
data_bundle.datasets['test'] | |||||
(0) 如果传入None,将尝试自动下载数据集并缓存。但不是所有的数据都可以直接下载。 | |||||
(1) 如果传入的是一个文件path,则返回的DataBundle包含一个名为train的DataSet可以通过data_bundle.datasets['train']获取 | |||||
(2) 传入的是一个文件夹目录,将读取的是这个文件夹下文件名中包含'train', 'test', 'dev'的文件,其它文件会被忽略。 | |||||
假设某个目录下的文件为 | |||||
-train.txt | |||||
-dev.txt | |||||
-test.txt | |||||
-other.txt | |||||
Loader().load('/path/to/dir')读取,返回的data_bundle中可以用data_bundle.datasets['train'], data_bundle.datasets['dev'], | |||||
data_bundle.datasets['test']获取对应的DataSet,其中other.txt的内容会被忽略。 | |||||
假设某个目录下的文件为 | |||||
-train.txt | |||||
-dev.txt | |||||
Loader().load('/path/to/dir')读取,返回的data_bundle中可以用data_bundle.datasets['train'], data_bundle.datasets['dev']获取 | |||||
对应的DataSet。 | |||||
(3) 传入一个dict,key为dataset的名称,value是该dataset的文件路径。 | |||||
paths = {'train':'/path/to/train', 'dev': '/path/to/dev', 'test':'/path/to/test'} | |||||
Loader().load(paths) # 返回的data_bundle可以通过以下的方式获取相应的DataSet, data_bundle.datasets['train'], data_bundle.datasets['dev'], | |||||
data_bundle.datasets['test'] | |||||
""" | """ | ||||
__all__ = [ | |||||
'YelpLoader', | |||||
'YelpFullLoader', | |||||
'YelpPolarityLoader', | |||||
'IMDBLoader', | |||||
'SSTLoader', | |||||
'SST2Loader', | |||||
'ConllLoader', | |||||
'Conll2003Loader', | |||||
'Conll2003NERLoader', | |||||
'OntoNotesNERLoader', | |||||
'CTBLoader', | |||||
'Loader', | |||||
'CSVLoader', | |||||
'JsonLoader', | |||||
'CWSLoader', | |||||
'MNLILoader', | |||||
"QuoraLoader", | |||||
"SNLILoader", | |||||
"QNLILoader", | |||||
"RTELoader" | |||||
] | |||||
from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader | |||||
from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader | |||||
from .csv import CSVLoader | |||||
from .cws import CWSLoader | |||||
from .json import JsonLoader | |||||
from .loader import Loader | |||||
from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader |
@@ -1,12 +1,12 @@ | |||||
from typing import Dict, Union | from typing import Dict, Union | ||||
from .loader import Loader | from .loader import Loader | ||||
from ... import DataSet | |||||
from ...core.dataset import DataSet | |||||
from ..file_reader import _read_conll | from ..file_reader import _read_conll | ||||
from ... import Instance | |||||
from ...core.instance import Instance | |||||
from .. import DataBundle | from .. import DataBundle | ||||
from ..utils import check_loader_paths | from ..utils import check_loader_paths | ||||
from ... import Const | |||||
from ...core.const import Const | |||||
class ConllLoader(Loader): | class ConllLoader(Loader): | ||||
@@ -1,6 +1,6 @@ | |||||
from .loader import Loader | from .loader import Loader | ||||
from ...core import DataSet, Instance | |||||
from ...core.dataset import DataSet | |||||
from ...core.instance import Instance | |||||
class CWSLoader(Loader): | class CWSLoader(Loader): | ||||
@@ -1,4 +1,4 @@ | |||||
from ... import DataSet | |||||
from ...core.dataset import DataSet | |||||
from .. import DataBundle | from .. import DataBundle | ||||
from ..utils import check_loader_paths | from ..utils import check_loader_paths | ||||
from typing import Union, Dict | from typing import Union, Dict | ||||
@@ -1,19 +1,12 @@ | |||||
import warnings | import warnings | ||||
from .loader import Loader | from .loader import Loader | ||||
from .json import JsonLoader | from .json import JsonLoader | ||||
from ...core import Const | |||||
from ...core.const import Const | |||||
from .. import DataBundle | from .. import DataBundle | ||||
import os | import os | ||||
from typing import Union, Dict | from typing import Union, Dict | ||||
from ...core import DataSet | |||||
from ...core import Instance | |||||
__all__ = ['MNLILoader', | |||||
"QuoraLoader", | |||||
"SNLILoader", | |||||
"QNLILoader", | |||||
"RTELoader"] | |||||
from ...core.dataset import DataSet | |||||
from ...core.instance import Instance | |||||
class MNLILoader(Loader): | class MNLILoader(Loader): | ||||
@@ -1,8 +1,34 @@ | |||||
""" | """ | ||||
Pipe用于处理数据,所有的Pipe都包含一个process(DataBundle)方法,传入一个DataBundle对象, 在传入DataBundle上进行原位修改,并将其返回; | Pipe用于处理数据,所有的Pipe都包含一个process(DataBundle)方法,传入一个DataBundle对象, 在传入DataBundle上进行原位修改,并将其返回; | ||||
process_from_file(paths)传入的文件路径,返回一个DataBundle。process(DataBundle)或者process_from_file(paths)的返回DataBundle | |||||
中的DataSet一般都包含原文与转换为index的输入,以及转换为index的target;除了DataSet之外,还会包含将field转为index时所建立的词表。 | |||||
process_from_file(paths)传入的文件路径,返回一个DataBundle。process(DataBundle)或者process_from_file(paths)的返回DataBundle | |||||
中的DataSet一般都包含原文与转换为index的输入,以及转换为index的target;除了DataSet之外,还会包含将field转为index时所建立的词表。 | |||||
""" | |||||
__all__ = [ | |||||
"YelpFullPipe", | |||||
"YelpPolarityPipe", | |||||
"SSTPipe", | |||||
"SST2Pipe", | |||||
"IMDBPipe", | |||||
"Conll2003NERPipe", | |||||
"OntoNotesNERPipe", | |||||
"MatchingBertPipe", | |||||
"RTEBertPipe", | |||||
"SNLIBertPipe", | |||||
"QuoraBertPipe", | |||||
"QNLIBertPipe", | |||||
"MNLIBertPipe", | |||||
"MatchingPipe", | |||||
"RTEPipe", | |||||
"SNLIPipe", | |||||
"QuoraPipe", | |||||
"QNLIPipe", | |||||
"MNLIPipe", | |||||
] | |||||
""" | |||||
from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe | |||||
from .conll import Conll2003NERPipe, OntoNotesNERPipe | |||||
from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \ | |||||
MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe |
@@ -1,17 +1,17 @@ | |||||
from nltk import Tree | from nltk import Tree | ||||
from ..base_loader import DataBundle | from ..base_loader import DataBundle | ||||
from ...core.vocabulary import Vocabulary | from ...core.vocabulary import Vocabulary | ||||
from ...core.const import Const | from ...core.const import Const | ||||
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader | from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader | ||||
from ...core import DataSet, Instance | |||||
from ...core.dataset import DataSet | |||||
from ...core.instance import Instance | |||||
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance | from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance | ||||
from .pipe import Pipe | from .pipe import Pipe | ||||
import re | import re | ||||
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') | nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') | ||||
from ...core import cache_results | |||||
from ...core.utils import cache_results | |||||
class _CLSPipe(Pipe): | class _CLSPipe(Pipe): | ||||
""" | """ | ||||
@@ -257,7 +257,7 @@ class SSTPipe(_CLSPipe): | |||||
"(4 (4 (2 Offers) (3 (3 (2 that) (3 (3 rare)..." | "(4 (4 (2 Offers) (3 (3 (2 that) (3 (3 rare)..." | ||||
"..." | "..." | ||||
:param DataBundle data_bundle: 需要处理的DataBundle对象 | |||||
:param ~fastNLP.io.DataBundle data_bundle: 需要处理的DataBundle对象 | |||||
:return: | :return: | ||||
""" | """ | ||||
# 先取出subtree | # 先取出subtree | ||||
@@ -407,7 +407,7 @@ class IMDBPipe(_CLSPipe): | |||||
:param DataBunlde data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和target两个field,且raw_words列应该为str, | :param DataBunlde data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和target两个field,且raw_words列应该为str, | ||||
target列应该为str。 | target列应该为str。 | ||||
:return:DataBundle | |||||
:return: DataBundle | |||||
""" | """ | ||||
# 替换<br /> | # 替换<br /> | ||||
def replace_br(raw_words): | def replace_br(raw_words): | ||||
@@ -1,7 +1,7 @@ | |||||
from .pipe import Pipe | from .pipe import Pipe | ||||
from .. import DataBundle | from .. import DataBundle | ||||
from .utils import iob2, iob2bioes | from .utils import iob2, iob2bioes | ||||
from ... import Const | |||||
from ...core.const import Const | |||||
from ..loader.conll import Conll2003NERLoader, OntoNotesNERLoader | from ..loader.conll import Conll2003NERLoader, OntoNotesNERLoader | ||||
from .utils import _indexize, _add_words_field | from .utils import _indexize, _add_words_field | ||||
@@ -19,15 +19,16 @@ class _NERPipe(Pipe): | |||||
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | ||||
:param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为-100。 | :param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为-100。 | ||||
""" | """ | ||||
def __init__(self, encoding_type:str='bio', lower:bool=False, target_pad_val=0): | |||||
if encoding_type == 'bio': | |||||
def __init__(self, encoding_type: str = 'bio', lower: bool = False, target_pad_val=0): | |||||
if encoding_type == 'bio': | |||||
self.convert_tag = iob2 | self.convert_tag = iob2 | ||||
else: | else: | ||||
self.convert_tag = iob2bioes | self.convert_tag = iob2bioes | ||||
self.lower = lower | self.lower = lower | ||||
self.target_pad_val = int(target_pad_val) | self.target_pad_val = int(target_pad_val) | ||||
def process(self, data_bundle:DataBundle)->DataBundle: | |||||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||||
""" | """ | ||||
支持的DataSet的field为 | 支持的DataSet的field为 | ||||
@@ -146,4 +147,3 @@ class OntoNotesNERPipe(_NERPipe): | |||||
def process_from_file(self, paths): | def process_from_file(self, paths): | ||||
data_bundle = OntoNotesNERLoader().load(paths) | data_bundle = OntoNotesNERLoader().load(paths) | ||||
return self.process(data_bundle) | return self.process(data_bundle) | ||||
@@ -2,10 +2,11 @@ import math | |||||
from .pipe import Pipe | from .pipe import Pipe | ||||
from .utils import get_tokenizer | from .utils import get_tokenizer | ||||
from ...core import Const | |||||
from ...core import Vocabulary | |||||
from ...core.const import Const | |||||
from ...core.vocabulary import Vocabulary | |||||
from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader | from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader | ||||
class MatchingBertPipe(Pipe): | class MatchingBertPipe(Pipe): | ||||
""" | """ | ||||
Matching任务的Bert pipe,输出的DataSet将包含以下的field | Matching任务的Bert pipe,输出的DataSet将包含以下的field | ||||
@@ -1,9 +1,9 @@ | |||||
from .. import DataBundle | from .. import DataBundle | ||||
class Pipe: | class Pipe: | ||||
def process(self, data_bundle:DataBundle)->DataBundle: | |||||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||||
raise NotImplementedError | raise NotImplementedError | ||||
def process_from_file(self, paths)->DataBundle: | |||||
def process_from_file(self, paths) -> DataBundle: | |||||
raise NotImplementedError | raise NotImplementedError |
@@ -1,6 +1,6 @@ | |||||
from typing import List | from typing import List | ||||
from ...core import Vocabulary | |||||
from ...core import Const | |||||
from ...core.vocabulary import Vocabulary | |||||
from ...core.const import Const | |||||
def iob2(tags:List[str])->List[str]: | def iob2(tags:List[str])->List[str]: | ||||
""" | """ | ||||
@@ -6,12 +6,14 @@ from pathlib import Path | |||||
def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | ||||
""" | """ | ||||
检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果 | |||||
{ | |||||
'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。 | |||||
'test': 'xxx' # 可能有,也可能没有 | |||||
... | |||||
} | |||||
检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果:: | |||||
{ | |||||
'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。 | |||||
'test': 'xxx' # 可能有,也可能没有 | |||||
... | |||||
} | |||||
如果paths为不合法的,将直接进行raise相应的错误. 如果paths内不包含train也会报错。 | 如果paths为不合法的,将直接进行raise相应的错误. 如果paths内不包含train也会报错。 | ||||
:param str paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名 | :param str paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名 | ||||
@@ -112,7 +112,7 @@ def get_dropout_mask(drop_p: float, tensor: torch.Tensor): | |||||
根据tensor的形状,生成一个mask | 根据tensor的形状,生成一个mask | ||||
:param drop_p: float, 以多大的概率置为0。 | :param drop_p: float, 以多大的概率置为0。 | ||||
:param tensor:torch.Tensor | |||||
:param tensor: torch.Tensor | |||||
:return: torch.FloatTensor. 与tensor一样的shape | :return: torch.FloatTensor. 与tensor一样的shape | ||||
""" | """ | ||||
mask_x = torch.ones_like(tensor) | mask_x = torch.ones_like(tensor) | ||||
@@ -51,7 +51,7 @@ class ChineseNERLoader(DataSetLoader): | |||||
:param paths: | :param paths: | ||||
:param bool, bigrams: 是否包含生成bigram feature, [a, b, c, d] -> [ab, bc, cd, d<eos>] | :param bool, bigrams: 是否包含生成bigram feature, [a, b, c, d] -> [ab, bc, cd, d<eos>] | ||||
:param bool, trigrams: 是否包含trigram feature,[a, b, c, d] -> [abc, bcd, cd<eos>, d<eos><eos>] | :param bool, trigrams: 是否包含trigram feature,[a, b, c, d] -> [abc, bcd, cd<eos>, d<eos><eos>] | ||||
:return: DataBundle | |||||
:return: ~fastNLP.io.DataBundle | |||||
包含以下的fields | 包含以下的fields | ||||
raw_chars: List[str] | raw_chars: List[str] | ||||
chars: List[int] | chars: List[int] | ||||