@@ -1,10 +1,15 @@ | |||||
""" | |||||
.. todo:: | |||||
doc | |||||
""" | |||||
__all__ = [ | __all__ = [ | ||||
'DataBundle', | 'DataBundle', | ||||
] | ] | ||||
import _pickle as pickle | import _pickle as pickle | ||||
from typing import Union, Dict | |||||
import os | import os | ||||
from typing import Union, Dict | |||||
from ..core.dataset import DataSet | from ..core.dataset import DataSet | ||||
from ..core.vocabulary import Vocabulary | from ..core.vocabulary import Vocabulary | ||||
@@ -1,4 +1,4 @@ | |||||
""" | |||||
"""undocumented | |||||
.. warning:: | .. warning:: | ||||
本模块将在 `0.5.0版本` 中被废弃,由 :mod:`~fastNLP.io.loader` 和 :mod:`~fastNLP.io.pipe` 模块替代。 | 本模块将在 `0.5.0版本` 中被废弃,由 :mod:`~fastNLP.io.loader` 和 :mod:`~fastNLP.io.pipe` 模块替代。 | ||||
@@ -23,10 +23,10 @@ __all__ = [ | |||||
] | ] | ||||
from .data_bundle import DataSetLoader | |||||
from .file_reader import _read_csv, _read_json | |||||
from ..core.dataset import DataSet | from ..core.dataset import DataSet | ||||
from ..core.instance import Instance | from ..core.instance import Instance | ||||
from .file_reader import _read_csv, _read_json | |||||
from .data_bundle import DataSetLoader | |||||
class JsonLoader(DataSetLoader): | class JsonLoader(DataSetLoader): | ||||
@@ -1,17 +1,22 @@ | |||||
""" | |||||
.. todo:: | |||||
doc | |||||
""" | |||||
__all__ = [ | __all__ = [ | ||||
"EmbedLoader", | "EmbedLoader", | ||||
"EmbeddingOption", | "EmbeddingOption", | ||||
] | ] | ||||
import logging | |||||
import os | import os | ||||
import warnings | import warnings | ||||
import numpy as np | import numpy as np | ||||
from ..core.vocabulary import Vocabulary | |||||
from .data_bundle import BaseLoader | from .data_bundle import BaseLoader | ||||
from ..core.utils import Option | from ..core.utils import Option | ||||
import logging | |||||
from ..core.vocabulary import Vocabulary | |||||
class EmbeddingOption(Option): | class EmbeddingOption(Option): | ||||
def __init__(self, | def __init__(self, | ||||
@@ -1,7 +1,11 @@ | |||||
""" | |||||
"""undocumented | |||||
此模块用于给其它模块提供读取文件的函数,没有为用户提供 API | 此模块用于给其它模块提供读取文件的函数,没有为用户提供 API | ||||
""" | """ | ||||
__all__ = [] | |||||
import json | import json | ||||
from ..core import logger | from ..core import logger | ||||
@@ -24,8 +28,8 @@ def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): | |||||
headers = headers.split(sep) | headers = headers.split(sep) | ||||
start_idx += 1 | start_idx += 1 | ||||
elif not isinstance(headers, (list, tuple)): | elif not isinstance(headers, (list, tuple)): | ||||
raise TypeError("headers should be list or tuple, not {}." \ | |||||
.format(type(headers))) | |||||
raise TypeError("headers should be list or tuple, not {}." \ | |||||
.format(type(headers))) | |||||
for line_idx, line in enumerate(f, start_idx): | for line_idx, line in enumerate(f, start_idx): | ||||
contents = line.rstrip('\r\n').split(sep) | contents = line.rstrip('\r\n').split(sep) | ||||
if len(contents) != len(headers): | if len(contents) != len(headers): | ||||
@@ -82,6 +86,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||||
:if False, raise ValueError when reading invalid data. default: True | :if False, raise ValueError when reading invalid data. default: True | ||||
:return: generator, every time yield (line number, conll item) | :return: generator, every time yield (line number, conll item) | ||||
""" | """ | ||||
def parse_conll(sample): | def parse_conll(sample): | ||||
sample = list(map(list, zip(*sample))) | sample = list(map(list, zip(*sample))) | ||||
sample = [sample[i] for i in indexes] | sample = [sample[i] for i in indexes] | ||||
@@ -89,14 +94,15 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||||
if len(f) <= 0: | if len(f) <= 0: | ||||
raise ValueError('empty field') | raise ValueError('empty field') | ||||
return sample | return sample | ||||
with open(path, 'r', encoding=encoding) as f: | with open(path, 'r', encoding=encoding) as f: | ||||
sample = [] | sample = [] | ||||
start = next(f).strip() | start = next(f).strip() | ||||
if start!='': | |||||
if start != '': | |||||
sample.append(start.split()) | sample.append(start.split()) | ||||
for line_idx, line in enumerate(f, 1): | for line_idx, line in enumerate(f, 1): | ||||
line = line.strip() | line = line.strip() | ||||
if line=='': | |||||
if line == '': | |||||
if len(sample): | if len(sample): | ||||
try: | try: | ||||
res = parse_conll(sample) | res = parse_conll(sample) | ||||
@@ -1,12 +1,27 @@ | |||||
""" | |||||
.. todo:: | |||||
doc | |||||
""" | |||||
__all__ = [ | |||||
"cached_path", | |||||
"get_filepath", | |||||
"get_cache_path", | |||||
"split_filename_suffix", | |||||
"get_from_cache", | |||||
] | |||||
import os | import os | ||||
import re | |||||
import shutil | |||||
import tempfile | |||||
from pathlib import Path | from pathlib import Path | ||||
from urllib.parse import urlparse | from urllib.parse import urlparse | ||||
import re | |||||
import requests | import requests | ||||
import tempfile | |||||
from tqdm import tqdm | |||||
import shutil | |||||
from requests import HTTPError | from requests import HTTPError | ||||
from tqdm import tqdm | |||||
from ..core import logger | from ..core import logger | ||||
PRETRAINED_BERT_MODEL_DIR = { | PRETRAINED_BERT_MODEL_DIR = { | ||||
@@ -1,12 +1,24 @@ | |||||
from ...core.dataset import DataSet | |||||
from ...core.instance import Instance | |||||
from .loader import Loader | |||||
import warnings | |||||
"""undocumented""" | |||||
__all__ = [ | |||||
"YelpLoader", | |||||
"YelpFullLoader", | |||||
"YelpPolarityLoader", | |||||
"IMDBLoader", | |||||
"SSTLoader", | |||||
"SST2Loader", | |||||
] | |||||
import glob | |||||
import os | import os | ||||
import random | import random | ||||
import shutil | import shutil | ||||
import glob | |||||
import time | import time | ||||
import warnings | |||||
from .loader import Loader | |||||
from ...core.dataset import DataSet | |||||
from ...core.instance import Instance | |||||
class YelpLoader(Loader): | class YelpLoader(Loader): | ||||
@@ -58,7 +70,7 @@ class YelpLoader(Loader): | |||||
class YelpFullLoader(YelpLoader): | class YelpFullLoader(YelpLoader): | ||||
def download(self, dev_ratio: float = 0.1, re_download:bool=False): | |||||
def download(self, dev_ratio: float = 0.1, re_download: bool = False): | |||||
""" | """ | ||||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | ||||
@@ -127,7 +139,7 @@ class YelpPolarityLoader(YelpLoader): | |||||
if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 | if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 | ||||
shutil.rmtree(data_dir) | shutil.rmtree(data_dir) | ||||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | data_dir = self._get_dataset_path(dataset_name=dataset_name) | ||||
if not os.path.exists(os.path.join(data_dir, 'dev.csv')): | if not os.path.exists(os.path.join(data_dir, 'dev.csv')): | ||||
if dev_ratio > 0: | if dev_ratio > 0: | ||||
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | ||||
@@ -1,15 +1,28 @@ | |||||
from typing import Dict, Union | |||||
"""undocumented""" | |||||
__all__ = [ | |||||
"ConllLoader", | |||||
"Conll2003Loader", | |||||
"Conll2003NERLoader", | |||||
"OntoNotesNERLoader", | |||||
"CTBLoader", | |||||
"CNNERLoader", | |||||
"MsraNERLoader", | |||||
"WeiboNERLoader", | |||||
"PeopleDailyNERLoader" | |||||
] | |||||
from .loader import Loader | |||||
from ...core.dataset import DataSet | |||||
from ..file_reader import _read_conll | |||||
from ...core.instance import Instance | |||||
from ...core.const import Const | |||||
import glob | import glob | ||||
import os | import os | ||||
import random | |||||
import shutil | import shutil | ||||
import time | import time | ||||
import random | |||||
from .loader import Loader | |||||
from ..file_reader import _read_conll | |||||
from ...core.const import Const | |||||
from ...core.dataset import DataSet | |||||
from ...core.instance import Instance | |||||
class ConllLoader(Loader): | class ConllLoader(Loader): | ||||
@@ -47,6 +60,7 @@ class ConllLoader(Loader): | |||||
:param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` | :param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` | ||||
""" | """ | ||||
def __init__(self, headers, indexes=None, dropna=True): | def __init__(self, headers, indexes=None, dropna=True): | ||||
super(ConllLoader, self).__init__() | super(ConllLoader, self).__init__() | ||||
if not isinstance(headers, (list, tuple)): | if not isinstance(headers, (list, tuple)): | ||||
@@ -60,7 +74,7 @@ class ConllLoader(Loader): | |||||
if len(indexes) != len(headers): | if len(indexes) != len(headers): | ||||
raise ValueError | raise ValueError | ||||
self.indexes = indexes | self.indexes = indexes | ||||
def _load(self, path): | def _load(self, path): | ||||
""" | """ | ||||
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 | 传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 | ||||
@@ -101,12 +115,13 @@ class Conll2003Loader(ConllLoader): | |||||
"[...]", "[...]", "[...]", "[...]" | "[...]", "[...]", "[...]", "[...]" | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
headers = [ | headers = [ | ||||
'raw_words', 'pos', 'chunk', 'ner', | 'raw_words', 'pos', 'chunk', 'ner', | ||||
] | ] | ||||
super(Conll2003Loader, self).__init__(headers=headers) | super(Conll2003Loader, self).__init__(headers=headers) | ||||
def _load(self, path): | def _load(self, path): | ||||
""" | """ | ||||
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 | 传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 | ||||
@@ -127,7 +142,7 @@ class Conll2003Loader(ConllLoader): | |||||
ins = {h: data[i] for i, h in enumerate(self.headers)} | ins = {h: data[i] for i, h in enumerate(self.headers)} | ||||
ds.append(Instance(**ins)) | ds.append(Instance(**ins)) | ||||
return ds | return ds | ||||
def download(self, output_dir=None): | def download(self, output_dir=None): | ||||
raise RuntimeError("conll2003 cannot be downloaded automatically.") | raise RuntimeError("conll2003 cannot be downloaded automatically.") | ||||
@@ -158,12 +173,13 @@ class Conll2003NERLoader(ConllLoader): | |||||
"[...]", "[...]" | "[...]", "[...]" | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
headers = [ | headers = [ | ||||
'raw_words', 'target', | 'raw_words', 'target', | ||||
] | ] | ||||
super().__init__(headers=headers, indexes=[0, 3]) | super().__init__(headers=headers, indexes=[0, 3]) | ||||
def _load(self, path): | def _load(self, path): | ||||
""" | """ | ||||
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 | 传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 | ||||
@@ -184,7 +200,7 @@ class Conll2003NERLoader(ConllLoader): | |||||
ins = {h: data[i] for i, h in enumerate(self.headers)} | ins = {h: data[i] for i, h in enumerate(self.headers)} | ||||
ds.append(Instance(**ins)) | ds.append(Instance(**ins)) | ||||
return ds | return ds | ||||
def download(self): | def download(self): | ||||
raise RuntimeError("conll2003 cannot be downloaded automatically.") | raise RuntimeError("conll2003 cannot be downloaded automatically.") | ||||
@@ -204,13 +220,13 @@ class OntoNotesNERLoader(ConllLoader): | |||||
"[...]", "[...]" | "[...]", "[...]" | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__(headers=[Const.RAW_WORD, Const.TARGET], indexes=[3, 10]) | super().__init__(headers=[Const.RAW_WORD, Const.TARGET], indexes=[3, 10]) | ||||
def _load(self, path:str): | |||||
def _load(self, path: str): | |||||
dataset = super()._load(path) | dataset = super()._load(path) | ||||
def convert_to_bio(tags): | def convert_to_bio(tags): | ||||
bio_tags = [] | bio_tags = [] | ||||
flag = None | flag = None | ||||
@@ -227,7 +243,7 @@ class OntoNotesNERLoader(ConllLoader): | |||||
flag = None | flag = None | ||||
bio_tags.append(bio_label) | bio_tags.append(bio_label) | ||||
return bio_tags | return bio_tags | ||||
def convert_word(words): | def convert_word(words): | ||||
converted_words = [] | converted_words = [] | ||||
for word in words: | for word in words: | ||||
@@ -236,7 +252,7 @@ class OntoNotesNERLoader(ConllLoader): | |||||
converted_words.append(word) | converted_words.append(word) | ||||
continue | continue | ||||
# 以下是由于这些符号被转义了,再转回来 | # 以下是由于这些符号被转义了,再转回来 | ||||
tfrs = {'-LRB-':'(', | |||||
tfrs = {'-LRB-': '(', | |||||
'-RRB-': ')', | '-RRB-': ')', | ||||
'-LSB-': '[', | '-LSB-': '[', | ||||
'-RSB-': ']', | '-RSB-': ']', | ||||
@@ -248,12 +264,12 @@ class OntoNotesNERLoader(ConllLoader): | |||||
else: | else: | ||||
converted_words.append(word) | converted_words.append(word) | ||||
return converted_words | return converted_words | ||||
dataset.apply_field(convert_word, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD) | dataset.apply_field(convert_word, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD) | ||||
dataset.apply_field(convert_to_bio, field_name=Const.TARGET, new_field_name=Const.TARGET) | dataset.apply_field(convert_to_bio, field_name=Const.TARGET, new_field_name=Const.TARGET) | ||||
return dataset | return dataset | ||||
def download(self): | def download(self): | ||||
raise RuntimeError("Ontonotes cannot be downloaded automatically, you can refer " | raise RuntimeError("Ontonotes cannot be downloaded automatically, you can refer " | ||||
"https://github.com/yhcc/OntoNotes-5.0-NER to download and preprocess.") | "https://github.com/yhcc/OntoNotes-5.0-NER to download and preprocess.") | ||||
@@ -262,13 +278,13 @@ class OntoNotesNERLoader(ConllLoader): | |||||
class CTBLoader(Loader): | class CTBLoader(Loader): | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
def _load(self, path:str): | |||||
def _load(self, path: str): | |||||
pass | pass | ||||
class CNNERLoader(Loader): | class CNNERLoader(Loader): | ||||
def _load(self, path:str): | |||||
def _load(self, path: str): | |||||
""" | """ | ||||
支持加载形如以下格式的内容,一行两列,以空格隔开两个sample | 支持加载形如以下格式的内容,一行两列,以空格隔开两个sample | ||||
@@ -331,10 +347,11 @@ class MsraNERLoader(CNNERLoader): | |||||
"[...]", "[...]" | "[...]", "[...]" | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
def download(self, dev_ratio:float=0.1, re_download:bool=False)->str: | |||||
def download(self, dev_ratio: float = 0.1, re_download: bool = False) -> str: | |||||
""" | """ | ||||
自动下载MSAR-NER的数据,如果你使用该数据,请引用 Gina-Anne Levow, 2006, The Third International Chinese Language | 自动下载MSAR-NER的数据,如果你使用该数据,请引用 Gina-Anne Levow, 2006, The Third International Chinese Language | ||||
Processing Bakeoff: Word Segmentation and Named Entity Recognition. | Processing Bakeoff: Word Segmentation and Named Entity Recognition. | ||||
@@ -356,7 +373,7 @@ class MsraNERLoader(CNNERLoader): | |||||
if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 | if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 | ||||
shutil.rmtree(data_dir) | shutil.rmtree(data_dir) | ||||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | data_dir = self._get_dataset_path(dataset_name=dataset_name) | ||||
if not os.path.exists(os.path.join(data_dir, 'dev.conll')): | if not os.path.exists(os.path.join(data_dir, 'dev.conll')): | ||||
if dev_ratio > 0: | if dev_ratio > 0: | ||||
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | ||||
@@ -380,15 +397,15 @@ class MsraNERLoader(CNNERLoader): | |||||
finally: | finally: | ||||
if os.path.exists(os.path.join(data_dir, 'middle_file.conll')): | if os.path.exists(os.path.join(data_dir, 'middle_file.conll')): | ||||
os.remove(os.path.join(data_dir, 'middle_file.conll')) | os.remove(os.path.join(data_dir, 'middle_file.conll')) | ||||
return data_dir | return data_dir | ||||
class WeiboNERLoader(CNNERLoader): | class WeiboNERLoader(CNNERLoader): | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
def download(self)->str: | |||||
def download(self) -> str: | |||||
""" | """ | ||||
自动下载Weibo-NER的数据,如果你使用了该数据,请引用 Nanyun Peng and Mark Dredze, 2015, Named Entity Recognition for | 自动下载Weibo-NER的数据,如果你使用了该数据,请引用 Nanyun Peng and Mark Dredze, 2015, Named Entity Recognition for | ||||
Chinese Social Media with Jointly Trained Embeddings. | Chinese Social Media with Jointly Trained Embeddings. | ||||
@@ -397,7 +414,7 @@ class WeiboNERLoader(CNNERLoader): | |||||
""" | """ | ||||
dataset_name = 'weibo-ner' | dataset_name = 'weibo-ner' | ||||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | data_dir = self._get_dataset_path(dataset_name=dataset_name) | ||||
return data_dir | return data_dir | ||||
@@ -427,11 +444,12 @@ class PeopleDailyNERLoader(CNNERLoader): | |||||
"[...]", "[...]" | "[...]", "[...]" | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
def download(self) -> str: | def download(self) -> str: | ||||
dataset_name = 'peopledaily' | dataset_name = 'peopledaily' | ||||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | data_dir = self._get_dataset_path(dataset_name=dataset_name) | ||||
return data_dir | return data_dir |
@@ -1,7 +1,13 @@ | |||||
"""undocumented""" | |||||
__all__ = [ | |||||
"CSVLoader", | |||||
] | |||||
from .loader import Loader | |||||
from ..file_reader import _read_csv | |||||
from ...core.dataset import DataSet | from ...core.dataset import DataSet | ||||
from ...core.instance import Instance | from ...core.instance import Instance | ||||
from ..file_reader import _read_csv | |||||
from .loader import Loader | |||||
class CSVLoader(Loader): | class CSVLoader(Loader): | ||||
@@ -1,11 +1,18 @@ | |||||
from .loader import Loader | |||||
from ...core.dataset import DataSet | |||||
from ...core.instance import Instance | |||||
"""undocumented""" | |||||
__all__ = [ | |||||
"CWSLoader" | |||||
] | |||||
import glob | import glob | ||||
import os | import os | ||||
import time | |||||
import shutil | |||||
import random | import random | ||||
import shutil | |||||
import time | |||||
from .loader import Loader | |||||
from ...core.dataset import DataSet | |||||
from ...core.instance import Instance | |||||
class CWSLoader(Loader): | class CWSLoader(Loader): | ||||
@@ -1,7 +1,13 @@ | |||||
"""undocumented""" | |||||
__all__ = [ | |||||
"JsonLoader" | |||||
] | |||||
from .loader import Loader | |||||
from ..file_reader import _read_json | |||||
from ...core.dataset import DataSet | from ...core.dataset import DataSet | ||||
from ...core.instance import Instance | from ...core.instance import Instance | ||||
from ..file_reader import _read_json | |||||
from .loader import Loader | |||||
class JsonLoader(Loader): | class JsonLoader(Loader): | ||||
@@ -1,8 +1,15 @@ | |||||
from ...core.dataset import DataSet | |||||
from .. import DataBundle | |||||
from ..utils import check_loader_paths | |||||
"""undocumented""" | |||||
__all__ = [ | |||||
"Loader" | |||||
] | |||||
from typing import Union, Dict | from typing import Union, Dict | ||||
from .. import DataBundle | |||||
from ..file_utils import _get_dataset_url, get_cache_path, cached_path | from ..file_utils import _get_dataset_url, get_cache_path, cached_path | ||||
from ..utils import check_loader_paths | |||||
from ...core.dataset import DataSet | |||||
class Loader: | class Loader: | ||||
@@ -1,10 +1,21 @@ | |||||
"""undocumented""" | |||||
__all__ = [ | |||||
"MNLILoader", | |||||
"SNLILoader", | |||||
"QNLILoader", | |||||
"RTELoader", | |||||
"QuoraLoader", | |||||
] | |||||
import os | |||||
import warnings | import warnings | ||||
from .loader import Loader | |||||
from typing import Union, Dict | |||||
from .json import JsonLoader | from .json import JsonLoader | ||||
from ...core.const import Const | |||||
from .loader import Loader | |||||
from .. import DataBundle | from .. import DataBundle | ||||
import os | |||||
from typing import Union, Dict | |||||
from ...core.const import Const | |||||
from ...core.dataset import DataSet | from ...core.dataset import DataSet | ||||
from ...core.instance import Instance | from ...core.instance import Instance | ||||
@@ -22,10 +33,11 @@ class MNLILoader(Loader): | |||||
"...", "...","." | "...", "...","." | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
def _load(self, path:str): | |||||
def _load(self, path: str): | |||||
ds = DataSet() | ds = DataSet() | ||||
with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
f.readline() # 跳过header | f.readline() # 跳过header | ||||
@@ -50,8 +62,8 @@ class MNLILoader(Loader): | |||||
if raw_words1 and raw_words2 and target: | if raw_words1 and raw_words2 and target: | ||||
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) | ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) | ||||
return ds | return ds | ||||
def load(self, paths:str=None): | |||||
def load(self, paths: str = None): | |||||
""" | """ | ||||
:param str paths: 传入数据所在目录,会在该目录下寻找dev_matched.tsv, dev_mismatched.tsv, test_matched.tsv, | :param str paths: 传入数据所在目录,会在该目录下寻找dev_matched.tsv, dev_mismatched.tsv, test_matched.tsv, | ||||
@@ -64,13 +76,13 @@ class MNLILoader(Loader): | |||||
paths = self.download() | paths = self.download() | ||||
if not os.path.isdir(paths): | if not os.path.isdir(paths): | ||||
raise NotADirectoryError(f"{paths} is not a valid directory.") | raise NotADirectoryError(f"{paths} is not a valid directory.") | ||||
files = {'dev_matched':"dev_matched.tsv", | |||||
"dev_mismatched":"dev_mismatched.tsv", | |||||
"test_matched":"test_matched.tsv", | |||||
"test_mismatched":"test_mismatched.tsv", | |||||
"train":'train.tsv'} | |||||
files = {'dev_matched': "dev_matched.tsv", | |||||
"dev_mismatched": "dev_mismatched.tsv", | |||||
"test_matched": "test_matched.tsv", | |||||
"test_mismatched": "test_mismatched.tsv", | |||||
"train": 'train.tsv'} | |||||
datasets = {} | datasets = {} | ||||
for name, filename in files.items(): | for name, filename in files.items(): | ||||
filepath = os.path.join(paths, filename) | filepath = os.path.join(paths, filename) | ||||
@@ -78,11 +90,11 @@ class MNLILoader(Loader): | |||||
if 'test' not in name: | if 'test' not in name: | ||||
raise FileNotFoundError(f"{name} not found in directory {filepath}.") | raise FileNotFoundError(f"{name} not found in directory {filepath}.") | ||||
datasets[name] = self._load(filepath) | datasets[name] = self._load(filepath) | ||||
data_bundle = DataBundle(datasets=datasets) | data_bundle = DataBundle(datasets=datasets) | ||||
return data_bundle | return data_bundle | ||||
def download(self): | def download(self): | ||||
""" | """ | ||||
如果你使用了这个数据,请引用 | 如果你使用了这个数据,请引用 | ||||
@@ -106,14 +118,15 @@ class SNLILoader(JsonLoader): | |||||
"...", "...", "." | "...", "...", "." | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__(fields={ | super().__init__(fields={ | ||||
'sentence1': Const.RAW_WORDS(0), | 'sentence1': Const.RAW_WORDS(0), | ||||
'sentence2': Const.RAW_WORDS(1), | 'sentence2': Const.RAW_WORDS(1), | ||||
'gold_label': Const.TARGET, | 'gold_label': Const.TARGET, | ||||
}) | }) | ||||
def load(self, paths: Union[str, Dict[str, str]]=None) -> DataBundle: | |||||
def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: | |||||
""" | """ | ||||
从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。 | 从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。 | ||||
@@ -138,11 +151,11 @@ class SNLILoader(JsonLoader): | |||||
paths = _paths | paths = _paths | ||||
else: | else: | ||||
raise NotADirectoryError(f"{paths} is not a valid directory.") | raise NotADirectoryError(f"{paths} is not a valid directory.") | ||||
datasets = {name: self._load(path) for name, path in paths.items()} | datasets = {name: self._load(path) for name, path in paths.items()} | ||||
data_bundle = DataBundle(datasets=datasets) | data_bundle = DataBundle(datasets=datasets) | ||||
return data_bundle | return data_bundle | ||||
def download(self): | def download(self): | ||||
""" | """ | ||||
如果您的文章使用了这份数据,请引用 | 如果您的文章使用了这份数据,请引用 | ||||
@@ -169,12 +182,13 @@ class QNLILoader(JsonLoader): | |||||
test数据集没有target列 | test数据集没有target列 | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
def _load(self, path): | def _load(self, path): | ||||
ds = DataSet() | ds = DataSet() | ||||
with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
f.readline() # 跳过header | f.readline() # 跳过header | ||||
if path.endswith("test.tsv"): | if path.endswith("test.tsv"): | ||||
@@ -198,7 +212,7 @@ class QNLILoader(JsonLoader): | |||||
if raw_words1 and raw_words2 and target: | if raw_words1 and raw_words2 and target: | ||||
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) | ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) | ||||
return ds | return ds | ||||
def download(self): | def download(self): | ||||
""" | """ | ||||
如果您的实验使用到了该数据,请引用 | 如果您的实验使用到了该数据,请引用 | ||||
@@ -225,12 +239,13 @@ class RTELoader(Loader): | |||||
test数据集没有target列 | test数据集没有target列 | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
def _load(self, path:str): | |||||
def _load(self, path: str): | |||||
ds = DataSet() | ds = DataSet() | ||||
with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
f.readline() # 跳过header | f.readline() # 跳过header | ||||
if path.endswith("test.tsv"): | if path.endswith("test.tsv"): | ||||
@@ -254,7 +269,7 @@ class RTELoader(Loader): | |||||
if raw_words1 and raw_words2 and target: | if raw_words1 and raw_words2 and target: | ||||
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) | ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) | ||||
return ds | return ds | ||||
def download(self): | def download(self): | ||||
return self._get_dataset_path('rte') | return self._get_dataset_path('rte') | ||||
@@ -281,12 +296,13 @@ class QuoraLoader(Loader): | |||||
"...","." | "...","." | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
def _load(self, path:str): | |||||
def _load(self, path: str): | |||||
ds = DataSet() | ds = DataSet() | ||||
with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
for line in f: | for line in f: | ||||
line = line.strip() | line = line.strip() | ||||
@@ -298,6 +314,6 @@ class QuoraLoader(Loader): | |||||
if raw_words1 and raw_words2 and target: | if raw_words1 and raw_words2 and target: | ||||
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) | ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) | ||||
return ds | return ds | ||||
def download(self): | def download(self): | ||||
raise RuntimeError("Quora cannot be downloaded automatically.") | raise RuntimeError("Quora cannot be downloaded automatically.") |
@@ -1,26 +1,39 @@ | |||||
"""undocumented""" | |||||
__all__ = [ | |||||
"YelpFullPipe", | |||||
"YelpPolarityPipe", | |||||
"SSTPipe", | |||||
"SST2Pipe", | |||||
'IMDBPipe' | |||||
] | |||||
import re | |||||
from nltk import Tree | from nltk import Tree | ||||
from .pipe import Pipe | |||||
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance | |||||
from ..data_bundle import DataBundle | from ..data_bundle import DataBundle | ||||
from ...core.vocabulary import Vocabulary | |||||
from ...core.const import Const | |||||
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader | from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader | ||||
from ...core.const import Const | |||||
from ...core.dataset import DataSet | from ...core.dataset import DataSet | ||||
from ...core.instance import Instance | from ...core.instance import Instance | ||||
from ...core.vocabulary import Vocabulary | |||||
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance | |||||
from .pipe import Pipe | |||||
import re | |||||
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') | nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') | ||||
class _CLSPipe(Pipe): | class _CLSPipe(Pipe): | ||||
""" | """ | ||||
分类问题的基类,负责对classification的数据进行tokenize操作。默认是对raw_words列操作,然后生成words列 | 分类问题的基类,负责对classification的数据进行tokenize操作。默认是对raw_words列操作,然后生成words列 | ||||
""" | """ | ||||
def __init__(self, tokenizer:str='spacy', lang='en'): | |||||
def __init__(self, tokenizer: str = 'spacy', lang='en'): | |||||
self.tokenizer = get_tokenizer(tokenizer, lang=lang) | self.tokenizer = get_tokenizer(tokenizer, lang=lang) | ||||
def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): | def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): | ||||
""" | """ | ||||
将DataBundle中的数据进行tokenize | 将DataBundle中的数据进行tokenize | ||||
@@ -33,9 +46,9 @@ class _CLSPipe(Pipe): | |||||
new_field_name = new_field_name or field_name | new_field_name = new_field_name or field_name | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) | dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) | ||||
return data_bundle | return data_bundle | ||||
def _granularize(self, data_bundle, tag_map): | def _granularize(self, data_bundle, tag_map): | ||||
""" | """ | ||||
该函数对data_bundle中'target'列中的内容进行转换。 | 该函数对data_bundle中'target'列中的内容进行转换。 | ||||
@@ -47,9 +60,9 @@ class _CLSPipe(Pipe): | |||||
""" | """ | ||||
for name in list(data_bundle.datasets.keys()): | for name in list(data_bundle.datasets.keys()): | ||||
dataset = data_bundle.get_dataset(name) | dataset = data_bundle.get_dataset(name) | ||||
dataset.apply_field(lambda target:tag_map.get(target, -100), field_name=Const.TARGET, | |||||
dataset.apply_field(lambda target: tag_map.get(target, -100), field_name=Const.TARGET, | |||||
new_field_name=Const.TARGET) | new_field_name=Const.TARGET) | ||||
dataset.drop(lambda ins:ins[Const.TARGET] == -100) | |||||
dataset.drop(lambda ins: ins[Const.TARGET] == -100) | |||||
data_bundle.set_dataset(dataset, name) | data_bundle.set_dataset(dataset, name) | ||||
return data_bundle | return data_bundle | ||||
@@ -69,7 +82,7 @@ def _clean_str(words): | |||||
t = ''.join(tt) | t = ''.join(tt) | ||||
if t != '': | if t != '': | ||||
words_collection.append(t) | words_collection.append(t) | ||||
return words_collection | return words_collection | ||||
@@ -89,19 +102,20 @@ class YelpFullPipe(_CLSPipe): | |||||
1、2归为1类,3归为1类,4、5归为1类;若为5, 则有5分类问题。 | 1、2归为1类,3归为1类,4、5归为1类;若为5, 则有5分类问题。 | ||||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | ||||
""" | """ | ||||
def __init__(self, lower:bool=False, granularity=5, tokenizer:str='spacy'): | |||||
def __init__(self, lower: bool = False, granularity=5, tokenizer: str = 'spacy'): | |||||
super().__init__(tokenizer=tokenizer, lang='en') | super().__init__(tokenizer=tokenizer, lang='en') | ||||
self.lower = lower | self.lower = lower | ||||
assert granularity in (2, 3, 5), "granularity can only be 2,3,5." | assert granularity in (2, 3, 5), "granularity can only be 2,3,5." | ||||
self.granularity = granularity | self.granularity = granularity | ||||
if granularity==2: | |||||
if granularity == 2: | |||||
self.tag_map = {"1": 0, "2": 0, "4": 1, "5": 1} | self.tag_map = {"1": 0, "2": 0, "4": 1, "5": 1} | ||||
elif granularity==3: | |||||
self.tag_map = {"1": 0, "2": 0, "3":1, "4": 2, "5": 2} | |||||
elif granularity == 3: | |||||
self.tag_map = {"1": 0, "2": 0, "3": 1, "4": 2, "5": 2} | |||||
else: | else: | ||||
self.tag_map = {"1": 0, "2": 1, "3": 2, "4": 3, "5": 4} | self.tag_map = {"1": 0, "2": 1, "3": 2, "4": 3, "5": 4} | ||||
def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): | def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): | ||||
""" | """ | ||||
将DataBundle中的数据进行tokenize | 将DataBundle中的数据进行tokenize | ||||
@@ -116,7 +130,7 @@ class YelpFullPipe(_CLSPipe): | |||||
dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) | dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) | ||||
dataset.apply_field(_clean_str, field_name=field_name, new_field_name=new_field_name) | dataset.apply_field(_clean_str, field_name=field_name, new_field_name=new_field_name) | ||||
return data_bundle | return data_bundle | ||||
def process(self, data_bundle): | def process(self, data_bundle): | ||||
""" | """ | ||||
传入的DataSet应该具备如下的结构 | 传入的DataSet应该具备如下的结构 | ||||
@@ -131,30 +145,30 @@ class YelpFullPipe(_CLSPipe): | |||||
:param data_bundle: | :param data_bundle: | ||||
:return: | :return: | ||||
""" | """ | ||||
# 复制一列words | # 复制一列words | ||||
data_bundle = _add_words_field(data_bundle, lower=self.lower) | data_bundle = _add_words_field(data_bundle, lower=self.lower) | ||||
# 进行tokenize | # 进行tokenize | ||||
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) | data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) | ||||
# 根据granularity设置tag | # 根据granularity设置tag | ||||
data_bundle = self._granularize(data_bundle, tag_map=self.tag_map) | data_bundle = self._granularize(data_bundle, tag_map=self.tag_map) | ||||
# 删除空行 | # 删除空行 | ||||
data_bundle = _drop_empty_instance(data_bundle, field_name=Const.INPUT) | data_bundle = _drop_empty_instance(data_bundle, field_name=Const.INPUT) | ||||
# index | # index | ||||
data_bundle = _indexize(data_bundle=data_bundle) | data_bundle = _indexize(data_bundle=data_bundle) | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.add_seq_len(Const.INPUT) | dataset.add_seq_len(Const.INPUT) | ||||
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) | data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) | ||||
data_bundle.set_target(Const.TARGET) | data_bundle.set_target(Const.TARGET) | ||||
return data_bundle | return data_bundle | ||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
""" | """ | ||||
@@ -179,27 +193,28 @@ class YelpPolarityPipe(_CLSPipe): | |||||
:param bool lower: 是否对输入进行小写化。 | :param bool lower: 是否对输入进行小写化。 | ||||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | ||||
""" | """ | ||||
def __init__(self, lower:bool=False, tokenizer:str='spacy'): | |||||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||||
super().__init__(tokenizer=tokenizer, lang='en') | super().__init__(tokenizer=tokenizer, lang='en') | ||||
self.lower = lower | self.lower = lower | ||||
def process(self, data_bundle): | def process(self, data_bundle): | ||||
# 复制一列words | # 复制一列words | ||||
data_bundle = _add_words_field(data_bundle, lower=self.lower) | data_bundle = _add_words_field(data_bundle, lower=self.lower) | ||||
# 进行tokenize | # 进行tokenize | ||||
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) | data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) | ||||
# index | # index | ||||
data_bundle = _indexize(data_bundle=data_bundle) | data_bundle = _indexize(data_bundle=data_bundle) | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.add_seq_len(Const.INPUT) | dataset.add_seq_len(Const.INPUT) | ||||
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) | data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) | ||||
data_bundle.set_target(Const.TARGET) | data_bundle.set_target(Const.TARGET) | ||||
return data_bundle | return data_bundle | ||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
""" | """ | ||||
@@ -230,7 +245,7 @@ class SSTPipe(_CLSPipe): | |||||
0、1归为1类,2归为1类,3、4归为1类;若为5, 则有5分类问题。 | 0、1归为1类,2归为1类,3、4归为1类;若为5, 则有5分类问题。 | ||||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | ||||
""" | """ | ||||
def __init__(self, subtree=False, train_subtree=True, lower=False, granularity=5, tokenizer='spacy'): | def __init__(self, subtree=False, train_subtree=True, lower=False, granularity=5, tokenizer='spacy'): | ||||
super().__init__(tokenizer=tokenizer, lang='en') | super().__init__(tokenizer=tokenizer, lang='en') | ||||
self.subtree = subtree | self.subtree = subtree | ||||
@@ -238,15 +253,15 @@ class SSTPipe(_CLSPipe): | |||||
self.lower = lower | self.lower = lower | ||||
assert granularity in (2, 3, 5), "granularity can only be 2,3,5." | assert granularity in (2, 3, 5), "granularity can only be 2,3,5." | ||||
self.granularity = granularity | self.granularity = granularity | ||||
if granularity==2: | |||||
if granularity == 2: | |||||
self.tag_map = {"0": 0, "1": 0, "3": 1, "4": 1} | self.tag_map = {"0": 0, "1": 0, "3": 1, "4": 1} | ||||
elif granularity==3: | |||||
self.tag_map = {"0": 0, "1": 0, "2":1, "3": 2, "4": 2} | |||||
elif granularity == 3: | |||||
self.tag_map = {"0": 0, "1": 0, "2": 1, "3": 2, "4": 2} | |||||
else: | else: | ||||
self.tag_map = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4} | self.tag_map = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4} | ||||
def process(self, data_bundle:DataBundle): | |||||
def process(self, data_bundle: DataBundle): | |||||
""" | """ | ||||
对DataBundle中的数据进行预处理。输入的DataSet应该至少拥有raw_words这一列,且内容类似与 | 对DataBundle中的数据进行预处理。输入的DataSet应该至少拥有raw_words这一列,且内容类似与 | ||||
@@ -277,26 +292,26 @@ class SSTPipe(_CLSPipe): | |||||
instance = Instance(raw_words=' '.join(tree.leaves()), target=tree.label()) | instance = Instance(raw_words=' '.join(tree.leaves()), target=tree.label()) | ||||
ds.append(instance) | ds.append(instance) | ||||
data_bundle.set_dataset(ds, name) | data_bundle.set_dataset(ds, name) | ||||
_add_words_field(data_bundle, lower=self.lower) | _add_words_field(data_bundle, lower=self.lower) | ||||
# 进行tokenize | # 进行tokenize | ||||
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) | data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) | ||||
# 根据granularity设置tag | # 根据granularity设置tag | ||||
data_bundle = self._granularize(data_bundle, tag_map=self.tag_map) | data_bundle = self._granularize(data_bundle, tag_map=self.tag_map) | ||||
# index | # index | ||||
data_bundle = _indexize(data_bundle=data_bundle) | data_bundle = _indexize(data_bundle=data_bundle) | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.add_seq_len(Const.INPUT) | dataset.add_seq_len(Const.INPUT) | ||||
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) | data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) | ||||
data_bundle.set_target(Const.TARGET) | data_bundle.set_target(Const.TARGET) | ||||
return data_bundle | return data_bundle | ||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
data_bundle = SSTLoader().load(paths) | data_bundle = SSTLoader().load(paths) | ||||
return self.process(data_bundle=data_bundle) | return self.process(data_bundle=data_bundle) | ||||
@@ -316,11 +331,12 @@ class SST2Pipe(_CLSPipe): | |||||
:param bool lower: 是否对输入进行小写化。 | :param bool lower: 是否对输入进行小写化。 | ||||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | ||||
""" | """ | ||||
def __init__(self, lower=False, tokenizer='spacy'): | def __init__(self, lower=False, tokenizer='spacy'): | ||||
super().__init__(tokenizer=tokenizer, lang='en') | super().__init__(tokenizer=tokenizer, lang='en') | ||||
self.lower = lower | self.lower = lower | ||||
def process(self, data_bundle:DataBundle): | |||||
def process(self, data_bundle: DataBundle): | |||||
""" | """ | ||||
可以处理的DataSet应该具备如下的结构 | 可以处理的DataSet应该具备如下的结构 | ||||
@@ -335,15 +351,15 @@ class SST2Pipe(_CLSPipe): | |||||
:return: | :return: | ||||
""" | """ | ||||
_add_words_field(data_bundle, self.lower) | _add_words_field(data_bundle, self.lower) | ||||
data_bundle = self._tokenize(data_bundle=data_bundle) | data_bundle = self._tokenize(data_bundle=data_bundle) | ||||
src_vocab = Vocabulary() | src_vocab = Vocabulary() | ||||
src_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.INPUT, | src_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.INPUT, | ||||
no_create_entry_dataset=[dataset for name,dataset in data_bundle.datasets.items() if | |||||
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if | |||||
name != 'train']) | name != 'train']) | ||||
src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT) | src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT) | ||||
tgt_vocab = Vocabulary(unknown=None, padding=None) | tgt_vocab = Vocabulary(unknown=None, padding=None) | ||||
tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) | tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) | ||||
datasets = [] | datasets = [] | ||||
@@ -351,18 +367,18 @@ class SST2Pipe(_CLSPipe): | |||||
if dataset.has_field(Const.TARGET): | if dataset.has_field(Const.TARGET): | ||||
datasets.append(dataset) | datasets.append(dataset) | ||||
tgt_vocab.index_dataset(*datasets, field_name=Const.TARGET) | tgt_vocab.index_dataset(*datasets, field_name=Const.TARGET) | ||||
data_bundle.set_vocab(src_vocab, Const.INPUT) | data_bundle.set_vocab(src_vocab, Const.INPUT) | ||||
data_bundle.set_vocab(tgt_vocab, Const.TARGET) | data_bundle.set_vocab(tgt_vocab, Const.TARGET) | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.add_seq_len(Const.INPUT) | dataset.add_seq_len(Const.INPUT) | ||||
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) | data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) | ||||
data_bundle.set_target(Const.TARGET) | data_bundle.set_target(Const.TARGET) | ||||
return data_bundle | return data_bundle | ||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
""" | """ | ||||
@@ -390,11 +406,12 @@ class IMDBPipe(_CLSPipe): | |||||
:param bool lower: 是否将words列的数据小写。 | :param bool lower: 是否将words列的数据小写。 | ||||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | ||||
""" | """ | ||||
def __init__(self, lower:bool=False, tokenizer:str='spacy'): | |||||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||||
super().__init__(tokenizer=tokenizer, lang='en') | super().__init__(tokenizer=tokenizer, lang='en') | ||||
self.lower = lower | self.lower = lower | ||||
def process(self, data_bundle:DataBundle): | |||||
def process(self, data_bundle: DataBundle): | |||||
""" | """ | ||||
期待的DataBunlde中输入的DataSet应该类似于如下,有两个field,raw_words和target,且均为str类型 | 期待的DataBunlde中输入的DataSet应该类似于如下,有两个field,raw_words和target,且均为str类型 | ||||
@@ -409,25 +426,26 @@ class IMDBPipe(_CLSPipe): | |||||
target列应该为str。 | target列应该为str。 | ||||
:return: DataBundle | :return: DataBundle | ||||
""" | """ | ||||
# 替换<br /> | # 替换<br /> | ||||
def replace_br(raw_words): | def replace_br(raw_words): | ||||
raw_words = raw_words.replace("<br />", ' ') | raw_words = raw_words.replace("<br />", ' ') | ||||
return raw_words | return raw_words | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.apply_field(replace_br, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD) | dataset.apply_field(replace_br, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD) | ||||
_add_words_field(data_bundle, lower=self.lower) | _add_words_field(data_bundle, lower=self.lower) | ||||
self._tokenize(data_bundle, field_name=Const.INPUT, new_field_name=Const.INPUT) | self._tokenize(data_bundle, field_name=Const.INPUT, new_field_name=Const.INPUT) | ||||
_indexize(data_bundle) | _indexize(data_bundle) | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.add_seq_len(Const.INPUT) | dataset.add_seq_len(Const.INPUT) | ||||
dataset.set_input(Const.INPUT, Const.INPUT_LEN) | dataset.set_input(Const.INPUT, Const.INPUT_LEN) | ||||
dataset.set_target(Const.TARGET) | dataset.set_target(Const.TARGET) | ||||
return data_bundle | return data_bundle | ||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
""" | """ | ||||
@@ -437,8 +455,5 @@ class IMDBPipe(_CLSPipe): | |||||
# 读取数据 | # 读取数据 | ||||
data_bundle = IMDBLoader().load(paths) | data_bundle = IMDBLoader().load(paths) | ||||
data_bundle = self.process(data_bundle) | data_bundle = self.process(data_bundle) | ||||
return data_bundle | return data_bundle | ||||
@@ -1,13 +1,25 @@ | |||||
"""undocumented""" | |||||
__all__ = [ | |||||
"Conll2003NERPipe", | |||||
"Conll2003Pipe", | |||||
"OntoNotesNERPipe", | |||||
"MsraNERPipe", | |||||
"PeopleDailyPipe", | |||||
"WeiboNERPipe" | |||||
] | |||||
from .pipe import Pipe | from .pipe import Pipe | ||||
from .. import DataBundle | |||||
from .utils import _add_chars_field | |||||
from .utils import _indexize, _add_words_field | |||||
from .utils import iob2, iob2bioes | from .utils import iob2, iob2bioes | ||||
from ...core.const import Const | |||||
from .. import DataBundle | |||||
from ..loader.conll import Conll2003NERLoader, OntoNotesNERLoader | from ..loader.conll import Conll2003NERLoader, OntoNotesNERLoader | ||||
from .utils import _indexize, _add_words_field | |||||
from .utils import _add_chars_field | |||||
from ..loader.conll import PeopleDailyNERLoader, WeiboNERLoader, MsraNERLoader, ConllLoader | from ..loader.conll import PeopleDailyNERLoader, WeiboNERLoader, MsraNERLoader, ConllLoader | ||||
from ...core.const import Const | |||||
from ...core.vocabulary import Vocabulary | from ...core.vocabulary import Vocabulary | ||||
class _NERPipe(Pipe): | class _NERPipe(Pipe): | ||||
""" | """ | ||||
NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表 | NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表 | ||||
@@ -20,14 +32,14 @@ class _NERPipe(Pipe): | |||||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | ||||
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | ||||
""" | """ | ||||
def __init__(self, encoding_type: str = 'bio', lower: bool = False): | def __init__(self, encoding_type: str = 'bio', lower: bool = False): | ||||
if encoding_type == 'bio': | if encoding_type == 'bio': | ||||
self.convert_tag = iob2 | self.convert_tag = iob2 | ||||
else: | else: | ||||
self.convert_tag = lambda words: iob2bioes(iob2(words)) | self.convert_tag = lambda words: iob2bioes(iob2(words)) | ||||
self.lower = lower | self.lower = lower | ||||
def process(self, data_bundle: DataBundle) -> DataBundle: | def process(self, data_bundle: DataBundle) -> DataBundle: | ||||
""" | """ | ||||
支持的DataSet的field为 | 支持的DataSet的field为 | ||||
@@ -46,21 +58,21 @@ class _NERPipe(Pipe): | |||||
# 转换tag | # 转换tag | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET) | dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET) | ||||
_add_words_field(data_bundle, lower=self.lower) | _add_words_field(data_bundle, lower=self.lower) | ||||
# index | # index | ||||
_indexize(data_bundle) | _indexize(data_bundle) | ||||
input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] | input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] | ||||
target_fields = [Const.TARGET, Const.INPUT_LEN] | target_fields = [Const.TARGET, Const.INPUT_LEN] | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.add_seq_len(Const.INPUT) | dataset.add_seq_len(Const.INPUT) | ||||
data_bundle.set_input(*input_fields) | data_bundle.set_input(*input_fields) | ||||
data_bundle.set_target(*target_fields) | data_bundle.set_target(*target_fields) | ||||
return data_bundle | return data_bundle | ||||
@@ -84,7 +96,7 @@ class Conll2003NERPipe(_NERPipe): | |||||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | ||||
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | ||||
""" | """ | ||||
def process_from_file(self, paths) -> DataBundle: | def process_from_file(self, paths) -> DataBundle: | ||||
""" | """ | ||||
@@ -94,7 +106,7 @@ class Conll2003NERPipe(_NERPipe): | |||||
# 读取数据 | # 读取数据 | ||||
data_bundle = Conll2003NERLoader().load(paths) | data_bundle = Conll2003NERLoader().load(paths) | ||||
data_bundle = self.process(data_bundle) | data_bundle = self.process(data_bundle) | ||||
return data_bundle | return data_bundle | ||||
@@ -125,8 +137,8 @@ class Conll2003Pipe(Pipe): | |||||
else: | else: | ||||
self.ner_convert_tag = lambda tags: iob2bioes(iob2(tags)) | self.ner_convert_tag = lambda tags: iob2bioes(iob2(tags)) | ||||
self.lower = lower | self.lower = lower | ||||
def process(self, data_bundle)->DataBundle: | |||||
def process(self, data_bundle) -> DataBundle: | |||||
""" | """ | ||||
输入的DataSet应该类似于如下的形式 | 输入的DataSet应该类似于如下的形式 | ||||
@@ -145,9 +157,9 @@ class Conll2003Pipe(Pipe): | |||||
dataset.drop(lambda x: "-DOCSTART-" in x[Const.RAW_WORD]) | dataset.drop(lambda x: "-DOCSTART-" in x[Const.RAW_WORD]) | ||||
dataset.apply_field(self.chunk_convert_tag, field_name='chunk', new_field_name='chunk') | dataset.apply_field(self.chunk_convert_tag, field_name='chunk', new_field_name='chunk') | ||||
dataset.apply_field(self.ner_convert_tag, field_name='ner', new_field_name='ner') | dataset.apply_field(self.ner_convert_tag, field_name='ner', new_field_name='ner') | ||||
_add_words_field(data_bundle, lower=self.lower) | _add_words_field(data_bundle, lower=self.lower) | ||||
# index | # index | ||||
_indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=['pos', 'ner']) | _indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=['pos', 'ner']) | ||||
# chunk中存在一些tag只在dev中出现,没在train中 | # chunk中存在一些tag只在dev中出现,没在train中 | ||||
@@ -155,18 +167,18 @@ class Conll2003Pipe(Pipe): | |||||
tgt_vocab.from_dataset(*data_bundle.datasets.values(), field_name='chunk') | tgt_vocab.from_dataset(*data_bundle.datasets.values(), field_name='chunk') | ||||
tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name='chunk') | tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name='chunk') | ||||
data_bundle.set_vocab(tgt_vocab, 'chunk') | data_bundle.set_vocab(tgt_vocab, 'chunk') | ||||
input_fields = [Const.INPUT, Const.INPUT_LEN] | input_fields = [Const.INPUT, Const.INPUT_LEN] | ||||
target_fields = ['pos', 'ner', 'chunk', Const.INPUT_LEN] | target_fields = ['pos', 'ner', 'chunk', Const.INPUT_LEN] | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.add_seq_len(Const.INPUT) | dataset.add_seq_len(Const.INPUT) | ||||
data_bundle.set_input(*input_fields) | data_bundle.set_input(*input_fields) | ||||
data_bundle.set_target(*target_fields) | data_bundle.set_target(*target_fields) | ||||
return data_bundle | return data_bundle | ||||
def process_from_file(self, paths): | def process_from_file(self, paths): | ||||
""" | """ | ||||
@@ -194,7 +206,7 @@ class OntoNotesNERPipe(_NERPipe): | |||||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | ||||
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | ||||
""" | """ | ||||
def process_from_file(self, paths): | def process_from_file(self, paths): | ||||
data_bundle = OntoNotesNERLoader().load(paths) | data_bundle = OntoNotesNERLoader().load(paths) | ||||
return self.process(data_bundle) | return self.process(data_bundle) | ||||
@@ -211,13 +223,13 @@ class _CNNERPipe(Pipe): | |||||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | ||||
""" | """ | ||||
def __init__(self, encoding_type: str = 'bio'): | def __init__(self, encoding_type: str = 'bio'): | ||||
if encoding_type == 'bio': | if encoding_type == 'bio': | ||||
self.convert_tag = iob2 | self.convert_tag = iob2 | ||||
else: | else: | ||||
self.convert_tag = lambda words: iob2bioes(iob2(words)) | self.convert_tag = lambda words: iob2bioes(iob2(words)) | ||||
def process(self, data_bundle: DataBundle) -> DataBundle: | def process(self, data_bundle: DataBundle) -> DataBundle: | ||||
""" | """ | ||||
支持的DataSet的field为 | 支持的DataSet的field为 | ||||
@@ -239,21 +251,21 @@ class _CNNERPipe(Pipe): | |||||
# 转换tag | # 转换tag | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET) | dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET) | ||||
_add_chars_field(data_bundle, lower=False) | _add_chars_field(data_bundle, lower=False) | ||||
# index | # index | ||||
_indexize(data_bundle, input_field_names=Const.CHAR_INPUT, target_field_names=Const.TARGET) | _indexize(data_bundle, input_field_names=Const.CHAR_INPUT, target_field_names=Const.TARGET) | ||||
input_fields = [Const.TARGET, Const.CHAR_INPUT, Const.INPUT_LEN] | input_fields = [Const.TARGET, Const.CHAR_INPUT, Const.INPUT_LEN] | ||||
target_fields = [Const.TARGET, Const.INPUT_LEN] | target_fields = [Const.TARGET, Const.INPUT_LEN] | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.add_seq_len(Const.CHAR_INPUT) | dataset.add_seq_len(Const.CHAR_INPUT) | ||||
data_bundle.set_input(*input_fields) | data_bundle.set_input(*input_fields) | ||||
data_bundle.set_target(*target_fields) | data_bundle.set_target(*target_fields) | ||||
return data_bundle | return data_bundle | ||||
@@ -272,6 +284,7 @@ class MsraNERPipe(_CNNERPipe): | |||||
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | ||||
""" | """ | ||||
def process_from_file(self, paths=None) -> DataBundle: | def process_from_file(self, paths=None) -> DataBundle: | ||||
data_bundle = MsraNERLoader().load(paths) | data_bundle = MsraNERLoader().load(paths) | ||||
return self.process(data_bundle) | return self.process(data_bundle) | ||||
@@ -291,6 +304,7 @@ class PeopleDailyPipe(_CNNERPipe): | |||||
raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | ||||
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | ||||
""" | """ | ||||
def process_from_file(self, paths=None) -> DataBundle: | def process_from_file(self, paths=None) -> DataBundle: | ||||
data_bundle = PeopleDailyNERLoader().load(paths) | data_bundle = PeopleDailyNERLoader().load(paths) | ||||
return self.process(data_bundle) | return self.process(data_bundle) | ||||
@@ -312,6 +326,7 @@ class WeiboNERPipe(_CNNERPipe): | |||||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | ||||
""" | """ | ||||
def process_from_file(self, paths=None) -> DataBundle: | def process_from_file(self, paths=None) -> DataBundle: | ||||
data_bundle = WeiboNERLoader().load(paths) | data_bundle = WeiboNERLoader().load(paths) | ||||
return self.process(data_bundle) | return self.process(data_bundle) |
@@ -1,3 +1,9 @@ | |||||
"""undocumented""" | |||||
__all__ = [ | |||||
"CWSPipe" | |||||
] | |||||
import re | import re | ||||
from itertools import chain | from itertools import chain | ||||
@@ -1,9 +1,25 @@ | |||||
"""undocumented""" | |||||
__all__ = [ | |||||
"MatchingBertPipe", | |||||
"RTEBertPipe", | |||||
"SNLIBertPipe", | |||||
"QuoraBertPipe", | |||||
"QNLIBertPipe", | |||||
"MNLIBertPipe", | |||||
"MatchingPipe", | |||||
"RTEPipe", | |||||
"SNLIPipe", | |||||
"QuoraPipe", | |||||
"QNLIPipe", | |||||
"MNLIPipe", | |||||
] | |||||
from .pipe import Pipe | from .pipe import Pipe | ||||
from .utils import get_tokenizer | from .utils import get_tokenizer | ||||
from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader | |||||
from ...core.const import Const | from ...core.const import Const | ||||
from ...core.vocabulary import Vocabulary | from ...core.vocabulary import Vocabulary | ||||
from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader | |||||
class MatchingBertPipe(Pipe): | class MatchingBertPipe(Pipe): | ||||
@@ -24,12 +40,13 @@ class MatchingBertPipe(Pipe): | |||||
:param bool lower: 是否将word小写化。 | :param bool lower: 是否将word小写化。 | ||||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | ||||
""" | """ | ||||
def __init__(self, lower=False, tokenizer: str='raw'): | |||||
def __init__(self, lower=False, tokenizer: str = 'raw'): | |||||
super().__init__() | super().__init__() | ||||
self.lower = bool(lower) | self.lower = bool(lower) | ||||
self.tokenizer = get_tokenizer(tokenizer=tokenizer) | self.tokenizer = get_tokenizer(tokenizer=tokenizer) | ||||
def _tokenize(self, data_bundle, field_names, new_field_names): | def _tokenize(self, data_bundle, field_names, new_field_names): | ||||
""" | """ | ||||
@@ -43,62 +60,62 @@ class MatchingBertPipe(Pipe): | |||||
dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, | dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, | ||||
new_field_name=new_field_name) | new_field_name=new_field_name) | ||||
return data_bundle | return data_bundle | ||||
def process(self, data_bundle): | def process(self, data_bundle): | ||||
for dataset in data_bundle.datasets.values(): | for dataset in data_bundle.datasets.values(): | ||||
if dataset.has_field(Const.TARGET): | if dataset.has_field(Const.TARGET): | ||||
dataset.drop(lambda x: x[Const.TARGET] == '-') | dataset.drop(lambda x: x[Const.TARGET] == '-') | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.copy_field(Const.RAW_WORDS(0), Const.INPUTS(0), ) | dataset.copy_field(Const.RAW_WORDS(0), Const.INPUTS(0), ) | ||||
dataset.copy_field(Const.RAW_WORDS(1), Const.INPUTS(1), ) | dataset.copy_field(Const.RAW_WORDS(1), Const.INPUTS(1), ) | ||||
if self.lower: | if self.lower: | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset[Const.INPUTS(0)].lower() | dataset[Const.INPUTS(0)].lower() | ||||
dataset[Const.INPUTS(1)].lower() | dataset[Const.INPUTS(1)].lower() | ||||
data_bundle = self._tokenize(data_bundle, [Const.INPUTS(0), Const.INPUTS(1)], | data_bundle = self._tokenize(data_bundle, [Const.INPUTS(0), Const.INPUTS(1)], | ||||
[Const.INPUTS(0), Const.INPUTS(1)]) | [Const.INPUTS(0), Const.INPUTS(1)]) | ||||
# concat两个words | # concat两个words | ||||
def concat(ins): | def concat(ins): | ||||
words0 = ins[Const.INPUTS(0)] | words0 = ins[Const.INPUTS(0)] | ||||
words1 = ins[Const.INPUTS(1)] | words1 = ins[Const.INPUTS(1)] | ||||
words = words0 + ['[SEP]'] + words1 | words = words0 + ['[SEP]'] + words1 | ||||
return words | return words | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.apply(concat, new_field_name=Const.INPUT) | dataset.apply(concat, new_field_name=Const.INPUT) | ||||
dataset.delete_field(Const.INPUTS(0)) | dataset.delete_field(Const.INPUTS(0)) | ||||
dataset.delete_field(Const.INPUTS(1)) | dataset.delete_field(Const.INPUTS(1)) | ||||
word_vocab = Vocabulary() | word_vocab = Vocabulary() | ||||
word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name], | word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name], | ||||
field_name=Const.INPUT, | field_name=Const.INPUT, | ||||
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if | no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if | ||||
'train' not in name]) | 'train' not in name]) | ||||
word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT) | word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT) | ||||
target_vocab = Vocabulary(padding=None, unknown=None) | target_vocab = Vocabulary(padding=None, unknown=None) | ||||
target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) | target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) | ||||
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if | has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if | ||||
dataset.has_field(Const.TARGET)] | dataset.has_field(Const.TARGET)] | ||||
target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) | target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) | ||||
data_bundle.set_vocab(word_vocab, Const.INPUT) | data_bundle.set_vocab(word_vocab, Const.INPUT) | ||||
data_bundle.set_vocab(target_vocab, Const.TARGET) | data_bundle.set_vocab(target_vocab, Const.TARGET) | ||||
input_fields = [Const.INPUT, Const.INPUT_LEN] | input_fields = [Const.INPUT, Const.INPUT_LEN] | ||||
target_fields = [Const.TARGET] | target_fields = [Const.TARGET] | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.add_seq_len(Const.INPUT) | dataset.add_seq_len(Const.INPUT) | ||||
dataset.set_input(*input_fields, flag=True) | dataset.set_input(*input_fields, flag=True) | ||||
for fields in target_fields: | for fields in target_fields: | ||||
if dataset.has_field(fields): | if dataset.has_field(fields): | ||||
dataset.set_target(fields, flag=True) | dataset.set_target(fields, flag=True) | ||||
return data_bundle | return data_bundle | ||||
@@ -150,12 +167,13 @@ class MatchingPipe(Pipe): | |||||
:param bool lower: 是否将所有raw_words转为小写。 | :param bool lower: 是否将所有raw_words转为小写。 | ||||
:param str tokenizer: 将原始数据tokenize的方式。支持spacy, raw. spacy是使用spacy切分,raw就是用空格切分。 | :param str tokenizer: 将原始数据tokenize的方式。支持spacy, raw. spacy是使用spacy切分,raw就是用空格切分。 | ||||
""" | """ | ||||
def __init__(self, lower=False, tokenizer: str='raw'): | |||||
def __init__(self, lower=False, tokenizer: str = 'raw'): | |||||
super().__init__() | super().__init__() | ||||
self.lower = bool(lower) | self.lower = bool(lower) | ||||
self.tokenizer = get_tokenizer(tokenizer=tokenizer) | self.tokenizer = get_tokenizer(tokenizer=tokenizer) | ||||
def _tokenize(self, data_bundle, field_names, new_field_names): | def _tokenize(self, data_bundle, field_names, new_field_names): | ||||
""" | """ | ||||
@@ -169,7 +187,7 @@ class MatchingPipe(Pipe): | |||||
dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, | dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, | ||||
new_field_name=new_field_name) | new_field_name=new_field_name) | ||||
return data_bundle | return data_bundle | ||||
def process(self, data_bundle): | def process(self, data_bundle): | ||||
""" | """ | ||||
接受的DataBundle中的DataSet应该具有以下的field, target列可以没有 | 接受的DataBundle中的DataSet应该具有以下的field, target列可以没有 | ||||
@@ -186,35 +204,35 @@ class MatchingPipe(Pipe): | |||||
""" | """ | ||||
data_bundle = self._tokenize(data_bundle, [Const.RAW_WORDS(0), Const.RAW_WORDS(1)], | data_bundle = self._tokenize(data_bundle, [Const.RAW_WORDS(0), Const.RAW_WORDS(1)], | ||||
[Const.INPUTS(0), Const.INPUTS(1)]) | [Const.INPUTS(0), Const.INPUTS(1)]) | ||||
for dataset in data_bundle.datasets.values(): | for dataset in data_bundle.datasets.values(): | ||||
if dataset.has_field(Const.TARGET): | if dataset.has_field(Const.TARGET): | ||||
dataset.drop(lambda x: x[Const.TARGET] == '-') | dataset.drop(lambda x: x[Const.TARGET] == '-') | ||||
if self.lower: | if self.lower: | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset[Const.INPUTS(0)].lower() | dataset[Const.INPUTS(0)].lower() | ||||
dataset[Const.INPUTS(1)].lower() | dataset[Const.INPUTS(1)].lower() | ||||
word_vocab = Vocabulary() | word_vocab = Vocabulary() | ||||
word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name], | word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name], | ||||
field_name=[Const.INPUTS(0), Const.INPUTS(1)], | field_name=[Const.INPUTS(0), Const.INPUTS(1)], | ||||
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if | no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if | ||||
'train' not in name]) | 'train' not in name]) | ||||
word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=[Const.INPUTS(0), Const.INPUTS(1)]) | word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=[Const.INPUTS(0), Const.INPUTS(1)]) | ||||
target_vocab = Vocabulary(padding=None, unknown=None) | target_vocab = Vocabulary(padding=None, unknown=None) | ||||
target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) | target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) | ||||
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if | has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if | ||||
dataset.has_field(Const.TARGET)] | dataset.has_field(Const.TARGET)] | ||||
target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) | target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) | ||||
data_bundle.set_vocab(word_vocab, Const.INPUTS(0)) | data_bundle.set_vocab(word_vocab, Const.INPUTS(0)) | ||||
data_bundle.set_vocab(target_vocab, Const.TARGET) | data_bundle.set_vocab(target_vocab, Const.TARGET) | ||||
input_fields = [Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1)] | input_fields = [Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1)] | ||||
target_fields = [Const.TARGET] | target_fields = [Const.TARGET] | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.add_seq_len(Const.INPUTS(0), Const.INPUT_LENS(0)) | dataset.add_seq_len(Const.INPUTS(0), Const.INPUT_LENS(0)) | ||||
dataset.add_seq_len(Const.INPUTS(1), Const.INPUT_LENS(1)) | dataset.add_seq_len(Const.INPUTS(1), Const.INPUT_LENS(1)) | ||||
@@ -222,7 +240,7 @@ class MatchingPipe(Pipe): | |||||
for fields in target_fields: | for fields in target_fields: | ||||
if dataset.has_field(fields): | if dataset.has_field(fields): | ||||
dataset.set_target(fields, flag=True) | dataset.set_target(fields, flag=True) | ||||
return data_bundle | return data_bundle | ||||
@@ -254,4 +272,3 @@ class MNLIPipe(MatchingPipe): | |||||
def process_from_file(self, paths=None): | def process_from_file(self, paths=None): | ||||
data_bundle = MNLILoader().load(paths) | data_bundle = MNLILoader().load(paths) | ||||
return self.process(data_bundle) | return self.process(data_bundle) | ||||
@@ -1,3 +1,9 @@ | |||||
"""undocumented""" | |||||
__all__ = [ | |||||
"Pipe", | |||||
] | |||||
from .. import DataBundle | from .. import DataBundle | ||||
@@ -1,8 +1,18 @@ | |||||
"""undocumented""" | |||||
__all__ = [ | |||||
"iob2", | |||||
"iob2bioes", | |||||
"get_tokenizer", | |||||
] | |||||
from typing import List | from typing import List | ||||
from ...core.vocabulary import Vocabulary | |||||
from ...core.const import Const | from ...core.const import Const | ||||
from ...core.vocabulary import Vocabulary | |||||
def iob2(tags:List[str])->List[str]: | |||||
def iob2(tags: List[str]) -> List[str]: | |||||
""" | """ | ||||
检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两种格式的区别见 | 检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两种格式的区别见 | ||||
https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format | https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format | ||||
@@ -25,7 +35,8 @@ def iob2(tags:List[str])->List[str]: | |||||
tags[i] = "B" + tag[1:] | tags[i] = "B" + tag[1:] | ||||
return tags | return tags | ||||
def iob2bioes(tags:List[str])->List[str]: | |||||
def iob2bioes(tags: List[str]) -> List[str]: | |||||
""" | """ | ||||
将iob的tag转换为bioes编码 | 将iob的tag转换为bioes编码 | ||||
:param tags: | :param tags: | ||||
@@ -38,12 +49,12 @@ def iob2bioes(tags:List[str])->List[str]: | |||||
else: | else: | ||||
split = tag.split('-')[0] | split = tag.split('-')[0] | ||||
if split == 'B': | if split == 'B': | ||||
if i+1!=len(tags) and tags[i+1].split('-')[0] == 'I': | |||||
if i + 1 != len(tags) and tags[i + 1].split('-')[0] == 'I': | |||||
new_tags.append(tag) | new_tags.append(tag) | ||||
else: | else: | ||||
new_tags.append(tag.replace('B-', 'S-')) | new_tags.append(tag.replace('B-', 'S-')) | ||||
elif split == 'I': | elif split == 'I': | ||||
if i + 1<len(tags) and tags[i+1].split('-')[0] == 'I': | |||||
if i + 1 < len(tags) and tags[i + 1].split('-')[0] == 'I': | |||||
new_tags.append(tag) | new_tags.append(tag) | ||||
else: | else: | ||||
new_tags.append(tag.replace('I-', 'E-')) | new_tags.append(tag.replace('I-', 'E-')) | ||||
@@ -52,7 +63,7 @@ def iob2bioes(tags:List[str])->List[str]: | |||||
return new_tags | return new_tags | ||||
def get_tokenizer(tokenizer:str, lang='en'): | |||||
def get_tokenizer(tokenizer: str, lang='en'): | |||||
""" | """ | ||||
:param str tokenizer: 获取tokenzier方法 | :param str tokenizer: 获取tokenzier方法 | ||||
@@ -97,13 +108,13 @@ def _indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=Con | |||||
name != 'train']) | name != 'train']) | ||||
src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=input_field_name) | src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=input_field_name) | ||||
data_bundle.set_vocab(src_vocab, input_field_name) | data_bundle.set_vocab(src_vocab, input_field_name) | ||||
for target_field_name in target_field_names: | for target_field_name in target_field_names: | ||||
tgt_vocab = Vocabulary(unknown=None, padding=None) | tgt_vocab = Vocabulary(unknown=None, padding=None) | ||||
tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=target_field_name) | tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=target_field_name) | ||||
tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name=target_field_name) | tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name=target_field_name) | ||||
data_bundle.set_vocab(tgt_vocab, target_field_name) | data_bundle.set_vocab(tgt_vocab, target_field_name) | ||||
return data_bundle | return data_bundle | ||||
@@ -116,7 +127,7 @@ def _add_words_field(data_bundle, lower=False): | |||||
:return: 传入的DataBundle | :return: 传入的DataBundle | ||||
""" | """ | ||||
data_bundle.copy_field(field_name=Const.RAW_WORD, new_field_name=Const.INPUT, ignore_miss_dataset=True) | data_bundle.copy_field(field_name=Const.RAW_WORD, new_field_name=Const.INPUT, ignore_miss_dataset=True) | ||||
if lower: | if lower: | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset[Const.INPUT].lower() | dataset[Const.INPUT].lower() | ||||
@@ -132,7 +143,7 @@ def _add_chars_field(data_bundle, lower=False): | |||||
:return: 传入的DataBundle | :return: 传入的DataBundle | ||||
""" | """ | ||||
data_bundle.copy_field(field_name=Const.RAW_CHAR, new_field_name=Const.CHAR_INPUT, ignore_miss_dataset=True) | data_bundle.copy_field(field_name=Const.RAW_CHAR, new_field_name=Const.CHAR_INPUT, ignore_miss_dataset=True) | ||||
if lower: | if lower: | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset[Const.CHAR_INPUT].lower() | dataset[Const.CHAR_INPUT].lower() | ||||
@@ -147,6 +158,7 @@ def _drop_empty_instance(data_bundle, field_name): | |||||
:param str field_name: 对哪个field进行检查,如果为None,则任意field为空都会删掉 | :param str field_name: 对哪个field进行检查,如果为None,则任意field为空都会删掉 | ||||
:return: 传入的DataBundle | :return: 传入的DataBundle | ||||
""" | """ | ||||
def empty_instance(ins): | def empty_instance(ins): | ||||
if field_name: | if field_name: | ||||
field_value = ins[field_name] | field_value = ins[field_name] | ||||
@@ -157,10 +169,8 @@ def _drop_empty_instance(data_bundle, field_name): | |||||
if field_value in ((), {}, [], ''): | if field_value in ((), {}, [], ''): | ||||
return True | return True | ||||
return False | return False | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.drop(empty_instance) | dataset.drop(empty_instance) | ||||
return data_bundle | return data_bundle | ||||
@@ -1,10 +1,20 @@ | |||||
import os | |||||
""" | |||||
.. todo:: | |||||
doc | |||||
""" | |||||
from typing import Union, Dict | |||||
__all__ = [ | |||||
"check_loader_paths" | |||||
] | |||||
import os | |||||
from pathlib import Path | from pathlib import Path | ||||
from typing import Union, Dict | |||||
from ..core import logger | from ..core import logger | ||||
def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||||
def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]: | |||||
""" | """ | ||||
检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果:: | 检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果:: | ||||
@@ -33,11 +43,13 @@ def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||||
path_pair = ('train', filename) | path_pair = ('train', filename) | ||||
if 'dev' in filename: | if 'dev' in filename: | ||||
if path_pair: | if path_pair: | ||||
raise Exception("File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0])) | |||||
raise Exception( | |||||
"File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0])) | |||||
path_pair = ('dev', filename) | path_pair = ('dev', filename) | ||||
if 'test' in filename: | if 'test' in filename: | ||||
if path_pair: | if path_pair: | ||||
raise Exception("File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0])) | |||||
raise Exception( | |||||
"File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0])) | |||||
path_pair = ('test', filename) | path_pair = ('test', filename) | ||||
if path_pair: | if path_pair: | ||||
files[path_pair[0]] = os.path.join(paths, path_pair[1]) | files[path_pair[0]] = os.path.join(paths, path_pair[1]) | ||||
@@ -46,7 +58,7 @@ def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||||
return files | return files | ||||
else: | else: | ||||
raise FileNotFoundError(f"{paths} is not a valid file path.") | raise FileNotFoundError(f"{paths} is not a valid file path.") | ||||
elif isinstance(paths, dict): | elif isinstance(paths, dict): | ||||
if paths: | if paths: | ||||
if 'train' not in paths: | if 'train' not in paths: | ||||
@@ -65,6 +77,7 @@ def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||||
else: | else: | ||||
raise TypeError(f"paths only supports str and dict. not {type(paths)}.") | raise TypeError(f"paths only supports str and dict. not {type(paths)}.") | ||||
def get_tokenizer(): | def get_tokenizer(): | ||||
try: | try: | ||||
import spacy | import spacy | ||||