From 0d5f43b451473fe25703cb1f9798fcf03eb64c76 Mon Sep 17 00:00:00 2001 From: ChenXin Date: Mon, 26 Aug 2019 10:25:01 +0800 Subject: [PATCH] add __all__ and __doc__ for all files in module 'io', using 'undocumented' tags --- fastNLP/io/data_bundle.py | 7 +- fastNLP/io/dataset_loader.py | 6 +- fastNLP/io/embed_loader.py | 9 +- fastNLP/io/file_reader.py | 16 ++- fastNLP/io/file_utils.py | 23 +++- fastNLP/io/loader/classification.py | 26 +++-- fastNLP/io/loader/conll.py | 84 +++++++++------ fastNLP/io/loader/csv.py | 10 +- fastNLP/io/loader/cws.py | 17 ++- fastNLP/io/loader/json.py | 10 +- fastNLP/io/loader/loader.py | 13 ++- fastNLP/io/loader/matching.py | 82 ++++++++------ fastNLP/io/pipe/classification.py | 161 +++++++++++++++------------- fastNLP/io/pipe/conll.py | 79 ++++++++------ fastNLP/io/pipe/cws.py | 6 ++ fastNLP/io/pipe/matching.py | 75 ++++++++----- fastNLP/io/pipe/pipe.py | 6 ++ fastNLP/io/pipe/utils.py | 38 ++++--- fastNLP/io/utils.py | 25 +++-- 19 files changed, 439 insertions(+), 254 deletions(-) diff --git a/fastNLP/io/data_bundle.py b/fastNLP/io/data_bundle.py index 1e663f1e..db60a86f 100644 --- a/fastNLP/io/data_bundle.py +++ b/fastNLP/io/data_bundle.py @@ -1,10 +1,15 @@ +""" +.. todo:: + doc +""" __all__ = [ 'DataBundle', ] import _pickle as pickle -from typing import Union, Dict import os +from typing import Union, Dict + from ..core.dataset import DataSet from ..core.vocabulary import Vocabulary diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index 82e96597..fca0de69 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -1,4 +1,4 @@ -""" +"""undocumented .. warning:: 本模块将在 `0.5.0版本` 中被废弃,由 :mod:`~fastNLP.io.loader` 和 :mod:`~fastNLP.io.pipe` 模块替代。 @@ -23,10 +23,10 @@ __all__ = [ ] +from .data_bundle import DataSetLoader +from .file_reader import _read_csv, _read_json from ..core.dataset import DataSet from ..core.instance import Instance -from .file_reader import _read_csv, _read_json -from .data_bundle import DataSetLoader class JsonLoader(DataSetLoader): diff --git a/fastNLP/io/embed_loader.py b/fastNLP/io/embed_loader.py index c58385e1..780d91e4 100644 --- a/fastNLP/io/embed_loader.py +++ b/fastNLP/io/embed_loader.py @@ -1,17 +1,22 @@ +""" +.. todo:: + doc +""" __all__ = [ "EmbedLoader", "EmbeddingOption", ] +import logging import os import warnings import numpy as np -from ..core.vocabulary import Vocabulary from .data_bundle import BaseLoader from ..core.utils import Option -import logging +from ..core.vocabulary import Vocabulary + class EmbeddingOption(Option): def __init__(self, diff --git a/fastNLP/io/file_reader.py b/fastNLP/io/file_reader.py index 0320572c..7a953098 100644 --- a/fastNLP/io/file_reader.py +++ b/fastNLP/io/file_reader.py @@ -1,7 +1,11 @@ -""" +"""undocumented 此模块用于给其它模块提供读取文件的函数,没有为用户提供 API """ + +__all__ = [] + import json + from ..core import logger @@ -24,8 +28,8 @@ def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): headers = headers.split(sep) start_idx += 1 elif not isinstance(headers, (list, tuple)): - raise TypeError("headers should be list or tuple, not {}." \ - .format(type(headers))) + raise TypeError("headers should be list or tuple, not {}." \ + .format(type(headers))) for line_idx, line in enumerate(f, start_idx): contents = line.rstrip('\r\n').split(sep) if len(contents) != len(headers): @@ -82,6 +86,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): :if False, raise ValueError when reading invalid data. default: True :return: generator, every time yield (line number, conll item) """ + def parse_conll(sample): sample = list(map(list, zip(*sample))) sample = [sample[i] for i in indexes] @@ -89,14 +94,15 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): if len(f) <= 0: raise ValueError('empty field') return sample + with open(path, 'r', encoding=encoding) as f: sample = [] start = next(f).strip() - if start!='': + if start != '': sample.append(start.split()) for line_idx, line in enumerate(f, 1): line = line.strip() - if line=='': + if line == '': if len(sample): try: res = parse_conll(sample) diff --git a/fastNLP/io/file_utils.py b/fastNLP/io/file_utils.py index bd02158e..8ecdff25 100644 --- a/fastNLP/io/file_utils.py +++ b/fastNLP/io/file_utils.py @@ -1,12 +1,27 @@ +""" +.. todo:: + doc +""" + +__all__ = [ + "cached_path", + "get_filepath", + "get_cache_path", + "split_filename_suffix", + "get_from_cache", +] + import os +import re +import shutil +import tempfile from pathlib import Path from urllib.parse import urlparse -import re + import requests -import tempfile -from tqdm import tqdm -import shutil from requests import HTTPError +from tqdm import tqdm + from ..core import logger PRETRAINED_BERT_MODEL_DIR = { diff --git a/fastNLP/io/loader/classification.py b/fastNLP/io/loader/classification.py index f64a26e7..ec00d2b4 100644 --- a/fastNLP/io/loader/classification.py +++ b/fastNLP/io/loader/classification.py @@ -1,12 +1,24 @@ -from ...core.dataset import DataSet -from ...core.instance import Instance -from .loader import Loader -import warnings +"""undocumented""" + +__all__ = [ + "YelpLoader", + "YelpFullLoader", + "YelpPolarityLoader", + "IMDBLoader", + "SSTLoader", + "SST2Loader", +] + +import glob import os import random import shutil -import glob import time +import warnings + +from .loader import Loader +from ...core.dataset import DataSet +from ...core.instance import Instance class YelpLoader(Loader): @@ -58,7 +70,7 @@ class YelpLoader(Loader): class YelpFullLoader(YelpLoader): - def download(self, dev_ratio: float = 0.1, re_download:bool=False): + def download(self, dev_ratio: float = 0.1, re_download: bool = False): """ 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 @@ -127,7 +139,7 @@ class YelpPolarityLoader(YelpLoader): if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 shutil.rmtree(data_dir) data_dir = self._get_dataset_path(dataset_name=dataset_name) - + if not os.path.exists(os.path.join(data_dir, 'dev.csv')): if dev_ratio > 0: assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." diff --git a/fastNLP/io/loader/conll.py b/fastNLP/io/loader/conll.py index b5241cff..1bd1b448 100644 --- a/fastNLP/io/loader/conll.py +++ b/fastNLP/io/loader/conll.py @@ -1,15 +1,28 @@ -from typing import Dict, Union +"""undocumented""" + +__all__ = [ + "ConllLoader", + "Conll2003Loader", + "Conll2003NERLoader", + "OntoNotesNERLoader", + "CTBLoader", + "CNNERLoader", + "MsraNERLoader", + "WeiboNERLoader", + "PeopleDailyNERLoader" +] -from .loader import Loader -from ...core.dataset import DataSet -from ..file_reader import _read_conll -from ...core.instance import Instance -from ...core.const import Const import glob import os +import random import shutil import time -import random + +from .loader import Loader +from ..file_reader import _read_conll +from ...core.const import Const +from ...core.dataset import DataSet +from ...core.instance import Instance class ConllLoader(Loader): @@ -47,6 +60,7 @@ class ConllLoader(Loader): :param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` """ + def __init__(self, headers, indexes=None, dropna=True): super(ConllLoader, self).__init__() if not isinstance(headers, (list, tuple)): @@ -60,7 +74,7 @@ class ConllLoader(Loader): if len(indexes) != len(headers): raise ValueError self.indexes = indexes - + def _load(self, path): """ 传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 @@ -101,12 +115,13 @@ class Conll2003Loader(ConllLoader): "[...]", "[...]", "[...]", "[...]" """ + def __init__(self): headers = [ 'raw_words', 'pos', 'chunk', 'ner', ] super(Conll2003Loader, self).__init__(headers=headers) - + def _load(self, path): """ 传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 @@ -127,7 +142,7 @@ class Conll2003Loader(ConllLoader): ins = {h: data[i] for i, h in enumerate(self.headers)} ds.append(Instance(**ins)) return ds - + def download(self, output_dir=None): raise RuntimeError("conll2003 cannot be downloaded automatically.") @@ -158,12 +173,13 @@ class Conll2003NERLoader(ConllLoader): "[...]", "[...]" """ + def __init__(self): headers = [ 'raw_words', 'target', ] super().__init__(headers=headers, indexes=[0, 3]) - + def _load(self, path): """ 传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 @@ -184,7 +200,7 @@ class Conll2003NERLoader(ConllLoader): ins = {h: data[i] for i, h in enumerate(self.headers)} ds.append(Instance(**ins)) return ds - + def download(self): raise RuntimeError("conll2003 cannot be downloaded automatically.") @@ -204,13 +220,13 @@ class OntoNotesNERLoader(ConllLoader): "[...]", "[...]" """ - + def __init__(self): super().__init__(headers=[Const.RAW_WORD, Const.TARGET], indexes=[3, 10]) - - def _load(self, path:str): + + def _load(self, path: str): dataset = super()._load(path) - + def convert_to_bio(tags): bio_tags = [] flag = None @@ -227,7 +243,7 @@ class OntoNotesNERLoader(ConllLoader): flag = None bio_tags.append(bio_label) return bio_tags - + def convert_word(words): converted_words = [] for word in words: @@ -236,7 +252,7 @@ class OntoNotesNERLoader(ConllLoader): converted_words.append(word) continue # 以下是由于这些符号被转义了,再转回来 - tfrs = {'-LRB-':'(', + tfrs = {'-LRB-': '(', '-RRB-': ')', '-LSB-': '[', '-RSB-': ']', @@ -248,12 +264,12 @@ class OntoNotesNERLoader(ConllLoader): else: converted_words.append(word) return converted_words - + dataset.apply_field(convert_word, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD) dataset.apply_field(convert_to_bio, field_name=Const.TARGET, new_field_name=Const.TARGET) - + return dataset - + def download(self): raise RuntimeError("Ontonotes cannot be downloaded automatically, you can refer " "https://github.com/yhcc/OntoNotes-5.0-NER to download and preprocess.") @@ -262,13 +278,13 @@ class OntoNotesNERLoader(ConllLoader): class CTBLoader(Loader): def __init__(self): super().__init__() - - def _load(self, path:str): + + def _load(self, path: str): pass class CNNERLoader(Loader): - def _load(self, path:str): + def _load(self, path: str): """ 支持加载形如以下格式的内容,一行两列,以空格隔开两个sample @@ -331,10 +347,11 @@ class MsraNERLoader(CNNERLoader): "[...]", "[...]" """ + def __init__(self): super().__init__() - - def download(self, dev_ratio:float=0.1, re_download:bool=False)->str: + + def download(self, dev_ratio: float = 0.1, re_download: bool = False) -> str: """ 自动下载MSAR-NER的数据,如果你使用该数据,请引用 Gina-Anne Levow, 2006, The Third International Chinese Language Processing Bakeoff: Word Segmentation and Named Entity Recognition. @@ -356,7 +373,7 @@ class MsraNERLoader(CNNERLoader): if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 shutil.rmtree(data_dir) data_dir = self._get_dataset_path(dataset_name=dataset_name) - + if not os.path.exists(os.path.join(data_dir, 'dev.conll')): if dev_ratio > 0: assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." @@ -380,15 +397,15 @@ class MsraNERLoader(CNNERLoader): finally: if os.path.exists(os.path.join(data_dir, 'middle_file.conll')): os.remove(os.path.join(data_dir, 'middle_file.conll')) - + return data_dir class WeiboNERLoader(CNNERLoader): def __init__(self): super().__init__() - - def download(self)->str: + + def download(self) -> str: """ 自动下载Weibo-NER的数据,如果你使用了该数据,请引用 Nanyun Peng and Mark Dredze, 2015, Named Entity Recognition for Chinese Social Media with Jointly Trained Embeddings. @@ -397,7 +414,7 @@ class WeiboNERLoader(CNNERLoader): """ dataset_name = 'weibo-ner' data_dir = self._get_dataset_path(dataset_name=dataset_name) - + return data_dir @@ -427,11 +444,12 @@ class PeopleDailyNERLoader(CNNERLoader): "[...]", "[...]" """ + def __init__(self): super().__init__() - + def download(self) -> str: dataset_name = 'peopledaily' data_dir = self._get_dataset_path(dataset_name=dataset_name) - + return data_dir diff --git a/fastNLP/io/loader/csv.py b/fastNLP/io/loader/csv.py index 5195cc8e..0d6e35fa 100644 --- a/fastNLP/io/loader/csv.py +++ b/fastNLP/io/loader/csv.py @@ -1,7 +1,13 @@ +"""undocumented""" + +__all__ = [ + "CSVLoader", +] + +from .loader import Loader +from ..file_reader import _read_csv from ...core.dataset import DataSet from ...core.instance import Instance -from ..file_reader import _read_csv -from .loader import Loader class CSVLoader(Loader): diff --git a/fastNLP/io/loader/cws.py b/fastNLP/io/loader/cws.py index fab7639c..2fbb1091 100644 --- a/fastNLP/io/loader/cws.py +++ b/fastNLP/io/loader/cws.py @@ -1,11 +1,18 @@ -from .loader import Loader -from ...core.dataset import DataSet -from ...core.instance import Instance +"""undocumented""" + +__all__ = [ + "CWSLoader" +] + import glob import os -import time -import shutil import random +import shutil +import time + +from .loader import Loader +from ...core.dataset import DataSet +from ...core.instance import Instance class CWSLoader(Loader): diff --git a/fastNLP/io/loader/json.py b/fastNLP/io/loader/json.py index 8856b73a..012dee5a 100644 --- a/fastNLP/io/loader/json.py +++ b/fastNLP/io/loader/json.py @@ -1,7 +1,13 @@ +"""undocumented""" + +__all__ = [ + "JsonLoader" +] + +from .loader import Loader +from ..file_reader import _read_json from ...core.dataset import DataSet from ...core.instance import Instance -from ..file_reader import _read_json -from .loader import Loader class JsonLoader(Loader): diff --git a/fastNLP/io/loader/loader.py b/fastNLP/io/loader/loader.py index e7b419ac..22636a27 100644 --- a/fastNLP/io/loader/loader.py +++ b/fastNLP/io/loader/loader.py @@ -1,8 +1,15 @@ -from ...core.dataset import DataSet -from .. import DataBundle -from ..utils import check_loader_paths +"""undocumented""" + +__all__ = [ + "Loader" +] + from typing import Union, Dict + +from .. import DataBundle from ..file_utils import _get_dataset_url, get_cache_path, cached_path +from ..utils import check_loader_paths +from ...core.dataset import DataSet class Loader: diff --git a/fastNLP/io/loader/matching.py b/fastNLP/io/loader/matching.py index 26455914..7f03ca3e 100644 --- a/fastNLP/io/loader/matching.py +++ b/fastNLP/io/loader/matching.py @@ -1,10 +1,21 @@ +"""undocumented""" + +__all__ = [ + "MNLILoader", + "SNLILoader", + "QNLILoader", + "RTELoader", + "QuoraLoader", +] + +import os import warnings -from .loader import Loader +from typing import Union, Dict + from .json import JsonLoader -from ...core.const import Const +from .loader import Loader from .. import DataBundle -import os -from typing import Union, Dict +from ...core.const import Const from ...core.dataset import DataSet from ...core.instance import Instance @@ -22,10 +33,11 @@ class MNLILoader(Loader): "...", "...","." """ + def __init__(self): super().__init__() - - def _load(self, path:str): + + def _load(self, path: str): ds = DataSet() with open(path, 'r', encoding='utf-8') as f: f.readline() # 跳过header @@ -50,8 +62,8 @@ class MNLILoader(Loader): if raw_words1 and raw_words2 and target: ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) return ds - - def load(self, paths:str=None): + + def load(self, paths: str = None): """ :param str paths: 传入数据所在目录,会在该目录下寻找dev_matched.tsv, dev_mismatched.tsv, test_matched.tsv, @@ -64,13 +76,13 @@ class MNLILoader(Loader): paths = self.download() if not os.path.isdir(paths): raise NotADirectoryError(f"{paths} is not a valid directory.") - - files = {'dev_matched':"dev_matched.tsv", - "dev_mismatched":"dev_mismatched.tsv", - "test_matched":"test_matched.tsv", - "test_mismatched":"test_mismatched.tsv", - "train":'train.tsv'} - + + files = {'dev_matched': "dev_matched.tsv", + "dev_mismatched": "dev_mismatched.tsv", + "test_matched": "test_matched.tsv", + "test_mismatched": "test_mismatched.tsv", + "train": 'train.tsv'} + datasets = {} for name, filename in files.items(): filepath = os.path.join(paths, filename) @@ -78,11 +90,11 @@ class MNLILoader(Loader): if 'test' not in name: raise FileNotFoundError(f"{name} not found in directory {filepath}.") datasets[name] = self._load(filepath) - + data_bundle = DataBundle(datasets=datasets) - + return data_bundle - + def download(self): """ 如果你使用了这个数据,请引用 @@ -106,14 +118,15 @@ class SNLILoader(JsonLoader): "...", "...", "." """ + def __init__(self): super().__init__(fields={ 'sentence1': Const.RAW_WORDS(0), 'sentence2': Const.RAW_WORDS(1), 'gold_label': Const.TARGET, }) - - def load(self, paths: Union[str, Dict[str, str]]=None) -> DataBundle: + + def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: """ 从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。 @@ -138,11 +151,11 @@ class SNLILoader(JsonLoader): paths = _paths else: raise NotADirectoryError(f"{paths} is not a valid directory.") - + datasets = {name: self._load(path) for name, path in paths.items()} data_bundle = DataBundle(datasets=datasets) return data_bundle - + def download(self): """ 如果您的文章使用了这份数据,请引用 @@ -169,12 +182,13 @@ class QNLILoader(JsonLoader): test数据集没有target列 """ + def __init__(self): super().__init__() - + def _load(self, path): ds = DataSet() - + with open(path, 'r', encoding='utf-8') as f: f.readline() # 跳过header if path.endswith("test.tsv"): @@ -198,7 +212,7 @@ class QNLILoader(JsonLoader): if raw_words1 and raw_words2 and target: ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) return ds - + def download(self): """ 如果您的实验使用到了该数据,请引用 @@ -225,12 +239,13 @@ class RTELoader(Loader): test数据集没有target列 """ + def __init__(self): super().__init__() - - def _load(self, path:str): + + def _load(self, path: str): ds = DataSet() - + with open(path, 'r', encoding='utf-8') as f: f.readline() # 跳过header if path.endswith("test.tsv"): @@ -254,7 +269,7 @@ class RTELoader(Loader): if raw_words1 and raw_words2 and target: ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) return ds - + def download(self): return self._get_dataset_path('rte') @@ -281,12 +296,13 @@ class QuoraLoader(Loader): "...","." """ + def __init__(self): super().__init__() - - def _load(self, path:str): + + def _load(self, path: str): ds = DataSet() - + with open(path, 'r', encoding='utf-8') as f: for line in f: line = line.strip() @@ -298,6 +314,6 @@ class QuoraLoader(Loader): if raw_words1 and raw_words2 and target: ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) return ds - + def download(self): raise RuntimeError("Quora cannot be downloaded automatically.") diff --git a/fastNLP/io/pipe/classification.py b/fastNLP/io/pipe/classification.py index f42d5400..30c591a4 100644 --- a/fastNLP/io/pipe/classification.py +++ b/fastNLP/io/pipe/classification.py @@ -1,26 +1,39 @@ +"""undocumented""" + +__all__ = [ + "YelpFullPipe", + "YelpPolarityPipe", + "SSTPipe", + "SST2Pipe", + 'IMDBPipe' +] + +import re + from nltk import Tree +from .pipe import Pipe +from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance from ..data_bundle import DataBundle -from ...core.vocabulary import Vocabulary -from ...core.const import Const from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader +from ...core.const import Const from ...core.dataset import DataSet from ...core.instance import Instance +from ...core.vocabulary import Vocabulary -from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance -from .pipe import Pipe -import re nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') + class _CLSPipe(Pipe): """ 分类问题的基类,负责对classification的数据进行tokenize操作。默认是对raw_words列操作,然后生成words列 """ - def __init__(self, tokenizer:str='spacy', lang='en'): + + def __init__(self, tokenizer: str = 'spacy', lang='en'): self.tokenizer = get_tokenizer(tokenizer, lang=lang) - + def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): """ 将DataBundle中的数据进行tokenize @@ -33,9 +46,9 @@ class _CLSPipe(Pipe): new_field_name = new_field_name or field_name for name, dataset in data_bundle.datasets.items(): dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) - + return data_bundle - + def _granularize(self, data_bundle, tag_map): """ 该函数对data_bundle中'target'列中的内容进行转换。 @@ -47,9 +60,9 @@ class _CLSPipe(Pipe): """ for name in list(data_bundle.datasets.keys()): dataset = data_bundle.get_dataset(name) - dataset.apply_field(lambda target:tag_map.get(target, -100), field_name=Const.TARGET, + dataset.apply_field(lambda target: tag_map.get(target, -100), field_name=Const.TARGET, new_field_name=Const.TARGET) - dataset.drop(lambda ins:ins[Const.TARGET] == -100) + dataset.drop(lambda ins: ins[Const.TARGET] == -100) data_bundle.set_dataset(dataset, name) return data_bundle @@ -69,7 +82,7 @@ def _clean_str(words): t = ''.join(tt) if t != '': words_collection.append(t) - + return words_collection @@ -89,19 +102,20 @@ class YelpFullPipe(_CLSPipe): 1、2归为1类,3归为1类,4、5归为1类;若为5, 则有5分类问题。 :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 """ - def __init__(self, lower:bool=False, granularity=5, tokenizer:str='spacy'): + + def __init__(self, lower: bool = False, granularity=5, tokenizer: str = 'spacy'): super().__init__(tokenizer=tokenizer, lang='en') self.lower = lower assert granularity in (2, 3, 5), "granularity can only be 2,3,5." self.granularity = granularity - - if granularity==2: + + if granularity == 2: self.tag_map = {"1": 0, "2": 0, "4": 1, "5": 1} - elif granularity==3: - self.tag_map = {"1": 0, "2": 0, "3":1, "4": 2, "5": 2} + elif granularity == 3: + self.tag_map = {"1": 0, "2": 0, "3": 1, "4": 2, "5": 2} else: self.tag_map = {"1": 0, "2": 1, "3": 2, "4": 3, "5": 4} - + def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): """ 将DataBundle中的数据进行tokenize @@ -116,7 +130,7 @@ class YelpFullPipe(_CLSPipe): dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) dataset.apply_field(_clean_str, field_name=field_name, new_field_name=new_field_name) return data_bundle - + def process(self, data_bundle): """ 传入的DataSet应该具备如下的结构 @@ -131,30 +145,30 @@ class YelpFullPipe(_CLSPipe): :param data_bundle: :return: """ - + # 复制一列words data_bundle = _add_words_field(data_bundle, lower=self.lower) - + # 进行tokenize data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) - + # 根据granularity设置tag data_bundle = self._granularize(data_bundle, tag_map=self.tag_map) - + # 删除空行 data_bundle = _drop_empty_instance(data_bundle, field_name=Const.INPUT) - + # index data_bundle = _indexize(data_bundle=data_bundle) - + for name, dataset in data_bundle.datasets.items(): dataset.add_seq_len(Const.INPUT) - + data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) data_bundle.set_target(Const.TARGET) - + return data_bundle - + def process_from_file(self, paths=None): """ @@ -179,27 +193,28 @@ class YelpPolarityPipe(_CLSPipe): :param bool lower: 是否对输入进行小写化。 :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 """ - def __init__(self, lower:bool=False, tokenizer:str='spacy'): + + def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): super().__init__(tokenizer=tokenizer, lang='en') self.lower = lower - + def process(self, data_bundle): # 复制一列words data_bundle = _add_words_field(data_bundle, lower=self.lower) - + # 进行tokenize data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) # index data_bundle = _indexize(data_bundle=data_bundle) - + for name, dataset in data_bundle.datasets.items(): dataset.add_seq_len(Const.INPUT) - + data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) data_bundle.set_target(Const.TARGET) - + return data_bundle - + def process_from_file(self, paths=None): """ @@ -230,7 +245,7 @@ class SSTPipe(_CLSPipe): 0、1归为1类,2归为1类,3、4归为1类;若为5, 则有5分类问题。 :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 """ - + def __init__(self, subtree=False, train_subtree=True, lower=False, granularity=5, tokenizer='spacy'): super().__init__(tokenizer=tokenizer, lang='en') self.subtree = subtree @@ -238,15 +253,15 @@ class SSTPipe(_CLSPipe): self.lower = lower assert granularity in (2, 3, 5), "granularity can only be 2,3,5." self.granularity = granularity - - if granularity==2: + + if granularity == 2: self.tag_map = {"0": 0, "1": 0, "3": 1, "4": 1} - elif granularity==3: - self.tag_map = {"0": 0, "1": 0, "2":1, "3": 2, "4": 2} + elif granularity == 3: + self.tag_map = {"0": 0, "1": 0, "2": 1, "3": 2, "4": 2} else: self.tag_map = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4} - - def process(self, data_bundle:DataBundle): + + def process(self, data_bundle: DataBundle): """ 对DataBundle中的数据进行预处理。输入的DataSet应该至少拥有raw_words这一列,且内容类似与 @@ -277,26 +292,26 @@ class SSTPipe(_CLSPipe): instance = Instance(raw_words=' '.join(tree.leaves()), target=tree.label()) ds.append(instance) data_bundle.set_dataset(ds, name) - + _add_words_field(data_bundle, lower=self.lower) - + # 进行tokenize data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) - + # 根据granularity设置tag data_bundle = self._granularize(data_bundle, tag_map=self.tag_map) - + # index data_bundle = _indexize(data_bundle=data_bundle) - + for name, dataset in data_bundle.datasets.items(): dataset.add_seq_len(Const.INPUT) - + data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) data_bundle.set_target(Const.TARGET) - + return data_bundle - + def process_from_file(self, paths=None): data_bundle = SSTLoader().load(paths) return self.process(data_bundle=data_bundle) @@ -316,11 +331,12 @@ class SST2Pipe(_CLSPipe): :param bool lower: 是否对输入进行小写化。 :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 """ + def __init__(self, lower=False, tokenizer='spacy'): super().__init__(tokenizer=tokenizer, lang='en') self.lower = lower - - def process(self, data_bundle:DataBundle): + + def process(self, data_bundle: DataBundle): """ 可以处理的DataSet应该具备如下的结构 @@ -335,15 +351,15 @@ class SST2Pipe(_CLSPipe): :return: """ _add_words_field(data_bundle, self.lower) - + data_bundle = self._tokenize(data_bundle=data_bundle) - + src_vocab = Vocabulary() src_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.INPUT, - no_create_entry_dataset=[dataset for name,dataset in data_bundle.datasets.items() if + no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if name != 'train']) src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT) - + tgt_vocab = Vocabulary(unknown=None, padding=None) tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) datasets = [] @@ -351,18 +367,18 @@ class SST2Pipe(_CLSPipe): if dataset.has_field(Const.TARGET): datasets.append(dataset) tgt_vocab.index_dataset(*datasets, field_name=Const.TARGET) - + data_bundle.set_vocab(src_vocab, Const.INPUT) data_bundle.set_vocab(tgt_vocab, Const.TARGET) - + for name, dataset in data_bundle.datasets.items(): dataset.add_seq_len(Const.INPUT) - + data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) data_bundle.set_target(Const.TARGET) - + return data_bundle - + def process_from_file(self, paths=None): """ @@ -390,11 +406,12 @@ class IMDBPipe(_CLSPipe): :param bool lower: 是否将words列的数据小写。 :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 """ - def __init__(self, lower:bool=False, tokenizer:str='spacy'): + + def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): super().__init__(tokenizer=tokenizer, lang='en') self.lower = lower - - def process(self, data_bundle:DataBundle): + + def process(self, data_bundle: DataBundle): """ 期待的DataBunlde中输入的DataSet应该类似于如下,有两个field,raw_words和target,且均为str类型 @@ -409,25 +426,26 @@ class IMDBPipe(_CLSPipe): target列应该为str。 :return: DataBundle """ + # 替换
def replace_br(raw_words): raw_words = raw_words.replace("
", ' ') return raw_words - + for name, dataset in data_bundle.datasets.items(): dataset.apply_field(replace_br, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD) - + _add_words_field(data_bundle, lower=self.lower) self._tokenize(data_bundle, field_name=Const.INPUT, new_field_name=Const.INPUT) _indexize(data_bundle) - + for name, dataset in data_bundle.datasets.items(): dataset.add_seq_len(Const.INPUT) dataset.set_input(Const.INPUT, Const.INPUT_LEN) dataset.set_target(Const.TARGET) - + return data_bundle - + def process_from_file(self, paths=None): """ @@ -437,8 +455,5 @@ class IMDBPipe(_CLSPipe): # 读取数据 data_bundle = IMDBLoader().load(paths) data_bundle = self.process(data_bundle) - + return data_bundle - - - diff --git a/fastNLP/io/pipe/conll.py b/fastNLP/io/pipe/conll.py index 617d1236..2efec8e0 100644 --- a/fastNLP/io/pipe/conll.py +++ b/fastNLP/io/pipe/conll.py @@ -1,13 +1,25 @@ +"""undocumented""" + +__all__ = [ + "Conll2003NERPipe", + "Conll2003Pipe", + "OntoNotesNERPipe", + "MsraNERPipe", + "PeopleDailyPipe", + "WeiboNERPipe" +] + from .pipe import Pipe -from .. import DataBundle +from .utils import _add_chars_field +from .utils import _indexize, _add_words_field from .utils import iob2, iob2bioes -from ...core.const import Const +from .. import DataBundle from ..loader.conll import Conll2003NERLoader, OntoNotesNERLoader -from .utils import _indexize, _add_words_field -from .utils import _add_chars_field from ..loader.conll import PeopleDailyNERLoader, WeiboNERLoader, MsraNERLoader, ConllLoader +from ...core.const import Const from ...core.vocabulary import Vocabulary + class _NERPipe(Pipe): """ NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表 @@ -20,14 +32,14 @@ class _NERPipe(Pipe): :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 """ - + def __init__(self, encoding_type: str = 'bio', lower: bool = False): if encoding_type == 'bio': self.convert_tag = iob2 else: self.convert_tag = lambda words: iob2bioes(iob2(words)) self.lower = lower - + def process(self, data_bundle: DataBundle) -> DataBundle: """ 支持的DataSet的field为 @@ -46,21 +58,21 @@ class _NERPipe(Pipe): # 转换tag for name, dataset in data_bundle.datasets.items(): dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET) - + _add_words_field(data_bundle, lower=self.lower) - + # index _indexize(data_bundle) - + input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] target_fields = [Const.TARGET, Const.INPUT_LEN] - + for name, dataset in data_bundle.datasets.items(): dataset.add_seq_len(Const.INPUT) - + data_bundle.set_input(*input_fields) data_bundle.set_target(*target_fields) - + return data_bundle @@ -84,7 +96,7 @@ class Conll2003NERPipe(_NERPipe): :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 """ - + def process_from_file(self, paths) -> DataBundle: """ @@ -94,7 +106,7 @@ class Conll2003NERPipe(_NERPipe): # 读取数据 data_bundle = Conll2003NERLoader().load(paths) data_bundle = self.process(data_bundle) - + return data_bundle @@ -125,8 +137,8 @@ class Conll2003Pipe(Pipe): else: self.ner_convert_tag = lambda tags: iob2bioes(iob2(tags)) self.lower = lower - - def process(self, data_bundle)->DataBundle: + + def process(self, data_bundle) -> DataBundle: """ 输入的DataSet应该类似于如下的形式 @@ -145,9 +157,9 @@ class Conll2003Pipe(Pipe): dataset.drop(lambda x: "-DOCSTART-" in x[Const.RAW_WORD]) dataset.apply_field(self.chunk_convert_tag, field_name='chunk', new_field_name='chunk') dataset.apply_field(self.ner_convert_tag, field_name='ner', new_field_name='ner') - + _add_words_field(data_bundle, lower=self.lower) - + # index _indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=['pos', 'ner']) # chunk中存在一些tag只在dev中出现,没在train中 @@ -155,18 +167,18 @@ class Conll2003Pipe(Pipe): tgt_vocab.from_dataset(*data_bundle.datasets.values(), field_name='chunk') tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name='chunk') data_bundle.set_vocab(tgt_vocab, 'chunk') - + input_fields = [Const.INPUT, Const.INPUT_LEN] target_fields = ['pos', 'ner', 'chunk', Const.INPUT_LEN] - + for name, dataset in data_bundle.datasets.items(): dataset.add_seq_len(Const.INPUT) - + data_bundle.set_input(*input_fields) data_bundle.set_target(*target_fields) - + return data_bundle - + def process_from_file(self, paths): """ @@ -194,7 +206,7 @@ class OntoNotesNERPipe(_NERPipe): :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 """ - + def process_from_file(self, paths): data_bundle = OntoNotesNERLoader().load(paths) return self.process(data_bundle) @@ -211,13 +223,13 @@ class _CNNERPipe(Pipe): :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 """ - + def __init__(self, encoding_type: str = 'bio'): if encoding_type == 'bio': self.convert_tag = iob2 else: self.convert_tag = lambda words: iob2bioes(iob2(words)) - + def process(self, data_bundle: DataBundle) -> DataBundle: """ 支持的DataSet的field为 @@ -239,21 +251,21 @@ class _CNNERPipe(Pipe): # 转换tag for name, dataset in data_bundle.datasets.items(): dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET) - + _add_chars_field(data_bundle, lower=False) - + # index _indexize(data_bundle, input_field_names=Const.CHAR_INPUT, target_field_names=Const.TARGET) - + input_fields = [Const.TARGET, Const.CHAR_INPUT, Const.INPUT_LEN] target_fields = [Const.TARGET, Const.INPUT_LEN] - + for name, dataset in data_bundle.datasets.items(): dataset.add_seq_len(Const.CHAR_INPUT) - + data_bundle.set_input(*input_fields) data_bundle.set_target(*target_fields) - + return data_bundle @@ -272,6 +284,7 @@ class MsraNERPipe(_CNNERPipe): target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 """ + def process_from_file(self, paths=None) -> DataBundle: data_bundle = MsraNERLoader().load(paths) return self.process(data_bundle) @@ -291,6 +304,7 @@ class PeopleDailyPipe(_CNNERPipe): raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 """ + def process_from_file(self, paths=None) -> DataBundle: data_bundle = PeopleDailyNERLoader().load(paths) return self.process(data_bundle) @@ -312,6 +326,7 @@ class WeiboNERPipe(_CNNERPipe): :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 """ + def process_from_file(self, paths=None) -> DataBundle: data_bundle = WeiboNERLoader().load(paths) return self.process(data_bundle) diff --git a/fastNLP/io/pipe/cws.py b/fastNLP/io/pipe/cws.py index 4ca0219c..748cf10a 100644 --- a/fastNLP/io/pipe/cws.py +++ b/fastNLP/io/pipe/cws.py @@ -1,3 +1,9 @@ +"""undocumented""" + +__all__ = [ + "CWSPipe" +] + import re from itertools import chain diff --git a/fastNLP/io/pipe/matching.py b/fastNLP/io/pipe/matching.py index ffa6375b..699438c8 100644 --- a/fastNLP/io/pipe/matching.py +++ b/fastNLP/io/pipe/matching.py @@ -1,9 +1,25 @@ +"""undocumented""" + +__all__ = [ + "MatchingBertPipe", + "RTEBertPipe", + "SNLIBertPipe", + "QuoraBertPipe", + "QNLIBertPipe", + "MNLIBertPipe", + "MatchingPipe", + "RTEPipe", + "SNLIPipe", + "QuoraPipe", + "QNLIPipe", + "MNLIPipe", +] from .pipe import Pipe from .utils import get_tokenizer +from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader from ...core.const import Const from ...core.vocabulary import Vocabulary -from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader class MatchingBertPipe(Pipe): @@ -24,12 +40,13 @@ class MatchingBertPipe(Pipe): :param bool lower: 是否将word小写化。 :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 """ - def __init__(self, lower=False, tokenizer: str='raw'): + + def __init__(self, lower=False, tokenizer: str = 'raw'): super().__init__() - + self.lower = bool(lower) self.tokenizer = get_tokenizer(tokenizer=tokenizer) - + def _tokenize(self, data_bundle, field_names, new_field_names): """ @@ -43,62 +60,62 @@ class MatchingBertPipe(Pipe): dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, new_field_name=new_field_name) return data_bundle - + def process(self, data_bundle): for dataset in data_bundle.datasets.values(): if dataset.has_field(Const.TARGET): dataset.drop(lambda x: x[Const.TARGET] == '-') - + for name, dataset in data_bundle.datasets.items(): dataset.copy_field(Const.RAW_WORDS(0), Const.INPUTS(0), ) dataset.copy_field(Const.RAW_WORDS(1), Const.INPUTS(1), ) - + if self.lower: for name, dataset in data_bundle.datasets.items(): dataset[Const.INPUTS(0)].lower() dataset[Const.INPUTS(1)].lower() - + data_bundle = self._tokenize(data_bundle, [Const.INPUTS(0), Const.INPUTS(1)], [Const.INPUTS(0), Const.INPUTS(1)]) - + # concat两个words def concat(ins): words0 = ins[Const.INPUTS(0)] words1 = ins[Const.INPUTS(1)] words = words0 + ['[SEP]'] + words1 return words - + for name, dataset in data_bundle.datasets.items(): dataset.apply(concat, new_field_name=Const.INPUT) dataset.delete_field(Const.INPUTS(0)) dataset.delete_field(Const.INPUTS(1)) - + word_vocab = Vocabulary() word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name], field_name=Const.INPUT, no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if 'train' not in name]) word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT) - + target_vocab = Vocabulary(padding=None, unknown=None) target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if dataset.has_field(Const.TARGET)] target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) - + data_bundle.set_vocab(word_vocab, Const.INPUT) data_bundle.set_vocab(target_vocab, Const.TARGET) - + input_fields = [Const.INPUT, Const.INPUT_LEN] target_fields = [Const.TARGET] - + for name, dataset in data_bundle.datasets.items(): dataset.add_seq_len(Const.INPUT) dataset.set_input(*input_fields, flag=True) for fields in target_fields: if dataset.has_field(fields): dataset.set_target(fields, flag=True) - + return data_bundle @@ -150,12 +167,13 @@ class MatchingPipe(Pipe): :param bool lower: 是否将所有raw_words转为小写。 :param str tokenizer: 将原始数据tokenize的方式。支持spacy, raw. spacy是使用spacy切分,raw就是用空格切分。 """ - def __init__(self, lower=False, tokenizer: str='raw'): + + def __init__(self, lower=False, tokenizer: str = 'raw'): super().__init__() - + self.lower = bool(lower) self.tokenizer = get_tokenizer(tokenizer=tokenizer) - + def _tokenize(self, data_bundle, field_names, new_field_names): """ @@ -169,7 +187,7 @@ class MatchingPipe(Pipe): dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, new_field_name=new_field_name) return data_bundle - + def process(self, data_bundle): """ 接受的DataBundle中的DataSet应该具有以下的field, target列可以没有 @@ -186,35 +204,35 @@ class MatchingPipe(Pipe): """ data_bundle = self._tokenize(data_bundle, [Const.RAW_WORDS(0), Const.RAW_WORDS(1)], [Const.INPUTS(0), Const.INPUTS(1)]) - + for dataset in data_bundle.datasets.values(): if dataset.has_field(Const.TARGET): dataset.drop(lambda x: x[Const.TARGET] == '-') - + if self.lower: for name, dataset in data_bundle.datasets.items(): dataset[Const.INPUTS(0)].lower() dataset[Const.INPUTS(1)].lower() - + word_vocab = Vocabulary() word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name], field_name=[Const.INPUTS(0), Const.INPUTS(1)], no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if 'train' not in name]) word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=[Const.INPUTS(0), Const.INPUTS(1)]) - + target_vocab = Vocabulary(padding=None, unknown=None) target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if dataset.has_field(Const.TARGET)] target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) - + data_bundle.set_vocab(word_vocab, Const.INPUTS(0)) data_bundle.set_vocab(target_vocab, Const.TARGET) - + input_fields = [Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1)] target_fields = [Const.TARGET] - + for name, dataset in data_bundle.datasets.items(): dataset.add_seq_len(Const.INPUTS(0), Const.INPUT_LENS(0)) dataset.add_seq_len(Const.INPUTS(1), Const.INPUT_LENS(1)) @@ -222,7 +240,7 @@ class MatchingPipe(Pipe): for fields in target_fields: if dataset.has_field(fields): dataset.set_target(fields, flag=True) - + return data_bundle @@ -254,4 +272,3 @@ class MNLIPipe(MatchingPipe): def process_from_file(self, paths=None): data_bundle = MNLILoader().load(paths) return self.process(data_bundle) - diff --git a/fastNLP/io/pipe/pipe.py b/fastNLP/io/pipe/pipe.py index cc45dee4..a1435fd3 100644 --- a/fastNLP/io/pipe/pipe.py +++ b/fastNLP/io/pipe/pipe.py @@ -1,3 +1,9 @@ +"""undocumented""" + +__all__ = [ + "Pipe", +] + from .. import DataBundle diff --git a/fastNLP/io/pipe/utils.py b/fastNLP/io/pipe/utils.py index 8facd8d9..f32f58b7 100644 --- a/fastNLP/io/pipe/utils.py +++ b/fastNLP/io/pipe/utils.py @@ -1,8 +1,18 @@ +"""undocumented""" + +__all__ = [ + "iob2", + "iob2bioes", + "get_tokenizer", +] + from typing import List -from ...core.vocabulary import Vocabulary + from ...core.const import Const +from ...core.vocabulary import Vocabulary + -def iob2(tags:List[str])->List[str]: +def iob2(tags: List[str]) -> List[str]: """ 检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两种格式的区别见 https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format @@ -25,7 +35,8 @@ def iob2(tags:List[str])->List[str]: tags[i] = "B" + tag[1:] return tags -def iob2bioes(tags:List[str])->List[str]: + +def iob2bioes(tags: List[str]) -> List[str]: """ 将iob的tag转换为bioes编码 :param tags: @@ -38,12 +49,12 @@ def iob2bioes(tags:List[str])->List[str]: else: split = tag.split('-')[0] if split == 'B': - if i+1!=len(tags) and tags[i+1].split('-')[0] == 'I': + if i + 1 != len(tags) and tags[i + 1].split('-')[0] == 'I': new_tags.append(tag) else: new_tags.append(tag.replace('B-', 'S-')) elif split == 'I': - if i + 1List[str]: return new_tags -def get_tokenizer(tokenizer:str, lang='en'): +def get_tokenizer(tokenizer: str, lang='en'): """ :param str tokenizer: 获取tokenzier方法 @@ -97,13 +108,13 @@ def _indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=Con name != 'train']) src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=input_field_name) data_bundle.set_vocab(src_vocab, input_field_name) - + for target_field_name in target_field_names: tgt_vocab = Vocabulary(unknown=None, padding=None) tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=target_field_name) tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name=target_field_name) data_bundle.set_vocab(tgt_vocab, target_field_name) - + return data_bundle @@ -116,7 +127,7 @@ def _add_words_field(data_bundle, lower=False): :return: 传入的DataBundle """ data_bundle.copy_field(field_name=Const.RAW_WORD, new_field_name=Const.INPUT, ignore_miss_dataset=True) - + if lower: for name, dataset in data_bundle.datasets.items(): dataset[Const.INPUT].lower() @@ -132,7 +143,7 @@ def _add_chars_field(data_bundle, lower=False): :return: 传入的DataBundle """ data_bundle.copy_field(field_name=Const.RAW_CHAR, new_field_name=Const.CHAR_INPUT, ignore_miss_dataset=True) - + if lower: for name, dataset in data_bundle.datasets.items(): dataset[Const.CHAR_INPUT].lower() @@ -147,6 +158,7 @@ def _drop_empty_instance(data_bundle, field_name): :param str field_name: 对哪个field进行检查,如果为None,则任意field为空都会删掉 :return: 传入的DataBundle """ + def empty_instance(ins): if field_name: field_value = ins[field_name] @@ -157,10 +169,8 @@ def _drop_empty_instance(data_bundle, field_name): if field_value in ((), {}, [], ''): return True return False - + for name, dataset in data_bundle.datasets.items(): dataset.drop(empty_instance) - + return data_bundle - - diff --git a/fastNLP/io/utils.py b/fastNLP/io/utils.py index faec2a55..e1de2ae7 100644 --- a/fastNLP/io/utils.py +++ b/fastNLP/io/utils.py @@ -1,10 +1,20 @@ -import os +""" +.. todo:: + doc +""" -from typing import Union, Dict +__all__ = [ + "check_loader_paths" +] + +import os from pathlib import Path +from typing import Union, Dict + from ..core import logger -def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: + +def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]: """ 检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果:: @@ -33,11 +43,13 @@ def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: path_pair = ('train', filename) if 'dev' in filename: if path_pair: - raise Exception("File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0])) + raise Exception( + "File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0])) path_pair = ('dev', filename) if 'test' in filename: if path_pair: - raise Exception("File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0])) + raise Exception( + "File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0])) path_pair = ('test', filename) if path_pair: files[path_pair[0]] = os.path.join(paths, path_pair[1]) @@ -46,7 +58,7 @@ def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: return files else: raise FileNotFoundError(f"{paths} is not a valid file path.") - + elif isinstance(paths, dict): if paths: if 'train' not in paths: @@ -65,6 +77,7 @@ def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: else: raise TypeError(f"paths only supports str and dict. not {type(paths)}.") + def get_tokenizer(): try: import spacy