From 014e9786c7abbbb3c043c3a1db19e703ad338659 Mon Sep 17 00:00:00 2001 From: yh Date: Wed, 14 Aug 2019 15:35:05 +0800 Subject: [PATCH] =?UTF-8?q?1.=20=E5=88=86=E7=B1=BBDataSetLoader=E4=B8=AD?= =?UTF-8?q?=E7=9A=84Loader=E5=8A=9F=E8=83=BDPipe=E5=8A=9F=E8=83=BD;=202.?= =?UTF-8?q?=20=E5=A2=9E=E5=8A=A0=E6=95=B0=E6=8D=AE=E9=9B=86=E8=87=AA?= =?UTF-8?q?=E5=8A=A8=E4=B8=8B=E8=BD=BD;=203.=E4=BF=AE=E5=A4=8Dvocabulary?= =?UTF-8?q?=E4=B8=AD=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .travis.yml | 3 + fastNLP/core/batch.py | 12 + fastNLP/core/const.py | 26 +- fastNLP/core/dataset.py | 33 +- fastNLP/core/field.py | 10 +- fastNLP/core/instance.py | 7 + fastNLP/core/utils.py | 21 + fastNLP/core/vocabulary.py | 63 +-- fastNLP/embeddings/__init__.py | 3 +- fastNLP/embeddings/bert_embedding.py | 16 +- fastNLP/embeddings/elmo_embedding.py | 8 +- fastNLP/embeddings/static_embedding.py | 12 +- fastNLP/io/base_loader.py | 89 +++- fastNLP/io/data_loader/conll.py | 116 +++-- fastNLP/io/data_loader/matching.py | 2 +- fastNLP/io/data_loader/mtl.py | 4 +- fastNLP/io/data_loader/sst.py | 10 +- fastNLP/io/data_loader/yelp.py | 4 +- fastNLP/io/dataset_loader.py | 22 - fastNLP/io/file_reader.py | 10 +- fastNLP/io/file_utils.py | 277 +++++++---- fastNLP/io/loader/__init__.py | 30 ++ fastNLP/io/loader/classification.py | 369 +++++++++++++++ fastNLP/io/loader/conll.py | 264 +++++++++++ fastNLP/io/loader/csv.py | 32 ++ fastNLP/io/loader/cws.py | 41 ++ fastNLP/io/loader/json.py | 40 ++ fastNLP/io/loader/loader.py | 75 +++ fastNLP/io/loader/matching.py | 309 ++++++++++++ fastNLP/io/pipe/__init__.py | 8 + fastNLP/io/pipe/classification.py | 444 ++++++++++++++++++ fastNLP/io/pipe/conll.py | 149 ++++++ fastNLP/io/pipe/matching.py | 254 ++++++++++ fastNLP/io/pipe/pipe.py | 9 + fastNLP/io/pipe/utils.py | 142 ++++++ fastNLP/io/utils.py | 14 +- test/embeddings/__init__.py | 0 .../encoder => embeddings}/test_bert.py | 0 test/embeddings/test_elmo_embedding.py | 21 + test/io/loader/test_classification_loader.py | 19 + test/io/loader/test_matching_loader.py | 22 + test/io/pipe/test_classification.py | 13 + test/io/pipe/test_matching.py | 26 + 43 files changed, 2802 insertions(+), 227 deletions(-) create mode 100644 fastNLP/io/loader/__init__.py create mode 100644 fastNLP/io/loader/classification.py create mode 100644 fastNLP/io/loader/conll.py create mode 100644 fastNLP/io/loader/csv.py create mode 100644 fastNLP/io/loader/cws.py create mode 100644 fastNLP/io/loader/json.py create mode 100644 fastNLP/io/loader/loader.py create mode 100644 fastNLP/io/loader/matching.py create mode 100644 fastNLP/io/pipe/__init__.py create mode 100644 fastNLP/io/pipe/classification.py create mode 100644 fastNLP/io/pipe/conll.py create mode 100644 fastNLP/io/pipe/matching.py create mode 100644 fastNLP/io/pipe/pipe.py create mode 100644 fastNLP/io/pipe/utils.py create mode 100644 test/embeddings/__init__.py rename test/{modules/encoder => embeddings}/test_bert.py (100%) create mode 100644 test/embeddings/test_elmo_embedding.py create mode 100644 test/io/loader/test_classification_loader.py create mode 100644 test/io/loader/test_matching_loader.py create mode 100644 test/io/pipe/test_classification.py create mode 100644 test/io/pipe/test_matching.py diff --git a/.travis.yml b/.travis.yml index 210d158a..856ec9c8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,9 @@ language: python python: - "3.6" + +env + - TRAVIS=1 # command to install dependencies install: - pip install --quiet -r requirements.txt diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 538f583a..8d97783e 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -48,6 +48,11 @@ class DataSetGetter: return len(self.dataset) def collate_fn(self, batch: list): + """ + + :param batch: [[idx1, x_dict1, y_dict1], [idx2, x_dict2, y_dict2], [xx, xx, xx]] + :return: + """ # TODO 支持在DataSet中定义collate_fn,因为有时候可能需要不同的field之间融合,比如BERT的场景 batch_x = {n:[] for n in self.inputs.keys()} batch_y = {n:[] for n in self.targets.keys()} @@ -208,6 +213,13 @@ class OnlineDataIter(BatchIter): def _to_tensor(batch, field_dtype): + """ + + :param batch: np.array() + :param field_dtype: 数据类型 + :return: batch, flag. 如果传入的数据支持转为tensor,返回的batch就是tensor,且flag为True;如果传入的数据不支持转为tensor, + 返回的batch就是原来的数据,且flag为False + """ try: if field_dtype is not None and isinstance(field_dtype, type)\ and issubclass(field_dtype, Number) \ diff --git a/fastNLP/core/const.py b/fastNLP/core/const.py index 89ff51a2..27e8d1cb 100644 --- a/fastNLP/core/const.py +++ b/fastNLP/core/const.py @@ -7,12 +7,14 @@ class Const: 具体列表:: - INPUT 模型的序列输入 words(复数words1, words2) - CHAR_INPUT 模型character输入 chars(复数chars1, chars2) - INPUT_LEN 序列长度 seq_len(复数seq_len1,seq_len2) - OUTPUT 模型输出 pred(复数pred1, pred2) - TARGET 真实目标 target(复数target1,target2) - LOSS 损失函数 loss (复数loss1,loss2) + INPUT 模型的序列输入 words(具有多列words时,依次使用words1, words2, ) + CHAR_INPUT 模型character输入 chars(具有多列chars时,依次使用chars1, chars2) + INPUT_LEN 序列长度 seq_len(具有多列seq_len时,依次使用seq_len1,seq_len2) + OUTPUT 模型输出 pred(具有多列pred时,依次使用pred1, pred2) + TARGET 真实目标 target(具有多列target时,依次使用target1,target2) + LOSS 损失函数 loss (具有多列loss时,依次使用loss1,loss2) + RAW_WORD 原文的词 raw_words (具有多列raw_words时,依次使用raw_words1, raw_words2) + RAW_CHAR 原文的字 raw_chars (具有多列raw_chars时,依次使用raw_chars1, raw_chars2) """ INPUT = 'words' @@ -21,6 +23,8 @@ class Const: OUTPUT = 'pred' TARGET = 'target' LOSS = 'loss' + RAW_WORD = 'raw_words' + RAW_CHAR = 'raw_chars' @staticmethod def INPUTS(i): @@ -34,6 +38,16 @@ class Const: i = int(i) + 1 return Const.CHAR_INPUT + str(i) + @staticmethod + def RAW_WORDS(i): + i = int(i) + 1 + return Const.RAW_WORD + str(i) + + @staticmethod + def RAW_CHARS(i): + i = int(i) + 1 + return Const.RAW_CHAR + str(i) + @staticmethod def INPUT_LENS(i): """得到第 i 个 ``INPUT_LEN`` 的命名""" diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 2955eff6..0f98ed1f 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -291,6 +291,7 @@ import _pickle as pickle import warnings import numpy as np +from copy import deepcopy from .field import AutoPadder from .field import FieldArray @@ -298,6 +299,7 @@ from .instance import Instance from .utils import _get_func_signature from .field import AppendToTargetOrInputException from .field import SetInputOrTargetException +from .const import Const class DataSet(object): """ @@ -349,7 +351,11 @@ class DataSet(object): self.idx]) assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx) return self.dataset.field_arrays[item][self.idx] - + + def items(self): + ins = self.dataset[self.idx] + return ins.items() + def __repr__(self): return self.dataset[self.idx].__repr__() @@ -497,6 +503,7 @@ class DataSet(object): else: for field in self.field_arrays.values(): field.pop(index) + return self def delete_field(self, field_name): """ @@ -505,7 +512,22 @@ class DataSet(object): :param str field_name: 需要删除的field的名称. """ self.field_arrays.pop(field_name) - + return self + + def copy_field(self, field_name, new_field_name): + """ + 深度copy名为field_name的field到new_field_name + + :param str field_name: 需要copy的field。 + :param str new_field_name: copy生成的field名称 + :return: self + """ + if not self.has_field(field_name): + raise KeyError(f"Field:{field_name} not found in DataSet.") + fieldarray = deepcopy(self.get_field(field_name)) + self.add_fieldarray(field_name=new_field_name, fieldarray=fieldarray) + return self + def has_field(self, field_name): """ 判断DataSet中是否有名为field_name这个field @@ -701,7 +723,7 @@ class DataSet(object): results.append(func(ins[field_name])) except Exception as e: if idx != -1: - print("Exception happens at the `{}`th instance.".format(idx)) + print("Exception happens at the `{}`th(from 1) instance.".format(idx+1)) raise e if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None raise ValueError("{} always return None.".format(_get_func_signature(func=func))) @@ -766,10 +788,11 @@ class DataSet(object): results = [] for idx, ins in enumerate(self._inner_iter()): results.append(func(ins)) - except Exception as e: + except BaseException as e: if idx != -1: print("Exception happens at the `{}`th instance.".format(idx)) raise e + # results = [func(ins) for ins in self._inner_iter()] if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None raise ValueError("{} always return None.".format(_get_func_signature(func=func))) @@ -779,7 +802,7 @@ class DataSet(object): return results - def add_seq_len(self, field_name:str, new_field_name='seq_len'): + def add_seq_len(self, field_name:str, new_field_name=Const.INPUT_LEN): """ 将使用len()直接对field_name中每个元素作用,将其结果作为seqence length, 并放入seq_len这个field。 diff --git a/fastNLP/core/field.py b/fastNLP/core/field.py index d7d3bb8b..65bd9be4 100644 --- a/fastNLP/core/field.py +++ b/fastNLP/core/field.py @@ -7,6 +7,7 @@ from typing import Any from abc import abstractmethod from copy import deepcopy from collections import Counter +from .utils import _is_iterable class SetInputOrTargetException(Exception): def __init__(self, msg, index=None, field_name=None): @@ -443,15 +444,6 @@ def _get_ele_type_and_dim(cell:Any, dim=0): raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.") -def _is_iterable(value): - # 检查是否是iterable的, duck typing - try: - iter(value) - return True - except BaseException as e: - return False - - class Padder: """ 别名::class:`fastNLP.Padder` :class:`fastNLP.core.field.Padder` diff --git a/fastNLP/core/instance.py b/fastNLP/core/instance.py index 5408522e..9a5d9edf 100644 --- a/fastNLP/core/instance.py +++ b/fastNLP/core/instance.py @@ -35,6 +35,13 @@ class Instance(object): :param Any field: 新增field的内容 """ self.fields[field_name] = field + + def items(self): + """ + 返回一个迭代器,迭代器返回两个内容,第一个内容是field_name, 第二个内容是field_value + :return: + """ + return self.fields.items() def __getitem__(self, name): if name in self.fields: diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 8483f9f2..4ce382f3 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -4,6 +4,7 @@ utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户 __all__ = [ "cache_results", "seq_len_to_mask", + "get_seq_len" ] import _pickle @@ -730,3 +731,23 @@ def iob2bioes(tags: List[str]) -> List[str]: else: raise TypeError("Invalid IOB format.") return new_tags + + +def _is_iterable(value): + # 检查是否是iterable的, duck typing + try: + iter(value) + return True + except BaseException as e: + return False + + +def get_seq_len(words, pad_value=0): + """ + 给定batch_size x max_len的words矩阵,返回句子长度 + + :param words: batch_size x max_len + :return: (batch_size,) + """ + mask = words.ne(pad_value) + return mask.sum(dim=-1) diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 9ce59a8c..a51c3f92 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -4,12 +4,12 @@ __all__ = [ ] from functools import wraps -from collections import Counter, defaultdict +from collections import Counter from .dataset import DataSet from .utils import Option from functools import partial import numpy as np - +from .utils import _is_iterable class VocabularyOption(Option): def __init__(self, @@ -131,11 +131,11 @@ class Vocabulary(object): """ 在新加入word时,检查_no_create_word的设置。 - :param str, List[str] word: + :param str List[str] word: :param bool no_create_entry: :return: """ - if isinstance(word, str): + if isinstance(word, str) or not _is_iterable(word): word = [word] for w in word: if no_create_entry and self.word_count.get(w, 0) == self._no_create_word.get(w, 0): @@ -257,35 +257,45 @@ class Vocabulary(object): vocab.index_dataset(train_data, dev_data, test_data, field_name='words') :param ~fastNLP.DataSet,List[~fastNLP.DataSet] datasets: 需要转index的一个或多个数据集 - :param str field_name: 需要转index的field, 若有多个 DataSet, 每个DataSet都必须有此 field. - 目前仅支持 ``str`` , ``List[str]`` , ``List[List[str]]`` - :param str new_field_name: 保存结果的field_name. 若为 ``None`` , 将覆盖原field. - Default: ``None`` + :param list,str field_name: 需要转index的field, 若有多个 DataSet, 每个DataSet都必须有此 field. + 目前支持 ``str`` , ``List[str]`` + :param list,str new_field_name: 保存结果的field_name. 若为 ``None`` , 将覆盖原field. + Default: ``None``. """ - def index_instance(ins): + def index_instance(field): """ 有几种情况, str, 1d-list, 2d-list :param ins: :return: """ - field = ins[field_name] - if isinstance(field, str): + if isinstance(field, str) or not _is_iterable(field): return self.to_index(field) - elif isinstance(field, list): - if not isinstance(field[0], list): + else: + if isinstance(field[0], str) or not _is_iterable(field[0]): return [self.to_index(w) for w in field] else: - if isinstance(field[0][0], list): + if not isinstance(field[0][0], str) and _is_iterable(field[0][0]): raise RuntimeError("Only support field with 2 dimensions.") return [[self.to_index(c) for c in w] for w in field] - - if new_field_name is None: - new_field_name = field_name + + new_field_name = new_field_name or field_name + + if type(new_field_name) == type(field_name): + if isinstance(new_field_name, list): + assert len(new_field_name) == len(field_name), "new_field_name should have same number elements with " \ + "field_name." + elif isinstance(new_field_name, str): + field_name = [field_name] + new_field_name = [new_field_name] + else: + raise TypeError("field_name and new_field_name can only be str or List[str].") + for idx, dataset in enumerate(datasets): if isinstance(dataset, DataSet): try: - dataset.apply(index_instance, new_field_name=new_field_name) + for f_n, n_f_n in zip(field_name, new_field_name): + dataset.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n) except Exception as e: print("When processing the `{}` dataset, the following error occurred.".format(idx)) raise e @@ -306,9 +316,8 @@ class Vocabulary(object): :param ~fastNLP.DataSet,List[~fastNLP.DataSet] datasets: 需要转index的一个或多个数据集 :param str,List[str] field_name: 可为 ``str`` 或 ``List[str]`` . - 构建词典所使用的 field(s), 支持一个或多个field - 若有多个 DataSet, 每个DataSet都必须有这些field. - 目前仅支持的field结构: ``str`` , ``List[str]`` , ``list[List[str]]`` + 构建词典所使用的 field(s), 支持一个或多个field,若有多个 DataSet, 每个DataSet都必须有这些field. 目前支持的field结构 + : ``str`` , ``List[str]`` :param no_create_entry_dataset: 可以传入DataSet, List[DataSet]或者None(默认),该选项用在接下来的模型会使用pretrain 的embedding(包括glove, word2vec, elmo与bert)且会finetune的情况。如果仅使用来自于train的数据建立vocabulary,会导致test与dev 中的数据无法充分利用到来自于预训练embedding的信息,所以在建立词表的时候将test与dev考虑进来会使得最终的结果更好。 @@ -326,14 +335,14 @@ class Vocabulary(object): def construct_vocab(ins, no_create_entry=False): for fn in field_name: field = ins[fn] - if isinstance(field, str): + if isinstance(field, str) or not _is_iterable(field): self.add_word(field, no_create_entry=no_create_entry) - elif isinstance(field, (list, np.ndarray)): - if not isinstance(field[0], (list, np.ndarray)): + else: + if isinstance(field[0], str) or not _is_iterable(field[0]): for word in field: self.add_word(word, no_create_entry=no_create_entry) else: - if isinstance(field[0][0], (list, np.ndarray)): + if not isinstance(field[0][0], str) and _is_iterable(field[0][0]): raise RuntimeError("Only support field with 2 dimensions.") for words in field: for word in words: @@ -343,8 +352,8 @@ class Vocabulary(object): if isinstance(dataset, DataSet): try: dataset.apply(construct_vocab) - except Exception as e: - print("When processing the `{}` dataset, the following error occurred.".format(idx)) + except BaseException as e: + print("When processing the `{}` dataset, the following error occurred:".format(idx)) raise e else: raise TypeError("Only DataSet type is allowed.") diff --git a/fastNLP/embeddings/__init__.py b/fastNLP/embeddings/__init__.py index 2bfb2960..4f90ac63 100644 --- a/fastNLP/embeddings/__init__.py +++ b/fastNLP/embeddings/__init__.py @@ -10,6 +10,7 @@ __all__ = [ "StaticEmbedding", "ElmoEmbedding", "BertEmbedding", + "BertWordPieceEncoder", "StackEmbedding", "LSTMCharEmbedding", "CNNCharEmbedding", @@ -20,7 +21,7 @@ __all__ = [ from .embedding import Embedding from .static_embedding import StaticEmbedding from .elmo_embedding import ElmoEmbedding -from .bert_embedding import BertEmbedding +from .bert_embedding import BertEmbedding, BertWordPieceEncoder from .char_embedding import CNNCharEmbedding, LSTMCharEmbedding from .stack_embedding import StackEmbedding from .utils import get_embeddings \ No newline at end of file diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index 1fadd491..261007ae 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -8,7 +8,7 @@ import numpy as np from itertools import chain from ..core.vocabulary import Vocabulary -from ..io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR +from ..io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MODEL_DIR from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer from .contextual_embedding import ContextualEmbedding @@ -60,10 +60,8 @@ class BertEmbedding(ContextualEmbedding): # 根据model_dir_or_name检查是否存在并下载 if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: - PRETRAIN_URL = _get_base_url('bert') - model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] - model_url = PRETRAIN_URL + model_name - model_dir = cached_path(model_url) + model_url = _get_embedding_url('bert', model_dir_or_name.lower()) + model_dir = cached_path(model_url, name='embedding') # 检查是否存在 elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): model_dir = os.path.expanduser(os.path.abspath(model_dir_or_name)) @@ -133,11 +131,9 @@ class BertWordPieceEncoder(nn.Module): pooled_cls: bool = False, requires_grad: bool=False): super().__init__() - if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR: - PRETRAIN_URL = _get_base_url('bert') - model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] - model_url = PRETRAIN_URL + model_name - model_dir = cached_path(model_url) + if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: + model_url = _get_embedding_url('bert', model_dir_or_name.lower()) + model_dir = cached_path(model_url, name='embedding') # 检查是否存在 elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): model_dir = model_dir_or_name diff --git a/fastNLP/embeddings/elmo_embedding.py b/fastNLP/embeddings/elmo_embedding.py index 53adfd62..590aba74 100644 --- a/fastNLP/embeddings/elmo_embedding.py +++ b/fastNLP/embeddings/elmo_embedding.py @@ -8,7 +8,7 @@ import json import codecs from ..core.vocabulary import Vocabulary -from ..io.file_utils import cached_path, _get_base_url, PRETRAINED_ELMO_MODEL_DIR +from ..io.file_utils import cached_path, _get_embedding_url, PRETRAINED_ELMO_MODEL_DIR from ..modules.encoder._elmo import ElmobiLm, ConvTokenEmbedder from .contextual_embedding import ContextualEmbedding @@ -53,10 +53,8 @@ class ElmoEmbedding(ContextualEmbedding): # 根据model_dir_or_name检查是否存在并下载 if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR: - PRETRAIN_URL = _get_base_url('elmo') - model_name = PRETRAINED_ELMO_MODEL_DIR[model_dir_or_name] - model_url = PRETRAIN_URL + model_name - model_dir = cached_path(model_url) + model_url = _get_embedding_url('elmo', model_dir_or_name.lower()) + model_dir = cached_path(model_url, name='embedding') # 检查是否存在 elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): model_dir = model_dir_or_name diff --git a/fastNLP/embeddings/static_embedding.py b/fastNLP/embeddings/static_embedding.py index b78e63e8..d44d7087 100644 --- a/fastNLP/embeddings/static_embedding.py +++ b/fastNLP/embeddings/static_embedding.py @@ -7,7 +7,7 @@ import numpy as np import warnings from ..core.vocabulary import Vocabulary -from ..io.file_utils import PRETRAIN_STATIC_FILES, _get_base_url, cached_path +from ..io.file_utils import PRETRAIN_STATIC_FILES, _get_embedding_url, cached_path from .embedding import TokenEmbedding from ..modules.utils import _get_file_name_base_on_postfix @@ -60,10 +60,8 @@ class StaticEmbedding(TokenEmbedding): embedding_dim = int(embedding_dim) model_path = None elif model_dir_or_name.lower() in PRETRAIN_STATIC_FILES: - PRETRAIN_URL = _get_base_url('static') - model_name = PRETRAIN_STATIC_FILES[model_dir_or_name] - model_url = PRETRAIN_URL + model_name - model_path = cached_path(model_url) + model_url = _get_embedding_url('static', model_dir_or_name.lower()) + model_path = cached_path(model_url, name='embedding') # 检查是否存在 elif os.path.isfile(os.path.expanduser(os.path.abspath(model_dir_or_name))): model_path = model_dir_or_name @@ -84,8 +82,8 @@ class StaticEmbedding(TokenEmbedding): if lowered_word not in lowered_vocab.word_count: lowered_vocab.add_word(lowered_word) lowered_vocab._no_create_word[lowered_word] += 1 - print(f"All word in the vocab have been lowered. There are {len(vocab)} words, {len(lowered_vocab)} unique lowered " - f"words.") + print(f"All word in the vocab have been lowered before finding pretrained vectors. There are {len(vocab)} " + f"words, {len(lowered_vocab)} unique lowered words.") if model_path: embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method) else: diff --git a/fastNLP/io/base_loader.py b/fastNLP/io/base_loader.py index 5d61c16a..01232627 100644 --- a/fastNLP/io/base_loader.py +++ b/fastNLP/io/base_loader.py @@ -5,10 +5,10 @@ __all__ = [ ] import _pickle as pickle -import os from typing import Union, Dict import os from ..core.dataset import DataSet +from ..core.vocabulary import Vocabulary class BaseLoader(object): @@ -111,7 +111,10 @@ def _uncompress(src, dst): class DataBundle: """ - 经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)以及各个field对应的vocabulary。 + 经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)以及各个field对应的vocabulary。该对象一般由fastNLP中各种 + DataSetLoader的load函数生成,可以通过以下的方法获取里面的内容 + + Example:: :param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict :param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict @@ -121,6 +124,88 @@ class DataBundle: self.vocabs = vocabs or {} self.datasets = datasets or {} + def set_vocab(self, vocab, field_name): + """ + 向DataBunlde中增加vocab + + :param Vocabulary vocab: 词表 + :param str field_name: 这个vocab对应的field名称 + :return: + """ + assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary supports." + self.vocabs[field_name] = vocab + + def set_dataset(self, dataset, name): + """ + + :param DataSet dataset: 传递给DataBundle的DataSet + :param str name: dataset的名称 + :return: + """ + self.datasets[name] = dataset + + def get_dataset(self, name:str): + """ + 获取名为name的dataset + + :param str name: dataset的名称,一般为'train', 'dev', 'test' + :return: DataSet + """ + return self.datasets[name] + + def get_vocab(self, field_name:str): + """ + 获取field名为field_name对应的vocab + + :param str field_name: 名称 + :return: Vocabulary + """ + return self.vocabs[field_name] + + def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_field=True): + """ + 将field_names中的field设置为input, 对data_bundle中所有的dataset执行该操作:: + + data_bundle.set_input('words', 'seq_len') # 将words和seq_len这两个field的input属性设置为True + data_bundle.set_input('words', flag=False) # 将words这个field的input属性设置为False + + :param str field_names: field的名称 + :param bool flag: 将field_name的input状态设置为flag + :param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一 + 行的数据进行类型和维度推断本列的数据的类型和维度。 + :param bool ignore_miss_field: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略; 如果为False,则报错 + """ + for field_name in field_names: + for name, dataset in self.datasets.items(): + if not ignore_miss_field and not dataset.has_field(field_name): + raise KeyError(f"Field:{field_name} was not found in DataSet:{name}") + if not dataset.has_field(field_name): + continue + else: + dataset.set_input(field_name, flag=flag, use_1st_ins_infer_dim_type=use_1st_ins_infer_dim_type) + + def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_field=True): + """ + 将field_names中的field设置为target, 对data_bundle中所有的dataset执行该操作:: + + data_bundle.set_target('target', 'seq_len') # 将words和target这两个field的input属性设置为True + data_bundle.set_target('target', flag=False) # 将target这个field的input属性设置为False + + :param str field_names: field的名称 + :param bool flag: 将field_name的target状态设置为flag + :param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一 + 行的数据进行类型和维度推断本列的数据的类型和维度。 + :param bool ignore_miss_field: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略; 如果为False,则报错 + """ + for field_name in field_names: + for name, dataset in self.datasets.items(): + if not ignore_miss_field and not dataset.has_field(field_name): + raise KeyError(f"Field:{field_name} was not found in DataSet:{name}") + if not dataset.has_field(field_name): + continue + else: + dataset.set_target(field_name, flag=flag, use_1st_ins_infer_dim_type=use_1st_ins_infer_dim_type) + def __repr__(self): _str = 'In total {} datasets:\n'.format(len(self.datasets)) for name, dataset in self.datasets.items(): diff --git a/fastNLP/io/data_loader/conll.py b/fastNLP/io/data_loader/conll.py index 9b2402a2..0285173c 100644 --- a/fastNLP/io/data_loader/conll.py +++ b/fastNLP/io/data_loader/conll.py @@ -3,38 +3,47 @@ from ...core.dataset import DataSet from ...core.instance import Instance from ..base_loader import DataSetLoader from ..file_reader import _read_conll - +from typing import Union, Dict +from ..utils import check_loader_paths +from ..base_loader import DataBundle class ConllLoader(DataSetLoader): """ 别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.data_loader.ConllLoader` - 读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html. 数据中以"-DOCSTART-"开头的行将被忽略,因为 - 该符号在conll 2003中被用为文档分割符。 - - 列号从0开始, 每列对应内容为:: - - Column Type - 0 Document ID - 1 Part number - 2 Word number - 3 Word itself - 4 Part-of-Speech - 5 Parse bit - 6 Predicate lemma - 7 Predicate Frameset ID - 8 Word sense - 9 Speaker/Author - 10 Named Entities - 11:N Predicate Arguments - N Coreference - - :param headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 - :param indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` - :param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False`` + 该ConllLoader支持读取的数据格式: 以空行隔开两个sample,除了分割行,每一行用空格或者制表符隔开不同的元素。如下例所示: + + Example:: + + # 文件中的内容 + Nadim NNP B-NP B-PER + Ladki NNP I-NP I-PER + + AL-AIN NNP B-NP B-LOC + United NNP B-NP B-LOC + Arab NNP I-NP I-LOC + Emirates NNPS I-NP I-LOC + 1996-12-06 CD I-NP O + ... + + # 如果用以下的参数读取,返回的DataSet将包含raw_words和pos两个field, 这两个field的值分别取自于第0列与第1列 + dataset = ConllLoader(headers=['raw_words', 'pos'], indexes=[0, 1])._load('/path/to/train.conll') + # 如果用以下的参数读取,返回的DataSet将包含raw_words和ner两个field, 这两个field的值分别取自于第0列与第2列 + dataset = ConllLoader(headers=['raw_words', 'ner'], indexes=[0, 3])._load('/path/to/train.conll') + # 如果用以下的参数读取,返回的DataSet将包含raw_words, pos和ner三个field + dataset = ConllLoader(headers=['raw_words', 'pos', 'ner'], indexes=[0, 1, 3])._load('/path/to/train.conll') + + dataset = ConllLoader(headers=['raw_words', 'pos'], indexes=[0, 1])._load('/path/to/train.conll')中DataSet的raw_words + 列与pos列的内容都是List[str] + + 数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。 + + :param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 + :param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` + :param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` """ - def __init__(self, headers, indexes=None, dropna=False): + def __init__(self, headers, indexes=None, dropna=True): super(ConllLoader, self).__init__() if not isinstance(headers, (list, tuple)): raise TypeError( @@ -49,25 +58,74 @@ class ConllLoader(DataSetLoader): self.indexes = indexes def _load(self, path): + """ + 传入的一个文件路径,将该文件读入DataSet中,field由Loader初始化时指定的headers决定。 + + :param str path: 文件的路径 + :return: DataSet + """ ds = DataSet() for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): ins = {h: data[i] for i, h in enumerate(self.headers)} ds.append(Instance(**ins)) return ds + def load(self, paths: Union[str, Dict[str, str]]) -> DataBundle: + """ + 从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。 + + 读取的field根据ConllLoader初始化时传入的headers决定。 + + :param Union[str, Dict[str, str]] paths: 支持以下的几种输入方式 + (1) 传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件 + 名包含'train'、 'dev'、 'test'则会报错 + + Example:: + data_bundle = ConllLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train, dev, test等有所变化 + # 可以通过以下的方式取出DataSet + tr_data = data_bundle.datasets['train'] + te_data = data_bundle.datasets['test'] # 如果目录下有文件包含test这个字段 + + (2) 传入文件path + + Example:: + data_bundle = ConllLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train' + tr_data = data_bundle.datasets['train'] # 可以通过以下的方式取出DataSet + + (3) 传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test + + Example:: + paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"} + data_bundle = ConllLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test" + dev_data = data_bundle.datasets['dev'] + + :return: :class:`~fastNLP.DataSet` 类的对象或 :class:`~fastNLP.io.DataBundle` 的字典 + """ + paths = check_loader_paths(paths) + datasets = {name: self._load(path) for name, path in paths.items()} + data_bundle = DataBundle(datasets=datasets) + return data_bundle + class Conll2003Loader(ConllLoader): """ 别名::class:`fastNLP.io.Conll2003Loader` :class:`fastNLP.io.data_loader.Conll2003Loader` - 读取Conll2003数据 + 该Loader用以读取Conll2003数据,conll2003的数据可以在https://github.com/davidsbatista/NER-datasets/tree/master/CONLL2003 + 找到。数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。 + + 返回的DataSet将具有以下['raw_words', 'pos', 'chunks', 'ner']四个field, 每个field中的内容都是List[str]。 + + .. csv-table:: Conll2003Loader处理之 :header: "raw_words", "words", "target", "seq_len" + + "[Nadim, Ladki]", "[1, 2]", "[1, 2]", 2 + "[AL-AIN, United, Arab, ...]", "[3, 4, 5,...]", "[3, 4]", 5 + "[...]", "[...]", "[...]", . - 关于数据集的更多信息,参考: - https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data """ def __init__(self): headers = [ - 'tokens', 'pos', 'chunks', 'ner', + 'raw_words', 'pos', 'chunks', 'ner', ] super(Conll2003Loader, self).__init__(headers=headers) diff --git a/fastNLP/io/data_loader/matching.py b/fastNLP/io/data_loader/matching.py index 481b5056..1242b432 100644 --- a/fastNLP/io/data_loader/matching.py +++ b/fastNLP/io/data_loader/matching.py @@ -121,7 +121,7 @@ class MatchingLoader(DataSetLoader): PRETRAIN_URL = _get_base_url('bert') model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer] model_url = PRETRAIN_URL + model_name - model_dir = cached_path(model_url) + model_dir = cached_path(model_url, name='embedding') # 检查是否存在 elif os.path.isdir(bert_tokenizer): model_dir = bert_tokenizer diff --git a/fastNLP/io/data_loader/mtl.py b/fastNLP/io/data_loader/mtl.py index cbca413d..20824958 100644 --- a/fastNLP/io/data_loader/mtl.py +++ b/fastNLP/io/data_loader/mtl.py @@ -5,7 +5,7 @@ from ..base_loader import DataBundle from ..dataset_loader import CSVLoader from ...core.vocabulary import Vocabulary, VocabularyOption from ...core.const import Const -from ..utils import check_dataloader_paths +from ..utils import check_loader_paths class MTL16Loader(CSVLoader): @@ -38,7 +38,7 @@ class MTL16Loader(CSVLoader): src_vocab_opt: VocabularyOption = None, tgt_vocab_opt: VocabularyOption = None,): - paths = check_dataloader_paths(paths) + paths = check_loader_paths(paths) datasets = {} info = DataBundle() for name, path in paths.items(): diff --git a/fastNLP/io/data_loader/sst.py b/fastNLP/io/data_loader/sst.py index 6c06a9ce..c2e0eca1 100644 --- a/fastNLP/io/data_loader/sst.py +++ b/fastNLP/io/data_loader/sst.py @@ -8,7 +8,7 @@ from ...core.vocabulary import VocabularyOption, Vocabulary from ...core.dataset import DataSet from ...core.const import Const from ...core.instance import Instance -from ..utils import check_dataloader_paths, get_tokenizer +from ..utils import check_loader_paths, get_tokenizer class SSTLoader(DataSetLoader): @@ -67,7 +67,7 @@ class SSTLoader(DataSetLoader): paths, train_subtree=True, src_vocab_op: VocabularyOption = None, tgt_vocab_op: VocabularyOption = None,): - paths = check_dataloader_paths(paths) + paths = check_loader_paths(paths) input_name, target_name = 'words', 'target' src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) tgt_vocab = Vocabulary(unknown=None, padding=None) \ @@ -129,7 +129,7 @@ class SST2Loader(CSVLoader): tgt_vocab_opt: VocabularyOption = None, char_level_op=False): - paths = check_dataloader_paths(paths) + paths = check_loader_paths(paths) datasets = {} info = DataBundle() for name, path in paths.items(): @@ -155,7 +155,9 @@ class SST2Loader(CSVLoader): for dataset in datasets.values(): dataset.apply_field(wordtochar, field_name=Const.INPUT, new_field_name=Const.CHAR_INPUT) src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) - src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT) + src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT, no_create_entry_dataset=[ + dataset for name, dataset in datasets.items() if name!='train' + ]) src_vocab.index_dataset(*datasets.values(), field_name=Const.INPUT) tgt_vocab = Vocabulary(unknown=None, padding=None) \ diff --git a/fastNLP/io/data_loader/yelp.py b/fastNLP/io/data_loader/yelp.py index 333fcab0..15533b04 100644 --- a/fastNLP/io/data_loader/yelp.py +++ b/fastNLP/io/data_loader/yelp.py @@ -8,7 +8,7 @@ from ...core.instance import Instance from ...core.vocabulary import VocabularyOption, Vocabulary from ..base_loader import DataBundle, DataSetLoader from typing import Union, Dict -from ..utils import check_dataloader_paths, get_tokenizer +from ..utils import check_loader_paths, get_tokenizer class YelpLoader(DataSetLoader): @@ -62,7 +62,7 @@ class YelpLoader(DataSetLoader): src_vocab_op: VocabularyOption = None, tgt_vocab_op: VocabularyOption = None, char_level_op=False): - paths = check_dataloader_paths(paths) + paths = check_loader_paths(paths) info = DataBundle(datasets=self.load(paths)) src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) tgt_vocab = Vocabulary(unknown=None, padding=None) \ diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index ad6bbdc1..3e3ac575 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -114,25 +114,3 @@ def _cut_long_sentence(sent, max_sample_length=200): else: cutted_sentence.append(sent) return cutted_sentence - - -def _add_seg_tag(data): - """ - - :param data: list of ([word], [pos], [heads], [head_tags]) - :return: list of ([word], [pos]) - """ - - _processed = [] - for word_list, pos_list, _, _ in data: - new_sample = [] - for word, pos in zip(word_list, pos_list): - if len(word) == 1: - new_sample.append((word, 'S-' + pos)) - else: - new_sample.append((word[0], 'B-' + pos)) - for c in word[1:-1]: - new_sample.append((c, 'M-' + pos)) - new_sample.append((word[-1], 'E-' + pos)) - _processed.append(list(map(list, zip(*new_sample)))) - return _processed diff --git a/fastNLP/io/file_reader.py b/fastNLP/io/file_reader.py index 0ae0a319..6aa89b80 100644 --- a/fastNLP/io/file_reader.py +++ b/fastNLP/io/file_reader.py @@ -2,7 +2,7 @@ 此模块用于给其它模块提供读取文件的函数,没有为用户提供 API """ import json - +import warnings def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): """ @@ -91,7 +91,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): with open(path, 'r', encoding=encoding) as f: sample = [] start = next(f).strip() - if '-DOCSTART-' not in start and start!='': + if start!='': sample.append(start.split()) for line_idx, line in enumerate(f, 1): line = line.strip() @@ -103,13 +103,13 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): yield line_idx, res except Exception as e: if dropna: + warnings.warn('Invalid instance ends at line: {} has been dropped.'.format(line_idx)) continue - raise ValueError('invalid instance ends at line: {}'.format(line_idx)) + raise ValueError('Invalid instance ends at line: {}'.format(line_idx)) elif line.startswith('#'): continue else: - if not line.startswith('-DOCSTART-'): - sample.append(line.split()) + sample.append(line.split()) if len(sample) > 0: try: res = parse_conll(sample) diff --git a/fastNLP/io/file_utils.py b/fastNLP/io/file_utils.py index 4be1360b..b465ed9b 100644 --- a/fastNLP/io/file_utils.py +++ b/fastNLP/io/file_utils.py @@ -7,7 +7,7 @@ import requests import tempfile from tqdm import tqdm import shutil -import hashlib +from requests import HTTPError PRETRAINED_BERT_MODEL_DIR = { @@ -23,15 +23,25 @@ PRETRAINED_BERT_MODEL_DIR = { 'cn': 'bert-base-chinese-29d0a84a.zip', 'cn-base': 'bert-base-chinese-29d0a84a.zip', - - 'multilingual': 'bert-base-multilingual-cased.zip', - 'multilingual-base-uncased': 'bert-base-multilingual-uncased.zip', - 'multilingual-base-cased': 'bert-base-multilingual-cased.zip', + 'bert-base-chinese': 'bert-base-chinese.zip', + 'bert-base-cased': 'bert-base-cased.zip', + 'bert-base-cased-finetuned-mrpc': 'bert-base-cased-finetuned-mrpc.zip', + 'bert-large-cased-wwm': 'bert-large-cased-wwm.zip', + 'bert-large-uncased': 'bert-large-uncased.zip', + 'bert-large-cased': 'bert-large-cased.zip', + 'bert-base-uncased': 'bert-base-uncased.zip', + 'bert-large-uncased-wwm': 'bert-large-uncased-wwm.zip', + 'bert-chinese-wwm': 'bert-chinese-wwm.zip', + 'bert-base-multilingual-cased': 'bert-base-multilingual-cased.zip', + 'bert-base-multilingual-uncased': 'bert-base-multilingual-uncased.zip', } PRETRAINED_ELMO_MODEL_DIR = { 'en': 'elmo_en-d39843fe.tar.gz', - 'en-small': "elmo_en_Small.zip" + 'en-small': "elmo_en_Small.zip", + 'en-original-5.5b': 'elmo_en_Original_5.5B.zip', + 'en-original': 'elmo_en_Original.zip', + 'en-medium': 'elmo_en_Medium.zip' } PRETRAIN_STATIC_FILES = { @@ -42,34 +52,68 @@ PRETRAIN_STATIC_FILES = { 'en-fasttext-wiki': "wiki-news-300d-1M.vec.zip", 'cn': "tencent_cn-dab24577.tar.gz", 'cn-fasttext': "cc.zh.300.vec-d68a9bcf.gz", + 'sgns-literature-word':'sgns.literature.word.txt.zip', + 'glove-42b-300d': 'glove.42B.300d.zip', + 'glove-6b-50d': 'glove.6B.50d.zip', + 'glove-6b-100d': 'glove.6B.100d.zip', + 'glove-6b-200d': 'glove.6B.200d.zip', + 'glove-6b-300d': 'glove.6B.300d.zip', + 'glove-840b-300d': 'glove.840B.300d.zip', + 'glove-twitter-27b-25d': 'glove.twitter.27B.25d.zip', + 'glove-twitter-27b-50d': 'glove.twitter.27B.50d.zip', + 'glove-twitter-27b-100d': 'glove.twitter.27B.100d.zip', + 'glove-twitter-27b-200d': 'glove.twitter.27B.200d.zip' +} + + +DATASET_DIR = { + 'aclImdb': "imdb.zip", + "yelp-review-full":"yelp_review_full.tar.gz", + "yelp-review-polarity": "yelp_review_polarity.tar.gz", + "mnli": "MNLI.zip", + "snli": "SNLI.zip", + "qnli": "QNLI.zip", + "sst-2": "SST-2.zip", + "sst": "SST.zip", + "rte": "RTE.zip" } -def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path: +def cached_path(url_or_filename:str, cache_dir:str=None, name=None) -> Path: """ - 给定一个url或者文件名(可以是具体的文件名,也可以是文件),先在cache_dir下寻找该文件是否存在,如果不存在则去下载, 并 + 给定一个url,尝试通过url中的解析出来的文件名字filename到{cache_dir}/{name}/{filename}下寻找这个文件, + (1)如果cache_dir=None, 则cache_dir=~/.fastNLP/; 否则cache_dir=cache_dir + (2)如果name=None, 则没有中间的{name}这一层结构;否者中间结构就为{name} + 如果有该文件,就直接返回路径 + 如果没有该文件,则尝试用传入的url下载 + + 或者文件名(可以是具体的文件名,也可以是文件夹),先在cache_dir下寻找该文件是否存在,如果不存在则去下载, 并 将文件放入到cache_dir中. - :param url_or_filename: 文件的下载url或者文件路径 - :param cache_dir: 文件的缓存文件夹 + :param str url_or_filename: 文件的下载url或者文件名称。 + :param str cache_dir: 文件的缓存文件夹。如果为None,将使用"~/.fastNLP"这个默认路径 + :param str name: 中间一层的名称。如embedding, dataset :return: """ if cache_dir is None: - dataset_cache = Path(get_default_cache_path()) + data_cache = Path(get_default_cache_path()) else: - dataset_cache = cache_dir + data_cache = cache_dir + + if name: + data_cache = os.path.join(data_cache, name) parsed = urlparse(url_or_filename) if parsed.scheme in ("http", "https"): # URL, so get it from the cache (downloading if necessary) - return get_from_cache(url_or_filename, dataset_cache) - elif parsed.scheme == "" and Path(os.path.join(dataset_cache, url_or_filename)).exists(): + return get_from_cache(url_or_filename, Path(data_cache)) + elif parsed.scheme == "" and Path(os.path.join(data_cache, url_or_filename)).exists(): # File, and it exists. - return Path(url_or_filename) + return Path(os.path.join(data_cache, url_or_filename)) elif parsed.scheme == "": # File, but it doesn't exist. - raise FileNotFoundError("file {} not found".format(url_or_filename)) + raise FileNotFoundError("file {} not found in {}.".format(url_or_filename, data_cache)) else: # Something unknown raise ValueError( @@ -79,8 +123,12 @@ def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path: def get_filepath(filepath): """ - 如果filepath中只有一个文件,则直接返回对应的全路径. - :param filepath: + 如果filepath为文件夹, + 如果内含多个文件, 返回filepath + 如果只有一个文件, 返回filepath + filename + 如果filepath为文件 + 返回filepath + :param str filepath: 路径 :return: """ if os.path.isdir(filepath): @@ -89,14 +137,17 @@ def get_filepath(filepath): return os.path.join(filepath, files[0]) else: return filepath - return filepath + elif os.path.isfile(filepath): + return filepath + else: + raise FileNotFoundError(f"{filepath} is not a valid file or directory.") def get_default_cache_path(): """ 获取默认的fastNLP存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。 - :return: + :return: str """ if 'FASTNLP_CACHE_DIR' in os.environ: fastnlp_cache_dir = os.environ.get('FASTNLP_CACHE_DIR') @@ -109,17 +160,66 @@ def get_default_cache_path(): def _get_base_url(name): + """ + 根据name返回下载的url地址。 + + :param str name: 支持dataset和embedding两种 + :return: + """ # 返回的URL结尾必须是/ - if 'FASTNLP_BASE_URL' in os.environ: - fastnlp_base_url = os.environ['FASTNLP_BASE_URL'] - if fastnlp_base_url.endswith('/'): - return fastnlp_base_url + environ_name = "FASTNLP_{}_URL".format(name.upper()) + + if environ_name in os.environ: + url = os.environ[environ_name] + if url.endswith('/'): + return url else: - return fastnlp_base_url + '/' + return url + '/' else: - # TODO 替换 - dbbrain_url = "http://dbcloud.irocn.cn:8989/api/public/dl/" - return dbbrain_url + URLS = { + 'embedding': "http://dbcloud.irocn.cn:8989/api/public/dl/", + "dataset": "http://dbcloud.irocn.cn:8989/api/public/dl/dataset/" + } + if name.lower() not in URLS: + raise KeyError(f"{name} is not recognized.") + return URLS[name.lower()] + + +def _get_embedding_url(type, name): + """ + 给定embedding类似和名称,返回下载url + + :param str type: 支持static, bert, elmo。即embedding的类型 + :param str name: embedding的名称, 例如en, cn, based等 + :return: str, 下载的url地址 + """ + PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, + "bert": PRETRAINED_BERT_MODEL_DIR, + "static":PRETRAIN_STATIC_FILES} + map = PRETRAIN_MAP.get(type, None) + if map: + filename = map.get(name, None) + if filename: + url = _get_base_url('embedding') + filename + return url + raise KeyError("There is no {}. Only supports {}.".format(name, list(map.keys()))) + else: + raise KeyError(f"There is no {type}. Only supports bert, elmo, static") + + +def _get_dataset_url(name): + """ + 给定dataset的名称,返回下载url + + :param str name: 给定dataset的名称,比如imdb, sst-2等 + :return: str + """ + filename = DATASET_DIR.get(name, None) + if filename: + url = _get_base_url('dataset') + filename + return url + else: + raise KeyError(f"There is no {name}.") def split_filename_suffix(filepath): @@ -136,9 +236,9 @@ def split_filename_suffix(filepath): def get_from_cache(url: str, cache_dir: Path = None) -> Path: """ - 尝试在cache_dir中寻找url定义的资源; 如果没有找到。则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。 - 如果从url中下载的资源解压后有多个文件,则返回directory的路径; 如果只有一个资源,则返回具体的路径。 - + 尝试在cache_dir中寻找url定义的资源; 如果没有找到; 则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。会将下载的 + 文件解压,将解压后的文件全部放在cache_dir文件夹中。 + 如果从url中下载的资源解压后有多个文件,则返回目录的路径; 如果只有一个资源文件,则返回具体的路径。 """ cache_dir.mkdir(parents=True, exist_ok=True) @@ -173,63 +273,68 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: # GET file object req = requests.get(url, stream=True, headers={"User-Agent": "fastNLP"}) - content_length = req.headers.get("Content-Length") - total = int(content_length) if content_length is not None else None - progress = tqdm(unit="B", total=total) - with open(temp_filename, "wb") as temp_file: - for chunk in req.iter_content(chunk_size=1024): - if chunk: # filter out keep-alive new chunks - progress.update(len(chunk)) - temp_file.write(chunk) - progress.close() - print(f"Finish download from {url}.") - - # 开始解压 - delete_temp_dir = None - if suffix in ('.zip', '.tar.gz'): - uncompress_temp_dir = tempfile.mkdtemp() - delete_temp_dir = uncompress_temp_dir - print(f"Start to uncompress file to {uncompress_temp_dir}") - if suffix == '.zip': - unzip_file(Path(temp_filename), Path(uncompress_temp_dir)) + if req.status_code==200: + content_length = req.headers.get("Content-Length") + total = int(content_length) if content_length is not None else None + progress = tqdm(unit="B", total=total, unit_scale=1) + with open(temp_filename, "wb") as temp_file: + for chunk in req.iter_content(chunk_size=1024*16): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + progress.close() + print(f"Finish download from {url}.") + + # 开始解压 + delete_temp_dir = None + if suffix in ('.zip', '.tar.gz'): + uncompress_temp_dir = tempfile.mkdtemp() + delete_temp_dir = uncompress_temp_dir + print(f"Start to uncompress file to {uncompress_temp_dir}") + if suffix == '.zip': + unzip_file(Path(temp_filename), Path(uncompress_temp_dir)) + else: + untar_gz_file(Path(temp_filename), Path(uncompress_temp_dir)) + filenames = os.listdir(uncompress_temp_dir) + if len(filenames)==1: + if os.path.isdir(os.path.join(uncompress_temp_dir, filenames[0])): + uncompress_temp_dir = os.path.join(uncompress_temp_dir, filenames[0]) + + cache_path.mkdir(parents=True, exist_ok=True) + print("Finish un-compressing file.") else: - untar_gz_file(Path(temp_filename), Path(uncompress_temp_dir)) - filenames = os.listdir(uncompress_temp_dir) - if len(filenames)==1: - if os.path.isdir(os.path.join(uncompress_temp_dir, filenames[0])): - uncompress_temp_dir = os.path.join(uncompress_temp_dir, filenames[0]) - - cache_path.mkdir(parents=True, exist_ok=True) - print("Finish un-compressing file.") + uncompress_temp_dir = temp_filename + cache_path = str(cache_path) + suffix + success = False + try: + # 复制到指定的位置 + print(f"Copy file to {cache_path}") + if os.path.isdir(uncompress_temp_dir): + for filename in os.listdir(uncompress_temp_dir): + if os.path.isdir(os.path.join(uncompress_temp_dir, filename)): + shutil.copytree(os.path.join(uncompress_temp_dir, filename), cache_path/filename) + else: + shutil.copyfile(os.path.join(uncompress_temp_dir, filename), cache_path/filename) + else: + shutil.copyfile(uncompress_temp_dir, cache_path) + success = True + except Exception as e: + print(e) + raise e + finally: + if not success: + if cache_path.exists(): + if cache_path.is_file(): + os.remove(cache_path) + else: + shutil.rmtree(cache_path) + if delete_temp_dir: + shutil.rmtree(delete_temp_dir) + os.close(fd) + os.remove(temp_filename) + return get_filepath(cache_path) else: - uncompress_temp_dir = temp_filename - cache_path = str(cache_path) + suffix - success = False - try: - # 复制到指定的位置 - print(f"Copy file to {cache_path}") - if os.path.isdir(uncompress_temp_dir): - for filename in os.listdir(uncompress_temp_dir): - shutil.copyfile(os.path.join(uncompress_temp_dir, filename), cache_path/filename) - else: - shutil.copyfile(uncompress_temp_dir, cache_path) - success = True - except Exception as e: - print(e) - raise e - finally: - if not success: - if cache_path.exists(): - if cache_path.is_file(): - os.remove(cache_path) - else: - shutil.rmtree(cache_path) - if delete_temp_dir: - shutil.rmtree(delete_temp_dir) - os.close(fd) - os.remove(temp_filename) - - return get_filepath(cache_path) + raise HTTPError(f"Fail to download from {url}.") def unzip_file(file: Path, to: Path): diff --git a/fastNLP/io/loader/__init__.py b/fastNLP/io/loader/__init__.py new file mode 100644 index 00000000..8e436532 --- /dev/null +++ b/fastNLP/io/loader/__init__.py @@ -0,0 +1,30 @@ + +""" +Loader用于读取数据,并将内容读取到 :class:`~fastNLP.DataSet` 或者 :class:`~fastNLP.io.DataBundle`中。所有的Loader都支持以下的 + 三个方法: __init__(),_load(), loads(). 其中__init__()用于申明读取参数,以及说明该Loader支持的数据格式,读取后Dataset中field + ; _load(path)方法传入一个文件路径读取单个文件,并返回DataSet; load(paths)用于读取文件夹下的文件,并返回DataBundle, load()方法 + 支持以下三种类型的参数 + + Example:: + (0) 如果传入None,将尝试自动下载数据集并缓存。但不是所有的数据都可以直接下载。 + (1) 如果传入的是一个文件path,则返回的DataBundle包含一个名为train的DataSet可以通过data_bundle.datasets['train']获取 + (2) 传入的是一个文件夹目录,将读取的是这个文件夹下文件名中包含'train', 'test', 'dev'的文件,其它文件会被忽略。 + 假设某个目录下的文件为 + -train.txt + -dev.txt + -test.txt + -other.txt + Loader().load('/path/to/dir')读取,返回的data_bundle中可以用data_bundle.datasets['train'], data_bundle.datasets['dev'], + data_bundle.datasets['test']获取对应的DataSet,其中other.txt的内容会被忽略。 + 假设某个目录下的文件为 + -train.txt + -dev.txt + Loader().load('/path/to/dir')读取,返回的data_bundle中可以用data_bundle.datasets['train'], data_bundle.datasets['dev']获取 + 对应的DataSet。 + (3) 传入一个dict,key为dataset的名称,value是该dataset的文件路径。 + paths = {'train':'/path/to/train', 'dev': '/path/to/dev', 'test':'/path/to/test'} + Loader().load(paths) # 返回的data_bundle可以通过以下的方式获取相应的DataSet, data_bundle.datasets['train'], data_bundle.datasets['dev'], + data_bundle.datasets['test'] + +""" + diff --git a/fastNLP/io/loader/classification.py b/fastNLP/io/loader/classification.py new file mode 100644 index 00000000..dd85b4fe --- /dev/null +++ b/fastNLP/io/loader/classification.py @@ -0,0 +1,369 @@ +from ...core.dataset import DataSet +from ...core.instance import Instance +from .loader import Loader +import warnings +import os +import random +import shutil +import numpy as np + +class YelpLoader(Loader): + """ + 别名::class:`fastNLP.io.YelpLoader` :class:`fastNLP.io.loader.YelpLoader` + + 原始数据中内容应该为, 每一行为一个sample,第一个逗号之前为target,第一个逗号之后为文本内容。 + + Example:: + "1","I got 'new' tires from the..." + "1","Don't waste your time..." + + 读取YelpFull, YelpPolarity的数据。可以通过xxx下载并预处理数据。 + 读取的DataSet将具备以下的数据结构 + + .. csv-table:: + :header: "raw_words", "target" + + "I got 'new' tires from them and... ", "1" + "Don't waste your time. We had two...", "1" + "...", "..." + + """ + + def __init__(self): + super(YelpLoader, self).__init__() + + def _load(self, path: str=None): + ds = DataSet() + with open(path, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + sep_index = line.index(',') + target = line[:sep_index] + raw_words = line[sep_index + 1:] + if target.startswith("\""): + target = target[1:] + if target.endswith("\""): + target = target[:-1] + if raw_words.endswith("\""): + raw_words = raw_words[:-1] + if raw_words.startswith('"'): + raw_words = raw_words[1:] + raw_words = raw_words.replace('""', '"') # 替换双引号 + if raw_words: + ds.append(Instance(raw_words=raw_words, target=target)) + return ds + + +class YelpFullLoader(YelpLoader): + def download(self, dev_ratio: float = 0.1, seed: int = 0): + """ + 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 + + Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances + in Neural Information Processing Systems 28 (NIPS 2015) + + 根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后在output_dir中有train.csv, test.csv, + dev.csv三个文件。 + + :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 + :param int seed: 划分dev时的随机数种子 + :return: str, 数据集的目录地址 + """ + + dataset_name = 'yelp-review-full' + data_dir = self._get_dataset_path(dataset_name=dataset_name) + if os.path.exists(os.path.join(data_dir, 'dev.csv')): # 存在dev的话,check是否需要重新下载 + re_download = True + if dev_ratio>0: + dev_line_count = 0 + tr_line_count = 0 + with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f1, \ + open(os.path.join(data_dir, 'dev.csv'), 'r', encoding='utf-8') as f2: + for line in f1: + tr_line_count += 1 + for line in f2: + dev_line_count += 1 + if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005): + re_download = True + else: + re_download = False + if re_download: + shutil.rmtree(data_dir) + data_dir = self._get_dataset_path(dataset_name=dataset_name) + + if not os.path.exists(os.path.join(data_dir, 'dev.csv')): + if dev_ratio > 0: + assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." + random.seed(int(seed)) + try: + with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f, \ + open(os.path.join(data_dir, 'middle_file.csv'), 'w', encoding='utf-8') as f1, \ + open(os.path.join(data_dir, 'dev.csv'), 'w', encoding='utf-8') as f2: + for line in f: + if random.random() < dev_ratio: + f2.write(line) + else: + f1.write(line) + os.remove(os.path.join(data_dir, 'train.csv')) + os.renames(os.path.join(data_dir, 'middle_file.csv'), os.path.join(data_dir, 'train.csv')) + finally: + if os.path.exists(os.path.join(data_dir, 'middle_file.csv')): + os.remove(os.path.join(data_dir, 'middle_file.csv')) + + return data_dir + + +class YelpPolarityLoader(YelpLoader): + def download(self, dev_ratio: float = 0.1, seed: int = 0): + """ + 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 + + Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances + in Neural Information Processing Systems 28 (NIPS 2015) + + 根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后从train中切分0.1作为dev + + :param float dev_ratio: 如果路径中不存在dev.csv, 从train划分多少作为dev的数据. 如果为0,则不划分dev + :param int seed: 划分dev时的随机数种子 + :return: str, 数据集的目录地址 + """ + dataset_name = 'yelp-review-polarity' + data_dir = self._get_dataset_path(dataset_name=dataset_name) + if os.path.exists(os.path.join(data_dir, 'dev.csv')): # 存在dev的话,check是否符合比例要求 + re_download = True + if dev_ratio>0: + dev_line_count = 0 + tr_line_count = 0 + with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f1, \ + open(os.path.join(data_dir, 'dev.csv'), 'r', encoding='utf-8') as f2: + for line in f1: + tr_line_count += 1 + for line in f2: + dev_line_count += 1 + if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005): + re_download = True + else: + re_download = False + if re_download: + shutil.rmtree(data_dir) + data_dir = self._get_dataset_path(dataset_name=dataset_name) + + if not os.path.exists(os.path.join(data_dir, 'dev.csv')): + if dev_ratio > 0: + assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." + random.seed(int(seed)) + try: + with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f, \ + open(os.path.join(data_dir, 'middle_file.csv'), 'w', encoding='utf-8') as f1, \ + open(os.path.join(data_dir, 'dev.csv'), 'w', encoding='utf-8') as f2: + for line in f: + if random.random() < dev_ratio: + f2.write(line) + else: + f1.write(line) + os.remove(os.path.join(data_dir, 'train.csv')) + os.renames(os.path.join(data_dir, 'middle_file.csv'), os.path.join(data_dir, 'train.csv')) + finally: + if os.path.exists(os.path.join(data_dir, 'middle_file.csv')): + os.remove(os.path.join(data_dir, 'middle_file.csv')) + + return data_dir + + +class IMDBLoader(Loader): + """ + 别名::class:`fastNLP.io.IMDBLoader` :class:`fastNLP.io.loader.IMDBLoader` + + IMDBLoader读取后的数据将具有以下两列内容: raw_words: str, 需要分类的文本; target: str, 文本的标签 + DataSet具备以下的结构: + + .. csv-table:: + :header: "raw_words", "target" + + "Bromwell High is a cartoon ... ", "pos" + "Story of a man who has ...", "neg" + "...", "..." + + """ + + def __init__(self): + super(IMDBLoader, self).__init__() + + def _load(self, path: str): + dataset = DataSet() + with open(path, 'r', encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + parts = line.split('\t') + target = parts[0] + words = parts[1] + if words: + dataset.append(Instance(raw_words=words, target=target)) + + if len(dataset) == 0: + raise RuntimeError(f"{path} has no valid data.") + + return dataset + + def download(self, dev_ratio: float = 0.1, seed: int = 0): + """ + 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 + + http://www.aclweb.org/anthology/P11-1015 + + 根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后从train中切分0.1作为dev + + :param float dev_ratio: 如果路径中没有dev.txt。从train划分多少作为dev的数据. 如果为0,则不划分dev + :param int seed: 划分dev时的随机数种子 + :return: str, 数据集的目录地址 + """ + dataset_name = 'aclImdb' + data_dir = self._get_dataset_path(dataset_name=dataset_name) + if os.path.exists(os.path.join(data_dir, 'dev.txt')): # 存在dev的话,check是否符合比例要求 + re_download = True + if dev_ratio>0: + dev_line_count = 0 + tr_line_count = 0 + with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f1, \ + open(os.path.join(data_dir, 'dev.txt'), 'r', encoding='utf-8') as f2: + for line in f1: + tr_line_count += 1 + for line in f2: + dev_line_count += 1 + if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005): + re_download = True + else: + re_download = False + if re_download: + shutil.rmtree(data_dir) + data_dir = self._get_dataset_path(dataset_name=dataset_name) + + if not os.path.exists(os.path.join(data_dir, 'dev.csv')): + if dev_ratio > 0: + assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." + random.seed(int(seed)) + try: + with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f, \ + open(os.path.join(data_dir, 'middle_file.txt'), 'w', encoding='utf-8') as f1, \ + open(os.path.join(data_dir, 'dev.txt'), 'w', encoding='utf-8') as f2: + for line in f: + if random.random() < dev_ratio: + f2.write(line) + else: + f1.write(line) + os.remove(os.path.join(data_dir, 'train.txt')) + os.renames(os.path.join(data_dir, 'middle_file.txt'), os.path.join(data_dir, 'train.txt')) + finally: + if os.path.exists(os.path.join(data_dir, 'middle_file.txt')): + os.remove(os.path.join(data_dir, 'middle_file.txt')) + + return data_dir + + +class SSTLoader(Loader): + """ + 别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.loader.SSTLoader` + + 读取之后的DataSet具有以下的结构 + + .. csv-table:: 下面是使用SSTLoader读取的DataSet所具备的field + :header: "raw_words" + + "(3 (2 It) (4 (4 (2 's) (4 (3 (2 a)..." + "(4 (4 (2 Offers) (3 (3 (2 that) (3 (3 rare)..." + "..." + + raw_words列是str。 + + """ + + def __init__(self): + super().__init__() + + def _load(self, path: str): + """ + 从path读取SST文件 + + :param str path: 文件路径 + :return: DataSet + """ + ds = DataSet() + with open(path, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if line: + ds.append(Instance(raw_words=line)) + return ds + + def download(self): + """ + 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 + + https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf + + :return: str, 数据集的目录地址 + """ + output_dir = self._get_dataset_path(dataset_name='sst') + return output_dir + + +class SST2Loader(Loader): + """ + 数据SST2的Loader + 读取之后DataSet将如下所示 + + .. csv-table:: 下面是使用SSTLoader读取的DataSet所具备的field + :header: "raw_words", "target" + + "it 's a charming and often affecting...", "1" + "unflinchingly bleak and...", "0" + "..." + + test的DataSet没有target列。 + """ + + def __init__(self): + super().__init__() + + def _load(self, path: str): + """ + 从path读取SST2文件 + + :param str path: 数据路径 + :return: DataSet + """ + ds = DataSet() + + with open(path, 'r', encoding='utf-8') as f: + f.readline() # 跳过header + if 'test' in os.path.split(path)[1]: + warnings.warn("SST2's test file has no target.") + for line in f: + line = line.strip() + if line: + sep_index = line.index('\t') + raw_words = line[sep_index + 1:] + if raw_words: + ds.append(Instance(raw_words=raw_words)) + else: + for line in f: + line = line.strip() + if line: + raw_words = line[:-2] + target = line[-1] + if raw_words: + ds.append(Instance(raw_words=raw_words, target=target)) + return ds + + def download(self): + """ + 自动下载数据集,如果你使用了该数据集,请引用以下的文章 + + https://nlp.stanford.edu/pubs/SocherBauerManningNg_ACL2013.pdf + + :return: + """ + output_dir = self._get_dataset_path(dataset_name='sst-2') + return output_dir diff --git a/fastNLP/io/loader/conll.py b/fastNLP/io/loader/conll.py new file mode 100644 index 00000000..43790c15 --- /dev/null +++ b/fastNLP/io/loader/conll.py @@ -0,0 +1,264 @@ +from typing import Dict, Union + +from .loader import Loader +from ... import DataSet +from ..file_reader import _read_conll +from ... import Instance +from .. import DataBundle +from ..utils import check_loader_paths +from ... import Const + + +class ConllLoader(Loader): + """ + 别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.data_loader.ConllLoader` + + ConllLoader支持读取的数据格式: 以空行隔开两个sample,除了分割行,每一行用空格或者制表符隔开不同的元素。如下例所示: + + Example:: + + # 文件中的内容 + Nadim NNP B-NP B-PER + Ladki NNP I-NP I-PER + + AL-AIN NNP B-NP B-LOC + United NNP B-NP B-LOC + Arab NNP I-NP I-LOC + Emirates NNPS I-NP I-LOC + 1996-12-06 CD I-NP O + ... + + # 如果用以下的参数读取,返回的DataSet将包含raw_words和pos两个field, 这两个field的值分别取自于第0列与第1列 + dataset = ConllLoader(headers=['raw_words', 'pos'], indexes=[0, 1])._load('/path/to/train.conll') + # 如果用以下的参数读取,返回的DataSet将包含raw_words和ner两个field, 这两个field的值分别取自于第0列与第2列 + dataset = ConllLoader(headers=['raw_words', 'ner'], indexes=[0, 3])._load('/path/to/train.conll') + # 如果用以下的参数读取,返回的DataSet将包含raw_words, pos和ner三个field + dataset = ConllLoader(headers=['raw_words', 'pos', 'ner'], indexes=[0, 1, 3])._load('/path/to/train.conll') + + ConllLoader返回的DataSet的field由传入的headers确定。 + + 数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。 + + :param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 + :param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` + :param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` + + """ + def __init__(self, headers, indexes=None, dropna=True): + super(ConllLoader, self).__init__() + if not isinstance(headers, (list, tuple)): + raise TypeError( + 'invalid headers: {}, should be list of strings'.format(headers)) + self.headers = headers + self.dropna = dropna + if indexes is None: + self.indexes = list(range(len(self.headers))) + else: + if len(indexes) != len(headers): + raise ValueError + self.indexes = indexes + + def _load(self, path): + """ + 传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 + + :param str path: 文件的路径 + :return: DataSet + """ + ds = DataSet() + for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): + ins = {h: data[i] for i, h in enumerate(self.headers)} + ds.append(Instance(**ins)) + return ds + + +class Conll2003Loader(ConllLoader): + """ + 用于读取conll2003任务的数据。数据的内容应该类似与以下的内容, 第一列为raw_words, 第二列为pos, 第三列为chunking,第四列为ner。 + + Example:: + + Nadim NNP B-NP B-PER + Ladki NNP I-NP I-PER + + AL-AIN NNP B-NP B-LOC + United NNP B-NP B-LOC + Arab NNP I-NP I-LOC + Emirates NNPS I-NP I-LOC + 1996-12-06 CD I-NP O + ... + + 返回的DataSet的内容为 + + .. csv-table:: 下面是Conll2003Loader加载后数据具备的结构。 + :header: "raw_words", "pos", "chunk", "ner" + + "[Nadim, Ladki]", "[NNP, NNP]", "[B-NP, I-NP]", "[B-PER, I-PER]" + "[AL-AIN, United, Arab, ...]", "[NNP, NNP, NNP, ...]", "[B-NP, B-NP, I-NP, ...]", "[B-LOC, B-LOC, I-LOC, ...]" + "[...]", "[...]", "[...]", "[...]" + + """ + def __init__(self): + headers = [ + 'raw_words', 'pos', 'chunk', 'ner', + ] + super(Conll2003Loader, self).__init__(headers=headers) + + def _load(self, path): + """ + 传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 + + :param str path: 文件的路径 + :return: DataSet + """ + ds = DataSet() + for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): + doc_start = False + for i, h in enumerate(self.headers): + field = data[i] + if str(field[0]).startswith('-DOCSTART-'): + doc_start = True + break + if doc_start: + continue + ins = {h: data[i] for i, h in enumerate(self.headers)} + ds.append(Instance(**ins)) + return ds + + def download(self, output_dir=None): + raise RuntimeError("conll2003 cannot be downloaded automatically.") + + +class Conll2003NERLoader(ConllLoader): + """ + 用于读取conll2003任务的NER数据。 + + Example:: + + Nadim NNP B-NP B-PER + Ladki NNP I-NP I-PER + + AL-AIN NNP B-NP B-LOC + United NNP B-NP B-LOC + Arab NNP I-NP I-LOC + Emirates NNPS I-NP I-LOC + 1996-12-06 CD I-NP O + ... + + 返回的DataSet的内容为 + + .. csv-table:: 下面是Conll2003Loader加载后数据具备的结构, target是BIO2编码 + :header: "raw_words", "target" + + "[Nadim, Ladki]", "[B-PER, I-PER]" + "[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]" + "[...]", "[...]" + + """ + def __init__(self): + headers = [ + 'raw_words', 'target', + ] + super().__init__(headers=headers, indexes=[0, 3]) + + def _load(self, path): + """ + 传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 + + :param str path: 文件的路径 + :return: DataSet + """ + ds = DataSet() + for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): + doc_start = False + for i, h in enumerate(self.headers): + field = data[i] + if str(field[0]).startswith('-DOCSTART-'): + doc_start = True + break + if doc_start: + continue + ins = {h: data[i] for i, h in enumerate(self.headers)} + ds.append(Instance(**ins)) + return ds + + def download(self): + raise RuntimeError("conll2003 cannot be downloaded automatically.") + + +class OntoNotesNERLoader(ConllLoader): + """ + 用以读取OntoNotes的NER数据,同时也是Conll2012的NER任务数据。将OntoNote数据处理为conll格式的过程可以参考 + https://github.com/yhcc/OntoNotes-5.0-NER。OntoNoteNERLoader将取第4列和第11列的内容。 + + 返回的DataSet的内容为 + + .. csv-table:: 下面是使用OntoNoteNERLoader读取的DataSet所具备的结构, target列是BIO编码 + :header: "raw_words", "target" + + "[Nadim, Ladki]", "[B-PER, I-PER]" + "[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]" + "[...]", "[...]" + + """ + + def __init__(self): + super().__init__(headers=[Const.RAW_WORD, Const.TARGET], indexes=[3, 10]) + + def _load(self, path:str): + dataset = super()._load(path) + + def convert_to_bio(tags): + bio_tags = [] + flag = None + for tag in tags: + label = tag.strip("()*") + if '(' in tag: + bio_label = 'B-' + label + flag = label + elif flag: + bio_label = 'I-' + flag + else: + bio_label = 'O' + if ')' in tag: + flag = None + bio_tags.append(bio_label) + return bio_tags + + def convert_word(words): + converted_words = [] + for word in words: + word = word.replace('/.', '.') # 有些结尾的.是/.形式的 + if not word.startswith('-'): + converted_words.append(word) + continue + # 以下是由于这些符号被转义了,再转回来 + tfrs = {'-LRB-':'(', + '-RRB-': ')', + '-LSB-': '[', + '-RSB-': ']', + '-LCB-': '{', + '-RCB-': '}' + } + if word in tfrs: + converted_words.append(tfrs[word]) + else: + converted_words.append(word) + return converted_words + + dataset.apply_field(convert_word, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD) + dataset.apply_field(convert_to_bio, field_name=Const.TARGET, new_field_name=Const.TARGET) + + return dataset + + def download(self): + raise RuntimeError("Ontonotes cannot be downloaded automatically, you can refer " + "https://github.com/yhcc/OntoNotes-5.0-NER to download and preprocess.") + + +class CTBLoader(Loader): + def __init__(self): + super().__init__() + + def _load(self, path:str): + pass diff --git a/fastNLP/io/loader/csv.py b/fastNLP/io/loader/csv.py new file mode 100644 index 00000000..166f912b --- /dev/null +++ b/fastNLP/io/loader/csv.py @@ -0,0 +1,32 @@ +from ...core.dataset import DataSet +from ...core.instance import Instance +from ..file_reader import _read_csv +from .loader import Loader + + +class CSVLoader(Loader): + """ + 别名::class:`fastNLP.io.CSVLoader` :class:`fastNLP.io.dataset_loader.CSVLoader` + + 读取CSV格式的数据集, 返回 ``DataSet`` 。 + + :param List[str] headers: CSV文件的文件头.定义每一列的属性名称,即返回的DataSet中`field`的名称 + 若为 ``None`` ,则将读入文件的第一行视作 ``headers`` . Default: ``None`` + :param str sep: CSV文件中列与列之间的分隔符. Default: "," + :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . + Default: ``False`` + """ + + def __init__(self, headers=None, sep=",", dropna=False): + super().__init__() + self.headers = headers + self.sep = sep + self.dropna = dropna + + def _load(self, path): + ds = DataSet() + for idx, data in _read_csv(path, headers=self.headers, + sep=self.sep, dropna=self.dropna): + ds.append(Instance(**data)) + return ds + diff --git a/fastNLP/io/loader/cws.py b/fastNLP/io/loader/cws.py new file mode 100644 index 00000000..46c07f28 --- /dev/null +++ b/fastNLP/io/loader/cws.py @@ -0,0 +1,41 @@ + +from .loader import Loader +from ...core import DataSet, Instance + + +class CWSLoader(Loader): + """ + 分词任务数据加载器, + SigHan2005的数据可以用xxx下载并预处理 + + CWSLoader支持的数据格式为,一行一句话,不同词之间用空格隔开, 例如: + + Example:: + + 上海 浦东 开发 与 法制 建设 同步 + 新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 ) + ... + + 该Loader读取后的DataSet具有如下的结构 + + .. csv-table:: + :header: "raw_words" + + "上海 浦东 开发 与 法制 建设 同步" + "新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )" + "..." + """ + def __init__(self): + super().__init__() + + def _load(self, path:str): + ds = DataSet() + with open(path, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if line: + ds.append(Instance(raw_words=line)) + return ds + + def download(self, output_dir=None): + raise RuntimeError("You can refer {} for sighan2005's data downloading.") diff --git a/fastNLP/io/loader/json.py b/fastNLP/io/loader/json.py new file mode 100644 index 00000000..8856b73a --- /dev/null +++ b/fastNLP/io/loader/json.py @@ -0,0 +1,40 @@ +from ...core.dataset import DataSet +from ...core.instance import Instance +from ..file_reader import _read_json +from .loader import Loader + + +class JsonLoader(Loader): + """ + 别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.loader.JsonLoader` + + 读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象 + + :param dict fields: 需要读入的json属性名称, 和读入后在DataSet中存储的field_name + ``fields`` 的 `key` 必须是json对象的属性名. ``fields`` 的 `value` 为读入后在DataSet存储的 `field_name` , + `value` 也可为 ``None`` , 这时读入后的 `field_name` 与json对象对应属性同名 + ``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None`` + :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . + Default: ``False`` + """ + + def __init__(self, fields=None, dropna=False): + super(JsonLoader, self).__init__() + self.dropna = dropna + self.fields = None + self.fields_list = None + if fields: + self.fields = {} + for k, v in fields.items(): + self.fields[k] = k if v is None else v + self.fields_list = list(self.fields.keys()) + + def _load(self, path): + ds = DataSet() + for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): + if self.fields: + ins = {self.fields[k]: v for k, v in d.items()} + else: + ins = d + ds.append(Instance(**ins)) + return ds diff --git a/fastNLP/io/loader/loader.py b/fastNLP/io/loader/loader.py new file mode 100644 index 00000000..4cf5bcf3 --- /dev/null +++ b/fastNLP/io/loader/loader.py @@ -0,0 +1,75 @@ +from ... import DataSet +from .. import DataBundle +from ..utils import check_loader_paths +from typing import Union, Dict +import os +from ..file_utils import _get_dataset_url, get_default_cache_path, cached_path + +class Loader: + def __init__(self): + pass + + def _load(self, path:str) -> DataSet: + raise NotImplementedError + + def load(self, paths: Union[str, Dict[str, str]]=None) -> DataBundle: + """ + 从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。 + + 读取的field根据ConllLoader初始化时传入的headers决定。 + + :param Union[str, Dict[str, str]] paths: 支持以下的几种输入方式 + (0) 如果为None,则先查看本地是否有缓存,如果没有则自动下载并缓存。 + + (1) 传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件 + 名包含'train'、 'dev'、 'test'则会报错 + + Example:: + + data_bundle = ConllLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train、 + # dev、 test等有所变化,可以通过以下的方式取出DataSet + tr_data = data_bundle.datasets['train'] + te_data = data_bundle.datasets['test'] # 如果目录下有文件包含test这个字段 + + (2) 传入文件路径 + + Example:: + + data_bundle = ConllLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train' + tr_data = data_bundle.datasets['train'] # 可以通过以下的方式取出DataSet + + (3) 传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test + + Example:: + + paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"} + data_bundle = ConllLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test" + dev_data = data_bundle.datasets['dev'] + + :return: 返回的:class:`~fastNLP.io.DataBundle` + """ + if paths is None: + paths = self.download() + paths = check_loader_paths(paths) + datasets = {name: self._load(path) for name, path in paths.items()} + data_bundle = DataBundle(datasets=datasets) + return data_bundle + + def download(self): + raise NotImplementedError(f"{self.__class__} cannot download data automatically.") + + def _get_dataset_path(self, dataset_name): + """ + 传入dataset的名称,获取读取数据的目录。如果数据不存在,会尝试自动下载并缓存 + + :param str dataset_name: 数据集的名称 + :return: str, 数据集的目录地址。直接到该目录下读取相应的数据即可。 + """ + + default_cache_path = get_default_cache_path() + url = _get_dataset_url(dataset_name) + output_dir = cached_path(url_or_filename=url, cache_dir=default_cache_path, name='dataset') + + return output_dir + + diff --git a/fastNLP/io/loader/matching.py b/fastNLP/io/loader/matching.py new file mode 100644 index 00000000..eff98ba3 --- /dev/null +++ b/fastNLP/io/loader/matching.py @@ -0,0 +1,309 @@ + +import warnings +from .loader import Loader +from .json import JsonLoader +from ...core import Const +from .. import DataBundle +import os +from typing import Union, Dict +from ...core import DataSet +from ...core import Instance + +__all__ = ['MNLILoader', + "QuoraLoader", + "SNLILoader", + "QNLILoader", + "RTELoader"] + + +class MNLILoader(Loader): + """ + 读取MNLI任务的数据,读取之后的DataSet中包含以下的内容,words0是sentence1, words1是sentence2, target是gold_label, 测试集中没 + 有target列。 + + .. csv-table:: + :header: "raw_words1", "raw_words2", "target" + + "The new rights are...", "Everyone really likes..", "neutral" + "This site includes a...", "The Government Executive...", "contradiction" + "...", "...","." + + """ + def __init__(self): + super().__init__() + + def _load(self, path:str): + ds = DataSet() + with open(path, 'r', encoding='utf-8') as f: + f.readline() # 跳过header + if path.endswith("test.tsv"): + warnings.warn("RTE's test file has no target.") + for line in f: + line = line.strip() + if line: + parts = line.split('\t') + raw_words1 = parts[8] + raw_words2 = parts[9] + if raw_words1 and raw_words2: + ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2)) + else: + for line in f: + line = line.strip() + if line: + parts = line.split('\t') + raw_words1 = parts[8] + raw_words2 = parts[9] + target = parts[-1] + if raw_words1 and raw_words2 and target: + ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) + return ds + + def load(self, paths:str=None): + """ + + :param str paths: 传入数据所在目录,会在该目录下寻找dev_matched.tsv, dev_mismatched.tsv, test_matched.tsv, + test_mismatched.tsv, train.tsv文件夹 + :return: DataBundle + """ + if paths: + paths = os.path.abspath(os.path.expanduser(paths)) + else: + paths = self.download() + if not os.path.isdir(paths): + raise NotADirectoryError(f"{paths} is not a valid directory.") + + files = {'dev_matched':"dev_matched.tsv", + "dev_mismatched":"dev_mismatched.tsv", + "test_matched":"test_matched.tsv", + "test_mismatched":"test_mismatched.tsv", + "train":'train.tsv'} + + datasets = {} + for name, filename in files.items(): + filepath = os.path.join(paths, filename) + if not os.path.isfile(filepath): + if 'test' not in name: + raise FileNotFoundError(f"{name} not found in directory {filepath}.") + datasets[name] = self._load(filepath) + + data_bundle = DataBundle(datasets=datasets) + + return data_bundle + + def download(self): + """ + 如果你使用了这个数据,请引用 + + https://www.nyu.edu/projects/bowman/multinli/paper.pdf + :return: + """ + output_dir = self._get_dataset_path('mnli') + return output_dir + + +class SNLILoader(JsonLoader): + """ + 读取之后的DataSet中的field情况为 + + .. csv-table:: 下面是使用SNLILoader加载的DataSet所具备的field + :header: "raw_words1", "raw_words2", "target" + + "The new rights are...", "Everyone really likes..", "neutral" + "This site includes a...", "The Government Executive...", "entailment" + "...", "...", "." + + """ + def __init__(self): + super().__init__(fields={ + 'sentence1': Const.RAW_WORDS(0), + 'sentence2': Const.RAW_WORDS(1), + 'gold_label': Const.TARGET, + }) + + def load(self, paths: Union[str, Dict[str, str]]=None) -> DataBundle: + """ + 从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。 + + 读取的field根据ConllLoader初始化时传入的headers决定。 + + :param str paths: 传入一个目录, 将在该目录下寻找snli_1.0_train.jsonl, snli_1.0_dev.jsonl + 和snli_1.0_test.jsonl三个文件。 + + :return: 返回的:class:`~fastNLP.io.DataBundle` + """ + _paths = {} + if paths is None: + paths = self.download() + if paths: + if os.path.isdir(paths): + if not os.path.isfile(os.path.join(paths, 'snli_1.0_train.jsonl')): + raise FileNotFoundError(f"snli_1.0_train.jsonl is not found in {paths}") + _paths['train'] = os.path.join(paths, 'snli_1.0_train.jsonl') + for filename in ['snli_1.0_dev.jsonl', 'snli_1.0_test.jsonl']: + filepath = os.path.join(paths, filename) + _paths[filename.split('_')[-1].split('.')[0]] = filepath + paths = _paths + else: + raise NotADirectoryError(f"{paths} is not a valid directory.") + + datasets = {name: self._load(path) for name, path in paths.items()} + data_bundle = DataBundle(datasets=datasets) + return data_bundle + + def download(self): + """ + 如果您的文章使用了这份数据,请引用 + + http://nlp.stanford.edu/pubs/snli_paper.pdf + + :return: str + """ + return self._get_dataset_path('snli') + + +class QNLILoader(JsonLoader): + """ + QNLI数据集的Loader, + 加载的DataSet将具备以下的field, raw_words1是question, raw_words2是sentence, target是label + + .. csv-table:: + :header: "raw_words1", "raw_words2", "target" + + "What came into force after the new...", "As of that day...", "entailment" + "What is the first major...", "The most important tributaries", "not_entailment" + "...","." + + test数据集没有target列 + + """ + def __init__(self): + super().__init__() + + def _load(self, path): + ds = DataSet() + + with open(path, 'r', encoding='utf-8') as f: + f.readline() # 跳过header + if path.endswith("test.tsv"): + warnings.warn("QNLI's test file has no target.") + for line in f: + line = line.strip() + if line: + parts = line.split('\t') + raw_words1 = parts[1] + raw_words2 = parts[2] + if raw_words1 and raw_words2: + ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2)) + else: + for line in f: + line = line.strip() + if line: + parts = line.split('\t') + raw_words1 = parts[1] + raw_words2 = parts[2] + target = parts[-1] + if raw_words1 and raw_words2 and target: + ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) + return ds + + def download(self): + """ + 如果您的实验使用到了该数据,请引用 + + TODO 补充 + + :return: + """ + return self._get_dataset_path('qnli') + + +class RTELoader(Loader): + """ + RTE数据的loader + 加载的DataSet将具备以下的field, raw_words1是sentence0,raw_words2是sentence1, target是label + + .. csv-table:: + :header: "raw_words1", "raw_words2", "target" + + "Dana Reeve, the widow of the actor...", "Christopher Reeve had an...", "not_entailment" + "Yet, we now are discovering that...", "Bacteria is winning...", "entailment" + "...","." + + test数据集没有target列 + """ + def __init__(self): + super().__init__() + + def _load(self, path:str): + ds = DataSet() + + with open(path, 'r', encoding='utf-8') as f: + f.readline() # 跳过header + if path.endswith("test.tsv"): + warnings.warn("RTE's test file has no target.") + for line in f: + line = line.strip() + if line: + parts = line.split('\t') + raw_words1 = parts[1] + raw_words2 = parts[2] + if raw_words1 and raw_words2: + ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2)) + else: + for line in f: + line = line.strip() + if line: + parts = line.split('\t') + raw_words1 = parts[1] + raw_words2 = parts[2] + target = parts[-1] + if raw_words1 and raw_words2 and target: + ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) + return ds + + def download(self): + return self._get_dataset_path('rte') + + +class QuoraLoader(Loader): + """ + Quora matching任务的数据集Loader + + 支持读取的文件中的内容,应该有以下的形式, 以制表符分隔,且前三列的内容必须是:第一列是label,第二列和第三列是句子 + + Example:: + + 1 How do I get funding for my web based startup idea ? How do I get seed funding pre product ? 327970 + 1 How can I stop my depression ? What can I do to stop being depressed ? 339556 + ... + + 加载的DataSet将具备以下的field + + .. csv-table:: + :header: "raw_words1", "raw_words2", "target" + + "What should I do to avoid...", "1" + "How do I not sleep in a boring class...", "0" + "...","." + + """ + def __init__(self): + super().__init__() + + def _load(self, path:str): + ds = DataSet() + + with open(path, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if line: + parts = line.split('\t') + raw_words1 = parts[1] + raw_words2 = parts[2] + target = parts[0] + if raw_words1 and raw_words2 and target: + ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) + return ds + + def download(self): + raise RuntimeError("Quora cannot be downloaded automatically.") diff --git a/fastNLP/io/pipe/__init__.py b/fastNLP/io/pipe/__init__.py new file mode 100644 index 00000000..0cf8d949 --- /dev/null +++ b/fastNLP/io/pipe/__init__.py @@ -0,0 +1,8 @@ + + +""" +Pipe用于处理数据,所有的Pipe都包含一个process(DataBundle)方法,传入一个DataBundle对象, 在传入DataBundle上进行原位修改,并将其返回; + process_from_file(paths)传入的文件路径,返回一个DataBundle。process(DataBundle)或者process_from_file(paths)的返回DataBundle + 中的DataSet一般都包含原文与转换为index的输入,以及转换为index的target;除了DataSet之外,还会包含将field转为index时所建立的词表。 + +""" \ No newline at end of file diff --git a/fastNLP/io/pipe/classification.py b/fastNLP/io/pipe/classification.py new file mode 100644 index 00000000..a64e5328 --- /dev/null +++ b/fastNLP/io/pipe/classification.py @@ -0,0 +1,444 @@ + +from nltk import Tree + +from ..base_loader import DataBundle +from ...core.vocabulary import Vocabulary +from ...core.const import Const +from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader +from ...core import DataSet, Instance + +from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance +from .pipe import Pipe +import re +nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') +from ...core import cache_results + +class _CLSPipe(Pipe): + """ + 分类问题的基类,负责对classification的数据进行tokenize操作。默认是对raw_words列操作,然后生成words列 + + """ + def __init__(self, tokenizer:str='spacy', lang='en'): + self.tokenizer = get_tokenizer(tokenizer, lang=lang) + + def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): + """ + 将DataBundle中的数据进行tokenize + + :param DataBundle data_bundle: + :param str field_name: + :param str new_field_name: + :return: 传入的DataBundle对象 + """ + new_field_name = new_field_name or field_name + for name, dataset in data_bundle.datasets.items(): + dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) + + return data_bundle + + def _granularize(self, data_bundle, tag_map): + """ + 该函数对data_bundle中'target'列中的内容进行转换。 + + :param data_bundle: + :param dict tag_map: 将target列中的tag做以下的映射,比如{"0":0, "1":0, "3":1, "4":1}, 则会删除target为"2"的instance, + 且将"1"认为是第0类。 + :return: 传入的data_bundle + """ + for name in list(data_bundle.datasets.keys()): + dataset = data_bundle.get_dataset(name) + dataset.apply_field(lambda target:tag_map.get(target, -100), field_name=Const.TARGET, + new_field_name=Const.TARGET) + dataset.drop(lambda ins:ins[Const.TARGET] == -100) + data_bundle.set_dataset(dataset, name) + return data_bundle + + +def _clean_str(words): + """ + heavily borrowed from github + https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb + :param sentence: is a str + :return: + """ + words_collection = [] + for word in words: + if word in ['-lrb-', '-rrb-', '', '-r', '-l', 'b-']: + continue + tt = nonalpnum.split(word) + t = ''.join(tt) + if t != '': + words_collection.append(t) + + return words_collection + + +class YelpFullPipe(_CLSPipe): + """ + 处理YelpFull的数据, 处理之后DataSet中的内容如下 + + .. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field + :header: "raw_words", "words", "target", "seq_len" + + "It 's a ...", "[4, 2, 10, ...]", 0, 10 + "Offers that ...", "[20, 40, ...]", 1, 21 + "...", "[...]", ., . + + :param bool lower: 是否对输入进行小写化。 + :param int granularity: 支持2, 3, 5。若为2, 则认为是2分类问题,将1、2归为1类,4、5归为一类,丢掉2;若为3, 则有3分类问题,将 + 1、2归为1类,3归为1类,4、5归为1类;若为5, 则有5分类问题。 + :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 + """ + def __init__(self, lower:bool=False, granularity=5, tokenizer:str='spacy'): + super().__init__(tokenizer=tokenizer, lang='en') + self.lower = lower + assert granularity in (2, 3, 5), "granularity can only be 2,3,5." + self.granularity = granularity + + if granularity==2: + self.tag_map = {"1": 0, "2": 0, "4": 1, "5": 1} + elif granularity==3: + self.tag_map = {"1": 0, "2": 0, "3":1, "4": 2, "5": 2} + else: + self.tag_map = {"1": 0, "2": 1, "3": 2, "4": 3, "5": 4} + + def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): + """ + 将DataBundle中的数据进行tokenize + + :param DataBundle data_bundle: + :param str field_name: + :param str new_field_name: + :return: 传入的DataBundle对象 + """ + new_field_name = new_field_name or field_name + for name, dataset in data_bundle.datasets.items(): + dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) + dataset.apply_field(_clean_str, field_name=field_name, new_field_name=new_field_name) + return data_bundle + + def process(self, data_bundle): + """ + 传入的DataSet应该具备如下的结构 + + .. csv-table:: + :header: "raw_words", "target" + + "I got 'new' tires from them and... ", "1" + "Don't waste your time. We had two...", "1" + "...", "..." + + :param data_bundle: + :return: + """ + + # 复制一列words + data_bundle = _add_words_field(data_bundle, lower=self.lower) + + # 进行tokenize + data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) + + # 根据granularity设置tag + data_bundle = self._granularize(data_bundle, tag_map=self.tag_map) + + # 删除空行 + data_bundle = _drop_empty_instance(data_bundle, field_name=Const.INPUT) + + # index + data_bundle = _indexize(data_bundle=data_bundle) + + for name, dataset in data_bundle.datasets.items(): + dataset.add_seq_len(Const.INPUT) + + data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) + data_bundle.set_target(Const.TARGET) + + return data_bundle + + def process_from_file(self, paths=None): + """ + + :param paths: + :return: DataBundle + """ + data_bundle = YelpFullLoader().load(paths) + return self.process(data_bundle=data_bundle) + + +class YelpPolarityPipe(_CLSPipe): + """ + 处理YelpPolarity的数据, 处理之后DataSet中的内容如下 + + .. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field + :header: "raw_words", "words", "target", "seq_len" + + "It 's a ...", "[4, 2, 10, ...]", 0, 10 + "Offers that ...", "[20, 40, ...]", 1, 21 + "...", "[...]", ., . + + :param bool lower: 是否对输入进行小写化。 + :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 + """ + def __init__(self, lower:bool=False, tokenizer:str='spacy'): + super().__init__(tokenizer=tokenizer, lang='en') + self.lower = lower + + def process(self, data_bundle): + # 复制一列words + data_bundle = _add_words_field(data_bundle, lower=self.lower) + + # 进行tokenize + data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) + # index + data_bundle = _indexize(data_bundle=data_bundle) + + for name, dataset in data_bundle.datasets.items(): + dataset.add_seq_len(Const.INPUT) + + data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) + data_bundle.set_target(Const.TARGET) + + return data_bundle + + def process_from_file(self, paths=None): + """ + + :param str paths: + :return: DataBundle + """ + data_bundle = YelpPolarityLoader().load(paths) + return self.process(data_bundle=data_bundle) + + +class SSTPipe(_CLSPipe): + """ + 别名::class:`fastNLP.io.SSTPipe` :class:`fastNLP.io.pipe.SSTPipe` + + 经过该Pipe之后,DataSet中具备的field如下所示 + + .. csv-table:: 下面是使用SSTPipe处理后的DataSet所具备的field + :header: "raw_words", "words", "target", "seq_len" + + "It 's a ...", "[4, 2, 10, ...]", 0, 16 + "Offers that ...", "[20, 40, ...]", 1, 18 + "...", "[...]", ., . + + :param bool subtree: 是否将train, test, dev数据展开为子树,扩充数据量。 Default: ``False`` + :param bool train_subtree: 是否将train集通过子树扩展数据。 + :param bool lower: 是否对输入进行小写化。 + :param int granularity: 支持2, 3, 5。若为2, 则认为是2分类问题,将0、1归为1类,3、4归为一类,丢掉2;若为3, 则有3分类问题,将 + 0、1归为1类,2归为1类,3、4归为1类;若为5, 则有5分类问题。 + :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 + """ + + def __init__(self, subtree=False, train_subtree=True, lower=False, granularity=5, tokenizer='spacy'): + super().__init__(tokenizer=tokenizer, lang='en') + self.subtree = subtree + self.train_tree = train_subtree + self.lower = lower + assert granularity in (2, 3, 5), "granularity can only be 2,3,5." + self.granularity = granularity + + if granularity==2: + self.tag_map = {"0": 0, "1": 0, "3": 1, "4": 1} + elif granularity==3: + self.tag_map = {"0": 0, "1": 0, "2":1, "3": 2, "4": 2} + else: + self.tag_map = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4} + + def process(self, data_bundle:DataBundle): + """ + 对DataBundle中的数据进行预处理。输入的DataSet应该至少拥有raw_words这一列,且内容类似与 + + .. csv-table:: + :header: "raw_words" + + "(3 (2 It) (4 (4 (2 's) (4 (3 (2 a)..." + "(4 (4 (2 Offers) (3 (3 (2 that) (3 (3 rare)..." + "..." + + :param DataBundle data_bundle: 需要处理的DataBundle对象 + :return: + """ + # 先取出subtree + for name in list(data_bundle.datasets.keys()): + dataset = data_bundle.get_dataset(name) + ds = DataSet() + use_subtree = self.subtree or (name == 'train' and self.train_tree) + for ins in dataset: + raw_words = ins['raw_words'] + tree = Tree.fromstring(raw_words) + if use_subtree: + for t in tree.subtrees(): + raw_words = " ".join(t.leaves()) + instance = Instance(raw_words=raw_words, target=t.label()) + ds.append(instance) + else: + instance = Instance(raw_words=' '.join(tree.leaves()), target=tree.label()) + ds.append(instance) + data_bundle.set_dataset(ds, name) + + _add_words_field(data_bundle, lower=self.lower) + + # 进行tokenize + data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) + + # 根据granularity设置tag + data_bundle = self._granularize(data_bundle, tag_map=self.tag_map) + + # index + data_bundle = _indexize(data_bundle=data_bundle) + + for name, dataset in data_bundle.datasets.items(): + dataset.add_seq_len(Const.INPUT) + + data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) + data_bundle.set_target(Const.TARGET) + + return data_bundle + + def process_from_file(self, paths=None): + data_bundle = SSTLoader().load(paths) + return self.process(data_bundle=data_bundle) + + +class SST2Pipe(_CLSPipe): + """ + 加载SST2的数据, 处理完成之后DataSet将拥有以下的field + + .. csv-table:: + :header: "raw_words", "words", "target", "seq_len" + + "it 's a charming and... ", "[3, 4, 5, 6, 7,...]", 1, 43 + "unflinchingly bleak and...", "[10, 11, 7,...]", 1, 21 + "...", "...", ., . + + :param bool lower: 是否对输入进行小写化。 + :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 + """ + def __init__(self, lower=False, tokenizer='spacy'): + super().__init__(tokenizer=tokenizer, lang='en') + self.lower = lower + + def process(self, data_bundle:DataBundle): + """ + 可以处理的DataSet应该具备如下的结构 + + .. csv-table:: + :header: "raw_words", "target" + + "it 's a charming and... ", 1 + "unflinchingly bleak and...", 1 + "...", "..." + + :param data_bundle: + :return: + """ + _add_words_field(data_bundle, self.lower) + + data_bundle = self._tokenize(data_bundle=data_bundle) + + src_vocab = Vocabulary() + src_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.INPUT, + no_create_entry_dataset=[dataset for name,dataset in data_bundle.datasets.items() if + name != 'train']) + src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT) + + tgt_vocab = Vocabulary(unknown=None, padding=None) + tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) + datasets = [] + for name, dataset in data_bundle.datasets.items(): + if dataset.has_field(Const.TARGET): + datasets.append(dataset) + tgt_vocab.index_dataset(*datasets, field_name=Const.TARGET) + + data_bundle.set_vocab(src_vocab, Const.INPUT) + data_bundle.set_vocab(tgt_vocab, Const.TARGET) + + for name, dataset in data_bundle.datasets.items(): + dataset.add_seq_len(Const.INPUT) + + data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) + data_bundle.set_target(Const.TARGET) + + return data_bundle + + def process_from_file(self, paths=None): + """ + + :param str paths: 如果为None,则自动下载并缓存到fastNLP的缓存地址。 + :return: DataBundle + """ + data_bundle = SST2Loader().load(paths) + return self.process(data_bundle) + + +class IMDBPipe(_CLSPipe): + """ + 经过本Pipe处理后DataSet将如下 + + .. csv-table:: 输出DataSet的field + :header: "raw_words", "words", "target", "seq_len" + + "Bromwell High is a cartoon ... ", "[3, 5, 6, 9, ...]", 0, 20 + "Story of a man who has ...", "[20, 43, 9, 10, ...]", 1, 31 + "...", "[...]", ., . + + 其中raw_words为str类型,是原文; words是转换为index的输入; target是转换为index的目标值; + words列被设置为input; target列被设置为target。 + + :param bool lower: 是否将words列的数据小写。 + :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 + """ + def __init__(self, lower:bool=False, tokenizer:str='spacy'): + super().__init__(tokenizer=tokenizer, lang='en') + self.lower = lower + + def process(self, data_bundle:DataBundle): + """ + 期待的DataBunlde中输入的DataSet应该类似于如下,有两个field,raw_words和target,且均为str类型 + + .. csv-table:: 输入DataSet的field + :header: "raw_words", "target" + + "Bromwell High is a cartoon ... ", "pos" + "Story of a man who has ...", "neg" + "...", "..." + + :param DataBunlde data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和target两个field,且raw_words列应该为str, + target列应该为str。 + :return:DataBundle + """ + # 替换
+ def replace_br(raw_words): + raw_words = raw_words.replace("
", ' ') + return raw_words + + for name, dataset in data_bundle.datasets.items(): + dataset.apply_field(replace_br, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD) + + _add_words_field(data_bundle, lower=self.lower) + self._tokenize(data_bundle, field_name=Const.INPUT, new_field_name=Const.INPUT) + _indexize(data_bundle) + + for name, dataset in data_bundle.datasets.items(): + dataset.add_seq_len(Const.INPUT) + dataset.set_input(Const.INPUT, Const.INPUT_LEN) + dataset.set_target(Const.TARGET) + + return data_bundle + + def process_from_file(self, paths=None): + """ + + :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 + :return: DataBundle + """ + # 读取数据 + data_bundle = IMDBLoader().load(paths) + data_bundle = self.process(data_bundle) + + return data_bundle + + + diff --git a/fastNLP/io/pipe/conll.py b/fastNLP/io/pipe/conll.py new file mode 100644 index 00000000..4f780614 --- /dev/null +++ b/fastNLP/io/pipe/conll.py @@ -0,0 +1,149 @@ +from .pipe import Pipe +from .. import DataBundle +from .utils import iob2, iob2bioes +from ... import Const +from ..loader.conll import Conll2003NERLoader, OntoNotesNERLoader +from .utils import _indexize, _add_words_field + + +class _NERPipe(Pipe): + """ + NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表 + (创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将words,target列根据相应的 + Vocabulary转换为index。 + + raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 + target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target, seq_len。 + + :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 + :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 + :param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为-100。 + """ + def __init__(self, encoding_type:str='bio', lower:bool=False, target_pad_val=0): + if encoding_type == 'bio': + self.convert_tag = iob2 + else: + self.convert_tag = iob2bioes + self.lower = lower + self.target_pad_val = int(target_pad_val) + + def process(self, data_bundle:DataBundle)->DataBundle: + """ + 支持的DataSet的field为 + + .. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader + :header: "raw_words", "target" + + "[Nadim, Ladki]", "[B-PER, I-PER]" + "[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]" + "[...]", "[...]" + + + :param DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]。 + 在传入DataBundle基础上原位修改。 + :return: DataBundle + + Example:: + + data_bundle = Conll2003Loader().load('/path/to/conll2003/') + data_bundle = Conll2003NERPipe().process(data_bundle) + + # 获取train + tr_data = data_bundle.get_dataset('train') + + # 获取target这个field的词表 + target_vocab = data_bundle.get_vocab('target') + # 获取words这个field的词表 + word_vocab = data_bundle.get_vocab('words') + + """ + # 转换tag + for name, dataset in data_bundle.datasets.items(): + dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET) + + _add_words_field(data_bundle, lower=self.lower) + + # index + _indexize(data_bundle) + + input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] + target_fields = [Const.TARGET, Const.INPUT_LEN] + + for name, dataset in data_bundle.datasets.items(): + dataset.set_pad_val(Const.TARGET, self.target_pad_val) + dataset.add_seq_len(Const.INPUT) + + data_bundle.set_input(*input_fields) + data_bundle.set_target(*target_fields) + + return data_bundle + + def process_from_file(self, paths) -> DataBundle: + """ + + :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.ConllLoader` 的load函数。 + :return: DataBundle + """ + # 读取数据 + data_bundle = Conll2003NERLoader().load(paths) + data_bundle = self.process(data_bundle) + + return data_bundle + + +class Conll2003NERPipe(_NERPipe): + """ + Conll2003的NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表 + (创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将words,target列根据相应的 + Vocabulary转换为index。 + 经过该Pipe过后,DataSet中的内容如下所示 + + .. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader + :header: "raw_words", "words", "target", "seq_len" + + "[Nadim, Ladki]", "[1, 2]", "[1, 2]", 2 + "[AL-AIN, United, Arab, ...]", "[3, 4, 5,...]", "[3, 4]", 10 + "[...]", "[...]", "[...]", . + + raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 + target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target。 + + :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 + :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 + :param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为-100。 + """ + + def process_from_file(self, paths) -> DataBundle: + """ + + :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.ConllLoader` 的load函数。 + :return: DataBundle + """ + # 读取数据 + data_bundle = Conll2003NERLoader().load(paths) + data_bundle = self.process(data_bundle) + + return data_bundle + + +class OntoNotesNERPipe(_NERPipe): + """ + 处理OntoNotes的NER数据,处理之后DataSet中的field情况为 + + .. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader + :header: "raw_words", "words", "target", "seq_len" + + "[Nadim, Ladki]", "[1, 2]", "[1, 2]", 2 + "[AL-AIN, United, Arab, ...]", "[3, 4, 5,...]", "[3, 4]", 6 + "[...]", "[...]", "[...]", . + + + :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 + :param bool delete_unused_fields: 是否删除NER任务中用不到的field。 + :param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为-100。 + """ + + def process_from_file(self, paths): + data_bundle = OntoNotesNERLoader().load(paths) + return self.process(data_bundle) + diff --git a/fastNLP/io/pipe/matching.py b/fastNLP/io/pipe/matching.py new file mode 100644 index 00000000..76a0eaf7 --- /dev/null +++ b/fastNLP/io/pipe/matching.py @@ -0,0 +1,254 @@ +import math + +from .pipe import Pipe +from .utils import get_tokenizer +from ...core import Const +from ...core import Vocabulary +from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader + +class MatchingBertPipe(Pipe): + """ + Matching任务的Bert pipe,输出的DataSet将包含以下的field + + .. csv-table:: + :header: "raw_words1", "raw_words2", "words", "target", "seq_len" + + "The new rights are...", "Everyone really likes..", "[2, 3, 4, 5, ...]", 1, 10 + "This site includes a...", "The Government Executive...", "[11, 12, 13,...]", 0, 5 + "...", "...", "[...]", ., . + + words列是将raw_words1(即premise), raw_words2(即hypothesis)使用"[SEP]"链接起来转换为index的。 + words列被设置为input,target列被设置为target. + + :param bool lower: 是否将word小写化。 + :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 + :param int max_concat_sent_length: 如果concat后的句子长度超过了该值,则合并后的句子将被截断到这个长度,截断时同时对premise + 和hypothesis按比例截断。 + """ + def __init__(self, lower=False, tokenizer:str='raw', max_concat_sent_length:int=480): + super().__init__() + + self.lower = bool(lower) + self.tokenizer = get_tokenizer(tokenizer=tokenizer) + self.max_concat_sent_length = int(max_concat_sent_length) + + def _tokenize(self, data_bundle, field_names, new_field_names): + """ + + :param DataBundle data_bundle: DataBundle. + :param list field_names: List[str], 需要tokenize的field名称 + :param list new_field_names: List[str], tokenize之后field的名称,与field_names一一对应。 + :return: 输入的DataBundle对象 + """ + for name, dataset in data_bundle.datasets.items(): + for field_name, new_field_name in zip(field_names, new_field_names): + dataset.apply_field(lambda words:self.tokenizer(words), field_name=field_name, + new_field_name=new_field_name) + return data_bundle + + def process(self, data_bundle): + for name, dataset in data_bundle.datasets.items(): + dataset.copy_field(Const.RAW_WORDS(0), Const.INPUTS(0)) + dataset.copy_field(Const.RAW_WORDS(1), Const.INPUTS(1)) + + if self.lower: + for name, dataset in data_bundle.datasets.items(): + dataset[Const.INPUTS(0)].lower() + dataset[Const.INPUTS(1)].lower() + + data_bundle = self._tokenize(data_bundle, [Const.INPUTS(0), Const.INPUT(1)], + [Const.INPUTS(0), Const.INPUTS(1)]) + + # concat两个words + def concat(ins): + words0 = ins[Const.INPUTS(0)] + words1 = ins[Const.INPUTS(1)] + len0 = len(words0) + len1 = len(words1) + if len0 + len1 > self.max_concat_sent_length: + ratio = self.max_concat_sent_length / (len0 + len1) + len0 = math.floor(ratio * len0) + len1 = math.floor(ratio * len1) + words0 = words0[:len0] + words1 = words1[:len1] + + words = words0 + ['[SEP]'] + words1 + return words + for name, dataset in data_bundle.datasets.items(): + dataset.apply(concat, new_field_name=Const.INPUT) + dataset.delete_field(Const.INPUTS(0)) + dataset.delete_field(Const.INPUTS(1)) + + word_vocab = Vocabulary() + word_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.INPUT, + no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if + name != 'train']) + word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT) + + target_vocab = Vocabulary(padding=None, unknown=None) + target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) + has_target_datasets = [] + for name, dataset in data_bundle.datasets.items(): + if dataset.has_field(Const.TARGET): + has_target_datasets.append(dataset) + target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) + + data_bundle.set_vocab(word_vocab, Const.INPUT) + data_bundle.set_vocab(target_vocab, Const.TARGET) + + input_fields = [Const.INPUT, Const.INPUT_LEN] + target_fields = [Const.TARGET] + + for name, dataset in data_bundle.datasets.items(): + dataset.add_seq_len(Const.INPUT) + dataset.set_input(*input_fields, flag=True) + dataset.set_target(*target_fields, flag=True) + + return data_bundle + + +class RTEBertPipe(MatchingBertPipe): + def process_from_file(self, paths=None): + data_bundle = RTELoader().load(paths) + return self.process(data_bundle) + + +class SNLIBertPipe(MatchingBertPipe): + def process_from_file(self, paths=None): + data_bundle = SNLILoader().load(paths) + return self.process(data_bundle) + + +class QuoraBertPipe(MatchingBertPipe): + def process_from_file(self, paths): + data_bundle = QuoraLoader().load(paths) + return self.process(data_bundle) + + +class QNLIBertPipe(MatchingBertPipe): + def process_from_file(self, paths=None): + data_bundle = QNLILoader().load(paths) + return self.process(data_bundle) + + +class MNLIBertPipe(MatchingBertPipe): + def process_from_file(self, paths=None): + data_bundle = MNLILoader().load(paths) + return self.process(data_bundle) + + +class MatchingPipe(Pipe): + """ + Matching任务的Pipe。输出的DataSet将包含以下的field + + .. csv-table:: + :header: "raw_words1", "raw_words2", "words1", "words2", "target", "seq_len1", "seq_len2" + + "The new rights are...", "Everyone really likes..", "[2, 3, 4, 5, ...]", "[10, 20, 6]", 1, 10, 13 + "This site includes a...", "The Government Executive...", "[11, 12, 13,...]", "[2, 7, ...]", 0, 6, 7 + "...", "...", "[...]", "[...]", ., ., . + + words1是premise,words2是hypothesis。其中words1,words2,seq_len1,seq_len2被设置为input;target被设置为target。 + + :param bool lower: 是否将所有raw_words转为小写。 + :param str tokenizer: 将原始数据tokenize的方式。支持spacy, raw. spacy是使用spacy切分,raw就是用空格切分。 + """ + def __init__(self, lower=False, tokenizer:str='raw'): + super().__init__() + + self.lower = bool(lower) + self.tokenizer = get_tokenizer(tokenizer=tokenizer) + + def _tokenize(self, data_bundle, field_names, new_field_names): + """ + + :param DataBundle data_bundle: DataBundle. + :param list field_names: List[str], 需要tokenize的field名称 + :param list new_field_names: List[str], tokenize之后field的名称,与field_names一一对应。 + :return: 输入的DataBundle对象 + """ + for name, dataset in data_bundle.datasets.items(): + for field_name, new_field_name in zip(field_names, new_field_names): + dataset.apply_field(lambda words:self.tokenizer(words), field_name=field_name, + new_field_name=new_field_name) + return data_bundle + + def process(self, data_bundle): + """ + 接受的DataBundle中的DataSet应该具有以下的field, target列可以没有 + + .. csv-table:: + :header: "raw_words1", "raw_words2", "target" + + "The new rights are...", "Everyone really likes..", "entailment" + "This site includes a...", "The Government Executive...", "not_entailment" + "...", "..." + + :param data_bundle: + :return: + """ + data_bundle = self._tokenize(data_bundle, [Const.RAW_WORDS(0), Const.RAW_WORDS(1)], + [Const.INPUTS(0), Const.INPUTS(1)]) + + if self.lower: + for name, dataset in data_bundle.datasets.items(): + dataset[Const.INPUTS(0)].lower() + dataset[Const.INPUTS(1)].lower() + + word_vocab = Vocabulary() + word_vocab.from_dataset(data_bundle.datasets['train'], field_name=[Const.INPUTS(0), Const.INPUTS(1)], + no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if + name != 'train']) + word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=[Const.INPUTS(0), Const.INPUTS(1)]) + + target_vocab = Vocabulary(padding=None, unknown=None) + target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) + has_target_datasets = [] + for name, dataset in data_bundle.datasets.items(): + if dataset.has_field(Const.TARGET): + has_target_datasets.append(dataset) + target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) + + data_bundle.set_vocab(word_vocab, Const.INPUTS(0)) + data_bundle.set_vocab(target_vocab, Const.TARGET) + + input_fields = [Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LEN(0), Const.INPUT_LEN(1)] + target_fields = [Const.TARGET] + + for name, dataset in data_bundle.datasets.items(): + dataset.add_seq_len(Const.INPUTS(0), Const.INPUT_LEN(0)) + dataset.add_seq_len(Const.INPUTS(1), Const.INPUT_LEN(1)) + dataset.set_input(*input_fields, flag=True) + dataset.set_target(*target_fields, flag=True) + + return data_bundle + + +class RTEPipe(MatchingPipe): + def process_from_file(self, paths=None): + data_bundle = RTELoader().load(paths) + return self.process(data_bundle) + + +class SNLIPipe(MatchingPipe): + def process_from_file(self, paths=None): + data_bundle = SNLILoader().load(paths) + return self.process(data_bundle) + + +class QuoraPipe(MatchingPipe): + def process_from_file(self, paths): + data_bundle = QuoraLoader().load(paths) + return self.process(data_bundle) + +class QNLIPipe(MatchingPipe): + def process_from_file(self, paths=None): + data_bundle = QNLILoader().load(paths) + return self.process(data_bundle) + + +class MNLIPipe(MatchingPipe): + def process_from_file(self, paths=None): + data_bundle = MNLILoader().load(paths) + return self.process(data_bundle) + diff --git a/fastNLP/io/pipe/pipe.py b/fastNLP/io/pipe/pipe.py new file mode 100644 index 00000000..14c3866a --- /dev/null +++ b/fastNLP/io/pipe/pipe.py @@ -0,0 +1,9 @@ + +from .. import DataBundle + +class Pipe: + def process(self, data_bundle:DataBundle)->DataBundle: + raise NotImplementedError + + def process_from_file(self, paths)->DataBundle: + raise NotImplementedError diff --git a/fastNLP/io/pipe/utils.py b/fastNLP/io/pipe/utils.py new file mode 100644 index 00000000..59bee96e --- /dev/null +++ b/fastNLP/io/pipe/utils.py @@ -0,0 +1,142 @@ +from typing import List +from ...core import Vocabulary +from ...core import Const + +def iob2(tags:List[str])->List[str]: + """ + 检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两种格式的区别见https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format + + :param tags: 需要转换的tags + """ + for i, tag in enumerate(tags): + if tag == "O": + continue + split = tag.split("-") + if len(split) != 2 or split[0] not in ["I", "B"]: + raise TypeError("The encoding schema is not a valid IOB type.") + if split[0] == "B": + continue + elif i == 0 or tags[i - 1] == "O": # conversion IOB1 to IOB2 + tags[i] = "B" + tag[1:] + elif tags[i - 1][1:] == tag[1:]: + continue + else: # conversion IOB1 to IOB2 + tags[i] = "B" + tag[1:] + return tags + +def iob2bioes(tags:List[str])->List[str]: + """ + 将iob的tag转换为bioes编码 + :param tags: + :return: + """ + new_tags = [] + for i, tag in enumerate(tags): + if tag == 'O': + new_tags.append(tag) + else: + split = tag.split('-')[0] + if split == 'B': + if i+1!=len(tags) and tags[i+1].split('-')[0] == 'I': + new_tags.append(tag) + else: + new_tags.append(tag.replace('B-', 'S-')) + elif split == 'I': + if i + 1Dict[str, str]: +def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: """ 检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果 { @@ -11,13 +12,14 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: 'test': 'xxx' # 可能有,也可能没有 ... } - 如果paths为不合法的,将直接进行raise相应的错误 + 如果paths为不合法的,将直接进行raise相应的错误. 如果paths内不包含train也会报错。 - :param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名 + :param str paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名 中包含train这个字段), test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 :return: """ - if isinstance(paths, str): + if isinstance(paths, (str, Path)): + paths = os.path.abspath(os.path.expanduser(paths)) if os.path.isfile(paths): return {'train': paths} elif os.path.isdir(paths): @@ -37,6 +39,8 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: path_pair = ('test', filename) if path_pair: files[path_pair[0]] = os.path.join(paths, path_pair[1]) + if 'train' not in files: + raise KeyError(f"There is no train file in {paths}.") return files else: raise FileNotFoundError(f"{paths} is not a valid file path.") @@ -47,8 +51,10 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: raise KeyError("You have to include `train` in your dict.") for key, value in paths.items(): if isinstance(key, str) and isinstance(value, str): + value = os.path.abspath(os.path.expanduser(value)) if not os.path.isfile(value): raise TypeError(f"{value} is not a valid file.") + paths[key] = value else: raise TypeError("All keys and values in paths should be str.") return paths diff --git a/test/embeddings/__init__.py b/test/embeddings/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/modules/encoder/test_bert.py b/test/embeddings/test_bert.py similarity index 100% rename from test/modules/encoder/test_bert.py rename to test/embeddings/test_bert.py diff --git a/test/embeddings/test_elmo_embedding.py b/test/embeddings/test_elmo_embedding.py new file mode 100644 index 00000000..a087f0a4 --- /dev/null +++ b/test/embeddings/test_elmo_embedding.py @@ -0,0 +1,21 @@ + +import unittest +from fastNLP import Vocabulary +from fastNLP.embeddings import ElmoEmbedding +import torch +import os + +@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") +class TestDownload(unittest.TestCase): + def test_download_small(self): + # import os + vocab = Vocabulary().add_word_lst("This is a test .".split()) + elmo_embed = ElmoEmbedding(vocab, model_dir_or_name='en-small') + words = torch.LongTensor([[0, 1, 2]]) + print(elmo_embed(words).size()) + + +# 首先保证所有权重可以加载;上传权重;验证可以下载 + + + diff --git a/test/io/loader/test_classification_loader.py b/test/io/loader/test_classification_loader.py new file mode 100644 index 00000000..28f08921 --- /dev/null +++ b/test/io/loader/test_classification_loader.py @@ -0,0 +1,19 @@ + +import unittest +from fastNLP.io.loader.classification import YelpFullLoader +from fastNLP.io.loader.classification import YelpPolarityLoader +from fastNLP.io.loader.classification import IMDBLoader +from fastNLP.io.loader.classification import SST2Loader +from fastNLP.io.loader.classification import SSTLoader +import os + +@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") +class TestDownload(unittest.TestCase): + def test_download(self): + for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader]: + loader().download() + + def test_load(self): + for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader]: + data_bundle = loader().load() + print(data_bundle) diff --git a/test/io/loader/test_matching_loader.py b/test/io/loader/test_matching_loader.py new file mode 100644 index 00000000..5c1a91f1 --- /dev/null +++ b/test/io/loader/test_matching_loader.py @@ -0,0 +1,22 @@ + +import unittest +from fastNLP.io.loader.matching import RTELoader +from fastNLP.io.loader.matching import QNLILoader +from fastNLP.io.loader.matching import SNLILoader +from fastNLP.io.loader.matching import QuoraLoader +from fastNLP.io.loader.matching import MNLILoader +import os + +@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") +class TestDownload(unittest.TestCase): + def test_download(self): + for loader in [RTELoader, QNLILoader, SNLILoader, MNLILoader]: + loader().download() + with self.assertRaises(Exception): + QuoraLoader().load() + + def test_load(self): + for loader in [RTELoader, QNLILoader, SNLILoader, MNLILoader]: + data_bundle = loader().load() + print(data_bundle) + diff --git a/test/io/pipe/test_classification.py b/test/io/pipe/test_classification.py new file mode 100644 index 00000000..39dc71e0 --- /dev/null +++ b/test/io/pipe/test_classification.py @@ -0,0 +1,13 @@ +import unittest +import os + +from fastNLP.io.pipe.classification import SSTPipe, SST2Pipe, IMDBPipe, YelpFullPipe, YelpPolarityPipe + +@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") +class TestPipe(unittest.TestCase): + def test_process_from_file(self): + for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]: + with self.subTest(pipe=pipe): + print(pipe) + data_bundle = pipe(tokenizer='raw').process_from_file() + print(data_bundle) diff --git a/test/io/pipe/test_matching.py b/test/io/pipe/test_matching.py new file mode 100644 index 00000000..c057bb0c --- /dev/null +++ b/test/io/pipe/test_matching.py @@ -0,0 +1,26 @@ + +import unittest +import os + +from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe +from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe + + +@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") +class TestPipe(unittest.TestCase): + def test_process_from_file(self): + for pipe in [SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe]: + with self.subTest(pipe=pipe): + print(pipe) + data_bundle = pipe(tokenizer='raw').process_from_file() + print(data_bundle) + + +@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") +class TestBertPipe(unittest.TestCase): + def test_process_from_file(self): + for pipe in [SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe]: + with self.subTest(pipe=pipe): + print(pipe) + data_bundle = pipe(tokenizer='raw').process_from_file() + print(data_bundle)