diff --git a/fastNLP/io/data_bundle.py b/fastNLP/io/data_bundle.py
index 1e663f1e..db60a86f 100644
--- a/fastNLP/io/data_bundle.py
+++ b/fastNLP/io/data_bundle.py
@@ -1,10 +1,15 @@
+"""
+.. todo::
+ doc
+"""
__all__ = [
'DataBundle',
]
import _pickle as pickle
-from typing import Union, Dict
import os
+from typing import Union, Dict
+
from ..core.dataset import DataSet
from ..core.vocabulary import Vocabulary
diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py
index 82e96597..fca0de69 100644
--- a/fastNLP/io/dataset_loader.py
+++ b/fastNLP/io/dataset_loader.py
@@ -1,4 +1,4 @@
-"""
+"""undocumented
.. warning::
本模块将在 `0.5.0版本` 中被废弃,由 :mod:`~fastNLP.io.loader` 和 :mod:`~fastNLP.io.pipe` 模块替代。
@@ -23,10 +23,10 @@ __all__ = [
]
+from .data_bundle import DataSetLoader
+from .file_reader import _read_csv, _read_json
from ..core.dataset import DataSet
from ..core.instance import Instance
-from .file_reader import _read_csv, _read_json
-from .data_bundle import DataSetLoader
class JsonLoader(DataSetLoader):
diff --git a/fastNLP/io/embed_loader.py b/fastNLP/io/embed_loader.py
index c58385e1..780d91e4 100644
--- a/fastNLP/io/embed_loader.py
+++ b/fastNLP/io/embed_loader.py
@@ -1,17 +1,22 @@
+"""
+.. todo::
+ doc
+"""
__all__ = [
"EmbedLoader",
"EmbeddingOption",
]
+import logging
import os
import warnings
import numpy as np
-from ..core.vocabulary import Vocabulary
from .data_bundle import BaseLoader
from ..core.utils import Option
-import logging
+from ..core.vocabulary import Vocabulary
+
class EmbeddingOption(Option):
def __init__(self,
diff --git a/fastNLP/io/file_reader.py b/fastNLP/io/file_reader.py
index 0320572c..7a953098 100644
--- a/fastNLP/io/file_reader.py
+++ b/fastNLP/io/file_reader.py
@@ -1,7 +1,11 @@
-"""
+"""undocumented
此模块用于给其它模块提供读取文件的函数,没有为用户提供 API
"""
+
+__all__ = []
+
import json
+
from ..core import logger
@@ -24,8 +28,8 @@ def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True):
headers = headers.split(sep)
start_idx += 1
elif not isinstance(headers, (list, tuple)):
- raise TypeError("headers should be list or tuple, not {}." \
- .format(type(headers)))
+ raise TypeError("headers should be list or tuple, not {}." \
+ .format(type(headers)))
for line_idx, line in enumerate(f, start_idx):
contents = line.rstrip('\r\n').split(sep)
if len(contents) != len(headers):
@@ -82,6 +86,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
:if False, raise ValueError when reading invalid data. default: True
:return: generator, every time yield (line number, conll item)
"""
+
def parse_conll(sample):
sample = list(map(list, zip(*sample)))
sample = [sample[i] for i in indexes]
@@ -89,14 +94,15 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
if len(f) <= 0:
raise ValueError('empty field')
return sample
+
with open(path, 'r', encoding=encoding) as f:
sample = []
start = next(f).strip()
- if start!='':
+ if start != '':
sample.append(start.split())
for line_idx, line in enumerate(f, 1):
line = line.strip()
- if line=='':
+ if line == '':
if len(sample):
try:
res = parse_conll(sample)
diff --git a/fastNLP/io/file_utils.py b/fastNLP/io/file_utils.py
index bd02158e..8ecdff25 100644
--- a/fastNLP/io/file_utils.py
+++ b/fastNLP/io/file_utils.py
@@ -1,12 +1,27 @@
+"""
+.. todo::
+ doc
+"""
+
+__all__ = [
+ "cached_path",
+ "get_filepath",
+ "get_cache_path",
+ "split_filename_suffix",
+ "get_from_cache",
+]
+
import os
+import re
+import shutil
+import tempfile
from pathlib import Path
from urllib.parse import urlparse
-import re
+
import requests
-import tempfile
-from tqdm import tqdm
-import shutil
from requests import HTTPError
+from tqdm import tqdm
+
from ..core import logger
PRETRAINED_BERT_MODEL_DIR = {
diff --git a/fastNLP/io/loader/classification.py b/fastNLP/io/loader/classification.py
index f64a26e7..ec00d2b4 100644
--- a/fastNLP/io/loader/classification.py
+++ b/fastNLP/io/loader/classification.py
@@ -1,12 +1,24 @@
-from ...core.dataset import DataSet
-from ...core.instance import Instance
-from .loader import Loader
-import warnings
+"""undocumented"""
+
+__all__ = [
+ "YelpLoader",
+ "YelpFullLoader",
+ "YelpPolarityLoader",
+ "IMDBLoader",
+ "SSTLoader",
+ "SST2Loader",
+]
+
+import glob
import os
import random
import shutil
-import glob
import time
+import warnings
+
+from .loader import Loader
+from ...core.dataset import DataSet
+from ...core.instance import Instance
class YelpLoader(Loader):
@@ -58,7 +70,7 @@ class YelpLoader(Loader):
class YelpFullLoader(YelpLoader):
- def download(self, dev_ratio: float = 0.1, re_download:bool=False):
+ def download(self, dev_ratio: float = 0.1, re_download: bool = False):
"""
自动下载数据集,如果你使用了这个数据集,请引用以下的文章
@@ -127,7 +139,7 @@ class YelpPolarityLoader(YelpLoader):
if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的
shutil.rmtree(data_dir)
data_dir = self._get_dataset_path(dataset_name=dataset_name)
-
+
if not os.path.exists(os.path.join(data_dir, 'dev.csv')):
if dev_ratio > 0:
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."
diff --git a/fastNLP/io/loader/conll.py b/fastNLP/io/loader/conll.py
index b5241cff..1bd1b448 100644
--- a/fastNLP/io/loader/conll.py
+++ b/fastNLP/io/loader/conll.py
@@ -1,15 +1,28 @@
-from typing import Dict, Union
+"""undocumented"""
+
+__all__ = [
+ "ConllLoader",
+ "Conll2003Loader",
+ "Conll2003NERLoader",
+ "OntoNotesNERLoader",
+ "CTBLoader",
+ "CNNERLoader",
+ "MsraNERLoader",
+ "WeiboNERLoader",
+ "PeopleDailyNERLoader"
+]
-from .loader import Loader
-from ...core.dataset import DataSet
-from ..file_reader import _read_conll
-from ...core.instance import Instance
-from ...core.const import Const
import glob
import os
+import random
import shutil
import time
-import random
+
+from .loader import Loader
+from ..file_reader import _read_conll
+from ...core.const import Const
+from ...core.dataset import DataSet
+from ...core.instance import Instance
class ConllLoader(Loader):
@@ -47,6 +60,7 @@ class ConllLoader(Loader):
:param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True``
"""
+
def __init__(self, headers, indexes=None, dropna=True):
super(ConllLoader, self).__init__()
if not isinstance(headers, (list, tuple)):
@@ -60,7 +74,7 @@ class ConllLoader(Loader):
if len(indexes) != len(headers):
raise ValueError
self.indexes = indexes
-
+
def _load(self, path):
"""
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。
@@ -101,12 +115,13 @@ class Conll2003Loader(ConllLoader):
"[...]", "[...]", "[...]", "[...]"
"""
+
def __init__(self):
headers = [
'raw_words', 'pos', 'chunk', 'ner',
]
super(Conll2003Loader, self).__init__(headers=headers)
-
+
def _load(self, path):
"""
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。
@@ -127,7 +142,7 @@ class Conll2003Loader(ConllLoader):
ins = {h: data[i] for i, h in enumerate(self.headers)}
ds.append(Instance(**ins))
return ds
-
+
def download(self, output_dir=None):
raise RuntimeError("conll2003 cannot be downloaded automatically.")
@@ -158,12 +173,13 @@ class Conll2003NERLoader(ConllLoader):
"[...]", "[...]"
"""
+
def __init__(self):
headers = [
'raw_words', 'target',
]
super().__init__(headers=headers, indexes=[0, 3])
-
+
def _load(self, path):
"""
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。
@@ -184,7 +200,7 @@ class Conll2003NERLoader(ConllLoader):
ins = {h: data[i] for i, h in enumerate(self.headers)}
ds.append(Instance(**ins))
return ds
-
+
def download(self):
raise RuntimeError("conll2003 cannot be downloaded automatically.")
@@ -204,13 +220,13 @@ class OntoNotesNERLoader(ConllLoader):
"[...]", "[...]"
"""
-
+
def __init__(self):
super().__init__(headers=[Const.RAW_WORD, Const.TARGET], indexes=[3, 10])
-
- def _load(self, path:str):
+
+ def _load(self, path: str):
dataset = super()._load(path)
-
+
def convert_to_bio(tags):
bio_tags = []
flag = None
@@ -227,7 +243,7 @@ class OntoNotesNERLoader(ConllLoader):
flag = None
bio_tags.append(bio_label)
return bio_tags
-
+
def convert_word(words):
converted_words = []
for word in words:
@@ -236,7 +252,7 @@ class OntoNotesNERLoader(ConllLoader):
converted_words.append(word)
continue
# 以下是由于这些符号被转义了,再转回来
- tfrs = {'-LRB-':'(',
+ tfrs = {'-LRB-': '(',
'-RRB-': ')',
'-LSB-': '[',
'-RSB-': ']',
@@ -248,12 +264,12 @@ class OntoNotesNERLoader(ConllLoader):
else:
converted_words.append(word)
return converted_words
-
+
dataset.apply_field(convert_word, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD)
dataset.apply_field(convert_to_bio, field_name=Const.TARGET, new_field_name=Const.TARGET)
-
+
return dataset
-
+
def download(self):
raise RuntimeError("Ontonotes cannot be downloaded automatically, you can refer "
"https://github.com/yhcc/OntoNotes-5.0-NER to download and preprocess.")
@@ -262,13 +278,13 @@ class OntoNotesNERLoader(ConllLoader):
class CTBLoader(Loader):
def __init__(self):
super().__init__()
-
- def _load(self, path:str):
+
+ def _load(self, path: str):
pass
class CNNERLoader(Loader):
- def _load(self, path:str):
+ def _load(self, path: str):
"""
支持加载形如以下格式的内容,一行两列,以空格隔开两个sample
@@ -331,10 +347,11 @@ class MsraNERLoader(CNNERLoader):
"[...]", "[...]"
"""
+
def __init__(self):
super().__init__()
-
- def download(self, dev_ratio:float=0.1, re_download:bool=False)->str:
+
+ def download(self, dev_ratio: float = 0.1, re_download: bool = False) -> str:
"""
自动下载MSAR-NER的数据,如果你使用该数据,请引用 Gina-Anne Levow, 2006, The Third International Chinese Language
Processing Bakeoff: Word Segmentation and Named Entity Recognition.
@@ -356,7 +373,7 @@ class MsraNERLoader(CNNERLoader):
if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的
shutil.rmtree(data_dir)
data_dir = self._get_dataset_path(dataset_name=dataset_name)
-
+
if not os.path.exists(os.path.join(data_dir, 'dev.conll')):
if dev_ratio > 0:
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."
@@ -380,15 +397,15 @@ class MsraNERLoader(CNNERLoader):
finally:
if os.path.exists(os.path.join(data_dir, 'middle_file.conll')):
os.remove(os.path.join(data_dir, 'middle_file.conll'))
-
+
return data_dir
class WeiboNERLoader(CNNERLoader):
def __init__(self):
super().__init__()
-
- def download(self)->str:
+
+ def download(self) -> str:
"""
自动下载Weibo-NER的数据,如果你使用了该数据,请引用 Nanyun Peng and Mark Dredze, 2015, Named Entity Recognition for
Chinese Social Media with Jointly Trained Embeddings.
@@ -397,7 +414,7 @@ class WeiboNERLoader(CNNERLoader):
"""
dataset_name = 'weibo-ner'
data_dir = self._get_dataset_path(dataset_name=dataset_name)
-
+
return data_dir
@@ -427,11 +444,12 @@ class PeopleDailyNERLoader(CNNERLoader):
"[...]", "[...]"
"""
+
def __init__(self):
super().__init__()
-
+
def download(self) -> str:
dataset_name = 'peopledaily'
data_dir = self._get_dataset_path(dataset_name=dataset_name)
-
+
return data_dir
diff --git a/fastNLP/io/loader/csv.py b/fastNLP/io/loader/csv.py
index 5195cc8e..0d6e35fa 100644
--- a/fastNLP/io/loader/csv.py
+++ b/fastNLP/io/loader/csv.py
@@ -1,7 +1,13 @@
+"""undocumented"""
+
+__all__ = [
+ "CSVLoader",
+]
+
+from .loader import Loader
+from ..file_reader import _read_csv
from ...core.dataset import DataSet
from ...core.instance import Instance
-from ..file_reader import _read_csv
-from .loader import Loader
class CSVLoader(Loader):
diff --git a/fastNLP/io/loader/cws.py b/fastNLP/io/loader/cws.py
index fab7639c..2fbb1091 100644
--- a/fastNLP/io/loader/cws.py
+++ b/fastNLP/io/loader/cws.py
@@ -1,11 +1,18 @@
-from .loader import Loader
-from ...core.dataset import DataSet
-from ...core.instance import Instance
+"""undocumented"""
+
+__all__ = [
+ "CWSLoader"
+]
+
import glob
import os
-import time
-import shutil
import random
+import shutil
+import time
+
+from .loader import Loader
+from ...core.dataset import DataSet
+from ...core.instance import Instance
class CWSLoader(Loader):
diff --git a/fastNLP/io/loader/json.py b/fastNLP/io/loader/json.py
index 8856b73a..012dee5a 100644
--- a/fastNLP/io/loader/json.py
+++ b/fastNLP/io/loader/json.py
@@ -1,7 +1,13 @@
+"""undocumented"""
+
+__all__ = [
+ "JsonLoader"
+]
+
+from .loader import Loader
+from ..file_reader import _read_json
from ...core.dataset import DataSet
from ...core.instance import Instance
-from ..file_reader import _read_json
-from .loader import Loader
class JsonLoader(Loader):
diff --git a/fastNLP/io/loader/loader.py b/fastNLP/io/loader/loader.py
index e7b419ac..22636a27 100644
--- a/fastNLP/io/loader/loader.py
+++ b/fastNLP/io/loader/loader.py
@@ -1,8 +1,15 @@
-from ...core.dataset import DataSet
-from .. import DataBundle
-from ..utils import check_loader_paths
+"""undocumented"""
+
+__all__ = [
+ "Loader"
+]
+
from typing import Union, Dict
+
+from .. import DataBundle
from ..file_utils import _get_dataset_url, get_cache_path, cached_path
+from ..utils import check_loader_paths
+from ...core.dataset import DataSet
class Loader:
diff --git a/fastNLP/io/loader/matching.py b/fastNLP/io/loader/matching.py
index 26455914..7f03ca3e 100644
--- a/fastNLP/io/loader/matching.py
+++ b/fastNLP/io/loader/matching.py
@@ -1,10 +1,21 @@
+"""undocumented"""
+
+__all__ = [
+ "MNLILoader",
+ "SNLILoader",
+ "QNLILoader",
+ "RTELoader",
+ "QuoraLoader",
+]
+
+import os
import warnings
-from .loader import Loader
+from typing import Union, Dict
+
from .json import JsonLoader
-from ...core.const import Const
+from .loader import Loader
from .. import DataBundle
-import os
-from typing import Union, Dict
+from ...core.const import Const
from ...core.dataset import DataSet
from ...core.instance import Instance
@@ -22,10 +33,11 @@ class MNLILoader(Loader):
"...", "...","."
"""
+
def __init__(self):
super().__init__()
-
- def _load(self, path:str):
+
+ def _load(self, path: str):
ds = DataSet()
with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header
@@ -50,8 +62,8 @@ class MNLILoader(Loader):
if raw_words1 and raw_words2 and target:
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target))
return ds
-
- def load(self, paths:str=None):
+
+ def load(self, paths: str = None):
"""
:param str paths: 传入数据所在目录,会在该目录下寻找dev_matched.tsv, dev_mismatched.tsv, test_matched.tsv,
@@ -64,13 +76,13 @@ class MNLILoader(Loader):
paths = self.download()
if not os.path.isdir(paths):
raise NotADirectoryError(f"{paths} is not a valid directory.")
-
- files = {'dev_matched':"dev_matched.tsv",
- "dev_mismatched":"dev_mismatched.tsv",
- "test_matched":"test_matched.tsv",
- "test_mismatched":"test_mismatched.tsv",
- "train":'train.tsv'}
-
+
+ files = {'dev_matched': "dev_matched.tsv",
+ "dev_mismatched": "dev_mismatched.tsv",
+ "test_matched": "test_matched.tsv",
+ "test_mismatched": "test_mismatched.tsv",
+ "train": 'train.tsv'}
+
datasets = {}
for name, filename in files.items():
filepath = os.path.join(paths, filename)
@@ -78,11 +90,11 @@ class MNLILoader(Loader):
if 'test' not in name:
raise FileNotFoundError(f"{name} not found in directory {filepath}.")
datasets[name] = self._load(filepath)
-
+
data_bundle = DataBundle(datasets=datasets)
-
+
return data_bundle
-
+
def download(self):
"""
如果你使用了这个数据,请引用
@@ -106,14 +118,15 @@ class SNLILoader(JsonLoader):
"...", "...", "."
"""
+
def __init__(self):
super().__init__(fields={
'sentence1': Const.RAW_WORDS(0),
'sentence2': Const.RAW_WORDS(1),
'gold_label': Const.TARGET,
})
-
- def load(self, paths: Union[str, Dict[str, str]]=None) -> DataBundle:
+
+ def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle:
"""
从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。
@@ -138,11 +151,11 @@ class SNLILoader(JsonLoader):
paths = _paths
else:
raise NotADirectoryError(f"{paths} is not a valid directory.")
-
+
datasets = {name: self._load(path) for name, path in paths.items()}
data_bundle = DataBundle(datasets=datasets)
return data_bundle
-
+
def download(self):
"""
如果您的文章使用了这份数据,请引用
@@ -169,12 +182,13 @@ class QNLILoader(JsonLoader):
test数据集没有target列
"""
+
def __init__(self):
super().__init__()
-
+
def _load(self, path):
ds = DataSet()
-
+
with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header
if path.endswith("test.tsv"):
@@ -198,7 +212,7 @@ class QNLILoader(JsonLoader):
if raw_words1 and raw_words2 and target:
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target))
return ds
-
+
def download(self):
"""
如果您的实验使用到了该数据,请引用
@@ -225,12 +239,13 @@ class RTELoader(Loader):
test数据集没有target列
"""
+
def __init__(self):
super().__init__()
-
- def _load(self, path:str):
+
+ def _load(self, path: str):
ds = DataSet()
-
+
with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header
if path.endswith("test.tsv"):
@@ -254,7 +269,7 @@ class RTELoader(Loader):
if raw_words1 and raw_words2 and target:
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target))
return ds
-
+
def download(self):
return self._get_dataset_path('rte')
@@ -281,12 +296,13 @@ class QuoraLoader(Loader):
"...","."
"""
+
def __init__(self):
super().__init__()
-
- def _load(self, path:str):
+
+ def _load(self, path: str):
ds = DataSet()
-
+
with open(path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
@@ -298,6 +314,6 @@ class QuoraLoader(Loader):
if raw_words1 and raw_words2 and target:
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target))
return ds
-
+
def download(self):
raise RuntimeError("Quora cannot be downloaded automatically.")
diff --git a/fastNLP/io/pipe/classification.py b/fastNLP/io/pipe/classification.py
index f42d5400..30c591a4 100644
--- a/fastNLP/io/pipe/classification.py
+++ b/fastNLP/io/pipe/classification.py
@@ -1,26 +1,39 @@
+"""undocumented"""
+
+__all__ = [
+ "YelpFullPipe",
+ "YelpPolarityPipe",
+ "SSTPipe",
+ "SST2Pipe",
+ 'IMDBPipe'
+]
+
+import re
+
from nltk import Tree
+from .pipe import Pipe
+from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance
from ..data_bundle import DataBundle
-from ...core.vocabulary import Vocabulary
-from ...core.const import Const
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader
+from ...core.const import Const
from ...core.dataset import DataSet
from ...core.instance import Instance
+from ...core.vocabulary import Vocabulary
-from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance
-from .pipe import Pipe
-import re
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+')
+
class _CLSPipe(Pipe):
"""
分类问题的基类,负责对classification的数据进行tokenize操作。默认是对raw_words列操作,然后生成words列
"""
- def __init__(self, tokenizer:str='spacy', lang='en'):
+
+ def __init__(self, tokenizer: str = 'spacy', lang='en'):
self.tokenizer = get_tokenizer(tokenizer, lang=lang)
-
+
def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None):
"""
将DataBundle中的数据进行tokenize
@@ -33,9 +46,9 @@ class _CLSPipe(Pipe):
new_field_name = new_field_name or field_name
for name, dataset in data_bundle.datasets.items():
dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name)
-
+
return data_bundle
-
+
def _granularize(self, data_bundle, tag_map):
"""
该函数对data_bundle中'target'列中的内容进行转换。
@@ -47,9 +60,9 @@ class _CLSPipe(Pipe):
"""
for name in list(data_bundle.datasets.keys()):
dataset = data_bundle.get_dataset(name)
- dataset.apply_field(lambda target:tag_map.get(target, -100), field_name=Const.TARGET,
+ dataset.apply_field(lambda target: tag_map.get(target, -100), field_name=Const.TARGET,
new_field_name=Const.TARGET)
- dataset.drop(lambda ins:ins[Const.TARGET] == -100)
+ dataset.drop(lambda ins: ins[Const.TARGET] == -100)
data_bundle.set_dataset(dataset, name)
return data_bundle
@@ -69,7 +82,7 @@ def _clean_str(words):
t = ''.join(tt)
if t != '':
words_collection.append(t)
-
+
return words_collection
@@ -89,19 +102,20 @@ class YelpFullPipe(_CLSPipe):
1、2归为1类,3归为1类,4、5归为1类;若为5, 则有5分类问题。
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。
"""
- def __init__(self, lower:bool=False, granularity=5, tokenizer:str='spacy'):
+
+ def __init__(self, lower: bool = False, granularity=5, tokenizer: str = 'spacy'):
super().__init__(tokenizer=tokenizer, lang='en')
self.lower = lower
assert granularity in (2, 3, 5), "granularity can only be 2,3,5."
self.granularity = granularity
-
- if granularity==2:
+
+ if granularity == 2:
self.tag_map = {"1": 0, "2": 0, "4": 1, "5": 1}
- elif granularity==3:
- self.tag_map = {"1": 0, "2": 0, "3":1, "4": 2, "5": 2}
+ elif granularity == 3:
+ self.tag_map = {"1": 0, "2": 0, "3": 1, "4": 2, "5": 2}
else:
self.tag_map = {"1": 0, "2": 1, "3": 2, "4": 3, "5": 4}
-
+
def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None):
"""
将DataBundle中的数据进行tokenize
@@ -116,7 +130,7 @@ class YelpFullPipe(_CLSPipe):
dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name)
dataset.apply_field(_clean_str, field_name=field_name, new_field_name=new_field_name)
return data_bundle
-
+
def process(self, data_bundle):
"""
传入的DataSet应该具备如下的结构
@@ -131,30 +145,30 @@ class YelpFullPipe(_CLSPipe):
:param data_bundle:
:return:
"""
-
+
# 复制一列words
data_bundle = _add_words_field(data_bundle, lower=self.lower)
-
+
# 进行tokenize
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT)
-
+
# 根据granularity设置tag
data_bundle = self._granularize(data_bundle, tag_map=self.tag_map)
-
+
# 删除空行
data_bundle = _drop_empty_instance(data_bundle, field_name=Const.INPUT)
-
+
# index
data_bundle = _indexize(data_bundle=data_bundle)
-
+
for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT)
-
+
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN)
data_bundle.set_target(Const.TARGET)
-
+
return data_bundle
-
+
def process_from_file(self, paths=None):
"""
@@ -179,27 +193,28 @@ class YelpPolarityPipe(_CLSPipe):
:param bool lower: 是否对输入进行小写化。
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。
"""
- def __init__(self, lower:bool=False, tokenizer:str='spacy'):
+
+ def __init__(self, lower: bool = False, tokenizer: str = 'spacy'):
super().__init__(tokenizer=tokenizer, lang='en')
self.lower = lower
-
+
def process(self, data_bundle):
# 复制一列words
data_bundle = _add_words_field(data_bundle, lower=self.lower)
-
+
# 进行tokenize
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT)
# index
data_bundle = _indexize(data_bundle=data_bundle)
-
+
for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT)
-
+
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN)
data_bundle.set_target(Const.TARGET)
-
+
return data_bundle
-
+
def process_from_file(self, paths=None):
"""
@@ -230,7 +245,7 @@ class SSTPipe(_CLSPipe):
0、1归为1类,2归为1类,3、4归为1类;若为5, 则有5分类问题。
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。
"""
-
+
def __init__(self, subtree=False, train_subtree=True, lower=False, granularity=5, tokenizer='spacy'):
super().__init__(tokenizer=tokenizer, lang='en')
self.subtree = subtree
@@ -238,15 +253,15 @@ class SSTPipe(_CLSPipe):
self.lower = lower
assert granularity in (2, 3, 5), "granularity can only be 2,3,5."
self.granularity = granularity
-
- if granularity==2:
+
+ if granularity == 2:
self.tag_map = {"0": 0, "1": 0, "3": 1, "4": 1}
- elif granularity==3:
- self.tag_map = {"0": 0, "1": 0, "2":1, "3": 2, "4": 2}
+ elif granularity == 3:
+ self.tag_map = {"0": 0, "1": 0, "2": 1, "3": 2, "4": 2}
else:
self.tag_map = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4}
-
- def process(self, data_bundle:DataBundle):
+
+ def process(self, data_bundle: DataBundle):
"""
对DataBundle中的数据进行预处理。输入的DataSet应该至少拥有raw_words这一列,且内容类似与
@@ -277,26 +292,26 @@ class SSTPipe(_CLSPipe):
instance = Instance(raw_words=' '.join(tree.leaves()), target=tree.label())
ds.append(instance)
data_bundle.set_dataset(ds, name)
-
+
_add_words_field(data_bundle, lower=self.lower)
-
+
# 进行tokenize
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT)
-
+
# 根据granularity设置tag
data_bundle = self._granularize(data_bundle, tag_map=self.tag_map)
-
+
# index
data_bundle = _indexize(data_bundle=data_bundle)
-
+
for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT)
-
+
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN)
data_bundle.set_target(Const.TARGET)
-
+
return data_bundle
-
+
def process_from_file(self, paths=None):
data_bundle = SSTLoader().load(paths)
return self.process(data_bundle=data_bundle)
@@ -316,11 +331,12 @@ class SST2Pipe(_CLSPipe):
:param bool lower: 是否对输入进行小写化。
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。
"""
+
def __init__(self, lower=False, tokenizer='spacy'):
super().__init__(tokenizer=tokenizer, lang='en')
self.lower = lower
-
- def process(self, data_bundle:DataBundle):
+
+ def process(self, data_bundle: DataBundle):
"""
可以处理的DataSet应该具备如下的结构
@@ -335,15 +351,15 @@ class SST2Pipe(_CLSPipe):
:return:
"""
_add_words_field(data_bundle, self.lower)
-
+
data_bundle = self._tokenize(data_bundle=data_bundle)
-
+
src_vocab = Vocabulary()
src_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.INPUT,
- no_create_entry_dataset=[dataset for name,dataset in data_bundle.datasets.items() if
+ no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if
name != 'train'])
src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT)
-
+
tgt_vocab = Vocabulary(unknown=None, padding=None)
tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET)
datasets = []
@@ -351,18 +367,18 @@ class SST2Pipe(_CLSPipe):
if dataset.has_field(Const.TARGET):
datasets.append(dataset)
tgt_vocab.index_dataset(*datasets, field_name=Const.TARGET)
-
+
data_bundle.set_vocab(src_vocab, Const.INPUT)
data_bundle.set_vocab(tgt_vocab, Const.TARGET)
-
+
for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT)
-
+
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN)
data_bundle.set_target(Const.TARGET)
-
+
return data_bundle
-
+
def process_from_file(self, paths=None):
"""
@@ -390,11 +406,12 @@ class IMDBPipe(_CLSPipe):
:param bool lower: 是否将words列的数据小写。
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。
"""
- def __init__(self, lower:bool=False, tokenizer:str='spacy'):
+
+ def __init__(self, lower: bool = False, tokenizer: str = 'spacy'):
super().__init__(tokenizer=tokenizer, lang='en')
self.lower = lower
-
- def process(self, data_bundle:DataBundle):
+
+ def process(self, data_bundle: DataBundle):
"""
期待的DataBunlde中输入的DataSet应该类似于如下,有两个field,raw_words和target,且均为str类型
@@ -409,25 +426,26 @@ class IMDBPipe(_CLSPipe):
target列应该为str。
:return: DataBundle
"""
+
# 替换
def replace_br(raw_words):
raw_words = raw_words.replace("
", ' ')
return raw_words
-
+
for name, dataset in data_bundle.datasets.items():
dataset.apply_field(replace_br, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD)
-
+
_add_words_field(data_bundle, lower=self.lower)
self._tokenize(data_bundle, field_name=Const.INPUT, new_field_name=Const.INPUT)
_indexize(data_bundle)
-
+
for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT)
dataset.set_input(Const.INPUT, Const.INPUT_LEN)
dataset.set_target(Const.TARGET)
-
+
return data_bundle
-
+
def process_from_file(self, paths=None):
"""
@@ -437,8 +455,5 @@ class IMDBPipe(_CLSPipe):
# 读取数据
data_bundle = IMDBLoader().load(paths)
data_bundle = self.process(data_bundle)
-
+
return data_bundle
-
-
-
diff --git a/fastNLP/io/pipe/conll.py b/fastNLP/io/pipe/conll.py
index 617d1236..2efec8e0 100644
--- a/fastNLP/io/pipe/conll.py
+++ b/fastNLP/io/pipe/conll.py
@@ -1,13 +1,25 @@
+"""undocumented"""
+
+__all__ = [
+ "Conll2003NERPipe",
+ "Conll2003Pipe",
+ "OntoNotesNERPipe",
+ "MsraNERPipe",
+ "PeopleDailyPipe",
+ "WeiboNERPipe"
+]
+
from .pipe import Pipe
-from .. import DataBundle
+from .utils import _add_chars_field
+from .utils import _indexize, _add_words_field
from .utils import iob2, iob2bioes
-from ...core.const import Const
+from .. import DataBundle
from ..loader.conll import Conll2003NERLoader, OntoNotesNERLoader
-from .utils import _indexize, _add_words_field
-from .utils import _add_chars_field
from ..loader.conll import PeopleDailyNERLoader, WeiboNERLoader, MsraNERLoader, ConllLoader
+from ...core.const import Const
from ...core.vocabulary import Vocabulary
+
class _NERPipe(Pipe):
"""
NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表
@@ -20,14 +32,14 @@ class _NERPipe(Pipe):
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。
"""
-
+
def __init__(self, encoding_type: str = 'bio', lower: bool = False):
if encoding_type == 'bio':
self.convert_tag = iob2
else:
self.convert_tag = lambda words: iob2bioes(iob2(words))
self.lower = lower
-
+
def process(self, data_bundle: DataBundle) -> DataBundle:
"""
支持的DataSet的field为
@@ -46,21 +58,21 @@ class _NERPipe(Pipe):
# 转换tag
for name, dataset in data_bundle.datasets.items():
dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET)
-
+
_add_words_field(data_bundle, lower=self.lower)
-
+
# index
_indexize(data_bundle)
-
+
input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN]
target_fields = [Const.TARGET, Const.INPUT_LEN]
-
+
for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT)
-
+
data_bundle.set_input(*input_fields)
data_bundle.set_target(*target_fields)
-
+
return data_bundle
@@ -84,7 +96,7 @@ class Conll2003NERPipe(_NERPipe):
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。
"""
-
+
def process_from_file(self, paths) -> DataBundle:
"""
@@ -94,7 +106,7 @@ class Conll2003NERPipe(_NERPipe):
# 读取数据
data_bundle = Conll2003NERLoader().load(paths)
data_bundle = self.process(data_bundle)
-
+
return data_bundle
@@ -125,8 +137,8 @@ class Conll2003Pipe(Pipe):
else:
self.ner_convert_tag = lambda tags: iob2bioes(iob2(tags))
self.lower = lower
-
- def process(self, data_bundle)->DataBundle:
+
+ def process(self, data_bundle) -> DataBundle:
"""
输入的DataSet应该类似于如下的形式
@@ -145,9 +157,9 @@ class Conll2003Pipe(Pipe):
dataset.drop(lambda x: "-DOCSTART-" in x[Const.RAW_WORD])
dataset.apply_field(self.chunk_convert_tag, field_name='chunk', new_field_name='chunk')
dataset.apply_field(self.ner_convert_tag, field_name='ner', new_field_name='ner')
-
+
_add_words_field(data_bundle, lower=self.lower)
-
+
# index
_indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=['pos', 'ner'])
# chunk中存在一些tag只在dev中出现,没在train中
@@ -155,18 +167,18 @@ class Conll2003Pipe(Pipe):
tgt_vocab.from_dataset(*data_bundle.datasets.values(), field_name='chunk')
tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name='chunk')
data_bundle.set_vocab(tgt_vocab, 'chunk')
-
+
input_fields = [Const.INPUT, Const.INPUT_LEN]
target_fields = ['pos', 'ner', 'chunk', Const.INPUT_LEN]
-
+
for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT)
-
+
data_bundle.set_input(*input_fields)
data_bundle.set_target(*target_fields)
-
+
return data_bundle
-
+
def process_from_file(self, paths):
"""
@@ -194,7 +206,7 @@ class OntoNotesNERPipe(_NERPipe):
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。
"""
-
+
def process_from_file(self, paths):
data_bundle = OntoNotesNERLoader().load(paths)
return self.process(data_bundle)
@@ -211,13 +223,13 @@ class _CNNERPipe(Pipe):
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。
"""
-
+
def __init__(self, encoding_type: str = 'bio'):
if encoding_type == 'bio':
self.convert_tag = iob2
else:
self.convert_tag = lambda words: iob2bioes(iob2(words))
-
+
def process(self, data_bundle: DataBundle) -> DataBundle:
"""
支持的DataSet的field为
@@ -239,21 +251,21 @@ class _CNNERPipe(Pipe):
# 转换tag
for name, dataset in data_bundle.datasets.items():
dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET)
-
+
_add_chars_field(data_bundle, lower=False)
-
+
# index
_indexize(data_bundle, input_field_names=Const.CHAR_INPUT, target_field_names=Const.TARGET)
-
+
input_fields = [Const.TARGET, Const.CHAR_INPUT, Const.INPUT_LEN]
target_fields = [Const.TARGET, Const.INPUT_LEN]
-
+
for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.CHAR_INPUT)
-
+
data_bundle.set_input(*input_fields)
data_bundle.set_target(*target_fields)
-
+
return data_bundle
@@ -272,6 +284,7 @@ class MsraNERPipe(_CNNERPipe):
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。
"""
+
def process_from_file(self, paths=None) -> DataBundle:
data_bundle = MsraNERLoader().load(paths)
return self.process(data_bundle)
@@ -291,6 +304,7 @@ class PeopleDailyPipe(_CNNERPipe):
raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。
"""
+
def process_from_file(self, paths=None) -> DataBundle:
data_bundle = PeopleDailyNERLoader().load(paths)
return self.process(data_bundle)
@@ -312,6 +326,7 @@ class WeiboNERPipe(_CNNERPipe):
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。
"""
+
def process_from_file(self, paths=None) -> DataBundle:
data_bundle = WeiboNERLoader().load(paths)
return self.process(data_bundle)
diff --git a/fastNLP/io/pipe/cws.py b/fastNLP/io/pipe/cws.py
index 4ca0219c..748cf10a 100644
--- a/fastNLP/io/pipe/cws.py
+++ b/fastNLP/io/pipe/cws.py
@@ -1,3 +1,9 @@
+"""undocumented"""
+
+__all__ = [
+ "CWSPipe"
+]
+
import re
from itertools import chain
diff --git a/fastNLP/io/pipe/matching.py b/fastNLP/io/pipe/matching.py
index ffa6375b..699438c8 100644
--- a/fastNLP/io/pipe/matching.py
+++ b/fastNLP/io/pipe/matching.py
@@ -1,9 +1,25 @@
+"""undocumented"""
+
+__all__ = [
+ "MatchingBertPipe",
+ "RTEBertPipe",
+ "SNLIBertPipe",
+ "QuoraBertPipe",
+ "QNLIBertPipe",
+ "MNLIBertPipe",
+ "MatchingPipe",
+ "RTEPipe",
+ "SNLIPipe",
+ "QuoraPipe",
+ "QNLIPipe",
+ "MNLIPipe",
+]
from .pipe import Pipe
from .utils import get_tokenizer
+from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader
from ...core.const import Const
from ...core.vocabulary import Vocabulary
-from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader
class MatchingBertPipe(Pipe):
@@ -24,12 +40,13 @@ class MatchingBertPipe(Pipe):
:param bool lower: 是否将word小写化。
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。
"""
- def __init__(self, lower=False, tokenizer: str='raw'):
+
+ def __init__(self, lower=False, tokenizer: str = 'raw'):
super().__init__()
-
+
self.lower = bool(lower)
self.tokenizer = get_tokenizer(tokenizer=tokenizer)
-
+
def _tokenize(self, data_bundle, field_names, new_field_names):
"""
@@ -43,62 +60,62 @@ class MatchingBertPipe(Pipe):
dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name,
new_field_name=new_field_name)
return data_bundle
-
+
def process(self, data_bundle):
for dataset in data_bundle.datasets.values():
if dataset.has_field(Const.TARGET):
dataset.drop(lambda x: x[Const.TARGET] == '-')
-
+
for name, dataset in data_bundle.datasets.items():
dataset.copy_field(Const.RAW_WORDS(0), Const.INPUTS(0), )
dataset.copy_field(Const.RAW_WORDS(1), Const.INPUTS(1), )
-
+
if self.lower:
for name, dataset in data_bundle.datasets.items():
dataset[Const.INPUTS(0)].lower()
dataset[Const.INPUTS(1)].lower()
-
+
data_bundle = self._tokenize(data_bundle, [Const.INPUTS(0), Const.INPUTS(1)],
[Const.INPUTS(0), Const.INPUTS(1)])
-
+
# concat两个words
def concat(ins):
words0 = ins[Const.INPUTS(0)]
words1 = ins[Const.INPUTS(1)]
words = words0 + ['[SEP]'] + words1
return words
-
+
for name, dataset in data_bundle.datasets.items():
dataset.apply(concat, new_field_name=Const.INPUT)
dataset.delete_field(Const.INPUTS(0))
dataset.delete_field(Const.INPUTS(1))
-
+
word_vocab = Vocabulary()
word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name],
field_name=Const.INPUT,
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if
'train' not in name])
word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT)
-
+
target_vocab = Vocabulary(padding=None, unknown=None)
target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET)
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if
dataset.has_field(Const.TARGET)]
target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET)
-
+
data_bundle.set_vocab(word_vocab, Const.INPUT)
data_bundle.set_vocab(target_vocab, Const.TARGET)
-
+
input_fields = [Const.INPUT, Const.INPUT_LEN]
target_fields = [Const.TARGET]
-
+
for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT)
dataset.set_input(*input_fields, flag=True)
for fields in target_fields:
if dataset.has_field(fields):
dataset.set_target(fields, flag=True)
-
+
return data_bundle
@@ -150,12 +167,13 @@ class MatchingPipe(Pipe):
:param bool lower: 是否将所有raw_words转为小写。
:param str tokenizer: 将原始数据tokenize的方式。支持spacy, raw. spacy是使用spacy切分,raw就是用空格切分。
"""
- def __init__(self, lower=False, tokenizer: str='raw'):
+
+ def __init__(self, lower=False, tokenizer: str = 'raw'):
super().__init__()
-
+
self.lower = bool(lower)
self.tokenizer = get_tokenizer(tokenizer=tokenizer)
-
+
def _tokenize(self, data_bundle, field_names, new_field_names):
"""
@@ -169,7 +187,7 @@ class MatchingPipe(Pipe):
dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name,
new_field_name=new_field_name)
return data_bundle
-
+
def process(self, data_bundle):
"""
接受的DataBundle中的DataSet应该具有以下的field, target列可以没有
@@ -186,35 +204,35 @@ class MatchingPipe(Pipe):
"""
data_bundle = self._tokenize(data_bundle, [Const.RAW_WORDS(0), Const.RAW_WORDS(1)],
[Const.INPUTS(0), Const.INPUTS(1)])
-
+
for dataset in data_bundle.datasets.values():
if dataset.has_field(Const.TARGET):
dataset.drop(lambda x: x[Const.TARGET] == '-')
-
+
if self.lower:
for name, dataset in data_bundle.datasets.items():
dataset[Const.INPUTS(0)].lower()
dataset[Const.INPUTS(1)].lower()
-
+
word_vocab = Vocabulary()
word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name],
field_name=[Const.INPUTS(0), Const.INPUTS(1)],
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if
'train' not in name])
word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=[Const.INPUTS(0), Const.INPUTS(1)])
-
+
target_vocab = Vocabulary(padding=None, unknown=None)
target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET)
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if
dataset.has_field(Const.TARGET)]
target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET)
-
+
data_bundle.set_vocab(word_vocab, Const.INPUTS(0))
data_bundle.set_vocab(target_vocab, Const.TARGET)
-
+
input_fields = [Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1)]
target_fields = [Const.TARGET]
-
+
for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUTS(0), Const.INPUT_LENS(0))
dataset.add_seq_len(Const.INPUTS(1), Const.INPUT_LENS(1))
@@ -222,7 +240,7 @@ class MatchingPipe(Pipe):
for fields in target_fields:
if dataset.has_field(fields):
dataset.set_target(fields, flag=True)
-
+
return data_bundle
@@ -254,4 +272,3 @@ class MNLIPipe(MatchingPipe):
def process_from_file(self, paths=None):
data_bundle = MNLILoader().load(paths)
return self.process(data_bundle)
-
diff --git a/fastNLP/io/pipe/pipe.py b/fastNLP/io/pipe/pipe.py
index cc45dee4..a1435fd3 100644
--- a/fastNLP/io/pipe/pipe.py
+++ b/fastNLP/io/pipe/pipe.py
@@ -1,3 +1,9 @@
+"""undocumented"""
+
+__all__ = [
+ "Pipe",
+]
+
from .. import DataBundle
diff --git a/fastNLP/io/pipe/utils.py b/fastNLP/io/pipe/utils.py
index 8facd8d9..f32f58b7 100644
--- a/fastNLP/io/pipe/utils.py
+++ b/fastNLP/io/pipe/utils.py
@@ -1,8 +1,18 @@
+"""undocumented"""
+
+__all__ = [
+ "iob2",
+ "iob2bioes",
+ "get_tokenizer",
+]
+
from typing import List
-from ...core.vocabulary import Vocabulary
+
from ...core.const import Const
+from ...core.vocabulary import Vocabulary
+
-def iob2(tags:List[str])->List[str]:
+def iob2(tags: List[str]) -> List[str]:
"""
检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两种格式的区别见
https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format
@@ -25,7 +35,8 @@ def iob2(tags:List[str])->List[str]:
tags[i] = "B" + tag[1:]
return tags
-def iob2bioes(tags:List[str])->List[str]:
+
+def iob2bioes(tags: List[str]) -> List[str]:
"""
将iob的tag转换为bioes编码
:param tags:
@@ -38,12 +49,12 @@ def iob2bioes(tags:List[str])->List[str]:
else:
split = tag.split('-')[0]
if split == 'B':
- if i+1!=len(tags) and tags[i+1].split('-')[0] == 'I':
+ if i + 1 != len(tags) and tags[i + 1].split('-')[0] == 'I':
new_tags.append(tag)
else:
new_tags.append(tag.replace('B-', 'S-'))
elif split == 'I':
- if i + 1List[str]:
return new_tags
-def get_tokenizer(tokenizer:str, lang='en'):
+def get_tokenizer(tokenizer: str, lang='en'):
"""
:param str tokenizer: 获取tokenzier方法
@@ -97,13 +108,13 @@ def _indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=Con
name != 'train'])
src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=input_field_name)
data_bundle.set_vocab(src_vocab, input_field_name)
-
+
for target_field_name in target_field_names:
tgt_vocab = Vocabulary(unknown=None, padding=None)
tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=target_field_name)
tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name=target_field_name)
data_bundle.set_vocab(tgt_vocab, target_field_name)
-
+
return data_bundle
@@ -116,7 +127,7 @@ def _add_words_field(data_bundle, lower=False):
:return: 传入的DataBundle
"""
data_bundle.copy_field(field_name=Const.RAW_WORD, new_field_name=Const.INPUT, ignore_miss_dataset=True)
-
+
if lower:
for name, dataset in data_bundle.datasets.items():
dataset[Const.INPUT].lower()
@@ -132,7 +143,7 @@ def _add_chars_field(data_bundle, lower=False):
:return: 传入的DataBundle
"""
data_bundle.copy_field(field_name=Const.RAW_CHAR, new_field_name=Const.CHAR_INPUT, ignore_miss_dataset=True)
-
+
if lower:
for name, dataset in data_bundle.datasets.items():
dataset[Const.CHAR_INPUT].lower()
@@ -147,6 +158,7 @@ def _drop_empty_instance(data_bundle, field_name):
:param str field_name: 对哪个field进行检查,如果为None,则任意field为空都会删掉
:return: 传入的DataBundle
"""
+
def empty_instance(ins):
if field_name:
field_value = ins[field_name]
@@ -157,10 +169,8 @@ def _drop_empty_instance(data_bundle, field_name):
if field_value in ((), {}, [], ''):
return True
return False
-
+
for name, dataset in data_bundle.datasets.items():
dataset.drop(empty_instance)
-
+
return data_bundle
-
-
diff --git a/fastNLP/io/utils.py b/fastNLP/io/utils.py
index faec2a55..e1de2ae7 100644
--- a/fastNLP/io/utils.py
+++ b/fastNLP/io/utils.py
@@ -1,10 +1,20 @@
-import os
+"""
+.. todo::
+ doc
+"""
-from typing import Union, Dict
+__all__ = [
+ "check_loader_paths"
+]
+
+import os
from pathlib import Path
+from typing import Union, Dict
+
from ..core import logger
-def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
+
+def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]:
"""
检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果::
@@ -33,11 +43,13 @@ def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
path_pair = ('train', filename)
if 'dev' in filename:
if path_pair:
- raise Exception("File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0]))
+ raise Exception(
+ "File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0]))
path_pair = ('dev', filename)
if 'test' in filename:
if path_pair:
- raise Exception("File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0]))
+ raise Exception(
+ "File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0]))
path_pair = ('test', filename)
if path_pair:
files[path_pair[0]] = os.path.join(paths, path_pair[1])
@@ -46,7 +58,7 @@ def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
return files
else:
raise FileNotFoundError(f"{paths} is not a valid file path.")
-
+
elif isinstance(paths, dict):
if paths:
if 'train' not in paths:
@@ -65,6 +77,7 @@ def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
else:
raise TypeError(f"paths only supports str and dict. not {type(paths)}.")
+
def get_tokenizer():
try:
import spacy