diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 79af296b..fa6d90a2 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -3,7 +3,8 @@ utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户 """ __all__ = [ "cache_results", - "seq_len_to_mask" + "seq_len_to_mask", + "Example", ] import _pickle @@ -21,6 +22,32 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require 'varargs']) +class Example(dict): + """a dict can treat keys as attributes""" + def __getattr__(self, item): + try: + return self.__getitem__(item) + except KeyError: + raise AttributeError(item) + + def __setattr__(self, key, value): + if key.startswith('__') and key.endswith('__'): + raise AttributeError(key) + self.__setitem__(key, value) + + def __delattr__(self, item): + try: + self.pop(item) + except KeyError: + raise AttributeError(item) + + def __getstate__(self): + return self + + def __setstate__(self, state): + self.update(state) + + def _prepare_cache_filepath(filepath): """ 检查filepath是否可以作为合理的cache文件. 如果可以的话,会自动创造路径 diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index 0abaa42b..40604deb 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -29,8 +29,12 @@ from nltk.tree import Tree from ..core.dataset import DataSet from ..core.instance import Instance from .file_reader import _read_csv, _read_json, _read_conll -from typing import Union, Dict +from typing import Union, Dict, Iterable import os +from ..core.utils import Example +from ..core import Vocabulary +from ..io import EmbedLoader +import numpy as np def _download_from_url(url, path): @@ -39,7 +43,7 @@ def _download_from_url(url, path): except: from ..core.utils import _pseudo_tqdm as tqdm import requests - + """Download file""" r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) chunk_size = 16 * 1024 @@ -58,11 +62,11 @@ def _uncompress(src, dst): import gzip import tarfile import os - + def unzip(src, dst): with zipfile.ZipFile(src, 'r') as f: f.extractall(dst) - + def ungz(src, dst): with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf: length = 16 * 1024 # 16KB @@ -70,11 +74,11 @@ def _uncompress(src, dst): while buf: uf.write(buf) buf = f.read(length) - + def untar(src, dst): with tarfile.open(src, 'r:gz') as f: f.extractall(dst) - + fn, ext = os.path.splitext(src) _, ext_2 = os.path.splitext(fn) if ext == '.zip': @@ -87,6 +91,34 @@ def _uncompress(src, dst): raise ValueError('unsupported file {}'.format(src)) +class VocabularyOption(Example): + def __init__(self, + max_size=None, + min_freq=None, + padding='', + unknown=''): + super().__init__( + max_size=max_size, + min_freq=min_freq, + padding=padding, + unknown=unknown + ) + + +class EmbeddingOption(Example): + def __init__(self, + embed_filepath=None, + dtype=np.float32, + normalize=True, + error='ignore'): + super().__init__( + embed_filepath=embed_filepath, + dtype=dtype, + normalize=normalize, + error=error + ) + + class DataInfo: """ 经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。 @@ -95,7 +127,7 @@ class DataInfo: :param embeddings: 从名称(字符串)到一系列 embedding 的dict,参考 :class:`~fastNLP.io.EmbedLoader` :param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict """ - + def __init__(self, vocabs: dict = None, embeddings: dict = None, datasets: dict = None): self.vocabs = vocabs or {} self.embeddings = embeddings or {} @@ -106,21 +138,21 @@ class DataSetLoader: """ 别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader` - 定义了各种 DataSetLoader (针对特定数据上的特定任务) 所需的API 接口,开发者应该继承它实现各种的 DataSetLoader。 - + 定义了各种 DataSetLoader 所需的API 接口,开发者应该继承它实现各种的 DataSetLoader。 + 开发者至少应该编写如下内容: - + - _load 函数:从一个数据文件中读取数据到一个 :class:`~fastNLP.DataSet` - load 函数(可以使用基类的方法):从一个或多个数据文件中读取数据到一个或多个 :class:`~fastNLP.DataSet` - process 函数:一个或多个从数据文件中读取数据,并处理成可以训练的一个或多个 :class:`~fastNLP.DataSet` - + **process 函数中可以 调用load 函数或 _load 函数** - + """ - + def _download(self, url: str, path: str, uncompress=True) -> str: """ - + 从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。 :param url: 下载的网站 @@ -136,7 +168,7 @@ class DataSetLoader: _uncompress(path, dst) return dst return path - + def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]: """ 从指定一个或多个路径中的文件中读取数据,返回一个或多个数据集 :class:`~fastNLP.DataSet` 。 @@ -148,7 +180,7 @@ class DataSetLoader: if isinstance(paths, str): return self._load(paths) return {name: self._load(path) for name, path in paths.items()} - + def _load(self, path: str) -> DataSet: """从指定路径的文件中读取数据,返回 :class:`~fastNLP.DataSet` 类型的对象 @@ -156,16 +188,16 @@ class DataSetLoader: :return: 一个 :class:`~fastNLP.DataSet` 类型的对象 """ raise NotImplementedError - + def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo: """ 对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。 - + 从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。 如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。 返回的 :class:`DataInfo` 对象有如下属性: - + - vocabs: 由从数据集中获取的词表组成的字典,每个词表 - embeddings: (可选) 数据集对应的词嵌入 - datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const` @@ -183,12 +215,12 @@ class PeopleDailyCorpusLoader(DataSetLoader): 读取人民日报数据集 """ - + def __init__(self, pos=True, ner=True): super(PeopleDailyCorpusLoader, self).__init__() self.pos = pos self.ner = ner - + def _load(self, data_path): with open(data_path, "r", encoding="utf-8") as f: sents = f.readlines() @@ -233,7 +265,7 @@ class PeopleDailyCorpusLoader(DataSetLoader): example.append(sent_ner) examples.append(example) return self.convert(examples) - + def convert(self, data): """ @@ -284,7 +316,7 @@ class ConllLoader(DataSetLoader): :param indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` :param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False`` """ - + def __init__(self, headers, indexes=None, dropna=False): super(ConllLoader, self).__init__() if not isinstance(headers, (list, tuple)): @@ -298,7 +330,7 @@ class ConllLoader(DataSetLoader): if len(indexes) != len(headers): raise ValueError self.indexes = indexes - + def _load(self, path): ds = DataSet() for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): @@ -316,7 +348,7 @@ class Conll2003Loader(ConllLoader): 关于数据集的更多信息,参考: https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data """ - + def __init__(self): headers = [ 'tokens', 'pos', 'chunks', 'ner', @@ -368,17 +400,17 @@ class SSTLoader(DataSetLoader): :param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` """ - + def __init__(self, subtree=False, fine_grained=False): self.subtree = subtree - + tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral', '3': 'positive', '4': 'very positive'} if not fine_grained: tag_v['0'] = tag_v['1'] tag_v['4'] = tag_v['3'] self.tag_v = tag_v - + def _load(self, path): """ @@ -395,7 +427,7 @@ class SSTLoader(DataSetLoader): for words, tag in datas: ds.append(Instance(words=words, target=tag)) return ds - + @staticmethod def _get_one(data, subtree): tree = Tree.fromstring(data) @@ -403,6 +435,40 @@ class SSTLoader(DataSetLoader): return [(t.leaves(), t.label()) for t in tree.subtrees()] return [(tree.leaves(), tree.label())] + def process(self, + paths, + train_ds: Iterable[str] = None, + src_vocab_op: VocabularyOption = None, + tgt_vocab_op: VocabularyOption = None, + embed_op: EmbeddingOption = None): + input_name, target_name = 'words', 'target' + src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) + tgt_vocab = Vocabulary() if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) + + info = DataInfo(datasets=self.load(paths)) + _train_ds = [info.datasets[name] + for name in train_ds] if train_ds else info.datasets.values() + + src_vocab.from_dataset(_train_ds, field_name=input_name) + tgt_vocab.from_dataset(_train_ds, field_name=target_name) + src_vocab.index_dataset( + info.datasets.values(), + field_name=input_name, new_field_name=input_name) + tgt_vocab.index_dataset( + info.datasets.values(), + field_name=target_name, new_field_name=target_name) + info.vocabs = { + input_name: src_vocab, + target_name: tgt_vocab + } + + if embed_op is not None: + embed_op.vocab = src_vocab + init_emb = EmbedLoader.load_with_vocab(**embed_op) + info.embeddings[input_name] = init_emb + + return info + class JsonLoader(DataSetLoader): """ @@ -417,7 +483,7 @@ class JsonLoader(DataSetLoader): :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . Default: ``False`` """ - + def __init__(self, fields=None, dropna=False): super(JsonLoader, self).__init__() self.dropna = dropna @@ -428,7 +494,7 @@ class JsonLoader(DataSetLoader): for k, v in fields.items(): self.fields[k] = k if v is None else v self.fields_list = list(self.fields.keys()) - + def _load(self, path): ds = DataSet() for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): @@ -452,7 +518,7 @@ class SNLILoader(JsonLoader): 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip """ - + def __init__(self): fields = { 'sentence1_parse': 'words1', @@ -460,14 +526,14 @@ class SNLILoader(JsonLoader): 'gold_label': 'target', } super(SNLILoader, self).__init__(fields=fields) - + def _load(self, path): ds = super(SNLILoader, self)._load(path) - + def parse_tree(x): t = Tree.fromstring(x) return t.leaves() - + ds.apply(lambda ins: parse_tree( ins['words1']), new_field_name='words1') ds.apply(lambda ins: parse_tree( @@ -488,12 +554,12 @@ class CSVLoader(DataSetLoader): :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . Default: ``False`` """ - + def __init__(self, headers=None, sep=",", dropna=False): self.headers = headers self.sep = sep self.dropna = dropna - + def _load(self, path): ds = DataSet() for idx, data in _read_csv(path, headers=self.headers, @@ -508,7 +574,7 @@ def _add_seg_tag(data): :param data: list of ([word], [pos], [heads], [head_tags]) :return: list of ([word], [pos]) """ - + _processed = [] for word_list, pos_list, _, _ in data: new_sample = []