Browse Source

add __all__ and __doc__ for all files in module 'io', using 'undocumented' tags

tags/v0.4.10
ChenXin 5 years ago
parent
commit
0d5f43b451
19 changed files with 439 additions and 254 deletions
  1. +6
    -1
      fastNLP/io/data_bundle.py
  2. +3
    -3
      fastNLP/io/dataset_loader.py
  3. +7
    -2
      fastNLP/io/embed_loader.py
  4. +11
    -5
      fastNLP/io/file_reader.py
  5. +19
    -4
      fastNLP/io/file_utils.py
  6. +19
    -7
      fastNLP/io/loader/classification.py
  7. +51
    -33
      fastNLP/io/loader/conll.py
  8. +8
    -2
      fastNLP/io/loader/csv.py
  9. +12
    -5
      fastNLP/io/loader/cws.py
  10. +8
    -2
      fastNLP/io/loader/json.py
  11. +10
    -3
      fastNLP/io/loader/loader.py
  12. +49
    -33
      fastNLP/io/loader/matching.py
  13. +88
    -73
      fastNLP/io/pipe/classification.py
  14. +47
    -32
      fastNLP/io/pipe/conll.py
  15. +6
    -0
      fastNLP/io/pipe/cws.py
  16. +46
    -29
      fastNLP/io/pipe/matching.py
  17. +6
    -0
      fastNLP/io/pipe/pipe.py
  18. +24
    -14
      fastNLP/io/pipe/utils.py
  19. +19
    -6
      fastNLP/io/utils.py

+ 6
- 1
fastNLP/io/data_bundle.py View File

@@ -1,10 +1,15 @@
"""
.. todo::
doc
"""
__all__ = [ __all__ = [
'DataBundle', 'DataBundle',
] ]


import _pickle as pickle import _pickle as pickle
from typing import Union, Dict
import os import os
from typing import Union, Dict

from ..core.dataset import DataSet from ..core.dataset import DataSet
from ..core.vocabulary import Vocabulary from ..core.vocabulary import Vocabulary




+ 3
- 3
fastNLP/io/dataset_loader.py View File

@@ -1,4 +1,4 @@
"""
"""undocumented
.. warning:: .. warning::


本模块将在 `0.5.0版本` 中被废弃,由 :mod:`~fastNLP.io.loader` 和 :mod:`~fastNLP.io.pipe` 模块替代。 本模块将在 `0.5.0版本` 中被废弃,由 :mod:`~fastNLP.io.loader` 和 :mod:`~fastNLP.io.pipe` 模块替代。
@@ -23,10 +23,10 @@ __all__ = [
] ]




from .data_bundle import DataSetLoader
from .file_reader import _read_csv, _read_json
from ..core.dataset import DataSet from ..core.dataset import DataSet
from ..core.instance import Instance from ..core.instance import Instance
from .file_reader import _read_csv, _read_json
from .data_bundle import DataSetLoader




class JsonLoader(DataSetLoader): class JsonLoader(DataSetLoader):


+ 7
- 2
fastNLP/io/embed_loader.py View File

@@ -1,17 +1,22 @@
"""
.. todo::
doc
"""
__all__ = [ __all__ = [
"EmbedLoader", "EmbedLoader",
"EmbeddingOption", "EmbeddingOption",
] ]


import logging
import os import os
import warnings import warnings


import numpy as np import numpy as np


from ..core.vocabulary import Vocabulary
from .data_bundle import BaseLoader from .data_bundle import BaseLoader
from ..core.utils import Option from ..core.utils import Option
import logging
from ..core.vocabulary import Vocabulary



class EmbeddingOption(Option): class EmbeddingOption(Option):
def __init__(self, def __init__(self,


+ 11
- 5
fastNLP/io/file_reader.py View File

@@ -1,7 +1,11 @@
"""
"""undocumented
此模块用于给其它模块提供读取文件的函数,没有为用户提供 API 此模块用于给其它模块提供读取文件的函数,没有为用户提供 API
""" """

__all__ = []

import json import json

from ..core import logger from ..core import logger




@@ -24,8 +28,8 @@ def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True):
headers = headers.split(sep) headers = headers.split(sep)
start_idx += 1 start_idx += 1
elif not isinstance(headers, (list, tuple)): elif not isinstance(headers, (list, tuple)):
raise TypeError("headers should be list or tuple, not {}." \
.format(type(headers)))
raise TypeError("headers should be list or tuple, not {}." \
.format(type(headers)))
for line_idx, line in enumerate(f, start_idx): for line_idx, line in enumerate(f, start_idx):
contents = line.rstrip('\r\n').split(sep) contents = line.rstrip('\r\n').split(sep)
if len(contents) != len(headers): if len(contents) != len(headers):
@@ -82,6 +86,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
:if False, raise ValueError when reading invalid data. default: True :if False, raise ValueError when reading invalid data. default: True
:return: generator, every time yield (line number, conll item) :return: generator, every time yield (line number, conll item)
""" """
def parse_conll(sample): def parse_conll(sample):
sample = list(map(list, zip(*sample))) sample = list(map(list, zip(*sample)))
sample = [sample[i] for i in indexes] sample = [sample[i] for i in indexes]
@@ -89,14 +94,15 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
if len(f) <= 0: if len(f) <= 0:
raise ValueError('empty field') raise ValueError('empty field')
return sample return sample
with open(path, 'r', encoding=encoding) as f: with open(path, 'r', encoding=encoding) as f:
sample = [] sample = []
start = next(f).strip() start = next(f).strip()
if start!='':
if start != '':
sample.append(start.split()) sample.append(start.split())
for line_idx, line in enumerate(f, 1): for line_idx, line in enumerate(f, 1):
line = line.strip() line = line.strip()
if line=='':
if line == '':
if len(sample): if len(sample):
try: try:
res = parse_conll(sample) res = parse_conll(sample)


+ 19
- 4
fastNLP/io/file_utils.py View File

@@ -1,12 +1,27 @@
"""
.. todo::
doc
"""

__all__ = [
"cached_path",
"get_filepath",
"get_cache_path",
"split_filename_suffix",
"get_from_cache",
]

import os import os
import re
import shutil
import tempfile
from pathlib import Path from pathlib import Path
from urllib.parse import urlparse from urllib.parse import urlparse
import re
import requests import requests
import tempfile
from tqdm import tqdm
import shutil
from requests import HTTPError from requests import HTTPError
from tqdm import tqdm

from ..core import logger from ..core import logger


PRETRAINED_BERT_MODEL_DIR = { PRETRAINED_BERT_MODEL_DIR = {


+ 19
- 7
fastNLP/io/loader/classification.py View File

@@ -1,12 +1,24 @@
from ...core.dataset import DataSet
from ...core.instance import Instance
from .loader import Loader
import warnings
"""undocumented"""

__all__ = [
"YelpLoader",
"YelpFullLoader",
"YelpPolarityLoader",
"IMDBLoader",
"SSTLoader",
"SST2Loader",
]

import glob
import os import os
import random import random
import shutil import shutil
import glob
import time import time
import warnings

from .loader import Loader
from ...core.dataset import DataSet
from ...core.instance import Instance




class YelpLoader(Loader): class YelpLoader(Loader):
@@ -58,7 +70,7 @@ class YelpLoader(Loader):




class YelpFullLoader(YelpLoader): class YelpFullLoader(YelpLoader):
def download(self, dev_ratio: float = 0.1, re_download:bool=False):
def download(self, dev_ratio: float = 0.1, re_download: bool = False):
""" """
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 自动下载数据集,如果你使用了这个数据集,请引用以下的文章


@@ -127,7 +139,7 @@ class YelpPolarityLoader(YelpLoader):
if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的
shutil.rmtree(data_dir) shutil.rmtree(data_dir)
data_dir = self._get_dataset_path(dataset_name=dataset_name) data_dir = self._get_dataset_path(dataset_name=dataset_name)
if not os.path.exists(os.path.join(data_dir, 'dev.csv')): if not os.path.exists(os.path.join(data_dir, 'dev.csv')):
if dev_ratio > 0: if dev_ratio > 0:
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."


+ 51
- 33
fastNLP/io/loader/conll.py View File

@@ -1,15 +1,28 @@
from typing import Dict, Union
"""undocumented"""

__all__ = [
"ConllLoader",
"Conll2003Loader",
"Conll2003NERLoader",
"OntoNotesNERLoader",
"CTBLoader",
"CNNERLoader",
"MsraNERLoader",
"WeiboNERLoader",
"PeopleDailyNERLoader"
]


from .loader import Loader
from ...core.dataset import DataSet
from ..file_reader import _read_conll
from ...core.instance import Instance
from ...core.const import Const
import glob import glob
import os import os
import random
import shutil import shutil
import time import time
import random

from .loader import Loader
from ..file_reader import _read_conll
from ...core.const import Const
from ...core.dataset import DataSet
from ...core.instance import Instance




class ConllLoader(Loader): class ConllLoader(Loader):
@@ -47,6 +60,7 @@ class ConllLoader(Loader):
:param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` :param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True``


""" """
def __init__(self, headers, indexes=None, dropna=True): def __init__(self, headers, indexes=None, dropna=True):
super(ConllLoader, self).__init__() super(ConllLoader, self).__init__()
if not isinstance(headers, (list, tuple)): if not isinstance(headers, (list, tuple)):
@@ -60,7 +74,7 @@ class ConllLoader(Loader):
if len(indexes) != len(headers): if len(indexes) != len(headers):
raise ValueError raise ValueError
self.indexes = indexes self.indexes = indexes
def _load(self, path): def _load(self, path):
""" """
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。
@@ -101,12 +115,13 @@ class Conll2003Loader(ConllLoader):
"[...]", "[...]", "[...]", "[...]" "[...]", "[...]", "[...]", "[...]"


""" """
def __init__(self): def __init__(self):
headers = [ headers = [
'raw_words', 'pos', 'chunk', 'ner', 'raw_words', 'pos', 'chunk', 'ner',
] ]
super(Conll2003Loader, self).__init__(headers=headers) super(Conll2003Loader, self).__init__(headers=headers)
def _load(self, path): def _load(self, path):
""" """
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。
@@ -127,7 +142,7 @@ class Conll2003Loader(ConllLoader):
ins = {h: data[i] for i, h in enumerate(self.headers)} ins = {h: data[i] for i, h in enumerate(self.headers)}
ds.append(Instance(**ins)) ds.append(Instance(**ins))
return ds return ds
def download(self, output_dir=None): def download(self, output_dir=None):
raise RuntimeError("conll2003 cannot be downloaded automatically.") raise RuntimeError("conll2003 cannot be downloaded automatically.")


@@ -158,12 +173,13 @@ class Conll2003NERLoader(ConllLoader):
"[...]", "[...]" "[...]", "[...]"


""" """
def __init__(self): def __init__(self):
headers = [ headers = [
'raw_words', 'target', 'raw_words', 'target',
] ]
super().__init__(headers=headers, indexes=[0, 3]) super().__init__(headers=headers, indexes=[0, 3])
def _load(self, path): def _load(self, path):
""" """
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。
@@ -184,7 +200,7 @@ class Conll2003NERLoader(ConllLoader):
ins = {h: data[i] for i, h in enumerate(self.headers)} ins = {h: data[i] for i, h in enumerate(self.headers)}
ds.append(Instance(**ins)) ds.append(Instance(**ins))
return ds return ds
def download(self): def download(self):
raise RuntimeError("conll2003 cannot be downloaded automatically.") raise RuntimeError("conll2003 cannot be downloaded automatically.")


@@ -204,13 +220,13 @@ class OntoNotesNERLoader(ConllLoader):
"[...]", "[...]" "[...]", "[...]"


""" """
def __init__(self): def __init__(self):
super().__init__(headers=[Const.RAW_WORD, Const.TARGET], indexes=[3, 10]) super().__init__(headers=[Const.RAW_WORD, Const.TARGET], indexes=[3, 10])
def _load(self, path:str):
def _load(self, path: str):
dataset = super()._load(path) dataset = super()._load(path)
def convert_to_bio(tags): def convert_to_bio(tags):
bio_tags = [] bio_tags = []
flag = None flag = None
@@ -227,7 +243,7 @@ class OntoNotesNERLoader(ConllLoader):
flag = None flag = None
bio_tags.append(bio_label) bio_tags.append(bio_label)
return bio_tags return bio_tags
def convert_word(words): def convert_word(words):
converted_words = [] converted_words = []
for word in words: for word in words:
@@ -236,7 +252,7 @@ class OntoNotesNERLoader(ConllLoader):
converted_words.append(word) converted_words.append(word)
continue continue
# 以下是由于这些符号被转义了,再转回来 # 以下是由于这些符号被转义了,再转回来
tfrs = {'-LRB-':'(',
tfrs = {'-LRB-': '(',
'-RRB-': ')', '-RRB-': ')',
'-LSB-': '[', '-LSB-': '[',
'-RSB-': ']', '-RSB-': ']',
@@ -248,12 +264,12 @@ class OntoNotesNERLoader(ConllLoader):
else: else:
converted_words.append(word) converted_words.append(word)
return converted_words return converted_words
dataset.apply_field(convert_word, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD) dataset.apply_field(convert_word, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD)
dataset.apply_field(convert_to_bio, field_name=Const.TARGET, new_field_name=Const.TARGET) dataset.apply_field(convert_to_bio, field_name=Const.TARGET, new_field_name=Const.TARGET)
return dataset return dataset
def download(self): def download(self):
raise RuntimeError("Ontonotes cannot be downloaded automatically, you can refer " raise RuntimeError("Ontonotes cannot be downloaded automatically, you can refer "
"https://github.com/yhcc/OntoNotes-5.0-NER to download and preprocess.") "https://github.com/yhcc/OntoNotes-5.0-NER to download and preprocess.")
@@ -262,13 +278,13 @@ class OntoNotesNERLoader(ConllLoader):
class CTBLoader(Loader): class CTBLoader(Loader):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def _load(self, path:str):
def _load(self, path: str):
pass pass




class CNNERLoader(Loader): class CNNERLoader(Loader):
def _load(self, path:str):
def _load(self, path: str):
""" """
支持加载形如以下格式的内容,一行两列,以空格隔开两个sample 支持加载形如以下格式的内容,一行两列,以空格隔开两个sample


@@ -331,10 +347,11 @@ class MsraNERLoader(CNNERLoader):
"[...]", "[...]" "[...]", "[...]"


""" """
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def download(self, dev_ratio:float=0.1, re_download:bool=False)->str:
def download(self, dev_ratio: float = 0.1, re_download: bool = False) -> str:
""" """
自动下载MSAR-NER的数据,如果你使用该数据,请引用 Gina-Anne Levow, 2006, The Third International Chinese Language 自动下载MSAR-NER的数据,如果你使用该数据,请引用 Gina-Anne Levow, 2006, The Third International Chinese Language
Processing Bakeoff: Word Segmentation and Named Entity Recognition. Processing Bakeoff: Word Segmentation and Named Entity Recognition.
@@ -356,7 +373,7 @@ class MsraNERLoader(CNNERLoader):
if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的
shutil.rmtree(data_dir) shutil.rmtree(data_dir)
data_dir = self._get_dataset_path(dataset_name=dataset_name) data_dir = self._get_dataset_path(dataset_name=dataset_name)
if not os.path.exists(os.path.join(data_dir, 'dev.conll')): if not os.path.exists(os.path.join(data_dir, 'dev.conll')):
if dev_ratio > 0: if dev_ratio > 0:
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."
@@ -380,15 +397,15 @@ class MsraNERLoader(CNNERLoader):
finally: finally:
if os.path.exists(os.path.join(data_dir, 'middle_file.conll')): if os.path.exists(os.path.join(data_dir, 'middle_file.conll')):
os.remove(os.path.join(data_dir, 'middle_file.conll')) os.remove(os.path.join(data_dir, 'middle_file.conll'))
return data_dir return data_dir




class WeiboNERLoader(CNNERLoader): class WeiboNERLoader(CNNERLoader):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def download(self)->str:
def download(self) -> str:
""" """
自动下载Weibo-NER的数据,如果你使用了该数据,请引用 Nanyun Peng and Mark Dredze, 2015, Named Entity Recognition for 自动下载Weibo-NER的数据,如果你使用了该数据,请引用 Nanyun Peng and Mark Dredze, 2015, Named Entity Recognition for
Chinese Social Media with Jointly Trained Embeddings. Chinese Social Media with Jointly Trained Embeddings.
@@ -397,7 +414,7 @@ class WeiboNERLoader(CNNERLoader):
""" """
dataset_name = 'weibo-ner' dataset_name = 'weibo-ner'
data_dir = self._get_dataset_path(dataset_name=dataset_name) data_dir = self._get_dataset_path(dataset_name=dataset_name)
return data_dir return data_dir




@@ -427,11 +444,12 @@ class PeopleDailyNERLoader(CNNERLoader):
"[...]", "[...]" "[...]", "[...]"


""" """
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def download(self) -> str: def download(self) -> str:
dataset_name = 'peopledaily' dataset_name = 'peopledaily'
data_dir = self._get_dataset_path(dataset_name=dataset_name) data_dir = self._get_dataset_path(dataset_name=dataset_name)
return data_dir return data_dir

+ 8
- 2
fastNLP/io/loader/csv.py View File

@@ -1,7 +1,13 @@
"""undocumented"""

__all__ = [
"CSVLoader",
]

from .loader import Loader
from ..file_reader import _read_csv
from ...core.dataset import DataSet from ...core.dataset import DataSet
from ...core.instance import Instance from ...core.instance import Instance
from ..file_reader import _read_csv
from .loader import Loader




class CSVLoader(Loader): class CSVLoader(Loader):


+ 12
- 5
fastNLP/io/loader/cws.py View File

@@ -1,11 +1,18 @@
from .loader import Loader
from ...core.dataset import DataSet
from ...core.instance import Instance
"""undocumented"""

__all__ = [
"CWSLoader"
]

import glob import glob
import os import os
import time
import shutil
import random import random
import shutil
import time

from .loader import Loader
from ...core.dataset import DataSet
from ...core.instance import Instance




class CWSLoader(Loader): class CWSLoader(Loader):


+ 8
- 2
fastNLP/io/loader/json.py View File

@@ -1,7 +1,13 @@
"""undocumented"""

__all__ = [
"JsonLoader"
]

from .loader import Loader
from ..file_reader import _read_json
from ...core.dataset import DataSet from ...core.dataset import DataSet
from ...core.instance import Instance from ...core.instance import Instance
from ..file_reader import _read_json
from .loader import Loader




class JsonLoader(Loader): class JsonLoader(Loader):


+ 10
- 3
fastNLP/io/loader/loader.py View File

@@ -1,8 +1,15 @@
from ...core.dataset import DataSet
from .. import DataBundle
from ..utils import check_loader_paths
"""undocumented"""

__all__ = [
"Loader"
]

from typing import Union, Dict from typing import Union, Dict

from .. import DataBundle
from ..file_utils import _get_dataset_url, get_cache_path, cached_path from ..file_utils import _get_dataset_url, get_cache_path, cached_path
from ..utils import check_loader_paths
from ...core.dataset import DataSet




class Loader: class Loader:


+ 49
- 33
fastNLP/io/loader/matching.py View File

@@ -1,10 +1,21 @@
"""undocumented"""

__all__ = [
"MNLILoader",
"SNLILoader",
"QNLILoader",
"RTELoader",
"QuoraLoader",
]

import os
import warnings import warnings
from .loader import Loader
from typing import Union, Dict

from .json import JsonLoader from .json import JsonLoader
from ...core.const import Const
from .loader import Loader
from .. import DataBundle from .. import DataBundle
import os
from typing import Union, Dict
from ...core.const import Const
from ...core.dataset import DataSet from ...core.dataset import DataSet
from ...core.instance import Instance from ...core.instance import Instance


@@ -22,10 +33,11 @@ class MNLILoader(Loader):
"...", "...","." "...", "...","."


""" """
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def _load(self, path:str):
def _load(self, path: str):
ds = DataSet() ds = DataSet()
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header f.readline() # 跳过header
@@ -50,8 +62,8 @@ class MNLILoader(Loader):
if raw_words1 and raw_words2 and target: if raw_words1 and raw_words2 and target:
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target))
return ds return ds
def load(self, paths:str=None):
def load(self, paths: str = None):
""" """


:param str paths: 传入数据所在目录,会在该目录下寻找dev_matched.tsv, dev_mismatched.tsv, test_matched.tsv, :param str paths: 传入数据所在目录,会在该目录下寻找dev_matched.tsv, dev_mismatched.tsv, test_matched.tsv,
@@ -64,13 +76,13 @@ class MNLILoader(Loader):
paths = self.download() paths = self.download()
if not os.path.isdir(paths): if not os.path.isdir(paths):
raise NotADirectoryError(f"{paths} is not a valid directory.") raise NotADirectoryError(f"{paths} is not a valid directory.")
files = {'dev_matched':"dev_matched.tsv",
"dev_mismatched":"dev_mismatched.tsv",
"test_matched":"test_matched.tsv",
"test_mismatched":"test_mismatched.tsv",
"train":'train.tsv'}
files = {'dev_matched': "dev_matched.tsv",
"dev_mismatched": "dev_mismatched.tsv",
"test_matched": "test_matched.tsv",
"test_mismatched": "test_mismatched.tsv",
"train": 'train.tsv'}
datasets = {} datasets = {}
for name, filename in files.items(): for name, filename in files.items():
filepath = os.path.join(paths, filename) filepath = os.path.join(paths, filename)
@@ -78,11 +90,11 @@ class MNLILoader(Loader):
if 'test' not in name: if 'test' not in name:
raise FileNotFoundError(f"{name} not found in directory {filepath}.") raise FileNotFoundError(f"{name} not found in directory {filepath}.")
datasets[name] = self._load(filepath) datasets[name] = self._load(filepath)
data_bundle = DataBundle(datasets=datasets) data_bundle = DataBundle(datasets=datasets)
return data_bundle return data_bundle
def download(self): def download(self):
""" """
如果你使用了这个数据,请引用 如果你使用了这个数据,请引用
@@ -106,14 +118,15 @@ class SNLILoader(JsonLoader):
"...", "...", "." "...", "...", "."


""" """
def __init__(self): def __init__(self):
super().__init__(fields={ super().__init__(fields={
'sentence1': Const.RAW_WORDS(0), 'sentence1': Const.RAW_WORDS(0),
'sentence2': Const.RAW_WORDS(1), 'sentence2': Const.RAW_WORDS(1),
'gold_label': Const.TARGET, 'gold_label': Const.TARGET,
}) })
def load(self, paths: Union[str, Dict[str, str]]=None) -> DataBundle:
def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle:
""" """
从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。 从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。


@@ -138,11 +151,11 @@ class SNLILoader(JsonLoader):
paths = _paths paths = _paths
else: else:
raise NotADirectoryError(f"{paths} is not a valid directory.") raise NotADirectoryError(f"{paths} is not a valid directory.")
datasets = {name: self._load(path) for name, path in paths.items()} datasets = {name: self._load(path) for name, path in paths.items()}
data_bundle = DataBundle(datasets=datasets) data_bundle = DataBundle(datasets=datasets)
return data_bundle return data_bundle
def download(self): def download(self):
""" """
如果您的文章使用了这份数据,请引用 如果您的文章使用了这份数据,请引用
@@ -169,12 +182,13 @@ class QNLILoader(JsonLoader):
test数据集没有target列 test数据集没有target列


""" """
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def _load(self, path): def _load(self, path):
ds = DataSet() ds = DataSet()
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header f.readline() # 跳过header
if path.endswith("test.tsv"): if path.endswith("test.tsv"):
@@ -198,7 +212,7 @@ class QNLILoader(JsonLoader):
if raw_words1 and raw_words2 and target: if raw_words1 and raw_words2 and target:
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target))
return ds return ds
def download(self): def download(self):
""" """
如果您的实验使用到了该数据,请引用 如果您的实验使用到了该数据,请引用
@@ -225,12 +239,13 @@ class RTELoader(Loader):


test数据集没有target列 test数据集没有target列
""" """
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def _load(self, path:str):
def _load(self, path: str):
ds = DataSet() ds = DataSet()
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header f.readline() # 跳过header
if path.endswith("test.tsv"): if path.endswith("test.tsv"):
@@ -254,7 +269,7 @@ class RTELoader(Loader):
if raw_words1 and raw_words2 and target: if raw_words1 and raw_words2 and target:
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target))
return ds return ds
def download(self): def download(self):
return self._get_dataset_path('rte') return self._get_dataset_path('rte')


@@ -281,12 +296,13 @@ class QuoraLoader(Loader):
"...","." "...","."


""" """
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def _load(self, path:str):
def _load(self, path: str):
ds = DataSet() ds = DataSet()
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
for line in f: for line in f:
line = line.strip() line = line.strip()
@@ -298,6 +314,6 @@ class QuoraLoader(Loader):
if raw_words1 and raw_words2 and target: if raw_words1 and raw_words2 and target:
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target))
return ds return ds
def download(self): def download(self):
raise RuntimeError("Quora cannot be downloaded automatically.") raise RuntimeError("Quora cannot be downloaded automatically.")

+ 88
- 73
fastNLP/io/pipe/classification.py View File

@@ -1,26 +1,39 @@
"""undocumented"""

__all__ = [
"YelpFullPipe",
"YelpPolarityPipe",
"SSTPipe",
"SST2Pipe",
'IMDBPipe'
]

import re

from nltk import Tree from nltk import Tree


from .pipe import Pipe
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance
from ..data_bundle import DataBundle from ..data_bundle import DataBundle
from ...core.vocabulary import Vocabulary
from ...core.const import Const
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader
from ...core.const import Const
from ...core.dataset import DataSet from ...core.dataset import DataSet
from ...core.instance import Instance from ...core.instance import Instance
from ...core.vocabulary import Vocabulary


from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance
from .pipe import Pipe
import re
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') nonalpnum = re.compile('[^0-9a-zA-Z?!\']+')





class _CLSPipe(Pipe): class _CLSPipe(Pipe):
""" """
分类问题的基类,负责对classification的数据进行tokenize操作。默认是对raw_words列操作,然后生成words列 分类问题的基类,负责对classification的数据进行tokenize操作。默认是对raw_words列操作,然后生成words列


""" """
def __init__(self, tokenizer:str='spacy', lang='en'):
def __init__(self, tokenizer: str = 'spacy', lang='en'):
self.tokenizer = get_tokenizer(tokenizer, lang=lang) self.tokenizer = get_tokenizer(tokenizer, lang=lang)
def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None):
""" """
将DataBundle中的数据进行tokenize 将DataBundle中的数据进行tokenize
@@ -33,9 +46,9 @@ class _CLSPipe(Pipe):
new_field_name = new_field_name or field_name new_field_name = new_field_name or field_name
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name)
return data_bundle return data_bundle
def _granularize(self, data_bundle, tag_map): def _granularize(self, data_bundle, tag_map):
""" """
该函数对data_bundle中'target'列中的内容进行转换。 该函数对data_bundle中'target'列中的内容进行转换。
@@ -47,9 +60,9 @@ class _CLSPipe(Pipe):
""" """
for name in list(data_bundle.datasets.keys()): for name in list(data_bundle.datasets.keys()):
dataset = data_bundle.get_dataset(name) dataset = data_bundle.get_dataset(name)
dataset.apply_field(lambda target:tag_map.get(target, -100), field_name=Const.TARGET,
dataset.apply_field(lambda target: tag_map.get(target, -100), field_name=Const.TARGET,
new_field_name=Const.TARGET) new_field_name=Const.TARGET)
dataset.drop(lambda ins:ins[Const.TARGET] == -100)
dataset.drop(lambda ins: ins[Const.TARGET] == -100)
data_bundle.set_dataset(dataset, name) data_bundle.set_dataset(dataset, name)
return data_bundle return data_bundle


@@ -69,7 +82,7 @@ def _clean_str(words):
t = ''.join(tt) t = ''.join(tt)
if t != '': if t != '':
words_collection.append(t) words_collection.append(t)
return words_collection return words_collection




@@ -89,19 +102,20 @@ class YelpFullPipe(_CLSPipe):
1、2归为1类,3归为1类,4、5归为1类;若为5, 则有5分类问题。 1、2归为1类,3归为1类,4、5归为1类;若为5, 则有5分类问题。
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。
""" """
def __init__(self, lower:bool=False, granularity=5, tokenizer:str='spacy'):
def __init__(self, lower: bool = False, granularity=5, tokenizer: str = 'spacy'):
super().__init__(tokenizer=tokenizer, lang='en') super().__init__(tokenizer=tokenizer, lang='en')
self.lower = lower self.lower = lower
assert granularity in (2, 3, 5), "granularity can only be 2,3,5." assert granularity in (2, 3, 5), "granularity can only be 2,3,5."
self.granularity = granularity self.granularity = granularity
if granularity==2:
if granularity == 2:
self.tag_map = {"1": 0, "2": 0, "4": 1, "5": 1} self.tag_map = {"1": 0, "2": 0, "4": 1, "5": 1}
elif granularity==3:
self.tag_map = {"1": 0, "2": 0, "3":1, "4": 2, "5": 2}
elif granularity == 3:
self.tag_map = {"1": 0, "2": 0, "3": 1, "4": 2, "5": 2}
else: else:
self.tag_map = {"1": 0, "2": 1, "3": 2, "4": 3, "5": 4} self.tag_map = {"1": 0, "2": 1, "3": 2, "4": 3, "5": 4}
def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None):
""" """
将DataBundle中的数据进行tokenize 将DataBundle中的数据进行tokenize
@@ -116,7 +130,7 @@ class YelpFullPipe(_CLSPipe):
dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name)
dataset.apply_field(_clean_str, field_name=field_name, new_field_name=new_field_name) dataset.apply_field(_clean_str, field_name=field_name, new_field_name=new_field_name)
return data_bundle return data_bundle
def process(self, data_bundle): def process(self, data_bundle):
""" """
传入的DataSet应该具备如下的结构 传入的DataSet应该具备如下的结构
@@ -131,30 +145,30 @@ class YelpFullPipe(_CLSPipe):
:param data_bundle: :param data_bundle:
:return: :return:
""" """
# 复制一列words # 复制一列words
data_bundle = _add_words_field(data_bundle, lower=self.lower) data_bundle = _add_words_field(data_bundle, lower=self.lower)
# 进行tokenize # 进行tokenize
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT)
# 根据granularity设置tag # 根据granularity设置tag
data_bundle = self._granularize(data_bundle, tag_map=self.tag_map) data_bundle = self._granularize(data_bundle, tag_map=self.tag_map)
# 删除空行 # 删除空行
data_bundle = _drop_empty_instance(data_bundle, field_name=Const.INPUT) data_bundle = _drop_empty_instance(data_bundle, field_name=Const.INPUT)
# index # index
data_bundle = _indexize(data_bundle=data_bundle) data_bundle = _indexize(data_bundle=data_bundle)
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT) dataset.add_seq_len(Const.INPUT)
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) data_bundle.set_input(Const.INPUT, Const.INPUT_LEN)
data_bundle.set_target(Const.TARGET) data_bundle.set_target(Const.TARGET)
return data_bundle return data_bundle
def process_from_file(self, paths=None): def process_from_file(self, paths=None):
""" """


@@ -179,27 +193,28 @@ class YelpPolarityPipe(_CLSPipe):
:param bool lower: 是否对输入进行小写化。 :param bool lower: 是否对输入进行小写化。
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。
""" """
def __init__(self, lower:bool=False, tokenizer:str='spacy'):
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'):
super().__init__(tokenizer=tokenizer, lang='en') super().__init__(tokenizer=tokenizer, lang='en')
self.lower = lower self.lower = lower
def process(self, data_bundle): def process(self, data_bundle):
# 复制一列words # 复制一列words
data_bundle = _add_words_field(data_bundle, lower=self.lower) data_bundle = _add_words_field(data_bundle, lower=self.lower)
# 进行tokenize # 进行tokenize
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT)
# index # index
data_bundle = _indexize(data_bundle=data_bundle) data_bundle = _indexize(data_bundle=data_bundle)
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT) dataset.add_seq_len(Const.INPUT)
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) data_bundle.set_input(Const.INPUT, Const.INPUT_LEN)
data_bundle.set_target(Const.TARGET) data_bundle.set_target(Const.TARGET)
return data_bundle return data_bundle
def process_from_file(self, paths=None): def process_from_file(self, paths=None):
""" """


@@ -230,7 +245,7 @@ class SSTPipe(_CLSPipe):
0、1归为1类,2归为1类,3、4归为1类;若为5, 则有5分类问题。 0、1归为1类,2归为1类,3、4归为1类;若为5, 则有5分类问题。
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。
""" """
def __init__(self, subtree=False, train_subtree=True, lower=False, granularity=5, tokenizer='spacy'): def __init__(self, subtree=False, train_subtree=True, lower=False, granularity=5, tokenizer='spacy'):
super().__init__(tokenizer=tokenizer, lang='en') super().__init__(tokenizer=tokenizer, lang='en')
self.subtree = subtree self.subtree = subtree
@@ -238,15 +253,15 @@ class SSTPipe(_CLSPipe):
self.lower = lower self.lower = lower
assert granularity in (2, 3, 5), "granularity can only be 2,3,5." assert granularity in (2, 3, 5), "granularity can only be 2,3,5."
self.granularity = granularity self.granularity = granularity
if granularity==2:
if granularity == 2:
self.tag_map = {"0": 0, "1": 0, "3": 1, "4": 1} self.tag_map = {"0": 0, "1": 0, "3": 1, "4": 1}
elif granularity==3:
self.tag_map = {"0": 0, "1": 0, "2":1, "3": 2, "4": 2}
elif granularity == 3:
self.tag_map = {"0": 0, "1": 0, "2": 1, "3": 2, "4": 2}
else: else:
self.tag_map = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4} self.tag_map = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4}
def process(self, data_bundle:DataBundle):
def process(self, data_bundle: DataBundle):
""" """
对DataBundle中的数据进行预处理。输入的DataSet应该至少拥有raw_words这一列,且内容类似与 对DataBundle中的数据进行预处理。输入的DataSet应该至少拥有raw_words这一列,且内容类似与


@@ -277,26 +292,26 @@ class SSTPipe(_CLSPipe):
instance = Instance(raw_words=' '.join(tree.leaves()), target=tree.label()) instance = Instance(raw_words=' '.join(tree.leaves()), target=tree.label())
ds.append(instance) ds.append(instance)
data_bundle.set_dataset(ds, name) data_bundle.set_dataset(ds, name)
_add_words_field(data_bundle, lower=self.lower) _add_words_field(data_bundle, lower=self.lower)
# 进行tokenize # 进行tokenize
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT)
# 根据granularity设置tag # 根据granularity设置tag
data_bundle = self._granularize(data_bundle, tag_map=self.tag_map) data_bundle = self._granularize(data_bundle, tag_map=self.tag_map)
# index # index
data_bundle = _indexize(data_bundle=data_bundle) data_bundle = _indexize(data_bundle=data_bundle)
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT) dataset.add_seq_len(Const.INPUT)
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) data_bundle.set_input(Const.INPUT, Const.INPUT_LEN)
data_bundle.set_target(Const.TARGET) data_bundle.set_target(Const.TARGET)
return data_bundle return data_bundle
def process_from_file(self, paths=None): def process_from_file(self, paths=None):
data_bundle = SSTLoader().load(paths) data_bundle = SSTLoader().load(paths)
return self.process(data_bundle=data_bundle) return self.process(data_bundle=data_bundle)
@@ -316,11 +331,12 @@ class SST2Pipe(_CLSPipe):
:param bool lower: 是否对输入进行小写化。 :param bool lower: 是否对输入进行小写化。
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。
""" """
def __init__(self, lower=False, tokenizer='spacy'): def __init__(self, lower=False, tokenizer='spacy'):
super().__init__(tokenizer=tokenizer, lang='en') super().__init__(tokenizer=tokenizer, lang='en')
self.lower = lower self.lower = lower
def process(self, data_bundle:DataBundle):
def process(self, data_bundle: DataBundle):
""" """
可以处理的DataSet应该具备如下的结构 可以处理的DataSet应该具备如下的结构


@@ -335,15 +351,15 @@ class SST2Pipe(_CLSPipe):
:return: :return:
""" """
_add_words_field(data_bundle, self.lower) _add_words_field(data_bundle, self.lower)
data_bundle = self._tokenize(data_bundle=data_bundle) data_bundle = self._tokenize(data_bundle=data_bundle)
src_vocab = Vocabulary() src_vocab = Vocabulary()
src_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.INPUT, src_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.INPUT,
no_create_entry_dataset=[dataset for name,dataset in data_bundle.datasets.items() if
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if
name != 'train']) name != 'train'])
src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT) src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT)
tgt_vocab = Vocabulary(unknown=None, padding=None) tgt_vocab = Vocabulary(unknown=None, padding=None)
tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET)
datasets = [] datasets = []
@@ -351,18 +367,18 @@ class SST2Pipe(_CLSPipe):
if dataset.has_field(Const.TARGET): if dataset.has_field(Const.TARGET):
datasets.append(dataset) datasets.append(dataset)
tgt_vocab.index_dataset(*datasets, field_name=Const.TARGET) tgt_vocab.index_dataset(*datasets, field_name=Const.TARGET)
data_bundle.set_vocab(src_vocab, Const.INPUT) data_bundle.set_vocab(src_vocab, Const.INPUT)
data_bundle.set_vocab(tgt_vocab, Const.TARGET) data_bundle.set_vocab(tgt_vocab, Const.TARGET)
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT) dataset.add_seq_len(Const.INPUT)
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) data_bundle.set_input(Const.INPUT, Const.INPUT_LEN)
data_bundle.set_target(Const.TARGET) data_bundle.set_target(Const.TARGET)
return data_bundle return data_bundle
def process_from_file(self, paths=None): def process_from_file(self, paths=None):
""" """


@@ -390,11 +406,12 @@ class IMDBPipe(_CLSPipe):
:param bool lower: 是否将words列的数据小写。 :param bool lower: 是否将words列的数据小写。
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。
""" """
def __init__(self, lower:bool=False, tokenizer:str='spacy'):
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'):
super().__init__(tokenizer=tokenizer, lang='en') super().__init__(tokenizer=tokenizer, lang='en')
self.lower = lower self.lower = lower
def process(self, data_bundle:DataBundle):
def process(self, data_bundle: DataBundle):
""" """
期待的DataBunlde中输入的DataSet应该类似于如下,有两个field,raw_words和target,且均为str类型 期待的DataBunlde中输入的DataSet应该类似于如下,有两个field,raw_words和target,且均为str类型


@@ -409,25 +426,26 @@ class IMDBPipe(_CLSPipe):
target列应该为str。 target列应该为str。
:return: DataBundle :return: DataBundle
""" """
# 替换<br /> # 替换<br />
def replace_br(raw_words): def replace_br(raw_words):
raw_words = raw_words.replace("<br />", ' ') raw_words = raw_words.replace("<br />", ' ')
return raw_words return raw_words
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.apply_field(replace_br, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD) dataset.apply_field(replace_br, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD)
_add_words_field(data_bundle, lower=self.lower) _add_words_field(data_bundle, lower=self.lower)
self._tokenize(data_bundle, field_name=Const.INPUT, new_field_name=Const.INPUT) self._tokenize(data_bundle, field_name=Const.INPUT, new_field_name=Const.INPUT)
_indexize(data_bundle) _indexize(data_bundle)
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT) dataset.add_seq_len(Const.INPUT)
dataset.set_input(Const.INPUT, Const.INPUT_LEN) dataset.set_input(Const.INPUT, Const.INPUT_LEN)
dataset.set_target(Const.TARGET) dataset.set_target(Const.TARGET)
return data_bundle return data_bundle
def process_from_file(self, paths=None): def process_from_file(self, paths=None):
""" """


@@ -437,8 +455,5 @@ class IMDBPipe(_CLSPipe):
# 读取数据 # 读取数据
data_bundle = IMDBLoader().load(paths) data_bundle = IMDBLoader().load(paths)
data_bundle = self.process(data_bundle) data_bundle = self.process(data_bundle)
return data_bundle return data_bundle




+ 47
- 32
fastNLP/io/pipe/conll.py View File

@@ -1,13 +1,25 @@
"""undocumented"""

__all__ = [
"Conll2003NERPipe",
"Conll2003Pipe",
"OntoNotesNERPipe",
"MsraNERPipe",
"PeopleDailyPipe",
"WeiboNERPipe"
]

from .pipe import Pipe from .pipe import Pipe
from .. import DataBundle
from .utils import _add_chars_field
from .utils import _indexize, _add_words_field
from .utils import iob2, iob2bioes from .utils import iob2, iob2bioes
from ...core.const import Const
from .. import DataBundle
from ..loader.conll import Conll2003NERLoader, OntoNotesNERLoader from ..loader.conll import Conll2003NERLoader, OntoNotesNERLoader
from .utils import _indexize, _add_words_field
from .utils import _add_chars_field
from ..loader.conll import PeopleDailyNERLoader, WeiboNERLoader, MsraNERLoader, ConllLoader from ..loader.conll import PeopleDailyNERLoader, WeiboNERLoader, MsraNERLoader, ConllLoader
from ...core.const import Const
from ...core.vocabulary import Vocabulary from ...core.vocabulary import Vocabulary



class _NERPipe(Pipe): class _NERPipe(Pipe):
""" """
NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表 NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表
@@ -20,14 +32,14 @@ class _NERPipe(Pipe):
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。
""" """
def __init__(self, encoding_type: str = 'bio', lower: bool = False): def __init__(self, encoding_type: str = 'bio', lower: bool = False):
if encoding_type == 'bio': if encoding_type == 'bio':
self.convert_tag = iob2 self.convert_tag = iob2
else: else:
self.convert_tag = lambda words: iob2bioes(iob2(words)) self.convert_tag = lambda words: iob2bioes(iob2(words))
self.lower = lower self.lower = lower
def process(self, data_bundle: DataBundle) -> DataBundle: def process(self, data_bundle: DataBundle) -> DataBundle:
""" """
支持的DataSet的field为 支持的DataSet的field为
@@ -46,21 +58,21 @@ class _NERPipe(Pipe):
# 转换tag # 转换tag
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET) dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET)
_add_words_field(data_bundle, lower=self.lower) _add_words_field(data_bundle, lower=self.lower)
# index # index
_indexize(data_bundle) _indexize(data_bundle)
input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN]
target_fields = [Const.TARGET, Const.INPUT_LEN] target_fields = [Const.TARGET, Const.INPUT_LEN]
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT) dataset.add_seq_len(Const.INPUT)
data_bundle.set_input(*input_fields) data_bundle.set_input(*input_fields)
data_bundle.set_target(*target_fields) data_bundle.set_target(*target_fields)
return data_bundle return data_bundle




@@ -84,7 +96,7 @@ class Conll2003NERPipe(_NERPipe):
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。
""" """
def process_from_file(self, paths) -> DataBundle: def process_from_file(self, paths) -> DataBundle:
""" """


@@ -94,7 +106,7 @@ class Conll2003NERPipe(_NERPipe):
# 读取数据 # 读取数据
data_bundle = Conll2003NERLoader().load(paths) data_bundle = Conll2003NERLoader().load(paths)
data_bundle = self.process(data_bundle) data_bundle = self.process(data_bundle)
return data_bundle return data_bundle




@@ -125,8 +137,8 @@ class Conll2003Pipe(Pipe):
else: else:
self.ner_convert_tag = lambda tags: iob2bioes(iob2(tags)) self.ner_convert_tag = lambda tags: iob2bioes(iob2(tags))
self.lower = lower self.lower = lower
def process(self, data_bundle)->DataBundle:
def process(self, data_bundle) -> DataBundle:
""" """
输入的DataSet应该类似于如下的形式 输入的DataSet应该类似于如下的形式


@@ -145,9 +157,9 @@ class Conll2003Pipe(Pipe):
dataset.drop(lambda x: "-DOCSTART-" in x[Const.RAW_WORD]) dataset.drop(lambda x: "-DOCSTART-" in x[Const.RAW_WORD])
dataset.apply_field(self.chunk_convert_tag, field_name='chunk', new_field_name='chunk') dataset.apply_field(self.chunk_convert_tag, field_name='chunk', new_field_name='chunk')
dataset.apply_field(self.ner_convert_tag, field_name='ner', new_field_name='ner') dataset.apply_field(self.ner_convert_tag, field_name='ner', new_field_name='ner')
_add_words_field(data_bundle, lower=self.lower) _add_words_field(data_bundle, lower=self.lower)
# index # index
_indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=['pos', 'ner']) _indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=['pos', 'ner'])
# chunk中存在一些tag只在dev中出现,没在train中 # chunk中存在一些tag只在dev中出现,没在train中
@@ -155,18 +167,18 @@ class Conll2003Pipe(Pipe):
tgt_vocab.from_dataset(*data_bundle.datasets.values(), field_name='chunk') tgt_vocab.from_dataset(*data_bundle.datasets.values(), field_name='chunk')
tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name='chunk') tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name='chunk')
data_bundle.set_vocab(tgt_vocab, 'chunk') data_bundle.set_vocab(tgt_vocab, 'chunk')
input_fields = [Const.INPUT, Const.INPUT_LEN] input_fields = [Const.INPUT, Const.INPUT_LEN]
target_fields = ['pos', 'ner', 'chunk', Const.INPUT_LEN] target_fields = ['pos', 'ner', 'chunk', Const.INPUT_LEN]
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT) dataset.add_seq_len(Const.INPUT)
data_bundle.set_input(*input_fields) data_bundle.set_input(*input_fields)
data_bundle.set_target(*target_fields) data_bundle.set_target(*target_fields)
return data_bundle return data_bundle
def process_from_file(self, paths): def process_from_file(self, paths):
""" """


@@ -194,7 +206,7 @@ class OntoNotesNERPipe(_NERPipe):
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。
""" """
def process_from_file(self, paths): def process_from_file(self, paths):
data_bundle = OntoNotesNERLoader().load(paths) data_bundle = OntoNotesNERLoader().load(paths)
return self.process(data_bundle) return self.process(data_bundle)
@@ -211,13 +223,13 @@ class _CNNERPipe(Pipe):


:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。
""" """
def __init__(self, encoding_type: str = 'bio'): def __init__(self, encoding_type: str = 'bio'):
if encoding_type == 'bio': if encoding_type == 'bio':
self.convert_tag = iob2 self.convert_tag = iob2
else: else:
self.convert_tag = lambda words: iob2bioes(iob2(words)) self.convert_tag = lambda words: iob2bioes(iob2(words))
def process(self, data_bundle: DataBundle) -> DataBundle: def process(self, data_bundle: DataBundle) -> DataBundle:
""" """
支持的DataSet的field为 支持的DataSet的field为
@@ -239,21 +251,21 @@ class _CNNERPipe(Pipe):
# 转换tag # 转换tag
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET) dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET)
_add_chars_field(data_bundle, lower=False) _add_chars_field(data_bundle, lower=False)
# index # index
_indexize(data_bundle, input_field_names=Const.CHAR_INPUT, target_field_names=Const.TARGET) _indexize(data_bundle, input_field_names=Const.CHAR_INPUT, target_field_names=Const.TARGET)
input_fields = [Const.TARGET, Const.CHAR_INPUT, Const.INPUT_LEN] input_fields = [Const.TARGET, Const.CHAR_INPUT, Const.INPUT_LEN]
target_fields = [Const.TARGET, Const.INPUT_LEN] target_fields = [Const.TARGET, Const.INPUT_LEN]
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.CHAR_INPUT) dataset.add_seq_len(Const.CHAR_INPUT)
data_bundle.set_input(*input_fields) data_bundle.set_input(*input_fields)
data_bundle.set_target(*target_fields) data_bundle.set_target(*target_fields)
return data_bundle return data_bundle




@@ -272,6 +284,7 @@ class MsraNERPipe(_CNNERPipe):
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。


""" """
def process_from_file(self, paths=None) -> DataBundle: def process_from_file(self, paths=None) -> DataBundle:
data_bundle = MsraNERLoader().load(paths) data_bundle = MsraNERLoader().load(paths)
return self.process(data_bundle) return self.process(data_bundle)
@@ -291,6 +304,7 @@ class PeopleDailyPipe(_CNNERPipe):
raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。
""" """
def process_from_file(self, paths=None) -> DataBundle: def process_from_file(self, paths=None) -> DataBundle:
data_bundle = PeopleDailyNERLoader().load(paths) data_bundle = PeopleDailyNERLoader().load(paths)
return self.process(data_bundle) return self.process(data_bundle)
@@ -312,6 +326,7 @@ class WeiboNERPipe(_CNNERPipe):


:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。
""" """
def process_from_file(self, paths=None) -> DataBundle: def process_from_file(self, paths=None) -> DataBundle:
data_bundle = WeiboNERLoader().load(paths) data_bundle = WeiboNERLoader().load(paths)
return self.process(data_bundle) return self.process(data_bundle)

+ 6
- 0
fastNLP/io/pipe/cws.py View File

@@ -1,3 +1,9 @@
"""undocumented"""

__all__ = [
"CWSPipe"
]

import re import re
from itertools import chain from itertools import chain




+ 46
- 29
fastNLP/io/pipe/matching.py View File

@@ -1,9 +1,25 @@
"""undocumented"""

__all__ = [
"MatchingBertPipe",
"RTEBertPipe",
"SNLIBertPipe",
"QuoraBertPipe",
"QNLIBertPipe",
"MNLIBertPipe",
"MatchingPipe",
"RTEPipe",
"SNLIPipe",
"QuoraPipe",
"QNLIPipe",
"MNLIPipe",
]


from .pipe import Pipe from .pipe import Pipe
from .utils import get_tokenizer from .utils import get_tokenizer
from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader
from ...core.const import Const from ...core.const import Const
from ...core.vocabulary import Vocabulary from ...core.vocabulary import Vocabulary
from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader




class MatchingBertPipe(Pipe): class MatchingBertPipe(Pipe):
@@ -24,12 +40,13 @@ class MatchingBertPipe(Pipe):
:param bool lower: 是否将word小写化。 :param bool lower: 是否将word小写化。
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。
""" """
def __init__(self, lower=False, tokenizer: str='raw'):
def __init__(self, lower=False, tokenizer: str = 'raw'):
super().__init__() super().__init__()
self.lower = bool(lower) self.lower = bool(lower)
self.tokenizer = get_tokenizer(tokenizer=tokenizer) self.tokenizer = get_tokenizer(tokenizer=tokenizer)
def _tokenize(self, data_bundle, field_names, new_field_names): def _tokenize(self, data_bundle, field_names, new_field_names):
""" """


@@ -43,62 +60,62 @@ class MatchingBertPipe(Pipe):
dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name,
new_field_name=new_field_name) new_field_name=new_field_name)
return data_bundle return data_bundle
def process(self, data_bundle): def process(self, data_bundle):
for dataset in data_bundle.datasets.values(): for dataset in data_bundle.datasets.values():
if dataset.has_field(Const.TARGET): if dataset.has_field(Const.TARGET):
dataset.drop(lambda x: x[Const.TARGET] == '-') dataset.drop(lambda x: x[Const.TARGET] == '-')
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.copy_field(Const.RAW_WORDS(0), Const.INPUTS(0), ) dataset.copy_field(Const.RAW_WORDS(0), Const.INPUTS(0), )
dataset.copy_field(Const.RAW_WORDS(1), Const.INPUTS(1), ) dataset.copy_field(Const.RAW_WORDS(1), Const.INPUTS(1), )
if self.lower: if self.lower:
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset[Const.INPUTS(0)].lower() dataset[Const.INPUTS(0)].lower()
dataset[Const.INPUTS(1)].lower() dataset[Const.INPUTS(1)].lower()
data_bundle = self._tokenize(data_bundle, [Const.INPUTS(0), Const.INPUTS(1)], data_bundle = self._tokenize(data_bundle, [Const.INPUTS(0), Const.INPUTS(1)],
[Const.INPUTS(0), Const.INPUTS(1)]) [Const.INPUTS(0), Const.INPUTS(1)])
# concat两个words # concat两个words
def concat(ins): def concat(ins):
words0 = ins[Const.INPUTS(0)] words0 = ins[Const.INPUTS(0)]
words1 = ins[Const.INPUTS(1)] words1 = ins[Const.INPUTS(1)]
words = words0 + ['[SEP]'] + words1 words = words0 + ['[SEP]'] + words1
return words return words
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.apply(concat, new_field_name=Const.INPUT) dataset.apply(concat, new_field_name=Const.INPUT)
dataset.delete_field(Const.INPUTS(0)) dataset.delete_field(Const.INPUTS(0))
dataset.delete_field(Const.INPUTS(1)) dataset.delete_field(Const.INPUTS(1))
word_vocab = Vocabulary() word_vocab = Vocabulary()
word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name], word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name],
field_name=Const.INPUT, field_name=Const.INPUT,
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if
'train' not in name]) 'train' not in name])
word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT) word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT)
target_vocab = Vocabulary(padding=None, unknown=None) target_vocab = Vocabulary(padding=None, unknown=None)
target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET)
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if
dataset.has_field(Const.TARGET)] dataset.has_field(Const.TARGET)]
target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET)
data_bundle.set_vocab(word_vocab, Const.INPUT) data_bundle.set_vocab(word_vocab, Const.INPUT)
data_bundle.set_vocab(target_vocab, Const.TARGET) data_bundle.set_vocab(target_vocab, Const.TARGET)
input_fields = [Const.INPUT, Const.INPUT_LEN] input_fields = [Const.INPUT, Const.INPUT_LEN]
target_fields = [Const.TARGET] target_fields = [Const.TARGET]
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT) dataset.add_seq_len(Const.INPUT)
dataset.set_input(*input_fields, flag=True) dataset.set_input(*input_fields, flag=True)
for fields in target_fields: for fields in target_fields:
if dataset.has_field(fields): if dataset.has_field(fields):
dataset.set_target(fields, flag=True) dataset.set_target(fields, flag=True)
return data_bundle return data_bundle




@@ -150,12 +167,13 @@ class MatchingPipe(Pipe):
:param bool lower: 是否将所有raw_words转为小写。 :param bool lower: 是否将所有raw_words转为小写。
:param str tokenizer: 将原始数据tokenize的方式。支持spacy, raw. spacy是使用spacy切分,raw就是用空格切分。 :param str tokenizer: 将原始数据tokenize的方式。支持spacy, raw. spacy是使用spacy切分,raw就是用空格切分。
""" """
def __init__(self, lower=False, tokenizer: str='raw'):
def __init__(self, lower=False, tokenizer: str = 'raw'):
super().__init__() super().__init__()
self.lower = bool(lower) self.lower = bool(lower)
self.tokenizer = get_tokenizer(tokenizer=tokenizer) self.tokenizer = get_tokenizer(tokenizer=tokenizer)
def _tokenize(self, data_bundle, field_names, new_field_names): def _tokenize(self, data_bundle, field_names, new_field_names):
""" """


@@ -169,7 +187,7 @@ class MatchingPipe(Pipe):
dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name,
new_field_name=new_field_name) new_field_name=new_field_name)
return data_bundle return data_bundle
def process(self, data_bundle): def process(self, data_bundle):
""" """
接受的DataBundle中的DataSet应该具有以下的field, target列可以没有 接受的DataBundle中的DataSet应该具有以下的field, target列可以没有
@@ -186,35 +204,35 @@ class MatchingPipe(Pipe):
""" """
data_bundle = self._tokenize(data_bundle, [Const.RAW_WORDS(0), Const.RAW_WORDS(1)], data_bundle = self._tokenize(data_bundle, [Const.RAW_WORDS(0), Const.RAW_WORDS(1)],
[Const.INPUTS(0), Const.INPUTS(1)]) [Const.INPUTS(0), Const.INPUTS(1)])
for dataset in data_bundle.datasets.values(): for dataset in data_bundle.datasets.values():
if dataset.has_field(Const.TARGET): if dataset.has_field(Const.TARGET):
dataset.drop(lambda x: x[Const.TARGET] == '-') dataset.drop(lambda x: x[Const.TARGET] == '-')
if self.lower: if self.lower:
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset[Const.INPUTS(0)].lower() dataset[Const.INPUTS(0)].lower()
dataset[Const.INPUTS(1)].lower() dataset[Const.INPUTS(1)].lower()
word_vocab = Vocabulary() word_vocab = Vocabulary()
word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name], word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name],
field_name=[Const.INPUTS(0), Const.INPUTS(1)], field_name=[Const.INPUTS(0), Const.INPUTS(1)],
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if
'train' not in name]) 'train' not in name])
word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=[Const.INPUTS(0), Const.INPUTS(1)]) word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=[Const.INPUTS(0), Const.INPUTS(1)])
target_vocab = Vocabulary(padding=None, unknown=None) target_vocab = Vocabulary(padding=None, unknown=None)
target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET)
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if
dataset.has_field(Const.TARGET)] dataset.has_field(Const.TARGET)]
target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET)
data_bundle.set_vocab(word_vocab, Const.INPUTS(0)) data_bundle.set_vocab(word_vocab, Const.INPUTS(0))
data_bundle.set_vocab(target_vocab, Const.TARGET) data_bundle.set_vocab(target_vocab, Const.TARGET)
input_fields = [Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1)] input_fields = [Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1)]
target_fields = [Const.TARGET] target_fields = [Const.TARGET]
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUTS(0), Const.INPUT_LENS(0)) dataset.add_seq_len(Const.INPUTS(0), Const.INPUT_LENS(0))
dataset.add_seq_len(Const.INPUTS(1), Const.INPUT_LENS(1)) dataset.add_seq_len(Const.INPUTS(1), Const.INPUT_LENS(1))
@@ -222,7 +240,7 @@ class MatchingPipe(Pipe):
for fields in target_fields: for fields in target_fields:
if dataset.has_field(fields): if dataset.has_field(fields):
dataset.set_target(fields, flag=True) dataset.set_target(fields, flag=True)
return data_bundle return data_bundle




@@ -254,4 +272,3 @@ class MNLIPipe(MatchingPipe):
def process_from_file(self, paths=None): def process_from_file(self, paths=None):
data_bundle = MNLILoader().load(paths) data_bundle = MNLILoader().load(paths)
return self.process(data_bundle) return self.process(data_bundle)


+ 6
- 0
fastNLP/io/pipe/pipe.py View File

@@ -1,3 +1,9 @@
"""undocumented"""

__all__ = [
"Pipe",
]

from .. import DataBundle from .. import DataBundle






+ 24
- 14
fastNLP/io/pipe/utils.py View File

@@ -1,8 +1,18 @@
"""undocumented"""

__all__ = [
"iob2",
"iob2bioes",
"get_tokenizer",
]

from typing import List from typing import List
from ...core.vocabulary import Vocabulary
from ...core.const import Const from ...core.const import Const
from ...core.vocabulary import Vocabulary



def iob2(tags:List[str])->List[str]:
def iob2(tags: List[str]) -> List[str]:
""" """
检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两种格式的区别见 检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两种格式的区别见
https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format
@@ -25,7 +35,8 @@ def iob2(tags:List[str])->List[str]:
tags[i] = "B" + tag[1:] tags[i] = "B" + tag[1:]
return tags return tags


def iob2bioes(tags:List[str])->List[str]:

def iob2bioes(tags: List[str]) -> List[str]:
""" """
将iob的tag转换为bioes编码 将iob的tag转换为bioes编码
:param tags: :param tags:
@@ -38,12 +49,12 @@ def iob2bioes(tags:List[str])->List[str]:
else: else:
split = tag.split('-')[0] split = tag.split('-')[0]
if split == 'B': if split == 'B':
if i+1!=len(tags) and tags[i+1].split('-')[0] == 'I':
if i + 1 != len(tags) and tags[i + 1].split('-')[0] == 'I':
new_tags.append(tag) new_tags.append(tag)
else: else:
new_tags.append(tag.replace('B-', 'S-')) new_tags.append(tag.replace('B-', 'S-'))
elif split == 'I': elif split == 'I':
if i + 1<len(tags) and tags[i+1].split('-')[0] == 'I':
if i + 1 < len(tags) and tags[i + 1].split('-')[0] == 'I':
new_tags.append(tag) new_tags.append(tag)
else: else:
new_tags.append(tag.replace('I-', 'E-')) new_tags.append(tag.replace('I-', 'E-'))
@@ -52,7 +63,7 @@ def iob2bioes(tags:List[str])->List[str]:
return new_tags return new_tags




def get_tokenizer(tokenizer:str, lang='en'):
def get_tokenizer(tokenizer: str, lang='en'):
""" """


:param str tokenizer: 获取tokenzier方法 :param str tokenizer: 获取tokenzier方法
@@ -97,13 +108,13 @@ def _indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=Con
name != 'train']) name != 'train'])
src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=input_field_name) src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=input_field_name)
data_bundle.set_vocab(src_vocab, input_field_name) data_bundle.set_vocab(src_vocab, input_field_name)
for target_field_name in target_field_names: for target_field_name in target_field_names:
tgt_vocab = Vocabulary(unknown=None, padding=None) tgt_vocab = Vocabulary(unknown=None, padding=None)
tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=target_field_name) tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=target_field_name)
tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name=target_field_name) tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name=target_field_name)
data_bundle.set_vocab(tgt_vocab, target_field_name) data_bundle.set_vocab(tgt_vocab, target_field_name)
return data_bundle return data_bundle




@@ -116,7 +127,7 @@ def _add_words_field(data_bundle, lower=False):
:return: 传入的DataBundle :return: 传入的DataBundle
""" """
data_bundle.copy_field(field_name=Const.RAW_WORD, new_field_name=Const.INPUT, ignore_miss_dataset=True) data_bundle.copy_field(field_name=Const.RAW_WORD, new_field_name=Const.INPUT, ignore_miss_dataset=True)
if lower: if lower:
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset[Const.INPUT].lower() dataset[Const.INPUT].lower()
@@ -132,7 +143,7 @@ def _add_chars_field(data_bundle, lower=False):
:return: 传入的DataBundle :return: 传入的DataBundle
""" """
data_bundle.copy_field(field_name=Const.RAW_CHAR, new_field_name=Const.CHAR_INPUT, ignore_miss_dataset=True) data_bundle.copy_field(field_name=Const.RAW_CHAR, new_field_name=Const.CHAR_INPUT, ignore_miss_dataset=True)
if lower: if lower:
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset[Const.CHAR_INPUT].lower() dataset[Const.CHAR_INPUT].lower()
@@ -147,6 +158,7 @@ def _drop_empty_instance(data_bundle, field_name):
:param str field_name: 对哪个field进行检查,如果为None,则任意field为空都会删掉 :param str field_name: 对哪个field进行检查,如果为None,则任意field为空都会删掉
:return: 传入的DataBundle :return: 传入的DataBundle
""" """
def empty_instance(ins): def empty_instance(ins):
if field_name: if field_name:
field_value = ins[field_name] field_value = ins[field_name]
@@ -157,10 +169,8 @@ def _drop_empty_instance(data_bundle, field_name):
if field_value in ((), {}, [], ''): if field_value in ((), {}, [], ''):
return True return True
return False return False
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.drop(empty_instance) dataset.drop(empty_instance)
return data_bundle return data_bundle



+ 19
- 6
fastNLP/io/utils.py View File

@@ -1,10 +1,20 @@
import os
"""
.. todo::
doc
"""


from typing import Union, Dict
__all__ = [
"check_loader_paths"
]

import os
from pathlib import Path from pathlib import Path
from typing import Union, Dict

from ..core import logger from ..core import logger


def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:

def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]:
""" """
检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果:: 检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果::


@@ -33,11 +43,13 @@ def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
path_pair = ('train', filename) path_pair = ('train', filename)
if 'dev' in filename: if 'dev' in filename:
if path_pair: if path_pair:
raise Exception("File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0]))
raise Exception(
"File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0]))
path_pair = ('dev', filename) path_pair = ('dev', filename)
if 'test' in filename: if 'test' in filename:
if path_pair: if path_pair:
raise Exception("File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0]))
raise Exception(
"File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0]))
path_pair = ('test', filename) path_pair = ('test', filename)
if path_pair: if path_pair:
files[path_pair[0]] = os.path.join(paths, path_pair[1]) files[path_pair[0]] = os.path.join(paths, path_pair[1])
@@ -46,7 +58,7 @@ def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
return files return files
else: else:
raise FileNotFoundError(f"{paths} is not a valid file path.") raise FileNotFoundError(f"{paths} is not a valid file path.")
elif isinstance(paths, dict): elif isinstance(paths, dict):
if paths: if paths:
if 'train' not in paths: if 'train' not in paths:
@@ -65,6 +77,7 @@ def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
else: else:
raise TypeError(f"paths only supports str and dict. not {type(paths)}.") raise TypeError(f"paths only supports str and dict. not {type(paths)}.")



def get_tokenizer(): def get_tokenizer():
try: try:
import spacy import spacy


Loading…
Cancel
Save