Browse Source

1. 分类DataSetLoader中的Loader功能Pipe功能; 2. 增加数据集自动下载; 3.修复vocabulary中的bug

tags/v0.4.10
yh 6 years ago
parent
commit
014e9786c7
43 changed files with 2802 additions and 227 deletions
  1. +3
    -0
      .travis.yml
  2. +12
    -0
      fastNLP/core/batch.py
  3. +20
    -6
      fastNLP/core/const.py
  4. +28
    -5
      fastNLP/core/dataset.py
  5. +1
    -9
      fastNLP/core/field.py
  6. +7
    -0
      fastNLP/core/instance.py
  7. +21
    -0
      fastNLP/core/utils.py
  8. +36
    -27
      fastNLP/core/vocabulary.py
  9. +2
    -1
      fastNLP/embeddings/__init__.py
  10. +6
    -10
      fastNLP/embeddings/bert_embedding.py
  11. +3
    -5
      fastNLP/embeddings/elmo_embedding.py
  12. +5
    -7
      fastNLP/embeddings/static_embedding.py
  13. +87
    -2
      fastNLP/io/base_loader.py
  14. +87
    -29
      fastNLP/io/data_loader/conll.py
  15. +1
    -1
      fastNLP/io/data_loader/matching.py
  16. +2
    -2
      fastNLP/io/data_loader/mtl.py
  17. +6
    -4
      fastNLP/io/data_loader/sst.py
  18. +2
    -2
      fastNLP/io/data_loader/yelp.py
  19. +0
    -22
      fastNLP/io/dataset_loader.py
  20. +5
    -5
      fastNLP/io/file_reader.py
  21. +191
    -86
      fastNLP/io/file_utils.py
  22. +30
    -0
      fastNLP/io/loader/__init__.py
  23. +369
    -0
      fastNLP/io/loader/classification.py
  24. +264
    -0
      fastNLP/io/loader/conll.py
  25. +32
    -0
      fastNLP/io/loader/csv.py
  26. +41
    -0
      fastNLP/io/loader/cws.py
  27. +40
    -0
      fastNLP/io/loader/json.py
  28. +75
    -0
      fastNLP/io/loader/loader.py
  29. +309
    -0
      fastNLP/io/loader/matching.py
  30. +8
    -0
      fastNLP/io/pipe/__init__.py
  31. +444
    -0
      fastNLP/io/pipe/classification.py
  32. +149
    -0
      fastNLP/io/pipe/conll.py
  33. +254
    -0
      fastNLP/io/pipe/matching.py
  34. +9
    -0
      fastNLP/io/pipe/pipe.py
  35. +142
    -0
      fastNLP/io/pipe/utils.py
  36. +10
    -4
      fastNLP/io/utils.py
  37. +0
    -0
      test/embeddings/__init__.py
  38. +0
    -0
      test/embeddings/test_bert.py
  39. +21
    -0
      test/embeddings/test_elmo_embedding.py
  40. +19
    -0
      test/io/loader/test_classification_loader.py
  41. +22
    -0
      test/io/loader/test_matching_loader.py
  42. +13
    -0
      test/io/pipe/test_classification.py
  43. +26
    -0
      test/io/pipe/test_matching.py

+ 3
- 0
.travis.yml View File

@@ -1,6 +1,9 @@
language: python
python:
- "3.6"

env
- TRAVIS=1
# command to install dependencies
install:
- pip install --quiet -r requirements.txt


+ 12
- 0
fastNLP/core/batch.py View File

@@ -48,6 +48,11 @@ class DataSetGetter:
return len(self.dataset)

def collate_fn(self, batch: list):
"""

:param batch: [[idx1, x_dict1, y_dict1], [idx2, x_dict2, y_dict2], [xx, xx, xx]]
:return:
"""
# TODO 支持在DataSet中定义collate_fn,因为有时候可能需要不同的field之间融合,比如BERT的场景
batch_x = {n:[] for n in self.inputs.keys()}
batch_y = {n:[] for n in self.targets.keys()}
@@ -208,6 +213,13 @@ class OnlineDataIter(BatchIter):


def _to_tensor(batch, field_dtype):
"""

:param batch: np.array()
:param field_dtype: 数据类型
:return: batch, flag. 如果传入的数据支持转为tensor,返回的batch就是tensor,且flag为True;如果传入的数据不支持转为tensor,
返回的batch就是原来的数据,且flag为False
"""
try:
if field_dtype is not None and isinstance(field_dtype, type)\
and issubclass(field_dtype, Number) \


+ 20
- 6
fastNLP/core/const.py View File

@@ -7,12 +7,14 @@ class Const:
具体列表::

INPUT 模型的序列输入 words(复数words1, words2)
CHAR_INPUT 模型character输入 chars(复数chars1, chars2)
INPUT_LEN 序列长度 seq_len(复数seq_len1,seq_len2)
OUTPUT 模型输出 pred(复数pred1, pred2)
TARGET 真实目标 target(复数target1,target2)
LOSS 损失函数 loss (复数loss1,loss2)
INPUT 模型的序列输入 words(具有多列words时,依次使用words1, words2, )
CHAR_INPUT 模型character输入 chars(具有多列chars时,依次使用chars1, chars2)
INPUT_LEN 序列长度 seq_len(具有多列seq_len时,依次使用seq_len1,seq_len2)
OUTPUT 模型输出 pred(具有多列pred时,依次使用pred1, pred2)
TARGET 真实目标 target(具有多列target时,依次使用target1,target2)
LOSS 损失函数 loss (具有多列loss时,依次使用loss1,loss2)
RAW_WORD 原文的词 raw_words (具有多列raw_words时,依次使用raw_words1, raw_words2)
RAW_CHAR 原文的字 raw_chars (具有多列raw_chars时,依次使用raw_chars1, raw_chars2)

"""
INPUT = 'words'
@@ -21,6 +23,8 @@ class Const:
OUTPUT = 'pred'
TARGET = 'target'
LOSS = 'loss'
RAW_WORD = 'raw_words'
RAW_CHAR = 'raw_chars'

@staticmethod
def INPUTS(i):
@@ -34,6 +38,16 @@ class Const:
i = int(i) + 1
return Const.CHAR_INPUT + str(i)

@staticmethod
def RAW_WORDS(i):
i = int(i) + 1
return Const.RAW_WORD + str(i)

@staticmethod
def RAW_CHARS(i):
i = int(i) + 1
return Const.RAW_CHAR + str(i)

@staticmethod
def INPUT_LENS(i):
"""得到第 i 个 ``INPUT_LEN`` 的命名"""


+ 28
- 5
fastNLP/core/dataset.py View File

@@ -291,6 +291,7 @@ import _pickle as pickle
import warnings

import numpy as np
from copy import deepcopy

from .field import AutoPadder
from .field import FieldArray
@@ -298,6 +299,7 @@ from .instance import Instance
from .utils import _get_func_signature
from .field import AppendToTargetOrInputException
from .field import SetInputOrTargetException
from .const import Const

class DataSet(object):
"""
@@ -349,7 +351,11 @@ class DataSet(object):
self.idx])
assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx)
return self.dataset.field_arrays[item][self.idx]

def items(self):
ins = self.dataset[self.idx]
return ins.items()

def __repr__(self):
return self.dataset[self.idx].__repr__()
@@ -497,6 +503,7 @@ class DataSet(object):
else:
for field in self.field_arrays.values():
field.pop(index)
return self
def delete_field(self, field_name):
"""
@@ -505,7 +512,22 @@ class DataSet(object):
:param str field_name: 需要删除的field的名称.
"""
self.field_arrays.pop(field_name)
return self

def copy_field(self, field_name, new_field_name):
"""
深度copy名为field_name的field到new_field_name

:param str field_name: 需要copy的field。
:param str new_field_name: copy生成的field名称
:return: self
"""
if not self.has_field(field_name):
raise KeyError(f"Field:{field_name} not found in DataSet.")
fieldarray = deepcopy(self.get_field(field_name))
self.add_fieldarray(field_name=new_field_name, fieldarray=fieldarray)
return self

def has_field(self, field_name):
"""
判断DataSet中是否有名为field_name这个field
@@ -701,7 +723,7 @@ class DataSet(object):
results.append(func(ins[field_name]))
except Exception as e:
if idx != -1:
print("Exception happens at the `{}`th instance.".format(idx))
print("Exception happens at the `{}`th(from 1) instance.".format(idx+1))
raise e
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None
raise ValueError("{} always return None.".format(_get_func_signature(func=func)))
@@ -766,10 +788,11 @@ class DataSet(object):
results = []
for idx, ins in enumerate(self._inner_iter()):
results.append(func(ins))
except Exception as e:
except BaseException as e:
if idx != -1:
print("Exception happens at the `{}`th instance.".format(idx))
raise e

# results = [func(ins) for ins in self._inner_iter()]
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None
raise ValueError("{} always return None.".format(_get_func_signature(func=func)))
@@ -779,7 +802,7 @@ class DataSet(object):
return results

def add_seq_len(self, field_name:str, new_field_name='seq_len'):
def add_seq_len(self, field_name:str, new_field_name=Const.INPUT_LEN):
"""
将使用len()直接对field_name中每个元素作用,将其结果作为seqence length, 并放入seq_len这个field。



+ 1
- 9
fastNLP/core/field.py View File

@@ -7,6 +7,7 @@ from typing import Any
from abc import abstractmethod
from copy import deepcopy
from collections import Counter
from .utils import _is_iterable

class SetInputOrTargetException(Exception):
def __init__(self, msg, index=None, field_name=None):
@@ -443,15 +444,6 @@ def _get_ele_type_and_dim(cell:Any, dim=0):
raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.")


def _is_iterable(value):
# 检查是否是iterable的, duck typing
try:
iter(value)
return True
except BaseException as e:
return False


class Padder:
"""
别名::class:`fastNLP.Padder` :class:`fastNLP.core.field.Padder`


+ 7
- 0
fastNLP/core/instance.py View File

@@ -35,6 +35,13 @@ class Instance(object):
:param Any field: 新增field的内容
"""
self.fields[field_name] = field

def items(self):
"""
返回一个迭代器,迭代器返回两个内容,第一个内容是field_name, 第二个内容是field_value
:return:
"""
return self.fields.items()
def __getitem__(self, name):
if name in self.fields:


+ 21
- 0
fastNLP/core/utils.py View File

@@ -4,6 +4,7 @@ utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户
__all__ = [
"cache_results",
"seq_len_to_mask",
"get_seq_len"
]

import _pickle
@@ -730,3 +731,23 @@ def iob2bioes(tags: List[str]) -> List[str]:
else:
raise TypeError("Invalid IOB format.")
return new_tags


def _is_iterable(value):
# 检查是否是iterable的, duck typing
try:
iter(value)
return True
except BaseException as e:
return False


def get_seq_len(words, pad_value=0):
"""
给定batch_size x max_len的words矩阵,返回句子长度

:param words: batch_size x max_len
:return: (batch_size,)
"""
mask = words.ne(pad_value)
return mask.sum(dim=-1)

+ 36
- 27
fastNLP/core/vocabulary.py View File

@@ -4,12 +4,12 @@ __all__ = [
]

from functools import wraps
from collections import Counter, defaultdict
from collections import Counter
from .dataset import DataSet
from .utils import Option
from functools import partial
import numpy as np
from .utils import _is_iterable

class VocabularyOption(Option):
def __init__(self,
@@ -131,11 +131,11 @@ class Vocabulary(object):
"""
在新加入word时,检查_no_create_word的设置。

:param str, List[str] word:
:param str List[str] word:
:param bool no_create_entry:
:return:
"""
if isinstance(word, str):
if isinstance(word, str) or not _is_iterable(word):
word = [word]
for w in word:
if no_create_entry and self.word_count.get(w, 0) == self._no_create_word.get(w, 0):
@@ -257,35 +257,45 @@ class Vocabulary(object):
vocab.index_dataset(train_data, dev_data, test_data, field_name='words')

:param ~fastNLP.DataSet,List[~fastNLP.DataSet] datasets: 需要转index的一个或多个数据集
:param str field_name: 需要转index的field, 若有多个 DataSet, 每个DataSet都必须有此 field.
目前支持 ``str`` , ``List[str]`` , ``List[List[str]]``
:param str new_field_name: 保存结果的field_name. 若为 ``None`` , 将覆盖原field.
Default: ``None``
:param list,str field_name: 需要转index的field, 若有多个 DataSet, 每个DataSet都必须有此 field.
目前支持 ``str`` , ``List[str]``
:param list,str new_field_name: 保存结果的field_name. 若为 ``None`` , 将覆盖原field.
Default: ``None``.
"""
def index_instance(ins):
def index_instance(field):
"""
有几种情况, str, 1d-list, 2d-list
:param ins:
:return:
"""
field = ins[field_name]
if isinstance(field, str):
if isinstance(field, str) or not _is_iterable(field):
return self.to_index(field)
elif isinstance(field, list):
if not isinstance(field[0], list):
else:
if isinstance(field[0], str) or not _is_iterable(field[0]):
return [self.to_index(w) for w in field]
else:
if isinstance(field[0][0], list):
if not isinstance(field[0][0], str) and _is_iterable(field[0][0]):
raise RuntimeError("Only support field with 2 dimensions.")
return [[self.to_index(c) for c in w] for w in field]
if new_field_name is None:
new_field_name = field_name

new_field_name = new_field_name or field_name

if type(new_field_name) == type(field_name):
if isinstance(new_field_name, list):
assert len(new_field_name) == len(field_name), "new_field_name should have same number elements with " \
"field_name."
elif isinstance(new_field_name, str):
field_name = [field_name]
new_field_name = [new_field_name]
else:
raise TypeError("field_name and new_field_name can only be str or List[str].")

for idx, dataset in enumerate(datasets):
if isinstance(dataset, DataSet):
try:
dataset.apply(index_instance, new_field_name=new_field_name)
for f_n, n_f_n in zip(field_name, new_field_name):
dataset.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n)
except Exception as e:
print("When processing the `{}` dataset, the following error occurred.".format(idx))
raise e
@@ -306,9 +316,8 @@ class Vocabulary(object):

:param ~fastNLP.DataSet,List[~fastNLP.DataSet] datasets: 需要转index的一个或多个数据集
:param str,List[str] field_name: 可为 ``str`` 或 ``List[str]`` .
构建词典所使用的 field(s), 支持一个或多个field
若有多个 DataSet, 每个DataSet都必须有这些field.
目前仅支持的field结构: ``str`` , ``List[str]`` , ``list[List[str]]``
构建词典所使用的 field(s), 支持一个或多个field,若有多个 DataSet, 每个DataSet都必须有这些field. 目前支持的field结构
: ``str`` , ``List[str]``
:param no_create_entry_dataset: 可以传入DataSet, List[DataSet]或者None(默认),该选项用在接下来的模型会使用pretrain
的embedding(包括glove, word2vec, elmo与bert)且会finetune的情况。如果仅使用来自于train的数据建立vocabulary,会导致test与dev
中的数据无法充分利用到来自于预训练embedding的信息,所以在建立词表的时候将test与dev考虑进来会使得最终的结果更好。
@@ -326,14 +335,14 @@ class Vocabulary(object):
def construct_vocab(ins, no_create_entry=False):
for fn in field_name:
field = ins[fn]
if isinstance(field, str):
if isinstance(field, str) or not _is_iterable(field):
self.add_word(field, no_create_entry=no_create_entry)
elif isinstance(field, (list, np.ndarray)):
if not isinstance(field[0], (list, np.ndarray)):
else:
if isinstance(field[0], str) or not _is_iterable(field[0]):
for word in field:
self.add_word(word, no_create_entry=no_create_entry)
else:
if isinstance(field[0][0], (list, np.ndarray)):
if not isinstance(field[0][0], str) and _is_iterable(field[0][0]):
raise RuntimeError("Only support field with 2 dimensions.")
for words in field:
for word in words:
@@ -343,8 +352,8 @@ class Vocabulary(object):
if isinstance(dataset, DataSet):
try:
dataset.apply(construct_vocab)
except Exception as e:
print("When processing the `{}` dataset, the following error occurred.".format(idx))
except BaseException as e:
print("When processing the `{}` dataset, the following error occurred:".format(idx))
raise e
else:
raise TypeError("Only DataSet type is allowed.")


+ 2
- 1
fastNLP/embeddings/__init__.py View File

@@ -10,6 +10,7 @@ __all__ = [
"StaticEmbedding",
"ElmoEmbedding",
"BertEmbedding",
"BertWordPieceEncoder",
"StackEmbedding",
"LSTMCharEmbedding",
"CNNCharEmbedding",
@@ -20,7 +21,7 @@ __all__ = [
from .embedding import Embedding
from .static_embedding import StaticEmbedding
from .elmo_embedding import ElmoEmbedding
from .bert_embedding import BertEmbedding
from .bert_embedding import BertEmbedding, BertWordPieceEncoder
from .char_embedding import CNNCharEmbedding, LSTMCharEmbedding
from .stack_embedding import StackEmbedding
from .utils import get_embeddings

+ 6
- 10
fastNLP/embeddings/bert_embedding.py View File

@@ -8,7 +8,7 @@ import numpy as np
from itertools import chain

from ..core.vocabulary import Vocabulary
from ..io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR
from ..io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MODEL_DIR
from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer
from .contextual_embedding import ContextualEmbedding

@@ -60,10 +60,8 @@ class BertEmbedding(ContextualEmbedding):

# 根据model_dir_or_name检查是否存在并下载
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR:
PRETRAIN_URL = _get_base_url('bert')
model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name]
model_url = PRETRAIN_URL + model_name
model_dir = cached_path(model_url)
model_url = _get_embedding_url('bert', model_dir_or_name.lower())
model_dir = cached_path(model_url, name='embedding')
# 检查是否存在
elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))):
model_dir = os.path.expanduser(os.path.abspath(model_dir_or_name))
@@ -133,11 +131,9 @@ class BertWordPieceEncoder(nn.Module):
pooled_cls: bool = False, requires_grad: bool=False):
super().__init__()

if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR:
PRETRAIN_URL = _get_base_url('bert')
model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name]
model_url = PRETRAIN_URL + model_name
model_dir = cached_path(model_url)
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR:
model_url = _get_embedding_url('bert', model_dir_or_name.lower())
model_dir = cached_path(model_url, name='embedding')
# 检查是否存在
elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))):
model_dir = model_dir_or_name


+ 3
- 5
fastNLP/embeddings/elmo_embedding.py View File

@@ -8,7 +8,7 @@ import json
import codecs

from ..core.vocabulary import Vocabulary
from ..io.file_utils import cached_path, _get_base_url, PRETRAINED_ELMO_MODEL_DIR
from ..io.file_utils import cached_path, _get_embedding_url, PRETRAINED_ELMO_MODEL_DIR
from ..modules.encoder._elmo import ElmobiLm, ConvTokenEmbedder
from .contextual_embedding import ContextualEmbedding

@@ -53,10 +53,8 @@ class ElmoEmbedding(ContextualEmbedding):

# 根据model_dir_or_name检查是否存在并下载
if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR:
PRETRAIN_URL = _get_base_url('elmo')
model_name = PRETRAINED_ELMO_MODEL_DIR[model_dir_or_name]
model_url = PRETRAIN_URL + model_name
model_dir = cached_path(model_url)
model_url = _get_embedding_url('elmo', model_dir_or_name.lower())
model_dir = cached_path(model_url, name='embedding')
# 检查是否存在
elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))):
model_dir = model_dir_or_name


+ 5
- 7
fastNLP/embeddings/static_embedding.py View File

@@ -7,7 +7,7 @@ import numpy as np
import warnings

from ..core.vocabulary import Vocabulary
from ..io.file_utils import PRETRAIN_STATIC_FILES, _get_base_url, cached_path
from ..io.file_utils import PRETRAIN_STATIC_FILES, _get_embedding_url, cached_path
from .embedding import TokenEmbedding
from ..modules.utils import _get_file_name_base_on_postfix

@@ -60,10 +60,8 @@ class StaticEmbedding(TokenEmbedding):
embedding_dim = int(embedding_dim)
model_path = None
elif model_dir_or_name.lower() in PRETRAIN_STATIC_FILES:
PRETRAIN_URL = _get_base_url('static')
model_name = PRETRAIN_STATIC_FILES[model_dir_or_name]
model_url = PRETRAIN_URL + model_name
model_path = cached_path(model_url)
model_url = _get_embedding_url('static', model_dir_or_name.lower())
model_path = cached_path(model_url, name='embedding')
# 检查是否存在
elif os.path.isfile(os.path.expanduser(os.path.abspath(model_dir_or_name))):
model_path = model_dir_or_name
@@ -84,8 +82,8 @@ class StaticEmbedding(TokenEmbedding):
if lowered_word not in lowered_vocab.word_count:
lowered_vocab.add_word(lowered_word)
lowered_vocab._no_create_word[lowered_word] += 1
print(f"All word in the vocab have been lowered. There are {len(vocab)} words, {len(lowered_vocab)} unique lowered "
f"words.")
print(f"All word in the vocab have been lowered before finding pretrained vectors. There are {len(vocab)} "
f"words, {len(lowered_vocab)} unique lowered words.")
if model_path:
embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method)
else:


+ 87
- 2
fastNLP/io/base_loader.py View File

@@ -5,10 +5,10 @@ __all__ = [
]

import _pickle as pickle
import os
from typing import Union, Dict
import os
from ..core.dataset import DataSet
from ..core.vocabulary import Vocabulary


class BaseLoader(object):
@@ -111,7 +111,10 @@ def _uncompress(src, dst):

class DataBundle:
"""
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)以及各个field对应的vocabulary。
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)以及各个field对应的vocabulary。该对象一般由fastNLP中各种
DataSetLoader的load函数生成,可以通过以下的方法获取里面的内容

Example::

:param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict
@@ -121,6 +124,88 @@ class DataBundle:
self.vocabs = vocabs or {}
self.datasets = datasets or {}

def set_vocab(self, vocab, field_name):
"""
向DataBunlde中增加vocab

:param Vocabulary vocab: 词表
:param str field_name: 这个vocab对应的field名称
:return:
"""
assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary supports."
self.vocabs[field_name] = vocab

def set_dataset(self, dataset, name):
"""

:param DataSet dataset: 传递给DataBundle的DataSet
:param str name: dataset的名称
:return:
"""
self.datasets[name] = dataset

def get_dataset(self, name:str):
"""
获取名为name的dataset

:param str name: dataset的名称,一般为'train', 'dev', 'test'
:return: DataSet
"""
return self.datasets[name]

def get_vocab(self, field_name:str):
"""
获取field名为field_name对应的vocab

:param str field_name: 名称
:return: Vocabulary
"""
return self.vocabs[field_name]

def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_field=True):
"""
将field_names中的field设置为input, 对data_bundle中所有的dataset执行该操作::

data_bundle.set_input('words', 'seq_len') # 将words和seq_len这两个field的input属性设置为True
data_bundle.set_input('words', flag=False) # 将words这个field的input属性设置为False

:param str field_names: field的名称
:param bool flag: 将field_name的input状态设置为flag
:param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一
行的数据进行类型和维度推断本列的数据的类型和维度。
:param bool ignore_miss_field: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略; 如果为False,则报错
"""
for field_name in field_names:
for name, dataset in self.datasets.items():
if not ignore_miss_field and not dataset.has_field(field_name):
raise KeyError(f"Field:{field_name} was not found in DataSet:{name}")
if not dataset.has_field(field_name):
continue
else:
dataset.set_input(field_name, flag=flag, use_1st_ins_infer_dim_type=use_1st_ins_infer_dim_type)

def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_field=True):
"""
将field_names中的field设置为target, 对data_bundle中所有的dataset执行该操作::

data_bundle.set_target('target', 'seq_len') # 将words和target这两个field的input属性设置为True
data_bundle.set_target('target', flag=False) # 将target这个field的input属性设置为False

:param str field_names: field的名称
:param bool flag: 将field_name的target状态设置为flag
:param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一
行的数据进行类型和维度推断本列的数据的类型和维度。
:param bool ignore_miss_field: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略; 如果为False,则报错
"""
for field_name in field_names:
for name, dataset in self.datasets.items():
if not ignore_miss_field and not dataset.has_field(field_name):
raise KeyError(f"Field:{field_name} was not found in DataSet:{name}")
if not dataset.has_field(field_name):
continue
else:
dataset.set_target(field_name, flag=flag, use_1st_ins_infer_dim_type=use_1st_ins_infer_dim_type)

def __repr__(self):
_str = 'In total {} datasets:\n'.format(len(self.datasets))
for name, dataset in self.datasets.items():


+ 87
- 29
fastNLP/io/data_loader/conll.py View File

@@ -3,38 +3,47 @@ from ...core.dataset import DataSet
from ...core.instance import Instance
from ..base_loader import DataSetLoader
from ..file_reader import _read_conll

from typing import Union, Dict
from ..utils import check_loader_paths
from ..base_loader import DataBundle

class ConllLoader(DataSetLoader):
"""
别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.data_loader.ConllLoader`

读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html. 数据中以"-DOCSTART-"开头的行将被忽略,因为
该符号在conll 2003中被用为文档分割符。

列号从0开始, 每列对应内容为::

Column Type
0 Document ID
1 Part number
2 Word number
3 Word itself
4 Part-of-Speech
5 Parse bit
6 Predicate lemma
7 Predicate Frameset ID
8 Word sense
9 Speaker/Author
10 Named Entities
11:N Predicate Arguments
N Coreference

:param headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应
:param indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None``
:param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False``
该ConllLoader支持读取的数据格式: 以空行隔开两个sample,除了分割行,每一行用空格或者制表符隔开不同的元素。如下例所示:

Example::

# 文件中的内容
Nadim NNP B-NP B-PER
Ladki NNP I-NP I-PER

AL-AIN NNP B-NP B-LOC
United NNP B-NP B-LOC
Arab NNP I-NP I-LOC
Emirates NNPS I-NP I-LOC
1996-12-06 CD I-NP O
...

# 如果用以下的参数读取,返回的DataSet将包含raw_words和pos两个field, 这两个field的值分别取自于第0列与第1列
dataset = ConllLoader(headers=['raw_words', 'pos'], indexes=[0, 1])._load('/path/to/train.conll')
# 如果用以下的参数读取,返回的DataSet将包含raw_words和ner两个field, 这两个field的值分别取自于第0列与第2列
dataset = ConllLoader(headers=['raw_words', 'ner'], indexes=[0, 3])._load('/path/to/train.conll')
# 如果用以下的参数读取,返回的DataSet将包含raw_words, pos和ner三个field
dataset = ConllLoader(headers=['raw_words', 'pos', 'ner'], indexes=[0, 1, 3])._load('/path/to/train.conll')

dataset = ConllLoader(headers=['raw_words', 'pos'], indexes=[0, 1])._load('/path/to/train.conll')中DataSet的raw_words
列与pos列的内容都是List[str]

数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。

:param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应
:param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None``
:param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True``
"""

def __init__(self, headers, indexes=None, dropna=False):
def __init__(self, headers, indexes=None, dropna=True):
super(ConllLoader, self).__init__()
if not isinstance(headers, (list, tuple)):
raise TypeError(
@@ -49,25 +58,74 @@ class ConllLoader(DataSetLoader):
self.indexes = indexes

def _load(self, path):
"""
传入的一个文件路径,将该文件读入DataSet中,field由Loader初始化时指定的headers决定。

:param str path: 文件的路径
:return: DataSet
"""
ds = DataSet()
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna):
ins = {h: data[i] for i, h in enumerate(self.headers)}
ds.append(Instance(**ins))
return ds

def load(self, paths: Union[str, Dict[str, str]]) -> DataBundle:
"""
从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。

读取的field根据ConllLoader初始化时传入的headers决定。

:param Union[str, Dict[str, str]] paths: 支持以下的几种输入方式
(1) 传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件
名包含'train'、 'dev'、 'test'则会报错

Example::
data_bundle = ConllLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train, dev, test等有所变化
# 可以通过以下的方式取出DataSet
tr_data = data_bundle.datasets['train']
te_data = data_bundle.datasets['test'] # 如果目录下有文件包含test这个字段

(2) 传入文件path

Example::
data_bundle = ConllLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train'
tr_data = data_bundle.datasets['train'] # 可以通过以下的方式取出DataSet

(3) 传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test

Example::
paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"}
data_bundle = ConllLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test"
dev_data = data_bundle.datasets['dev']

:return: :class:`~fastNLP.DataSet` 类的对象或 :class:`~fastNLP.io.DataBundle` 的字典
"""
paths = check_loader_paths(paths)
datasets = {name: self._load(path) for name, path in paths.items()}
data_bundle = DataBundle(datasets=datasets)
return data_bundle


class Conll2003Loader(ConllLoader):
"""
别名::class:`fastNLP.io.Conll2003Loader` :class:`fastNLP.io.data_loader.Conll2003Loader`

读取Conll2003数据
该Loader用以读取Conll2003数据,conll2003的数据可以在https://github.com/davidsbatista/NER-datasets/tree/master/CONLL2003
找到。数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。

返回的DataSet将具有以下['raw_words', 'pos', 'chunks', 'ner']四个field, 每个field中的内容都是List[str]。

.. csv-table:: Conll2003Loader处理之 :header: "raw_words", "words", "target", "seq_len"

"[Nadim, Ladki]", "[1, 2]", "[1, 2]", 2
"[AL-AIN, United, Arab, ...]", "[3, 4, 5,...]", "[3, 4]", 5
"[...]", "[...]", "[...]", .

关于数据集的更多信息,参考:
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data
"""

def __init__(self):
headers = [
'tokens', 'pos', 'chunks', 'ner',
'raw_words', 'pos', 'chunks', 'ner',
]
super(Conll2003Loader, self).__init__(headers=headers)

+ 1
- 1
fastNLP/io/data_loader/matching.py View File

@@ -121,7 +121,7 @@ class MatchingLoader(DataSetLoader):
PRETRAIN_URL = _get_base_url('bert')
model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer]
model_url = PRETRAIN_URL + model_name
model_dir = cached_path(model_url)
model_dir = cached_path(model_url, name='embedding')
# 检查是否存在
elif os.path.isdir(bert_tokenizer):
model_dir = bert_tokenizer


+ 2
- 2
fastNLP/io/data_loader/mtl.py View File

@@ -5,7 +5,7 @@ from ..base_loader import DataBundle
from ..dataset_loader import CSVLoader
from ...core.vocabulary import Vocabulary, VocabularyOption
from ...core.const import Const
from ..utils import check_dataloader_paths
from ..utils import check_loader_paths


class MTL16Loader(CSVLoader):
@@ -38,7 +38,7 @@ class MTL16Loader(CSVLoader):
src_vocab_opt: VocabularyOption = None,
tgt_vocab_opt: VocabularyOption = None,):

paths = check_dataloader_paths(paths)
paths = check_loader_paths(paths)
datasets = {}
info = DataBundle()
for name, path in paths.items():


+ 6
- 4
fastNLP/io/data_loader/sst.py View File

@@ -8,7 +8,7 @@ from ...core.vocabulary import VocabularyOption, Vocabulary
from ...core.dataset import DataSet
from ...core.const import Const
from ...core.instance import Instance
from ..utils import check_dataloader_paths, get_tokenizer
from ..utils import check_loader_paths, get_tokenizer


class SSTLoader(DataSetLoader):
@@ -67,7 +67,7 @@ class SSTLoader(DataSetLoader):
paths, train_subtree=True,
src_vocab_op: VocabularyOption = None,
tgt_vocab_op: VocabularyOption = None,):
paths = check_dataloader_paths(paths)
paths = check_loader_paths(paths)
input_name, target_name = 'words', 'target'
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op)
tgt_vocab = Vocabulary(unknown=None, padding=None) \
@@ -129,7 +129,7 @@ class SST2Loader(CSVLoader):
tgt_vocab_opt: VocabularyOption = None,
char_level_op=False):

paths = check_dataloader_paths(paths)
paths = check_loader_paths(paths)
datasets = {}
info = DataBundle()
for name, path in paths.items():
@@ -155,7 +155,9 @@ class SST2Loader(CSVLoader):
for dataset in datasets.values():
dataset.apply_field(wordtochar, field_name=Const.INPUT, new_field_name=Const.CHAR_INPUT)
src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt)
src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT)
src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT, no_create_entry_dataset=[
dataset for name, dataset in datasets.items() if name!='train'
])
src_vocab.index_dataset(*datasets.values(), field_name=Const.INPUT)

tgt_vocab = Vocabulary(unknown=None, padding=None) \


+ 2
- 2
fastNLP/io/data_loader/yelp.py View File

@@ -8,7 +8,7 @@ from ...core.instance import Instance
from ...core.vocabulary import VocabularyOption, Vocabulary
from ..base_loader import DataBundle, DataSetLoader
from typing import Union, Dict
from ..utils import check_dataloader_paths, get_tokenizer
from ..utils import check_loader_paths, get_tokenizer


class YelpLoader(DataSetLoader):
@@ -62,7 +62,7 @@ class YelpLoader(DataSetLoader):
src_vocab_op: VocabularyOption = None,
tgt_vocab_op: VocabularyOption = None,
char_level_op=False):
paths = check_dataloader_paths(paths)
paths = check_loader_paths(paths)
info = DataBundle(datasets=self.load(paths))
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op)
tgt_vocab = Vocabulary(unknown=None, padding=None) \


+ 0
- 22
fastNLP/io/dataset_loader.py View File

@@ -114,25 +114,3 @@ def _cut_long_sentence(sent, max_sample_length=200):
else:
cutted_sentence.append(sent)
return cutted_sentence


def _add_seg_tag(data):
"""

:param data: list of ([word], [pos], [heads], [head_tags])
:return: list of ([word], [pos])
"""

_processed = []
for word_list, pos_list, _, _ in data:
new_sample = []
for word, pos in zip(word_list, pos_list):
if len(word) == 1:
new_sample.append((word, 'S-' + pos))
else:
new_sample.append((word[0], 'B-' + pos))
for c in word[1:-1]:
new_sample.append((c, 'M-' + pos))
new_sample.append((word[-1], 'E-' + pos))
_processed.append(list(map(list, zip(*new_sample))))
return _processed

+ 5
- 5
fastNLP/io/file_reader.py View File

@@ -2,7 +2,7 @@
此模块用于给其它模块提供读取文件的函数,没有为用户提供 API
"""
import json
import warnings

def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True):
"""
@@ -91,7 +91,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
with open(path, 'r', encoding=encoding) as f:
sample = []
start = next(f).strip()
if '-DOCSTART-' not in start and start!='':
if start!='':
sample.append(start.split())
for line_idx, line in enumerate(f, 1):
line = line.strip()
@@ -103,13 +103,13 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
yield line_idx, res
except Exception as e:
if dropna:
warnings.warn('Invalid instance ends at line: {} has been dropped.'.format(line_idx))
continue
raise ValueError('invalid instance ends at line: {}'.format(line_idx))
raise ValueError('Invalid instance ends at line: {}'.format(line_idx))
elif line.startswith('#'):
continue
else:
if not line.startswith('-DOCSTART-'):
sample.append(line.split())
sample.append(line.split())
if len(sample) > 0:
try:
res = parse_conll(sample)


+ 191
- 86
fastNLP/io/file_utils.py View File

@@ -7,7 +7,7 @@ import requests
import tempfile
from tqdm import tqdm
import shutil
import hashlib
from requests import HTTPError


PRETRAINED_BERT_MODEL_DIR = {
@@ -23,15 +23,25 @@ PRETRAINED_BERT_MODEL_DIR = {

'cn': 'bert-base-chinese-29d0a84a.zip',
'cn-base': 'bert-base-chinese-29d0a84a.zip',

'multilingual': 'bert-base-multilingual-cased.zip',
'multilingual-base-uncased': 'bert-base-multilingual-uncased.zip',
'multilingual-base-cased': 'bert-base-multilingual-cased.zip',
'bert-base-chinese': 'bert-base-chinese.zip',
'bert-base-cased': 'bert-base-cased.zip',
'bert-base-cased-finetuned-mrpc': 'bert-base-cased-finetuned-mrpc.zip',
'bert-large-cased-wwm': 'bert-large-cased-wwm.zip',
'bert-large-uncased': 'bert-large-uncased.zip',
'bert-large-cased': 'bert-large-cased.zip',
'bert-base-uncased': 'bert-base-uncased.zip',
'bert-large-uncased-wwm': 'bert-large-uncased-wwm.zip',
'bert-chinese-wwm': 'bert-chinese-wwm.zip',
'bert-base-multilingual-cased': 'bert-base-multilingual-cased.zip',
'bert-base-multilingual-uncased': 'bert-base-multilingual-uncased.zip',
}

PRETRAINED_ELMO_MODEL_DIR = {
'en': 'elmo_en-d39843fe.tar.gz',
'en-small': "elmo_en_Small.zip"
'en-small': "elmo_en_Small.zip",
'en-original-5.5b': 'elmo_en_Original_5.5B.zip',
'en-original': 'elmo_en_Original.zip',
'en-medium': 'elmo_en_Medium.zip'
}

PRETRAIN_STATIC_FILES = {
@@ -42,34 +52,68 @@ PRETRAIN_STATIC_FILES = {
'en-fasttext-wiki': "wiki-news-300d-1M.vec.zip",
'cn': "tencent_cn-dab24577.tar.gz",
'cn-fasttext': "cc.zh.300.vec-d68a9bcf.gz",
'sgns-literature-word':'sgns.literature.word.txt.zip',
'glove-42b-300d': 'glove.42B.300d.zip',
'glove-6b-50d': 'glove.6B.50d.zip',
'glove-6b-100d': 'glove.6B.100d.zip',
'glove-6b-200d': 'glove.6B.200d.zip',
'glove-6b-300d': 'glove.6B.300d.zip',
'glove-840b-300d': 'glove.840B.300d.zip',
'glove-twitter-27b-25d': 'glove.twitter.27B.25d.zip',
'glove-twitter-27b-50d': 'glove.twitter.27B.50d.zip',
'glove-twitter-27b-100d': 'glove.twitter.27B.100d.zip',
'glove-twitter-27b-200d': 'glove.twitter.27B.200d.zip'
}


DATASET_DIR = {
'aclImdb': "imdb.zip",
"yelp-review-full":"yelp_review_full.tar.gz",
"yelp-review-polarity": "yelp_review_polarity.tar.gz",
"mnli": "MNLI.zip",
"snli": "SNLI.zip",
"qnli": "QNLI.zip",
"sst-2": "SST-2.zip",
"sst": "SST.zip",
"rte": "RTE.zip"
}


def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path:
def cached_path(url_or_filename:str, cache_dir:str=None, name=None) -> Path:
"""
给定一个url或者文件名(可以是具体的文件名,也可以是文件),先在cache_dir下寻找该文件是否存在,如果不存在则去下载, 并
给定一个url,尝试通过url中的解析出来的文件名字filename到{cache_dir}/{name}/{filename}下寻找这个文件,
(1)如果cache_dir=None, 则cache_dir=~/.fastNLP/; 否则cache_dir=cache_dir
(2)如果name=None, 则没有中间的{name}这一层结构;否者中间结构就为{name}
如果有该文件,就直接返回路径
如果没有该文件,则尝试用传入的url下载

或者文件名(可以是具体的文件名,也可以是文件夹),先在cache_dir下寻找该文件是否存在,如果不存在则去下载, 并
将文件放入到cache_dir中.

:param url_or_filename: 文件的下载url或者文件路径
:param cache_dir: 文件的缓存文件夹
:param str url_or_filename: 文件的下载url或者文件名称。
:param str cache_dir: 文件的缓存文件夹。如果为None,将使用"~/.fastNLP"这个默认路径
:param str name: 中间一层的名称。如embedding, dataset
:return:
"""
if cache_dir is None:
dataset_cache = Path(get_default_cache_path())
data_cache = Path(get_default_cache_path())
else:
dataset_cache = cache_dir
data_cache = cache_dir

if name:
data_cache = os.path.join(data_cache, name)

parsed = urlparse(url_or_filename)

if parsed.scheme in ("http", "https"):
# URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, dataset_cache)
elif parsed.scheme == "" and Path(os.path.join(dataset_cache, url_or_filename)).exists():
return get_from_cache(url_or_filename, Path(data_cache))
elif parsed.scheme == "" and Path(os.path.join(data_cache, url_or_filename)).exists():
# File, and it exists.
return Path(url_or_filename)
return Path(os.path.join(data_cache, url_or_filename))
elif parsed.scheme == "":
# File, but it doesn't exist.
raise FileNotFoundError("file {} not found".format(url_or_filename))
raise FileNotFoundError("file {} not found in {}.".format(url_or_filename, data_cache))
else:
# Something unknown
raise ValueError(
@@ -79,8 +123,12 @@ def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path:

def get_filepath(filepath):
"""
如果filepath中只有一个文件,则直接返回对应的全路径.
:param filepath:
如果filepath为文件夹,
如果内含多个文件, 返回filepath
如果只有一个文件, 返回filepath + filename
如果filepath为文件
返回filepath
:param str filepath: 路径
:return:
"""
if os.path.isdir(filepath):
@@ -89,14 +137,17 @@ def get_filepath(filepath):
return os.path.join(filepath, files[0])
else:
return filepath
return filepath
elif os.path.isfile(filepath):
return filepath
else:
raise FileNotFoundError(f"{filepath} is not a valid file or directory.")


def get_default_cache_path():
"""
获取默认的fastNLP存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。

:return:
:return: str
"""
if 'FASTNLP_CACHE_DIR' in os.environ:
fastnlp_cache_dir = os.environ.get('FASTNLP_CACHE_DIR')
@@ -109,17 +160,66 @@ def get_default_cache_path():


def _get_base_url(name):
"""
根据name返回下载的url地址。

:param str name: 支持dataset和embedding两种
:return:
"""
# 返回的URL结尾必须是/
if 'FASTNLP_BASE_URL' in os.environ:
fastnlp_base_url = os.environ['FASTNLP_BASE_URL']
if fastnlp_base_url.endswith('/'):
return fastnlp_base_url
environ_name = "FASTNLP_{}_URL".format(name.upper())

if environ_name in os.environ:
url = os.environ[environ_name]
if url.endswith('/'):
return url
else:
return fastnlp_base_url + '/'
return url + '/'
else:
# TODO 替换
dbbrain_url = "http://dbcloud.irocn.cn:8989/api/public/dl/"
return dbbrain_url
URLS = {
'embedding': "http://dbcloud.irocn.cn:8989/api/public/dl/",
"dataset": "http://dbcloud.irocn.cn:8989/api/public/dl/dataset/"
}
if name.lower() not in URLS:
raise KeyError(f"{name} is not recognized.")
return URLS[name.lower()]


def _get_embedding_url(type, name):
"""
给定embedding类似和名称,返回下载url

:param str type: 支持static, bert, elmo。即embedding的类型
:param str name: embedding的名称, 例如en, cn, based等
:return: str, 下载的url地址
"""
PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR,
"bert": PRETRAINED_BERT_MODEL_DIR,
"static":PRETRAIN_STATIC_FILES}
map = PRETRAIN_MAP.get(type, None)
if map:
filename = map.get(name, None)
if filename:
url = _get_base_url('embedding') + filename
return url
raise KeyError("There is no {}. Only supports {}.".format(name, list(map.keys())))
else:
raise KeyError(f"There is no {type}. Only supports bert, elmo, static")


def _get_dataset_url(name):
"""
给定dataset的名称,返回下载url

:param str name: 给定dataset的名称,比如imdb, sst-2等
:return: str
"""
filename = DATASET_DIR.get(name, None)
if filename:
url = _get_base_url('dataset') + filename
return url
else:
raise KeyError(f"There is no {name}.")


def split_filename_suffix(filepath):
@@ -136,9 +236,9 @@ def split_filename_suffix(filepath):

def get_from_cache(url: str, cache_dir: Path = None) -> Path:
"""
尝试在cache_dir中寻找url定义的资源; 如果没有找到则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。
如果从url中下载的资源解压后有多个文件,则返回directory的路径; 如果只有一个资源,则返回具体的路径
尝试在cache_dir中寻找url定义的资源; 如果没有找到; 则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。会将下载的
文件解压,将解压后的文件全部放在cache_dir文件夹中
如果从url中下载的资源解压后有多个文件,则返回目录的路径; 如果只有一个资源文件,则返回具体的路径。
"""
cache_dir.mkdir(parents=True, exist_ok=True)

@@ -173,63 +273,68 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path:

# GET file object
req = requests.get(url, stream=True, headers={"User-Agent": "fastNLP"})
content_length = req.headers.get("Content-Length")
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total)
with open(temp_filename, "wb") as temp_file:
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
print(f"Finish download from {url}.")

# 开始解压
delete_temp_dir = None
if suffix in ('.zip', '.tar.gz'):
uncompress_temp_dir = tempfile.mkdtemp()
delete_temp_dir = uncompress_temp_dir
print(f"Start to uncompress file to {uncompress_temp_dir}")
if suffix == '.zip':
unzip_file(Path(temp_filename), Path(uncompress_temp_dir))
if req.status_code==200:
content_length = req.headers.get("Content-Length")
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total, unit_scale=1)
with open(temp_filename, "wb") as temp_file:
for chunk in req.iter_content(chunk_size=1024*16):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
print(f"Finish download from {url}.")

# 开始解压
delete_temp_dir = None
if suffix in ('.zip', '.tar.gz'):
uncompress_temp_dir = tempfile.mkdtemp()
delete_temp_dir = uncompress_temp_dir
print(f"Start to uncompress file to {uncompress_temp_dir}")
if suffix == '.zip':
unzip_file(Path(temp_filename), Path(uncompress_temp_dir))
else:
untar_gz_file(Path(temp_filename), Path(uncompress_temp_dir))
filenames = os.listdir(uncompress_temp_dir)
if len(filenames)==1:
if os.path.isdir(os.path.join(uncompress_temp_dir, filenames[0])):
uncompress_temp_dir = os.path.join(uncompress_temp_dir, filenames[0])

cache_path.mkdir(parents=True, exist_ok=True)
print("Finish un-compressing file.")
else:
untar_gz_file(Path(temp_filename), Path(uncompress_temp_dir))
filenames = os.listdir(uncompress_temp_dir)
if len(filenames)==1:
if os.path.isdir(os.path.join(uncompress_temp_dir, filenames[0])):
uncompress_temp_dir = os.path.join(uncompress_temp_dir, filenames[0])

cache_path.mkdir(parents=True, exist_ok=True)
print("Finish un-compressing file.")
uncompress_temp_dir = temp_filename
cache_path = str(cache_path) + suffix
success = False
try:
# 复制到指定的位置
print(f"Copy file to {cache_path}")
if os.path.isdir(uncompress_temp_dir):
for filename in os.listdir(uncompress_temp_dir):
if os.path.isdir(os.path.join(uncompress_temp_dir, filename)):
shutil.copytree(os.path.join(uncompress_temp_dir, filename), cache_path/filename)
else:
shutil.copyfile(os.path.join(uncompress_temp_dir, filename), cache_path/filename)
else:
shutil.copyfile(uncompress_temp_dir, cache_path)
success = True
except Exception as e:
print(e)
raise e
finally:
if not success:
if cache_path.exists():
if cache_path.is_file():
os.remove(cache_path)
else:
shutil.rmtree(cache_path)
if delete_temp_dir:
shutil.rmtree(delete_temp_dir)
os.close(fd)
os.remove(temp_filename)
return get_filepath(cache_path)
else:
uncompress_temp_dir = temp_filename
cache_path = str(cache_path) + suffix
success = False
try:
# 复制到指定的位置
print(f"Copy file to {cache_path}")
if os.path.isdir(uncompress_temp_dir):
for filename in os.listdir(uncompress_temp_dir):
shutil.copyfile(os.path.join(uncompress_temp_dir, filename), cache_path/filename)
else:
shutil.copyfile(uncompress_temp_dir, cache_path)
success = True
except Exception as e:
print(e)
raise e
finally:
if not success:
if cache_path.exists():
if cache_path.is_file():
os.remove(cache_path)
else:
shutil.rmtree(cache_path)
if delete_temp_dir:
shutil.rmtree(delete_temp_dir)
os.close(fd)
os.remove(temp_filename)

return get_filepath(cache_path)
raise HTTPError(f"Fail to download from {url}.")


def unzip_file(file: Path, to: Path):


+ 30
- 0
fastNLP/io/loader/__init__.py View File

@@ -0,0 +1,30 @@

"""
Loader用于读取数据,并将内容读取到 :class:`~fastNLP.DataSet` 或者 :class:`~fastNLP.io.DataBundle`中。所有的Loader都支持以下的
三个方法: __init__(),_load(), loads(). 其中__init__()用于申明读取参数,以及说明该Loader支持的数据格式,读取后Dataset中field
; _load(path)方法传入一个文件路径读取单个文件,并返回DataSet; load(paths)用于读取文件夹下的文件,并返回DataBundle, load()方法
支持以下三种类型的参数

Example::
(0) 如果传入None,将尝试自动下载数据集并缓存。但不是所有的数据都可以直接下载。
(1) 如果传入的是一个文件path,则返回的DataBundle包含一个名为train的DataSet可以通过data_bundle.datasets['train']获取
(2) 传入的是一个文件夹目录,将读取的是这个文件夹下文件名中包含'train', 'test', 'dev'的文件,其它文件会被忽略。
假设某个目录下的文件为
-train.txt
-dev.txt
-test.txt
-other.txt
Loader().load('/path/to/dir')读取,返回的data_bundle中可以用data_bundle.datasets['train'], data_bundle.datasets['dev'],
data_bundle.datasets['test']获取对应的DataSet,其中other.txt的内容会被忽略。
假设某个目录下的文件为
-train.txt
-dev.txt
Loader().load('/path/to/dir')读取,返回的data_bundle中可以用data_bundle.datasets['train'], data_bundle.datasets['dev']获取
对应的DataSet。
(3) 传入一个dict,key为dataset的名称,value是该dataset的文件路径。
paths = {'train':'/path/to/train', 'dev': '/path/to/dev', 'test':'/path/to/test'}
Loader().load(paths) # 返回的data_bundle可以通过以下的方式获取相应的DataSet, data_bundle.datasets['train'], data_bundle.datasets['dev'],
data_bundle.datasets['test']

"""


+ 369
- 0
fastNLP/io/loader/classification.py View File

@@ -0,0 +1,369 @@
from ...core.dataset import DataSet
from ...core.instance import Instance
from .loader import Loader
import warnings
import os
import random
import shutil
import numpy as np

class YelpLoader(Loader):
"""
别名::class:`fastNLP.io.YelpLoader` :class:`fastNLP.io.loader.YelpLoader`

原始数据中内容应该为, 每一行为一个sample,第一个逗号之前为target,第一个逗号之后为文本内容。

Example::
"1","I got 'new' tires from the..."
"1","Don't waste your time..."

读取YelpFull, YelpPolarity的数据。可以通过xxx下载并预处理数据。
读取的DataSet将具备以下的数据结构

.. csv-table::
:header: "raw_words", "target"

"I got 'new' tires from them and... ", "1"
"Don't waste your time. We had two...", "1"
"...", "..."

"""

def __init__(self):
super(YelpLoader, self).__init__()

def _load(self, path: str=None):
ds = DataSet()
with open(path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
sep_index = line.index(',')
target = line[:sep_index]
raw_words = line[sep_index + 1:]
if target.startswith("\""):
target = target[1:]
if target.endswith("\""):
target = target[:-1]
if raw_words.endswith("\""):
raw_words = raw_words[:-1]
if raw_words.startswith('"'):
raw_words = raw_words[1:]
raw_words = raw_words.replace('""', '"') # 替换双引号
if raw_words:
ds.append(Instance(raw_words=raw_words, target=target))
return ds


class YelpFullLoader(YelpLoader):
def download(self, dev_ratio: float = 0.1, seed: int = 0):
"""
自动下载数据集,如果你使用了这个数据集,请引用以下的文章

Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances
in Neural Information Processing Systems 28 (NIPS 2015)

根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后在output_dir中有train.csv, test.csv,
dev.csv三个文件。

:param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。
:param int seed: 划分dev时的随机数种子
:return: str, 数据集的目录地址
"""

dataset_name = 'yelp-review-full'
data_dir = self._get_dataset_path(dataset_name=dataset_name)
if os.path.exists(os.path.join(data_dir, 'dev.csv')): # 存在dev的话,check是否需要重新下载
re_download = True
if dev_ratio>0:
dev_line_count = 0
tr_line_count = 0
with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f1, \
open(os.path.join(data_dir, 'dev.csv'), 'r', encoding='utf-8') as f2:
for line in f1:
tr_line_count += 1
for line in f2:
dev_line_count += 1
if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005):
re_download = True
else:
re_download = False
if re_download:
shutil.rmtree(data_dir)
data_dir = self._get_dataset_path(dataset_name=dataset_name)

if not os.path.exists(os.path.join(data_dir, 'dev.csv')):
if dev_ratio > 0:
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."
random.seed(int(seed))
try:
with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f, \
open(os.path.join(data_dir, 'middle_file.csv'), 'w', encoding='utf-8') as f1, \
open(os.path.join(data_dir, 'dev.csv'), 'w', encoding='utf-8') as f2:
for line in f:
if random.random() < dev_ratio:
f2.write(line)
else:
f1.write(line)
os.remove(os.path.join(data_dir, 'train.csv'))
os.renames(os.path.join(data_dir, 'middle_file.csv'), os.path.join(data_dir, 'train.csv'))
finally:
if os.path.exists(os.path.join(data_dir, 'middle_file.csv')):
os.remove(os.path.join(data_dir, 'middle_file.csv'))

return data_dir


class YelpPolarityLoader(YelpLoader):
def download(self, dev_ratio: float = 0.1, seed: int = 0):
"""
自动下载数据集,如果你使用了这个数据集,请引用以下的文章

Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances
in Neural Information Processing Systems 28 (NIPS 2015)

根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后从train中切分0.1作为dev

:param float dev_ratio: 如果路径中不存在dev.csv, 从train划分多少作为dev的数据. 如果为0,则不划分dev
:param int seed: 划分dev时的随机数种子
:return: str, 数据集的目录地址
"""
dataset_name = 'yelp-review-polarity'
data_dir = self._get_dataset_path(dataset_name=dataset_name)
if os.path.exists(os.path.join(data_dir, 'dev.csv')): # 存在dev的话,check是否符合比例要求
re_download = True
if dev_ratio>0:
dev_line_count = 0
tr_line_count = 0
with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f1, \
open(os.path.join(data_dir, 'dev.csv'), 'r', encoding='utf-8') as f2:
for line in f1:
tr_line_count += 1
for line in f2:
dev_line_count += 1
if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005):
re_download = True
else:
re_download = False
if re_download:
shutil.rmtree(data_dir)
data_dir = self._get_dataset_path(dataset_name=dataset_name)

if not os.path.exists(os.path.join(data_dir, 'dev.csv')):
if dev_ratio > 0:
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."
random.seed(int(seed))
try:
with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f, \
open(os.path.join(data_dir, 'middle_file.csv'), 'w', encoding='utf-8') as f1, \
open(os.path.join(data_dir, 'dev.csv'), 'w', encoding='utf-8') as f2:
for line in f:
if random.random() < dev_ratio:
f2.write(line)
else:
f1.write(line)
os.remove(os.path.join(data_dir, 'train.csv'))
os.renames(os.path.join(data_dir, 'middle_file.csv'), os.path.join(data_dir, 'train.csv'))
finally:
if os.path.exists(os.path.join(data_dir, 'middle_file.csv')):
os.remove(os.path.join(data_dir, 'middle_file.csv'))

return data_dir


class IMDBLoader(Loader):
"""
别名::class:`fastNLP.io.IMDBLoader` :class:`fastNLP.io.loader.IMDBLoader`

IMDBLoader读取后的数据将具有以下两列内容: raw_words: str, 需要分类的文本; target: str, 文本的标签
DataSet具备以下的结构:

.. csv-table::
:header: "raw_words", "target"

"Bromwell High is a cartoon ... ", "pos"
"Story of a man who has ...", "neg"
"...", "..."

"""

def __init__(self):
super(IMDBLoader, self).__init__()

def _load(self, path: str):
dataset = DataSet()
with open(path, 'r', encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
parts = line.split('\t')
target = parts[0]
words = parts[1]
if words:
dataset.append(Instance(raw_words=words, target=target))

if len(dataset) == 0:
raise RuntimeError(f"{path} has no valid data.")

return dataset

def download(self, dev_ratio: float = 0.1, seed: int = 0):
"""
自动下载数据集,如果你使用了这个数据集,请引用以下的文章

http://www.aclweb.org/anthology/P11-1015

根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后从train中切分0.1作为dev

:param float dev_ratio: 如果路径中没有dev.txt。从train划分多少作为dev的数据. 如果为0,则不划分dev
:param int seed: 划分dev时的随机数种子
:return: str, 数据集的目录地址
"""
dataset_name = 'aclImdb'
data_dir = self._get_dataset_path(dataset_name=dataset_name)
if os.path.exists(os.path.join(data_dir, 'dev.txt')): # 存在dev的话,check是否符合比例要求
re_download = True
if dev_ratio>0:
dev_line_count = 0
tr_line_count = 0
with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f1, \
open(os.path.join(data_dir, 'dev.txt'), 'r', encoding='utf-8') as f2:
for line in f1:
tr_line_count += 1
for line in f2:
dev_line_count += 1
if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005):
re_download = True
else:
re_download = False
if re_download:
shutil.rmtree(data_dir)
data_dir = self._get_dataset_path(dataset_name=dataset_name)

if not os.path.exists(os.path.join(data_dir, 'dev.csv')):
if dev_ratio > 0:
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."
random.seed(int(seed))
try:
with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f, \
open(os.path.join(data_dir, 'middle_file.txt'), 'w', encoding='utf-8') as f1, \
open(os.path.join(data_dir, 'dev.txt'), 'w', encoding='utf-8') as f2:
for line in f:
if random.random() < dev_ratio:
f2.write(line)
else:
f1.write(line)
os.remove(os.path.join(data_dir, 'train.txt'))
os.renames(os.path.join(data_dir, 'middle_file.txt'), os.path.join(data_dir, 'train.txt'))
finally:
if os.path.exists(os.path.join(data_dir, 'middle_file.txt')):
os.remove(os.path.join(data_dir, 'middle_file.txt'))

return data_dir


class SSTLoader(Loader):
"""
别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.loader.SSTLoader`

读取之后的DataSet具有以下的结构

.. csv-table:: 下面是使用SSTLoader读取的DataSet所具备的field
:header: "raw_words"

"(3 (2 It) (4 (4 (2 's) (4 (3 (2 a)..."
"(4 (4 (2 Offers) (3 (3 (2 that) (3 (3 rare)..."
"..."

raw_words列是str。

"""

def __init__(self):
super().__init__()

def _load(self, path: str):
"""
从path读取SST文件

:param str path: 文件路径
:return: DataSet
"""
ds = DataSet()
with open(path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line:
ds.append(Instance(raw_words=line))
return ds

def download(self):
"""
自动下载数据集,如果你使用了这个数据集,请引用以下的文章

https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf

:return: str, 数据集的目录地址
"""
output_dir = self._get_dataset_path(dataset_name='sst')
return output_dir


class SST2Loader(Loader):
"""
数据SST2的Loader
读取之后DataSet将如下所示

.. csv-table:: 下面是使用SSTLoader读取的DataSet所具备的field
:header: "raw_words", "target"

"it 's a charming and often affecting...", "1"
"unflinchingly bleak and...", "0"
"..."

test的DataSet没有target列。
"""

def __init__(self):
super().__init__()

def _load(self, path: str):
"""
从path读取SST2文件

:param str path: 数据路径
:return: DataSet
"""
ds = DataSet()

with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header
if 'test' in os.path.split(path)[1]:
warnings.warn("SST2's test file has no target.")
for line in f:
line = line.strip()
if line:
sep_index = line.index('\t')
raw_words = line[sep_index + 1:]
if raw_words:
ds.append(Instance(raw_words=raw_words))
else:
for line in f:
line = line.strip()
if line:
raw_words = line[:-2]
target = line[-1]
if raw_words:
ds.append(Instance(raw_words=raw_words, target=target))
return ds

def download(self):
"""
自动下载数据集,如果你使用了该数据集,请引用以下的文章

https://nlp.stanford.edu/pubs/SocherBauerManningNg_ACL2013.pdf

:return:
"""
output_dir = self._get_dataset_path(dataset_name='sst-2')
return output_dir

+ 264
- 0
fastNLP/io/loader/conll.py View File

@@ -0,0 +1,264 @@
from typing import Dict, Union

from .loader import Loader
from ... import DataSet
from ..file_reader import _read_conll
from ... import Instance
from .. import DataBundle
from ..utils import check_loader_paths
from ... import Const


class ConllLoader(Loader):
"""
别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.data_loader.ConllLoader`

ConllLoader支持读取的数据格式: 以空行隔开两个sample,除了分割行,每一行用空格或者制表符隔开不同的元素。如下例所示:

Example::

# 文件中的内容
Nadim NNP B-NP B-PER
Ladki NNP I-NP I-PER

AL-AIN NNP B-NP B-LOC
United NNP B-NP B-LOC
Arab NNP I-NP I-LOC
Emirates NNPS I-NP I-LOC
1996-12-06 CD I-NP O
...

# 如果用以下的参数读取,返回的DataSet将包含raw_words和pos两个field, 这两个field的值分别取自于第0列与第1列
dataset = ConllLoader(headers=['raw_words', 'pos'], indexes=[0, 1])._load('/path/to/train.conll')
# 如果用以下的参数读取,返回的DataSet将包含raw_words和ner两个field, 这两个field的值分别取自于第0列与第2列
dataset = ConllLoader(headers=['raw_words', 'ner'], indexes=[0, 3])._load('/path/to/train.conll')
# 如果用以下的参数读取,返回的DataSet将包含raw_words, pos和ner三个field
dataset = ConllLoader(headers=['raw_words', 'pos', 'ner'], indexes=[0, 1, 3])._load('/path/to/train.conll')

ConllLoader返回的DataSet的field由传入的headers确定。

数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。

:param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应
:param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None``
:param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True``

"""
def __init__(self, headers, indexes=None, dropna=True):
super(ConllLoader, self).__init__()
if not isinstance(headers, (list, tuple)):
raise TypeError(
'invalid headers: {}, should be list of strings'.format(headers))
self.headers = headers
self.dropna = dropna
if indexes is None:
self.indexes = list(range(len(self.headers)))
else:
if len(indexes) != len(headers):
raise ValueError
self.indexes = indexes

def _load(self, path):
"""
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。

:param str path: 文件的路径
:return: DataSet
"""
ds = DataSet()
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna):
ins = {h: data[i] for i, h in enumerate(self.headers)}
ds.append(Instance(**ins))
return ds


class Conll2003Loader(ConllLoader):
"""
用于读取conll2003任务的数据。数据的内容应该类似与以下的内容, 第一列为raw_words, 第二列为pos, 第三列为chunking,第四列为ner。

Example::

Nadim NNP B-NP B-PER
Ladki NNP I-NP I-PER

AL-AIN NNP B-NP B-LOC
United NNP B-NP B-LOC
Arab NNP I-NP I-LOC
Emirates NNPS I-NP I-LOC
1996-12-06 CD I-NP O
...

返回的DataSet的内容为

.. csv-table:: 下面是Conll2003Loader加载后数据具备的结构。
:header: "raw_words", "pos", "chunk", "ner"

"[Nadim, Ladki]", "[NNP, NNP]", "[B-NP, I-NP]", "[B-PER, I-PER]"
"[AL-AIN, United, Arab, ...]", "[NNP, NNP, NNP, ...]", "[B-NP, B-NP, I-NP, ...]", "[B-LOC, B-LOC, I-LOC, ...]"
"[...]", "[...]", "[...]", "[...]"

"""
def __init__(self):
headers = [
'raw_words', 'pos', 'chunk', 'ner',
]
super(Conll2003Loader, self).__init__(headers=headers)

def _load(self, path):
"""
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。

:param str path: 文件的路径
:return: DataSet
"""
ds = DataSet()
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna):
doc_start = False
for i, h in enumerate(self.headers):
field = data[i]
if str(field[0]).startswith('-DOCSTART-'):
doc_start = True
break
if doc_start:
continue
ins = {h: data[i] for i, h in enumerate(self.headers)}
ds.append(Instance(**ins))
return ds

def download(self, output_dir=None):
raise RuntimeError("conll2003 cannot be downloaded automatically.")


class Conll2003NERLoader(ConllLoader):
"""
用于读取conll2003任务的NER数据。

Example::

Nadim NNP B-NP B-PER
Ladki NNP I-NP I-PER

AL-AIN NNP B-NP B-LOC
United NNP B-NP B-LOC
Arab NNP I-NP I-LOC
Emirates NNPS I-NP I-LOC
1996-12-06 CD I-NP O
...

返回的DataSet的内容为

.. csv-table:: 下面是Conll2003Loader加载后数据具备的结构, target是BIO2编码
:header: "raw_words", "target"

"[Nadim, Ladki]", "[B-PER, I-PER]"
"[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]"
"[...]", "[...]"

"""
def __init__(self):
headers = [
'raw_words', 'target',
]
super().__init__(headers=headers, indexes=[0, 3])

def _load(self, path):
"""
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。

:param str path: 文件的路径
:return: DataSet
"""
ds = DataSet()
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna):
doc_start = False
for i, h in enumerate(self.headers):
field = data[i]
if str(field[0]).startswith('-DOCSTART-'):
doc_start = True
break
if doc_start:
continue
ins = {h: data[i] for i, h in enumerate(self.headers)}
ds.append(Instance(**ins))
return ds

def download(self):
raise RuntimeError("conll2003 cannot be downloaded automatically.")


class OntoNotesNERLoader(ConllLoader):
"""
用以读取OntoNotes的NER数据,同时也是Conll2012的NER任务数据。将OntoNote数据处理为conll格式的过程可以参考
https://github.com/yhcc/OntoNotes-5.0-NER。OntoNoteNERLoader将取第4列和第11列的内容。

返回的DataSet的内容为

.. csv-table:: 下面是使用OntoNoteNERLoader读取的DataSet所具备的结构, target列是BIO编码
:header: "raw_words", "target"

"[Nadim, Ladki]", "[B-PER, I-PER]"
"[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]"
"[...]", "[...]"

"""

def __init__(self):
super().__init__(headers=[Const.RAW_WORD, Const.TARGET], indexes=[3, 10])

def _load(self, path:str):
dataset = super()._load(path)

def convert_to_bio(tags):
bio_tags = []
flag = None
for tag in tags:
label = tag.strip("()*")
if '(' in tag:
bio_label = 'B-' + label
flag = label
elif flag:
bio_label = 'I-' + flag
else:
bio_label = 'O'
if ')' in tag:
flag = None
bio_tags.append(bio_label)
return bio_tags

def convert_word(words):
converted_words = []
for word in words:
word = word.replace('/.', '.') # 有些结尾的.是/.形式的
if not word.startswith('-'):
converted_words.append(word)
continue
# 以下是由于这些符号被转义了,再转回来
tfrs = {'-LRB-':'(',
'-RRB-': ')',
'-LSB-': '[',
'-RSB-': ']',
'-LCB-': '{',
'-RCB-': '}'
}
if word in tfrs:
converted_words.append(tfrs[word])
else:
converted_words.append(word)
return converted_words

dataset.apply_field(convert_word, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD)
dataset.apply_field(convert_to_bio, field_name=Const.TARGET, new_field_name=Const.TARGET)

return dataset

def download(self):
raise RuntimeError("Ontonotes cannot be downloaded automatically, you can refer "
"https://github.com/yhcc/OntoNotes-5.0-NER to download and preprocess.")


class CTBLoader(Loader):
def __init__(self):
super().__init__()

def _load(self, path:str):
pass

+ 32
- 0
fastNLP/io/loader/csv.py View File

@@ -0,0 +1,32 @@
from ...core.dataset import DataSet
from ...core.instance import Instance
from ..file_reader import _read_csv
from .loader import Loader


class CSVLoader(Loader):
"""
别名::class:`fastNLP.io.CSVLoader` :class:`fastNLP.io.dataset_loader.CSVLoader`

读取CSV格式的数据集, 返回 ``DataSet`` 。

:param List[str] headers: CSV文件的文件头.定义每一列的属性名称,即返回的DataSet中`field`的名称
若为 ``None`` ,则将读入文件的第一行视作 ``headers`` . Default: ``None``
:param str sep: CSV文件中列与列之间的分隔符. Default: ","
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
Default: ``False``
"""

def __init__(self, headers=None, sep=",", dropna=False):
super().__init__()
self.headers = headers
self.sep = sep
self.dropna = dropna

def _load(self, path):
ds = DataSet()
for idx, data in _read_csv(path, headers=self.headers,
sep=self.sep, dropna=self.dropna):
ds.append(Instance(**data))
return ds


+ 41
- 0
fastNLP/io/loader/cws.py View File

@@ -0,0 +1,41 @@

from .loader import Loader
from ...core import DataSet, Instance


class CWSLoader(Loader):
"""
分词任务数据加载器,
SigHan2005的数据可以用xxx下载并预处理

CWSLoader支持的数据格式为,一行一句话,不同词之间用空格隔开, 例如:

Example::

上海 浦东 开发 与 法制 建设 同步
新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )
...

该Loader读取后的DataSet具有如下的结构

.. csv-table::
:header: "raw_words"

"上海 浦东 开发 与 法制 建设 同步"
"新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )"
"..."
"""
def __init__(self):
super().__init__()

def _load(self, path:str):
ds = DataSet()
with open(path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line:
ds.append(Instance(raw_words=line))
return ds

def download(self, output_dir=None):
raise RuntimeError("You can refer {} for sighan2005's data downloading.")

+ 40
- 0
fastNLP/io/loader/json.py View File

@@ -0,0 +1,40 @@
from ...core.dataset import DataSet
from ...core.instance import Instance
from ..file_reader import _read_json
from .loader import Loader


class JsonLoader(Loader):
"""
别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.loader.JsonLoader`

读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象

:param dict fields: 需要读入的json属性名称, 和读入后在DataSet中存储的field_name
``fields`` 的 `key` 必须是json对象的属性名. ``fields`` 的 `value` 为读入后在DataSet存储的 `field_name` ,
`value` 也可为 ``None`` , 这时读入后的 `field_name` 与json对象对应属性同名
``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None``
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
Default: ``False``
"""

def __init__(self, fields=None, dropna=False):
super(JsonLoader, self).__init__()
self.dropna = dropna
self.fields = None
self.fields_list = None
if fields:
self.fields = {}
for k, v in fields.items():
self.fields[k] = k if v is None else v
self.fields_list = list(self.fields.keys())

def _load(self, path):
ds = DataSet()
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna):
if self.fields:
ins = {self.fields[k]: v for k, v in d.items()}
else:
ins = d
ds.append(Instance(**ins))
return ds

+ 75
- 0
fastNLP/io/loader/loader.py View File

@@ -0,0 +1,75 @@
from ... import DataSet
from .. import DataBundle
from ..utils import check_loader_paths
from typing import Union, Dict
import os
from ..file_utils import _get_dataset_url, get_default_cache_path, cached_path

class Loader:
def __init__(self):
pass

def _load(self, path:str) -> DataSet:
raise NotImplementedError

def load(self, paths: Union[str, Dict[str, str]]=None) -> DataBundle:
"""
从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。

读取的field根据ConllLoader初始化时传入的headers决定。

:param Union[str, Dict[str, str]] paths: 支持以下的几种输入方式
(0) 如果为None,则先查看本地是否有缓存,如果没有则自动下载并缓存。

(1) 传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件
名包含'train'、 'dev'、 'test'则会报错

Example::

data_bundle = ConllLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train、
# dev、 test等有所变化,可以通过以下的方式取出DataSet
tr_data = data_bundle.datasets['train']
te_data = data_bundle.datasets['test'] # 如果目录下有文件包含test这个字段

(2) 传入文件路径

Example::

data_bundle = ConllLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train'
tr_data = data_bundle.datasets['train'] # 可以通过以下的方式取出DataSet

(3) 传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test

Example::

paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"}
data_bundle = ConllLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test"
dev_data = data_bundle.datasets['dev']

:return: 返回的:class:`~fastNLP.io.DataBundle`
"""
if paths is None:
paths = self.download()
paths = check_loader_paths(paths)
datasets = {name: self._load(path) for name, path in paths.items()}
data_bundle = DataBundle(datasets=datasets)
return data_bundle

def download(self):
raise NotImplementedError(f"{self.__class__} cannot download data automatically.")

def _get_dataset_path(self, dataset_name):
"""
传入dataset的名称,获取读取数据的目录。如果数据不存在,会尝试自动下载并缓存

:param str dataset_name: 数据集的名称
:return: str, 数据集的目录地址。直接到该目录下读取相应的数据即可。
"""

default_cache_path = get_default_cache_path()
url = _get_dataset_url(dataset_name)
output_dir = cached_path(url_or_filename=url, cache_dir=default_cache_path, name='dataset')

return output_dir



+ 309
- 0
fastNLP/io/loader/matching.py View File

@@ -0,0 +1,309 @@

import warnings
from .loader import Loader
from .json import JsonLoader
from ...core import Const
from .. import DataBundle
import os
from typing import Union, Dict
from ...core import DataSet
from ...core import Instance

__all__ = ['MNLILoader',
"QuoraLoader",
"SNLILoader",
"QNLILoader",
"RTELoader"]


class MNLILoader(Loader):
"""
读取MNLI任务的数据,读取之后的DataSet中包含以下的内容,words0是sentence1, words1是sentence2, target是gold_label, 测试集中没
有target列。

.. csv-table::
:header: "raw_words1", "raw_words2", "target"

"The new rights are...", "Everyone really likes..", "neutral"
"This site includes a...", "The Government Executive...", "contradiction"
"...", "...","."

"""
def __init__(self):
super().__init__()

def _load(self, path:str):
ds = DataSet()
with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header
if path.endswith("test.tsv"):
warnings.warn("RTE's test file has no target.")
for line in f:
line = line.strip()
if line:
parts = line.split('\t')
raw_words1 = parts[8]
raw_words2 = parts[9]
if raw_words1 and raw_words2:
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2))
else:
for line in f:
line = line.strip()
if line:
parts = line.split('\t')
raw_words1 = parts[8]
raw_words2 = parts[9]
target = parts[-1]
if raw_words1 and raw_words2 and target:
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target))
return ds

def load(self, paths:str=None):
"""

:param str paths: 传入数据所在目录,会在该目录下寻找dev_matched.tsv, dev_mismatched.tsv, test_matched.tsv,
test_mismatched.tsv, train.tsv文件夹
:return: DataBundle
"""
if paths:
paths = os.path.abspath(os.path.expanduser(paths))
else:
paths = self.download()
if not os.path.isdir(paths):
raise NotADirectoryError(f"{paths} is not a valid directory.")

files = {'dev_matched':"dev_matched.tsv",
"dev_mismatched":"dev_mismatched.tsv",
"test_matched":"test_matched.tsv",
"test_mismatched":"test_mismatched.tsv",
"train":'train.tsv'}

datasets = {}
for name, filename in files.items():
filepath = os.path.join(paths, filename)
if not os.path.isfile(filepath):
if 'test' not in name:
raise FileNotFoundError(f"{name} not found in directory {filepath}.")
datasets[name] = self._load(filepath)

data_bundle = DataBundle(datasets=datasets)

return data_bundle

def download(self):
"""
如果你使用了这个数据,请引用

https://www.nyu.edu/projects/bowman/multinli/paper.pdf
:return:
"""
output_dir = self._get_dataset_path('mnli')
return output_dir


class SNLILoader(JsonLoader):
"""
读取之后的DataSet中的field情况为

.. csv-table:: 下面是使用SNLILoader加载的DataSet所具备的field
:header: "raw_words1", "raw_words2", "target"

"The new rights are...", "Everyone really likes..", "neutral"
"This site includes a...", "The Government Executive...", "entailment"
"...", "...", "."

"""
def __init__(self):
super().__init__(fields={
'sentence1': Const.RAW_WORDS(0),
'sentence2': Const.RAW_WORDS(1),
'gold_label': Const.TARGET,
})

def load(self, paths: Union[str, Dict[str, str]]=None) -> DataBundle:
"""
从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。

读取的field根据ConllLoader初始化时传入的headers决定。

:param str paths: 传入一个目录, 将在该目录下寻找snli_1.0_train.jsonl, snli_1.0_dev.jsonl
和snli_1.0_test.jsonl三个文件。

:return: 返回的:class:`~fastNLP.io.DataBundle`
"""
_paths = {}
if paths is None:
paths = self.download()
if paths:
if os.path.isdir(paths):
if not os.path.isfile(os.path.join(paths, 'snli_1.0_train.jsonl')):
raise FileNotFoundError(f"snli_1.0_train.jsonl is not found in {paths}")
_paths['train'] = os.path.join(paths, 'snli_1.0_train.jsonl')
for filename in ['snli_1.0_dev.jsonl', 'snli_1.0_test.jsonl']:
filepath = os.path.join(paths, filename)
_paths[filename.split('_')[-1].split('.')[0]] = filepath
paths = _paths
else:
raise NotADirectoryError(f"{paths} is not a valid directory.")

datasets = {name: self._load(path) for name, path in paths.items()}
data_bundle = DataBundle(datasets=datasets)
return data_bundle

def download(self):
"""
如果您的文章使用了这份数据,请引用

http://nlp.stanford.edu/pubs/snli_paper.pdf

:return: str
"""
return self._get_dataset_path('snli')


class QNLILoader(JsonLoader):
"""
QNLI数据集的Loader,
加载的DataSet将具备以下的field, raw_words1是question, raw_words2是sentence, target是label

.. csv-table::
:header: "raw_words1", "raw_words2", "target"

"What came into force after the new...", "As of that day...", "entailment"
"What is the first major...", "The most important tributaries", "not_entailment"
"...","."

test数据集没有target列

"""
def __init__(self):
super().__init__()

def _load(self, path):
ds = DataSet()

with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header
if path.endswith("test.tsv"):
warnings.warn("QNLI's test file has no target.")
for line in f:
line = line.strip()
if line:
parts = line.split('\t')
raw_words1 = parts[1]
raw_words2 = parts[2]
if raw_words1 and raw_words2:
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2))
else:
for line in f:
line = line.strip()
if line:
parts = line.split('\t')
raw_words1 = parts[1]
raw_words2 = parts[2]
target = parts[-1]
if raw_words1 and raw_words2 and target:
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target))
return ds

def download(self):
"""
如果您的实验使用到了该数据,请引用

TODO 补充

:return:
"""
return self._get_dataset_path('qnli')


class RTELoader(Loader):
"""
RTE数据的loader
加载的DataSet将具备以下的field, raw_words1是sentence0,raw_words2是sentence1, target是label

.. csv-table::
:header: "raw_words1", "raw_words2", "target"

"Dana Reeve, the widow of the actor...", "Christopher Reeve had an...", "not_entailment"
"Yet, we now are discovering that...", "Bacteria is winning...", "entailment"
"...","."

test数据集没有target列
"""
def __init__(self):
super().__init__()

def _load(self, path:str):
ds = DataSet()

with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header
if path.endswith("test.tsv"):
warnings.warn("RTE's test file has no target.")
for line in f:
line = line.strip()
if line:
parts = line.split('\t')
raw_words1 = parts[1]
raw_words2 = parts[2]
if raw_words1 and raw_words2:
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2))
else:
for line in f:
line = line.strip()
if line:
parts = line.split('\t')
raw_words1 = parts[1]
raw_words2 = parts[2]
target = parts[-1]
if raw_words1 and raw_words2 and target:
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target))
return ds

def download(self):
return self._get_dataset_path('rte')


class QuoraLoader(Loader):
"""
Quora matching任务的数据集Loader

支持读取的文件中的内容,应该有以下的形式, 以制表符分隔,且前三列的内容必须是:第一列是label,第二列和第三列是句子

Example::

1 How do I get funding for my web based startup idea ? How do I get seed funding pre product ? 327970
1 How can I stop my depression ? What can I do to stop being depressed ? 339556
...

加载的DataSet将具备以下的field

.. csv-table::
:header: "raw_words1", "raw_words2", "target"

"What should I do to avoid...", "1"
"How do I not sleep in a boring class...", "0"
"...","."

"""
def __init__(self):
super().__init__()

def _load(self, path:str):
ds = DataSet()

with open(path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line:
parts = line.split('\t')
raw_words1 = parts[1]
raw_words2 = parts[2]
target = parts[0]
if raw_words1 and raw_words2 and target:
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target))
return ds

def download(self):
raise RuntimeError("Quora cannot be downloaded automatically.")

+ 8
- 0
fastNLP/io/pipe/__init__.py View File

@@ -0,0 +1,8 @@


"""
Pipe用于处理数据,所有的Pipe都包含一个process(DataBundle)方法,传入一个DataBundle对象, 在传入DataBundle上进行原位修改,并将其返回;
process_from_file(paths)传入的文件路径,返回一个DataBundle。process(DataBundle)或者process_from_file(paths)的返回DataBundle
中的DataSet一般都包含原文与转换为index的输入,以及转换为index的target;除了DataSet之外,还会包含将field转为index时所建立的词表。

"""

+ 444
- 0
fastNLP/io/pipe/classification.py View File

@@ -0,0 +1,444 @@

from nltk import Tree

from ..base_loader import DataBundle
from ...core.vocabulary import Vocabulary
from ...core.const import Const
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader
from ...core import DataSet, Instance

from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance
from .pipe import Pipe
import re
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+')
from ...core import cache_results

class _CLSPipe(Pipe):
"""
分类问题的基类,负责对classification的数据进行tokenize操作。默认是对raw_words列操作,然后生成words列

"""
def __init__(self, tokenizer:str='spacy', lang='en'):
self.tokenizer = get_tokenizer(tokenizer, lang=lang)

def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None):
"""
将DataBundle中的数据进行tokenize

:param DataBundle data_bundle:
:param str field_name:
:param str new_field_name:
:return: 传入的DataBundle对象
"""
new_field_name = new_field_name or field_name
for name, dataset in data_bundle.datasets.items():
dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name)

return data_bundle

def _granularize(self, data_bundle, tag_map):
"""
该函数对data_bundle中'target'列中的内容进行转换。

:param data_bundle:
:param dict tag_map: 将target列中的tag做以下的映射,比如{"0":0, "1":0, "3":1, "4":1}, 则会删除target为"2"的instance,
且将"1"认为是第0类。
:return: 传入的data_bundle
"""
for name in list(data_bundle.datasets.keys()):
dataset = data_bundle.get_dataset(name)
dataset.apply_field(lambda target:tag_map.get(target, -100), field_name=Const.TARGET,
new_field_name=Const.TARGET)
dataset.drop(lambda ins:ins[Const.TARGET] == -100)
data_bundle.set_dataset(dataset, name)
return data_bundle


def _clean_str(words):
"""
heavily borrowed from github
https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb
:param sentence: is a str
:return:
"""
words_collection = []
for word in words:
if word in ['-lrb-', '-rrb-', '<sssss>', '-r', '-l', 'b-']:
continue
tt = nonalpnum.split(word)
t = ''.join(tt)
if t != '':
words_collection.append(t)

return words_collection


class YelpFullPipe(_CLSPipe):
"""
处理YelpFull的数据, 处理之后DataSet中的内容如下

.. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field
:header: "raw_words", "words", "target", "seq_len"

"It 's a ...", "[4, 2, 10, ...]", 0, 10
"Offers that ...", "[20, 40, ...]", 1, 21
"...", "[...]", ., .

:param bool lower: 是否对输入进行小写化。
:param int granularity: 支持2, 3, 5。若为2, 则认为是2分类问题,将1、2归为1类,4、5归为一类,丢掉2;若为3, 则有3分类问题,将
1、2归为1类,3归为1类,4、5归为1类;若为5, 则有5分类问题。
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。
"""
def __init__(self, lower:bool=False, granularity=5, tokenizer:str='spacy'):
super().__init__(tokenizer=tokenizer, lang='en')
self.lower = lower
assert granularity in (2, 3, 5), "granularity can only be 2,3,5."
self.granularity = granularity

if granularity==2:
self.tag_map = {"1": 0, "2": 0, "4": 1, "5": 1}
elif granularity==3:
self.tag_map = {"1": 0, "2": 0, "3":1, "4": 2, "5": 2}
else:
self.tag_map = {"1": 0, "2": 1, "3": 2, "4": 3, "5": 4}

def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None):
"""
将DataBundle中的数据进行tokenize

:param DataBundle data_bundle:
:param str field_name:
:param str new_field_name:
:return: 传入的DataBundle对象
"""
new_field_name = new_field_name or field_name
for name, dataset in data_bundle.datasets.items():
dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name)
dataset.apply_field(_clean_str, field_name=field_name, new_field_name=new_field_name)
return data_bundle

def process(self, data_bundle):
"""
传入的DataSet应该具备如下的结构

.. csv-table::
:header: "raw_words", "target"

"I got 'new' tires from them and... ", "1"
"Don't waste your time. We had two...", "1"
"...", "..."

:param data_bundle:
:return:
"""

# 复制一列words
data_bundle = _add_words_field(data_bundle, lower=self.lower)

# 进行tokenize
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT)

# 根据granularity设置tag
data_bundle = self._granularize(data_bundle, tag_map=self.tag_map)

# 删除空行
data_bundle = _drop_empty_instance(data_bundle, field_name=Const.INPUT)

# index
data_bundle = _indexize(data_bundle=data_bundle)

for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT)

data_bundle.set_input(Const.INPUT, Const.INPUT_LEN)
data_bundle.set_target(Const.TARGET)

return data_bundle

def process_from_file(self, paths=None):
"""

:param paths:
:return: DataBundle
"""
data_bundle = YelpFullLoader().load(paths)
return self.process(data_bundle=data_bundle)


class YelpPolarityPipe(_CLSPipe):
"""
处理YelpPolarity的数据, 处理之后DataSet中的内容如下

.. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field
:header: "raw_words", "words", "target", "seq_len"

"It 's a ...", "[4, 2, 10, ...]", 0, 10
"Offers that ...", "[20, 40, ...]", 1, 21
"...", "[...]", ., .

:param bool lower: 是否对输入进行小写化。
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。
"""
def __init__(self, lower:bool=False, tokenizer:str='spacy'):
super().__init__(tokenizer=tokenizer, lang='en')
self.lower = lower

def process(self, data_bundle):
# 复制一列words
data_bundle = _add_words_field(data_bundle, lower=self.lower)

# 进行tokenize
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT)
# index
data_bundle = _indexize(data_bundle=data_bundle)

for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT)

data_bundle.set_input(Const.INPUT, Const.INPUT_LEN)
data_bundle.set_target(Const.TARGET)

return data_bundle

def process_from_file(self, paths=None):
"""

:param str paths:
:return: DataBundle
"""
data_bundle = YelpPolarityLoader().load(paths)
return self.process(data_bundle=data_bundle)


class SSTPipe(_CLSPipe):
"""
别名::class:`fastNLP.io.SSTPipe` :class:`fastNLP.io.pipe.SSTPipe`

经过该Pipe之后,DataSet中具备的field如下所示

.. csv-table:: 下面是使用SSTPipe处理后的DataSet所具备的field
:header: "raw_words", "words", "target", "seq_len"

"It 's a ...", "[4, 2, 10, ...]", 0, 16
"Offers that ...", "[20, 40, ...]", 1, 18
"...", "[...]", ., .

:param bool subtree: 是否将train, test, dev数据展开为子树,扩充数据量。 Default: ``False``
:param bool train_subtree: 是否将train集通过子树扩展数据。
:param bool lower: 是否对输入进行小写化。
:param int granularity: 支持2, 3, 5。若为2, 则认为是2分类问题,将0、1归为1类,3、4归为一类,丢掉2;若为3, 则有3分类问题,将
0、1归为1类,2归为1类,3、4归为1类;若为5, 则有5分类问题。
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。
"""

def __init__(self, subtree=False, train_subtree=True, lower=False, granularity=5, tokenizer='spacy'):
super().__init__(tokenizer=tokenizer, lang='en')
self.subtree = subtree
self.train_tree = train_subtree
self.lower = lower
assert granularity in (2, 3, 5), "granularity can only be 2,3,5."
self.granularity = granularity

if granularity==2:
self.tag_map = {"0": 0, "1": 0, "3": 1, "4": 1}
elif granularity==3:
self.tag_map = {"0": 0, "1": 0, "2":1, "3": 2, "4": 2}
else:
self.tag_map = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4}

def process(self, data_bundle:DataBundle):
"""
对DataBundle中的数据进行预处理。输入的DataSet应该至少拥有raw_words这一列,且内容类似与

.. csv-table::
:header: "raw_words"

"(3 (2 It) (4 (4 (2 's) (4 (3 (2 a)..."
"(4 (4 (2 Offers) (3 (3 (2 that) (3 (3 rare)..."
"..."

:param DataBundle data_bundle: 需要处理的DataBundle对象
:return:
"""
# 先取出subtree
for name in list(data_bundle.datasets.keys()):
dataset = data_bundle.get_dataset(name)
ds = DataSet()
use_subtree = self.subtree or (name == 'train' and self.train_tree)
for ins in dataset:
raw_words = ins['raw_words']
tree = Tree.fromstring(raw_words)
if use_subtree:
for t in tree.subtrees():
raw_words = " ".join(t.leaves())
instance = Instance(raw_words=raw_words, target=t.label())
ds.append(instance)
else:
instance = Instance(raw_words=' '.join(tree.leaves()), target=tree.label())
ds.append(instance)
data_bundle.set_dataset(ds, name)

_add_words_field(data_bundle, lower=self.lower)

# 进行tokenize
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT)

# 根据granularity设置tag
data_bundle = self._granularize(data_bundle, tag_map=self.tag_map)

# index
data_bundle = _indexize(data_bundle=data_bundle)

for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT)

data_bundle.set_input(Const.INPUT, Const.INPUT_LEN)
data_bundle.set_target(Const.TARGET)

return data_bundle

def process_from_file(self, paths=None):
data_bundle = SSTLoader().load(paths)
return self.process(data_bundle=data_bundle)


class SST2Pipe(_CLSPipe):
"""
加载SST2的数据, 处理完成之后DataSet将拥有以下的field

.. csv-table::
:header: "raw_words", "words", "target", "seq_len"

"it 's a charming and... ", "[3, 4, 5, 6, 7,...]", 1, 43
"unflinchingly bleak and...", "[10, 11, 7,...]", 1, 21
"...", "...", ., .

:param bool lower: 是否对输入进行小写化。
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。
"""
def __init__(self, lower=False, tokenizer='spacy'):
super().__init__(tokenizer=tokenizer, lang='en')
self.lower = lower

def process(self, data_bundle:DataBundle):
"""
可以处理的DataSet应该具备如下的结构

.. csv-table::
:header: "raw_words", "target"

"it 's a charming and... ", 1
"unflinchingly bleak and...", 1
"...", "..."

:param data_bundle:
:return:
"""
_add_words_field(data_bundle, self.lower)

data_bundle = self._tokenize(data_bundle=data_bundle)

src_vocab = Vocabulary()
src_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.INPUT,
no_create_entry_dataset=[dataset for name,dataset in data_bundle.datasets.items() if
name != 'train'])
src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT)

tgt_vocab = Vocabulary(unknown=None, padding=None)
tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET)
datasets = []
for name, dataset in data_bundle.datasets.items():
if dataset.has_field(Const.TARGET):
datasets.append(dataset)
tgt_vocab.index_dataset(*datasets, field_name=Const.TARGET)

data_bundle.set_vocab(src_vocab, Const.INPUT)
data_bundle.set_vocab(tgt_vocab, Const.TARGET)

for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT)

data_bundle.set_input(Const.INPUT, Const.INPUT_LEN)
data_bundle.set_target(Const.TARGET)

return data_bundle

def process_from_file(self, paths=None):
"""

:param str paths: 如果为None,则自动下载并缓存到fastNLP的缓存地址。
:return: DataBundle
"""
data_bundle = SST2Loader().load(paths)
return self.process(data_bundle)


class IMDBPipe(_CLSPipe):
"""
经过本Pipe处理后DataSet将如下

.. csv-table:: 输出DataSet的field
:header: "raw_words", "words", "target", "seq_len"

"Bromwell High is a cartoon ... ", "[3, 5, 6, 9, ...]", 0, 20
"Story of a man who has ...", "[20, 43, 9, 10, ...]", 1, 31
"...", "[...]", ., .

其中raw_words为str类型,是原文; words是转换为index的输入; target是转换为index的目标值;
words列被设置为input; target列被设置为target。

:param bool lower: 是否将words列的数据小写。
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。
"""
def __init__(self, lower:bool=False, tokenizer:str='spacy'):
super().__init__(tokenizer=tokenizer, lang='en')
self.lower = lower

def process(self, data_bundle:DataBundle):
"""
期待的DataBunlde中输入的DataSet应该类似于如下,有两个field,raw_words和target,且均为str类型

.. csv-table:: 输入DataSet的field
:header: "raw_words", "target"

"Bromwell High is a cartoon ... ", "pos"
"Story of a man who has ...", "neg"
"...", "..."

:param DataBunlde data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和target两个field,且raw_words列应该为str,
target列应该为str。
:return:DataBundle
"""
# 替换<br />
def replace_br(raw_words):
raw_words = raw_words.replace("<br />", ' ')
return raw_words

for name, dataset in data_bundle.datasets.items():
dataset.apply_field(replace_br, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD)

_add_words_field(data_bundle, lower=self.lower)
self._tokenize(data_bundle, field_name=Const.INPUT, new_field_name=Const.INPUT)
_indexize(data_bundle)

for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT)
dataset.set_input(Const.INPUT, Const.INPUT_LEN)
dataset.set_target(Const.TARGET)

return data_bundle

def process_from_file(self, paths=None):
"""

:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。
:return: DataBundle
"""
# 读取数据
data_bundle = IMDBLoader().load(paths)
data_bundle = self.process(data_bundle)

return data_bundle




+ 149
- 0
fastNLP/io/pipe/conll.py View File

@@ -0,0 +1,149 @@
from .pipe import Pipe
from .. import DataBundle
from .utils import iob2, iob2bioes
from ... import Const
from ..loader.conll import Conll2003NERLoader, OntoNotesNERLoader
from .utils import _indexize, _add_words_field


class _NERPipe(Pipe):
"""
NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表
(创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将words,target列根据相应的
Vocabulary转换为index。

raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的
target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target, seq_len。

:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。
:param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为-100。
"""
def __init__(self, encoding_type:str='bio', lower:bool=False, target_pad_val=0):
if encoding_type == 'bio':
self.convert_tag = iob2
else:
self.convert_tag = iob2bioes
self.lower = lower
self.target_pad_val = int(target_pad_val)

def process(self, data_bundle:DataBundle)->DataBundle:
"""
支持的DataSet的field为

.. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader
:header: "raw_words", "target"

"[Nadim, Ladki]", "[B-PER, I-PER]"
"[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]"
"[...]", "[...]"


:param DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]。
在传入DataBundle基础上原位修改。
:return: DataBundle

Example::

data_bundle = Conll2003Loader().load('/path/to/conll2003/')
data_bundle = Conll2003NERPipe().process(data_bundle)

# 获取train
tr_data = data_bundle.get_dataset('train')

# 获取target这个field的词表
target_vocab = data_bundle.get_vocab('target')
# 获取words这个field的词表
word_vocab = data_bundle.get_vocab('words')

"""
# 转换tag
for name, dataset in data_bundle.datasets.items():
dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET)

_add_words_field(data_bundle, lower=self.lower)

# index
_indexize(data_bundle)

input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN]
target_fields = [Const.TARGET, Const.INPUT_LEN]

for name, dataset in data_bundle.datasets.items():
dataset.set_pad_val(Const.TARGET, self.target_pad_val)
dataset.add_seq_len(Const.INPUT)

data_bundle.set_input(*input_fields)
data_bundle.set_target(*target_fields)

return data_bundle

def process_from_file(self, paths) -> DataBundle:
"""

:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.ConllLoader` 的load函数。
:return: DataBundle
"""
# 读取数据
data_bundle = Conll2003NERLoader().load(paths)
data_bundle = self.process(data_bundle)

return data_bundle


class Conll2003NERPipe(_NERPipe):
"""
Conll2003的NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表
(创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将words,target列根据相应的
Vocabulary转换为index。
经过该Pipe过后,DataSet中的内容如下所示

.. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader
:header: "raw_words", "words", "target", "seq_len"

"[Nadim, Ladki]", "[1, 2]", "[1, 2]", 2
"[AL-AIN, United, Arab, ...]", "[3, 4, 5,...]", "[3, 4]", 10
"[...]", "[...]", "[...]", .

raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的
target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target。

:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。
:param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为-100。
"""

def process_from_file(self, paths) -> DataBundle:
"""

:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.ConllLoader` 的load函数。
:return: DataBundle
"""
# 读取数据
data_bundle = Conll2003NERLoader().load(paths)
data_bundle = self.process(data_bundle)

return data_bundle


class OntoNotesNERPipe(_NERPipe):
"""
处理OntoNotes的NER数据,处理之后DataSet中的field情况为

.. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader
:header: "raw_words", "words", "target", "seq_len"

"[Nadim, Ladki]", "[1, 2]", "[1, 2]", 2
"[AL-AIN, United, Arab, ...]", "[3, 4, 5,...]", "[3, 4]", 6
"[...]", "[...]", "[...]", .


:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。
:param bool delete_unused_fields: 是否删除NER任务中用不到的field。
:param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为-100。
"""

def process_from_file(self, paths):
data_bundle = OntoNotesNERLoader().load(paths)
return self.process(data_bundle)


+ 254
- 0
fastNLP/io/pipe/matching.py View File

@@ -0,0 +1,254 @@
import math

from .pipe import Pipe
from .utils import get_tokenizer
from ...core import Const
from ...core import Vocabulary
from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader

class MatchingBertPipe(Pipe):
"""
Matching任务的Bert pipe,输出的DataSet将包含以下的field

.. csv-table::
:header: "raw_words1", "raw_words2", "words", "target", "seq_len"

"The new rights are...", "Everyone really likes..", "[2, 3, 4, 5, ...]", 1, 10
"This site includes a...", "The Government Executive...", "[11, 12, 13,...]", 0, 5
"...", "...", "[...]", ., .

words列是将raw_words1(即premise), raw_words2(即hypothesis)使用"[SEP]"链接起来转换为index的。
words列被设置为input,target列被设置为target.

:param bool lower: 是否将word小写化。
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。
:param int max_concat_sent_length: 如果concat后的句子长度超过了该值,则合并后的句子将被截断到这个长度,截断时同时对premise
和hypothesis按比例截断。
"""
def __init__(self, lower=False, tokenizer:str='raw', max_concat_sent_length:int=480):
super().__init__()

self.lower = bool(lower)
self.tokenizer = get_tokenizer(tokenizer=tokenizer)
self.max_concat_sent_length = int(max_concat_sent_length)

def _tokenize(self, data_bundle, field_names, new_field_names):
"""

:param DataBundle data_bundle: DataBundle.
:param list field_names: List[str], 需要tokenize的field名称
:param list new_field_names: List[str], tokenize之后field的名称,与field_names一一对应。
:return: 输入的DataBundle对象
"""
for name, dataset in data_bundle.datasets.items():
for field_name, new_field_name in zip(field_names, new_field_names):
dataset.apply_field(lambda words:self.tokenizer(words), field_name=field_name,
new_field_name=new_field_name)
return data_bundle

def process(self, data_bundle):
for name, dataset in data_bundle.datasets.items():
dataset.copy_field(Const.RAW_WORDS(0), Const.INPUTS(0))
dataset.copy_field(Const.RAW_WORDS(1), Const.INPUTS(1))

if self.lower:
for name, dataset in data_bundle.datasets.items():
dataset[Const.INPUTS(0)].lower()
dataset[Const.INPUTS(1)].lower()

data_bundle = self._tokenize(data_bundle, [Const.INPUTS(0), Const.INPUT(1)],
[Const.INPUTS(0), Const.INPUTS(1)])

# concat两个words
def concat(ins):
words0 = ins[Const.INPUTS(0)]
words1 = ins[Const.INPUTS(1)]
len0 = len(words0)
len1 = len(words1)
if len0 + len1 > self.max_concat_sent_length:
ratio = self.max_concat_sent_length / (len0 + len1)
len0 = math.floor(ratio * len0)
len1 = math.floor(ratio * len1)
words0 = words0[:len0]
words1 = words1[:len1]

words = words0 + ['[SEP]'] + words1
return words
for name, dataset in data_bundle.datasets.items():
dataset.apply(concat, new_field_name=Const.INPUT)
dataset.delete_field(Const.INPUTS(0))
dataset.delete_field(Const.INPUTS(1))

word_vocab = Vocabulary()
word_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.INPUT,
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if
name != 'train'])
word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT)

target_vocab = Vocabulary(padding=None, unknown=None)
target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET)
has_target_datasets = []
for name, dataset in data_bundle.datasets.items():
if dataset.has_field(Const.TARGET):
has_target_datasets.append(dataset)
target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET)

data_bundle.set_vocab(word_vocab, Const.INPUT)
data_bundle.set_vocab(target_vocab, Const.TARGET)

input_fields = [Const.INPUT, Const.INPUT_LEN]
target_fields = [Const.TARGET]

for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT)
dataset.set_input(*input_fields, flag=True)
dataset.set_target(*target_fields, flag=True)

return data_bundle


class RTEBertPipe(MatchingBertPipe):
def process_from_file(self, paths=None):
data_bundle = RTELoader().load(paths)
return self.process(data_bundle)


class SNLIBertPipe(MatchingBertPipe):
def process_from_file(self, paths=None):
data_bundle = SNLILoader().load(paths)
return self.process(data_bundle)


class QuoraBertPipe(MatchingBertPipe):
def process_from_file(self, paths):
data_bundle = QuoraLoader().load(paths)
return self.process(data_bundle)


class QNLIBertPipe(MatchingBertPipe):
def process_from_file(self, paths=None):
data_bundle = QNLILoader().load(paths)
return self.process(data_bundle)


class MNLIBertPipe(MatchingBertPipe):
def process_from_file(self, paths=None):
data_bundle = MNLILoader().load(paths)
return self.process(data_bundle)


class MatchingPipe(Pipe):
"""
Matching任务的Pipe。输出的DataSet将包含以下的field

.. csv-table::
:header: "raw_words1", "raw_words2", "words1", "words2", "target", "seq_len1", "seq_len2"

"The new rights are...", "Everyone really likes..", "[2, 3, 4, 5, ...]", "[10, 20, 6]", 1, 10, 13
"This site includes a...", "The Government Executive...", "[11, 12, 13,...]", "[2, 7, ...]", 0, 6, 7
"...", "...", "[...]", "[...]", ., ., .

words1是premise,words2是hypothesis。其中words1,words2,seq_len1,seq_len2被设置为input;target被设置为target。

:param bool lower: 是否将所有raw_words转为小写。
:param str tokenizer: 将原始数据tokenize的方式。支持spacy, raw. spacy是使用spacy切分,raw就是用空格切分。
"""
def __init__(self, lower=False, tokenizer:str='raw'):
super().__init__()

self.lower = bool(lower)
self.tokenizer = get_tokenizer(tokenizer=tokenizer)

def _tokenize(self, data_bundle, field_names, new_field_names):
"""

:param DataBundle data_bundle: DataBundle.
:param list field_names: List[str], 需要tokenize的field名称
:param list new_field_names: List[str], tokenize之后field的名称,与field_names一一对应。
:return: 输入的DataBundle对象
"""
for name, dataset in data_bundle.datasets.items():
for field_name, new_field_name in zip(field_names, new_field_names):
dataset.apply_field(lambda words:self.tokenizer(words), field_name=field_name,
new_field_name=new_field_name)
return data_bundle

def process(self, data_bundle):
"""
接受的DataBundle中的DataSet应该具有以下的field, target列可以没有

.. csv-table::
:header: "raw_words1", "raw_words2", "target"

"The new rights are...", "Everyone really likes..", "entailment"
"This site includes a...", "The Government Executive...", "not_entailment"
"...", "..."

:param data_bundle:
:return:
"""
data_bundle = self._tokenize(data_bundle, [Const.RAW_WORDS(0), Const.RAW_WORDS(1)],
[Const.INPUTS(0), Const.INPUTS(1)])

if self.lower:
for name, dataset in data_bundle.datasets.items():
dataset[Const.INPUTS(0)].lower()
dataset[Const.INPUTS(1)].lower()

word_vocab = Vocabulary()
word_vocab.from_dataset(data_bundle.datasets['train'], field_name=[Const.INPUTS(0), Const.INPUTS(1)],
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if
name != 'train'])
word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=[Const.INPUTS(0), Const.INPUTS(1)])

target_vocab = Vocabulary(padding=None, unknown=None)
target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET)
has_target_datasets = []
for name, dataset in data_bundle.datasets.items():
if dataset.has_field(Const.TARGET):
has_target_datasets.append(dataset)
target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET)

data_bundle.set_vocab(word_vocab, Const.INPUTS(0))
data_bundle.set_vocab(target_vocab, Const.TARGET)

input_fields = [Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LEN(0), Const.INPUT_LEN(1)]
target_fields = [Const.TARGET]

for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUTS(0), Const.INPUT_LEN(0))
dataset.add_seq_len(Const.INPUTS(1), Const.INPUT_LEN(1))
dataset.set_input(*input_fields, flag=True)
dataset.set_target(*target_fields, flag=True)

return data_bundle


class RTEPipe(MatchingPipe):
def process_from_file(self, paths=None):
data_bundle = RTELoader().load(paths)
return self.process(data_bundle)


class SNLIPipe(MatchingPipe):
def process_from_file(self, paths=None):
data_bundle = SNLILoader().load(paths)
return self.process(data_bundle)


class QuoraPipe(MatchingPipe):
def process_from_file(self, paths):
data_bundle = QuoraLoader().load(paths)
return self.process(data_bundle)

class QNLIPipe(MatchingPipe):
def process_from_file(self, paths=None):
data_bundle = QNLILoader().load(paths)
return self.process(data_bundle)


class MNLIPipe(MatchingPipe):
def process_from_file(self, paths=None):
data_bundle = MNLILoader().load(paths)
return self.process(data_bundle)


+ 9
- 0
fastNLP/io/pipe/pipe.py View File

@@ -0,0 +1,9 @@

from .. import DataBundle

class Pipe:
def process(self, data_bundle:DataBundle)->DataBundle:
raise NotImplementedError

def process_from_file(self, paths)->DataBundle:
raise NotImplementedError

+ 142
- 0
fastNLP/io/pipe/utils.py View File

@@ -0,0 +1,142 @@
from typing import List
from ...core import Vocabulary
from ...core import Const

def iob2(tags:List[str])->List[str]:
"""
检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两种格式的区别见https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format

:param tags: 需要转换的tags
"""
for i, tag in enumerate(tags):
if tag == "O":
continue
split = tag.split("-")
if len(split) != 2 or split[0] not in ["I", "B"]:
raise TypeError("The encoding schema is not a valid IOB type.")
if split[0] == "B":
continue
elif i == 0 or tags[i - 1] == "O": # conversion IOB1 to IOB2
tags[i] = "B" + tag[1:]
elif tags[i - 1][1:] == tag[1:]:
continue
else: # conversion IOB1 to IOB2
tags[i] = "B" + tag[1:]
return tags

def iob2bioes(tags:List[str])->List[str]:
"""
将iob的tag转换为bioes编码
:param tags:
:return:
"""
new_tags = []
for i, tag in enumerate(tags):
if tag == 'O':
new_tags.append(tag)
else:
split = tag.split('-')[0]
if split == 'B':
if i+1!=len(tags) and tags[i+1].split('-')[0] == 'I':
new_tags.append(tag)
else:
new_tags.append(tag.replace('B-', 'S-'))
elif split == 'I':
if i + 1<len(tags) and tags[i+1].split('-')[0] == 'I':
new_tags.append(tag)
else:
new_tags.append(tag.replace('I-', 'E-'))
else:
raise TypeError("Invalid IOB format.")
return new_tags


def get_tokenizer(tokenizer:str, lang='en'):
"""

:param str tokenizer: 获取tokenzier方法
:param str lang: 语言,当前仅支持en
:return: 返回tokenize函数
"""
if tokenizer == 'spacy':
import spacy
spacy.prefer_gpu()
if lang!='en':
raise RuntimeError("Spacy only supports en right right.")
en = spacy.load(lang)
tokenizer = lambda x: [w.text for w in en.tokenizer(x)]
elif tokenizer == 'raw':
tokenizer = _raw_split
else:
raise RuntimeError("Only support `spacy`, `raw` tokenizer.")
return tokenizer


def _raw_split(sent):
return sent.split()


def _indexize(data_bundle):
"""
在dataset中的"words"列建立词表,"target"列建立词表,并把词表加入到data_bundle中。

:param data_bundle:
:return:
"""
src_vocab = Vocabulary()
src_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.INPUT,
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if
name != 'train'])
src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT)

tgt_vocab = Vocabulary(unknown=None, padding=None)
tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET)
tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.TARGET)

data_bundle.set_vocab(src_vocab, Const.INPUT)
data_bundle.set_vocab(tgt_vocab, Const.TARGET)

return data_bundle


def _add_words_field(data_bundle, lower=False):
"""
给data_bundle中的dataset中复制一列words. 并根据lower参数判断是否需要小写化

:param data_bundle:
:param bool lower:是否要小写化
:return: 传入的DataBundle
"""
for name, dataset in data_bundle.datasets.items():
dataset.copy_field(field_name=Const.RAW_WORD, new_field_name=Const.INPUT)

if lower:
for name, dataset in data_bundle.datasets.items():
dataset[Const.INPUT].lower()
return data_bundle

def _drop_empty_instance(data_bundle, field_name):
"""
删除data_bundle的DataSet中存在的某个field为空的情况

:param data_bundle: DataBundle
:param str field_name: 对哪个field进行检查,如果为None,则任意field为空都会删掉
:return: 传入的DataBundle
"""
def empty_instance(ins):
if field_name:
field_value = ins[field_name]
if field_value in ((), {}, [], ''):
return True
return False
for _, field_value in ins.items():
if field_value in ((), {}, [], ''):
return True
return False

for name, dataset in data_bundle.datasets.items():
dataset.drop(empty_instance)

return data_bundle



+ 10
- 4
fastNLP/io/utils.py View File

@@ -1,9 +1,10 @@
import os

from typing import Union, Dict
from pathlib import Path


def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
"""
检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果
{
@@ -11,13 +12,14 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
'test': 'xxx' # 可能有,也可能没有
...
}
如果paths为不合法的,将直接进行raise相应的错误
如果paths为不合法的,将直接进行raise相应的错误. 如果paths内不包含train也会报错。

:param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名
:param str paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名
中包含train这个字段), test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。
:return:
"""
if isinstance(paths, str):
if isinstance(paths, (str, Path)):
paths = os.path.abspath(os.path.expanduser(paths))
if os.path.isfile(paths):
return {'train': paths}
elif os.path.isdir(paths):
@@ -37,6 +39,8 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
path_pair = ('test', filename)
if path_pair:
files[path_pair[0]] = os.path.join(paths, path_pair[1])
if 'train' not in files:
raise KeyError(f"There is no train file in {paths}.")
return files
else:
raise FileNotFoundError(f"{paths} is not a valid file path.")
@@ -47,8 +51,10 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
raise KeyError("You have to include `train` in your dict.")
for key, value in paths.items():
if isinstance(key, str) and isinstance(value, str):
value = os.path.abspath(os.path.expanduser(value))
if not os.path.isfile(value):
raise TypeError(f"{value} is not a valid file.")
paths[key] = value
else:
raise TypeError("All keys and values in paths should be str.")
return paths


+ 0
- 0
test/embeddings/__init__.py View File


test/modules/encoder/test_bert.py → test/embeddings/test_bert.py View File


+ 21
- 0
test/embeddings/test_elmo_embedding.py View File

@@ -0,0 +1,21 @@

import unittest
from fastNLP import Vocabulary
from fastNLP.embeddings import ElmoEmbedding
import torch
import os

@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
class TestDownload(unittest.TestCase):
def test_download_small(self):
# import os
vocab = Vocabulary().add_word_lst("This is a test .".split())
elmo_embed = ElmoEmbedding(vocab, model_dir_or_name='en-small')
words = torch.LongTensor([[0, 1, 2]])
print(elmo_embed(words).size())


# 首先保证所有权重可以加载;上传权重;验证可以下载




+ 19
- 0
test/io/loader/test_classification_loader.py View File

@@ -0,0 +1,19 @@

import unittest
from fastNLP.io.loader.classification import YelpFullLoader
from fastNLP.io.loader.classification import YelpPolarityLoader
from fastNLP.io.loader.classification import IMDBLoader
from fastNLP.io.loader.classification import SST2Loader
from fastNLP.io.loader.classification import SSTLoader
import os

@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
class TestDownload(unittest.TestCase):
def test_download(self):
for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader]:
loader().download()

def test_load(self):
for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader]:
data_bundle = loader().load()
print(data_bundle)

+ 22
- 0
test/io/loader/test_matching_loader.py View File

@@ -0,0 +1,22 @@

import unittest
from fastNLP.io.loader.matching import RTELoader
from fastNLP.io.loader.matching import QNLILoader
from fastNLP.io.loader.matching import SNLILoader
from fastNLP.io.loader.matching import QuoraLoader
from fastNLP.io.loader.matching import MNLILoader
import os

@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
class TestDownload(unittest.TestCase):
def test_download(self):
for loader in [RTELoader, QNLILoader, SNLILoader, MNLILoader]:
loader().download()
with self.assertRaises(Exception):
QuoraLoader().load()

def test_load(self):
for loader in [RTELoader, QNLILoader, SNLILoader, MNLILoader]:
data_bundle = loader().load()
print(data_bundle)


+ 13
- 0
test/io/pipe/test_classification.py View File

@@ -0,0 +1,13 @@
import unittest
import os

from fastNLP.io.pipe.classification import SSTPipe, SST2Pipe, IMDBPipe, YelpFullPipe, YelpPolarityPipe

@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
class TestPipe(unittest.TestCase):
def test_process_from_file(self):
for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]:
with self.subTest(pipe=pipe):
print(pipe)
data_bundle = pipe(tokenizer='raw').process_from_file()
print(data_bundle)

+ 26
- 0
test/io/pipe/test_matching.py View File

@@ -0,0 +1,26 @@

import unittest
import os

from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe
from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe


@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
class TestPipe(unittest.TestCase):
def test_process_from_file(self):
for pipe in [SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe]:
with self.subTest(pipe=pipe):
print(pipe)
data_bundle = pipe(tokenizer='raw').process_from_file()
print(data_bundle)


@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
class TestBertPipe(unittest.TestCase):
def test_process_from_file(self):
for pipe in [SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe]:
with self.subTest(pipe=pipe):
print(pipe)
data_bundle = pipe(tokenizer='raw').process_from_file()
print(data_bundle)

Loading…
Cancel
Save