@@ -1,10 +1,15 @@ | |||
""" | |||
.. todo:: | |||
doc | |||
""" | |||
__all__ = [ | |||
'DataBundle', | |||
] | |||
import _pickle as pickle | |||
from typing import Union, Dict | |||
import os | |||
from typing import Union, Dict | |||
from ..core.dataset import DataSet | |||
from ..core.vocabulary import Vocabulary | |||
@@ -1,4 +1,4 @@ | |||
""" | |||
"""undocumented | |||
.. warning:: | |||
本模块将在 `0.5.0版本` 中被废弃,由 :mod:`~fastNLP.io.loader` 和 :mod:`~fastNLP.io.pipe` 模块替代。 | |||
@@ -23,10 +23,10 @@ __all__ = [ | |||
] | |||
from .data_bundle import DataSetLoader | |||
from .file_reader import _read_csv, _read_json | |||
from ..core.dataset import DataSet | |||
from ..core.instance import Instance | |||
from .file_reader import _read_csv, _read_json | |||
from .data_bundle import DataSetLoader | |||
class JsonLoader(DataSetLoader): | |||
@@ -1,17 +1,22 @@ | |||
""" | |||
.. todo:: | |||
doc | |||
""" | |||
__all__ = [ | |||
"EmbedLoader", | |||
"EmbeddingOption", | |||
] | |||
import logging | |||
import os | |||
import warnings | |||
import numpy as np | |||
from ..core.vocabulary import Vocabulary | |||
from .data_bundle import BaseLoader | |||
from ..core.utils import Option | |||
import logging | |||
from ..core.vocabulary import Vocabulary | |||
class EmbeddingOption(Option): | |||
def __init__(self, | |||
@@ -1,7 +1,11 @@ | |||
""" | |||
"""undocumented | |||
此模块用于给其它模块提供读取文件的函数,没有为用户提供 API | |||
""" | |||
__all__ = [] | |||
import json | |||
from ..core import logger | |||
@@ -24,8 +28,8 @@ def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): | |||
headers = headers.split(sep) | |||
start_idx += 1 | |||
elif not isinstance(headers, (list, tuple)): | |||
raise TypeError("headers should be list or tuple, not {}." \ | |||
.format(type(headers))) | |||
raise TypeError("headers should be list or tuple, not {}." \ | |||
.format(type(headers))) | |||
for line_idx, line in enumerate(f, start_idx): | |||
contents = line.rstrip('\r\n').split(sep) | |||
if len(contents) != len(headers): | |||
@@ -82,6 +86,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||
:if False, raise ValueError when reading invalid data. default: True | |||
:return: generator, every time yield (line number, conll item) | |||
""" | |||
def parse_conll(sample): | |||
sample = list(map(list, zip(*sample))) | |||
sample = [sample[i] for i in indexes] | |||
@@ -89,14 +94,15 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||
if len(f) <= 0: | |||
raise ValueError('empty field') | |||
return sample | |||
with open(path, 'r', encoding=encoding) as f: | |||
sample = [] | |||
start = next(f).strip() | |||
if start!='': | |||
if start != '': | |||
sample.append(start.split()) | |||
for line_idx, line in enumerate(f, 1): | |||
line = line.strip() | |||
if line=='': | |||
if line == '': | |||
if len(sample): | |||
try: | |||
res = parse_conll(sample) | |||
@@ -1,12 +1,27 @@ | |||
""" | |||
.. todo:: | |||
doc | |||
""" | |||
__all__ = [ | |||
"cached_path", | |||
"get_filepath", | |||
"get_cache_path", | |||
"split_filename_suffix", | |||
"get_from_cache", | |||
] | |||
import os | |||
import re | |||
import shutil | |||
import tempfile | |||
from pathlib import Path | |||
from urllib.parse import urlparse | |||
import re | |||
import requests | |||
import tempfile | |||
from tqdm import tqdm | |||
import shutil | |||
from requests import HTTPError | |||
from tqdm import tqdm | |||
from ..core import logger | |||
PRETRAINED_BERT_MODEL_DIR = { | |||
@@ -1,12 +1,24 @@ | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
from .loader import Loader | |||
import warnings | |||
"""undocumented""" | |||
__all__ = [ | |||
"YelpLoader", | |||
"YelpFullLoader", | |||
"YelpPolarityLoader", | |||
"IMDBLoader", | |||
"SSTLoader", | |||
"SST2Loader", | |||
] | |||
import glob | |||
import os | |||
import random | |||
import shutil | |||
import glob | |||
import time | |||
import warnings | |||
from .loader import Loader | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
class YelpLoader(Loader): | |||
@@ -58,7 +70,7 @@ class YelpLoader(Loader): | |||
class YelpFullLoader(YelpLoader): | |||
def download(self, dev_ratio: float = 0.1, re_download:bool=False): | |||
def download(self, dev_ratio: float = 0.1, re_download: bool = False): | |||
""" | |||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | |||
@@ -127,7 +139,7 @@ class YelpPolarityLoader(YelpLoader): | |||
if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 | |||
shutil.rmtree(data_dir) | |||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
if not os.path.exists(os.path.join(data_dir, 'dev.csv')): | |||
if dev_ratio > 0: | |||
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | |||
@@ -1,15 +1,28 @@ | |||
from typing import Dict, Union | |||
"""undocumented""" | |||
__all__ = [ | |||
"ConllLoader", | |||
"Conll2003Loader", | |||
"Conll2003NERLoader", | |||
"OntoNotesNERLoader", | |||
"CTBLoader", | |||
"CNNERLoader", | |||
"MsraNERLoader", | |||
"WeiboNERLoader", | |||
"PeopleDailyNERLoader" | |||
] | |||
from .loader import Loader | |||
from ...core.dataset import DataSet | |||
from ..file_reader import _read_conll | |||
from ...core.instance import Instance | |||
from ...core.const import Const | |||
import glob | |||
import os | |||
import random | |||
import shutil | |||
import time | |||
import random | |||
from .loader import Loader | |||
from ..file_reader import _read_conll | |||
from ...core.const import Const | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
class ConllLoader(Loader): | |||
@@ -47,6 +60,7 @@ class ConllLoader(Loader): | |||
:param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` | |||
""" | |||
def __init__(self, headers, indexes=None, dropna=True): | |||
super(ConllLoader, self).__init__() | |||
if not isinstance(headers, (list, tuple)): | |||
@@ -60,7 +74,7 @@ class ConllLoader(Loader): | |||
if len(indexes) != len(headers): | |||
raise ValueError | |||
self.indexes = indexes | |||
def _load(self, path): | |||
""" | |||
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 | |||
@@ -101,12 +115,13 @@ class Conll2003Loader(ConllLoader): | |||
"[...]", "[...]", "[...]", "[...]" | |||
""" | |||
def __init__(self): | |||
headers = [ | |||
'raw_words', 'pos', 'chunk', 'ner', | |||
] | |||
super(Conll2003Loader, self).__init__(headers=headers) | |||
def _load(self, path): | |||
""" | |||
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 | |||
@@ -127,7 +142,7 @@ class Conll2003Loader(ConllLoader): | |||
ins = {h: data[i] for i, h in enumerate(self.headers)} | |||
ds.append(Instance(**ins)) | |||
return ds | |||
def download(self, output_dir=None): | |||
raise RuntimeError("conll2003 cannot be downloaded automatically.") | |||
@@ -158,12 +173,13 @@ class Conll2003NERLoader(ConllLoader): | |||
"[...]", "[...]" | |||
""" | |||
def __init__(self): | |||
headers = [ | |||
'raw_words', 'target', | |||
] | |||
super().__init__(headers=headers, indexes=[0, 3]) | |||
def _load(self, path): | |||
""" | |||
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 | |||
@@ -184,7 +200,7 @@ class Conll2003NERLoader(ConllLoader): | |||
ins = {h: data[i] for i, h in enumerate(self.headers)} | |||
ds.append(Instance(**ins)) | |||
return ds | |||
def download(self): | |||
raise RuntimeError("conll2003 cannot be downloaded automatically.") | |||
@@ -204,13 +220,13 @@ class OntoNotesNERLoader(ConllLoader): | |||
"[...]", "[...]" | |||
""" | |||
def __init__(self): | |||
super().__init__(headers=[Const.RAW_WORD, Const.TARGET], indexes=[3, 10]) | |||
def _load(self, path:str): | |||
def _load(self, path: str): | |||
dataset = super()._load(path) | |||
def convert_to_bio(tags): | |||
bio_tags = [] | |||
flag = None | |||
@@ -227,7 +243,7 @@ class OntoNotesNERLoader(ConllLoader): | |||
flag = None | |||
bio_tags.append(bio_label) | |||
return bio_tags | |||
def convert_word(words): | |||
converted_words = [] | |||
for word in words: | |||
@@ -236,7 +252,7 @@ class OntoNotesNERLoader(ConllLoader): | |||
converted_words.append(word) | |||
continue | |||
# 以下是由于这些符号被转义了,再转回来 | |||
tfrs = {'-LRB-':'(', | |||
tfrs = {'-LRB-': '(', | |||
'-RRB-': ')', | |||
'-LSB-': '[', | |||
'-RSB-': ']', | |||
@@ -248,12 +264,12 @@ class OntoNotesNERLoader(ConllLoader): | |||
else: | |||
converted_words.append(word) | |||
return converted_words | |||
dataset.apply_field(convert_word, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD) | |||
dataset.apply_field(convert_to_bio, field_name=Const.TARGET, new_field_name=Const.TARGET) | |||
return dataset | |||
def download(self): | |||
raise RuntimeError("Ontonotes cannot be downloaded automatically, you can refer " | |||
"https://github.com/yhcc/OntoNotes-5.0-NER to download and preprocess.") | |||
@@ -262,13 +278,13 @@ class OntoNotesNERLoader(ConllLoader): | |||
class CTBLoader(Loader): | |||
def __init__(self): | |||
super().__init__() | |||
def _load(self, path:str): | |||
def _load(self, path: str): | |||
pass | |||
class CNNERLoader(Loader): | |||
def _load(self, path:str): | |||
def _load(self, path: str): | |||
""" | |||
支持加载形如以下格式的内容,一行两列,以空格隔开两个sample | |||
@@ -331,10 +347,11 @@ class MsraNERLoader(CNNERLoader): | |||
"[...]", "[...]" | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def download(self, dev_ratio:float=0.1, re_download:bool=False)->str: | |||
def download(self, dev_ratio: float = 0.1, re_download: bool = False) -> str: | |||
""" | |||
自动下载MSAR-NER的数据,如果你使用该数据,请引用 Gina-Anne Levow, 2006, The Third International Chinese Language | |||
Processing Bakeoff: Word Segmentation and Named Entity Recognition. | |||
@@ -356,7 +373,7 @@ class MsraNERLoader(CNNERLoader): | |||
if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 | |||
shutil.rmtree(data_dir) | |||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
if not os.path.exists(os.path.join(data_dir, 'dev.conll')): | |||
if dev_ratio > 0: | |||
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | |||
@@ -380,15 +397,15 @@ class MsraNERLoader(CNNERLoader): | |||
finally: | |||
if os.path.exists(os.path.join(data_dir, 'middle_file.conll')): | |||
os.remove(os.path.join(data_dir, 'middle_file.conll')) | |||
return data_dir | |||
class WeiboNERLoader(CNNERLoader): | |||
def __init__(self): | |||
super().__init__() | |||
def download(self)->str: | |||
def download(self) -> str: | |||
""" | |||
自动下载Weibo-NER的数据,如果你使用了该数据,请引用 Nanyun Peng and Mark Dredze, 2015, Named Entity Recognition for | |||
Chinese Social Media with Jointly Trained Embeddings. | |||
@@ -397,7 +414,7 @@ class WeiboNERLoader(CNNERLoader): | |||
""" | |||
dataset_name = 'weibo-ner' | |||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
return data_dir | |||
@@ -427,11 +444,12 @@ class PeopleDailyNERLoader(CNNERLoader): | |||
"[...]", "[...]" | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def download(self) -> str: | |||
dataset_name = 'peopledaily' | |||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
return data_dir |
@@ -1,7 +1,13 @@ | |||
"""undocumented""" | |||
__all__ = [ | |||
"CSVLoader", | |||
] | |||
from .loader import Loader | |||
from ..file_reader import _read_csv | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
from ..file_reader import _read_csv | |||
from .loader import Loader | |||
class CSVLoader(Loader): | |||
@@ -1,11 +1,18 @@ | |||
from .loader import Loader | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
"""undocumented""" | |||
__all__ = [ | |||
"CWSLoader" | |||
] | |||
import glob | |||
import os | |||
import time | |||
import shutil | |||
import random | |||
import shutil | |||
import time | |||
from .loader import Loader | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
class CWSLoader(Loader): | |||
@@ -1,7 +1,13 @@ | |||
"""undocumented""" | |||
__all__ = [ | |||
"JsonLoader" | |||
] | |||
from .loader import Loader | |||
from ..file_reader import _read_json | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
from ..file_reader import _read_json | |||
from .loader import Loader | |||
class JsonLoader(Loader): | |||
@@ -1,8 +1,15 @@ | |||
from ...core.dataset import DataSet | |||
from .. import DataBundle | |||
from ..utils import check_loader_paths | |||
"""undocumented""" | |||
__all__ = [ | |||
"Loader" | |||
] | |||
from typing import Union, Dict | |||
from .. import DataBundle | |||
from ..file_utils import _get_dataset_url, get_cache_path, cached_path | |||
from ..utils import check_loader_paths | |||
from ...core.dataset import DataSet | |||
class Loader: | |||
@@ -1,10 +1,21 @@ | |||
"""undocumented""" | |||
__all__ = [ | |||
"MNLILoader", | |||
"SNLILoader", | |||
"QNLILoader", | |||
"RTELoader", | |||
"QuoraLoader", | |||
] | |||
import os | |||
import warnings | |||
from .loader import Loader | |||
from typing import Union, Dict | |||
from .json import JsonLoader | |||
from ...core.const import Const | |||
from .loader import Loader | |||
from .. import DataBundle | |||
import os | |||
from typing import Union, Dict | |||
from ...core.const import Const | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
@@ -22,10 +33,11 @@ class MNLILoader(Loader): | |||
"...", "...","." | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def _load(self, path:str): | |||
def _load(self, path: str): | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
f.readline() # 跳过header | |||
@@ -50,8 +62,8 @@ class MNLILoader(Loader): | |||
if raw_words1 and raw_words2 and target: | |||
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) | |||
return ds | |||
def load(self, paths:str=None): | |||
def load(self, paths: str = None): | |||
""" | |||
:param str paths: 传入数据所在目录,会在该目录下寻找dev_matched.tsv, dev_mismatched.tsv, test_matched.tsv, | |||
@@ -64,13 +76,13 @@ class MNLILoader(Loader): | |||
paths = self.download() | |||
if not os.path.isdir(paths): | |||
raise NotADirectoryError(f"{paths} is not a valid directory.") | |||
files = {'dev_matched':"dev_matched.tsv", | |||
"dev_mismatched":"dev_mismatched.tsv", | |||
"test_matched":"test_matched.tsv", | |||
"test_mismatched":"test_mismatched.tsv", | |||
"train":'train.tsv'} | |||
files = {'dev_matched': "dev_matched.tsv", | |||
"dev_mismatched": "dev_mismatched.tsv", | |||
"test_matched": "test_matched.tsv", | |||
"test_mismatched": "test_mismatched.tsv", | |||
"train": 'train.tsv'} | |||
datasets = {} | |||
for name, filename in files.items(): | |||
filepath = os.path.join(paths, filename) | |||
@@ -78,11 +90,11 @@ class MNLILoader(Loader): | |||
if 'test' not in name: | |||
raise FileNotFoundError(f"{name} not found in directory {filepath}.") | |||
datasets[name] = self._load(filepath) | |||
data_bundle = DataBundle(datasets=datasets) | |||
return data_bundle | |||
def download(self): | |||
""" | |||
如果你使用了这个数据,请引用 | |||
@@ -106,14 +118,15 @@ class SNLILoader(JsonLoader): | |||
"...", "...", "." | |||
""" | |||
def __init__(self): | |||
super().__init__(fields={ | |||
'sentence1': Const.RAW_WORDS(0), | |||
'sentence2': Const.RAW_WORDS(1), | |||
'gold_label': Const.TARGET, | |||
}) | |||
def load(self, paths: Union[str, Dict[str, str]]=None) -> DataBundle: | |||
def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: | |||
""" | |||
从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。 | |||
@@ -138,11 +151,11 @@ class SNLILoader(JsonLoader): | |||
paths = _paths | |||
else: | |||
raise NotADirectoryError(f"{paths} is not a valid directory.") | |||
datasets = {name: self._load(path) for name, path in paths.items()} | |||
data_bundle = DataBundle(datasets=datasets) | |||
return data_bundle | |||
def download(self): | |||
""" | |||
如果您的文章使用了这份数据,请引用 | |||
@@ -169,12 +182,13 @@ class QNLILoader(JsonLoader): | |||
test数据集没有target列 | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def _load(self, path): | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
f.readline() # 跳过header | |||
if path.endswith("test.tsv"): | |||
@@ -198,7 +212,7 @@ class QNLILoader(JsonLoader): | |||
if raw_words1 and raw_words2 and target: | |||
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) | |||
return ds | |||
def download(self): | |||
""" | |||
如果您的实验使用到了该数据,请引用 | |||
@@ -225,12 +239,13 @@ class RTELoader(Loader): | |||
test数据集没有target列 | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def _load(self, path:str): | |||
def _load(self, path: str): | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
f.readline() # 跳过header | |||
if path.endswith("test.tsv"): | |||
@@ -254,7 +269,7 @@ class RTELoader(Loader): | |||
if raw_words1 and raw_words2 and target: | |||
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) | |||
return ds | |||
def download(self): | |||
return self._get_dataset_path('rte') | |||
@@ -281,12 +296,13 @@ class QuoraLoader(Loader): | |||
"...","." | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def _load(self, path:str): | |||
def _load(self, path: str): | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
for line in f: | |||
line = line.strip() | |||
@@ -298,6 +314,6 @@ class QuoraLoader(Loader): | |||
if raw_words1 and raw_words2 and target: | |||
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) | |||
return ds | |||
def download(self): | |||
raise RuntimeError("Quora cannot be downloaded automatically.") |
@@ -1,26 +1,39 @@ | |||
"""undocumented""" | |||
__all__ = [ | |||
"YelpFullPipe", | |||
"YelpPolarityPipe", | |||
"SSTPipe", | |||
"SST2Pipe", | |||
'IMDBPipe' | |||
] | |||
import re | |||
from nltk import Tree | |||
from .pipe import Pipe | |||
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance | |||
from ..data_bundle import DataBundle | |||
from ...core.vocabulary import Vocabulary | |||
from ...core.const import Const | |||
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader | |||
from ...core.const import Const | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
from ...core.vocabulary import Vocabulary | |||
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance | |||
from .pipe import Pipe | |||
import re | |||
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') | |||
class _CLSPipe(Pipe): | |||
""" | |||
分类问题的基类,负责对classification的数据进行tokenize操作。默认是对raw_words列操作,然后生成words列 | |||
""" | |||
def __init__(self, tokenizer:str='spacy', lang='en'): | |||
def __init__(self, tokenizer: str = 'spacy', lang='en'): | |||
self.tokenizer = get_tokenizer(tokenizer, lang=lang) | |||
def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): | |||
""" | |||
将DataBundle中的数据进行tokenize | |||
@@ -33,9 +46,9 @@ class _CLSPipe(Pipe): | |||
new_field_name = new_field_name or field_name | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) | |||
return data_bundle | |||
def _granularize(self, data_bundle, tag_map): | |||
""" | |||
该函数对data_bundle中'target'列中的内容进行转换。 | |||
@@ -47,9 +60,9 @@ class _CLSPipe(Pipe): | |||
""" | |||
for name in list(data_bundle.datasets.keys()): | |||
dataset = data_bundle.get_dataset(name) | |||
dataset.apply_field(lambda target:tag_map.get(target, -100), field_name=Const.TARGET, | |||
dataset.apply_field(lambda target: tag_map.get(target, -100), field_name=Const.TARGET, | |||
new_field_name=Const.TARGET) | |||
dataset.drop(lambda ins:ins[Const.TARGET] == -100) | |||
dataset.drop(lambda ins: ins[Const.TARGET] == -100) | |||
data_bundle.set_dataset(dataset, name) | |||
return data_bundle | |||
@@ -69,7 +82,7 @@ def _clean_str(words): | |||
t = ''.join(tt) | |||
if t != '': | |||
words_collection.append(t) | |||
return words_collection | |||
@@ -89,19 +102,20 @@ class YelpFullPipe(_CLSPipe): | |||
1、2归为1类,3归为1类,4、5归为1类;若为5, 则有5分类问题。 | |||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | |||
""" | |||
def __init__(self, lower:bool=False, granularity=5, tokenizer:str='spacy'): | |||
def __init__(self, lower: bool = False, granularity=5, tokenizer: str = 'spacy'): | |||
super().__init__(tokenizer=tokenizer, lang='en') | |||
self.lower = lower | |||
assert granularity in (2, 3, 5), "granularity can only be 2,3,5." | |||
self.granularity = granularity | |||
if granularity==2: | |||
if granularity == 2: | |||
self.tag_map = {"1": 0, "2": 0, "4": 1, "5": 1} | |||
elif granularity==3: | |||
self.tag_map = {"1": 0, "2": 0, "3":1, "4": 2, "5": 2} | |||
elif granularity == 3: | |||
self.tag_map = {"1": 0, "2": 0, "3": 1, "4": 2, "5": 2} | |||
else: | |||
self.tag_map = {"1": 0, "2": 1, "3": 2, "4": 3, "5": 4} | |||
def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): | |||
""" | |||
将DataBundle中的数据进行tokenize | |||
@@ -116,7 +130,7 @@ class YelpFullPipe(_CLSPipe): | |||
dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) | |||
dataset.apply_field(_clean_str, field_name=field_name, new_field_name=new_field_name) | |||
return data_bundle | |||
def process(self, data_bundle): | |||
""" | |||
传入的DataSet应该具备如下的结构 | |||
@@ -131,30 +145,30 @@ class YelpFullPipe(_CLSPipe): | |||
:param data_bundle: | |||
:return: | |||
""" | |||
# 复制一列words | |||
data_bundle = _add_words_field(data_bundle, lower=self.lower) | |||
# 进行tokenize | |||
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) | |||
# 根据granularity设置tag | |||
data_bundle = self._granularize(data_bundle, tag_map=self.tag_map) | |||
# 删除空行 | |||
data_bundle = _drop_empty_instance(data_bundle, field_name=Const.INPUT) | |||
# index | |||
data_bundle = _indexize(data_bundle=data_bundle) | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len(Const.INPUT) | |||
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) | |||
data_bundle.set_target(Const.TARGET) | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
""" | |||
@@ -179,27 +193,28 @@ class YelpPolarityPipe(_CLSPipe): | |||
:param bool lower: 是否对输入进行小写化。 | |||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | |||
""" | |||
def __init__(self, lower:bool=False, tokenizer:str='spacy'): | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||
super().__init__(tokenizer=tokenizer, lang='en') | |||
self.lower = lower | |||
def process(self, data_bundle): | |||
# 复制一列words | |||
data_bundle = _add_words_field(data_bundle, lower=self.lower) | |||
# 进行tokenize | |||
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) | |||
# index | |||
data_bundle = _indexize(data_bundle=data_bundle) | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len(Const.INPUT) | |||
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) | |||
data_bundle.set_target(Const.TARGET) | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
""" | |||
@@ -230,7 +245,7 @@ class SSTPipe(_CLSPipe): | |||
0、1归为1类,2归为1类,3、4归为1类;若为5, 则有5分类问题。 | |||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | |||
""" | |||
def __init__(self, subtree=False, train_subtree=True, lower=False, granularity=5, tokenizer='spacy'): | |||
super().__init__(tokenizer=tokenizer, lang='en') | |||
self.subtree = subtree | |||
@@ -238,15 +253,15 @@ class SSTPipe(_CLSPipe): | |||
self.lower = lower | |||
assert granularity in (2, 3, 5), "granularity can only be 2,3,5." | |||
self.granularity = granularity | |||
if granularity==2: | |||
if granularity == 2: | |||
self.tag_map = {"0": 0, "1": 0, "3": 1, "4": 1} | |||
elif granularity==3: | |||
self.tag_map = {"0": 0, "1": 0, "2":1, "3": 2, "4": 2} | |||
elif granularity == 3: | |||
self.tag_map = {"0": 0, "1": 0, "2": 1, "3": 2, "4": 2} | |||
else: | |||
self.tag_map = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4} | |||
def process(self, data_bundle:DataBundle): | |||
def process(self, data_bundle: DataBundle): | |||
""" | |||
对DataBundle中的数据进行预处理。输入的DataSet应该至少拥有raw_words这一列,且内容类似与 | |||
@@ -277,26 +292,26 @@ class SSTPipe(_CLSPipe): | |||
instance = Instance(raw_words=' '.join(tree.leaves()), target=tree.label()) | |||
ds.append(instance) | |||
data_bundle.set_dataset(ds, name) | |||
_add_words_field(data_bundle, lower=self.lower) | |||
# 进行tokenize | |||
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) | |||
# 根据granularity设置tag | |||
data_bundle = self._granularize(data_bundle, tag_map=self.tag_map) | |||
# index | |||
data_bundle = _indexize(data_bundle=data_bundle) | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len(Const.INPUT) | |||
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) | |||
data_bundle.set_target(Const.TARGET) | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
data_bundle = SSTLoader().load(paths) | |||
return self.process(data_bundle=data_bundle) | |||
@@ -316,11 +331,12 @@ class SST2Pipe(_CLSPipe): | |||
:param bool lower: 是否对输入进行小写化。 | |||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | |||
""" | |||
def __init__(self, lower=False, tokenizer='spacy'): | |||
super().__init__(tokenizer=tokenizer, lang='en') | |||
self.lower = lower | |||
def process(self, data_bundle:DataBundle): | |||
def process(self, data_bundle: DataBundle): | |||
""" | |||
可以处理的DataSet应该具备如下的结构 | |||
@@ -335,15 +351,15 @@ class SST2Pipe(_CLSPipe): | |||
:return: | |||
""" | |||
_add_words_field(data_bundle, self.lower) | |||
data_bundle = self._tokenize(data_bundle=data_bundle) | |||
src_vocab = Vocabulary() | |||
src_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.INPUT, | |||
no_create_entry_dataset=[dataset for name,dataset in data_bundle.datasets.items() if | |||
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if | |||
name != 'train']) | |||
src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT) | |||
tgt_vocab = Vocabulary(unknown=None, padding=None) | |||
tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) | |||
datasets = [] | |||
@@ -351,18 +367,18 @@ class SST2Pipe(_CLSPipe): | |||
if dataset.has_field(Const.TARGET): | |||
datasets.append(dataset) | |||
tgt_vocab.index_dataset(*datasets, field_name=Const.TARGET) | |||
data_bundle.set_vocab(src_vocab, Const.INPUT) | |||
data_bundle.set_vocab(tgt_vocab, Const.TARGET) | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len(Const.INPUT) | |||
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) | |||
data_bundle.set_target(Const.TARGET) | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
""" | |||
@@ -390,11 +406,12 @@ class IMDBPipe(_CLSPipe): | |||
:param bool lower: 是否将words列的数据小写。 | |||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | |||
""" | |||
def __init__(self, lower:bool=False, tokenizer:str='spacy'): | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||
super().__init__(tokenizer=tokenizer, lang='en') | |||
self.lower = lower | |||
def process(self, data_bundle:DataBundle): | |||
def process(self, data_bundle: DataBundle): | |||
""" | |||
期待的DataBunlde中输入的DataSet应该类似于如下,有两个field,raw_words和target,且均为str类型 | |||
@@ -409,25 +426,26 @@ class IMDBPipe(_CLSPipe): | |||
target列应该为str。 | |||
:return: DataBundle | |||
""" | |||
# 替换<br /> | |||
def replace_br(raw_words): | |||
raw_words = raw_words.replace("<br />", ' ') | |||
return raw_words | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply_field(replace_br, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD) | |||
_add_words_field(data_bundle, lower=self.lower) | |||
self._tokenize(data_bundle, field_name=Const.INPUT, new_field_name=Const.INPUT) | |||
_indexize(data_bundle) | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len(Const.INPUT) | |||
dataset.set_input(Const.INPUT, Const.INPUT_LEN) | |||
dataset.set_target(Const.TARGET) | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
""" | |||
@@ -437,8 +455,5 @@ class IMDBPipe(_CLSPipe): | |||
# 读取数据 | |||
data_bundle = IMDBLoader().load(paths) | |||
data_bundle = self.process(data_bundle) | |||
return data_bundle | |||
@@ -1,13 +1,25 @@ | |||
"""undocumented""" | |||
__all__ = [ | |||
"Conll2003NERPipe", | |||
"Conll2003Pipe", | |||
"OntoNotesNERPipe", | |||
"MsraNERPipe", | |||
"PeopleDailyPipe", | |||
"WeiboNERPipe" | |||
] | |||
from .pipe import Pipe | |||
from .. import DataBundle | |||
from .utils import _add_chars_field | |||
from .utils import _indexize, _add_words_field | |||
from .utils import iob2, iob2bioes | |||
from ...core.const import Const | |||
from .. import DataBundle | |||
from ..loader.conll import Conll2003NERLoader, OntoNotesNERLoader | |||
from .utils import _indexize, _add_words_field | |||
from .utils import _add_chars_field | |||
from ..loader.conll import PeopleDailyNERLoader, WeiboNERLoader, MsraNERLoader, ConllLoader | |||
from ...core.const import Const | |||
from ...core.vocabulary import Vocabulary | |||
class _NERPipe(Pipe): | |||
""" | |||
NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表 | |||
@@ -20,14 +32,14 @@ class _NERPipe(Pipe): | |||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | |||
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | |||
""" | |||
def __init__(self, encoding_type: str = 'bio', lower: bool = False): | |||
if encoding_type == 'bio': | |||
self.convert_tag = iob2 | |||
else: | |||
self.convert_tag = lambda words: iob2bioes(iob2(words)) | |||
self.lower = lower | |||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||
""" | |||
支持的DataSet的field为 | |||
@@ -46,21 +58,21 @@ class _NERPipe(Pipe): | |||
# 转换tag | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET) | |||
_add_words_field(data_bundle, lower=self.lower) | |||
# index | |||
_indexize(data_bundle) | |||
input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] | |||
target_fields = [Const.TARGET, Const.INPUT_LEN] | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len(Const.INPUT) | |||
data_bundle.set_input(*input_fields) | |||
data_bundle.set_target(*target_fields) | |||
return data_bundle | |||
@@ -84,7 +96,7 @@ class Conll2003NERPipe(_NERPipe): | |||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | |||
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | |||
""" | |||
def process_from_file(self, paths) -> DataBundle: | |||
""" | |||
@@ -94,7 +106,7 @@ class Conll2003NERPipe(_NERPipe): | |||
# 读取数据 | |||
data_bundle = Conll2003NERLoader().load(paths) | |||
data_bundle = self.process(data_bundle) | |||
return data_bundle | |||
@@ -125,8 +137,8 @@ class Conll2003Pipe(Pipe): | |||
else: | |||
self.ner_convert_tag = lambda tags: iob2bioes(iob2(tags)) | |||
self.lower = lower | |||
def process(self, data_bundle)->DataBundle: | |||
def process(self, data_bundle) -> DataBundle: | |||
""" | |||
输入的DataSet应该类似于如下的形式 | |||
@@ -145,9 +157,9 @@ class Conll2003Pipe(Pipe): | |||
dataset.drop(lambda x: "-DOCSTART-" in x[Const.RAW_WORD]) | |||
dataset.apply_field(self.chunk_convert_tag, field_name='chunk', new_field_name='chunk') | |||
dataset.apply_field(self.ner_convert_tag, field_name='ner', new_field_name='ner') | |||
_add_words_field(data_bundle, lower=self.lower) | |||
# index | |||
_indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=['pos', 'ner']) | |||
# chunk中存在一些tag只在dev中出现,没在train中 | |||
@@ -155,18 +167,18 @@ class Conll2003Pipe(Pipe): | |||
tgt_vocab.from_dataset(*data_bundle.datasets.values(), field_name='chunk') | |||
tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name='chunk') | |||
data_bundle.set_vocab(tgt_vocab, 'chunk') | |||
input_fields = [Const.INPUT, Const.INPUT_LEN] | |||
target_fields = ['pos', 'ner', 'chunk', Const.INPUT_LEN] | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len(Const.INPUT) | |||
data_bundle.set_input(*input_fields) | |||
data_bundle.set_target(*target_fields) | |||
return data_bundle | |||
def process_from_file(self, paths): | |||
""" | |||
@@ -194,7 +206,7 @@ class OntoNotesNERPipe(_NERPipe): | |||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | |||
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | |||
""" | |||
def process_from_file(self, paths): | |||
data_bundle = OntoNotesNERLoader().load(paths) | |||
return self.process(data_bundle) | |||
@@ -211,13 +223,13 @@ class _CNNERPipe(Pipe): | |||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | |||
""" | |||
def __init__(self, encoding_type: str = 'bio'): | |||
if encoding_type == 'bio': | |||
self.convert_tag = iob2 | |||
else: | |||
self.convert_tag = lambda words: iob2bioes(iob2(words)) | |||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||
""" | |||
支持的DataSet的field为 | |||
@@ -239,21 +251,21 @@ class _CNNERPipe(Pipe): | |||
# 转换tag | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET) | |||
_add_chars_field(data_bundle, lower=False) | |||
# index | |||
_indexize(data_bundle, input_field_names=Const.CHAR_INPUT, target_field_names=Const.TARGET) | |||
input_fields = [Const.TARGET, Const.CHAR_INPUT, Const.INPUT_LEN] | |||
target_fields = [Const.TARGET, Const.INPUT_LEN] | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len(Const.CHAR_INPUT) | |||
data_bundle.set_input(*input_fields) | |||
data_bundle.set_target(*target_fields) | |||
return data_bundle | |||
@@ -272,6 +284,7 @@ class MsraNERPipe(_CNNERPipe): | |||
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | |||
""" | |||
def process_from_file(self, paths=None) -> DataBundle: | |||
data_bundle = MsraNERLoader().load(paths) | |||
return self.process(data_bundle) | |||
@@ -291,6 +304,7 @@ class PeopleDailyPipe(_CNNERPipe): | |||
raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | |||
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | |||
""" | |||
def process_from_file(self, paths=None) -> DataBundle: | |||
data_bundle = PeopleDailyNERLoader().load(paths) | |||
return self.process(data_bundle) | |||
@@ -312,6 +326,7 @@ class WeiboNERPipe(_CNNERPipe): | |||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | |||
""" | |||
def process_from_file(self, paths=None) -> DataBundle: | |||
data_bundle = WeiboNERLoader().load(paths) | |||
return self.process(data_bundle) |
@@ -1,3 +1,9 @@ | |||
"""undocumented""" | |||
__all__ = [ | |||
"CWSPipe" | |||
] | |||
import re | |||
from itertools import chain | |||
@@ -1,9 +1,25 @@ | |||
"""undocumented""" | |||
__all__ = [ | |||
"MatchingBertPipe", | |||
"RTEBertPipe", | |||
"SNLIBertPipe", | |||
"QuoraBertPipe", | |||
"QNLIBertPipe", | |||
"MNLIBertPipe", | |||
"MatchingPipe", | |||
"RTEPipe", | |||
"SNLIPipe", | |||
"QuoraPipe", | |||
"QNLIPipe", | |||
"MNLIPipe", | |||
] | |||
from .pipe import Pipe | |||
from .utils import get_tokenizer | |||
from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader | |||
from ...core.const import Const | |||
from ...core.vocabulary import Vocabulary | |||
from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader | |||
class MatchingBertPipe(Pipe): | |||
@@ -24,12 +40,13 @@ class MatchingBertPipe(Pipe): | |||
:param bool lower: 是否将word小写化。 | |||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | |||
""" | |||
def __init__(self, lower=False, tokenizer: str='raw'): | |||
def __init__(self, lower=False, tokenizer: str = 'raw'): | |||
super().__init__() | |||
self.lower = bool(lower) | |||
self.tokenizer = get_tokenizer(tokenizer=tokenizer) | |||
def _tokenize(self, data_bundle, field_names, new_field_names): | |||
""" | |||
@@ -43,62 +60,62 @@ class MatchingBertPipe(Pipe): | |||
dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, | |||
new_field_name=new_field_name) | |||
return data_bundle | |||
def process(self, data_bundle): | |||
for dataset in data_bundle.datasets.values(): | |||
if dataset.has_field(Const.TARGET): | |||
dataset.drop(lambda x: x[Const.TARGET] == '-') | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.copy_field(Const.RAW_WORDS(0), Const.INPUTS(0), ) | |||
dataset.copy_field(Const.RAW_WORDS(1), Const.INPUTS(1), ) | |||
if self.lower: | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset[Const.INPUTS(0)].lower() | |||
dataset[Const.INPUTS(1)].lower() | |||
data_bundle = self._tokenize(data_bundle, [Const.INPUTS(0), Const.INPUTS(1)], | |||
[Const.INPUTS(0), Const.INPUTS(1)]) | |||
# concat两个words | |||
def concat(ins): | |||
words0 = ins[Const.INPUTS(0)] | |||
words1 = ins[Const.INPUTS(1)] | |||
words = words0 + ['[SEP]'] + words1 | |||
return words | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply(concat, new_field_name=Const.INPUT) | |||
dataset.delete_field(Const.INPUTS(0)) | |||
dataset.delete_field(Const.INPUTS(1)) | |||
word_vocab = Vocabulary() | |||
word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name], | |||
field_name=Const.INPUT, | |||
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if | |||
'train' not in name]) | |||
word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT) | |||
target_vocab = Vocabulary(padding=None, unknown=None) | |||
target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) | |||
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if | |||
dataset.has_field(Const.TARGET)] | |||
target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) | |||
data_bundle.set_vocab(word_vocab, Const.INPUT) | |||
data_bundle.set_vocab(target_vocab, Const.TARGET) | |||
input_fields = [Const.INPUT, Const.INPUT_LEN] | |||
target_fields = [Const.TARGET] | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len(Const.INPUT) | |||
dataset.set_input(*input_fields, flag=True) | |||
for fields in target_fields: | |||
if dataset.has_field(fields): | |||
dataset.set_target(fields, flag=True) | |||
return data_bundle | |||
@@ -150,12 +167,13 @@ class MatchingPipe(Pipe): | |||
:param bool lower: 是否将所有raw_words转为小写。 | |||
:param str tokenizer: 将原始数据tokenize的方式。支持spacy, raw. spacy是使用spacy切分,raw就是用空格切分。 | |||
""" | |||
def __init__(self, lower=False, tokenizer: str='raw'): | |||
def __init__(self, lower=False, tokenizer: str = 'raw'): | |||
super().__init__() | |||
self.lower = bool(lower) | |||
self.tokenizer = get_tokenizer(tokenizer=tokenizer) | |||
def _tokenize(self, data_bundle, field_names, new_field_names): | |||
""" | |||
@@ -169,7 +187,7 @@ class MatchingPipe(Pipe): | |||
dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, | |||
new_field_name=new_field_name) | |||
return data_bundle | |||
def process(self, data_bundle): | |||
""" | |||
接受的DataBundle中的DataSet应该具有以下的field, target列可以没有 | |||
@@ -186,35 +204,35 @@ class MatchingPipe(Pipe): | |||
""" | |||
data_bundle = self._tokenize(data_bundle, [Const.RAW_WORDS(0), Const.RAW_WORDS(1)], | |||
[Const.INPUTS(0), Const.INPUTS(1)]) | |||
for dataset in data_bundle.datasets.values(): | |||
if dataset.has_field(Const.TARGET): | |||
dataset.drop(lambda x: x[Const.TARGET] == '-') | |||
if self.lower: | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset[Const.INPUTS(0)].lower() | |||
dataset[Const.INPUTS(1)].lower() | |||
word_vocab = Vocabulary() | |||
word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name], | |||
field_name=[Const.INPUTS(0), Const.INPUTS(1)], | |||
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if | |||
'train' not in name]) | |||
word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=[Const.INPUTS(0), Const.INPUTS(1)]) | |||
target_vocab = Vocabulary(padding=None, unknown=None) | |||
target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) | |||
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if | |||
dataset.has_field(Const.TARGET)] | |||
target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) | |||
data_bundle.set_vocab(word_vocab, Const.INPUTS(0)) | |||
data_bundle.set_vocab(target_vocab, Const.TARGET) | |||
input_fields = [Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1)] | |||
target_fields = [Const.TARGET] | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len(Const.INPUTS(0), Const.INPUT_LENS(0)) | |||
dataset.add_seq_len(Const.INPUTS(1), Const.INPUT_LENS(1)) | |||
@@ -222,7 +240,7 @@ class MatchingPipe(Pipe): | |||
for fields in target_fields: | |||
if dataset.has_field(fields): | |||
dataset.set_target(fields, flag=True) | |||
return data_bundle | |||
@@ -254,4 +272,3 @@ class MNLIPipe(MatchingPipe): | |||
def process_from_file(self, paths=None): | |||
data_bundle = MNLILoader().load(paths) | |||
return self.process(data_bundle) | |||
@@ -1,3 +1,9 @@ | |||
"""undocumented""" | |||
__all__ = [ | |||
"Pipe", | |||
] | |||
from .. import DataBundle | |||
@@ -1,8 +1,18 @@ | |||
"""undocumented""" | |||
__all__ = [ | |||
"iob2", | |||
"iob2bioes", | |||
"get_tokenizer", | |||
] | |||
from typing import List | |||
from ...core.vocabulary import Vocabulary | |||
from ...core.const import Const | |||
from ...core.vocabulary import Vocabulary | |||
def iob2(tags:List[str])->List[str]: | |||
def iob2(tags: List[str]) -> List[str]: | |||
""" | |||
检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两种格式的区别见 | |||
https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format | |||
@@ -25,7 +35,8 @@ def iob2(tags:List[str])->List[str]: | |||
tags[i] = "B" + tag[1:] | |||
return tags | |||
def iob2bioes(tags:List[str])->List[str]: | |||
def iob2bioes(tags: List[str]) -> List[str]: | |||
""" | |||
将iob的tag转换为bioes编码 | |||
:param tags: | |||
@@ -38,12 +49,12 @@ def iob2bioes(tags:List[str])->List[str]: | |||
else: | |||
split = tag.split('-')[0] | |||
if split == 'B': | |||
if i+1!=len(tags) and tags[i+1].split('-')[0] == 'I': | |||
if i + 1 != len(tags) and tags[i + 1].split('-')[0] == 'I': | |||
new_tags.append(tag) | |||
else: | |||
new_tags.append(tag.replace('B-', 'S-')) | |||
elif split == 'I': | |||
if i + 1<len(tags) and tags[i+1].split('-')[0] == 'I': | |||
if i + 1 < len(tags) and tags[i + 1].split('-')[0] == 'I': | |||
new_tags.append(tag) | |||
else: | |||
new_tags.append(tag.replace('I-', 'E-')) | |||
@@ -52,7 +63,7 @@ def iob2bioes(tags:List[str])->List[str]: | |||
return new_tags | |||
def get_tokenizer(tokenizer:str, lang='en'): | |||
def get_tokenizer(tokenizer: str, lang='en'): | |||
""" | |||
:param str tokenizer: 获取tokenzier方法 | |||
@@ -97,13 +108,13 @@ def _indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=Con | |||
name != 'train']) | |||
src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=input_field_name) | |||
data_bundle.set_vocab(src_vocab, input_field_name) | |||
for target_field_name in target_field_names: | |||
tgt_vocab = Vocabulary(unknown=None, padding=None) | |||
tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=target_field_name) | |||
tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name=target_field_name) | |||
data_bundle.set_vocab(tgt_vocab, target_field_name) | |||
return data_bundle | |||
@@ -116,7 +127,7 @@ def _add_words_field(data_bundle, lower=False): | |||
:return: 传入的DataBundle | |||
""" | |||
data_bundle.copy_field(field_name=Const.RAW_WORD, new_field_name=Const.INPUT, ignore_miss_dataset=True) | |||
if lower: | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset[Const.INPUT].lower() | |||
@@ -132,7 +143,7 @@ def _add_chars_field(data_bundle, lower=False): | |||
:return: 传入的DataBundle | |||
""" | |||
data_bundle.copy_field(field_name=Const.RAW_CHAR, new_field_name=Const.CHAR_INPUT, ignore_miss_dataset=True) | |||
if lower: | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset[Const.CHAR_INPUT].lower() | |||
@@ -147,6 +158,7 @@ def _drop_empty_instance(data_bundle, field_name): | |||
:param str field_name: 对哪个field进行检查,如果为None,则任意field为空都会删掉 | |||
:return: 传入的DataBundle | |||
""" | |||
def empty_instance(ins): | |||
if field_name: | |||
field_value = ins[field_name] | |||
@@ -157,10 +169,8 @@ def _drop_empty_instance(data_bundle, field_name): | |||
if field_value in ((), {}, [], ''): | |||
return True | |||
return False | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.drop(empty_instance) | |||
return data_bundle | |||
@@ -1,10 +1,20 @@ | |||
import os | |||
""" | |||
.. todo:: | |||
doc | |||
""" | |||
from typing import Union, Dict | |||
__all__ = [ | |||
"check_loader_paths" | |||
] | |||
import os | |||
from pathlib import Path | |||
from typing import Union, Dict | |||
from ..core import logger | |||
def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||
def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]: | |||
""" | |||
检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果:: | |||
@@ -33,11 +43,13 @@ def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||
path_pair = ('train', filename) | |||
if 'dev' in filename: | |||
if path_pair: | |||
raise Exception("File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0])) | |||
raise Exception( | |||
"File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0])) | |||
path_pair = ('dev', filename) | |||
if 'test' in filename: | |||
if path_pair: | |||
raise Exception("File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0])) | |||
raise Exception( | |||
"File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0])) | |||
path_pair = ('test', filename) | |||
if path_pair: | |||
files[path_pair[0]] = os.path.join(paths, path_pair[1]) | |||
@@ -46,7 +58,7 @@ def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||
return files | |||
else: | |||
raise FileNotFoundError(f"{paths} is not a valid file path.") | |||
elif isinstance(paths, dict): | |||
if paths: | |||
if 'train' not in paths: | |||
@@ -65,6 +77,7 @@ def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||
else: | |||
raise TypeError(f"paths only supports str and dict. not {type(paths)}.") | |||
def get_tokenizer(): | |||
try: | |||
import spacy | |||