From 2550c96f80a9c5e52f910d112ac2cc659719849e Mon Sep 17 00:00:00 2001 From: MorningForest <2297662686@qq.com> Date: Fri, 8 Apr 2022 21:33:24 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86io=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/io/__init__.py | 121 ++++ fastNLP/io/cws.py | 97 +++ fastNLP/io/data_bundle.py | 354 +++++++++++ fastNLP/io/embed_loader.py | 188 ++++++ fastNLP/io/file_reader.py | 136 ++++ fastNLP/io/file_utils.py | 578 +++++++++++++++++ fastNLP/io/loader/__init__.py | 107 ++++ fastNLP/io/loader/classification.py | 647 +++++++++++++++++++ fastNLP/io/loader/conll.py | 542 ++++++++++++++++ fastNLP/io/loader/coreference.py | 64 ++ fastNLP/io/loader/csv.py | 38 ++ fastNLP/io/loader/cws.py | 97 +++ fastNLP/io/loader/json.py | 45 ++ fastNLP/io/loader/loader.py | 94 +++ fastNLP/io/loader/matching.py | 577 +++++++++++++++++ fastNLP/io/loader/qa.py | 74 +++ fastNLP/io/loader/summarization.py | 63 ++ fastNLP/io/model_io.py | 71 +++ fastNLP/io/pipe/__init__.py | 80 +++ fastNLP/io/pipe/classification.py | 939 ++++++++++++++++++++++++++++ fastNLP/io/pipe/conll.py | 427 +++++++++++++ fastNLP/io/pipe/construct_graph.py | 286 +++++++++ fastNLP/io/pipe/coreference.py | 186 ++++++ fastNLP/io/pipe/cws.py | 282 +++++++++ fastNLP/io/pipe/matching.py | 545 ++++++++++++++++ fastNLP/io/pipe/pipe.py | 41 ++ fastNLP/io/pipe/qa.py | 144 +++++ fastNLP/io/pipe/summarization.py | 196 ++++++ fastNLP/io/pipe/utils.py | 224 +++++++ fastNLP/io/utils.py | 82 +++ 30 files changed, 7325 insertions(+) create mode 100644 fastNLP/io/__init__.py create mode 100644 fastNLP/io/cws.py create mode 100644 fastNLP/io/data_bundle.py create mode 100644 fastNLP/io/embed_loader.py create mode 100644 fastNLP/io/file_reader.py create mode 100644 fastNLP/io/file_utils.py create mode 100644 fastNLP/io/loader/__init__.py create mode 100644 fastNLP/io/loader/classification.py create mode 100644 fastNLP/io/loader/conll.py create mode 100644 fastNLP/io/loader/coreference.py create mode 100644 fastNLP/io/loader/csv.py create mode 100644 fastNLP/io/loader/cws.py create mode 100644 fastNLP/io/loader/json.py create mode 100644 fastNLP/io/loader/loader.py create mode 100644 fastNLP/io/loader/matching.py create mode 100644 fastNLP/io/loader/qa.py create mode 100644 fastNLP/io/loader/summarization.py create mode 100644 fastNLP/io/model_io.py create mode 100644 fastNLP/io/pipe/__init__.py create mode 100644 fastNLP/io/pipe/classification.py create mode 100644 fastNLP/io/pipe/conll.py create mode 100644 fastNLP/io/pipe/construct_graph.py create mode 100644 fastNLP/io/pipe/coreference.py create mode 100644 fastNLP/io/pipe/cws.py create mode 100644 fastNLP/io/pipe/matching.py create mode 100644 fastNLP/io/pipe/pipe.py create mode 100644 fastNLP/io/pipe/qa.py create mode 100644 fastNLP/io/pipe/summarization.py create mode 100644 fastNLP/io/pipe/utils.py create mode 100644 fastNLP/io/utils.py diff --git a/fastNLP/io/__init__.py b/fastNLP/io/__init__.py new file mode 100644 index 00000000..75edf1c5 --- /dev/null +++ b/fastNLP/io/__init__.py @@ -0,0 +1,121 @@ +r""" +用于IO的模块, 具体包括: + +1. 用于读入 embedding 的 :mod:`EmbedLoader ` 类, + +2. 用于读入不同格式数据的 :mod:`Loader ` 类 + +3. 用于处理读入数据的 :mod:`Pipe ` 类 + +4. 用于保存和载入模型的类, 参考 :mod:`model_io模块 ` + +这些类的使用方法如下: +""" +__all__ = [ + 'DataBundle', + + 'EmbedLoader', + + 'Loader', + + 'CLSBaseLoader', + 'AGsNewsLoader', + 'DBPediaLoader', + 'YelpFullLoader', + 'YelpPolarityLoader', + 'IMDBLoader', + 'SSTLoader', + 'SST2Loader', + "ChnSentiCorpLoader", + "THUCNewsLoader", + "WeiboSenti100kLoader", + + 'ConllLoader', + 'Conll2003Loader', + 'Conll2003NERLoader', + 'OntoNotesNERLoader', + 'CTBLoader', + "MsraNERLoader", + "WeiboNERLoader", + "PeopleDailyNERLoader", + + 'CSVLoader', + 'JsonLoader', + + 'CWSLoader', + + 'MNLILoader', + "QuoraLoader", + "SNLILoader", + "QNLILoader", + "RTELoader", + "CNXNLILoader", + "BQCorpusLoader", + "LCQMCLoader", + + "CMRC2018Loader", + + "Pipe", + + "CLSBasePipe", + "AGsNewsPipe", + "DBPediaPipe", + "YelpFullPipe", + "YelpPolarityPipe", + "SSTPipe", + "SST2Pipe", + "IMDBPipe", + "ChnSentiCorpPipe", + "THUCNewsPipe", + "WeiboSenti100kPipe", + + "Conll2003Pipe", + "Conll2003NERPipe", + "OntoNotesNERPipe", + "MsraNERPipe", + "PeopleDailyPipe", + "WeiboNERPipe", + + "CWSPipe", + + "Conll2003NERPipe", + "OntoNotesNERPipe", + "MsraNERPipe", + "WeiboNERPipe", + "PeopleDailyPipe", + "Conll2003Pipe", + + "MatchingBertPipe", + "RTEBertPipe", + "SNLIBertPipe", + "QuoraBertPipe", + "QNLIBertPipe", + "MNLIBertPipe", + "CNXNLIBertPipe", + "BQCorpusBertPipe", + "LCQMCBertPipe", + "MatchingPipe", + "RTEPipe", + "SNLIPipe", + "QuoraPipe", + "QNLIPipe", + "MNLIPipe", + "LCQMCPipe", + "CNXNLIPipe", + "BQCorpusPipe", + "RenamePipe", + "GranularizePipe", + "MachingTruncatePipe", + + "CMRC2018BertPipe", + + 'ModelLoader', + 'ModelSaver', + +] + +from .data_bundle import DataBundle +from .embed_loader import EmbedLoader +from .loader import * +from .model_io import ModelLoader, ModelSaver +from .pipe import * diff --git a/fastNLP/io/cws.py b/fastNLP/io/cws.py new file mode 100644 index 00000000..d88d6a00 --- /dev/null +++ b/fastNLP/io/cws.py @@ -0,0 +1,97 @@ +r"""undocumented""" + +__all__ = [ + "CWSLoader" +] + +import glob +import os +import random +import shutil +import time + +from .loader import Loader +from fastNLP.core.dataset import DataSet, Instance + + +class CWSLoader(Loader): + r""" + CWSLoader支持的数据格式为,一行一句话,不同词之间用空格隔开, 例如: + + Example:: + + 上海 浦东 开发 与 法制 建设 同步 + 新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 ) + ... + + 该Loader读取后的DataSet具有如下的结构 + + .. csv-table:: + :header: "raw_words" + + "上海 浦东 开发 与 法制 建设 同步" + "新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )" + "..." + + """ + + def __init__(self, dataset_name: str = None): + r""" + + :param str dataset_name: data的名称,支持pku, msra, cityu(繁体), as(繁体), None + """ + super().__init__() + datanames = {'pku': 'cws-pku', 'msra': 'cws-msra', 'as': 'cws-as', 'cityu': 'cws-cityu'} + if dataset_name in datanames: + self.dataset_name = datanames[dataset_name] + else: + self.dataset_name = None + + def _load(self, path: str): + ds = DataSet() + with open(path, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if line: + ds.append(Instance(raw_words=line)) + return ds + + def download(self, dev_ratio=0.1, re_download=False) -> str: + r""" + 如果你使用了该数据集,请引用以下的文章:Thomas Emerson, The Second International Chinese Word Segmentation Bakeoff, + 2005. 更多信息可以在http://sighan.cs.uchicago.edu/bakeoff2005/查看 + + :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 + :param bool re_download: 是否重新下载数据,以重新切分数据。 + :return: str + """ + if self.dataset_name is None: + return '' + data_dir = self._get_dataset_path(dataset_name=self.dataset_name) + modify_time = 0 + for filepath in glob.glob(os.path.join(data_dir, '*')): + modify_time = os.stat(filepath).st_mtime + break + if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 + shutil.rmtree(data_dir) + data_dir = self._get_dataset_path(dataset_name=self.dataset_name) + + if not os.path.exists(os.path.join(data_dir, 'dev.txt')): + if dev_ratio > 0: + assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." + try: + with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f, \ + open(os.path.join(data_dir, 'middle_file.txt'), 'w', encoding='utf-8') as f1, \ + open(os.path.join(data_dir, 'dev.txt'), 'w', encoding='utf-8') as f2: + for line in f: + if random.random() < dev_ratio: + f2.write(line) + else: + f1.write(line) + os.remove(os.path.join(data_dir, 'train.txt')) + os.renames(os.path.join(data_dir, 'middle_file.txt'), os.path.join(data_dir, 'train.txt')) + finally: + if os.path.exists(os.path.join(data_dir, 'middle_file.txt')): + os.remove(os.path.join(data_dir, 'middle_file.txt')) + + return data_dir diff --git a/fastNLP/io/data_bundle.py b/fastNLP/io/data_bundle.py new file mode 100644 index 00000000..5daee519 --- /dev/null +++ b/fastNLP/io/data_bundle.py @@ -0,0 +1,354 @@ +r""" +.. todo:: + doc +""" +__all__ = [ + 'DataBundle', +] + +from typing import Union, List, Callable + +from ..core.dataset import DataSet +from fastNLP.core.vocabulary import Vocabulary +# from ..core._logger import _logger + + +class DataBundle: + r""" + 经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)以及各个field对应的vocabulary。该对象一般由fastNLP中各种 + Loader的load函数生成,可以通过以下的方法获取里面的内容 + + Example:: + + data_bundle = YelpLoader().load({'train':'/path/to/train', 'dev': '/path/to/dev'}) + train_vocabs = data_bundle.vocabs['train'] + train_data = data_bundle.datasets['train'] + dev_data = data_bundle.datasets['train'] + + """ + + def __init__(self, vocabs=None, datasets=None): + r""" + + :param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict + :param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict。建议不要将相同的DataSet对象重复传入,可能会在 + 使用Pipe处理数据的时候遇到问题,若多个数据集确需一致,请手动deepcopy后传入。 + """ + self.vocabs = vocabs or {} + self.datasets = datasets or {} + + def set_vocab(self, vocab: Vocabulary, field_name: str): + r""" + 向DataBunlde中增加vocab + + :param ~fastNLP.Vocabulary vocab: 词表 + :param str field_name: 这个vocab对应的field名称 + :return: self + """ + assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary supports." + self.vocabs[field_name] = vocab + return self + + def set_dataset(self, dataset: DataSet, name: str): + r""" + + :param ~fastNLP.DataSet dataset: 传递给DataBundle的DataSet + :param str name: dataset的名称 + :return: self + """ + assert isinstance(dataset, DataSet), "Only fastNLP.DataSet supports." + self.datasets[name] = dataset + return self + + def get_dataset(self, name: str) -> DataSet: + r""" + 获取名为name的dataset + + :param str name: dataset的名称,一般为'train', 'dev', 'test' + :return: DataSet + """ + if name in self.datasets.keys(): + return self.datasets[name] + else: + error_msg = f'DataBundle do NOT have DataSet named {name}. ' \ + f'It should be one of {self.datasets.keys()}.' + print(error_msg) + raise KeyError(error_msg) + + def delete_dataset(self, name: str): + r""" + 删除名为name的DataSet + + :param str name: + :return: self + """ + self.datasets.pop(name, None) + return self + + def get_vocab(self, field_name: str) -> Vocabulary: + r""" + 获取field名为field_name对应的vocab + + :param str field_name: 名称 + :return: Vocabulary + """ + if field_name in self.vocabs.keys(): + return self.vocabs[field_name] + else: + error_msg = f'DataBundle do NOT have Vocabulary named {field_name}. ' \ + f'It should be one of {self.vocabs.keys()}.' + print(error_msg) + raise KeyError(error_msg) + + def delete_vocab(self, field_name: str): + r""" + 删除vocab + :param str field_name: + :return: self + """ + self.vocabs.pop(field_name, None) + return self + + @property + def num_dataset(self): + return len(self.datasets) + + @property + def num_vocab(self): + return len(self.vocabs) + + def copy_field(self, field_name: str, new_field_name: str, ignore_miss_dataset=True): + r""" + 将DataBundle中所有的DataSet中名为field_name的Field复制一份并命名为叫new_field_name. + + :param str field_name: + :param str new_field_name: + :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; + 如果为False,则报错 + :return: self + """ + for name, dataset in self.datasets.items(): + if dataset.has_field(field_name=field_name): + dataset.copy_field(field_name=field_name, new_field_name=new_field_name) + elif not ignore_miss_dataset: + raise KeyError(f"{field_name} not found DataSet:{name}.") + return self + + def rename_field(self, field_name: str, new_field_name: str, ignore_miss_dataset=True, rename_vocab=True): + r""" + 将DataBundle中所有DataSet中名为field_name的field重命名为new_field_name. + + :param str field_name: + :param str new_field_name: + :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; + 如果为False,则报错 + :param bool rename_vocab: 如果该field同时也存在于vocabs中,会将该field的名称对应修改 + :return: self + """ + for name, dataset in self.datasets.items(): + if dataset.has_field(field_name=field_name): + dataset.rename_field(field_name=field_name, new_field_name=new_field_name) + elif not ignore_miss_dataset: + raise KeyError(f"{field_name} not found DataSet:{name}.") + if rename_vocab: + if field_name in self.vocabs: + self.vocabs[new_field_name] = self.vocabs.pop(field_name) + + return self + + def delete_field(self, field_name: str, ignore_miss_dataset=True, delete_vocab=True): + r""" + 将DataBundle中所有DataSet中名为field_name的field删除掉. + + :param str field_name: + :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; + 如果为False,则报错 + :param bool delete_vocab: 如果该field也在vocabs中存在,将该值也一并删除 + :return: self + """ + for name, dataset in self.datasets.items(): + if dataset.has_field(field_name=field_name): + dataset.delete_field(field_name=field_name) + elif not ignore_miss_dataset: + raise KeyError(f"{field_name} not found DataSet:{name}.") + if delete_vocab: + if field_name in self.vocabs: + self.vocabs.pop(field_name) + return self + + def iter_datasets(self) -> Union[str, DataSet]: + r""" + 迭代data_bundle中的DataSet + + Example:: + + for name, dataset in data_bundle.iter_datasets(): + pass + + :return: + """ + for name, dataset in self.datasets.items(): + yield name, dataset + + def get_dataset_names(self) -> List[str]: + r""" + 返回DataBundle中DataSet的名称 + + :return: + """ + return list(self.datasets.keys()) + + def get_vocab_names(self) -> List[str]: + r""" + 返回DataBundle中Vocabulary的名称 + + :return: + """ + return list(self.vocabs.keys()) + + def iter_vocabs(self): + r""" + 迭代data_bundle中的DataSet + + Example: + + for field_name, vocab in data_bundle.iter_vocabs(): + pass + + :return: + """ + for field_name, vocab in self.vocabs.items(): + yield field_name, vocab + + def apply_field(self, func: Callable, field_name: str, new_field_name: str, num_proc: int = 0, + ignore_miss_dataset: bool = True, progress_desc: str = '', show_progress_bar: bool = True): + r""" + 对 :class:`~fastNLP.io.DataBundle` 中所有的dataset使用 :method:`~fastNLP.DataSet.apply_field` 方法 + + :param callable func: input是instance中名为 `field_name` 的field的内容。 + :param str field_name: 传入func的是哪个field。 + :param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 + 盖之前的field。如果为None则不创建新的field。 + :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; + 如果为False,则报错 + :param ignore_miss_dataset: + :param num_proc: + :param progress_desc 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称 + :param show_progress_bar 是否显示tqdm进度条 + + """ + _progress_desc = progress_desc + for name, dataset in self.datasets.items(): + if _progress_desc: + progress_desc = _progress_desc + f' for `{name}`' + if dataset.has_field(field_name=field_name): + dataset.apply_field(func=func, field_name=field_name, new_field_name=new_field_name, num_proc=num_proc, + progress_desc=progress_desc, show_progress_bar=show_progress_bar) + elif not ignore_miss_dataset: + raise KeyError(f"{field_name} not found DataSet:{name}.") + return self + + def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True, + ignore_miss_dataset=True, progress_desc: str = '', show_progress_bar: bool = True): + r""" + 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法 + + .. note:: + ``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 + ``apply`` 区别的介绍。 + + :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 + :param str field_name: 传入func的是哪个field。 + :param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True + :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; + 如果为False,则报错 + :param show_progress_bar: 是否显示tqdm进度条 + :param progress_desc: 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称 + :param num_proc: + + :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 + + """ + res = {} + _progress_desc = progress_desc + for name, dataset in self.datasets.items(): + if _progress_desc: + progress_desc = _progress_desc + f' for `{name}`' + if dataset.has_field(field_name=field_name): + res[name] = dataset.apply_field_more(func=func, field_name=field_name, num_proc=num_proc, + modify_fields=modify_fields, + show_progress_bar=show_progress_bar, progress_desc=progress_desc) + elif not ignore_miss_dataset: + raise KeyError(f"{field_name} not found DataSet:{name} .") + return res + + def apply(self, func: Callable, new_field_name: str, num_proc: int = 0, + progress_desc: str = '', show_progress_bar: bool = True, _apply_field: str = None): + r""" + 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply` 方法 + + 对DataBundle中所有的dataset使用apply方法 + + :param callable func: input是instance中名为 `field_name` 的field的内容。 + :param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 + 盖之前的field。如果为None则不创建新的field。 + :param _apply_field: + :param show_progress_bar: 是否显示tqd进度条 + :param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称 + :param num_proc + + """ + _progress_desc = progress_desc + for name, dataset in self.datasets.items(): + if _progress_desc: + progress_desc = _progress_desc + f' for `{name}`' + dataset.apply(func, new_field_name=new_field_name, num_proc=num_proc, show_progress_bar=show_progress_bar, + progress_desc=progress_desc, _apply_field=_apply_field) + return self + + def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0, + progress_desc: str = '', show_progress_bar: bool = True): + r""" + 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_more` 方法 + + .. note:: + ``apply_more`` 与 ``apply`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 + ``apply`` 区别的介绍。 + + :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 + :param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True + :param show_progress_bar: 是否显示tqd进度条 + :param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称 + :param num_proc + + :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 + """ + res = {} + _progress_desc = progress_desc + for name, dataset in self.datasets.items(): + if _progress_desc: + progress_desc = _progress_desc + f' for `{name}`' + res[name] = dataset.apply_more(func, modify_fields=modify_fields, num_proc=num_proc, + show_progress_bar=show_progress_bar, progress_desc=progress_desc) + return res + + def set_pad_val(self, *field_names, val=0) -> None: + for _, ds in self.iter_datasets(): + ds.set_pad_val(*field_names, val=val) + + def set_input(self, *field_names) -> None: + for _, ds in self.iter_datasets(): + ds.set_input(*field_names) + + def __repr__(self) -> str: + _str = '' + if len(self.datasets): + _str += 'In total {} datasets:\n'.format(self.num_dataset) + for name, dataset in self.datasets.items(): + _str += '\t{} has {} instances.\n'.format(name, len(dataset)) + if len(self.vocabs): + _str += 'In total {} vocabs:\n'.format(self.num_vocab) + for name, vocab in self.vocabs.items(): + _str += '\t{} has {} entries.\n'.format(name, len(vocab)) + return _str + diff --git a/fastNLP/io/embed_loader.py b/fastNLP/io/embed_loader.py new file mode 100644 index 00000000..fe084455 --- /dev/null +++ b/fastNLP/io/embed_loader.py @@ -0,0 +1,188 @@ +r""" +.. todo:: + doc +""" +__all__ = [ + "EmbedLoader", + "EmbeddingOption", +] + +import logging +import os +import warnings + +import numpy as np + +from fastNLP.core.utils.utils import Option +from fastNLP.core.vocabulary import Vocabulary + + +class EmbeddingOption(Option): + def __init__(self, + embed_filepath=None, + dtype=np.float32, + normalize=True, + error='ignore'): + super().__init__( + embed_filepath=embed_filepath, + dtype=dtype, + normalize=normalize, + error=error + ) + + +class EmbedLoader: + r""" + 用于读取预训练的embedding, 读取结果可直接载入为模型参数。 + """ + + def __init__(self): + super(EmbedLoader, self).__init__() + + @staticmethod + def load_with_vocab(embed_filepath, vocab, dtype=np.float32, padding='', unknown='', normalize=True, + error='ignore', init_method=None): + r""" + 从embed_filepath这个预训练的词向量中抽取出vocab这个词表的词的embedding。EmbedLoader将自动判断embed_filepath是 + word2vec(第一行只有两个元素)还是glove格式的数据。 + + :param str embed_filepath: 预训练的embedding的路径。 + :param vocab: 词表 :class:`~fastNLP.Vocabulary` 类型,读取出现在vocab中的词的embedding。 + 没有出现在vocab中的词的embedding将通过找到的词的embedding的正态分布采样出来,以使得整个Embedding是同分布的。 + :param dtype: 读出的embedding的类型 + :param str padding: 词表中padding的token + :param str unknown: 词表中unknown的token + :param bool normalize: 是否将每个vector归一化到norm为1 + :param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。 + 这里主要可能出错的地方在于词表有空行或者词表出现了维度不一致。 + :param callable init_method: 传入numpy.ndarray, 返回numpy.ndarray, 用以初始化embedding + :return numpy.ndarray: shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。 + """ + assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary is supported." + if not os.path.exists(embed_filepath): + raise FileNotFoundError("`{}` does not exist.".format(embed_filepath)) + with open(embed_filepath, 'r', encoding='utf-8') as f: + hit_flags = np.zeros(len(vocab), dtype=bool) + line = f.readline().strip() + parts = line.split() + start_idx = 0 + if len(parts) == 2: + dim = int(parts[1]) + start_idx += 1 + else: + dim = len(parts) - 1 + f.seek(0) + matrix = np.random.randn(len(vocab), dim).astype(dtype) + if init_method: + matrix = init_method(matrix) + for idx, line in enumerate(f, start_idx): + try: + parts = line.strip().split() + word = ''.join(parts[:-dim]) + nums = parts[-dim:] + # 对齐unk与pad + if word == padding and vocab.padding is not None: + word = vocab.padding + elif word == unknown and vocab.unknown is not None: + word = vocab.unknown + if word in vocab: + index = vocab.to_index(word) + matrix[index] = np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim) + hit_flags[index] = True + except Exception as e: + if error == 'ignore': + warnings.warn("Error occurred at the {} line.".format(idx)) + else: + logging.error("Error occurred at the {} line.".format(idx)) + raise e + total_hits = sum(hit_flags) + logging.info("Found {} out of {} words in the pre-training embedding.".format(total_hits, len(vocab))) + if init_method is None: + found_vectors = matrix[hit_flags] + if len(found_vectors) != 0: + mean = np.mean(found_vectors, axis=0, keepdims=True) + std = np.std(found_vectors, axis=0, keepdims=True) + unfound_vec_num = len(vocab) - total_hits + r_vecs = np.random.randn(unfound_vec_num, dim).astype(dtype) * std + mean + matrix[hit_flags == False] = r_vecs + + if normalize: + matrix /= np.linalg.norm(matrix, axis=1, keepdims=True) + + return matrix + + @staticmethod + def load_without_vocab(embed_filepath, dtype=np.float32, padding='', unknown='', normalize=True, + error='ignore'): + r""" + 从embed_filepath中读取预训练的word vector。根据预训练的词表读取embedding并生成一个对应的Vocabulary。 + + :param str embed_filepath: 预训练的embedding的路径。 + :param dtype: 读出的embedding的类型 + :param str padding: 词表中的padding的token. 并以此用做vocab的padding。 + :param str unknown: 词表中的unknown的token. 并以此用做vocab的unknown。 + :param bool normalize: 是否将每个vector归一化到norm为1 + :param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。这里主要可能出错的地 + 方在于词表有空行或者词表出现了维度不一致。 + :return (numpy.ndarray, Vocabulary): Embedding的shape是[词表大小+x, 词表维度], "词表大小+x"是由于最终的大小还取决与 + 是否使用padding, 以及unknown有没有在词表中找到对应的词。 Vocabulary中的词的顺序与Embedding的顺序是一一对应的。 + + """ + vocab = Vocabulary(padding=padding, unknown=unknown) + vec_dict = {} + found_unknown = False + found_pad = False + + with open(embed_filepath, 'r', encoding='utf-8') as f: + line = f.readline() + start = 1 + dim = -1 + if len(line.strip().split()) != 2: + f.seek(0) + start = 0 + for idx, line in enumerate(f, start=start): + try: + parts = line.strip().split() + if dim == -1: + dim = len(parts) - 1 + word = ''.join(parts[:-dim]) + nums = parts[-dim:] + vec = np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim) + vec_dict[word] = vec + vocab.add_word(word) + if unknown is not None and unknown == word: + found_unknown = True + if padding is not None and padding == word: + found_pad = True + except Exception as e: + if error == 'ignore': + warnings.warn("Error occurred at the {} line.".format(idx)) + pass + else: + logging.error("Error occurred at the {} line.".format(idx)) + raise e + if dim == -1: + raise RuntimeError("{} is an empty file.".format(embed_filepath)) + matrix = np.random.randn(len(vocab), dim).astype(dtype) + for key, vec in vec_dict.items(): + index = vocab.to_index(key) + matrix[index] = vec + + if ((unknown is not None) and (not found_unknown)) or ((padding is not None) and (not found_pad)): + start_idx = 0 + if padding is not None: + start_idx += 1 + if unknown is not None: + start_idx += 1 + + mean = np.mean(matrix[start_idx:], axis=0, keepdims=True) + std = np.std(matrix[start_idx:], axis=0, keepdims=True) + if (unknown is not None) and (not found_unknown): + matrix[start_idx - 1] = np.random.randn(1, dim).astype(dtype) * std + mean + if (padding is not None) and (not found_pad): + matrix[0] = np.random.randn(1, dim).astype(dtype) * std + mean + + if normalize: + matrix /= np.linalg.norm(matrix, axis=1, keepdims=True) + + return matrix, vocab diff --git a/fastNLP/io/file_reader.py b/fastNLP/io/file_reader.py new file mode 100644 index 00000000..43460d19 --- /dev/null +++ b/fastNLP/io/file_reader.py @@ -0,0 +1,136 @@ +r"""undocumented +此模块用于给其它模块提供读取文件的函数,没有为用户提供 API +""" + +__all__ = [] + +import json +import csv + +# from ..core import log + + +def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): + r""" + Construct a generator to read csv items. + + :param path: file path + :param encoding: file's encoding, default: utf-8 + :param headers: file's headers, if None, make file's first line as headers. default: None + :param sep: separator for each column. default: ',' + :param dropna: weather to ignore and drop invalid data, + :if False, raise ValueError when reading invalid data. default: True + :return: generator, every time yield (line number, csv item) + """ + with open(path, 'r', encoding=encoding) as csv_file: + f = csv.reader(csv_file, delimiter=sep) + start_idx = 0 + if headers is None: + headers = next(f) + start_idx += 1 + elif not isinstance(headers, (list, tuple)): + raise TypeError("headers should be list or tuple, not {}." \ + .format(type(headers))) + for line_idx, line in enumerate(f, start_idx): + contents = line + if len(contents) != len(headers): + if dropna: + continue + else: + if "" in headers: + raise ValueError(("Line {} has {} parts, while header has {} parts.\n" + + "Please check the empty parts or unnecessary '{}'s in header.") + .format(line_idx, len(contents), len(headers), sep)) + else: + raise ValueError("Line {} has {} parts, while header has {} parts." \ + .format(line_idx, len(contents), len(headers))) + _dict = {} + for header, content in zip(headers, contents): + _dict[header] = content + yield line_idx, _dict + + +def _read_json(path, encoding='utf-8', fields=None, dropna=True): + r""" + Construct a generator to read json items. + + :param path: file path + :param encoding: file's encoding, default: utf-8 + :param fields: json object's fields that needed, if None, all fields are needed. default: None + :param dropna: weather to ignore and drop invalid data, + :if False, raise ValueError when reading invalid data. default: True + :return: generator, every time yield (line number, json item) + """ + if fields: + fields = set(fields) + with open(path, 'r', encoding=encoding) as f: + for line_idx, line in enumerate(f): + data = json.loads(line) + if fields is None: + yield line_idx, data + continue + _res = {} + for k, v in data.items(): + if k in fields: + _res[k] = v + if len(_res) < len(fields): + if dropna: + continue + else: + raise ValueError('invalid instance at line: {}'.format(line_idx)) + yield line_idx, _res + + +def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True): + r""" + Construct a generator to read conll items. + + :param path: file path + :param encoding: file's encoding, default: utf-8 + :param sep: seperator + :param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None + :param dropna: weather to ignore and drop invalid data, + :if False, raise ValueError when reading invalid data. default: True + :return: generator, every time yield (line number, conll item) + """ + + def parse_conll(sample): + sample = list(map(list, zip(*sample))) + sample = [sample[i] for i in indexes] + for f in sample: + if len(f) <= 0: + raise ValueError('empty field') + return sample + + with open(path, 'r', encoding=encoding) as f: + sample = [] + start = next(f).strip() + if start != '': + sample.append(start.split(sep)) if sep else sample.append(start.split()) + for line_idx, line in enumerate(f, 1): + line = line.strip() + if line == '': + if len(sample): + try: + res = parse_conll(sample) + sample = [] + yield line_idx, res + except Exception as e: + if dropna: + print('Invalid instance which ends at line: {} has been dropped.'.format(line_idx)) + sample = [] + continue + raise ValueError('Invalid instance which ends at line: {}'.format(line_idx)) + elif line.startswith('#'): + continue + else: + sample.append(line.split(sep)) if sep else sample.append(line.split()) + if len(sample) > 0: + try: + res = parse_conll(sample) + yield line_idx, res + except Exception as e: + if dropna: + return + print('invalid instance ends at line: {}'.format(line_idx)) + raise e diff --git a/fastNLP/io/file_utils.py b/fastNLP/io/file_utils.py new file mode 100644 index 00000000..b982af54 --- /dev/null +++ b/fastNLP/io/file_utils.py @@ -0,0 +1,578 @@ +r""" +.. todo:: + doc +""" + +__all__ = [ + "cached_path", + "get_filepath", + "get_cache_path", + "split_filename_suffix", + "get_from_cache", +] + +import os +import re +import shutil +import tempfile +from pathlib import Path +from urllib.parse import urlparse + +import requests +from requests import HTTPError + +from fastNLP.core.log import logger +from rich.progress import Progress, BarColumn, DownloadColumn, TimeRemainingColumn, TimeElapsedColumn + +PRETRAINED_BERT_MODEL_DIR = { + 'en': 'bert-base-cased.zip', + 'en-large-cased-wwm': 'bert-large-cased-wwm.zip', + 'en-large-uncased-wwm': 'bert-large-uncased-wwm.zip', + + 'en-large-uncased': 'bert-large-uncased.zip', + 'en-large-cased': 'bert-large-cased.zip', + + 'en-base-uncased': 'bert-base-uncased.zip', + 'en-base-cased': 'bert-base-cased.zip', + + 'en-base-cased-mrpc': 'bert-base-cased-finetuned-mrpc.zip', + + 'en-distilbert-base-uncased': 'distilbert-base-uncased.zip', + + 'multi-base-cased': 'bert-base-multilingual-cased.zip', + 'multi-base-uncased': 'bert-base-multilingual-uncased.zip', + + 'cn': 'bert-chinese-wwm.zip', + 'cn-base': 'bert-base-chinese.zip', + 'cn-wwm': 'bert-chinese-wwm.zip', + 'cn-wwm-ext': "bert-chinese-wwm-ext.zip" +} + +PRETRAINED_GPT2_MODEL_DIR = { + 'en': 'gpt2.zip', + 'en-medium': 'gpt2-medium.zip', + 'en-large': 'gpt2-large.zip', + 'en-xl': 'gpt2-xl.zip' +} + +PRETRAINED_ROBERTA_MODEL_DIR = { + 'en': 'roberta-base.zip', + 'en-large': 'roberta-large.zip' +} + +PRETRAINED_ELMO_MODEL_DIR = { + 'en': 'elmo_en_Medium.zip', + 'en-small': "elmo_en_Small.zip", + 'en-original-5.5b': 'elmo_en_Original_5.5B.zip', + 'en-original': 'elmo_en_Original.zip', + 'en-medium': 'elmo_en_Medium.zip' +} + +PRETRAIN_STATIC_FILES = { + 'en': 'glove.840B.300d.zip', + + 'en-glove-6b-50d': 'glove.6B.50d.zip', + 'en-glove-6b-100d': 'glove.6B.100d.zip', + 'en-glove-6b-200d': 'glove.6B.200d.zip', + 'en-glove-6b-300d': 'glove.6B.300d.zip', + 'en-glove-42b-300d': 'glove.42B.300d.zip', + 'en-glove-840b-300d': 'glove.840B.300d.zip', + 'en-glove-twitter-27b-25d': 'glove.twitter.27B.25d.zip', + 'en-glove-twitter-27b-50d': 'glove.twitter.27B.50d.zip', + 'en-glove-twitter-27b-100d': 'glove.twitter.27B.100d.zip', + 'en-glove-twitter-27b-200d': 'glove.twitter.27B.200d.zip', + + 'en-word2vec-300d': "GoogleNews-vectors-negative300.txt.gz", + + 'en-fasttext-wiki': "wiki-news-300d-1M.vec.zip", + 'en-fasttext-crawl': "crawl-300d-2M.vec.zip", + + 'cn': "tencent_cn.zip", + 'cn-tencent': "tencent_cn.zip", + 'cn-fasttext': "cc.zh.300.vec.gz", + 'cn-sgns-literature-word': 'sgns.literature.word.txt.zip', + 'cn-char-fastnlp-100d': "cn_char_fastnlp_100d.zip", + 'cn-bi-fastnlp-100d': "cn_bi_fastnlp_100d.zip", + "cn-tri-fastnlp-100d": "cn_tri_fastnlp_100d.zip" +} + +DATASET_DIR = { + # Classification, English + 'aclImdb': "imdb.zip", + "yelp-review-full": "yelp_review_full.tar.gz", + "yelp-review-polarity": "yelp_review_polarity.tar.gz", + "sst-2": "SST-2.zip", + "sst": "SST.zip", + 'mr': 'mr.zip', + "R8": "R8.zip", + "R52": "R52.zip", + "20ng": "20ng.zip", + "ohsumed": "ohsumed.zip", + + # Classification, Chinese + "chn-senti-corp": "chn_senti_corp.zip", + "weibo-senti-100k": "WeiboSenti100k.zip", + "thuc-news": "THUCNews.zip", + + # Matching, English + "mnli": "MNLI.zip", + "snli": "SNLI.zip", + "qnli": "QNLI.zip", + "rte": "RTE.zip", + + # Matching, Chinese + "cn-xnli": "XNLI.zip", + + # Sequence Labeling, Chinese + "msra-ner": "MSRA_NER.zip", + "peopledaily": "peopledaily.zip", + "weibo-ner": "weibo_NER.zip", + + # Chinese Word Segmentation + "cws-pku": 'cws_pku.zip', + "cws-cityu": "cws_cityu.zip", + "cws-as": 'cws_as.zip', + "cws-msra": 'cws_msra.zip', + + # Summarization, English + "ext-cnndm": "ext-cnndm.zip", + + # Question & answer, Chinese + "cmrc2018": "cmrc2018.zip" + +} + +PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, + "bert": PRETRAINED_BERT_MODEL_DIR, + "static": PRETRAIN_STATIC_FILES, + 'gpt2': PRETRAINED_GPT2_MODEL_DIR, + 'roberta': PRETRAINED_ROBERTA_MODEL_DIR} + +# 用于扩展fastNLP的下载 +FASTNLP_EXTEND_DATASET_URL = 'fastnlp_dataset_url.txt' +FASTNLP_EXTEND_EMBEDDING_URL = {'elmo': 'fastnlp_elmo_url.txt', + 'bert': 'fastnlp_bert_url.txt', + 'static': 'fastnlp_static_url.txt', + 'gpt2': 'fastnlp_gpt2_url.txt', + 'roberta': 'fastnlp_roberta_url.txt' + } + + +def cached_path(url_or_filename: str, cache_dir: str = None, name=None) -> Path: + r""" + 给定一个url,尝试通过url中的解析出来的文件名字filename到{cache_dir}/{name}/{filename}下寻找这个文件, + + 1. 如果cache_dir=None, 则cache_dir=~/.fastNLP/; 否则cache_dir=cache_dir + 2. 如果name=None, 则没有中间的{name}这一层结构;否者中间结构就为{name} + + 如果有该文件,就直接返回路径 + + 如果没有该文件,则尝试用传入的url下载 + + 或者文件名(可以是具体的文件名,也可以是文件夹),先在cache_dir下寻找该文件是否存在,如果不存在则去下载, 并 + 将文件放入到cache_dir中. + + :param str url_or_filename: 文件的下载url或者文件名称。 + :param str cache_dir: 文件的缓存文件夹。如果为None,将使用"~/.fastNLP"这个默认路径 + :param str name: 中间一层的名称。如embedding, dataset + :return: + """ + if cache_dir is None: + data_cache = Path(get_cache_path()) + else: + data_cache = cache_dir + + if name: + data_cache = os.path.join(data_cache, name) + + parsed = urlparse(url_or_filename) + + if parsed.scheme in ("http", "https"): + # URL, so get it from the cache (downloading if necessary) + return get_from_cache(url_or_filename, Path(data_cache)) + elif parsed.scheme == "" and Path(os.path.join(data_cache, url_or_filename)).exists(): + # File, and it exists. + return Path(os.path.join(data_cache, url_or_filename)) + elif parsed.scheme == "": + # File, but it doesn't exist. + raise FileNotFoundError("file {} not found in {}.".format(url_or_filename, data_cache)) + else: + # Something unknown + raise ValueError( + "unable to parse {} as a URL or as a local path".format(url_or_filename) + ) + + +def get_filepath(filepath): + r""" + 如果filepath为文件夹, + + 如果内含多个文件, 返回filepath + + 如果只有一个文件, 返回filepath + filename + + 如果filepath为文件 + + 返回filepath + + :param str filepath: 路径 + :return: + """ + if os.path.isdir(filepath): + files = os.listdir(filepath) + if len(files) == 1: + return os.path.join(filepath, files[0]) + else: + return filepath + elif os.path.isfile(filepath): + return filepath + else: + raise FileNotFoundError(f"{filepath} is not a valid file or directory.") + + +def get_cache_path(): + r""" + 获取fastNLP默认cache的存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。 + + :return str: 存放路径 + """ + if 'FASTNLP_CACHE_DIR' in os.environ: + fastnlp_cache_dir = os.environ.get('FASTNLP_CACHE_DIR') + if os.path.isdir(fastnlp_cache_dir): + return fastnlp_cache_dir + else: + raise NotADirectoryError(f"{os.environ['FASTNLP_CACHE_DIR']} is not a directory.") + fastnlp_cache_dir = os.path.expanduser(os.path.join("~", ".fastNLP")) + return fastnlp_cache_dir + + +def _get_base_url(name): + r""" + 根据name返回下载的url地址。 + + :param str name: 支持dataset和embedding两种 + :return: + """ + # 返回的URL结尾必须是/ + environ_name = "FASTNLP_{}_URL".format(name.upper()) + + if environ_name in os.environ: + url = os.environ[environ_name] + if url.endswith('/'): + return url + else: + return url + '/' + else: + URLS = { + 'embedding': "http://download.fastnlp.top/embedding/", + "dataset": "http://download.fastnlp.top/dataset/" + } + if name.lower() not in URLS: + raise KeyError(f"{name} is not recognized.") + return URLS[name.lower()] + + +def _get_embedding_url(embed_type, name): + r""" + 给定embedding类似和名称,返回下载url + + :param str embed_type: 支持static, bert, elmo。即embedding的类型 + :param str name: embedding的名称, 例如en, cn, based等 + :return: str, 下载的url地址 + """ + # 从扩展中寻找下载的url + _filename = FASTNLP_EXTEND_EMBEDDING_URL.get(embed_type, None) + if _filename: + url = _read_extend_url_file(_filename, name) + if url: + return url + embed_map = PRETRAIN_MAP.get(embed_type, None) + if embed_map: + filename = embed_map.get(name, None) + if filename: + url = _get_base_url('embedding') + filename + return url + raise KeyError("There is no {}. Only supports {}.".format(name, list(embed_map.keys()))) + else: + raise KeyError(f"There is no {embed_type}. Only supports bert, elmo, static, gpt2, roberta") + + +def _read_extend_url_file(filename, name) -> str: + r""" + filename中的内容使用制表符隔开,第一列是名称,第二列是下载的url地址 + + :param str filename: 在默认的路径下寻找file这个文件 + :param str name: 需要寻找的资源的名称 + :return: str,None + """ + cache_dir = get_cache_path() + filepath = os.path.join(cache_dir, filename) + if os.path.exists(filepath): + with open(filepath, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if line: + parts = line.split('\t') + if len(parts) == 2: + if name == parts[0]: + return parts[1] + return None + + +def _get_dataset_url(name, dataset_dir: dict = None): + r""" + 给定dataset的名称,返回下载url + + :param str name: 给定dataset的名称,比如imdb, sst-2等 + :return: str + """ + # 从扩展中寻找下载的url + url = _read_extend_url_file(FASTNLP_EXTEND_DATASET_URL, name) + if url: + return url + + dataset_dir = DATASET_DIR if dataset_dir is None else dataset_dir + filename = dataset_dir.get(name, None) + if filename: + url = _get_base_url('dataset') + filename + return url + else: + raise KeyError(f"There is no {name}.") + + +def split_filename_suffix(filepath): + r""" + 给定filepath 返回对应的name和suffix. 如果后缀是多个点,仅支持.tar.gz类型 + + :param filepath: 文件路径 + :return: filename, suffix + """ + filename = os.path.basename(filepath) + if filename.endswith('.tar.gz'): + return filename[:-7], '.tar.gz' + return os.path.splitext(filename) + + +def get_from_cache(url: str, cache_dir: Path = None) -> Path: + r""" + 尝试在cache_dir中寻找url定义的资源; 如果没有找到; 则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。会将下载的 + 文件解压,将解压后的文件全部放在cache_dir文件夹中。 + + 如果从url中下载的资源解压后有多个文件,则返回目录的路径; 如果只有一个资源文件,则返回具体的路径。 + + :param url: 资源的 url + :param cache_dir: cache 目录 + :return: 路径 + """ + cache_dir.mkdir(parents=True, exist_ok=True) + + filename = re.sub(r".+/", "", url) + dir_name, suffix = split_filename_suffix(filename) + + # 寻找与它名字匹配的内容, 而不关心后缀 + match_dir_name = match_file(dir_name, cache_dir) + if match_dir_name: + dir_name = match_dir_name + cache_path = cache_dir / dir_name + + # get cache path to put the file + if cache_path.exists(): + return get_filepath(cache_path) + + # make HEAD request to check ETag TODO ETag可以用来判断资源是否已经更新了,之后需要加上 + # response = requests.head(url, headers={"User-Agent": "fastNLP"}) + # if response.status_code != 200: + # raise IOError( + # f"HEAD request failed for url {url} with status code {response.status_code}." + # ) + + # add ETag to filename if it exists + # etag = response.headers.get("ETag") + + if not cache_path.exists(): + # Download to temporary file, then copy to cache dir once finished. + # Otherwise you get corrupt cache entries if the download gets interrupted. + # GET file object + req = requests.get(url, stream=True, headers={"User-Agent": "fastNLP"}) + if req.status_code == 200: + success = False + fd, temp_filename = tempfile.mkstemp() + uncompress_temp_dir = None + try: + content_length = req.headers.get("Content-Length") + total = int(content_length) if content_length is not None else None + # progress = tqdm(unit="B", total=total, unit_scale=1) + progress = Progress( + BarColumn(), + TimeElapsedColumn(), + "/", + TimeRemainingColumn(), + DownloadColumn() + ) + task = progress.add_task(total=total, description='download') + progress.start() + logger.info("%s not found in cache, downloading to %s" % (url, temp_filename)) + + with open(temp_filename, "wb") as temp_file: + for chunk in req.iter_content(chunk_size=1024 * 16): + if chunk: # filter out keep-alive new chunks + progress.update(task, advance=len(chunk)) + temp_file.write(chunk) + progress.stop() + progress.remove_task(task) + logger.info(f"Finish download from {url}") + + # 开始解压 + if suffix in ('.zip', '.tar.gz', '.gz'): + uncompress_temp_dir = tempfile.mkdtemp() + logger.info(f"Start to uncompress file to {uncompress_temp_dir}") + if suffix == '.zip': + unzip_file(Path(temp_filename), Path(uncompress_temp_dir)) + elif suffix == '.gz': + ungzip_file(temp_filename, uncompress_temp_dir, dir_name) + else: + untar_gz_file(Path(temp_filename), Path(uncompress_temp_dir)) + filenames = os.listdir(uncompress_temp_dir) + if len(filenames) == 1: + if os.path.isdir(os.path.join(uncompress_temp_dir, filenames[0])): + uncompress_temp_dir = os.path.join(uncompress_temp_dir, filenames[0]) + + cache_path.mkdir(parents=True, exist_ok=True) + logger.info("Finish un-compressing file.") + else: + uncompress_temp_dir = temp_filename + cache_path = str(cache_path) + suffix + + # 复制到指定的位置 + logger.info(f"Copy file to {cache_path}") + if os.path.isdir(uncompress_temp_dir): + for filename in os.listdir(uncompress_temp_dir): + if os.path.isdir(os.path.join(uncompress_temp_dir, filename)): + shutil.copytree(os.path.join(uncompress_temp_dir, filename), cache_path / filename) + else: + shutil.copyfile(os.path.join(uncompress_temp_dir, filename), cache_path / filename) + else: + shutil.copyfile(uncompress_temp_dir, cache_path) + success = True + except Exception as e: + logger.info(e) + raise e + finally: + if not success: + if cache_path.exists(): + if cache_path.is_file(): + os.remove(cache_path) + else: + shutil.rmtree(cache_path) + os.close(fd) + os.remove(temp_filename) + if uncompress_temp_dir is None: + pass + elif os.path.isdir(uncompress_temp_dir): + shutil.rmtree(uncompress_temp_dir) + elif os.path.isfile(uncompress_temp_dir): + os.remove(uncompress_temp_dir) + return get_filepath(cache_path) + else: + raise HTTPError(f"Status code:{req.status_code}. Fail to download from {url}.") + + +def unzip_file(file: Path, to: Path): + # unpack and write out in CoNLL column-like format + from zipfile import ZipFile + + with ZipFile(file, "r") as zipObj: + # Extract all the contents of zip file in current directory + zipObj.extractall(to) + + +def untar_gz_file(file: Path, to: Path): + import tarfile + + with tarfile.open(file, 'r:gz') as tar: + tar.extractall(to) + + +def ungzip_file(file: str, to: str, filename: str): + import gzip + + g_file = gzip.GzipFile(file) + with open(os.path.join(to, filename), 'wb+') as f: + f.write(g_file.read()) + g_file.close() + + +def match_file(dir_name: str, cache_dir: Path) -> str: + r""" + 匹配的原则是: 在cache_dir下的文件与dir_name完全一致, 或除了后缀以外和dir_name完全一致。 + 如果找到了两个匹配的结果将报错. 如果找到了则返回匹配的文件的名称; 没有找到返回空字符串 + + :param dir_name: 需要匹配的名称 + :param cache_dir: 在该目录下找匹配dir_name是否存在 + :return str: 做为匹配结果的字符串 + """ + files = os.listdir(cache_dir) + matched_filenames = [] + for file_name in files: + if re.match(dir_name + '$', file_name) or re.match(dir_name + '\\..*', file_name): + matched_filenames.append(file_name) + if len(matched_filenames) == 0: + return '' + elif len(matched_filenames) == 1: + return matched_filenames[-1] + else: + raise RuntimeError(f"Duplicate matched files:{matched_filenames}, this should be caused by a bug.") + + +def _get_bert_dir(model_dir_or_name: str = 'en-base-uncased'): + if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: + model_url = _get_embedding_url('bert', model_dir_or_name.lower()) + model_dir = cached_path(model_url, name='embedding') + # 检查是否存在 + elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))): + model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name)) + else: + logger.info(f"Cannot recognize BERT dir or name ``{model_dir_or_name}``.") + raise ValueError(f"Cannot recognize BERT dir or name ``{model_dir_or_name}``.") + return str(model_dir) + + +def _get_gpt2_dir(model_dir_or_name: str = 'en'): + if model_dir_or_name.lower() in PRETRAINED_GPT2_MODEL_DIR: + model_url = _get_embedding_url('gpt2', model_dir_or_name.lower()) + model_dir = cached_path(model_url, name='embedding') + # 检查是否存在 + elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))): + model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name)) + else: + logger.info(f"Cannot recognize GPT2 dir or name ``{model_dir_or_name}``.") + raise ValueError(f"Cannot recognize GPT2 dir or name ``{model_dir_or_name}``.") + return str(model_dir) + + +def _get_roberta_dir(model_dir_or_name: str = 'en'): + if model_dir_or_name.lower() in PRETRAINED_ROBERTA_MODEL_DIR: + model_url = _get_embedding_url('roberta', model_dir_or_name.lower()) + model_dir = cached_path(model_url, name='embedding') + # 检查是否存在 + elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))): + model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name)) + else: + logger.info(f"Cannot recognize RoBERTa dir or name ``{model_dir_or_name}``.") + raise ValueError(f"Cannot recognize RoBERTa dir or name ``{model_dir_or_name}``.") + return str(model_dir) + + +def _get_file_name_base_on_postfix(dir_path, postfix): + r""" + 在dir_path中寻找后缀为postfix的文件. + :param dir_path: str, 文件夹 + :param postfix: 形如".bin", ".json"等 + :return: str,文件的路径 + """ + files = list(filter(lambda filename: filename.endswith(postfix), os.listdir(os.path.join(dir_path)))) + if len(files) == 0: + raise FileNotFoundError(f"There is no file endswith {postfix} file in {dir_path}") + elif len(files) > 1: + raise FileExistsError(f"There are multiple *{postfix} files in {dir_path}") + return os.path.join(dir_path, files[0]) diff --git a/fastNLP/io/loader/__init__.py b/fastNLP/io/loader/__init__.py new file mode 100644 index 00000000..5ea9378b --- /dev/null +++ b/fastNLP/io/loader/__init__.py @@ -0,0 +1,107 @@ +r""" +Loader用于读取数据,并将内容读取到 :class:`~fastNLP.DataSet` 或者 :class:`~fastNLP.io.DataBundle` 中。所有的Loader都支持以下的 +三个方法: ``__init__`` , ``_load`` , ``loads`` . 其中 ``__init__(...)`` 用于申明读取参数,以及说明该Loader支持的数据格式, +读取后 :class:`~fastNLP.DataSet` 中的 `field` ; ``_load(path)`` 方法传入文件路径读取单个文件,并返回 :class:`~fastNLP.DataSet` ; +``load(paths)`` 用于读取文件夹下的文件,并返回 :class:`~fastNLP.io.DataBundle` 类型的对象 , load()方法支持以下几种类型的参数: + +0.传入None + 将尝试自动下载数据集并缓存。但不是所有的数据都可以直接下载。 + +1.传入一个文件的 path + 返回的 `data_bundle` 包含一个名为 `train` 的 dataset ,可以通过 ``data_bundle.get_dataset('train')`` 获取 + +2.传入一个文件夹目录 + 将读取的是这个文件夹下文件名中包含 `train` , `test` , `dev` 的文件,其它文件会被忽略。假设某个目录下的文件为:: + + | + +-train.txt + +-dev.txt + +-test.txt + +-other.txt + + 在 Loader().load('/path/to/dir') 返回的 `data_bundle` 中可以用 ``data_bundle.get_dataset('train')`` , + ``data_bundle.get_dataset('dev')`` , + ``data_bundle.get_dataset('test')`` 获取对应的 `dataset` ,其中 `other.txt` 的内容会被忽略。假设某个目录下的文件为:: + + | + +-train.txt + +-dev.txt + + 在 Loader().load('/path/to/dir') 返回的 `data_bundle` 中可以用 ``data_bundle.get_dataset('train')`` , + ``data_bundle.get_dataset('dev')`` 获取对应的 dataset。 + +3.传入一个字典 + 字典的的 key 为 `dataset` 的名称,value 是该 `dataset` 的文件路径:: + + paths = {'train':'/path/to/train', 'dev': '/path/to/dev', 'test':'/path/to/test'} + + 在 Loader().load(paths) 返回的 `data_bundle` 中可以用 ``data_bundle.get_dataset('train')`` , ``data_bundle.get_dataset('dev')`` , + ``data_bundle.get_dataset('test')`` 来获取对应的 `dataset` + +fastNLP 目前提供了如下的 Loader + + + +""" + +__all__ = [ + 'Loader', + + 'CLSBaseLoader', + 'YelpFullLoader', + 'YelpPolarityLoader', + 'AGsNewsLoader', + 'DBPediaLoader', + 'IMDBLoader', + 'SSTLoader', + 'SST2Loader', + "ChnSentiCorpLoader", + "THUCNewsLoader", + "WeiboSenti100kLoader", + "MRLoader", + "R8Loader", "R52Loader", "OhsumedLoader", "NG20Loader", + + 'ConllLoader', + 'Conll2003Loader', + 'Conll2003NERLoader', + 'OntoNotesNERLoader', + 'CTBLoader', + "MsraNERLoader", + "PeopleDailyNERLoader", + "WeiboNERLoader", + + 'CSVLoader', + 'JsonLoader', + + 'CWSLoader', + + 'MNLILoader', + "QuoraLoader", + "SNLILoader", + "QNLILoader", + "RTELoader", + "CNXNLILoader", + "BQCorpusLoader", + "LCQMCLoader", + + "CoReferenceLoader", + + "CMRC2018Loader" +] + +from .classification import CLSBaseLoader, YelpFullLoader, YelpPolarityLoader, AGsNewsLoader, IMDBLoader, \ + SSTLoader, SST2Loader, DBPediaLoader, \ + ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader, \ + MRLoader, R8Loader, R52Loader, OhsumedLoader, NG20Loader +from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader +from .conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader +from .coreference import CoReferenceLoader +from .csv import CSVLoader +from .cws import CWSLoader +from .json import JsonLoader +from .loader import Loader +from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader, CNXNLILoader, BQCorpusLoader, \ + LCQMCLoader +from .qa import CMRC2018Loader + + diff --git a/fastNLP/io/loader/classification.py b/fastNLP/io/loader/classification.py new file mode 100644 index 00000000..0b1a670b --- /dev/null +++ b/fastNLP/io/loader/classification.py @@ -0,0 +1,647 @@ +r"""undocumented""" + +__all__ = [ + "CLSBaseLoader", + "YelpFullLoader", + "YelpPolarityLoader", + "AGsNewsLoader", + "DBPediaLoader", + "IMDBLoader", + "SSTLoader", + "SST2Loader", + "ChnSentiCorpLoader", + "THUCNewsLoader", + "WeiboSenti100kLoader", + + "MRLoader", + "R8Loader", + "R52Loader", + "OhsumedLoader", + "NG20Loader", +] + +import glob +import os +import random +import shutil +import time +import warnings + +from .loader import Loader +from fastNLP.core.dataset import Instance, DataSet + + +# from ...core._logger import log + + +class CLSBaseLoader(Loader): + r""" + 文本分类Loader的一个基类 + + 原始数据中内容应该为, 每一行为一个sample,第一个逗号之前为target,第一个逗号之后为文本内容。 + + Example:: + + "1","I got 'new' tires from the..." + "1","Don't waste your time..." + + 读取的DataSet将具备以下的数据结构 + + .. csv-table:: + :header: "raw_words", "target" + + "I got 'new' tires from them and... ", "1" + "Don't waste your time. We had two...", "1" + "...", "..." + + """ + + def __init__(self, sep=',', has_header=False): + super().__init__() + self.sep = sep + self.has_header = has_header + + def _load(self, path: str): + ds = DataSet() + try: + with open(path, 'r', encoding='utf-8') as f: + read_header = self.has_header + for line in f: + if read_header: + read_header = False + continue + line = line.strip() + sep_index = line.index(self.sep) + target = line[:sep_index] + raw_words = line[sep_index + 1:] + if target.startswith("\""): + target = target[1:] + if target.endswith("\""): + target = target[:-1] + if raw_words.endswith("\""): + raw_words = raw_words[:-1] + if raw_words.startswith('"'): + raw_words = raw_words[1:] + raw_words = raw_words.replace('""', '"') # 替换双引号 + if raw_words: + ds.append(Instance(raw_words=raw_words, target=target)) + except Exception as e: + print(f'Load file `{path}` failed for `{e}`') + return ds + + +def _split_dev(dataset_name, data_dir, dev_ratio=0.0, re_download=False, suffix='csv'): + if dev_ratio == 0.0: + return data_dir + modify_time = 0 + for filepath in glob.glob(os.path.join(data_dir, '*')): + modify_time = os.stat(filepath).st_mtime + break + if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 + shutil.rmtree(data_dir) + data_dir = Loader()._get_dataset_path(dataset_name=dataset_name) + + if not os.path.exists(os.path.join(data_dir, f'dev.{suffix}')): + if dev_ratio > 0: + assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." + try: + with open(os.path.join(data_dir, f'train.{suffix}'), 'r', encoding='utf-8') as f, \ + open(os.path.join(data_dir, f'middle_file.{suffix}'), 'w', encoding='utf-8') as f1, \ + open(os.path.join(data_dir, f'dev.{suffix}'), 'w', encoding='utf-8') as f2: + for line in f: + if random.random() < dev_ratio: + f2.write(line) + else: + f1.write(line) + os.remove(os.path.join(data_dir, f'train.{suffix}')) + os.renames(os.path.join(data_dir, f'middle_file.{suffix}'), os.path.join(data_dir, f'train.{suffix}')) + finally: + if os.path.exists(os.path.join(data_dir, f'middle_file.{suffix}')): + os.remove(os.path.join(data_dir, f'middle_file.{suffix}')) + + return data_dir + + +class AGsNewsLoader(CLSBaseLoader): + def download(self): + r""" + 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 + + Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances + in Neural Information Processing Systems 28 (NIPS 2015) + + :return: str, 数据集的目录地址 + """ + return self._get_dataset_path(dataset_name='ag-news') + + +class DBPediaLoader(CLSBaseLoader): + def download(self, dev_ratio: float = 0.0, re_download: bool = False): + r""" + 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 + + Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances + in Neural Information Processing Systems 28 (NIPS 2015) + + 如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。 + 下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv + + :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 + :param bool re_download: 是否重新下载数据,以重新切分数据。 + :return: str, 数据集的目录地址 + """ + dataset_name = 'dbpedia' + data_dir = self._get_dataset_path(dataset_name=dataset_name) + data_dir = _split_dev(dataset_name=dataset_name, + data_dir=data_dir, + dev_ratio=dev_ratio, + re_download=re_download, + suffix='csv') + return data_dir + + +class IMDBLoader(CLSBaseLoader): + r""" + 原始数据中内容应该为, 每一行为一个sample,制表符之前为target,制表符之后为文本内容。 + + Example:: + + neg Alan Rickman & Emma... + neg I have seen this... + + IMDBLoader读取后的数据将具有以下两列内容: raw_words: str, 需要分类的文本; target: str, 文本的标签 + 读取的DataSet具备以下的结构: + + .. csv-table:: + :header: "raw_words", "target" + + "Alan Rickman & Emma... ", "neg" + "I have seen this... ", "neg" + "...", "..." + + """ + + def __init__(self): + super().__init__(sep='\t') + + def download(self, dev_ratio: float = 0.0, re_download=False): + r""" + 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 + + http://www.aclweb.org/anthology/P11-1015 + + 根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后不从train中切分dev + + :param float dev_ratio: 如果路径中没有dev.txt。从train划分多少作为dev的数据. 如果为0,则不划分dev + :param bool re_download: 是否重新下载数据,以重新切分数据。 + :return: str, 数据集的目录地址 + """ + dataset_name = 'aclImdb' + data_dir = self._get_dataset_path(dataset_name=dataset_name) + data_dir = _split_dev(dataset_name=dataset_name, + data_dir=data_dir, + dev_ratio=dev_ratio, + re_download=re_download, + suffix='txt') + return data_dir + + +class SSTLoader(Loader): + r""" + 原始数据中内容应该为: + + Example:: + + (2 (3 (3 Effective) (2 but)) (1 (1 too-tepid)... + (3 (3 (2 If) (3 (2 you) (3 (2 sometimes)... + + 读取之后的DataSet具有以下的结构 + + .. csv-table:: 下面是使用SSTLoader读取的DataSet所具备的field + :header: "raw_words" + + "(2 (3 (3 Effective) (2 but)) (1 (1 too-tepid)..." + "(3 (3 (2 If) (3 (2 you) (3 (2 sometimes) ..." + "..." + + raw_words列是str。 + + """ + + def __init__(self): + super().__init__() + + def _load(self, path: str): + r""" + 从path读取SST文件 + + :param str path: 文件路径 + :return: DataSet + """ + ds = DataSet() + with open(path, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if line: + ds.append(Instance(raw_words=line)) + return ds + + def download(self): + r""" + 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 + + https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf + + :return: str, 数据集的目录地址 + """ + output_dir = self._get_dataset_path(dataset_name='sst') + return output_dir + + +class YelpFullLoader(CLSBaseLoader): + def download(self, dev_ratio: float = 0.0, re_download: bool = False): + r""" + 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 + + Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances + in Neural Information Processing Systems 28 (NIPS 2015) + + 如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。 + 下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv + + :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 + :param bool re_download: 是否重新下载数据,以重新切分数据。 + :return: str, 数据集的目录地址 + """ + dataset_name = 'yelp-review-full' + data_dir = self._get_dataset_path(dataset_name=dataset_name) + data_dir = _split_dev(dataset_name=dataset_name, + data_dir=data_dir, + dev_ratio=dev_ratio, + re_download=re_download, + suffix='csv') + return data_dir + + +class YelpPolarityLoader(CLSBaseLoader): + def download(self, dev_ratio: float = 0.0, re_download: bool = False): + r""" + 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 + + Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances + in Neural Information Processing Systems 28 (NIPS 2015) + + 如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。 + 下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv + + :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 + :param bool re_download: 是否重新下载数据,以重新切分数据。 + :return: str, 数据集的目录地址 + """ + dataset_name = 'yelp-review-polarity' + data_dir = self._get_dataset_path(dataset_name=dataset_name) + data_dir = _split_dev(dataset_name=dataset_name, + data_dir=data_dir, + dev_ratio=dev_ratio, + re_download=re_download, + suffix='csv') + return data_dir + + +class SST2Loader(Loader): + r""" + 原始数据中内容为:第一行为标题(具体内容会被忽略),之后一行为一个sample,第一个制表符之前被认为是句子,第一个制表符之后认为是label + + Example:: + + sentence label + it 's a charming and often affecting journey . 1 + unflinchingly bleak and desperate 0 + + 读取之后DataSet将如下所示 + + .. csv-table:: + :header: "raw_words", "target" + + "it 's a charming and often affecting journey .", "1" + "unflinchingly bleak and desperate", "0" + "..." + + test的DataSet没有target列。 + """ + + def __init__(self): + super().__init__() + + def _load(self, path: str): + r"""从path读取SST2文件 + + :param str path: 数据路径 + :return: DataSet + """ + ds = DataSet() + + with open(path, 'r', encoding='utf-8') as f: + f.readline() # 跳过header + if 'test' in os.path.split(path)[1]: + warnings.warn("SST2's test file has no target.") + for line in f: + line = line.strip() + if line: + sep_index = line.index('\t') + raw_words = line[sep_index + 1:] + index = int(line[: sep_index]) + if raw_words: + ds.append(Instance(raw_words=raw_words, index=index)) + else: + for line in f: + line = line.strip() + if line: + raw_words = line[:-2] + target = line[-1] + if raw_words: + ds.append(Instance(raw_words=raw_words, target=target)) + return ds + + def download(self): + r""" + 自动下载数据集,如果你使用了该数据集,请引用以下的文章 + https://nlp.stanford.edu/pubs/SocherBauerManningNg_ACL2013.pdf + :return: + """ + output_dir = self._get_dataset_path(dataset_name='sst-2') + return output_dir + + +class ChnSentiCorpLoader(Loader): + r""" + 支持读取的数据的格式为,第一行为标题(具体内容会被忽略),之后一行为一个sample,第一个制表符之前被认为是label,第 + 一个制表符之后认为是句子 + + Example:: + + label text_a + 1 基金痛所有投资项目一样,必须先要有所了解... + 1 系统很好装,LED屏是不错,就是16比9的比例... + + 读取后的DataSet具有以下的field + + .. csv-table:: + :header: "raw_chars", "target" + + "基金痛所有投资项目一样,必须先要有所了解...", "1" + "系统很好装,LED屏是不错,就是16比9的比例...", "1" + "..." + + """ + + def __init__(self): + super().__init__() + + def _load(self, path: str): + r""" + 从path中读取数据 + + :param path: + :return: + """ + ds = DataSet() + with open(path, 'r', encoding='utf-8') as f: + f.readline() + for line in f: + line = line.strip() + tab_index = line.index('\t') + if tab_index != -1: + target = line[:tab_index] + raw_chars = line[tab_index + 1:] + if raw_chars: + ds.append(Instance(raw_chars=raw_chars, target=target)) + return ds + + def download(self) -> str: + r""" + 自动下载数据,该数据取自https://github.com/pengming617/bert_classification/tree/master/data,在 + https://arxiv.org/pdf/1904.09223.pdf与https://arxiv.org/pdf/1906.08101.pdf有使用 + + :return: + """ + output_dir = self._get_dataset_path('chn-senti-corp') + return output_dir + + +class THUCNewsLoader(Loader): + r""" + 数据集简介:document-level分类任务,新闻10分类 + 原始数据内容为:每行一个sample,第一个 "\\t" 之前为target,第一个 "\\t" 之后为raw_words + + Example:: + + 体育 调查-您如何评价热火客场胜绿军总分3-1夺赛点?... + + 读取后的Dataset将具有以下数据结构: + + .. csv-table:: + :header: "raw_words", "target" + + "调查-您如何评价热火客场胜绿军总分3-1夺赛点?...", "体育" + "...", "..." + + """ + + def __init__(self): + super(THUCNewsLoader, self).__init__() + + def _load(self, path: str = None): + ds = DataSet() + with open(path, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + sep_index = line.index('\t') + raw_chars = line[sep_index + 1:] + target = line[:sep_index] + if raw_chars: + ds.append(Instance(raw_chars=raw_chars, target=target)) + return ds + + def download(self) -> str: + r""" + 自动下载数据,该数据取自 + + http://thuctc.thunlp.org/#%E4%B8%AD%E6%96%87%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%E6%95%B0%E6%8D%AE%E9%9B%86THUCNews + + :return: + """ + output_dir = self._get_dataset_path('thuc-news') + return output_dir + + +class WeiboSenti100kLoader(Loader): + r""" + 别名: + 数据集简介:微博sentiment classification,二分类 + + Example:: + + label text + 1 多谢小莲,好运满满[爱你] + 1 能在他乡遇老友真不赖,哈哈,珠儿,我也要用... + + 读取后的Dataset将具有以下数据结构: + + .. csv-table:: + :header: "raw_chars", "target" + + "多谢小莲,好运满满[爱你]", "1" + "能在他乡遇老友真不赖,哈哈,珠儿,我也要用...", "1" + "...", "..." + + """ + + def __init__(self): + super(WeiboSenti100kLoader, self).__init__() + + def _load(self, path: str = None): + ds = DataSet() + with open(path, 'r', encoding='utf-8') as f: + next(f) + for line in f: + line = line.strip() + target = line[0] + raw_chars = line[1:] + if raw_chars: + ds.append(Instance(raw_chars=raw_chars, target=target)) + return ds + + def download(self) -> str: + r""" + 自动下载数据,该数据取自 https://github.com/SophonPlus/ChineseNlpCorpus/ + 在 https://arxiv.org/abs/1906.08101 有使用 + :return: + """ + output_dir = self._get_dataset_path('weibo-senti-100k') + return output_dir + + +class MRLoader(CLSBaseLoader): + def __init__(self): + super(MRLoader, self).__init__() + + def download(self, dev_ratio: float = 0.0, re_download: bool = False) -> str: + r""" + 自动下载数据集 + + 如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。 + 下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv + + :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 + :param bool re_download: 是否重新下载数据,以重新切分数据。 + :return: str, 数据集的目录地址 + """ + dataset_name = r'mr' + data_dir = self._get_dataset_path(dataset_name=dataset_name) + data_dir = _split_dev(dataset_name=dataset_name, + data_dir=data_dir, + dev_ratio=dev_ratio, + re_download=re_download, + suffix='csv') + return data_dir + + +class R8Loader(CLSBaseLoader): + def __init__(self): + super(R8Loader, self).__init__() + + def download(self, dev_ratio: float = 0.0, re_download: bool = False) -> str: + r""" + 自动下载数据集 + + 如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。 + 下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv + + :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 + :param bool re_download: 是否重新下载数据,以重新切分数据。 + :return: str, 数据集的目录地址 + """ + dataset_name = r'R8' + data_dir = self._get_dataset_path(dataset_name=dataset_name) + data_dir = _split_dev(dataset_name=dataset_name, + data_dir=data_dir, + dev_ratio=dev_ratio, + re_download=re_download, + suffix='csv') + return data_dir + + +class R52Loader(CLSBaseLoader): + def __init__(self): + super(R52Loader, self).__init__() + + def download(self, dev_ratio: float = 0.0, re_download: bool = False) -> str: + r""" + 自动下载数据集 + + 如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。 + 下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv + + :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 + :param bool re_download: 是否重新下载数据,以重新切分数据。 + :return: str, 数据集的目录地址 + """ + dataset_name = r'R52' + data_dir = self._get_dataset_path(dataset_name=dataset_name) + data_dir = _split_dev(dataset_name=dataset_name, + data_dir=data_dir, + dev_ratio=dev_ratio, + re_download=re_download, + suffix='csv') + return data_dir + + +class NG20Loader(CLSBaseLoader): + def __init__(self): + super(NG20Loader, self).__init__() + + def download(self, dev_ratio: float = 0.0, re_download: bool = False) -> str: + r""" + 自动下载数据集 + + 如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。 + 下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv + + :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 + :param bool re_download: 是否重新下载数据,以重新切分数据。 + :return: str, 数据集的目录地址 + """ + dataset_name = r'20ng' + data_dir = self._get_dataset_path(dataset_name=dataset_name) + data_dir = _split_dev(dataset_name=dataset_name, + data_dir=data_dir, + dev_ratio=dev_ratio, + re_download=re_download, + suffix='csv') + return data_dir + + +class OhsumedLoader(CLSBaseLoader): + def __init__(self): + super(OhsumedLoader, self).__init__() + + def download(self, dev_ratio: float = 0.0, re_download: bool = False) -> str: + r""" + 自动下载数据集 + + 如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。 + 下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv + + :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 + :param bool re_download: 是否重新下载数据,以重新切分数据。 + :return: str, 数据集的目录地址 + """ + dataset_name = r'ohsumed' + data_dir = self._get_dataset_path(dataset_name=dataset_name) + data_dir = _split_dev(dataset_name=dataset_name, + data_dir=data_dir, + dev_ratio=dev_ratio, + re_download=re_download, + suffix='csv') + return data_dir diff --git a/fastNLP/io/loader/conll.py b/fastNLP/io/loader/conll.py new file mode 100644 index 00000000..e099331f --- /dev/null +++ b/fastNLP/io/loader/conll.py @@ -0,0 +1,542 @@ +r"""undocumented""" + +__all__ = [ + "ConllLoader", + "Conll2003Loader", + "Conll2003NERLoader", + "OntoNotesNERLoader", + "CTBLoader", + "CNNERLoader", + "MsraNERLoader", + "WeiboNERLoader", + "PeopleDailyNERLoader" +] + +import glob +import os +import random +import shutil +import time + +from .loader import Loader +from ..file_reader import _read_conll +# from ...core.const import Const +from fastNLP.core.dataset import DataSet, Instance + + +class ConllLoader(Loader): + r""" + ConllLoader支持读取的数据格式: 以空行隔开两个sample,除了分割行,每一行用空格或者制表符隔开不同的元素。如下例所示: + + Example:: + + # 文件中的内容 + Nadim NNP B-NP B-PER + Ladki NNP I-NP I-PER + + AL-AIN NNP B-NP B-LOC + United NNP B-NP B-LOC + Arab NNP I-NP I-LOC + Emirates NNPS I-NP I-LOC + 1996-12-06 CD I-NP O + ... + + # 如果用以下的参数读取,返回的DataSet将包含raw_words和pos两个field, 这两个field的值分别取自于第0列与第1列 + dataset = ConllLoader(headers=['raw_words', 'pos'], indexes=[0, 1])._load('/path/to/train.conll') + # 如果用以下的参数读取,返回的DataSet将包含raw_words和ner两个field, 这两个field的值分别取自于第0列与第2列 + dataset = ConllLoader(headers=['raw_words', 'ner'], indexes=[0, 3])._load('/path/to/train.conll') + # 如果用以下的参数读取,返回的DataSet将包含raw_words, pos和ner三个field + dataset = ConllLoader(headers=['raw_words', 'pos', 'ner'], indexes=[0, 1, 3])._load('/path/to/train.conll') + + ConllLoader返回的DataSet的field由传入的headers确定。 + + 数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。 + + """ + + def __init__(self, headers, sep=None, indexes=None, dropna=True): + r""" + + :param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 + :param list sep: 指定分隔符,默认为制表符 + :param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` + :param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` + """ + super(ConllLoader, self).__init__() + if not isinstance(headers, (list, tuple)): + raise TypeError( + 'invalid headers: {}, should be list of strings'.format(headers)) + self.headers = headers + self.dropna = dropna + self.sep=sep + if indexes is None: + self.indexes = list(range(len(self.headers))) + else: + if len(indexes) != len(headers): + raise ValueError + self.indexes = indexes + + def _load(self, path): + r""" + 传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 + + :param str path: 文件的路径 + :return: DataSet + """ + ds = DataSet() + for idx, data in _read_conll(path,sep=self.sep, indexes=self.indexes, dropna=self.dropna): + ins = {h: data[i] for i, h in enumerate(self.headers)} + ds.append(Instance(**ins)) + return ds + + +class Conll2003Loader(ConllLoader): + r""" + 用于读取conll2003任务的数据。数据的内容应该类似与以下的内容, 第一列为raw_words, 第二列为pos, 第三列为chunking,第四列为ner。 + + Example:: + + Nadim NNP B-NP B-PER + Ladki NNP I-NP I-PER + + AL-AIN NNP B-NP B-LOC + United NNP B-NP B-LOC + Arab NNP I-NP I-LOC + Emirates NNPS I-NP I-LOC + 1996-12-06 CD I-NP O + ... + + 返回的DataSet的内容为 + + .. csv-table:: 下面是Conll2003Loader加载后数据具备的结构。 + :header: "raw_words", "pos", "chunk", "ner" + + "[Nadim, Ladki]", "[NNP, NNP]", "[B-NP, I-NP]", "[B-PER, I-PER]" + "[AL-AIN, United, Arab, ...]", "[NNP, NNP, NNP, ...]", "[B-NP, B-NP, I-NP, ...]", "[B-LOC, B-LOC, I-LOC, ...]" + "[...]", "[...]", "[...]", "[...]" + + """ + + def __init__(self): + headers = [ + 'raw_words', 'pos', 'chunk', 'ner', + ] + super(Conll2003Loader, self).__init__(headers=headers) + + def _load(self, path): + r""" + 传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 + + :param str path: 文件的路径 + :return: DataSet + """ + ds = DataSet() + for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): + doc_start = False + for i, h in enumerate(self.headers): + field = data[i] + if str(field[0]).startswith('-DOCSTART-'): + doc_start = True + break + if doc_start: + continue + ins = {h: data[i] for i, h in enumerate(self.headers)} + ds.append(Instance(**ins)) + return ds + + def download(self, output_dir=None): + raise RuntimeError("conll2003 cannot be downloaded automatically.") + + +class Conll2003NERLoader(ConllLoader): + r""" + 用于读取conll2003任务的NER数据。每一行有4列内容,空行意味着隔开两个句子 + + 支持读取的内容如下 + Example:: + + Nadim NNP B-NP B-PER + Ladki NNP I-NP I-PER + + AL-AIN NNP B-NP B-LOC + United NNP B-NP B-LOC + Arab NNP I-NP I-LOC + Emirates NNPS I-NP I-LOC + 1996-12-06 CD I-NP O + ... + + 返回的DataSet的内容为 + + .. csv-table:: 下面是Conll2003Loader加载后数据具备的结构, target是BIO2编码 + :header: "raw_words", "target" + + "[Nadim, Ladki]", "[B-PER, I-PER]" + "[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]" + "[...]", "[...]" + + """ + + def __init__(self): + headers = [ + 'raw_words', 'target', + ] + super().__init__(headers=headers, indexes=[0, 3]) + + def _load(self, path): + r""" + 传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 + + :param str path: 文件的路径 + :return: DataSet + """ + ds = DataSet() + for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): + doc_start = False + for i, h in enumerate(self.headers): + field = data[i] + if str(field[0]).startswith('-DOCSTART-'): + doc_start = True + break + if doc_start: + continue + ins = {h: data[i] for i, h in enumerate(self.headers)} + ds.append(Instance(**ins)) + if len(ds) == 0: + raise RuntimeError("No data found {}.".format(path)) + return ds + + def download(self): + raise RuntimeError("conll2003 cannot be downloaded automatically.") + + +class OntoNotesNERLoader(ConllLoader): + r""" + 用以读取OntoNotes的NER数据,同时也是Conll2012的NER任务数据。将OntoNote数据处理为conll格式的过程可以参考 + https://github.com/yhcc/OntoNotes-5.0-NER。OntoNoteNERLoader将取第4列和第11列的内容。 + + 读取的数据格式为: + + Example:: + + bc/msnbc/00/msnbc_0000 0 0 Hi UH (TOP(FRAG(INTJ*) - - - Dan_Abrams * - + bc/msnbc/00/msnbc_0000 0 1 everyone NN (NP*) - - - Dan_Abrams * - + ... + + 返回的DataSet的内容为 + + .. csv-table:: + :header: "raw_words", "target" + + "['Hi', 'everyone', '.']", "['O', 'O', 'O']" + "['first', 'up', 'on', 'the', 'docket']", "['O', 'O', 'O', 'O', 'O']" + "[...]", "[...]" + + """ + + def __init__(self): + super().__init__(headers=['raw_words', 'target'], indexes=[3, 10]) + + def _load(self, path: str): + dataset = super()._load(path) + + def convert_to_bio(tags): + bio_tags = [] + flag = None + for tag in tags: + label = tag.strip("()*") + if '(' in tag: + bio_label = 'B-' + label + flag = label + elif flag: + bio_label = 'I-' + flag + else: + bio_label = 'O' + if ')' in tag: + flag = None + bio_tags.append(bio_label) + return bio_tags + + def convert_word(words): + converted_words = [] + for word in words: + word = word.replace('/.', '.') # 有些结尾的.是/.形式的 + if not word.startswith('-'): + converted_words.append(word) + continue + # 以下是由于这些符号被转义了,再转回来 + tfrs = {'-LRB-': '(', + '-RRB-': ')', + '-LSB-': '[', + '-RSB-': ']', + '-LCB-': '{', + '-RCB-': '}' + } + if word in tfrs: + converted_words.append(tfrs[word]) + else: + converted_words.append(word) + return converted_words + + dataset.apply_field(convert_word, field_name='raw_words', new_field_name='raw_words') + dataset.apply_field(convert_to_bio, field_name='target', new_field_name='target') + + return dataset + + def download(self): + raise RuntimeError("Ontonotes cannot be downloaded automatically, you can refer " + "https://github.com/yhcc/OntoNotes-5.0-NER to download and preprocess.") + + +class CTBLoader(Loader): + r""" + 支持加载的数据应该具备以下格式, 其中第二列为词语,第四列为pos tag,第七列为依赖树的head,第八列为依赖树的label + + Example:: + + 1 印度 _ NR NR _ 3 nn _ _ + 2 海军 _ NN NN _ 3 nn _ _ + 3 参谋长 _ NN NN _ 5 nsubjpass _ _ + 4 被 _ SB SB _ 5 pass _ _ + 5 解职 _ VV VV _ 0 root _ _ + + 1 新华社 _ NR NR _ 7 dep _ _ + 2 新德里 _ NR NR _ 7 dep _ _ + 3 12月 _ NT NT _ 7 dep _ _ + ... + + 读取之后DataSet具备的格式为 + + .. csv-table:: + :header: "raw_words", "pos", "dep_head", "dep_label" + + "[印度, 海军, ...]", "[NR, NN, SB, ...]", "[3, 3, ...]", "[nn, nn, ...]" + "[新华社, 新德里, ...]", "[NR, NR, NT, ...]", "[7, 7, 7, ...]", "[dep, dep, dep, ...]" + "[...]", "[...]", "[...]", "[...]" + + """ + def __init__(self): + super().__init__() + headers = [ + 'raw_words', 'pos', 'dep_head', 'dep_label', + ] + indexes = [ + 1, 3, 6, 7, + ] + self.loader = ConllLoader(headers=headers, indexes=indexes) + + def _load(self, path: str): + dataset = self.loader._load(path) + return dataset + + def download(self): + r""" + 由于版权限制,不能提供自动下载功能。可参考 + + https://catalog.ldc.upenn.edu/LDC2013T21 + + :return: + """ + raise RuntimeError("CTB cannot be downloaded automatically.") + + +class CNNERLoader(Loader): + def _load(self, path: str): + r""" + 支持加载形如以下格式的内容,一行两列,以空格隔开两个sample + + Example:: + + 我 O + 们 O + 变 O + 而 O + 以 O + 书 O + 会 O + ... + + :param str path: 文件路径 + :return: DataSet,包含raw_words列和target列 + """ + ds = DataSet() + with open(path, 'r', encoding='utf-8') as f: + raw_chars = [] + target = [] + for line in f: + line = line.strip() + if line: + parts = line.split() + if len(parts) == 1: # 网上下载的数据有一些列少tag,默认补充O + parts.append('O') + raw_chars.append(parts[0]) + target.append(parts[1]) + else: + if raw_chars: + ds.append(Instance(raw_chars=raw_chars, target=target)) + raw_chars = [] + target = [] + return ds + + +class MsraNERLoader(CNNERLoader): + r""" + 读取MSRA-NER数据,数据中的格式应该类似与下列的内容 + + Example:: + + 把 O + 欧 B-LOC + + 美 B-LOC + 、 O + + 港 B-LOC + 台 B-LOC + + 流 O + 行 O + + 的 O + + 食 O + + ... + + 读取后的DataSet包含以下的field + + .. csv-table:: + :header: "raw_chars", "target" + + "['把', '欧'] ", "['O', 'B-LOC']" + "['美', '、']", "['B-LOC', 'O']" + "[...]", "[...]" + + """ + + def __init__(self): + super().__init__() + + def download(self, dev_ratio: float = 0.1, re_download: bool = False) -> str: + r""" + 自动下载MSAR-NER的数据,如果你使用该数据,请引用 Gina-Anne Levow, 2006, The Third International Chinese Language + Processing Bakeoff: Word Segmentation and Named Entity Recognition. + + 根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后在output_dir中有train.conll, test.conll, + dev.conll三个文件。 + + :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 + :param bool re_download: 是否重新下载数据,以重新切分数据。 + :return: str, 数据集的目录地址 + :return: + """ + dataset_name = 'msra-ner' + data_dir = self._get_dataset_path(dataset_name=dataset_name) + modify_time = 0 + for filepath in glob.glob(os.path.join(data_dir, '*')): + modify_time = os.stat(filepath).st_mtime + break + if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 + shutil.rmtree(data_dir) + data_dir = self._get_dataset_path(dataset_name=dataset_name) + + if not os.path.exists(os.path.join(data_dir, 'dev.conll')): + if dev_ratio > 0: + assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." + try: + with open(os.path.join(data_dir, 'train.conll'), 'r', encoding='utf-8') as f, \ + open(os.path.join(data_dir, 'middle_file.conll'), 'w', encoding='utf-8') as f1, \ + open(os.path.join(data_dir, 'dev.conll'), 'w', encoding='utf-8') as f2: + lines = [] # 一个sample包含很多行 + for line in f: + line = line.strip() + if line: + lines.append(line) + else: + if random.random() < dev_ratio: + f2.write('\n'.join(lines) + '\n\n') + else: + f1.write('\n'.join(lines) + '\n\n') + lines.clear() + os.remove(os.path.join(data_dir, 'train.conll')) + os.renames(os.path.join(data_dir, 'middle_file.conll'), os.path.join(data_dir, 'train.conll')) + finally: + if os.path.exists(os.path.join(data_dir, 'middle_file.conll')): + os.remove(os.path.join(data_dir, 'middle_file.conll')) + + return data_dir + + +class WeiboNERLoader(CNNERLoader): + r""" + 读取WeiboNER数据,数据中的格式应该类似与下列的内容 + + Example:: + + 老 B-PER.NOM + 百 I-PER.NOM + 姓 I-PER.NOM + + 心 O + + ... + + 读取后的DataSet包含以下的field + + .. csv-table:: + + :header: "raw_chars", "target" + + "['老', '百', '姓']", "['B-PER.NOM', 'I-PER.NOM', 'I-PER.NOM']" + "['心']", "['O']" + "[...]", "[...]" + + """ + def __init__(self): + super().__init__() + + def download(self) -> str: + r""" + 自动下载Weibo-NER的数据,如果你使用了该数据,请引用 Nanyun Peng and Mark Dredze, 2015, Named Entity Recognition for + Chinese Social Media with Jointly Trained Embeddings. + + :return: str + """ + dataset_name = 'weibo-ner' + data_dir = self._get_dataset_path(dataset_name=dataset_name) + + return data_dir + + +class PeopleDailyNERLoader(CNNERLoader): + r""" + 支持加载的数据格式如下 + + Example:: + + 中 B-ORG + 共 I-ORG + 中 I-ORG + 央 I-ORG + + 致 O + 中 B-ORG + ... + + 读取后的DataSet包含以下的field + + .. csv-table:: target列是基于BIO的编码方式 + :header: "raw_chars", "target" + + "['中', '共', '中', '央']", "['B-ORG', 'I-ORG', 'I-ORG', 'I-ORG']" + "[...]", "[...]" + + """ + + def __init__(self): + super().__init__() + + def download(self) -> str: + dataset_name = 'peopledaily' + data_dir = self._get_dataset_path(dataset_name=dataset_name) + + return data_dir diff --git a/fastNLP/io/loader/coreference.py b/fastNLP/io/loader/coreference.py new file mode 100644 index 00000000..66a39749 --- /dev/null +++ b/fastNLP/io/loader/coreference.py @@ -0,0 +1,64 @@ +r"""undocumented""" + +__all__ = [ + "CoReferenceLoader", +] + +from ...core.dataset import DataSet +from ..file_reader import _read_json +from fastNLP.core.dataset import Instance +# from ...core.const import Const +from .json import JsonLoader + + +class CoReferenceLoader(JsonLoader): + r""" + 原始数据中内容应该为, 每一行为一个json对象,其中doc_key包含文章的种类信息,speakers包含每句话的说话者信息,cluster是指向现实中同一个事物的聚集,sentences是文本信息内容。 + + Example:: + + {"doc_key": "bc/cctv/00/cctv_0000_0", + "speakers": [["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"], ["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"], ["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"]], + "clusters": [[[70, 70], [485, 486], [500, 500], [73, 73], [55, 55], [153, 154], [366, 366]]], + "sentences": [["In", "the", "summer", "of", "2005", ",", "a", "picture", "that", "people", "have", "long", "been", "looking", "forward", "to", "started", "emerging", "with", "frequency", "in", "various", "major", "Hong", "Kong", "media", "."], ["With", "their", "unique", "charm", ",", "these", "well", "-", "known", "cartoon", "images", "once", "again", "caused", "Hong", "Kong", "to", "be", "a", "focus", "of", "worldwide", "attention", "."]] + } + + 读取预处理好的Conll2012数据,数据结构如下: + + .. csv-table:: + :header: "raw_words1", "raw_words2", "raw_words3", "raw_words4" + + "bc/cctv/00/cctv_0000_0", "[['Speaker#1', 'Speaker#1', 'Speaker#1...", "[[[70, 70], [485, 486], [500, 500], [7...", "[['In', 'the', 'summer', 'of', '2005',..." + "...", "...", "...", "..." + + """ + def __init__(self, fields=None, dropna=False): + super().__init__(fields, dropna) + self.fields = {"doc_key": "raw_words1", "speakers": "raw_words2", "clusters": "raw_words3", + "sentences": "raw_words4"} + + def _load(self, path): + r""" + 加载数据 + :param path: 数据文件路径,文件为json + + :return: + """ + dataset = DataSet() + for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): + if self.fields: + ins = {self.fields[k]: v for k, v in d.items()} + else: + ins = d + dataset.append(Instance(**ins)) + return dataset + + def download(self): + r""" + 由于版权限制,不能提供自动下载功能。可参考 + + https://www.aclweb.org/anthology/W12-4501 + + :return: + """ + raise RuntimeError("CoReference cannot be downloaded automatically.") diff --git a/fastNLP/io/loader/csv.py b/fastNLP/io/loader/csv.py new file mode 100644 index 00000000..debd5222 --- /dev/null +++ b/fastNLP/io/loader/csv.py @@ -0,0 +1,38 @@ +r"""undocumented""" + +__all__ = [ + "CSVLoader", +] + +from .loader import Loader +from ..file_reader import _read_csv +from fastNLP.core.dataset import DataSet, Instance + + +class CSVLoader(Loader): + r""" + 读取CSV格式的数据集, 返回 ``DataSet`` 。 + + """ + + def __init__(self, headers=None, sep=",", dropna=False): + r""" + + :param List[str] headers: CSV文件的文件头.定义每一列的属性名称,即返回的DataSet中`field`的名称 + 若为 ``None`` ,则将读入文件的第一行视作 ``headers`` . Default: ``None`` + :param str sep: CSV文件中列与列之间的分隔符. Default: "," + :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . + Default: ``False`` + """ + super().__init__() + self.headers = headers + self.sep = sep + self.dropna = dropna + + def _load(self, path): + ds = DataSet() + for idx, data in _read_csv(path, headers=self.headers, + sep=self.sep, dropna=self.dropna): + ds.append(Instance(**data)) + return ds + diff --git a/fastNLP/io/loader/cws.py b/fastNLP/io/loader/cws.py new file mode 100644 index 00000000..d88d6a00 --- /dev/null +++ b/fastNLP/io/loader/cws.py @@ -0,0 +1,97 @@ +r"""undocumented""" + +__all__ = [ + "CWSLoader" +] + +import glob +import os +import random +import shutil +import time + +from .loader import Loader +from fastNLP.core.dataset import DataSet, Instance + + +class CWSLoader(Loader): + r""" + CWSLoader支持的数据格式为,一行一句话,不同词之间用空格隔开, 例如: + + Example:: + + 上海 浦东 开发 与 法制 建设 同步 + 新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 ) + ... + + 该Loader读取后的DataSet具有如下的结构 + + .. csv-table:: + :header: "raw_words" + + "上海 浦东 开发 与 法制 建设 同步" + "新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )" + "..." + + """ + + def __init__(self, dataset_name: str = None): + r""" + + :param str dataset_name: data的名称,支持pku, msra, cityu(繁体), as(繁体), None + """ + super().__init__() + datanames = {'pku': 'cws-pku', 'msra': 'cws-msra', 'as': 'cws-as', 'cityu': 'cws-cityu'} + if dataset_name in datanames: + self.dataset_name = datanames[dataset_name] + else: + self.dataset_name = None + + def _load(self, path: str): + ds = DataSet() + with open(path, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if line: + ds.append(Instance(raw_words=line)) + return ds + + def download(self, dev_ratio=0.1, re_download=False) -> str: + r""" + 如果你使用了该数据集,请引用以下的文章:Thomas Emerson, The Second International Chinese Word Segmentation Bakeoff, + 2005. 更多信息可以在http://sighan.cs.uchicago.edu/bakeoff2005/查看 + + :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 + :param bool re_download: 是否重新下载数据,以重新切分数据。 + :return: str + """ + if self.dataset_name is None: + return '' + data_dir = self._get_dataset_path(dataset_name=self.dataset_name) + modify_time = 0 + for filepath in glob.glob(os.path.join(data_dir, '*')): + modify_time = os.stat(filepath).st_mtime + break + if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 + shutil.rmtree(data_dir) + data_dir = self._get_dataset_path(dataset_name=self.dataset_name) + + if not os.path.exists(os.path.join(data_dir, 'dev.txt')): + if dev_ratio > 0: + assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." + try: + with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f, \ + open(os.path.join(data_dir, 'middle_file.txt'), 'w', encoding='utf-8') as f1, \ + open(os.path.join(data_dir, 'dev.txt'), 'w', encoding='utf-8') as f2: + for line in f: + if random.random() < dev_ratio: + f2.write(line) + else: + f1.write(line) + os.remove(os.path.join(data_dir, 'train.txt')) + os.renames(os.path.join(data_dir, 'middle_file.txt'), os.path.join(data_dir, 'train.txt')) + finally: + if os.path.exists(os.path.join(data_dir, 'middle_file.txt')): + os.remove(os.path.join(data_dir, 'middle_file.txt')) + + return data_dir diff --git a/fastNLP/io/loader/json.py b/fastNLP/io/loader/json.py new file mode 100644 index 00000000..e5648a26 --- /dev/null +++ b/fastNLP/io/loader/json.py @@ -0,0 +1,45 @@ +r"""undocumented""" + +__all__ = [ + "JsonLoader" +] + +from .loader import Loader +from ..file_reader import _read_json +from fastNLP.core.dataset import DataSet, Instance + + +class JsonLoader(Loader): + r""" + 别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.loader.JsonLoader` + + 读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象 + + :param dict fields: 需要读入的json属性名称, 和读入后在DataSet中存储的field_name + ``fields`` 的 `key` 必须是json对象的属性名. ``fields`` 的 `value` 为读入后在DataSet存储的 `field_name` , + `value` 也可为 ``None`` , 这时读入后的 `field_name` 与json对象对应属性同名 + ``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None`` + :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . + Default: ``False`` + """ + + def __init__(self, fields=None, dropna=False): + super(JsonLoader, self).__init__() + self.dropna = dropna + self.fields = None + self.fields_list = None + if fields: + self.fields = {} + for k, v in fields.items(): + self.fields[k] = k if v is None else v + self.fields_list = list(self.fields.keys()) + + def _load(self, path): + ds = DataSet() + for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): + if self.fields: + ins = {self.fields[k]: v for k, v in d.items()} + else: + ins = d + ds.append(Instance(**ins)) + return ds diff --git a/fastNLP/io/loader/loader.py b/fastNLP/io/loader/loader.py new file mode 100644 index 00000000..135a9d74 --- /dev/null +++ b/fastNLP/io/loader/loader.py @@ -0,0 +1,94 @@ +r"""undocumented""" + +__all__ = [ + "Loader" +] + +from typing import Union, Dict + +from fastNLP.io.data_bundle import DataBundle +from fastNLP.io.file_utils import _get_dataset_url, get_cache_path, cached_path +from fastNLP.io.utils import check_loader_paths +from fastNLP.core.dataset import DataSet + + +class Loader: + r""" + 各种数据 Loader 的基类,提供了 API 的参考. + Loader支持以下的三个函数 + + - download() 函数:自动将该数据集下载到缓存地址,默认缓存地址为~/.fastNLP/datasets/。由于版权等原因,不是所有的Loader都实现了该方法。该方法会返回下载后文件所处的缓存地址。 + - _load() 函数:从一个数据文件中读取数据,返回一个 :class:`~fastNLP.DataSet` 。返回的DataSet的内容可以通过每个Loader的文档判断出。 + - load() 函数:将文件分别读取为DataSet,然后将多个DataSet放入到一个DataBundle中并返回 + + """ + + def __init__(self): + pass + + def _load(self, path: str) -> DataSet: + r""" + 给定一个路径,返回读取的DataSet。 + + :param str path: 路径 + :return: DataSet + """ + raise NotImplementedError + + def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: + r""" + 从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 + + :param Union[str, Dict[str, str]] paths: 支持以下的几种输入方式: + + 0.如果为None,则先查看本地是否有缓存,如果没有则自动下载并缓存。 + + 1.传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件名包含'train'、 'dev'、 'test'则会报错:: + + data_bundle = xxxLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train + # dev、 test等有所变化,可以通过以下的方式取出DataSet + tr_data = data_bundle.get_dataset('train') + te_data = data_bundle.get_dataset('test') # 如果目录下有文件包含test这个字段 + + 2.传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test:: + + paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"} + data_bundle = xxxLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test" + dev_data = data_bundle.get_dataset('dev') + + 3.传入文件路径:: + + data_bundle = xxxLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train' + tr_data = data_bundle.get_dataset('train') # 取出DataSet + + :return: 返回的 :class:`~fastNLP.io.DataBundle` + """ + if paths is None: + paths = self.download() + paths = check_loader_paths(paths) + datasets = {name: self._load(path) for name, path in paths.items()} + data_bundle = DataBundle(datasets=datasets) + return data_bundle + + def download(self) -> str: + r""" + 自动下载该数据集 + + :return: 下载后解压目录 + """ + raise NotImplementedError(f"{self.__class__} cannot download data automatically.") + + @staticmethod + def _get_dataset_path(dataset_name): + r""" + 传入dataset的名称,获取读取数据的目录。如果数据不存在,会尝试自动下载并缓存(如果支持的话) + + :param str dataset_name: 数据集的名称 + :return: str, 数据集的目录地址。直接到该目录下读取相应的数据即可。 + """ + + default_cache_path = get_cache_path() + url = _get_dataset_url(dataset_name) + output_dir = cached_path(url_or_filename=url, cache_dir=default_cache_path, name='dataset') + + return output_dir diff --git a/fastNLP/io/loader/matching.py b/fastNLP/io/loader/matching.py new file mode 100644 index 00000000..4c798f8b --- /dev/null +++ b/fastNLP/io/loader/matching.py @@ -0,0 +1,577 @@ +r"""undocumented""" + +__all__ = [ + "MNLILoader", + "SNLILoader", + "QNLILoader", + "RTELoader", + "QuoraLoader", + "BQCorpusLoader", + "CNXNLILoader", + "LCQMCLoader" +] + +import os +import warnings +from typing import Union, Dict + +from .csv import CSVLoader +from .json import JsonLoader +from .loader import Loader +from fastNLP.io.data_bundle import DataBundle +from ..utils import check_loader_paths +# from ...core.const import Const +from fastNLP.core.dataset import DataSet, Instance + + +class MNLILoader(Loader): + r""" + 读取的数据格式为: + + Example:: + + index promptID pairID genre sentence1_binary_parse sentence2_binary_parse sentence1_parse sentence2_parse sentence1 sentence2 label1 gold_label + 0 31193 31193n government ( ( Conceptually ( cream skimming ) ) ... + 1 101457 101457e telephone ( you ( ( know ( during ( ( ( the season ) and ) ( i guess ) ) )... + ... + + 读取MNLI任务的数据,读取之后的DataSet中包含以下的内容,words0是sentence1, words1是sentence2, target是gold_label, 测试集中没 + 有target列。 + + .. csv-table:: + :header: "raw_words1", "raw_words2", "target" + + "Conceptually cream ...", "Product and geography...", "neutral" + "you know during the ...", "You lose the things to the...", "entailment" + "...", "...", "..." + + """ + + def __init__(self): + super().__init__() + + def _load(self, path: str): + ds = DataSet() + with open(path, 'r', encoding='utf-8') as f: + f.readline() # 跳过header + if path.endswith("test_matched.tsv") or path.endswith('test_mismatched.tsv'): + warnings.warn("MNLI's test file has no target.") + for line in f: + line = line.strip() + if line: + parts = line.split('\t') + raw_words1 = parts[8] + raw_words2 = parts[9] + idx = int(parts[0]) + if raw_words1 and raw_words2: + ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, index=idx)) + else: + for line in f: + line = line.strip() + if line: + parts = line.split('\t') + raw_words1 = parts[8] + raw_words2 = parts[9] + target = parts[-1] + idx = int(parts[0]) + if raw_words1 and raw_words2 and target: + ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target, index=idx)) + return ds + + def load(self, paths: str = None): + r""" + + :param str paths: 传入数据所在目录,会在该目录下寻找dev_matched.tsv, dev_mismatched.tsv, test_matched.tsv, + test_mismatched.tsv, train.tsv文件夹 + :return: DataBundle + """ + if paths: + paths = os.path.abspath(os.path.expanduser(paths)) + else: + paths = self.download() + if not os.path.isdir(paths): + raise NotADirectoryError(f"{paths} is not a valid directory.") + + files = {'dev_matched': "dev_matched.tsv", + "dev_mismatched": "dev_mismatched.tsv", + "test_matched": "test_matched.tsv", + "test_mismatched": "test_mismatched.tsv", + "train": 'train.tsv'} + + datasets = {} + for name, filename in files.items(): + filepath = os.path.join(paths, filename) + if not os.path.isfile(filepath): + if 'test' not in name: + raise FileNotFoundError(f"{name} not found in directory {filepath}.") + datasets[name] = self._load(filepath) + + data_bundle = DataBundle(datasets=datasets) + + return data_bundle + + def download(self): + r""" + 如果你使用了这个数据,请引用 + + https://www.nyu.edu/projects/bowman/multinli/paper.pdf + :return: + """ + output_dir = self._get_dataset_path('mnli') + return output_dir + + +class SNLILoader(JsonLoader): + r""" + 文件每一行是一个sample,每一行都为一个json对象,其数据格式为: + + Example:: + + {"annotator_labels": ["neutral", "entailment", "neutral", "neutral", "neutral"], "captionID": "4705552913.jpg#2", + "gold_label": "neutral", "pairID": "4705552913.jpg#2r1n", + "sentence1": "Two women are embracing while holding to go packages.", + "sentence1_binary_parse": "( ( Two women ) ( ( are ( embracing ( while ( holding ( to ( go packages ) ) ) ) ) ) . ) )", + "sentence1_parse": "(ROOT (S (NP (CD Two) (NNS women)) (VP (VBP are) (VP (VBG embracing) (SBAR (IN while) (S (NP (VBG holding)) (VP (TO to) (VP (VB go) (NP (NNS packages)))))))) (. .)))", + "sentence2": "The sisters are hugging goodbye while holding to go packages after just eating lunch.", + "sentence2_binary_parse": "( ( The sisters ) ( ( are ( ( hugging goodbye ) ( while ( holding ( to ( ( go packages ) ( after ( just ( eating lunch ) ) ) ) ) ) ) ) ) . ) )", + "sentence2_parse": "(ROOT (S (NP (DT The) (NNS sisters)) (VP (VBP are) (VP (VBG hugging) (NP (UH goodbye)) (PP (IN while) (S (VP (VBG holding) (S (VP (TO to) (VP (VB go) (NP (NNS packages)) (PP (IN after) (S (ADVP (RB just)) (VP (VBG eating) (NP (NN lunch))))))))))))) (. .)))" + } + + 读取之后的DataSet中的field情况为 + + .. csv-table:: 下面是使用SNLILoader加载的DataSet所具备的field + :header: "target", "raw_words1", "raw_words2", + + "neutral ", "Two women are embracing while holding..", "The sisters are hugging goodbye..." + "entailment", "Two women are embracing while holding...", "Two woman are holding packages." + "...", "...", "..." + + """ + + def __init__(self): + super().__init__(fields={ + 'sentence1': 'raw_words1', + 'sentence2': 'raw_words2', + 'gold_label': 'target', + }) + + def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: + r""" + 从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 + + 读取的field根据Loader初始化时传入的field决定。 + + :param str paths: 传入一个目录, 将在该目录下寻找snli_1.0_train.jsonl, snli_1.0_dev.jsonl + 和snli_1.0_test.jsonl三个文件。 + + :return: 返回的 :class:`~fastNLP.io.DataBundle` + """ + _paths = {} + if paths is None: + paths = self.download() + if paths: + if os.path.isdir(paths): + if not os.path.isfile(os.path.join(paths, 'snli_1.0_train.jsonl')): + raise FileNotFoundError(f"snli_1.0_train.jsonl is not found in {paths}") + _paths['train'] = os.path.join(paths, 'snli_1.0_train.jsonl') + for filename in ['snli_1.0_dev.jsonl', 'snli_1.0_test.jsonl']: + filepath = os.path.join(paths, filename) + _paths[filename.split('_')[-1].split('.')[0]] = filepath + paths = _paths + else: + raise NotADirectoryError(f"{paths} is not a valid directory.") + + datasets = {name: self._load(path) for name, path in paths.items()} + data_bundle = DataBundle(datasets=datasets) + return data_bundle + + def download(self): + r""" + 如果您的文章使用了这份数据,请引用 + + http://nlp.stanford.edu/pubs/snli_paper.pdf + + :return: str + """ + return self._get_dataset_path('snli') + + +class QNLILoader(JsonLoader): + r""" + 第一行为标题(具体内容会被忽略),之后每一行是一个sample,由index、问题、句子和标签构成(以制表符分割),数据结构如下: + + Example:: + + index question sentence label + 0 What came into force after the new constitution was herald? As of that day, the new constitution heralding the Second Republic came into force. entailment + + QNLI数据集的Loader, + 加载的DataSet将具备以下的field, raw_words1是question, raw_words2是sentence, target是label + + .. csv-table:: + :header: "raw_words1", "raw_words2", "target" + + "What came into force after the new...", "As of that day...", "entailment" + "...","." + + test数据集没有target列 + + """ + + def __init__(self): + super().__init__() + + def _load(self, path): + ds = DataSet() + + with open(path, 'r', encoding='utf-8') as f: + f.readline() # 跳过header + if path.endswith("test.tsv"): + warnings.warn("QNLI's test file has no target.") + for line in f: + line = line.strip() + if line: + parts = line.split('\t') + raw_words1 = parts[1] + raw_words2 = parts[2] + if raw_words1 and raw_words2: + ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2)) + else: + for line in f: + line = line.strip() + if line: + parts = line.split('\t') + raw_words1 = parts[1] + raw_words2 = parts[2] + target = parts[-1] + if raw_words1 and raw_words2 and target: + ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) + return ds + + def download(self): + r""" + 如果您的实验使用到了该数据,请引用 + + https://arxiv.org/pdf/1809.05053.pdf + + :return: + """ + return self._get_dataset_path('qnli') + + +class RTELoader(Loader): + r""" + 第一行为标题(具体内容会被忽略),之后每一行是一个sample,由index、句子1、句子2和标签构成(以制表符分割),数据结构如下: + + Example:: + + index sentence1 sentence2 label + 0 Dana Reeve, the widow of the actor Christopher Reeve, has died of lung cancer at age 44, according to the Christopher Reeve Foundation. Christopher Reeve had an accident. not_entailment + + RTE数据的loader + 加载的DataSet将具备以下的field, raw_words1是sentence0,raw_words2是sentence1, target是label + + .. csv-table:: + :header: "raw_words1", "raw_words2", "target" + + "Dana Reeve, the widow of the actor...", "Christopher Reeve had an...", "not_entailment" + "...","..." + + test数据集没有target列 + """ + + def __init__(self): + super().__init__() + + def _load(self, path: str): + ds = DataSet() + + with open(path, 'r', encoding='utf-8') as f: + f.readline() # 跳过header + if path.endswith("test.tsv"): + warnings.warn("RTE's test file has no target.") + for line in f: + line = line.strip() + if line: + parts = line.split('\t') + raw_words1 = parts[1] + raw_words2 = parts[2] + if raw_words1 and raw_words2: + ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2)) + else: + for line in f: + line = line.strip() + if line: + parts = line.split('\t') + raw_words1 = parts[1] + raw_words2 = parts[2] + target = parts[-1] + if raw_words1 and raw_words2 and target: + ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) + return ds + + def download(self): + r""" + 如果您的实验使用到了该数据,请引用GLUE Benchmark + + https://openreview.net/pdf?id=rJ4km2R5t7 + + :return: + """ + return self._get_dataset_path('rte') + + +class QuoraLoader(Loader): + r""" + Quora matching任务的数据集Loader + + 支持读取的文件中的内容,应该有以下的形式, 以制表符分隔,且前三列的内容必须是:第一列是label,第二列和第三列是句子 + + Example:: + + 1 How do I get funding for my web based startup idea ? How do I get seed funding pre product ? 327970 + 0 Is honey a viable alternative to sugar for diabetics ? How would you compare the United States ' euthanasia laws to Denmark ? 90348 + ... + + 加载的DataSet将具备以下的field + + .. csv-table:: + :header: "raw_words1", "raw_words2", "target" + + "How do I get funding for my web based...", "How do I get seed funding...","1" + "Is honey a viable alternative ...", "How would you compare the United...","0" + "...","...","..." + + """ + + def __init__(self): + super().__init__() + + def _load(self, path: str): + ds = DataSet() + + with open(path, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if line: + parts = line.split('\t') + raw_words1 = parts[1] + raw_words2 = parts[2] + target = parts[0] + if raw_words1 and raw_words2 and target: + ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) + return ds + + def download(self): + r""" + 由于版权限制,不能提供自动下载功能。可参考 + + https://www.kaggle.com/c/quora-question-pairs/data + + :return: + """ + raise RuntimeError("Quora cannot be downloaded automatically.") + + +class CNXNLILoader(Loader): + r""" + 数据集简介:中文句对NLI(本为multi-lingual的数据集,但是这里只取了中文的数据集)。原句子已被MOSES tokenizer处理,这里我们将其还原并重新按字tokenize + 原始数据数据为: + + Example:: + + premise hypo label + 我们 家里 有 一个 但 我 没 找到 我 可以 用 的 时间 我们 家里 有 一个 但 我 从来 没有 时间 使用 它 . entailment + + dev和test中的数据为csv或json格式,包括十多个field,这里只取与以上三个field中的数据 + 读取后的Dataset将具有以下数据结构: + + .. csv-table:: + :header: "raw_chars1", "raw_chars2", "target" + + "我们 家里 有 一个 但 我 没 找到 我 可以 用 的 时间", "我们 家里 有 一个 但 我 从来 没有 时间 使用 它 .", "0" + "...", "...", "..." + + """ + + def __init__(self): + super(CNXNLILoader, self).__init__() + + def _load(self, path: str = None): + ds_all = DataSet() + with open(path, 'r', encoding='utf-8') as f: + head_name_list = f.readline().strip().split('\t') + sentence1_index = head_name_list.index('sentence1') + sentence2_index = head_name_list.index('sentence2') + gold_label_index = head_name_list.index('gold_label') + language_index = head_name_list.index(('language')) + + for line in f: + line = line.strip() + raw_instance = line.split('\t') + sentence1 = raw_instance[sentence1_index] + sentence2 = raw_instance[sentence2_index] + gold_label = raw_instance[gold_label_index] + language = raw_instance[language_index] + if sentence1: + ds_all.append(Instance(sentence1=sentence1, sentence2=sentence2, gold_label=gold_label, language=language)) + + ds_zh = DataSet() + for i in ds_all: + if i['language'] == 'zh': + ds_zh.append(Instance(raw_chars1=i['sentence1'], raw_chars2=i['sentence2'], target=i['gold_label'])) + + return ds_zh + + def _load_train(self, path: str = None): + ds = DataSet() + + with open(path, 'r', encoding='utf-8') as f: + next(f) + for line in f: + raw_instance = line.strip().split('\t') + premise = "".join(raw_instance[0].split())# 把已经分好词的premise和hypo强制还原为character segmentation + hypo = "".join(raw_instance[1].split()) + label = "".join(raw_instance[-1].split()) + if premise: + ds.append(Instance(premise=premise, hypo=hypo, label=label)) + + ds.rename_field('label', 'target') + ds.rename_field('premise', 'raw_chars1') + ds.rename_field('hypo', 'raw_chars2') + ds.apply(lambda i: "".join(i['raw_chars1'].split()), new_field_name='raw_chars1') + ds.apply(lambda i: "".join(i['raw_chars2'].split()), new_field_name='raw_chars2') + return ds + + def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: + if paths is None: + paths = self.download() + paths = check_loader_paths(paths) + datasets = {} + for name, path in paths.items(): + if name == 'train': + datasets[name] = self._load_train(path) + else: + datasets[name] = self._load(path) + + data_bundle = DataBundle(datasets=datasets) + return data_bundle + + def download(self) -> str: + r""" + 自动下载数据,该数据取自 https://arxiv.org/abs/1809.05053 + 在 https://arxiv.org/pdf/1905.05526.pdf https://arxiv.org/pdf/1901.10125.pdf + https://arxiv.org/pdf/1809.05053.pdf 有使用 + :return: + """ + output_dir = self._get_dataset_path('cn-xnli') + return output_dir + + +class BQCorpusLoader(Loader): + r""" + 别名: + 数据集简介:句子对二分类任务(判断是否具有相同的语义) + 原始数据结构为: + + Example:: + + sentence1,sentence2,label + 综合评分不足什么原因,综合评估的依据,0 + 什么时候我能使用微粒贷,你就赶快给我开通就行了,0 + + 读取后的Dataset将具有以下数据结构: + + .. csv-table:: + :header: "raw_chars1", "raw_chars2", "target" + + "综合评分不足什么原因", "综合评估的依据", "0" + "什么时候我能使用微粒贷", "你就赶快给我开通就行了", "0" + "...", "...", "..." + + """ + + def __init__(self): + super(BQCorpusLoader, self).__init__() + + def _load(self, path: str = None): + ds = DataSet() + with open(path, 'r', encoding='utf-8') as f: + next(f) + for line in f: + line = line.strip() + target = line[-1] + sep_index = line.index(',') + raw_chars1 = line[:sep_index] + raw_chars2 = line[sep_index + 1:] + + if raw_chars1: + ds.append(Instance(raw_chars1=raw_chars1, raw_chars2=raw_chars2, target=target)) + return ds + + def download(self): + r""" + 由于版权限制,不能提供自动下载功能。可参考 + + https://github.com/ymcui/Chinese-BERT-wwm + + :return: + """ + raise RuntimeError("BQCorpus cannot be downloaded automatically.") + + +class LCQMCLoader(Loader): + r""" + 数据集简介:句对匹配(question matching) + + 原始数据为: + + Example:: + + 喜欢打篮球的男生喜欢什么样的女生 爱打篮球的男生喜欢什么样的女生 1 + 你帮我设计小说的封面吧 谁能帮我给小说设计个封面? 0 + + + 读取后的Dataset将具有以下的数据结构 + + .. csv-table:: + :header: "raw_chars1", "raw_chars2", "target" + + "喜欢打篮球的男生喜欢什么样的女生", "爱打篮球的男生喜欢什么样的女生", "1" + "你帮我设计小说的封面吧", "妇可以戴耳机听音乐吗?", "0" + "...", "...", "..." + + + """ + + def __init__(self): + super(LCQMCLoader, self).__init__() + + def _load(self, path: str = None): + ds = DataSet() + with open(path, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + line_segments = line.split('\t') + assert len(line_segments) == 3 + + target = line_segments[-1] + + raw_chars1 = line_segments[0] + raw_chars2 = line_segments[1] + + if raw_chars1: + ds.append(Instance(raw_chars1=raw_chars1, raw_chars2=raw_chars2, target=target)) + return ds + + def download(self): + r""" + 由于版权限制,不能提供自动下载功能。可参考 + + https://github.com/ymcui/Chinese-BERT-wwm + + :return: + """ + raise RuntimeError("LCQMC cannot be downloaded automatically.") + + diff --git a/fastNLP/io/loader/qa.py b/fastNLP/io/loader/qa.py new file mode 100644 index 00000000..a3140b01 --- /dev/null +++ b/fastNLP/io/loader/qa.py @@ -0,0 +1,74 @@ +r""" +该文件中的Loader主要用于读取问答式任务的数据 + +""" + + +from .loader import Loader +import json +from fastNLP.core.dataset import DataSet, Instance + +__all__ = ['CMRC2018Loader'] + + +class CMRC2018Loader(Loader): + r""" + 请直接使用从fastNLP下载的数据进行处理。该数据集未提供测试集,测试需要通过上传到对应的系统进行评测 + + 读取之后训练集DataSet将具备以下的内容,每个问题的答案只有一个 + + .. csv-table:: + :header:"title", "context", "question", "answers", "answer_starts", "id" + + "范廷颂", "范廷颂枢机(,),圣名保禄·若瑟()...", "范廷颂是什么时候被任为主教的?", ["1963年"], ["30"], "TRAIN_186_QUERY_0" + "范廷颂", "范廷颂枢机(,),圣名保禄·若瑟()...", "1990年,范廷颂担任什么职务?", ["1990年被擢升为天..."], ["41"],"TRAIN_186_QUERY_1" + "...", "...", "...","...", ".", "..." + + 其中title是文本的标题,多条记录可能是相同的title;id是该问题的id,具备唯一性 + + 验证集DataSet将具备以下的内容,每个问题的答案可能有三个(有时候只是3个重复的答案) + + .. csv-table:: + :header: "title", "context", "question", "answers", "answer_starts", "id" + + "战国无双3", "《战国无双3》()是由光荣和ω-force开发...", "《战国无双3》是由哪两个公司合作开发的?", "['光荣和ω-force', '光荣和ω-force', '光荣和ω-force']", "[30, 30, 30]", "DEV_0_QUERY_0" + "战国无双3", "《战国无双3》()是由光荣和ω-force开发...", "男女主角亦有专属声优这一模式是由谁改编的?", "['村雨城', '村雨城', '任天堂游戏谜之村雨城']", "[226, 226, 219]", "DEV_0_QUERY_1" + "...", "...", "...","...", ".", "..." + + 其中answer_starts是从0开始的index。例如"我来自a复旦大学?",其中"复"的开始index为4。另外"Russell评价说"中的说的index为9, 因为 + 英文和数字都直接按照character计量的。 + """ + def __init__(self): + super().__init__() + + def _load(self, path: str) -> DataSet: + with open(path, 'r', encoding='utf-8') as f: + data = json.load(f)['data'] + ds = DataSet() + for entry in data: + title = entry['title'] + para = entry['paragraphs'][0] + context = para['context'] + qas = para['qas'] + for qa in qas: + question = qa['question'] + ans = qa['answers'] + answers = [] + answer_starts = [] + id = qa['id'] + for an in ans: + answers.append(an['text']) + answer_starts.append(an['answer_start']) + ds.append(Instance(title=title, context=context, question=question, answers=answers, + answer_starts=answer_starts,id=id)) + return ds + + def download(self) -> str: + r""" + 如果您使用了本数据,请引用A Span-Extraction Dataset for Chinese Machine Reading Comprehension. Yiming Cui, Ting Liu, etc. + + :return: + """ + output_dir = self._get_dataset_path('cmrc2018') + return output_dir + diff --git a/fastNLP/io/loader/summarization.py b/fastNLP/io/loader/summarization.py new file mode 100644 index 00000000..3fe5f7a3 --- /dev/null +++ b/fastNLP/io/loader/summarization.py @@ -0,0 +1,63 @@ +r"""undocumented""" + +__all__ = [ + "ExtCNNDMLoader" +] + +import os +from typing import Union, Dict + +from ..data_bundle import DataBundle +from ..utils import check_loader_paths +from .json import JsonLoader + + +class ExtCNNDMLoader(JsonLoader): + r""" + 读取之后的DataSet中的field情况为 + + .. csv-table:: + :header: "text", "summary", "label", "publication" + + ["I got new tires from them and... ","..."], ["The new tires...","..."], [0, 1], "cnndm" + ["Don't waste your time. We had two...","..."], ["Time is precious","..."], [1], "cnndm" + ["..."], ["..."], [], "cnndm" + + """ + + def __init__(self, fields=None): + fields = fields or {"text": None, "summary": None, "label": None, "publication": None} + super(ExtCNNDMLoader, self).__init__(fields=fields) + + def load(self, paths: Union[str, Dict[str, str]] = None): + r""" + 从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 + + 读取的field根据ExtCNNDMLoader初始化时传入的headers决定。 + + :param str paths: 传入一个目录, 将在该目录下寻找train.label.jsonl, dev.label.jsonl + test.label.jsonl三个文件(该目录还应该需要有一个名字为vocab的文件,在 :class:`~fastNLP.io.ExtCNNDMPipe` + 当中需要用到)。 + + :return: 返回 :class:`~fastNLP.io.DataBundle` + """ + if paths is None: + paths = self.download() + paths = check_loader_paths(paths) + if ('train' in paths) and ('test' not in paths): + paths['test'] = paths['train'] + paths.pop('train') + + datasets = {name: self._load(path) for name, path in paths.items()} + data_bundle = DataBundle(datasets=datasets) + return data_bundle + + def download(self): + r""" + 如果你使用了这个数据,请引用 + + https://arxiv.org/pdf/1506.03340.pdf + :return: + """ + output_dir = self._get_dataset_path('ext-cnndm') + return output_dir diff --git a/fastNLP/io/model_io.py b/fastNLP/io/model_io.py new file mode 100644 index 00000000..30a8ef33 --- /dev/null +++ b/fastNLP/io/model_io.py @@ -0,0 +1,71 @@ +r""" +用于载入和保存模型 +""" +__all__ = [ + "ModelLoader", + "ModelSaver" +] + +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch + + +class ModelLoader: + r""" + 用于读取模型 + """ + + def __init__(self): + super(ModelLoader, self).__init__() + + @staticmethod + def load_pytorch(empty_model, model_path): + r""" + 从 ".pkl" 文件读取 PyTorch 模型 + + :param empty_model: 初始化参数的 PyTorch 模型 + :param str model_path: 模型保存的路径 + """ + empty_model.load_state_dict(torch.load(model_path)) + + @staticmethod + def load_pytorch_model(model_path): + r""" + 读取整个模型 + + :param str model_path: 模型保存的路径 + """ + return torch.load(model_path) + + +class ModelSaver(object): + r""" + 用于保存模型 + + Example:: + + saver = ModelSaver("./save/model_ckpt_100.pkl") + saver.save_pytorch(model) + + """ + + def __init__(self, save_path): + r""" + + :param save_path: 模型保存的路径 + """ + self.save_path = save_path + + def save_pytorch(self, model, param_only=True): + r""" + 把 PyTorch 模型存入 ".pkl" 文件 + + :param model: PyTorch 模型 + :param bool param_only: 是否只保存模型的参数(否则保存整个模型) + + """ + if param_only is True: + torch.save(model.state_dict(), self.save_path) + else: + torch.save(model, self.save_path) diff --git a/fastNLP/io/pipe/__init__.py b/fastNLP/io/pipe/__init__.py new file mode 100644 index 00000000..35965ca3 --- /dev/null +++ b/fastNLP/io/pipe/__init__.py @@ -0,0 +1,80 @@ +r""" +Pipe用于处理通过 Loader 读取的数据,所有的 Pipe 都包含 ``process`` 和 ``process_from_file`` 两种方法。 +``process(data_bundle)`` 传入一个 :class:`~fastNLP.io.DataBundle` 类型的对象, 在传入的 `data_bundle` 上进行原位修改,并将其返回; +``process_from_file(paths)`` 传入的文件路径,返回一个 :class:`~fastNLP.io.DataBundle` 类型的对象。 +``process(data_bundle)`` 或者 ``process_from_file(paths)`` 的返回 `data_bundle` 中的 :class:`~fastNLP.DataSet` +一般都包含原文与转换为index的输入以及转换为index的target;除了 :class:`~fastNLP.DataSet` 之外, +`data_bundle` 还会包含将field转为index时所建立的词表。 + +""" +__all__ = [ + "Pipe", + + "CWSPipe", + + "CLSBasePipe", + "AGsNewsPipe", + "DBPediaPipe", + "YelpFullPipe", + "YelpPolarityPipe", + "SSTPipe", + "SST2Pipe", + "IMDBPipe", + "ChnSentiCorpPipe", + "THUCNewsPipe", + "WeiboSenti100kPipe", + "MRPipe", "R52Pipe", "R8Pipe", "OhsumedPipe", "NG20Pipe", + + "Conll2003NERPipe", + "OntoNotesNERPipe", + "MsraNERPipe", + "WeiboNERPipe", + "PeopleDailyPipe", + "Conll2003Pipe", + + "MatchingBertPipe", + "RTEBertPipe", + "SNLIBertPipe", + "QuoraBertPipe", + "QNLIBertPipe", + "MNLIBertPipe", + "CNXNLIBertPipe", + "BQCorpusBertPipe", + "LCQMCBertPipe", + "MatchingPipe", + "RTEPipe", + "SNLIPipe", + "QuoraPipe", + "QNLIPipe", + "MNLIPipe", + "LCQMCPipe", + "CNXNLIPipe", + "BQCorpusPipe", + "RenamePipe", + "GranularizePipe", + "MachingTruncatePipe", + + "CoReferencePipe", + + "CMRC2018BertPipe", + + "R52PmiGraphPipe", + "R8PmiGraphPipe", + "OhsumedPmiGraphPipe", + "NG20PmiGraphPipe", + "MRPmiGraphPipe" +] + +from .classification import CLSBasePipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, \ + WeiboSenti100kPipe, AGsNewsPipe, DBPediaPipe, MRPipe, R8Pipe, R52Pipe, OhsumedPipe, NG20Pipe +from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe +from .conll import Conll2003Pipe +from .coreference import CoReferencePipe +from .cws import CWSPipe +from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \ + MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe, CNXNLIBertPipe, CNXNLIPipe, BQCorpusBertPipe, \ + LCQMCPipe, BQCorpusPipe, LCQMCBertPipe, RenamePipe, GranularizePipe, MachingTruncatePipe +from .pipe import Pipe +from .qa import CMRC2018BertPipe + +from .construct_graph import MRPmiGraphPipe, R8PmiGraphPipe, R52PmiGraphPipe, NG20PmiGraphPipe, OhsumedPmiGraphPipe diff --git a/fastNLP/io/pipe/classification.py b/fastNLP/io/pipe/classification.py new file mode 100644 index 00000000..0e5915a9 --- /dev/null +++ b/fastNLP/io/pipe/classification.py @@ -0,0 +1,939 @@ +r"""undocumented""" + +__all__ = [ + "CLSBasePipe", + "AGsNewsPipe", + "DBPediaPipe", + "YelpFullPipe", + "YelpPolarityPipe", + "SSTPipe", + "SST2Pipe", + 'IMDBPipe', + "ChnSentiCorpPipe", + "THUCNewsPipe", + "WeiboSenti100kPipe", + "MRPipe", "R8Pipe", "R52Pipe", "OhsumedPipe", "NG20Pipe" +] + +import re +import warnings + +try: + from nltk import Tree +except: + # only nltk in some versions can run + pass + +from .pipe import Pipe +from .utils import get_tokenizer, _indexize, _add_words_field, _add_chars_field, _granularize +from ..data_bundle import DataBundle +from ..loader.classification import ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader +from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader, \ + AGsNewsLoader, DBPediaLoader, MRLoader, R52Loader, R8Loader, OhsumedLoader, NG20Loader +# from ...core._logger import log +# from ...core.const import Const +from fastNLP.core.dataset import DataSet, Instance + + +class CLSBasePipe(Pipe): + + def __init__(self, lower: bool = False, tokenizer: str = 'spacy', lang='en'): + super().__init__() + self.lower = lower + self.tokenizer = get_tokenizer(tokenizer, lang=lang) + + def _tokenize(self, data_bundle, field_name='words', new_field_name=None): + r""" + 将DataBundle中的数据进行tokenize + + :param DataBundle data_bundle: + :param str field_name: + :param str new_field_name: + :return: 传入的DataBundle对象 + """ + new_field_name = new_field_name or field_name + for name, dataset in data_bundle.iter_datasets(): + dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) + + return data_bundle + + def process(self, data_bundle: DataBundle): + r""" + 传入的DataSet应该具备如下的结构 + + .. csv-table:: + :header: "raw_words", "target" + + "I got 'new' tires from them and... ", "1" + "Don't waste your time. We had two...", "1" + "...", "..." + + :param data_bundle: + :return: + """ + # 复制一列words + data_bundle = _add_words_field(data_bundle, lower=self.lower) + # 进行tokenize + data_bundle = self._tokenize(data_bundle=data_bundle, field_name='words') + # 建立词表并index + data_bundle = _indexize(data_bundle=data_bundle) + + for name, dataset in data_bundle.datasets.items(): + dataset.add_seq_len('words') + + data_bundle.set_input('words', 'seq_len', 'target') + + return data_bundle + + def process_from_file(self, paths) -> DataBundle: + r""" + 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` + + :param paths: + :return: DataBundle + """ + raise NotImplementedError + + +class YelpFullPipe(CLSBasePipe): + r""" + 处理YelpFull的数据, 处理之后DataSet中的内容如下 + + .. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field + :header: "raw_words", "target", "words", "seq_len" + + "I got 'new' tires from them and within...", 0 ,"[7, 110, 22, 107, 22, 499, 59, 140, 3,...]", 160 + " Don't waste your time. We had two dif... ", 0, "[277, 17, 278, 38, 30, 112, 24, 85, 27...", 40 + "...", ., "[...]", . + + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_words | target | words | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | False | True | True | + | is_target | False | True | False | False | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ + + """ + + def __init__(self, lower: bool = False, granularity=5, tokenizer: str = 'spacy'): + r""" + + :param bool lower: 是否对输入进行小写化。 + :param int granularity: 支持2, 3, 5。若为2, 则认为是2分类问题,将1、2归为1类,4、5归为一类,丢掉2;若为3, 则有3分类问题,将 + 1、2归为1类,3归为1类,4、5归为1类;若为5, 则有5分类问题。 + :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 + """ + super().__init__(lower=lower, tokenizer=tokenizer, lang='en') + assert granularity in (2, 3, 5), "granularity can only be 2,3,5." + self.granularity = granularity + + if granularity == 2: + self.tag_map = {"1": "negative", "2": "negative", "4": "positive", "5": "positive"} + elif granularity == 3: + self.tag_map = {"1": "negative", "2": "negative", "3": "medium", "4": "positive", "5": "positive"} + else: + self.tag_map = None + + def process(self, data_bundle): + r""" + 传入的DataSet应该具备如下的结构 + + .. csv-table:: + :header: "raw_words", "target" + + "I got 'new' tires from them and... ", "1" + "Don't waste your time. We had two...", "1" + "...", "..." + + :param data_bundle: + :return: + """ + if self.tag_map is not None: + data_bundle = _granularize(data_bundle, self.tag_map) + + data_bundle = super().process(data_bundle) + + return data_bundle + + def process_from_file(self, paths=None): + r""" + + :param paths: + :return: DataBundle + """ + data_bundle = YelpFullLoader().load(paths) + return self.process(data_bundle=data_bundle) + + +class YelpPolarityPipe(CLSBasePipe): + r""" + 处理YelpPolarity的数据, 处理之后DataSet中的内容如下 + + .. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field + :header: "raw_words", "target", "words", "seq_len" + + "I got 'new' tires from them and within...", 0 ,"[7, 110, 22, 107, 22, 499, 59, 140, 3,...]", 160 + " Don't waste your time. We had two dif... ", 0, "[277, 17, 278, 38, 30, 112, 24, 85, 27...", 40 + "...", ., "[...]", . + + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_words | target | words | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | False | True | True | + | is_target | False | True | False | False | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ + + """ + + def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): + r""" + + :param bool lower: 是否对输入进行小写化。 + :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 + """ + super().__init__(lower=lower, tokenizer=tokenizer, lang='en') + + def process_from_file(self, paths=None): + r""" + + :param str paths: + :return: DataBundle + """ + data_bundle = YelpPolarityLoader().load(paths) + return self.process(data_bundle=data_bundle) + + +class AGsNewsPipe(CLSBasePipe): + r""" + 处理AG's News的数据, 处理之后DataSet中的内容如下 + + .. csv-table:: 下面是使用AGsNewsPipe处理后的DataSet所具备的field + :header: "raw_words", "target", "words", "seq_len" + + "I got 'new' tires from them and within...", 0 ,"[7, 110, 22, 107, 22, 499, 59, 140, 3,...]", 160 + " Don't waste your time. We had two dif... ", 0, "[277, 17, 278, 38, 30, 112, 24, 85, 27...", 40 + "...", ., "[...]", . + + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_words | target | words | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | False | True | True | + | is_target | False | True | False | False | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ + + """ + + def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): + r""" + + :param bool lower: 是否对输入进行小写化。 + :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 + """ + super().__init__(lower=lower, tokenizer=tokenizer, lang='en') + + def process_from_file(self, paths=None): + r""" + :param str paths: + :return: DataBundle + """ + data_bundle = AGsNewsLoader().load(paths) + return self.process(data_bundle=data_bundle) + + +class DBPediaPipe(CLSBasePipe): + r""" + 处理DBPedia的数据, 处理之后DataSet中的内容如下 + + .. csv-table:: 下面是使用DBPediaPipe处理后的DataSet所具备的field + :header: "raw_words", "target", "words", "seq_len" + + "I got 'new' tires from them and within...", 0 ,"[7, 110, 22, 107, 22, 499, 59, 140, 3,...]", 160 + " Don't waste your time. We had two dif... ", 0, "[277, 17, 278, 38, 30, 112, 24, 85, 27...", 40 + "...", ., "[...]", . + + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_words | target | words | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | False | True | True | + | is_target | False | True | False | False | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ + + """ + + def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): + r""" + + :param bool lower: 是否对输入进行小写化。 + :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 + """ + super().__init__(lower=lower, tokenizer=tokenizer, lang='en') + + def process_from_file(self, paths=None): + r""" + :param str paths: + :return: DataBundle + """ + data_bundle = DBPediaLoader().load(paths) + return self.process(data_bundle=data_bundle) + + +class SSTPipe(CLSBasePipe): + r""" + 经过该Pipe之后,DataSet中具备的field如下所示 + + .. csv-table:: 下面是使用SSTPipe处理后的DataSet所具备的field + :header: "raw_words", "words", "target", "seq_len" + + "It 's a lovely film with lovely perfor...", 1, "[187, 6, 5, 132, 120, 70, 132, 188, 25...", 13 + "No one goes unindicted here , which is...", 0, "[191, 126, 192, 193, 194, 4, 195, 17, ...", 13 + "...", ., "[...]", . + + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_words | target | words | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | False | True | True | + | is_target | False | True | False | False | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ + + """ + + def __init__(self, subtree=False, train_subtree=True, lower=False, granularity=5, tokenizer='spacy'): + r""" + + :param bool subtree: 是否将train, test, dev数据展开为子树,扩充数据量。 Default: ``False`` + :param bool train_subtree: 是否将train集通过子树扩展数据。 + :param bool lower: 是否对输入进行小写化。 + :param int granularity: 支持2, 3, 5。若为2, 则认为是2分类问题,将0、1归为1类,3、4归为一类,丢掉2;若为3, 则有3分类问题,将 + 0、1归为1类,2归为1类,3、4归为1类;若为5, 则有5分类问题。 + :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 + """ + super().__init__(tokenizer=tokenizer, lang='en') + self.subtree = subtree + self.train_tree = train_subtree + self.lower = lower + assert granularity in (2, 3, 5), "granularity can only be 2,3,5." + self.granularity = granularity + + if granularity == 2: + self.tag_map = {"0": "negative", "1": "negative", "3": "positive", "4": "positive"} + elif granularity == 3: + self.tag_map = {"0": "negative", "1": "negative", "2": "medium", "3": "positive", "4": "positive"} + else: + self.tag_map = None + + def process(self, data_bundle: DataBundle): + r""" + 对DataBundle中的数据进行预处理。输入的DataSet应该至少拥有raw_words这一列,且内容类似与 + + .. csv-table:: 下面是使用SSTLoader读取的DataSet所具备的field + :header: "raw_words" + + "(2 (3 (3 Effective) (2 but)) (1 (1 too-tepid)..." + "(3 (3 (2 If) (3 (2 you) (3 (2 sometimes) ..." + "..." + + :param ~fastNLP.io.DataBundle data_bundle: 需要处理的DataBundle对象 + :return: + """ + # 先取出subtree + for name in list(data_bundle.datasets.keys()): + dataset = data_bundle.get_dataset(name) + ds = DataSet() + use_subtree = self.subtree or (name == 'train' and self.train_tree) + for ins in dataset: + raw_words = ins['raw_words'] + tree = Tree.fromstring(raw_words) + if use_subtree: + for t in tree.subtrees(): + raw_words = " ".join(t.leaves()) + instance = Instance(raw_words=raw_words, target=t.label()) + ds.append(instance) + else: + instance = Instance(raw_words=' '.join(tree.leaves()), target=tree.label()) + ds.append(instance) + data_bundle.set_dataset(ds, name) + + # 根据granularity设置tag + data_bundle = _granularize(data_bundle, tag_map=self.tag_map) + + data_bundle = super().process(data_bundle) + + return data_bundle + + def process_from_file(self, paths=None): + data_bundle = SSTLoader().load(paths) + return self.process(data_bundle=data_bundle) + + +class SST2Pipe(CLSBasePipe): + r""" + 加载SST2的数据, 处理完成之后DataSet将拥有以下的field + + .. csv-table:: + :header: "raw_words", "target", "words", "seq_len" + + "it 's a charming and often affecting j... ", 1, "[19, 9, 6, 111, 5, 112, 113, 114, 3]", 9 + "unflinchingly bleak and desperate", 0, "[115, 116, 5, 117]", 4 + "...", "...", ., . + + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_words | target | words | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | False | True | True | + | is_target | False | True | False | False | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ + + """ + + def __init__(self, lower=False, tokenizer='spacy'): + r""" + + :param bool lower: 是否对输入进行小写化。 + :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 + """ + super().__init__(lower=lower, tokenizer=tokenizer, lang='en') + + def process_from_file(self, paths=None): + r""" + + :param str paths: 如果为None,则自动下载并缓存到fastNLP的缓存地址。 + :return: DataBundle + """ + data_bundle = SST2Loader().load(paths) + return self.process(data_bundle) + + +class IMDBPipe(CLSBasePipe): + r""" + 经过本Pipe处理后DataSet将如下 + + .. csv-table:: 输出DataSet的field + :header: "raw_words", "target", "words", "seq_len" + + "Bromwell High is a cartoon ... ", 0, "[3, 5, 6, 9, ...]", 20 + "Story of a man who has ...", 1, "[20, 43, 9, 10, ...]", 31 + "...", ., "[...]", . + + 其中raw_words为str类型,是原文; words是转换为index的输入; target是转换为index的目标值; + words列被设置为input; target列被设置为target。 + + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_words | target | words | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | False | True | True | + | is_target | False | True | False | False | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ + + """ + + def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): + r""" + + :param bool lower: 是否将words列的数据小写。 + :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 + """ + super().__init__(tokenizer=tokenizer, lang='en') + self.lower = lower + + def process(self, data_bundle: DataBundle): + r""" + 期待的DataBunlde中输入的DataSet应该类似于如下,有两个field,raw_words和target,且均为str类型 + + .. csv-table:: 输入DataSet的field + :header: "raw_words", "target" + + "Bromwell High is a cartoon ... ", "pos" + "Story of a man who has ...", "neg" + "...", "..." + + :param DataBunlde data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和target两个field,且raw_words列应该为str, + target列应该为str。 + :return: DataBundle + """ + + # 替换
+ def replace_br(raw_words): + raw_words = raw_words.replace("
", ' ') + return raw_words + + for name, dataset in data_bundle.datasets.items(): + dataset.apply_field(replace_br, field_name='raw_words', new_field_name='raw_words') + + data_bundle = super().process(data_bundle) + + return data_bundle + + def process_from_file(self, paths=None): + r""" + + :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 + :return: DataBundle + """ + # 读取数据 + data_bundle = IMDBLoader().load(paths) + data_bundle = self.process(data_bundle) + + return data_bundle + + +class ChnSentiCorpPipe(Pipe): + r""" + 处理之后的DataSet有以下的结构 + + .. csv-table:: + :header: "raw_chars", "target", "chars", "seq_len" + + "這間酒店環境和服務態度亦算不錯,但房間空間太小~~", 1, "[2, 3, 4, 5, ...]", 31 + "<荐书> 推荐所有喜欢<红楼>...", 1, "[10, 21, ....]", 25 + "..." + + 其中chars, seq_len是input,target是target + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_chars | target | chars | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | True | True | True | + | is_target | False | True | False | False | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ + + """ + + def __init__(self, bigrams=False, trigrams=False): + r""" + + :param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 + 设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 + data_bundle.get_vocab('bigrams')获取. + :param bool trigrams: 是否增加一列trigrams. trigrams的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] + 。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 + data_bundle.get_vocab('trigrams')获取. + """ + super().__init__() + + self.bigrams = bigrams + self.trigrams = trigrams + + def _tokenize(self, data_bundle): + r""" + 将DataSet中的"复旦大学"拆分为["复", "旦", "大", "学"]. 未来可以通过扩展这个函数实现分词。 + + :param data_bundle: + :return: + """ + data_bundle.apply_field(list, field_name='chars', new_field_name='chars') + return data_bundle + + def process(self, data_bundle: DataBundle): + r""" + 可以处理的DataSet应该具备以下的field + + .. csv-table:: + :header: "raw_chars", "target" + + "這間酒店環境和服務態度亦算不錯,但房間空間太小~~", "1" + "<荐书> 推荐所有喜欢<红楼>...", "1" + "..." + + :param data_bundle: + :return: + """ + _add_chars_field(data_bundle, lower=False) + + data_bundle = self._tokenize(data_bundle) + + input_field_names = ['chars'] + if self.bigrams: + for name, dataset in data_bundle.iter_datasets(): + dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + [''])], + field_name='chars', new_field_name='bigrams') + input_field_names.append('bigrams') + if self.trigrams: + for name, dataset in data_bundle.iter_datasets(): + dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in + zip(chars, chars[1:] + [''], chars[2:] + [''] * 2)], + field_name='chars', new_field_name='trigrams') + input_field_names.append('trigrams') + + # index + _indexize(data_bundle, input_field_names, 'target') + + input_fields = ['target', 'seq_len'] + input_field_names + target_fields = ['target'] + + for name, dataset in data_bundle.datasets.items(): + dataset.add_seq_len('chars') + + data_bundle.set_input(*input_fields, *target_fields) + + return data_bundle + + def process_from_file(self, paths=None): + r""" + + :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 + :return: DataBundle + """ + # 读取数据 + data_bundle = ChnSentiCorpLoader().load(paths) + data_bundle = self.process(data_bundle) + + return data_bundle + + +class THUCNewsPipe(CLSBasePipe): + r""" + 处理之后的DataSet有以下的结构 + + .. csv-table:: + :header: "raw_chars", "target", "chars", "seq_len" + + "马晓旭意外受伤让国奥警惕 无奈大雨格外青睐殷家军记者傅亚雨沈阳报道...", 0, "[409, 1197, 2146, 213, ...]", 746 + "..." + + 其中chars, seq_len是input,target是target + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_chars | target | chars | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | True | True | True | + | is_target | False | True | False | False | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ + + :param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 + 设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 + data_bundle.get_vocab('bigrams')获取. + :param bool trigrams: 是否增加一列trigrams. trigrams的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] + 。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 + data_bundle.get_vocab('trigrams')获取. + """ + + def __init__(self, bigrams=False, trigrams=False): + super().__init__() + + self.bigrams = bigrams + self.trigrams = trigrams + + def _chracter_split(self, sent): + return list(sent) + # return [w for w in sent] + + def _raw_split(self, sent): + return sent.split() + + def _tokenize(self, data_bundle, field_name='words', new_field_name=None): + new_field_name = new_field_name or field_name + for name, dataset in data_bundle.datasets.items(): + dataset.apply_field(self._chracter_split, field_name=field_name, new_field_name=new_field_name) + return data_bundle + + def process(self, data_bundle: DataBundle): + r""" + 可处理的DataSet应具备如下的field + + .. csv-table:: + :header: "raw_words", "target" + + "马晓旭意外受伤让国奥警惕 无奈大雨格外青睐殷家军记者傅亚雨沈阳报道 ... ", "体育" + "...", "..." + + :param data_bundle: + :return: + """ + # 根据granularity设置tag + tag_map = {'体育': 0, '财经': 1, '房产': 2, '家居': 3, '教育': 4, '科技': 5, '时尚': 6, '时政': 7, '游戏': 8, '娱乐': 9} + data_bundle = _granularize(data_bundle=data_bundle, tag_map=tag_map) + + # clean,lower + + # CWS(tokenize) + data_bundle = self._tokenize(data_bundle=data_bundle, field_name='raw_chars', new_field_name='chars') + + input_field_names = ['chars'] + + # n-grams + if self.bigrams: + for name, dataset in data_bundle.iter_datasets(): + dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + [''])], + field_name='chars', new_field_name='bigrams') + input_field_names.append('bigrams') + if self.trigrams: + for name, dataset in data_bundle.iter_datasets(): + dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in + zip(chars, chars[1:] + [''], chars[2:] + [''] * 2)], + field_name='chars', new_field_name='trigrams') + input_field_names.append('trigrams') + + # index + data_bundle = _indexize(data_bundle=data_bundle, input_field_names='chars') + + # add length + for name, dataset in data_bundle.datasets.items(): + dataset.add_seq_len(field_name='chars', new_field_name='seq_len') + + input_fields = ['target', 'seq_len'] + input_field_names + target_fields = ['target'] + + data_bundle.set_input(*input_fields, *target_fields) + + return data_bundle + + def process_from_file(self, paths=None): + r""" + :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 + :return: DataBundle + """ + data_loader = THUCNewsLoader() # 此处需要实例化一个data_loader,否则传入load()的参数为None + data_bundle = data_loader.load(paths) + data_bundle = self.process(data_bundle) + return data_bundle + + +class WeiboSenti100kPipe(CLSBasePipe): + r""" + 处理之后的DataSet有以下的结构 + + .. csv-table:: + :header: "raw_chars", "target", "chars", "seq_len" + + "六一出生的?好讽刺…… //@祭春姬:他爸爸是外星人吧 //@面孔小高:现在的孩子都怎么了 [怒][怒][怒]", 0, "[0, 690, 18, ...]", 56 + "..." + + 其中chars, seq_len是input,target是target + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_chars | target | chars | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | True | True | True | + | is_target | False | True | False | False | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ + + :param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 + 设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 + data_bundle.get_vocab('bigrams')获取. + :param bool trigrams: 是否增加一列trigrams. trigrams的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] + 。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 + data_bundle.get_vocab('trigrams')获取. + """ + + def __init__(self, bigrams=False, trigrams=False): + super().__init__() + + self.bigrams = bigrams + self.trigrams = trigrams + + def _chracter_split(self, sent): + return list(sent) + + def _tokenize(self, data_bundle, field_name='words', new_field_name=None): + new_field_name = new_field_name or field_name + for name, dataset in data_bundle.datasets.items(): + dataset.apply_field(self._chracter_split, field_name=field_name, new_field_name=new_field_name) + return data_bundle + + def process(self, data_bundle: DataBundle): + r""" + 可处理的DataSet应具备以下的field + + .. csv-table:: + :header: "raw_chars", "target" + + "六一出生的?好讽刺…… //@祭春姬:他爸爸是外星人吧 //@面孔小高:现在的孩子都怎么了 [怒][怒][怒]", "0" + "...", "..." + + :param data_bundle: + :return: + """ + # clean,lower + + # CWS(tokenize) + data_bundle = self._tokenize(data_bundle=data_bundle, field_name='raw_chars', new_field_name='chars') + + input_field_names = ['chars'] + + # n-grams + if self.bigrams: + for name, dataset in data_bundle.iter_datasets(): + dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + [''])], + field_name='chars', new_field_name='bigrams') + input_field_names.append('bigrams') + if self.trigrams: + for name, dataset in data_bundle.iter_datasets(): + dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in + zip(chars, chars[1:] + [''], chars[2:] + [''] * 2)], + field_name='chars', new_field_name='trigrams') + input_field_names.append('trigrams') + + # index + data_bundle = _indexize(data_bundle=data_bundle, input_field_names='chars') + + # add length + for name, dataset in data_bundle.datasets.items(): + dataset.add_seq_len(field_name='chars', new_field_name='seq_len') + + input_fields = ['target', 'seq_len'] + input_field_names + target_fields = ['target'] + + data_bundle.set_input(*input_fields, *target_fields) + + return data_bundle + + def process_from_file(self, paths=None): + r""" + :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 + :return: DataBundle + """ + data_loader = WeiboSenti100kLoader() # 此处需要实例化一个data_loader,否则传入load()的参数为None + data_bundle = data_loader.load(paths) + data_bundle = self.process(data_bundle) + return data_bundle + +class MRPipe(CLSBasePipe): + def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): + r""" + + :param bool lower: 是否将words列的数据小写。 + :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 + """ + super().__init__(tokenizer=tokenizer, lang='en') + self.lower = lower + + def process_from_file(self, paths=None): + r""" + + :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 + :return: DataBundle + """ + # 读取数据 + data_bundle = MRLoader().load(paths) + data_bundle = self.process(data_bundle) + + return data_bundle + + +class R8Pipe(CLSBasePipe): + def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): + r""" + + :param bool lower: 是否将words列的数据小写。 + :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 + """ + super().__init__(tokenizer=tokenizer, lang='en') + self.lower = lower + + def process_from_file(self, paths=None): + r""" + + :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 + :return: DataBundle + """ + # 读取数据 + data_bundle = R8Loader().load(paths) + data_bundle = self.process(data_bundle) + + return data_bundle + + +class R52Pipe(CLSBasePipe): + def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): + r""" + + :param bool lower: 是否将words列的数据小写。 + :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 + """ + super().__init__(tokenizer=tokenizer, lang='en') + self.lower = lower + + def process_from_file(self, paths=None): + r""" + + :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 + :return: DataBundle + """ + # 读取数据 + data_bundle = R52Loader().load(paths) + data_bundle = self.process(data_bundle) + + return data_bundle + + +class OhsumedPipe(CLSBasePipe): + def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): + r""" + + :param bool lower: 是否将words列的数据小写。 + :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 + """ + super().__init__(tokenizer=tokenizer, lang='en') + self.lower = lower + + def process_from_file(self, paths=None): + r""" + + :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 + :return: DataBundle + """ + # 读取数据 + data_bundle = OhsumedLoader().load(paths) + data_bundle = self.process(data_bundle) + + return data_bundle + + +class NG20Pipe(CLSBasePipe): + def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): + r""" + + :param bool lower: 是否将words列的数据小写。 + :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 + """ + super().__init__(tokenizer=tokenizer, lang='en') + self.lower = lower + + def process_from_file(self, paths=None): + r""" + + :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 + :return: DataBundle + """ + # 读取数据 + data_bundle = NG20Loader().load(paths) + data_bundle = self.process(data_bundle) + + return data_bundle \ No newline at end of file diff --git a/fastNLP/io/pipe/conll.py b/fastNLP/io/pipe/conll.py new file mode 100644 index 00000000..bb3b1d51 --- /dev/null +++ b/fastNLP/io/pipe/conll.py @@ -0,0 +1,427 @@ +r"""undocumented""" + +__all__ = [ + "Conll2003NERPipe", + "Conll2003Pipe", + "OntoNotesNERPipe", + "MsraNERPipe", + "PeopleDailyPipe", + "WeiboNERPipe" +] + +from .pipe import Pipe +from .utils import _add_chars_field +from .utils import _indexize, _add_words_field +from .utils import iob2, iob2bioes +from fastNLP.io.data_bundle import DataBundle +from ..loader.conll import Conll2003NERLoader, OntoNotesNERLoader +from ..loader.conll import PeopleDailyNERLoader, WeiboNERLoader, MsraNERLoader, ConllLoader +# from ...core.const import Const +from ...core.vocabulary import Vocabulary + + +class _NERPipe(Pipe): + r""" + NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表 + (创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将words,target列根据相应的 + Vocabulary转换为index。 + + raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 + target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target, seq_len。 + """ + + def __init__(self, encoding_type: str = 'bio', lower: bool = False): + r""" + + :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 + :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 + """ + if encoding_type == 'bio': + self.convert_tag = iob2 + elif encoding_type == 'bioes': + self.convert_tag = lambda words: iob2bioes(iob2(words)) + else: + raise ValueError("encoding_type only supports `bio` and `bioes`.") + self.lower = lower + + def process(self, data_bundle: DataBundle) -> DataBundle: + r""" + 支持的DataSet的field为 + + .. csv-table:: + :header: "raw_words", "target" + + "[Nadim, Ladki]", "[B-PER, I-PER]" + "[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]" + "[...]", "[...]" + + :param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]在传入DataBundle基础上原位修改。 + :return DataBundle: + """ + # 转换tag + for name, dataset in data_bundle.iter_datasets(): + dataset.apply_field(self.convert_tag, field_name='target', new_field_name='target') + + _add_words_field(data_bundle, lower=self.lower) + + # index + _indexize(data_bundle) + + input_fields = ['target', 'words', 'seq_len'] + target_fields = ['target', 'seq_len'] + + for name, dataset in data_bundle.iter_datasets(): + dataset.add_seq_len('words') + + data_bundle.set_input(*input_fields, *target_fields) + + return data_bundle + + +class Conll2003NERPipe(_NERPipe): + r""" + Conll2003的NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表 + (创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将words,target列根据相应的 + Vocabulary转换为index。 + 经过该Pipe过后,DataSet中的内容如下所示 + + .. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader + :header: "raw_words", "target", "words", "seq_len" + + "[Nadim, Ladki]", "[1, 2]", "[2, 3]", 2 + "[AL-AIN, United, Arab, ...]", "[3, 4,...]", "[4, 5, 6,...]", 6 + "[...]", "[...]", "[...]", . + + raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 + target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target。 + + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_words | target | words | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | True | True | True | + | is_target | False | True | False | True | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ + + """ + + def process_from_file(self, paths) -> DataBundle: + r""" + + :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.ConllLoader` 的load函数。 + :return: DataBundle + """ + # 读取数据 + data_bundle = Conll2003NERLoader().load(paths) + data_bundle = self.process(data_bundle) + + return data_bundle + + +class Conll2003Pipe(Pipe): + r""" + 经过该Pipe后,DataSet中的内容如下 + + .. csv-table:: + :header: "raw_words" , "pos", "chunk", "ner", "words", "seq_len" + + "[Nadim, Ladki]", "[0, 0]", "[1, 2]", "[1, 2]", "[2, 3]", 2 + "[AL-AIN, United, Arab, ...]", "[1, 2...]", "[3, 4...]", "[3, 4...]", "[4, 5, 6,...]", 6 + "[...]", "[...]", "[...]", "[...]", "[...]", . + + 其中words, seq_len是input; pos, chunk, ner, seq_len是target + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+-------+-------+-------+-------+---------+ + | field_names | raw_words | pos | chunk | ner | words | seq_len | + +-------------+-----------+-------+-------+-------+-------+---------+ + | is_input | False | False | False | False | True | True | + | is_target | False | True | True | True | False | True | + | ignore_type | | False | False | False | False | False | + | pad_value | | 0 | 0 | 0 | 0 | 0 | + +-------------+-----------+-------+-------+-------+-------+---------+ + + + """ + def __init__(self, chunk_encoding_type='bioes', ner_encoding_type='bioes', lower: bool = False): + r""" + + :param str chunk_encoding_type: 支持bioes, bio。 + :param str ner_encoding_type: 支持bioes, bio。 + :param bool lower: 是否将words列小写化后再建立词表 + """ + if chunk_encoding_type == 'bio': + self.chunk_convert_tag = iob2 + elif chunk_encoding_type == 'bioes': + self.chunk_convert_tag = lambda tags: iob2bioes(iob2(tags)) + else: + raise ValueError("chunk_encoding_type only supports `bio` and `bioes`.") + if ner_encoding_type == 'bio': + self.ner_convert_tag = iob2 + elif ner_encoding_type == 'bioes': + self.ner_convert_tag = lambda tags: iob2bioes(iob2(tags)) + else: + raise ValueError("ner_encoding_type only supports `bio` and `bioes`.") + self.lower = lower + + def process(self, data_bundle) -> DataBundle: + r""" + 输入的DataSet应该类似于如下的形式 + + .. csv-table:: + :header: "raw_words", "pos", "chunk", "ner" + + "[Nadim, Ladki]", "[NNP, NNP]", "[B-NP, I-NP]", "[B-PER, I-PER]" + "[AL-AIN, United, Arab, ...]", "[NNP, NNP...]", "[B-NP, B-NP, ...]", "[B-LOC, B-LOC,...]" + "[...]", "[...]", "[...]", "[...]", . + + :param data_bundle: + :return: 传入的DataBundle + """ + # 转换tag + for name, dataset in data_bundle.datasets.items(): + dataset.drop(lambda x: "-DOCSTART-" in x['raw_words']) + dataset.apply_field(self.chunk_convert_tag, field_name='chunk', new_field_name='chunk') + dataset.apply_field(self.ner_convert_tag, field_name='ner', new_field_name='ner') + + _add_words_field(data_bundle, lower=self.lower) + + # index + _indexize(data_bundle, input_field_names='words', target_field_names=['pos', 'ner']) + # chunk中存在一些tag只在dev中出现,没在train中 + tgt_vocab = Vocabulary(unknown=None, padding=None) + tgt_vocab.from_dataset(*data_bundle.datasets.values(), field_name='chunk') + tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name='chunk') + data_bundle.set_vocab(tgt_vocab, 'chunk') + + input_fields = ['words', 'seq_len'] + target_fields = ['pos', 'ner', 'chunk', 'seq_len'] + + for name, dataset in data_bundle.iter_datasets(): + dataset.add_seq_len('words') + + data_bundle.set_input(*input_fields, *target_fields) + + return data_bundle + + def process_from_file(self, paths): + r""" + + :param paths: + :return: + """ + data_bundle = ConllLoader(headers=['raw_words', 'pos', 'chunk', 'ner']).load(paths) + return self.process(data_bundle) + + +class OntoNotesNERPipe(_NERPipe): + r""" + 处理OntoNotes的NER数据,处理之后DataSet中的field情况为 + + .. csv-table:: + :header: "raw_words", "target", "words", "seq_len" + + "[Nadim, Ladki]", "[1, 2]", "[2, 3]", 2 + "[AL-AIN, United, Arab, ...]", "[3, 4]", "[4, 5, 6,...]", 6 + "[...]", "[...]", "[...]", . + + raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 + target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target。 + + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_words | target | words | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | True | True | True | + | is_target | False | True | False | True | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ + + """ + + def process_from_file(self, paths): + data_bundle = OntoNotesNERLoader().load(paths) + return self.process(data_bundle) + + +class _CNNERPipe(Pipe): + r""" + 中文NER任务的处理Pipe, 该Pipe会(1)复制raw_chars列,并命名为chars; (2)在chars, target列建立词表 + (创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将chars,target列根据相应的 + Vocabulary转换为index。 + + raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 + target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target, seq_len。 + + """ + + def __init__(self, encoding_type: str = 'bio', bigrams=False, trigrams=False): + r""" + + :param str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 + :param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 + 设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 + data_bundle.get_vocab('bigrams')获取. + :param bool trigrams: 是否增加一列trigrams. trigrams的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] + 。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 + data_bundle.get_vocab('trigrams')获取. + """ + if encoding_type == 'bio': + self.convert_tag = iob2 + elif encoding_type == 'bioes': + self.convert_tag = lambda words: iob2bioes(iob2(words)) + else: + raise ValueError("encoding_type only supports `bio` and `bioes`.") + + self.bigrams = bigrams + self.trigrams = trigrams + + def process(self, data_bundle: DataBundle) -> DataBundle: + r""" + 支持的DataSet的field为 + + .. csv-table:: + :header: "raw_chars", "target" + + "[相, 比, 之, 下,...]", "[O, O, O, O, ...]" + "[青, 岛, 海, 牛, 队, 和, ...]", "[B-ORG, I-ORG, I-ORG, ...]" + "[...]", "[...]" + + raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int], + 是转换为index的target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 + + :param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]。在传入DataBundle基础上原位修改。 + :return: DataBundle + """ + # 转换tag + for name, dataset in data_bundle.datasets.items(): + dataset.apply_field(self.convert_tag, field_name='target', new_field_name='target') + + _add_chars_field(data_bundle, lower=False) + + input_field_names = ['chars'] + if self.bigrams: + for name, dataset in data_bundle.iter_datasets(): + dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + [''])], + field_name='chars', new_field_name='bigrams') + input_field_names.append('bigrams') + if self.trigrams: + for name, dataset in data_bundle.datasets.items(): + dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in + zip(chars, chars[1:] + [''], chars[2:] + [''] * 2)], + field_name='chars', new_field_name='trigrams') + input_field_names.append('trigrams') + + # index + _indexize(data_bundle, input_field_names, 'target') + + input_fields = ['target', 'seq_len'] + input_field_names + target_fields = ['target', 'seq_len'] + + for name, dataset in data_bundle.iter_datasets(): + dataset.add_seq_len('chars') + + data_bundle.set_input(*input_fields, *target_fields) + + return data_bundle + + +class MsraNERPipe(_CNNERPipe): + r""" + 处理MSRA-NER的数据,处理之后的DataSet的field情况为 + + .. csv-table:: + :header: "raw_chars", "target", "chars", "seq_len" + + "[相, 比, 之, 下,...]", "[0, 0, 0, 0, ...]", "[2, 3, 4, 5, ...]", 11 + "[青, 岛, 海, 牛, 队, 和, ...]", "[1, 2, 3, ...]", "[10, 21, ....]", 21 + "[...]", "[...]", "[...]", . + + raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 + target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 + + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_chars | target | chars | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | True | True | True | + | is_target | False | True | False | True | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ + + """ + + def process_from_file(self, paths=None) -> DataBundle: + data_bundle = MsraNERLoader().load(paths) + return self.process(data_bundle) + + +class PeopleDailyPipe(_CNNERPipe): + r""" + 处理people daily的ner的数据,处理之后的DataSet的field情况为 + + .. csv-table:: + :header: "raw_chars", "target", "chars", "seq_len" + + "[相, 比, 之, 下,...]", "[0, 0, 0, 0, ...]", "[2, 3, 4, 5, ...]", 11 + "[青, 岛, 海, 牛, 队, 和, ...]", "[1, 2, 3, ...]", "[10, 21, ....]", 21 + "[...]", "[...]", "[...]", . + + raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 + target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 + + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_chars | target | chars | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | True | True | True | + | is_target | False | True | False | True | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ + + """ + + def process_from_file(self, paths=None) -> DataBundle: + data_bundle = PeopleDailyNERLoader().load(paths) + return self.process(data_bundle) + + +class WeiboNERPipe(_CNNERPipe): + r""" + 处理weibo的ner的数据,处理之后的DataSet的field情况为 + + .. csv-table:: + :header: "raw_chars", "chars", "target", "seq_len" + + "['老', '百', '姓']", "[4, 3, 3]", "[38, 39, 40]", 3 + "['心']", "[0]", "[41]", 1 + "[...]", "[...]", "[...]", . + + raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 + target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 + + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_chars | target | chars | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | True | True | True | + | is_target | False | True | False | True | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ + + """ + + def process_from_file(self, paths=None) -> DataBundle: + data_bundle = WeiboNERLoader().load(paths) + return self.process(data_bundle) diff --git a/fastNLP/io/pipe/construct_graph.py b/fastNLP/io/pipe/construct_graph.py new file mode 100644 index 00000000..1448765e --- /dev/null +++ b/fastNLP/io/pipe/construct_graph.py @@ -0,0 +1,286 @@ +__all__ = [ + 'MRPmiGraphPipe', + 'R8PmiGraphPipe', + 'R52PmiGraphPipe', + 'OhsumedPmiGraphPipe', + 'NG20PmiGraphPipe' +] +try: + import networkx as nx + from sklearn.feature_extraction.text import CountVectorizer + from sklearn.feature_extraction.text import TfidfTransformer + from sklearn.pipeline import Pipeline +except: + pass +from collections import defaultdict +import itertools +import math +import numpy as np + +from ..data_bundle import DataBundle +# from ...core.const import Const +from ..loader.classification import MRLoader, OhsumedLoader, R52Loader, R8Loader, NG20Loader +from fastNLP.core.utils import f_rich_progress + + +def _get_windows(content_lst: list, window_size: int): + r""" + 滑动窗口处理文本,获取词频和共现词语的词频 + :param content_lst: + :param window_size: + :return: 词频,共现词频,窗口化后文本段的数量 + """ + word_window_freq = defaultdict(int) # w(i) 单词在窗口单位内出现的次数 + word_pair_count = defaultdict(int) # w(i, j) + windows_len = 0 + task_id = f_rich_progress.add_task(description="Split by window", total=len(content_lst)) + for words in content_lst: + windows = list() + + if isinstance(words, str): + words = words.split() + length = len(words) + + if length <= window_size: + windows.append(words) + else: + for j in range(length - window_size + 1): + window = words[j: j + window_size] + windows.append(list(set(window))) + + for window in windows: + for word in window: + word_window_freq[word] += 1 + + for word_pair in itertools.combinations(window, 2): + word_pair_count[word_pair] += 1 + + windows_len += len(windows) + + f_rich_progress.update(task_id, advance=1) + f_rich_progress.destroy_task(task_id) + return word_window_freq, word_pair_count, windows_len + + +def _cal_pmi(W_ij, W, word_freq_i, word_freq_j): + r""" + params: w_ij:为词语i,j的共现词频 + w:文本数量 + word_freq_i: 词语i的词频 + word_freq_j: 词语j的词频 + return: 词语i,j的tfidf值 + """ + p_i = word_freq_i / W + p_j = word_freq_j / W + p_i_j = W_ij / W + pmi = math.log(p_i_j / (p_i * p_j)) + + return pmi + + +def _count_pmi(windows_len, word_pair_count, word_window_freq, threshold): + r""" + params: windows_len: 文本段数量 + word_pair_count: 词共现频率字典 + word_window_freq: 词频率字典 + threshold: 阈值 + return 词语pmi的list列表,其中元素为[word1, word2, pmi] + """ + word_pmi_lst = list() + task_id = f_rich_progress.add_task(description="Calculate pmi between words", total=len(word_pair_count)) + for word_pair, W_i_j in word_pair_count.items(): + word_freq_1 = word_window_freq[word_pair[0]] + word_freq_2 = word_window_freq[word_pair[1]] + + pmi = _cal_pmi(W_i_j, windows_len, word_freq_1, word_freq_2) + if pmi <= threshold: + continue + word_pmi_lst.append([word_pair[0], word_pair[1], pmi]) + + f_rich_progress.update(task_id, advance=1) + f_rich_progress.destory_task(task_id) + return word_pmi_lst + + +class GraphBuilderBase: + def __init__(self, graph_type='pmi', widow_size=10, threshold=0.): + self.graph = nx.Graph() + self.word2id = dict() + self.graph_type = graph_type + self.window_size = widow_size + self.doc_node_num = 0 + self.tr_doc_index = None + self.te_doc_index = None + self.dev_doc_index = None + self.doc = None + self.threshold = threshold + + def _get_doc_edge(self, data_bundle: DataBundle): + r""" + 对输入的DataBundle进行处理,然后生成文档-单词的tfidf值 + :param: data_bundle中的文本若为英文,形式为[ 'This is the first document.'],若为中文则为['他 喜欢 吃 苹果'] + : return 返回带有具有tfidf边文档-单词稀疏矩阵 + """ + tr_doc = list(data_bundle.get_dataset("train").get_field('raw_words')) + val_doc = list(data_bundle.get_dataset("dev").get_field('raw_words')) + te_doc = list(data_bundle.get_dataset("test").get_field('raw_words')) + doc = tr_doc + val_doc + te_doc + self.doc = doc + self.tr_doc_index = [ind for ind in range(len(tr_doc))] + self.dev_doc_index = [ind + len(tr_doc) for ind in range(len(val_doc))] + self.te_doc_index = [ind + len(tr_doc) + len(val_doc) for ind in range(len(te_doc))] + text_tfidf = Pipeline([('count', CountVectorizer(token_pattern=r'\S+', min_df=1, max_df=1.0)), + ('tfidf', + TfidfTransformer(norm=None, use_idf=True, smooth_idf=False, sublinear_tf=False))]) + + tfidf_vec = text_tfidf.fit_transform(doc) + self.doc_node_num = tfidf_vec.shape[0] + vocab_lst = text_tfidf['count'].get_feature_names() + for ind, word in enumerate(vocab_lst): + self.word2id[word] = ind + for ind, row in enumerate(tfidf_vec): + for col_index, value in zip(row.indices, row.data): + self.graph.add_edge(ind, self.doc_node_num + col_index, weight=value) + return nx.to_scipy_sparse_matrix(self.graph) + + def _get_word_edge(self): + word_window_freq, word_pair_count, windows_len = _get_windows(self.doc, self.window_size) + pmi_edge_lst = _count_pmi(windows_len, word_pair_count, word_window_freq, self.threshold) + for edge_item in pmi_edge_lst: + word_indx1 = self.doc_node_num + self.word2id[edge_item[0]] + word_indx2 = self.doc_node_num + self.word2id[edge_item[1]] + if word_indx1 == word_indx2: + continue + self.graph.add_edge(word_indx1, word_indx2, weight=edge_item[2]) + + def build_graph(self, data_bundle: DataBundle): + r""" + 对输入的DataBundle进行处理,然后返回该scipy_sparse_matrix类型的邻接矩阵。 + + :param ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象 + :return: + """ + raise NotImplementedError + + def build_graph_from_file(self, path: str): + r""" + 传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` + + :param path: + :return: scipy_sparse_matrix + """ + raise NotImplementedError + + +class MRPmiGraphPipe(GraphBuilderBase): + + def __init__(self, graph_type='pmi', widow_size=10, threshold=0.): + super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold) + + def build_graph(self, data_bundle: DataBundle): + r""" + params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象. + return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index. + """ + self._get_doc_edge(data_bundle) + self._get_word_edge() + return nx.to_scipy_sparse_matrix(self.graph, + nodelist=list(range(self.graph.number_of_nodes())), + weight='weight', dtype=np.float32, format='csr'), ( + self.tr_doc_index, self.dev_doc_index, self.te_doc_index) + + def build_graph_from_file(self, path: str): + data_bundle = MRLoader().load(path) + return self.build_graph(data_bundle) + + +class R8PmiGraphPipe(GraphBuilderBase): + + def __init__(self, graph_type='pmi', widow_size=10, threshold=0.): + super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold) + + def build_graph(self, data_bundle: DataBundle): + r""" + params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象. + return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index. + """ + self._get_doc_edge(data_bundle) + self._get_word_edge() + return nx.to_scipy_sparse_matrix(self.graph, + nodelist=list(range(self.graph.number_of_nodes())), + weight='weight', dtype=np.float32, format='csr'), ( + self.tr_doc_index, self.dev_doc_index, self.te_doc_index) + + def build_graph_from_file(self, path: str): + data_bundle = R8Loader().load(path) + return self.build_graph(data_bundle) + + +class R52PmiGraphPipe(GraphBuilderBase): + + def __init__(self, graph_type='pmi', widow_size=10, threshold=0.): + super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold) + + def build_graph(self, data_bundle: DataBundle): + r""" + params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象. + return 返回csr类型的稀疏矩阵;训练集,验证集,测试集,在图中的index. + """ + self._get_doc_edge(data_bundle) + self._get_word_edge() + return nx.to_scipy_sparse_matrix(self.graph, + nodelist=list(range(self.graph.number_of_nodes())), + weight='weight', dtype=np.float32, format='csr'), ( + self.tr_doc_index, self.dev_doc_index, self.te_doc_index) + + def build_graph_from_file(self, path: str): + data_bundle = R52Loader().load(path) + return self.build_graph(data_bundle) + + +class OhsumedPmiGraphPipe(GraphBuilderBase): + + def __init__(self, graph_type='pmi', widow_size=10, threshold=0.): + super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold) + + def build_graph(self, data_bundle: DataBundle): + r""" + params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象. + return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index. + """ + self._get_doc_edge(data_bundle) + self._get_word_edge() + return nx.to_scipy_sparse_matrix(self.graph, + nodelist=list(range(self.graph.number_of_nodes())), + weight='weight', dtype=np.float32, format='csr'), ( + self.tr_doc_index, self.dev_doc_index, self.te_doc_index) + + def build_graph_from_file(self, path: str): + data_bundle = OhsumedLoader().load(path) + return self.build_graph(data_bundle) + + +class NG20PmiGraphPipe(GraphBuilderBase): + + def __init__(self, graph_type='pmi', widow_size=10, threshold=0.): + super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold) + + def build_graph(self, data_bundle: DataBundle): + r""" + params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象. + return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index. + """ + self._get_doc_edge(data_bundle) + self._get_word_edge() + return nx.to_scipy_sparse_matrix(self.graph, + nodelist=list(range(self.graph.number_of_nodes())), + weight='weight', dtype=np.float32, format='csr'), ( + self.tr_doc_index, self.dev_doc_index, self.te_doc_index) + + def build_graph_from_file(self, path: str): + r""" + param: path->数据集的路径. + return: 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index. + """ + data_bundle = NG20Loader().load(path) + return self.build_graph(data_bundle) diff --git a/fastNLP/io/pipe/coreference.py b/fastNLP/io/pipe/coreference.py new file mode 100644 index 00000000..6d35cd1b --- /dev/null +++ b/fastNLP/io/pipe/coreference.py @@ -0,0 +1,186 @@ +r"""undocumented""" + +__all__ = [ + "CoReferencePipe" +] + +import collections + +import numpy as np + +from fastNLP.core.vocabulary import Vocabulary +from .pipe import Pipe +from ..data_bundle import DataBundle +from ..loader.coreference import CoReferenceLoader + + +# from ...core.const import Const + + +class CoReferencePipe(Pipe): + r""" + 对Coreference resolution问题进行处理,得到文章种类/说话者/字符级信息/序列长度。 + + 处理完成后数据包含文章类别、speaker信息、句子信息、句子对应的index、char、句子长度、target: + + .. csv-table:: + :header: "words1", "words2","words3","words4","chars","seq_len","target" + + "bc", "[[0,0],[1,1]]","[['I','am'],[]]","[[1,2],[]]","[[[1],[2,3]],[]]","[2,3]","[[[2,3],[6,7]],[[10,12],[20,22]]]" + "[...]", "[...]","[...]","[...]","[...]","[...]","[...]" + + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_chars | target | chars | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | True | True | True | + | is_target | False | True | False | True | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ + + """ + + def __init__(self, config): + super().__init__() + self.config = config + + def process(self, data_bundle: DataBundle): + r""" + 对load进来的数据进一步处理原始数据包含:raw_key,raw_speaker,raw_words,raw_clusters + + .. csv-table:: + :header: "raw_key", "raw_speaker","raw_words","raw_clusters" + + "bc/cctv/00/cctv_0000_0", "[[Speaker#1, Speaker#1],[]]","[['I','am'],[]]","[[[2,3],[6,7]],[[10,12],[20,22]]]" + "bc/cctv/00/cctv_0000_1", "[['Speaker#1', 'peaker#1'],[]]","[['He','is'],[]]","[[[2,3],[6,7]],[[10,12],[20,22]]]" + "[...]", "[...]","[...]","[...]" + + + :param data_bundle: + :return: + """ + genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])} + vocab = Vocabulary().from_dataset(*data_bundle.datasets.values(), field_name='raw_words4') + vocab.build_vocab() + word2id = vocab.word2idx + data_bundle.set_vocab(vocab, 'words1') + if self.config.char_path: + char_dict = get_char_dict(self.config.char_path) + else: + char_set = set() + for i, w in enumerate(word2id): + if i < 2: + continue + for c in w: + char_set.add(c) + + char_dict = collections.defaultdict(int) + char_dict.update({c: i for i, c in enumerate(char_set)}) + + for name, ds in data_bundle.iter_datasets(): + # genre + ds.apply(lambda x: genres[x['raw_words1'][:2]], new_field_name='words1') + + # speaker_ids_np + ds.apply(lambda x: speaker2numpy(x['raw_words2'], self.config.max_sentences, is_train=name == 'train'), + new_field_name='words2') + + # sentences + ds.rename_field('raw_words4', 'words3') + + # doc_np + ds.apply(lambda x: doc2numpy(x['words3'], word2id, char_dict, max(self.config.filter), + self.config.max_sentences, is_train=name == 'train')[0], + new_field_name='words4') + # char_index + ds.apply(lambda x: doc2numpy(x['words3'], word2id, char_dict, max(self.config.filter), + self.config.max_sentences, is_train=name == 'train')[1], + new_field_name='chars') + # seq len + ds.apply(lambda x: doc2numpy(x['words3'], word2id, char_dict, max(self.config.filter), + self.config.max_sentences, is_train=name == 'train')[2], + new_field_name='seq_len') + + # clusters + ds.rename_field('raw_words3', 'target') + + ds.set_input('words1', 'words2', 'words3', 'words4', 'chars', 'seq_len', 'target') + + return data_bundle + + def process_from_file(self, paths): + bundle = CoReferenceLoader().load(paths) + return self.process(bundle) + + +# helper + +def doc2numpy(doc, word2id, chardict, max_filter, max_sentences, is_train): + docvec, char_index, length, max_len = _doc2vec(doc, word2id, chardict, max_filter, max_sentences, is_train) + assert max(length) == max_len + assert char_index.shape[0] == len(length) + assert char_index.shape[1] == max_len + doc_np = np.zeros((len(docvec), max_len), int) + for i in range(len(docvec)): + for j in range(len(docvec[i])): + doc_np[i][j] = docvec[i][j] + return doc_np, char_index, length + + +def _doc2vec(doc, word2id, char_dict, max_filter, max_sentences, is_train): + max_len = 0 + max_word_length = 0 + docvex = [] + length = [] + if is_train: + sent_num = min(max_sentences, len(doc)) + else: + sent_num = len(doc) + + for i in range(sent_num): + sent = doc[i] + length.append(len(sent)) + if (len(sent) > max_len): + max_len = len(sent) + sent_vec = [] + for j, word in enumerate(sent): + if len(word) > max_word_length: + max_word_length = len(word) + if word in word2id: + sent_vec.append(word2id[word]) + else: + sent_vec.append(word2id["UNK"]) + docvex.append(sent_vec) + + char_index = np.zeros((sent_num, max_len, max_word_length), dtype=int) + for i in range(sent_num): + sent = doc[i] + for j, word in enumerate(sent): + char_index[i, j, :len(word)] = [char_dict[c] for c in word] + + return docvex, char_index, length, max_len + + +def speaker2numpy(speakers_raw, max_sentences, is_train): + if is_train and len(speakers_raw) > max_sentences: + speakers_raw = speakers_raw[0:max_sentences] + speakers = flatten(speakers_raw) + speaker_dict = {s: i for i, s in enumerate(set(speakers))} + speaker_ids = np.array([speaker_dict[s] for s in speakers]) + return speaker_ids + + +# 展平 +def flatten(l): + return [item for sublist in l for item in sublist] + + +def get_char_dict(path): + vocab = [""] + with open(path) as f: + vocab.extend(c.strip() for c in f.readlines()) + char_dict = collections.defaultdict(int) + char_dict.update({c: i for i, c in enumerate(vocab)}) + return char_dict diff --git a/fastNLP/io/pipe/cws.py b/fastNLP/io/pipe/cws.py new file mode 100644 index 00000000..9ef19097 --- /dev/null +++ b/fastNLP/io/pipe/cws.py @@ -0,0 +1,282 @@ +r"""undocumented""" + +__all__ = [ + "CWSPipe" +] + +import re +from itertools import chain + +from .pipe import Pipe +from .utils import _indexize +from fastNLP.io.data_bundle import DataBundle +from fastNLP.io.loader import CWSLoader +# from ...core.const import Const + + +def _word_lens_to_bmes(word_lens): + r""" + + :param list word_lens: List[int], 每个词语的长度 + :return: List[str], BMES的序列 + """ + tags = [] + for word_len in word_lens: + if word_len == 1: + tags.append('S') + else: + tags.append('B') + tags.extend(['M'] * (word_len - 2)) + tags.append('E') + return tags + + +def _word_lens_to_segapp(word_lens): + r""" + + :param list word_lens: List[int], 每个词语的长度 + :return: List[str], BMES的序列 + """ + tags = [] + for word_len in word_lens: + if word_len == 1: + tags.append('SEG') + else: + tags.extend(['APP'] * (word_len - 1)) + tags.append('SEG') + return tags + + +def _alpha_span_to_special_tag(span): + r""" + 将span替换成特殊的字符 + + :param str span: + :return: + """ + if 'oo' == span.lower(): # speical case when represent 2OO8 + return span + if len(span) == 1: + return span + else: + return '' + + +def _find_and_replace_alpha_spans(line): + r""" + 传入原始句子,替换其中的字母为特殊标记 + + :param str line:原始数据 + :return: str + """ + new_line = '' + pattern = '[a-zA-Z]+(?=[\u4e00-\u9fff ,%,.。!<-“])' + prev_end = 0 + for match in re.finditer(pattern, line): + start, end = match.span() + span = line[start:end] + new_line += line[prev_end:start] + _alpha_span_to_special_tag(span) + prev_end = end + new_line += line[prev_end:] + return new_line + + +def _digit_span_to_special_tag(span): + r""" + + :param str span: 需要替换的str + :return: + """ + if span[0] == '0' and len(span) > 2: + return '' + decimal_point_count = 0 # one might have more than one decimal pointers + for idx, char in enumerate(span): + if char == '.' or char == '﹒' or char == '·': + decimal_point_count += 1 + if span[-1] == '.' or span[-1] == '﹒' or span[ + -1] == '·': # last digit being decimal point means this is not a number + if decimal_point_count == 1: + return span + else: + return '' + if decimal_point_count == 1: + return '' + elif decimal_point_count > 1: + return '' + else: + return '' + + +def _find_and_replace_digit_spans(line): + r""" + only consider words start with number, contains '.', characters. + + If ends with space, will be processed + + If ends with Chinese character, will be processed + + If ends with or contains english char, not handled. + + floats are replaced by + + otherwise unkdgt + """ + new_line = '' + pattern = r'\d[\d\\.﹒·]*(?=[\u4e00-\u9fff ,%%,。!<-“])' + prev_end = 0 + for match in re.finditer(pattern, line): + start, end = match.span() + span = line[start:end] + new_line += line[prev_end:start] + _digit_span_to_special_tag(span) + prev_end = end + new_line += line[prev_end:] + return new_line + + +class CWSPipe(Pipe): + r""" + 对CWS数据进行预处理, 处理之后的数据,具备以下的结构 + + .. csv-table:: + :header: "raw_words", "chars", "target", "seq_len" + + "共同 创造 美好...", "[2, 3, 4...]", "[0, 2, 0, 2,...]", 13 + "2001年 新年 钟声...", "[8, 9, 9, 7, ...]", "[0, 1, 1, 1, 2...]", 20 + "...", "[...]","[...]", . + + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+-------+--------+---------+ + | field_names | raw_words | chars | target | seq_len | + +-------------+-----------+-------+--------+---------+ + | is_input | False | True | True | True | + | is_target | False | False | True | True | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+-------+--------+---------+ + + """ + + def __init__(self, dataset_name=None, encoding_type='bmes', replace_num_alpha=True, bigrams=False, trigrams=False): + r""" + + :param str,None dataset_name: 支持'pku', 'msra', 'cityu', 'as', None + :param str encoding_type: 可以选择'bmes', 'segapp'两种。"我 来自 复旦大学...", bmes的tag为[S, B, E, B, M, M, E...]; segapp + 的tag为[seg, app, seg, app, app, app, seg, ...] + :param bool replace_num_alpha: 是否将数字和字母用特殊字符替换。 + :param bool bigrams: 是否增加一列bigram. bigram的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...] + :param bool trigrams: 是否增加一列trigram. trigram的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] + """ + if encoding_type == 'bmes': + self.word_lens_to_tags = _word_lens_to_bmes + else: + self.word_lens_to_tags = _word_lens_to_segapp + + self.dataset_name = dataset_name + self.bigrams = bigrams + self.trigrams = trigrams + self.replace_num_alpha = replace_num_alpha + + def _tokenize(self, data_bundle): + r""" + 将data_bundle中的'chars'列切分成一个一个的word. + 例如输入是"共同 创造 美好.."->[[共, 同], [创, 造], [...], ] + + :param data_bundle: + :return: + """ + def split_word_into_chars(raw_chars): + words = raw_chars.split() + chars = [] + for word in words: + char = [] + subchar = [] + for c in word: + if c == '<': + if subchar: + char.extend(subchar) + subchar = [] + subchar.append(c) + continue + if c == '>' and len(subchar)>0 and subchar[0] == '<': + subchar.append(c) + char.append(''.join(subchar)) + subchar = [] + continue + if subchar: + subchar.append(c) + else: + char.append(c) + char.extend(subchar) + chars.append(char) + return chars + + for name, dataset in data_bundle.iter_datasets(): + dataset.apply_field(split_word_into_chars, field_name='chars', + new_field_name='chars') + return data_bundle + + def process(self, data_bundle: DataBundle) -> DataBundle: + r""" + 可以处理的DataSet需要包含raw_words列 + + .. csv-table:: + :header: "raw_words" + + "上海 浦东 开发 与 法制 建设 同步" + "新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )" + "..." + + :param data_bundle: + :return: + """ + data_bundle.copy_field('raw_words', 'chars') + + if self.replace_num_alpha: + data_bundle.apply_field(_find_and_replace_alpha_spans, 'chars', 'chars') + data_bundle.apply_field(_find_and_replace_digit_spans, 'chars', 'chars') + + self._tokenize(data_bundle) + + for name, dataset in data_bundle.iter_datasets(): + dataset.apply_field(lambda chars: self.word_lens_to_tags(map(len, chars)), field_name='chars', + new_field_name='target') + dataset.apply_field(lambda chars: list(chain(*chars)), field_name='chars', + new_field_name='chars') + input_field_names = ['chars'] + if self.bigrams: + for name, dataset in data_bundle.iter_datasets(): + dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + [''])], + field_name='chars', new_field_name='bigrams') + input_field_names.append('bigrams') + if self.trigrams: + for name, dataset in data_bundle.iter_datasets(): + dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in + zip(chars, chars[1:] + [''], chars[2:] + [''] * 2)], + field_name='chars', new_field_name='trigrams') + input_field_names.append('trigrams') + + _indexize(data_bundle, input_field_names, 'target') + + input_fields = ['target', 'seq_len'] + input_field_names + target_fields = ['target', 'seq_len'] + for name, dataset in data_bundle.iter_datasets(): + dataset.add_seq_len('chars') + + data_bundle.set_input(*input_fields, *target_fields) + + return data_bundle + + def process_from_file(self, paths=None) -> DataBundle: + r""" + + :param str paths: + :return: + """ + if self.dataset_name is None and paths is None: + raise RuntimeError( + "You have to set `paths` when calling process_from_file() or `dataset_name `when initialization.") + if self.dataset_name is not None and paths is not None: + raise RuntimeError("You cannot specify `paths` and `dataset_name` simultaneously") + data_bundle = CWSLoader(self.dataset_name).load(paths) + return self.process(data_bundle) diff --git a/fastNLP/io/pipe/matching.py b/fastNLP/io/pipe/matching.py new file mode 100644 index 00000000..76626894 --- /dev/null +++ b/fastNLP/io/pipe/matching.py @@ -0,0 +1,545 @@ +r"""undocumented""" + +__all__ = [ + "MatchingBertPipe", + "RTEBertPipe", + "SNLIBertPipe", + "QuoraBertPipe", + "QNLIBertPipe", + "MNLIBertPipe", + "CNXNLIBertPipe", + "BQCorpusBertPipe", + "LCQMCBertPipe", + "MatchingPipe", + "RTEPipe", + "SNLIPipe", + "QuoraPipe", + "QNLIPipe", + "MNLIPipe", + "LCQMCPipe", + "CNXNLIPipe", + "BQCorpusPipe", + "RenamePipe", + "GranularizePipe", + "MachingTruncatePipe", +] + +import warnings + +from .pipe import Pipe +from .utils import get_tokenizer +from ..data_bundle import DataBundle +from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader, BQCorpusLoader, CNXNLILoader, \ + LCQMCLoader +# from ...core._logger import log +# from ...core.const import Const +from ...core.vocabulary import Vocabulary + + +class MatchingBertPipe(Pipe): + r""" + Matching任务的Bert pipe,输出的DataSet将包含以下的field + + .. csv-table:: + :header: "raw_words1", "raw_words2", "target", "words", "seq_len" + + "The new rights are...", "Everyone really likes..", 1, "[2, 3, 4, 5, ...]", 10 + "This site includes a...", "The Government Executive...", 0, "[11, 12, 13,...]", 5 + "...", "...", ., "[...]", . + + words列是将raw_words1(即premise), raw_words2(即hypothesis)使用"[SEP]"链接起来转换为index的。 + words列被设置为input,target列被设置为target和input(设置为input以方便在forward函数中计算loss, + 如果不在forward函数中计算loss也不影响,fastNLP将根据forward函数的形参名进行传参). + + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+------------+------------+--------+-------+---------+ + | field_names | raw_words1 | raw_words2 | target | words | seq_len | + +-------------+------------+------------+--------+-------+---------+ + | is_input | False | False | False | True | True | + | is_target | False | False | True | False | False | + | ignore_type | | | False | False | False | + | pad_value | | | 0 | 0 | 0 | + +-------------+------------+------------+--------+-------+---------+ + + """ + + def __init__(self, lower=False, tokenizer: str = 'raw'): + r""" + + :param bool lower: 是否将word小写化。 + :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 + """ + super().__init__() + + self.lower = bool(lower) + self.tokenizer = get_tokenizer(tokenize_method=tokenizer) + + def _tokenize(self, data_bundle, field_names, new_field_names): + r""" + + :param DataBundle data_bundle: DataBundle. + :param list field_names: List[str], 需要tokenize的field名称 + :param list new_field_names: List[str], tokenize之后field的名称,与field_names一一对应。 + :return: 输入的DataBundle对象 + """ + for name, dataset in data_bundle.iter_datasets(): + for field_name, new_field_name in zip(field_names, new_field_names): + dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, + new_field_name=new_field_name) + return data_bundle + + def process(self, data_bundle): + r""" + 输入的data_bundle中的dataset需要具有以下结构: + + .. csv-table:: + :header: "raw_words1", "raw_words2", "target" + + "Dana Reeve, the widow of the actor...", "Christopher Reeve had an...", "not_entailment" + "...","..." + + :param data_bundle: + :return: + """ + for dataset in data_bundle.datasets.values(): + if dataset.has_field('target'): + dataset.drop(lambda x: x['target'] == '-') + + for name, dataset in data_bundle.datasets.items(): + dataset.copy_field('raw_words1', 'words1', ) + dataset.copy_field('raw_words2', 'words2', ) + + if self.lower: + for name, dataset in data_bundle.datasets.items(): + dataset['words1'].lower() + dataset['words2'].lower() + + data_bundle = self._tokenize(data_bundle, ['words1', 'words2'], + ['words1', 'words2']) + + # concat两个words + def concat(ins): + words0 = ins['words1'] + words1 = ins['words2'] + words = words0 + ['[SEP]'] + words1 + return words + + for name, dataset in data_bundle.datasets.items(): + dataset.apply(concat, new_field_name='words') + dataset.delete_field('words1') + dataset.delete_field('words2') + + word_vocab = Vocabulary() + word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name], + field_name='words', + no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if + 'train' not in name]) + word_vocab.index_dataset(*data_bundle.datasets.values(), field_name='words') + + target_vocab = Vocabulary(padding=None, unknown=None) + target_vocab.from_dataset(*[ds for name, ds in data_bundle.iter_datasets() if 'train' in name], + field_name='target', + no_create_entry_dataset=[ds for name, ds in data_bundle.iter_datasets() + if ('train' not in name) and (ds.has_field('target'))] + ) + if len(target_vocab._no_create_word) > 0: + warn_msg = f"There are {len(target_vocab._no_create_word)} target labels" \ + f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ + f"data set but not in train data set!." + warnings.warn(warn_msg) + print(warn_msg) + + has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if + dataset.has_field('target')] + target_vocab.index_dataset(*has_target_datasets, field_name='target') + + data_bundle.set_vocab(word_vocab, 'words') + data_bundle.set_vocab(target_vocab, 'target') + + input_fields = ['words', 'seq_len'] + target_fields = ['target'] + + for name, dataset in data_bundle.iter_datasets(): + dataset.add_seq_len('words') + dataset.set_input(*input_fields) + for fields in target_fields: + if dataset.has_field(fields): + dataset.set_input(fields) + + return data_bundle + + +class RTEBertPipe(MatchingBertPipe): + def process_from_file(self, paths=None): + data_bundle = RTELoader().load(paths) + return self.process(data_bundle) + + +class SNLIBertPipe(MatchingBertPipe): + def process_from_file(self, paths=None): + data_bundle = SNLILoader().load(paths) + return self.process(data_bundle) + + +class QuoraBertPipe(MatchingBertPipe): + def process_from_file(self, paths): + data_bundle = QuoraLoader().load(paths) + return self.process(data_bundle) + + +class QNLIBertPipe(MatchingBertPipe): + def process_from_file(self, paths=None): + data_bundle = QNLILoader().load(paths) + return self.process(data_bundle) + + +class MNLIBertPipe(MatchingBertPipe): + def process_from_file(self, paths=None): + data_bundle = MNLILoader().load(paths) + return self.process(data_bundle) + + +class MatchingPipe(Pipe): + r""" + Matching任务的Pipe。输出的DataSet将包含以下的field + + .. csv-table:: + :header: "raw_words1", "raw_words2", "target", "words1", "words2", "seq_len1", "seq_len2" + + "The new rights are...", "Everyone really likes..", 1, "[2, 3, 4, 5, ...]", "[10, 20, 6]", 10, 13 + "This site includes a...", "The Government Executive...", 0, "[11, 12, 13,...]", "[2, 7, ...]", 6, 7 + "...", "...", ., "[...]", "[...]", ., . + + words1是premise,words2是hypothesis。其中words1,words2,seq_len1,seq_len2被设置为input;target被设置为target + 和input(设置为input以方便在forward函数中计算loss,如果不在forward函数中计算loss也不影响,fastNLP将根据forward函数 + 的形参名进行传参)。 + + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+------------+------------+--------+--------+--------+----------+----------+ + | field_names | raw_words1 | raw_words2 | target | words1 | words2 | seq_len1 | seq_len2 | + +-------------+------------+------------+--------+--------+--------+----------+----------+ + | is_input | False | False | False | True | True | True | True | + | is_target | False | False | True | False | False | False | False | + | ignore_type | | | False | False | False | False | False | + | pad_value | | | 0 | 0 | 0 | 0 | 0 | + +-------------+------------+------------+--------+--------+--------+----------+----------+ + + """ + + def __init__(self, lower=False, tokenizer: str = 'raw'): + r""" + + :param bool lower: 是否将所有raw_words转为小写。 + :param str tokenizer: 将原始数据tokenize的方式。支持spacy, raw. spacy是使用spacy切分,raw就是用空格切分。 + """ + super().__init__() + + self.lower = bool(lower) + self.tokenizer = get_tokenizer(tokenize_method=tokenizer) + + def _tokenize(self, data_bundle, field_names, new_field_names): + r""" + + :param ~fastNLP.DataBundle data_bundle: DataBundle. + :param list field_names: List[str], 需要tokenize的field名称 + :param list new_field_names: List[str], tokenize之后field的名称,与field_names一一对应。 + :return: 输入的DataBundle对象 + """ + for name, dataset in data_bundle.iter_datasets(): + for field_name, new_field_name in zip(field_names, new_field_names): + dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, + new_field_name=new_field_name) + return data_bundle + + def process(self, data_bundle): + r""" + 接受的DataBundle中的DataSet应该具有以下的field, target列可以没有 + + .. csv-table:: + :header: "raw_words1", "raw_words2", "target" + + "The new rights are...", "Everyone really likes..", "entailment" + "This site includes a...", "The Government Executive...", "not_entailment" + "...", "..." + + :param ~fastNLP.DataBundle data_bundle: 通过loader读取得到的data_bundle,里面包含了数据集的原始数据内容 + :return: data_bundle + """ + data_bundle = self._tokenize(data_bundle, ['raw_words1', 'raw_words2'], + ['words1', 'words2']) + + for dataset in data_bundle.datasets.values(): + if dataset.has_field('target'): + dataset.drop(lambda x: x['target'] == '-') + + if self.lower: + for name, dataset in data_bundle.datasets.items(): + dataset['words1'].lower() + dataset['words2'].lower() + + word_vocab = Vocabulary() + word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name], + field_name=['words1', 'words2'], + no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if + 'train' not in name]) + word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=['words1', 'words2']) + + target_vocab = Vocabulary(padding=None, unknown=None) + target_vocab.from_dataset(*[ds for name, ds in data_bundle.iter_datasets() if 'train' in name], + field_name='target', + no_create_entry_dataset=[ds for name, ds in data_bundle.iter_datasets() + if ('train' not in name) and (ds.has_field('target'))] + ) + if len(target_vocab._no_create_word) > 0: + warn_msg = f"There are {len(target_vocab._no_create_word)} target labels" \ + f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ + f"data set but not in train data set!." + warnings.warn(warn_msg) + print(warn_msg) + + has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if + dataset.has_field('target')] + target_vocab.index_dataset(*has_target_datasets, field_name='target') + + data_bundle.set_vocab(word_vocab, 'words1') + data_bundle.set_vocab(target_vocab, 'target') + + input_fields = ['words1', 'words2', 'seq_len1', 'seq_len2'] + target_fields = ['target'] + + for name, dataset in data_bundle.datasets.items(): + dataset.add_seq_len('words1', 'seq_len1') + dataset.add_seq_len('words2', 'seq_len2') + dataset.set_input(*input_fields) + for fields in target_fields: + if dataset.has_field(fields): + dataset.set_input(fields) + + return data_bundle + + +class RTEPipe(MatchingPipe): + def process_from_file(self, paths=None): + data_bundle = RTELoader().load(paths) + return self.process(data_bundle) + + +class SNLIPipe(MatchingPipe): + def process_from_file(self, paths=None): + data_bundle = SNLILoader().load(paths) + return self.process(data_bundle) + + +class QuoraPipe(MatchingPipe): + def process_from_file(self, paths): + data_bundle = QuoraLoader().load(paths) + return self.process(data_bundle) + + +class QNLIPipe(MatchingPipe): + def process_from_file(self, paths=None): + data_bundle = QNLILoader().load(paths) + return self.process(data_bundle) + + +class MNLIPipe(MatchingPipe): + def process_from_file(self, paths=None): + data_bundle = MNLILoader().load(paths) + return self.process(data_bundle) + + +class LCQMCPipe(MatchingPipe): + def __init__(self, tokenizer='cn=char'): + super().__init__(tokenizer=tokenizer) + + def process_from_file(self, paths=None): + data_bundle = LCQMCLoader().load(paths) + data_bundle = RenamePipe().process(data_bundle) + data_bundle = self.process(data_bundle) + data_bundle = RenamePipe().process(data_bundle) + return data_bundle + + +class CNXNLIPipe(MatchingPipe): + def __init__(self, tokenizer='cn-char'): + super().__init__(tokenizer=tokenizer) + + def process_from_file(self, paths=None): + data_bundle = CNXNLILoader().load(paths) + data_bundle = GranularizePipe(task='XNLI').process(data_bundle) + data_bundle = RenamePipe().process(data_bundle) # 使中文数据的field + data_bundle = self.process(data_bundle) + data_bundle = RenamePipe().process(data_bundle) + return data_bundle + + +class BQCorpusPipe(MatchingPipe): + def __init__(self, tokenizer='cn-char'): + super().__init__(tokenizer=tokenizer) + + def process_from_file(self, paths=None): + data_bundle = BQCorpusLoader().load(paths) + data_bundle = RenamePipe().process(data_bundle) + data_bundle = self.process(data_bundle) + data_bundle = RenamePipe().process(data_bundle) + return data_bundle + + +class RenamePipe(Pipe): + def __init__(self, task='cn-nli'): + super().__init__() + self.task = task + + def process(self, data_bundle: DataBundle): # rename field name for Chinese Matching dataset + if (self.task == 'cn-nli'): + for name, dataset in data_bundle.datasets.items(): + if (dataset.has_field('raw_chars1')): + dataset.rename_field('raw_chars1', 'raw_words1') # RAW_CHARS->RAW_WORDS + dataset.rename_field('raw_chars2', 'raw_words2') + elif (dataset.has_field('words1')): + dataset.rename_field('words1', 'chars1') # WORDS->CHARS + dataset.rename_field('words2', 'chars2') + dataset.rename_field('raw_words1', 'raw_chars1') + dataset.rename_field('raw_words2', 'raw_chars2') + else: + raise RuntimeError( + "field name of dataset is not qualified. It should have ether RAW_CHARS or WORDS") + elif (self.task == 'cn-nli-bert'): + for name, dataset in data_bundle.datasets.items(): + if (dataset.has_field('raw_chars1')): + dataset.rename_field('raw_chars1', 'raw_words1') # RAW_CHARS->RAW_WORDS + dataset.rename_field('raw_chars2', 'raw_words2') + elif (dataset.has_field('raw_words1')): + dataset.rename_field('raw_words1', 'raw_chars1') + dataset.rename_field('raw_words2', 'raw_chars2') + dataset.rename_field('words', 'chars') + else: + raise RuntimeError( + "field name of dataset is not qualified. It should have ether RAW_CHARS or RAW_WORDS" + ) + else: + raise RuntimeError( + "Only support task='cn-nli' or 'cn-nli-bert'" + ) + + return data_bundle + + +class GranularizePipe(Pipe): + def __init__(self, task=None): + super().__init__() + self.task = task + + def _granularize(self, data_bundle, tag_map): + r""" + 该函数对data_bundle中'target'列中的内容进行转换。 + + :param data_bundle: + :param dict tag_map: 将target列中的tag做以下的映射,比如{"0":0, "1":0, "3":1, "4":1}, 则会删除target为"2"的instance, + 且将"1"认为是第0类。 + :return: 传入的data_bundle + """ + for name in list(data_bundle.datasets.keys()): + dataset = data_bundle.get_dataset(name) + dataset.apply_field(lambda target: tag_map.get(target, -100), field_name='target', + new_field_name='target') + dataset.drop(lambda ins: ins['target'] == -100) + data_bundle.set_dataset(dataset, name) + return data_bundle + + def process(self, data_bundle: DataBundle): + task_tag_dict = { + 'XNLI': {'neutral': 0, 'entailment': 1, 'contradictory': 2, 'contradiction': 2} + } + if self.task in task_tag_dict: + data_bundle = self._granularize(data_bundle=data_bundle, tag_map=task_tag_dict[self.task]) + else: + raise RuntimeError(f"Only support {task_tag_dict.keys()} task_tag_map.") + return data_bundle + + +class MachingTruncatePipe(Pipe): # truncate sentence for bert, modify seq_len + def __init__(self): + super().__init__() + + def process(self, data_bundle: DataBundle): + for name, dataset in data_bundle.datasets.items(): + pass + return None + + +class LCQMCBertPipe(MatchingBertPipe): + def __init__(self, tokenizer='cn=char'): + super().__init__(tokenizer=tokenizer) + + def process_from_file(self, paths=None): + data_bundle = LCQMCLoader().load(paths) + data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) + data_bundle = self.process(data_bundle) + data_bundle = TruncateBertPipe(task='cn').process(data_bundle) + data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) + return data_bundle + + +class BQCorpusBertPipe(MatchingBertPipe): + def __init__(self, tokenizer='cn-char'): + super().__init__(tokenizer=tokenizer) + + def process_from_file(self, paths=None): + data_bundle = BQCorpusLoader().load(paths) + data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) + data_bundle = self.process(data_bundle) + data_bundle = TruncateBertPipe(task='cn').process(data_bundle) + data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) + return data_bundle + + +class CNXNLIBertPipe(MatchingBertPipe): + def __init__(self, tokenizer='cn-char'): + super().__init__(tokenizer=tokenizer) + + def process_from_file(self, paths=None): + data_bundle = CNXNLILoader().load(paths) + data_bundle = GranularizePipe(task='XNLI').process(data_bundle) + data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) + data_bundle = self.process(data_bundle) + data_bundle = TruncateBertPipe(task='cn').process(data_bundle) + data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) + return data_bundle + + +class TruncateBertPipe(Pipe): + def __init__(self, task='cn'): + super().__init__() + self.task = task + + def _truncate(self, sentence_index:list, sep_index_vocab): + # 根据[SEP]在vocab中的index,找到[SEP]在dataset的field['words']中的index + sep_index_words = sentence_index.index(sep_index_vocab) + words_before_sep = sentence_index[:sep_index_words] + words_after_sep = sentence_index[sep_index_words:] # 注意此部分包括了[SEP] + if self.task == 'cn': + # 中文任务将Instance['words']中在[SEP]前后的文本分别截至长度不超过250 + words_before_sep = words_before_sep[:250] + words_after_sep = words_after_sep[:250] + elif self.task == 'en': + # 英文任务将Instance['words']中在[SEP]前后的文本分别截至长度不超过215 + words_before_sep = words_before_sep[:215] + words_after_sep = words_after_sep[:215] + else: + raise RuntimeError("Only support 'cn' or 'en' task.") + + return words_before_sep + words_after_sep + + def process(self, data_bundle: DataBundle) -> DataBundle: + for name in data_bundle.datasets.keys(): + dataset = data_bundle.get_dataset(name) + sep_index_vocab = data_bundle.get_vocab('words').to_index('[SEP]') + dataset.apply_field(lambda sent_index: self._truncate(sentence_index=sent_index, sep_index_vocab=sep_index_vocab), field_name='words', new_field_name='words') + + # truncate之后需要更新seq_len + dataset.add_seq_len(field_name='words') + return data_bundle + diff --git a/fastNLP/io/pipe/pipe.py b/fastNLP/io/pipe/pipe.py new file mode 100644 index 00000000..4916bf09 --- /dev/null +++ b/fastNLP/io/pipe/pipe.py @@ -0,0 +1,41 @@ +r"""undocumented""" + +__all__ = [ + "Pipe", +] + +from fastNLP.io.data_bundle import DataBundle + + +class Pipe: + r""" + Pipe是fastNLP中用于处理DataBundle的类,但实际是处理DataBundle中的DataSet。所有Pipe都会在其process()函数的文档中指出该Pipe可处理的DataSet应该具备怎样的格式;在Pipe + 文档中说明该Pipe返回后DataSet的格式以及其field的信息;以及新增的Vocabulary的信息。 + + 一般情况下Pipe处理包含以下的几个过程,(1)将raw_words或raw_chars进行tokenize以切分成不同的词或字; + (2) 再建立词或字的 :class:`~fastNLP.Vocabulary` , 并将词或字转换为index; (3)将target列建立词表并将target列转为index; + + Pipe中提供了两个方法 + + -process()函数,输入为DataBundle + -process_from_file()函数,输入为对应Loader的load函数可接受的类型。 + + """ + + def process(self, data_bundle: DataBundle) -> DataBundle: + r""" + 对输入的DataBundle进行处理,然后返回该DataBundle。 + + :param ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象 + :return: DataBundle + """ + raise NotImplementedError + + def process_from_file(self, paths: str) -> DataBundle: + r""" + 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` + + :param str paths: + :return: DataBundle + """ + raise NotImplementedError diff --git a/fastNLP/io/pipe/qa.py b/fastNLP/io/pipe/qa.py new file mode 100644 index 00000000..4e2a977c --- /dev/null +++ b/fastNLP/io/pipe/qa.py @@ -0,0 +1,144 @@ +r""" +本文件中的Pipe主要用于处理问答任务的数据。 + +""" + +from copy import deepcopy + +from .pipe import Pipe +from fastNLP.io.data_bundle import DataBundle +from ..loader.qa import CMRC2018Loader +from .utils import get_tokenizer +from fastNLP.core.dataset import DataSet +from fastNLP.core.vocabulary import Vocabulary + +__all__ = ['CMRC2018BertPipe'] + + +def _concat_clip(data_bundle, max_len, concat_field_name='raw_chars'): + r""" + 处理data_bundle中的DataSet,将context与question按照character进行tokenize,然后使用[SEP]将两者连接起来。 + + 会新增field: context_len(int), raw_words(list[str]), target_start(int), target_end(int)其中target_start + 与target_end是与raw_chars等长的。其中target_start和target_end是前闭后闭的区间。 + + :param DataBundle data_bundle: 类似["a", "b", "[SEP]", "c", ] + :return: + """ + tokenizer = get_tokenizer('cn-char', lang='cn') + for name in list(data_bundle.datasets.keys()): + ds = data_bundle.get_dataset(name) + data_bundle.delete_dataset(name) + new_ds = DataSet() + for ins in ds: + new_ins = deepcopy(ins) + context = ins['context'] + question = ins['question'] + + cnt_lst = tokenizer(context) + q_lst = tokenizer(question) + + answer_start = -1 + + if len(cnt_lst) + len(q_lst) + 3 > max_len: # 预留开头的[CLS]和[SEP]和中间的[sep] + if 'answer_starts' in ins and 'answers' in ins: + answer_start = int(ins['answer_starts'][0]) + answer = ins['answers'][0] + answer_end = answer_start + len(answer) + if answer_end > max_len - 3 - len(q_lst): + span_start = answer_end + 3 + len(q_lst) - max_len + span_end = answer_end + else: + span_start = 0 + span_end = max_len - 3 - len(q_lst) + cnt_lst = cnt_lst[span_start:span_end] + answer_start = int(ins['answer_starts'][0]) + answer_start -= span_start + answer_end = answer_start + len(ins['answers'][0]) + else: + cnt_lst = cnt_lst[:max_len - len(q_lst) - 3] + else: + if 'answer_starts' in ins and 'answers' in ins: + answer_start = int(ins['answer_starts'][0]) + answer_end = answer_start + len(ins['answers'][0]) + + tokens = cnt_lst + ['[SEP]'] + q_lst + new_ins['context_len'] = len(cnt_lst) + new_ins[concat_field_name] = tokens + + if answer_start != -1: + new_ins['target_start'] = answer_start + new_ins['target_end'] = answer_end - 1 + + new_ds.append(new_ins) + data_bundle.set_dataset(new_ds, name) + + return data_bundle + + +class CMRC2018BertPipe(Pipe): + r""" + 处理之后的DataSet将新增以下的field(传入的field仍然保留) + + .. csv-table:: + :header: "context_len", "raw_chars", "target_start", "target_end", "chars" + + 492, ['范', '廷', '颂... ], 30, 34, "[21, 25, ...]" + 491, ['范', '廷', '颂... ], 41, 61, "[21, 25, ...]" + + ".", "...", "...","...", "..." + + raw_words列是context与question拼起来的结果(连接的地方加入了[SEP]),words是转为index的值, target_start为答案start的index,target_end为答案end的index + (闭区间);context_len指示的是words列中context的长度。 + + 其中各列的meta信息如下: + + .. code:: + + +-------------+-------------+-----------+--------------+------------+-------+---------+ + | field_names | context_len | raw_chars | target_start | target_end | chars | answers | + +-------------+-------------+-----------+--------------+------------+-------+---------| + | is_input | False | False | False | False | True | False | + | is_target | True | True | True | True | False | True | + | ignore_type | False | True | False | False | False | True | + | pad_value | 0 | 0 | 0 | 0 | 0 | 0 | + +-------------+-------------+-----------+--------------+------------+-------+---------+ + + """ + + def __init__(self, max_len=510): + super().__init__() + self.max_len = max_len + + def process(self, data_bundle: DataBundle) -> DataBundle: + r""" + 传入的DataSet应该具备以下的field + + .. csv-table:: + :header:"title", "context", "question", "answers", "answer_starts", "id" + + "范廷颂", "范廷颂枢机(,),圣名保禄·若瑟()...", "范廷颂是什么时候被任为主教的?", ["1963年"], ["30"], "TRAIN_186_QUERY_0" + "范廷颂", "范廷颂枢机(,),圣名保禄·若瑟()...", "1990年,范廷颂担任什么职务?", ["1990年被擢升为天..."], ["41"],"TRAIN_186_QUERY_1" + "...", "...", "...","...", ".", "..." + + :param data_bundle: + :return: + """ + data_bundle = _concat_clip(data_bundle, max_len=self.max_len, concat_field_name='raw_chars') + + src_vocab = Vocabulary() + src_vocab.from_dataset(*[ds for name, ds in data_bundle.iter_datasets() if 'train' in name], + field_name='raw_chars', + no_create_entry_dataset=[ds for name, ds in data_bundle.iter_datasets() + if 'train' not in name] + ) + src_vocab.index_dataset(*data_bundle.datasets.values(), field_name='raw_chars', new_field_name='chars') + data_bundle.set_vocab(src_vocab, 'chars') + + data_bundle.set_input('chars', 'raw_chars', 'answers', 'target_start', 'target_end', 'context_len') + + return data_bundle + + def process_from_file(self, paths=None) -> DataBundle: + data_bundle = CMRC2018Loader().load(paths) + return self.process(data_bundle) diff --git a/fastNLP/io/pipe/summarization.py b/fastNLP/io/pipe/summarization.py new file mode 100644 index 00000000..d5ef4c7e --- /dev/null +++ b/fastNLP/io/pipe/summarization.py @@ -0,0 +1,196 @@ +r"""undocumented""" +import os +import numpy as np + +from .pipe import Pipe +from .utils import _drop_empty_instance +from ..loader.summarization import ExtCNNDMLoader +from ..data_bundle import DataBundle +# from ...core.const import Const +from ...core.vocabulary import Vocabulary +# from ...core._logger import log + + +WORD_PAD = "[PAD]" +WORD_UNK = "[UNK]" +DOMAIN_UNK = "X" +TAG_UNK = "X" + + +class ExtCNNDMPipe(Pipe): + r""" + 对CNN/Daily Mail数据进行适用于extractive summarization task的预处理,预处理之后的数据,具备以下结构: + + .. csv-table:: + :header: "text", "summary", "label", "publication", "text_wd", "words", "seq_len", "target" + + """ + def __init__(self, vocab_size, sent_max_len, doc_max_timesteps, vocab_path=None, domain=False): + r""" + + :param vocab_size: int, 词表大小 + :param sent_max_len: int, 句子最大长度,不足的句子将padding,超出的将截断 + :param doc_max_timesteps: int, 文章最多句子个数,不足的将padding,超出的将截断 + :param vocab_path: str, 外部词表路径 + :param domain: bool, 是否需要建立domain词表 + """ + self.vocab_size = vocab_size + self.vocab_path = vocab_path + self.sent_max_len = sent_max_len + self.doc_max_timesteps = doc_max_timesteps + self.domain = domain + + def process(self, data_bundle: DataBundle): + r""" + 传入的DataSet应该具备如下的结构 + + .. csv-table:: + :header: "text", "summary", "label", "publication" + + ["I got new tires from them and... ","..."], ["The new tires...","..."], [0, 1], "cnndm" + ["Don't waste your time. We had two...","..."], ["Time is precious","..."], [1], "cnndm" + ["..."], ["..."], [], "cnndm" + + :param data_bundle: + :return: 处理得到的数据包括 + .. csv-table:: + :header: "text_wd", "words", "seq_len", "target" + + [["I","got",..."."],...,["..."]], [[54,89,...,5],...,[9,43,..,0]], [1,1,...,0], [0,1,...,0] + [["Don't","waste",...,"."],...,["..."]], [[5234,653,...,5],...,[87,234,..,0]], [1,1,...,0], [1,1,...,0] + [[""],...,[""]], [[],...,[]], [], [] + """ + + if self.vocab_path is None: + error_msg = 'vocab file is not defined!' + print(error_msg) + raise RuntimeError(error_msg) + data_bundle.apply(lambda x: _lower_text(x['text']), new_field_name='text') + data_bundle.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') + data_bundle.apply(lambda x: _split_list(x['text']), new_field_name='text_wd') + data_bundle.apply(lambda x: _convert_label(x["label"], len(x["text"])), new_field_name='target') + + data_bundle.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name='words') + # db.apply(lambda x: _token_mask(x["text_wd"], self.sent_max_len), new_field_name="pad_token_mask") + + # pad document + data_bundle.apply(lambda x: _pad_doc(x['words'], self.sent_max_len, self.doc_max_timesteps), new_field_name='words') + data_bundle.apply(lambda x: _sent_mask(x['words'], self.doc_max_timesteps), new_field_name='seq_len') + data_bundle.apply(lambda x: _pad_label(x['target'], self.doc_max_timesteps), new_field_name='target') + + data_bundle = _drop_empty_instance(data_bundle, "label") + + # set input and target + data_bundle.set_input('words', 'seq_len', 'target', 'seq_len') + + # print("[INFO] Load existing vocab from %s!" % self.vocab_path) + word_list = [] + with open(self.vocab_path, 'r', encoding='utf8') as vocab_f: + cnt = 2 # pad and unk + for line in vocab_f: + pieces = line.split("\t") + word_list.append(pieces[0]) + cnt += 1 + if cnt > self.vocab_size: + break + vocabs = Vocabulary(max_size=self.vocab_size, padding=WORD_PAD, unknown=WORD_UNK) + vocabs.add_word_lst(word_list) + vocabs.build_vocab() + data_bundle.set_vocab(vocabs, "vocab") + + if self.domain is True: + domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK) + domaindict.from_dataset(data_bundle.get_dataset("train"), field_name="publication") + data_bundle.set_vocab(domaindict, "domain") + + return data_bundle + + def process_from_file(self, paths=None): + r""" + :param paths: dict or string + :return: DataBundle + """ + loader = ExtCNNDMLoader() + if self.vocab_path is None: + if paths is None: + paths = loader.download() + if not os.path.isdir(paths): + error_msg = 'vocab file is not defined!' + print(error_msg) + raise RuntimeError(error_msg) + self.vocab_path = os.path.join(paths, 'vocab') + db = loader.load(paths=paths) + db = self.process(db) + for ds in db.datasets.values(): + db.get_vocab("vocab").index_dataset(ds, field_name='words', new_field_name='words') + + return db + + +def _lower_text(text_list): + return [text.lower() for text in text_list] + + +def _split_list(text_list): + return [text.split() for text in text_list] + + +def _convert_label(label, sent_len): + np_label = np.zeros(sent_len, dtype=int) + if label != []: + np_label[np.array(label)] = 1 + return np_label.tolist() + + +def _pad_sent(text_wd, sent_max_len): + pad_text_wd = [] + for sent_wd in text_wd: + if len(sent_wd) < sent_max_len: + pad_num = sent_max_len - len(sent_wd) + sent_wd.extend([WORD_PAD] * pad_num) + else: + sent_wd = sent_wd[:sent_max_len] + pad_text_wd.append(sent_wd) + return pad_text_wd + + +def _token_mask(text_wd, sent_max_len): + token_mask_list = [] + for sent_wd in text_wd: + token_num = len(sent_wd) + if token_num < sent_max_len: + mask = [1] * token_num + [0] * (sent_max_len - token_num) + else: + mask = [1] * sent_max_len + token_mask_list.append(mask) + return token_mask_list + + +def _pad_label(label, doc_max_timesteps): + text_len = len(label) + if text_len < doc_max_timesteps: + pad_label = label + [0] * (doc_max_timesteps - text_len) + else: + pad_label = label[:doc_max_timesteps] + return pad_label + + +def _pad_doc(text_wd, sent_max_len, doc_max_timesteps): + text_len = len(text_wd) + if text_len < doc_max_timesteps: + padding = [WORD_PAD] * sent_max_len + pad_text = text_wd + [padding] * (doc_max_timesteps - text_len) + else: + pad_text = text_wd[:doc_max_timesteps] + return pad_text + + +def _sent_mask(text_wd, doc_max_timesteps): + text_len = len(text_wd) + if text_len < doc_max_timesteps: + sent_mask = [1] * text_len + [0] * (doc_max_timesteps - text_len) + else: + sent_mask = [1] * doc_max_timesteps + return sent_mask + + diff --git a/fastNLP/io/pipe/utils.py b/fastNLP/io/pipe/utils.py new file mode 100644 index 00000000..b3e05f0c --- /dev/null +++ b/fastNLP/io/pipe/utils.py @@ -0,0 +1,224 @@ +r"""undocumented""" + +__all__ = [ + "iob2", + "iob2bioes", + "get_tokenizer", +] + +from typing import List +import warnings + +# from ...core.const import Const +from ...core.vocabulary import Vocabulary +# from ...core._logger import log +from pkg_resources import parse_version + + +def iob2(tags: List[str]) -> List[str]: + r""" + 检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两种格式的区别见 + https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format + + :param tags: 需要转换的tags + """ + for i, tag in enumerate(tags): + if tag == "O": + continue + split = tag.split("-") + if len(split) != 2 or split[0] not in ["I", "B"]: + raise TypeError("The encoding schema is not a valid IOB type.") + if split[0] == "B": + continue + elif i == 0 or tags[i - 1] == "O": # conversion IOB1 to IOB2 + tags[i] = "B" + tag[1:] + elif tags[i - 1][1:] == tag[1:]: + continue + else: # conversion IOB1 to IOB2 + tags[i] = "B" + tag[1:] + return tags + + +def iob2bioes(tags: List[str]) -> List[str]: + r""" + 将iob的tag转换为bioes编码 + :param tags: + :return: + """ + new_tags = [] + for i, tag in enumerate(tags): + if tag == 'O': + new_tags.append(tag) + else: + split = tag.split('-')[0] + if split == 'B': + if i + 1 != len(tags) and tags[i + 1].split('-')[0] == 'I': + new_tags.append(tag) + else: + new_tags.append(tag.replace('B-', 'S-')) + elif split == 'I': + if i + 1 < len(tags) and tags[i + 1].split('-')[0] == 'I': + new_tags.append(tag) + else: + new_tags.append(tag.replace('I-', 'E-')) + else: + raise TypeError("Invalid IOB format.") + return new_tags + + +def get_tokenizer(tokenize_method: str, lang='en'): + r""" + + :param str tokenize_method: 获取tokenzier方法 + :param str lang: 语言,当前仅支持en + :return: 返回tokenize函数 + """ + tokenizer_dict = { + 'spacy': None, + 'raw': _raw_split, + 'cn-char': _cn_char_split, + } + if tokenize_method == 'spacy': + import spacy + spacy.prefer_gpu() + if lang != 'en': + raise RuntimeError("Spacy only supports en right right.") + if parse_version(spacy.__version__) >= parse_version('3.0'): + en = spacy.load('en_core_web_sm') + else: + en = spacy.load(lang) + tokenizer = lambda x: [w.text for w in en.tokenizer(x)] + elif tokenize_method in tokenizer_dict: + tokenizer = tokenizer_dict[tokenize_method] + else: + raise RuntimeError(f"Only support {tokenizer_dict.keys()} tokenizer.") + return tokenizer + + +def _cn_char_split(sent): + return [chars for chars in sent] + + +def _raw_split(sent): + return sent.split() + + +def _indexize(data_bundle, input_field_names='words', target_field_names='target'): + r""" + 在dataset中的field_name列建立词表,'target'列建立词表,并把词表加入到data_bundle中。 + + :param ~fastNLP.DataBundle data_bundle: + :param: str,list input_field_names: + :param: str,list target_field_names: 这一列的vocabulary没有unknown和padding + :return: + """ + if isinstance(input_field_names, str): + input_field_names = [input_field_names] + if isinstance(target_field_names, str): + target_field_names = [target_field_names] + for input_field_name in input_field_names: + src_vocab = Vocabulary() + src_vocab.from_dataset(*[ds for name, ds in data_bundle.iter_datasets() if 'train' in name], + field_name=input_field_name, + no_create_entry_dataset=[ds for name, ds in data_bundle.iter_datasets() + if ('train' not in name) and (ds.has_field(input_field_name))] + ) + src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=input_field_name) + data_bundle.set_vocab(src_vocab, input_field_name) + + for target_field_name in target_field_names: + tgt_vocab = Vocabulary(unknown=None, padding=None) + tgt_vocab.from_dataset(*[ds for name, ds in data_bundle.iter_datasets() if 'train' in name], + field_name=target_field_name, + no_create_entry_dataset=[ds for name, ds in data_bundle.iter_datasets() + if ('train' not in name) and (ds.has_field(target_field_name))] + ) + if len(tgt_vocab._no_create_word) > 0: + warn_msg = f"There are {len(tgt_vocab._no_create_word)} `{target_field_name}` labels" \ + f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ + f"data set but not in train data set!.\n" \ + f"These label(s) are {tgt_vocab._no_create_word}" + warnings.warn(warn_msg) + # log.warning(warn_msg) + tgt_vocab.index_dataset(*[ds for ds in data_bundle.datasets.values() if ds.has_field(target_field_name)], field_name=target_field_name) + data_bundle.set_vocab(tgt_vocab, target_field_name) + + return data_bundle + + +def _add_words_field(data_bundle, lower=False): + r""" + 给data_bundle中的dataset中复制一列words. 并根据lower参数判断是否需要小写化 + + :param data_bundle: + :param bool lower:是否要小写化 + :return: 传入的DataBundle + """ + data_bundle.copy_field(field_name='raw_words', new_field_name='words', ignore_miss_dataset=True) + + if lower: + for name, dataset in data_bundle.datasets.items(): + dataset['words'].lower() + return data_bundle + + +def _add_chars_field(data_bundle, lower=False): + r""" + 给data_bundle中的dataset中复制一列chars. 并根据lower参数判断是否需要小写化 + + :param data_bundle: + :param bool lower:是否要小写化 + :return: 传入的DataBundle + """ + data_bundle.copy_field(field_name='raw_chars', new_field_name='chars', ignore_miss_dataset=True) + + if lower: + for name, dataset in data_bundle.datasets.items(): + dataset['chars'].lower() + return data_bundle + + +def _drop_empty_instance(data_bundle, field_name): + r""" + 删除data_bundle的DataSet中存在的某个field为空的情况 + + :param ~fastNLP.DataBundle data_bundle: + :param str field_name: 对哪个field进行检查,如果为None,则任意field为空都会删掉 + :return: 传入的DataBundle + """ + + def empty_instance(ins): + if field_name: + field_value = ins[field_name] + if field_value in ((), {}, [], ''): + return True + return False + for _, field_value in ins.items(): + if field_value in ((), {}, [], ''): + return True + return False + + for name, dataset in data_bundle.datasets.items(): + dataset.drop(empty_instance) + + return data_bundle + + +def _granularize(data_bundle, tag_map): + r""" + 该函数对data_bundle中'target'列中的内容进行转换。 + + :param data_bundle: + :param dict tag_map: 将target列中的tag做以下的映射,比如{"0":0, "1":0, "3":1, "4":1}, 则会删除target为"2"的instance, + 且将"1"认为是第0类。 + :return: 传入的data_bundle + """ + if tag_map is None: + return data_bundle + for name in list(data_bundle.datasets.keys()): + dataset = data_bundle.get_dataset(name) + dataset.apply_field(lambda target: tag_map.get(target, -100), field_name='target', + new_field_name='target') + dataset.drop(lambda ins: ins['target'] == -100) + data_bundle.set_dataset(dataset, name) + return data_bundle diff --git a/fastNLP/io/utils.py b/fastNLP/io/utils.py new file mode 100644 index 00000000..79806794 --- /dev/null +++ b/fastNLP/io/utils.py @@ -0,0 +1,82 @@ +r""" +.. todo:: + doc +""" + +__all__ = [ + "check_loader_paths" +] + +import os +from pathlib import Path +from typing import Union, Dict + +# from ..core import log + + +def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]: + r""" + 检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果:: + + { + 'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。 + 'test': 'xxx' # 可能有,也可能没有 + ... + } + + 如果paths为不合法的,将直接进行raise相应的错误. 如果paths内不包含train也会报错。 + + :param str paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找包含train(文件名 + 中包含train这个字段), test, dev这三个字段的文件或文件夹; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 + :return: + """ + if isinstance(paths, (str, Path)): + paths = os.path.abspath(os.path.expanduser(paths)) + if os.path.isfile(paths): + return {'train': paths} + elif os.path.isdir(paths): + filenames = os.listdir(paths) + filenames.sort() + files = {} + for filename in filenames: + path_pair = None + if 'train' in filename: + path_pair = ('train', filename) + if 'dev' in filename: + if path_pair: + raise Exception( + "Directory:{} in {} contains both `{}` and `dev`.".format(filename, paths, path_pair[0])) + path_pair = ('dev', filename) + if 'test' in filename: + if path_pair: + raise Exception( + "Directory:{} in {} contains both `{}` and `test`.".format(filename, paths, path_pair[0])) + path_pair = ('test', filename) + if path_pair: + if path_pair[0] in files: + raise FileExistsError(f"Two files contain `{path_pair[0]}` were found, please specify the " + f"filepath for `{path_pair[0]}`.") + files[path_pair[0]] = os.path.join(paths, path_pair[1]) + if 'train' not in files: + raise KeyError(f"There is no train file in {paths}.") + return files + else: + raise FileNotFoundError(f"{paths} is not a valid file path.") + + elif isinstance(paths, dict): + if paths: + if 'train' not in paths: + raise KeyError("You have to include `train` in your dict.") + for key, value in paths.items(): + if isinstance(key, str) and isinstance(value, str): + value = os.path.abspath(os.path.expanduser(value)) + if not os.path.exists(value): + raise TypeError(f"{value} is not a valid path.") + paths[key] = value + else: + raise TypeError("All keys and values in paths should be str.") + return paths + else: + raise ValueError("Empty paths is not allowed.") + else: + raise TypeError(f"paths only supports str and dict. not {type(paths)}.")