@@ -1,6 +1,9 @@ | |||||
language: python | language: python | ||||
python: | python: | ||||
- "3.6" | - "3.6" | ||||
env | |||||
- TRAVIS=1 | |||||
# command to install dependencies | # command to install dependencies | ||||
install: | install: | ||||
- pip install --quiet -r requirements.txt | - pip install --quiet -r requirements.txt | ||||
@@ -48,6 +48,11 @@ class DataSetGetter: | |||||
return len(self.dataset) | return len(self.dataset) | ||||
def collate_fn(self, batch: list): | def collate_fn(self, batch: list): | ||||
""" | |||||
:param batch: [[idx1, x_dict1, y_dict1], [idx2, x_dict2, y_dict2], [xx, xx, xx]] | |||||
:return: | |||||
""" | |||||
# TODO 支持在DataSet中定义collate_fn,因为有时候可能需要不同的field之间融合,比如BERT的场景 | # TODO 支持在DataSet中定义collate_fn,因为有时候可能需要不同的field之间融合,比如BERT的场景 | ||||
batch_x = {n:[] for n in self.inputs.keys()} | batch_x = {n:[] for n in self.inputs.keys()} | ||||
batch_y = {n:[] for n in self.targets.keys()} | batch_y = {n:[] for n in self.targets.keys()} | ||||
@@ -208,6 +213,13 @@ class OnlineDataIter(BatchIter): | |||||
def _to_tensor(batch, field_dtype): | def _to_tensor(batch, field_dtype): | ||||
""" | |||||
:param batch: np.array() | |||||
:param field_dtype: 数据类型 | |||||
:return: batch, flag. 如果传入的数据支持转为tensor,返回的batch就是tensor,且flag为True;如果传入的数据不支持转为tensor, | |||||
返回的batch就是原来的数据,且flag为False | |||||
""" | |||||
try: | try: | ||||
if field_dtype is not None and isinstance(field_dtype, type)\ | if field_dtype is not None and isinstance(field_dtype, type)\ | ||||
and issubclass(field_dtype, Number) \ | and issubclass(field_dtype, Number) \ | ||||
@@ -7,12 +7,14 @@ class Const: | |||||
具体列表:: | 具体列表:: | ||||
INPUT 模型的序列输入 words(复数words1, words2) | |||||
CHAR_INPUT 模型character输入 chars(复数chars1, chars2) | |||||
INPUT_LEN 序列长度 seq_len(复数seq_len1,seq_len2) | |||||
OUTPUT 模型输出 pred(复数pred1, pred2) | |||||
TARGET 真实目标 target(复数target1,target2) | |||||
LOSS 损失函数 loss (复数loss1,loss2) | |||||
INPUT 模型的序列输入 words(具有多列words时,依次使用words1, words2, ) | |||||
CHAR_INPUT 模型character输入 chars(具有多列chars时,依次使用chars1, chars2) | |||||
INPUT_LEN 序列长度 seq_len(具有多列seq_len时,依次使用seq_len1,seq_len2) | |||||
OUTPUT 模型输出 pred(具有多列pred时,依次使用pred1, pred2) | |||||
TARGET 真实目标 target(具有多列target时,依次使用target1,target2) | |||||
LOSS 损失函数 loss (具有多列loss时,依次使用loss1,loss2) | |||||
RAW_WORD 原文的词 raw_words (具有多列raw_words时,依次使用raw_words1, raw_words2) | |||||
RAW_CHAR 原文的字 raw_chars (具有多列raw_chars时,依次使用raw_chars1, raw_chars2) | |||||
""" | """ | ||||
INPUT = 'words' | INPUT = 'words' | ||||
@@ -21,6 +23,8 @@ class Const: | |||||
OUTPUT = 'pred' | OUTPUT = 'pred' | ||||
TARGET = 'target' | TARGET = 'target' | ||||
LOSS = 'loss' | LOSS = 'loss' | ||||
RAW_WORD = 'raw_words' | |||||
RAW_CHAR = 'raw_chars' | |||||
@staticmethod | @staticmethod | ||||
def INPUTS(i): | def INPUTS(i): | ||||
@@ -34,6 +38,16 @@ class Const: | |||||
i = int(i) + 1 | i = int(i) + 1 | ||||
return Const.CHAR_INPUT + str(i) | return Const.CHAR_INPUT + str(i) | ||||
@staticmethod | |||||
def RAW_WORDS(i): | |||||
i = int(i) + 1 | |||||
return Const.RAW_WORD + str(i) | |||||
@staticmethod | |||||
def RAW_CHARS(i): | |||||
i = int(i) + 1 | |||||
return Const.RAW_CHAR + str(i) | |||||
@staticmethod | @staticmethod | ||||
def INPUT_LENS(i): | def INPUT_LENS(i): | ||||
"""得到第 i 个 ``INPUT_LEN`` 的命名""" | """得到第 i 个 ``INPUT_LEN`` 的命名""" | ||||
@@ -291,6 +291,7 @@ import _pickle as pickle | |||||
import warnings | import warnings | ||||
import numpy as np | import numpy as np | ||||
from copy import deepcopy | |||||
from .field import AutoPadder | from .field import AutoPadder | ||||
from .field import FieldArray | from .field import FieldArray | ||||
@@ -298,6 +299,7 @@ from .instance import Instance | |||||
from .utils import _get_func_signature | from .utils import _get_func_signature | ||||
from .field import AppendToTargetOrInputException | from .field import AppendToTargetOrInputException | ||||
from .field import SetInputOrTargetException | from .field import SetInputOrTargetException | ||||
from .const import Const | |||||
class DataSet(object): | class DataSet(object): | ||||
""" | """ | ||||
@@ -349,7 +351,11 @@ class DataSet(object): | |||||
self.idx]) | self.idx]) | ||||
assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx) | assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx) | ||||
return self.dataset.field_arrays[item][self.idx] | return self.dataset.field_arrays[item][self.idx] | ||||
def items(self): | |||||
ins = self.dataset[self.idx] | |||||
return ins.items() | |||||
def __repr__(self): | def __repr__(self): | ||||
return self.dataset[self.idx].__repr__() | return self.dataset[self.idx].__repr__() | ||||
@@ -497,6 +503,7 @@ class DataSet(object): | |||||
else: | else: | ||||
for field in self.field_arrays.values(): | for field in self.field_arrays.values(): | ||||
field.pop(index) | field.pop(index) | ||||
return self | |||||
def delete_field(self, field_name): | def delete_field(self, field_name): | ||||
""" | """ | ||||
@@ -505,7 +512,22 @@ class DataSet(object): | |||||
:param str field_name: 需要删除的field的名称. | :param str field_name: 需要删除的field的名称. | ||||
""" | """ | ||||
self.field_arrays.pop(field_name) | self.field_arrays.pop(field_name) | ||||
return self | |||||
def copy_field(self, field_name, new_field_name): | |||||
""" | |||||
深度copy名为field_name的field到new_field_name | |||||
:param str field_name: 需要copy的field。 | |||||
:param str new_field_name: copy生成的field名称 | |||||
:return: self | |||||
""" | |||||
if not self.has_field(field_name): | |||||
raise KeyError(f"Field:{field_name} not found in DataSet.") | |||||
fieldarray = deepcopy(self.get_field(field_name)) | |||||
self.add_fieldarray(field_name=new_field_name, fieldarray=fieldarray) | |||||
return self | |||||
def has_field(self, field_name): | def has_field(self, field_name): | ||||
""" | """ | ||||
判断DataSet中是否有名为field_name这个field | 判断DataSet中是否有名为field_name这个field | ||||
@@ -701,7 +723,7 @@ class DataSet(object): | |||||
results.append(func(ins[field_name])) | results.append(func(ins[field_name])) | ||||
except Exception as e: | except Exception as e: | ||||
if idx != -1: | if idx != -1: | ||||
print("Exception happens at the `{}`th instance.".format(idx)) | |||||
print("Exception happens at the `{}`th(from 1) instance.".format(idx+1)) | |||||
raise e | raise e | ||||
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | ||||
raise ValueError("{} always return None.".format(_get_func_signature(func=func))) | raise ValueError("{} always return None.".format(_get_func_signature(func=func))) | ||||
@@ -766,10 +788,11 @@ class DataSet(object): | |||||
results = [] | results = [] | ||||
for idx, ins in enumerate(self._inner_iter()): | for idx, ins in enumerate(self._inner_iter()): | ||||
results.append(func(ins)) | results.append(func(ins)) | ||||
except Exception as e: | |||||
except BaseException as e: | |||||
if idx != -1: | if idx != -1: | ||||
print("Exception happens at the `{}`th instance.".format(idx)) | print("Exception happens at the `{}`th instance.".format(idx)) | ||||
raise e | raise e | ||||
# results = [func(ins) for ins in self._inner_iter()] | # results = [func(ins) for ins in self._inner_iter()] | ||||
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | ||||
raise ValueError("{} always return None.".format(_get_func_signature(func=func))) | raise ValueError("{} always return None.".format(_get_func_signature(func=func))) | ||||
@@ -779,7 +802,7 @@ class DataSet(object): | |||||
return results | return results | ||||
def add_seq_len(self, field_name:str, new_field_name='seq_len'): | |||||
def add_seq_len(self, field_name:str, new_field_name=Const.INPUT_LEN): | |||||
""" | """ | ||||
将使用len()直接对field_name中每个元素作用,将其结果作为seqence length, 并放入seq_len这个field。 | 将使用len()直接对field_name中每个元素作用,将其结果作为seqence length, 并放入seq_len这个field。 | ||||
@@ -7,6 +7,7 @@ from typing import Any | |||||
from abc import abstractmethod | from abc import abstractmethod | ||||
from copy import deepcopy | from copy import deepcopy | ||||
from collections import Counter | from collections import Counter | ||||
from .utils import _is_iterable | |||||
class SetInputOrTargetException(Exception): | class SetInputOrTargetException(Exception): | ||||
def __init__(self, msg, index=None, field_name=None): | def __init__(self, msg, index=None, field_name=None): | ||||
@@ -443,15 +444,6 @@ def _get_ele_type_and_dim(cell:Any, dim=0): | |||||
raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.") | raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.") | ||||
def _is_iterable(value): | |||||
# 检查是否是iterable的, duck typing | |||||
try: | |||||
iter(value) | |||||
return True | |||||
except BaseException as e: | |||||
return False | |||||
class Padder: | class Padder: | ||||
""" | """ | ||||
别名::class:`fastNLP.Padder` :class:`fastNLP.core.field.Padder` | 别名::class:`fastNLP.Padder` :class:`fastNLP.core.field.Padder` | ||||
@@ -35,6 +35,13 @@ class Instance(object): | |||||
:param Any field: 新增field的内容 | :param Any field: 新增field的内容 | ||||
""" | """ | ||||
self.fields[field_name] = field | self.fields[field_name] = field | ||||
def items(self): | |||||
""" | |||||
返回一个迭代器,迭代器返回两个内容,第一个内容是field_name, 第二个内容是field_value | |||||
:return: | |||||
""" | |||||
return self.fields.items() | |||||
def __getitem__(self, name): | def __getitem__(self, name): | ||||
if name in self.fields: | if name in self.fields: | ||||
@@ -4,6 +4,7 @@ utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户 | |||||
__all__ = [ | __all__ = [ | ||||
"cache_results", | "cache_results", | ||||
"seq_len_to_mask", | "seq_len_to_mask", | ||||
"get_seq_len" | |||||
] | ] | ||||
import _pickle | import _pickle | ||||
@@ -730,3 +731,23 @@ def iob2bioes(tags: List[str]) -> List[str]: | |||||
else: | else: | ||||
raise TypeError("Invalid IOB format.") | raise TypeError("Invalid IOB format.") | ||||
return new_tags | return new_tags | ||||
def _is_iterable(value): | |||||
# 检查是否是iterable的, duck typing | |||||
try: | |||||
iter(value) | |||||
return True | |||||
except BaseException as e: | |||||
return False | |||||
def get_seq_len(words, pad_value=0): | |||||
""" | |||||
给定batch_size x max_len的words矩阵,返回句子长度 | |||||
:param words: batch_size x max_len | |||||
:return: (batch_size,) | |||||
""" | |||||
mask = words.ne(pad_value) | |||||
return mask.sum(dim=-1) |
@@ -4,12 +4,12 @@ __all__ = [ | |||||
] | ] | ||||
from functools import wraps | from functools import wraps | ||||
from collections import Counter, defaultdict | |||||
from collections import Counter | |||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .utils import Option | from .utils import Option | ||||
from functools import partial | from functools import partial | ||||
import numpy as np | import numpy as np | ||||
from .utils import _is_iterable | |||||
class VocabularyOption(Option): | class VocabularyOption(Option): | ||||
def __init__(self, | def __init__(self, | ||||
@@ -131,11 +131,11 @@ class Vocabulary(object): | |||||
""" | """ | ||||
在新加入word时,检查_no_create_word的设置。 | 在新加入word时,检查_no_create_word的设置。 | ||||
:param str, List[str] word: | |||||
:param str List[str] word: | |||||
:param bool no_create_entry: | :param bool no_create_entry: | ||||
:return: | :return: | ||||
""" | """ | ||||
if isinstance(word, str): | |||||
if isinstance(word, str) or not _is_iterable(word): | |||||
word = [word] | word = [word] | ||||
for w in word: | for w in word: | ||||
if no_create_entry and self.word_count.get(w, 0) == self._no_create_word.get(w, 0): | if no_create_entry and self.word_count.get(w, 0) == self._no_create_word.get(w, 0): | ||||
@@ -257,35 +257,45 @@ class Vocabulary(object): | |||||
vocab.index_dataset(train_data, dev_data, test_data, field_name='words') | vocab.index_dataset(train_data, dev_data, test_data, field_name='words') | ||||
:param ~fastNLP.DataSet,List[~fastNLP.DataSet] datasets: 需要转index的一个或多个数据集 | :param ~fastNLP.DataSet,List[~fastNLP.DataSet] datasets: 需要转index的一个或多个数据集 | ||||
:param str field_name: 需要转index的field, 若有多个 DataSet, 每个DataSet都必须有此 field. | |||||
目前仅支持 ``str`` , ``List[str]`` , ``List[List[str]]`` | |||||
:param str new_field_name: 保存结果的field_name. 若为 ``None`` , 将覆盖原field. | |||||
Default: ``None`` | |||||
:param list,str field_name: 需要转index的field, 若有多个 DataSet, 每个DataSet都必须有此 field. | |||||
目前支持 ``str`` , ``List[str]`` | |||||
:param list,str new_field_name: 保存结果的field_name. 若为 ``None`` , 将覆盖原field. | |||||
Default: ``None``. | |||||
""" | """ | ||||
def index_instance(ins): | |||||
def index_instance(field): | |||||
""" | """ | ||||
有几种情况, str, 1d-list, 2d-list | 有几种情况, str, 1d-list, 2d-list | ||||
:param ins: | :param ins: | ||||
:return: | :return: | ||||
""" | """ | ||||
field = ins[field_name] | |||||
if isinstance(field, str): | |||||
if isinstance(field, str) or not _is_iterable(field): | |||||
return self.to_index(field) | return self.to_index(field) | ||||
elif isinstance(field, list): | |||||
if not isinstance(field[0], list): | |||||
else: | |||||
if isinstance(field[0], str) or not _is_iterable(field[0]): | |||||
return [self.to_index(w) for w in field] | return [self.to_index(w) for w in field] | ||||
else: | else: | ||||
if isinstance(field[0][0], list): | |||||
if not isinstance(field[0][0], str) and _is_iterable(field[0][0]): | |||||
raise RuntimeError("Only support field with 2 dimensions.") | raise RuntimeError("Only support field with 2 dimensions.") | ||||
return [[self.to_index(c) for c in w] for w in field] | return [[self.to_index(c) for c in w] for w in field] | ||||
if new_field_name is None: | |||||
new_field_name = field_name | |||||
new_field_name = new_field_name or field_name | |||||
if type(new_field_name) == type(field_name): | |||||
if isinstance(new_field_name, list): | |||||
assert len(new_field_name) == len(field_name), "new_field_name should have same number elements with " \ | |||||
"field_name." | |||||
elif isinstance(new_field_name, str): | |||||
field_name = [field_name] | |||||
new_field_name = [new_field_name] | |||||
else: | |||||
raise TypeError("field_name and new_field_name can only be str or List[str].") | |||||
for idx, dataset in enumerate(datasets): | for idx, dataset in enumerate(datasets): | ||||
if isinstance(dataset, DataSet): | if isinstance(dataset, DataSet): | ||||
try: | try: | ||||
dataset.apply(index_instance, new_field_name=new_field_name) | |||||
for f_n, n_f_n in zip(field_name, new_field_name): | |||||
dataset.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n) | |||||
except Exception as e: | except Exception as e: | ||||
print("When processing the `{}` dataset, the following error occurred.".format(idx)) | print("When processing the `{}` dataset, the following error occurred.".format(idx)) | ||||
raise e | raise e | ||||
@@ -306,9 +316,8 @@ class Vocabulary(object): | |||||
:param ~fastNLP.DataSet,List[~fastNLP.DataSet] datasets: 需要转index的一个或多个数据集 | :param ~fastNLP.DataSet,List[~fastNLP.DataSet] datasets: 需要转index的一个或多个数据集 | ||||
:param str,List[str] field_name: 可为 ``str`` 或 ``List[str]`` . | :param str,List[str] field_name: 可为 ``str`` 或 ``List[str]`` . | ||||
构建词典所使用的 field(s), 支持一个或多个field | |||||
若有多个 DataSet, 每个DataSet都必须有这些field. | |||||
目前仅支持的field结构: ``str`` , ``List[str]`` , ``list[List[str]]`` | |||||
构建词典所使用的 field(s), 支持一个或多个field,若有多个 DataSet, 每个DataSet都必须有这些field. 目前支持的field结构 | |||||
: ``str`` , ``List[str]`` | |||||
:param no_create_entry_dataset: 可以传入DataSet, List[DataSet]或者None(默认),该选项用在接下来的模型会使用pretrain | :param no_create_entry_dataset: 可以传入DataSet, List[DataSet]或者None(默认),该选项用在接下来的模型会使用pretrain | ||||
的embedding(包括glove, word2vec, elmo与bert)且会finetune的情况。如果仅使用来自于train的数据建立vocabulary,会导致test与dev | 的embedding(包括glove, word2vec, elmo与bert)且会finetune的情况。如果仅使用来自于train的数据建立vocabulary,会导致test与dev | ||||
中的数据无法充分利用到来自于预训练embedding的信息,所以在建立词表的时候将test与dev考虑进来会使得最终的结果更好。 | 中的数据无法充分利用到来自于预训练embedding的信息,所以在建立词表的时候将test与dev考虑进来会使得最终的结果更好。 | ||||
@@ -326,14 +335,14 @@ class Vocabulary(object): | |||||
def construct_vocab(ins, no_create_entry=False): | def construct_vocab(ins, no_create_entry=False): | ||||
for fn in field_name: | for fn in field_name: | ||||
field = ins[fn] | field = ins[fn] | ||||
if isinstance(field, str): | |||||
if isinstance(field, str) or not _is_iterable(field): | |||||
self.add_word(field, no_create_entry=no_create_entry) | self.add_word(field, no_create_entry=no_create_entry) | ||||
elif isinstance(field, (list, np.ndarray)): | |||||
if not isinstance(field[0], (list, np.ndarray)): | |||||
else: | |||||
if isinstance(field[0], str) or not _is_iterable(field[0]): | |||||
for word in field: | for word in field: | ||||
self.add_word(word, no_create_entry=no_create_entry) | self.add_word(word, no_create_entry=no_create_entry) | ||||
else: | else: | ||||
if isinstance(field[0][0], (list, np.ndarray)): | |||||
if not isinstance(field[0][0], str) and _is_iterable(field[0][0]): | |||||
raise RuntimeError("Only support field with 2 dimensions.") | raise RuntimeError("Only support field with 2 dimensions.") | ||||
for words in field: | for words in field: | ||||
for word in words: | for word in words: | ||||
@@ -343,8 +352,8 @@ class Vocabulary(object): | |||||
if isinstance(dataset, DataSet): | if isinstance(dataset, DataSet): | ||||
try: | try: | ||||
dataset.apply(construct_vocab) | dataset.apply(construct_vocab) | ||||
except Exception as e: | |||||
print("When processing the `{}` dataset, the following error occurred.".format(idx)) | |||||
except BaseException as e: | |||||
print("When processing the `{}` dataset, the following error occurred:".format(idx)) | |||||
raise e | raise e | ||||
else: | else: | ||||
raise TypeError("Only DataSet type is allowed.") | raise TypeError("Only DataSet type is allowed.") | ||||
@@ -10,6 +10,7 @@ __all__ = [ | |||||
"StaticEmbedding", | "StaticEmbedding", | ||||
"ElmoEmbedding", | "ElmoEmbedding", | ||||
"BertEmbedding", | "BertEmbedding", | ||||
"BertWordPieceEncoder", | |||||
"StackEmbedding", | "StackEmbedding", | ||||
"LSTMCharEmbedding", | "LSTMCharEmbedding", | ||||
"CNNCharEmbedding", | "CNNCharEmbedding", | ||||
@@ -20,7 +21,7 @@ __all__ = [ | |||||
from .embedding import Embedding | from .embedding import Embedding | ||||
from .static_embedding import StaticEmbedding | from .static_embedding import StaticEmbedding | ||||
from .elmo_embedding import ElmoEmbedding | from .elmo_embedding import ElmoEmbedding | ||||
from .bert_embedding import BertEmbedding | |||||
from .bert_embedding import BertEmbedding, BertWordPieceEncoder | |||||
from .char_embedding import CNNCharEmbedding, LSTMCharEmbedding | from .char_embedding import CNNCharEmbedding, LSTMCharEmbedding | ||||
from .stack_embedding import StackEmbedding | from .stack_embedding import StackEmbedding | ||||
from .utils import get_embeddings | from .utils import get_embeddings |
@@ -8,7 +8,7 @@ import numpy as np | |||||
from itertools import chain | from itertools import chain | ||||
from ..core.vocabulary import Vocabulary | from ..core.vocabulary import Vocabulary | ||||
from ..io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||||
from ..io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||||
from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer | from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer | ||||
from .contextual_embedding import ContextualEmbedding | from .contextual_embedding import ContextualEmbedding | ||||
@@ -60,10 +60,8 @@ class BertEmbedding(ContextualEmbedding): | |||||
# 根据model_dir_or_name检查是否存在并下载 | # 根据model_dir_or_name检查是否存在并下载 | ||||
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | ||||
PRETRAIN_URL = _get_base_url('bert') | |||||
model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] | |||||
model_url = PRETRAIN_URL + model_name | |||||
model_dir = cached_path(model_url) | |||||
model_url = _get_embedding_url('bert', model_dir_or_name.lower()) | |||||
model_dir = cached_path(model_url, name='embedding') | |||||
# 检查是否存在 | # 检查是否存在 | ||||
elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): | elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): | ||||
model_dir = os.path.expanduser(os.path.abspath(model_dir_or_name)) | model_dir = os.path.expanduser(os.path.abspath(model_dir_or_name)) | ||||
@@ -133,11 +131,9 @@ class BertWordPieceEncoder(nn.Module): | |||||
pooled_cls: bool = False, requires_grad: bool=False): | pooled_cls: bool = False, requires_grad: bool=False): | ||||
super().__init__() | super().__init__() | ||||
if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR: | |||||
PRETRAIN_URL = _get_base_url('bert') | |||||
model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] | |||||
model_url = PRETRAIN_URL + model_name | |||||
model_dir = cached_path(model_url) | |||||
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | |||||
model_url = _get_embedding_url('bert', model_dir_or_name.lower()) | |||||
model_dir = cached_path(model_url, name='embedding') | |||||
# 检查是否存在 | # 检查是否存在 | ||||
elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): | elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): | ||||
model_dir = model_dir_or_name | model_dir = model_dir_or_name | ||||
@@ -8,7 +8,7 @@ import json | |||||
import codecs | import codecs | ||||
from ..core.vocabulary import Vocabulary | from ..core.vocabulary import Vocabulary | ||||
from ..io.file_utils import cached_path, _get_base_url, PRETRAINED_ELMO_MODEL_DIR | |||||
from ..io.file_utils import cached_path, _get_embedding_url, PRETRAINED_ELMO_MODEL_DIR | |||||
from ..modules.encoder._elmo import ElmobiLm, ConvTokenEmbedder | from ..modules.encoder._elmo import ElmobiLm, ConvTokenEmbedder | ||||
from .contextual_embedding import ContextualEmbedding | from .contextual_embedding import ContextualEmbedding | ||||
@@ -53,10 +53,8 @@ class ElmoEmbedding(ContextualEmbedding): | |||||
# 根据model_dir_or_name检查是否存在并下载 | # 根据model_dir_or_name检查是否存在并下载 | ||||
if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR: | if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR: | ||||
PRETRAIN_URL = _get_base_url('elmo') | |||||
model_name = PRETRAINED_ELMO_MODEL_DIR[model_dir_or_name] | |||||
model_url = PRETRAIN_URL + model_name | |||||
model_dir = cached_path(model_url) | |||||
model_url = _get_embedding_url('elmo', model_dir_or_name.lower()) | |||||
model_dir = cached_path(model_url, name='embedding') | |||||
# 检查是否存在 | # 检查是否存在 | ||||
elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): | elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): | ||||
model_dir = model_dir_or_name | model_dir = model_dir_or_name | ||||
@@ -7,7 +7,7 @@ import numpy as np | |||||
import warnings | import warnings | ||||
from ..core.vocabulary import Vocabulary | from ..core.vocabulary import Vocabulary | ||||
from ..io.file_utils import PRETRAIN_STATIC_FILES, _get_base_url, cached_path | |||||
from ..io.file_utils import PRETRAIN_STATIC_FILES, _get_embedding_url, cached_path | |||||
from .embedding import TokenEmbedding | from .embedding import TokenEmbedding | ||||
from ..modules.utils import _get_file_name_base_on_postfix | from ..modules.utils import _get_file_name_base_on_postfix | ||||
@@ -60,10 +60,8 @@ class StaticEmbedding(TokenEmbedding): | |||||
embedding_dim = int(embedding_dim) | embedding_dim = int(embedding_dim) | ||||
model_path = None | model_path = None | ||||
elif model_dir_or_name.lower() in PRETRAIN_STATIC_FILES: | elif model_dir_or_name.lower() in PRETRAIN_STATIC_FILES: | ||||
PRETRAIN_URL = _get_base_url('static') | |||||
model_name = PRETRAIN_STATIC_FILES[model_dir_or_name] | |||||
model_url = PRETRAIN_URL + model_name | |||||
model_path = cached_path(model_url) | |||||
model_url = _get_embedding_url('static', model_dir_or_name.lower()) | |||||
model_path = cached_path(model_url, name='embedding') | |||||
# 检查是否存在 | # 检查是否存在 | ||||
elif os.path.isfile(os.path.expanduser(os.path.abspath(model_dir_or_name))): | elif os.path.isfile(os.path.expanduser(os.path.abspath(model_dir_or_name))): | ||||
model_path = model_dir_or_name | model_path = model_dir_or_name | ||||
@@ -84,8 +82,8 @@ class StaticEmbedding(TokenEmbedding): | |||||
if lowered_word not in lowered_vocab.word_count: | if lowered_word not in lowered_vocab.word_count: | ||||
lowered_vocab.add_word(lowered_word) | lowered_vocab.add_word(lowered_word) | ||||
lowered_vocab._no_create_word[lowered_word] += 1 | lowered_vocab._no_create_word[lowered_word] += 1 | ||||
print(f"All word in the vocab have been lowered. There are {len(vocab)} words, {len(lowered_vocab)} unique lowered " | |||||
f"words.") | |||||
print(f"All word in the vocab have been lowered before finding pretrained vectors. There are {len(vocab)} " | |||||
f"words, {len(lowered_vocab)} unique lowered words.") | |||||
if model_path: | if model_path: | ||||
embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method) | embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method) | ||||
else: | else: | ||||
@@ -5,10 +5,10 @@ __all__ = [ | |||||
] | ] | ||||
import _pickle as pickle | import _pickle as pickle | ||||
import os | |||||
from typing import Union, Dict | from typing import Union, Dict | ||||
import os | import os | ||||
from ..core.dataset import DataSet | from ..core.dataset import DataSet | ||||
from ..core.vocabulary import Vocabulary | |||||
class BaseLoader(object): | class BaseLoader(object): | ||||
@@ -111,7 +111,10 @@ def _uncompress(src, dst): | |||||
class DataBundle: | class DataBundle: | ||||
""" | """ | ||||
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)以及各个field对应的vocabulary。 | |||||
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)以及各个field对应的vocabulary。该对象一般由fastNLP中各种 | |||||
DataSetLoader的load函数生成,可以通过以下的方法获取里面的内容 | |||||
Example:: | |||||
:param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict | :param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict | ||||
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict | :param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict | ||||
@@ -121,6 +124,88 @@ class DataBundle: | |||||
self.vocabs = vocabs or {} | self.vocabs = vocabs or {} | ||||
self.datasets = datasets or {} | self.datasets = datasets or {} | ||||
def set_vocab(self, vocab, field_name): | |||||
""" | |||||
向DataBunlde中增加vocab | |||||
:param Vocabulary vocab: 词表 | |||||
:param str field_name: 这个vocab对应的field名称 | |||||
:return: | |||||
""" | |||||
assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary supports." | |||||
self.vocabs[field_name] = vocab | |||||
def set_dataset(self, dataset, name): | |||||
""" | |||||
:param DataSet dataset: 传递给DataBundle的DataSet | |||||
:param str name: dataset的名称 | |||||
:return: | |||||
""" | |||||
self.datasets[name] = dataset | |||||
def get_dataset(self, name:str): | |||||
""" | |||||
获取名为name的dataset | |||||
:param str name: dataset的名称,一般为'train', 'dev', 'test' | |||||
:return: DataSet | |||||
""" | |||||
return self.datasets[name] | |||||
def get_vocab(self, field_name:str): | |||||
""" | |||||
获取field名为field_name对应的vocab | |||||
:param str field_name: 名称 | |||||
:return: Vocabulary | |||||
""" | |||||
return self.vocabs[field_name] | |||||
def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_field=True): | |||||
""" | |||||
将field_names中的field设置为input, 对data_bundle中所有的dataset执行该操作:: | |||||
data_bundle.set_input('words', 'seq_len') # 将words和seq_len这两个field的input属性设置为True | |||||
data_bundle.set_input('words', flag=False) # 将words这个field的input属性设置为False | |||||
:param str field_names: field的名称 | |||||
:param bool flag: 将field_name的input状态设置为flag | |||||
:param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一 | |||||
行的数据进行类型和维度推断本列的数据的类型和维度。 | |||||
:param bool ignore_miss_field: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略; 如果为False,则报错 | |||||
""" | |||||
for field_name in field_names: | |||||
for name, dataset in self.datasets.items(): | |||||
if not ignore_miss_field and not dataset.has_field(field_name): | |||||
raise KeyError(f"Field:{field_name} was not found in DataSet:{name}") | |||||
if not dataset.has_field(field_name): | |||||
continue | |||||
else: | |||||
dataset.set_input(field_name, flag=flag, use_1st_ins_infer_dim_type=use_1st_ins_infer_dim_type) | |||||
def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_field=True): | |||||
""" | |||||
将field_names中的field设置为target, 对data_bundle中所有的dataset执行该操作:: | |||||
data_bundle.set_target('target', 'seq_len') # 将words和target这两个field的input属性设置为True | |||||
data_bundle.set_target('target', flag=False) # 将target这个field的input属性设置为False | |||||
:param str field_names: field的名称 | |||||
:param bool flag: 将field_name的target状态设置为flag | |||||
:param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一 | |||||
行的数据进行类型和维度推断本列的数据的类型和维度。 | |||||
:param bool ignore_miss_field: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略; 如果为False,则报错 | |||||
""" | |||||
for field_name in field_names: | |||||
for name, dataset in self.datasets.items(): | |||||
if not ignore_miss_field and not dataset.has_field(field_name): | |||||
raise KeyError(f"Field:{field_name} was not found in DataSet:{name}") | |||||
if not dataset.has_field(field_name): | |||||
continue | |||||
else: | |||||
dataset.set_target(field_name, flag=flag, use_1st_ins_infer_dim_type=use_1st_ins_infer_dim_type) | |||||
def __repr__(self): | def __repr__(self): | ||||
_str = 'In total {} datasets:\n'.format(len(self.datasets)) | _str = 'In total {} datasets:\n'.format(len(self.datasets)) | ||||
for name, dataset in self.datasets.items(): | for name, dataset in self.datasets.items(): | ||||
@@ -3,38 +3,47 @@ from ...core.dataset import DataSet | |||||
from ...core.instance import Instance | from ...core.instance import Instance | ||||
from ..base_loader import DataSetLoader | from ..base_loader import DataSetLoader | ||||
from ..file_reader import _read_conll | from ..file_reader import _read_conll | ||||
from typing import Union, Dict | |||||
from ..utils import check_loader_paths | |||||
from ..base_loader import DataBundle | |||||
class ConllLoader(DataSetLoader): | class ConllLoader(DataSetLoader): | ||||
""" | """ | ||||
别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.data_loader.ConllLoader` | 别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.data_loader.ConllLoader` | ||||
读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html. 数据中以"-DOCSTART-"开头的行将被忽略,因为 | |||||
该符号在conll 2003中被用为文档分割符。 | |||||
列号从0开始, 每列对应内容为:: | |||||
Column Type | |||||
0 Document ID | |||||
1 Part number | |||||
2 Word number | |||||
3 Word itself | |||||
4 Part-of-Speech | |||||
5 Parse bit | |||||
6 Predicate lemma | |||||
7 Predicate Frameset ID | |||||
8 Word sense | |||||
9 Speaker/Author | |||||
10 Named Entities | |||||
11:N Predicate Arguments | |||||
N Coreference | |||||
:param headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 | |||||
:param indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | |||||
:param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False`` | |||||
该ConllLoader支持读取的数据格式: 以空行隔开两个sample,除了分割行,每一行用空格或者制表符隔开不同的元素。如下例所示: | |||||
Example:: | |||||
# 文件中的内容 | |||||
Nadim NNP B-NP B-PER | |||||
Ladki NNP I-NP I-PER | |||||
AL-AIN NNP B-NP B-LOC | |||||
United NNP B-NP B-LOC | |||||
Arab NNP I-NP I-LOC | |||||
Emirates NNPS I-NP I-LOC | |||||
1996-12-06 CD I-NP O | |||||
... | |||||
# 如果用以下的参数读取,返回的DataSet将包含raw_words和pos两个field, 这两个field的值分别取自于第0列与第1列 | |||||
dataset = ConllLoader(headers=['raw_words', 'pos'], indexes=[0, 1])._load('/path/to/train.conll') | |||||
# 如果用以下的参数读取,返回的DataSet将包含raw_words和ner两个field, 这两个field的值分别取自于第0列与第2列 | |||||
dataset = ConllLoader(headers=['raw_words', 'ner'], indexes=[0, 3])._load('/path/to/train.conll') | |||||
# 如果用以下的参数读取,返回的DataSet将包含raw_words, pos和ner三个field | |||||
dataset = ConllLoader(headers=['raw_words', 'pos', 'ner'], indexes=[0, 1, 3])._load('/path/to/train.conll') | |||||
dataset = ConllLoader(headers=['raw_words', 'pos'], indexes=[0, 1])._load('/path/to/train.conll')中DataSet的raw_words | |||||
列与pos列的内容都是List[str] | |||||
数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。 | |||||
:param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 | |||||
:param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | |||||
:param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` | |||||
""" | """ | ||||
def __init__(self, headers, indexes=None, dropna=False): | |||||
def __init__(self, headers, indexes=None, dropna=True): | |||||
super(ConllLoader, self).__init__() | super(ConllLoader, self).__init__() | ||||
if not isinstance(headers, (list, tuple)): | if not isinstance(headers, (list, tuple)): | ||||
raise TypeError( | raise TypeError( | ||||
@@ -49,25 +58,74 @@ class ConllLoader(DataSetLoader): | |||||
self.indexes = indexes | self.indexes = indexes | ||||
def _load(self, path): | def _load(self, path): | ||||
""" | |||||
传入的一个文件路径,将该文件读入DataSet中,field由Loader初始化时指定的headers决定。 | |||||
:param str path: 文件的路径 | |||||
:return: DataSet | |||||
""" | |||||
ds = DataSet() | ds = DataSet() | ||||
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): | for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): | ||||
ins = {h: data[i] for i, h in enumerate(self.headers)} | ins = {h: data[i] for i, h in enumerate(self.headers)} | ||||
ds.append(Instance(**ins)) | ds.append(Instance(**ins)) | ||||
return ds | return ds | ||||
def load(self, paths: Union[str, Dict[str, str]]) -> DataBundle: | |||||
""" | |||||
从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。 | |||||
读取的field根据ConllLoader初始化时传入的headers决定。 | |||||
:param Union[str, Dict[str, str]] paths: 支持以下的几种输入方式 | |||||
(1) 传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件 | |||||
名包含'train'、 'dev'、 'test'则会报错 | |||||
Example:: | |||||
data_bundle = ConllLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train, dev, test等有所变化 | |||||
# 可以通过以下的方式取出DataSet | |||||
tr_data = data_bundle.datasets['train'] | |||||
te_data = data_bundle.datasets['test'] # 如果目录下有文件包含test这个字段 | |||||
(2) 传入文件path | |||||
Example:: | |||||
data_bundle = ConllLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train' | |||||
tr_data = data_bundle.datasets['train'] # 可以通过以下的方式取出DataSet | |||||
(3) 传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test | |||||
Example:: | |||||
paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"} | |||||
data_bundle = ConllLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test" | |||||
dev_data = data_bundle.datasets['dev'] | |||||
:return: :class:`~fastNLP.DataSet` 类的对象或 :class:`~fastNLP.io.DataBundle` 的字典 | |||||
""" | |||||
paths = check_loader_paths(paths) | |||||
datasets = {name: self._load(path) for name, path in paths.items()} | |||||
data_bundle = DataBundle(datasets=datasets) | |||||
return data_bundle | |||||
class Conll2003Loader(ConllLoader): | class Conll2003Loader(ConllLoader): | ||||
""" | """ | ||||
别名::class:`fastNLP.io.Conll2003Loader` :class:`fastNLP.io.data_loader.Conll2003Loader` | 别名::class:`fastNLP.io.Conll2003Loader` :class:`fastNLP.io.data_loader.Conll2003Loader` | ||||
读取Conll2003数据 | |||||
该Loader用以读取Conll2003数据,conll2003的数据可以在https://github.com/davidsbatista/NER-datasets/tree/master/CONLL2003 | |||||
找到。数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。 | |||||
返回的DataSet将具有以下['raw_words', 'pos', 'chunks', 'ner']四个field, 每个field中的内容都是List[str]。 | |||||
.. csv-table:: Conll2003Loader处理之 :header: "raw_words", "words", "target", "seq_len" | |||||
"[Nadim, Ladki]", "[1, 2]", "[1, 2]", 2 | |||||
"[AL-AIN, United, Arab, ...]", "[3, 4, 5,...]", "[3, 4]", 5 | |||||
"[...]", "[...]", "[...]", . | |||||
关于数据集的更多信息,参考: | |||||
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data | |||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
headers = [ | headers = [ | ||||
'tokens', 'pos', 'chunks', 'ner', | |||||
'raw_words', 'pos', 'chunks', 'ner', | |||||
] | ] | ||||
super(Conll2003Loader, self).__init__(headers=headers) | super(Conll2003Loader, self).__init__(headers=headers) |
@@ -121,7 +121,7 @@ class MatchingLoader(DataSetLoader): | |||||
PRETRAIN_URL = _get_base_url('bert') | PRETRAIN_URL = _get_base_url('bert') | ||||
model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer] | model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer] | ||||
model_url = PRETRAIN_URL + model_name | model_url = PRETRAIN_URL + model_name | ||||
model_dir = cached_path(model_url) | |||||
model_dir = cached_path(model_url, name='embedding') | |||||
# 检查是否存在 | # 检查是否存在 | ||||
elif os.path.isdir(bert_tokenizer): | elif os.path.isdir(bert_tokenizer): | ||||
model_dir = bert_tokenizer | model_dir = bert_tokenizer | ||||
@@ -5,7 +5,7 @@ from ..base_loader import DataBundle | |||||
from ..dataset_loader import CSVLoader | from ..dataset_loader import CSVLoader | ||||
from ...core.vocabulary import Vocabulary, VocabularyOption | from ...core.vocabulary import Vocabulary, VocabularyOption | ||||
from ...core.const import Const | from ...core.const import Const | ||||
from ..utils import check_dataloader_paths | |||||
from ..utils import check_loader_paths | |||||
class MTL16Loader(CSVLoader): | class MTL16Loader(CSVLoader): | ||||
@@ -38,7 +38,7 @@ class MTL16Loader(CSVLoader): | |||||
src_vocab_opt: VocabularyOption = None, | src_vocab_opt: VocabularyOption = None, | ||||
tgt_vocab_opt: VocabularyOption = None,): | tgt_vocab_opt: VocabularyOption = None,): | ||||
paths = check_dataloader_paths(paths) | |||||
paths = check_loader_paths(paths) | |||||
datasets = {} | datasets = {} | ||||
info = DataBundle() | info = DataBundle() | ||||
for name, path in paths.items(): | for name, path in paths.items(): | ||||
@@ -8,7 +8,7 @@ from ...core.vocabulary import VocabularyOption, Vocabulary | |||||
from ...core.dataset import DataSet | from ...core.dataset import DataSet | ||||
from ...core.const import Const | from ...core.const import Const | ||||
from ...core.instance import Instance | from ...core.instance import Instance | ||||
from ..utils import check_dataloader_paths, get_tokenizer | |||||
from ..utils import check_loader_paths, get_tokenizer | |||||
class SSTLoader(DataSetLoader): | class SSTLoader(DataSetLoader): | ||||
@@ -67,7 +67,7 @@ class SSTLoader(DataSetLoader): | |||||
paths, train_subtree=True, | paths, train_subtree=True, | ||||
src_vocab_op: VocabularyOption = None, | src_vocab_op: VocabularyOption = None, | ||||
tgt_vocab_op: VocabularyOption = None,): | tgt_vocab_op: VocabularyOption = None,): | ||||
paths = check_dataloader_paths(paths) | |||||
paths = check_loader_paths(paths) | |||||
input_name, target_name = 'words', 'target' | input_name, target_name = 'words', 'target' | ||||
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) | src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) | ||||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | tgt_vocab = Vocabulary(unknown=None, padding=None) \ | ||||
@@ -129,7 +129,7 @@ class SST2Loader(CSVLoader): | |||||
tgt_vocab_opt: VocabularyOption = None, | tgt_vocab_opt: VocabularyOption = None, | ||||
char_level_op=False): | char_level_op=False): | ||||
paths = check_dataloader_paths(paths) | |||||
paths = check_loader_paths(paths) | |||||
datasets = {} | datasets = {} | ||||
info = DataBundle() | info = DataBundle() | ||||
for name, path in paths.items(): | for name, path in paths.items(): | ||||
@@ -155,7 +155,9 @@ class SST2Loader(CSVLoader): | |||||
for dataset in datasets.values(): | for dataset in datasets.values(): | ||||
dataset.apply_field(wordtochar, field_name=Const.INPUT, new_field_name=Const.CHAR_INPUT) | dataset.apply_field(wordtochar, field_name=Const.INPUT, new_field_name=Const.CHAR_INPUT) | ||||
src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | ||||
src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT) | |||||
src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT, no_create_entry_dataset=[ | |||||
dataset for name, dataset in datasets.items() if name!='train' | |||||
]) | |||||
src_vocab.index_dataset(*datasets.values(), field_name=Const.INPUT) | src_vocab.index_dataset(*datasets.values(), field_name=Const.INPUT) | ||||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | tgt_vocab = Vocabulary(unknown=None, padding=None) \ | ||||
@@ -8,7 +8,7 @@ from ...core.instance import Instance | |||||
from ...core.vocabulary import VocabularyOption, Vocabulary | from ...core.vocabulary import VocabularyOption, Vocabulary | ||||
from ..base_loader import DataBundle, DataSetLoader | from ..base_loader import DataBundle, DataSetLoader | ||||
from typing import Union, Dict | from typing import Union, Dict | ||||
from ..utils import check_dataloader_paths, get_tokenizer | |||||
from ..utils import check_loader_paths, get_tokenizer | |||||
class YelpLoader(DataSetLoader): | class YelpLoader(DataSetLoader): | ||||
@@ -62,7 +62,7 @@ class YelpLoader(DataSetLoader): | |||||
src_vocab_op: VocabularyOption = None, | src_vocab_op: VocabularyOption = None, | ||||
tgt_vocab_op: VocabularyOption = None, | tgt_vocab_op: VocabularyOption = None, | ||||
char_level_op=False): | char_level_op=False): | ||||
paths = check_dataloader_paths(paths) | |||||
paths = check_loader_paths(paths) | |||||
info = DataBundle(datasets=self.load(paths)) | info = DataBundle(datasets=self.load(paths)) | ||||
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) | src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) | ||||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | tgt_vocab = Vocabulary(unknown=None, padding=None) \ | ||||
@@ -114,25 +114,3 @@ def _cut_long_sentence(sent, max_sample_length=200): | |||||
else: | else: | ||||
cutted_sentence.append(sent) | cutted_sentence.append(sent) | ||||
return cutted_sentence | return cutted_sentence | ||||
def _add_seg_tag(data): | |||||
""" | |||||
:param data: list of ([word], [pos], [heads], [head_tags]) | |||||
:return: list of ([word], [pos]) | |||||
""" | |||||
_processed = [] | |||||
for word_list, pos_list, _, _ in data: | |||||
new_sample = [] | |||||
for word, pos in zip(word_list, pos_list): | |||||
if len(word) == 1: | |||||
new_sample.append((word, 'S-' + pos)) | |||||
else: | |||||
new_sample.append((word[0], 'B-' + pos)) | |||||
for c in word[1:-1]: | |||||
new_sample.append((c, 'M-' + pos)) | |||||
new_sample.append((word[-1], 'E-' + pos)) | |||||
_processed.append(list(map(list, zip(*new_sample)))) | |||||
return _processed |
@@ -2,7 +2,7 @@ | |||||
此模块用于给其它模块提供读取文件的函数,没有为用户提供 API | 此模块用于给其它模块提供读取文件的函数,没有为用户提供 API | ||||
""" | """ | ||||
import json | import json | ||||
import warnings | |||||
def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): | def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): | ||||
""" | """ | ||||
@@ -91,7 +91,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||||
with open(path, 'r', encoding=encoding) as f: | with open(path, 'r', encoding=encoding) as f: | ||||
sample = [] | sample = [] | ||||
start = next(f).strip() | start = next(f).strip() | ||||
if '-DOCSTART-' not in start and start!='': | |||||
if start!='': | |||||
sample.append(start.split()) | sample.append(start.split()) | ||||
for line_idx, line in enumerate(f, 1): | for line_idx, line in enumerate(f, 1): | ||||
line = line.strip() | line = line.strip() | ||||
@@ -103,13 +103,13 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||||
yield line_idx, res | yield line_idx, res | ||||
except Exception as e: | except Exception as e: | ||||
if dropna: | if dropna: | ||||
warnings.warn('Invalid instance ends at line: {} has been dropped.'.format(line_idx)) | |||||
continue | continue | ||||
raise ValueError('invalid instance ends at line: {}'.format(line_idx)) | |||||
raise ValueError('Invalid instance ends at line: {}'.format(line_idx)) | |||||
elif line.startswith('#'): | elif line.startswith('#'): | ||||
continue | continue | ||||
else: | else: | ||||
if not line.startswith('-DOCSTART-'): | |||||
sample.append(line.split()) | |||||
sample.append(line.split()) | |||||
if len(sample) > 0: | if len(sample) > 0: | ||||
try: | try: | ||||
res = parse_conll(sample) | res = parse_conll(sample) | ||||
@@ -7,7 +7,7 @@ import requests | |||||
import tempfile | import tempfile | ||||
from tqdm import tqdm | from tqdm import tqdm | ||||
import shutil | import shutil | ||||
import hashlib | |||||
from requests import HTTPError | |||||
PRETRAINED_BERT_MODEL_DIR = { | PRETRAINED_BERT_MODEL_DIR = { | ||||
@@ -23,15 +23,25 @@ PRETRAINED_BERT_MODEL_DIR = { | |||||
'cn': 'bert-base-chinese-29d0a84a.zip', | 'cn': 'bert-base-chinese-29d0a84a.zip', | ||||
'cn-base': 'bert-base-chinese-29d0a84a.zip', | 'cn-base': 'bert-base-chinese-29d0a84a.zip', | ||||
'multilingual': 'bert-base-multilingual-cased.zip', | |||||
'multilingual-base-uncased': 'bert-base-multilingual-uncased.zip', | |||||
'multilingual-base-cased': 'bert-base-multilingual-cased.zip', | |||||
'bert-base-chinese': 'bert-base-chinese.zip', | |||||
'bert-base-cased': 'bert-base-cased.zip', | |||||
'bert-base-cased-finetuned-mrpc': 'bert-base-cased-finetuned-mrpc.zip', | |||||
'bert-large-cased-wwm': 'bert-large-cased-wwm.zip', | |||||
'bert-large-uncased': 'bert-large-uncased.zip', | |||||
'bert-large-cased': 'bert-large-cased.zip', | |||||
'bert-base-uncased': 'bert-base-uncased.zip', | |||||
'bert-large-uncased-wwm': 'bert-large-uncased-wwm.zip', | |||||
'bert-chinese-wwm': 'bert-chinese-wwm.zip', | |||||
'bert-base-multilingual-cased': 'bert-base-multilingual-cased.zip', | |||||
'bert-base-multilingual-uncased': 'bert-base-multilingual-uncased.zip', | |||||
} | } | ||||
PRETRAINED_ELMO_MODEL_DIR = { | PRETRAINED_ELMO_MODEL_DIR = { | ||||
'en': 'elmo_en-d39843fe.tar.gz', | 'en': 'elmo_en-d39843fe.tar.gz', | ||||
'en-small': "elmo_en_Small.zip" | |||||
'en-small': "elmo_en_Small.zip", | |||||
'en-original-5.5b': 'elmo_en_Original_5.5B.zip', | |||||
'en-original': 'elmo_en_Original.zip', | |||||
'en-medium': 'elmo_en_Medium.zip' | |||||
} | } | ||||
PRETRAIN_STATIC_FILES = { | PRETRAIN_STATIC_FILES = { | ||||
@@ -42,34 +52,68 @@ PRETRAIN_STATIC_FILES = { | |||||
'en-fasttext-wiki': "wiki-news-300d-1M.vec.zip", | 'en-fasttext-wiki': "wiki-news-300d-1M.vec.zip", | ||||
'cn': "tencent_cn-dab24577.tar.gz", | 'cn': "tencent_cn-dab24577.tar.gz", | ||||
'cn-fasttext': "cc.zh.300.vec-d68a9bcf.gz", | 'cn-fasttext': "cc.zh.300.vec-d68a9bcf.gz", | ||||
'sgns-literature-word':'sgns.literature.word.txt.zip', | |||||
'glove-42b-300d': 'glove.42B.300d.zip', | |||||
'glove-6b-50d': 'glove.6B.50d.zip', | |||||
'glove-6b-100d': 'glove.6B.100d.zip', | |||||
'glove-6b-200d': 'glove.6B.200d.zip', | |||||
'glove-6b-300d': 'glove.6B.300d.zip', | |||||
'glove-840b-300d': 'glove.840B.300d.zip', | |||||
'glove-twitter-27b-25d': 'glove.twitter.27B.25d.zip', | |||||
'glove-twitter-27b-50d': 'glove.twitter.27B.50d.zip', | |||||
'glove-twitter-27b-100d': 'glove.twitter.27B.100d.zip', | |||||
'glove-twitter-27b-200d': 'glove.twitter.27B.200d.zip' | |||||
} | |||||
DATASET_DIR = { | |||||
'aclImdb': "imdb.zip", | |||||
"yelp-review-full":"yelp_review_full.tar.gz", | |||||
"yelp-review-polarity": "yelp_review_polarity.tar.gz", | |||||
"mnli": "MNLI.zip", | |||||
"snli": "SNLI.zip", | |||||
"qnli": "QNLI.zip", | |||||
"sst-2": "SST-2.zip", | |||||
"sst": "SST.zip", | |||||
"rte": "RTE.zip" | |||||
} | } | ||||
def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path: | |||||
def cached_path(url_or_filename:str, cache_dir:str=None, name=None) -> Path: | |||||
""" | """ | ||||
给定一个url或者文件名(可以是具体的文件名,也可以是文件),先在cache_dir下寻找该文件是否存在,如果不存在则去下载, 并 | |||||
给定一个url,尝试通过url中的解析出来的文件名字filename到{cache_dir}/{name}/{filename}下寻找这个文件, | |||||
(1)如果cache_dir=None, 则cache_dir=~/.fastNLP/; 否则cache_dir=cache_dir | |||||
(2)如果name=None, 则没有中间的{name}这一层结构;否者中间结构就为{name} | |||||
如果有该文件,就直接返回路径 | |||||
如果没有该文件,则尝试用传入的url下载 | |||||
或者文件名(可以是具体的文件名,也可以是文件夹),先在cache_dir下寻找该文件是否存在,如果不存在则去下载, 并 | |||||
将文件放入到cache_dir中. | 将文件放入到cache_dir中. | ||||
:param url_or_filename: 文件的下载url或者文件路径 | |||||
:param cache_dir: 文件的缓存文件夹 | |||||
:param str url_or_filename: 文件的下载url或者文件名称。 | |||||
:param str cache_dir: 文件的缓存文件夹。如果为None,将使用"~/.fastNLP"这个默认路径 | |||||
:param str name: 中间一层的名称。如embedding, dataset | |||||
:return: | :return: | ||||
""" | """ | ||||
if cache_dir is None: | if cache_dir is None: | ||||
dataset_cache = Path(get_default_cache_path()) | |||||
data_cache = Path(get_default_cache_path()) | |||||
else: | else: | ||||
dataset_cache = cache_dir | |||||
data_cache = cache_dir | |||||
if name: | |||||
data_cache = os.path.join(data_cache, name) | |||||
parsed = urlparse(url_or_filename) | parsed = urlparse(url_or_filename) | ||||
if parsed.scheme in ("http", "https"): | if parsed.scheme in ("http", "https"): | ||||
# URL, so get it from the cache (downloading if necessary) | # URL, so get it from the cache (downloading if necessary) | ||||
return get_from_cache(url_or_filename, dataset_cache) | |||||
elif parsed.scheme == "" and Path(os.path.join(dataset_cache, url_or_filename)).exists(): | |||||
return get_from_cache(url_or_filename, Path(data_cache)) | |||||
elif parsed.scheme == "" and Path(os.path.join(data_cache, url_or_filename)).exists(): | |||||
# File, and it exists. | # File, and it exists. | ||||
return Path(url_or_filename) | |||||
return Path(os.path.join(data_cache, url_or_filename)) | |||||
elif parsed.scheme == "": | elif parsed.scheme == "": | ||||
# File, but it doesn't exist. | # File, but it doesn't exist. | ||||
raise FileNotFoundError("file {} not found".format(url_or_filename)) | |||||
raise FileNotFoundError("file {} not found in {}.".format(url_or_filename, data_cache)) | |||||
else: | else: | ||||
# Something unknown | # Something unknown | ||||
raise ValueError( | raise ValueError( | ||||
@@ -79,8 +123,12 @@ def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path: | |||||
def get_filepath(filepath): | def get_filepath(filepath): | ||||
""" | """ | ||||
如果filepath中只有一个文件,则直接返回对应的全路径. | |||||
:param filepath: | |||||
如果filepath为文件夹, | |||||
如果内含多个文件, 返回filepath | |||||
如果只有一个文件, 返回filepath + filename | |||||
如果filepath为文件 | |||||
返回filepath | |||||
:param str filepath: 路径 | |||||
:return: | :return: | ||||
""" | """ | ||||
if os.path.isdir(filepath): | if os.path.isdir(filepath): | ||||
@@ -89,14 +137,17 @@ def get_filepath(filepath): | |||||
return os.path.join(filepath, files[0]) | return os.path.join(filepath, files[0]) | ||||
else: | else: | ||||
return filepath | return filepath | ||||
return filepath | |||||
elif os.path.isfile(filepath): | |||||
return filepath | |||||
else: | |||||
raise FileNotFoundError(f"{filepath} is not a valid file or directory.") | |||||
def get_default_cache_path(): | def get_default_cache_path(): | ||||
""" | """ | ||||
获取默认的fastNLP存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。 | 获取默认的fastNLP存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。 | ||||
:return: | |||||
:return: str | |||||
""" | """ | ||||
if 'FASTNLP_CACHE_DIR' in os.environ: | if 'FASTNLP_CACHE_DIR' in os.environ: | ||||
fastnlp_cache_dir = os.environ.get('FASTNLP_CACHE_DIR') | fastnlp_cache_dir = os.environ.get('FASTNLP_CACHE_DIR') | ||||
@@ -109,17 +160,66 @@ def get_default_cache_path(): | |||||
def _get_base_url(name): | def _get_base_url(name): | ||||
""" | |||||
根据name返回下载的url地址。 | |||||
:param str name: 支持dataset和embedding两种 | |||||
:return: | |||||
""" | |||||
# 返回的URL结尾必须是/ | # 返回的URL结尾必须是/ | ||||
if 'FASTNLP_BASE_URL' in os.environ: | |||||
fastnlp_base_url = os.environ['FASTNLP_BASE_URL'] | |||||
if fastnlp_base_url.endswith('/'): | |||||
return fastnlp_base_url | |||||
environ_name = "FASTNLP_{}_URL".format(name.upper()) | |||||
if environ_name in os.environ: | |||||
url = os.environ[environ_name] | |||||
if url.endswith('/'): | |||||
return url | |||||
else: | else: | ||||
return fastnlp_base_url + '/' | |||||
return url + '/' | |||||
else: | else: | ||||
# TODO 替换 | |||||
dbbrain_url = "http://dbcloud.irocn.cn:8989/api/public/dl/" | |||||
return dbbrain_url | |||||
URLS = { | |||||
'embedding': "http://dbcloud.irocn.cn:8989/api/public/dl/", | |||||
"dataset": "http://dbcloud.irocn.cn:8989/api/public/dl/dataset/" | |||||
} | |||||
if name.lower() not in URLS: | |||||
raise KeyError(f"{name} is not recognized.") | |||||
return URLS[name.lower()] | |||||
def _get_embedding_url(type, name): | |||||
""" | |||||
给定embedding类似和名称,返回下载url | |||||
:param str type: 支持static, bert, elmo。即embedding的类型 | |||||
:param str name: embedding的名称, 例如en, cn, based等 | |||||
:return: str, 下载的url地址 | |||||
""" | |||||
PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, | |||||
"bert": PRETRAINED_BERT_MODEL_DIR, | |||||
"static":PRETRAIN_STATIC_FILES} | |||||
map = PRETRAIN_MAP.get(type, None) | |||||
if map: | |||||
filename = map.get(name, None) | |||||
if filename: | |||||
url = _get_base_url('embedding') + filename | |||||
return url | |||||
raise KeyError("There is no {}. Only supports {}.".format(name, list(map.keys()))) | |||||
else: | |||||
raise KeyError(f"There is no {type}. Only supports bert, elmo, static") | |||||
def _get_dataset_url(name): | |||||
""" | |||||
给定dataset的名称,返回下载url | |||||
:param str name: 给定dataset的名称,比如imdb, sst-2等 | |||||
:return: str | |||||
""" | |||||
filename = DATASET_DIR.get(name, None) | |||||
if filename: | |||||
url = _get_base_url('dataset') + filename | |||||
return url | |||||
else: | |||||
raise KeyError(f"There is no {name}.") | |||||
def split_filename_suffix(filepath): | def split_filename_suffix(filepath): | ||||
@@ -136,9 +236,9 @@ def split_filename_suffix(filepath): | |||||
def get_from_cache(url: str, cache_dir: Path = None) -> Path: | def get_from_cache(url: str, cache_dir: Path = None) -> Path: | ||||
""" | """ | ||||
尝试在cache_dir中寻找url定义的资源; 如果没有找到。则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。 | |||||
如果从url中下载的资源解压后有多个文件,则返回directory的路径; 如果只有一个资源,则返回具体的路径。 | |||||
尝试在cache_dir中寻找url定义的资源; 如果没有找到; 则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。会将下载的 | |||||
文件解压,将解压后的文件全部放在cache_dir文件夹中。 | |||||
如果从url中下载的资源解压后有多个文件,则返回目录的路径; 如果只有一个资源文件,则返回具体的路径。 | |||||
""" | """ | ||||
cache_dir.mkdir(parents=True, exist_ok=True) | cache_dir.mkdir(parents=True, exist_ok=True) | ||||
@@ -173,63 +273,68 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: | |||||
# GET file object | # GET file object | ||||
req = requests.get(url, stream=True, headers={"User-Agent": "fastNLP"}) | req = requests.get(url, stream=True, headers={"User-Agent": "fastNLP"}) | ||||
content_length = req.headers.get("Content-Length") | |||||
total = int(content_length) if content_length is not None else None | |||||
progress = tqdm(unit="B", total=total) | |||||
with open(temp_filename, "wb") as temp_file: | |||||
for chunk in req.iter_content(chunk_size=1024): | |||||
if chunk: # filter out keep-alive new chunks | |||||
progress.update(len(chunk)) | |||||
temp_file.write(chunk) | |||||
progress.close() | |||||
print(f"Finish download from {url}.") | |||||
# 开始解压 | |||||
delete_temp_dir = None | |||||
if suffix in ('.zip', '.tar.gz'): | |||||
uncompress_temp_dir = tempfile.mkdtemp() | |||||
delete_temp_dir = uncompress_temp_dir | |||||
print(f"Start to uncompress file to {uncompress_temp_dir}") | |||||
if suffix == '.zip': | |||||
unzip_file(Path(temp_filename), Path(uncompress_temp_dir)) | |||||
if req.status_code==200: | |||||
content_length = req.headers.get("Content-Length") | |||||
total = int(content_length) if content_length is not None else None | |||||
progress = tqdm(unit="B", total=total, unit_scale=1) | |||||
with open(temp_filename, "wb") as temp_file: | |||||
for chunk in req.iter_content(chunk_size=1024*16): | |||||
if chunk: # filter out keep-alive new chunks | |||||
progress.update(len(chunk)) | |||||
temp_file.write(chunk) | |||||
progress.close() | |||||
print(f"Finish download from {url}.") | |||||
# 开始解压 | |||||
delete_temp_dir = None | |||||
if suffix in ('.zip', '.tar.gz'): | |||||
uncompress_temp_dir = tempfile.mkdtemp() | |||||
delete_temp_dir = uncompress_temp_dir | |||||
print(f"Start to uncompress file to {uncompress_temp_dir}") | |||||
if suffix == '.zip': | |||||
unzip_file(Path(temp_filename), Path(uncompress_temp_dir)) | |||||
else: | |||||
untar_gz_file(Path(temp_filename), Path(uncompress_temp_dir)) | |||||
filenames = os.listdir(uncompress_temp_dir) | |||||
if len(filenames)==1: | |||||
if os.path.isdir(os.path.join(uncompress_temp_dir, filenames[0])): | |||||
uncompress_temp_dir = os.path.join(uncompress_temp_dir, filenames[0]) | |||||
cache_path.mkdir(parents=True, exist_ok=True) | |||||
print("Finish un-compressing file.") | |||||
else: | else: | ||||
untar_gz_file(Path(temp_filename), Path(uncompress_temp_dir)) | |||||
filenames = os.listdir(uncompress_temp_dir) | |||||
if len(filenames)==1: | |||||
if os.path.isdir(os.path.join(uncompress_temp_dir, filenames[0])): | |||||
uncompress_temp_dir = os.path.join(uncompress_temp_dir, filenames[0]) | |||||
cache_path.mkdir(parents=True, exist_ok=True) | |||||
print("Finish un-compressing file.") | |||||
uncompress_temp_dir = temp_filename | |||||
cache_path = str(cache_path) + suffix | |||||
success = False | |||||
try: | |||||
# 复制到指定的位置 | |||||
print(f"Copy file to {cache_path}") | |||||
if os.path.isdir(uncompress_temp_dir): | |||||
for filename in os.listdir(uncompress_temp_dir): | |||||
if os.path.isdir(os.path.join(uncompress_temp_dir, filename)): | |||||
shutil.copytree(os.path.join(uncompress_temp_dir, filename), cache_path/filename) | |||||
else: | |||||
shutil.copyfile(os.path.join(uncompress_temp_dir, filename), cache_path/filename) | |||||
else: | |||||
shutil.copyfile(uncompress_temp_dir, cache_path) | |||||
success = True | |||||
except Exception as e: | |||||
print(e) | |||||
raise e | |||||
finally: | |||||
if not success: | |||||
if cache_path.exists(): | |||||
if cache_path.is_file(): | |||||
os.remove(cache_path) | |||||
else: | |||||
shutil.rmtree(cache_path) | |||||
if delete_temp_dir: | |||||
shutil.rmtree(delete_temp_dir) | |||||
os.close(fd) | |||||
os.remove(temp_filename) | |||||
return get_filepath(cache_path) | |||||
else: | else: | ||||
uncompress_temp_dir = temp_filename | |||||
cache_path = str(cache_path) + suffix | |||||
success = False | |||||
try: | |||||
# 复制到指定的位置 | |||||
print(f"Copy file to {cache_path}") | |||||
if os.path.isdir(uncompress_temp_dir): | |||||
for filename in os.listdir(uncompress_temp_dir): | |||||
shutil.copyfile(os.path.join(uncompress_temp_dir, filename), cache_path/filename) | |||||
else: | |||||
shutil.copyfile(uncompress_temp_dir, cache_path) | |||||
success = True | |||||
except Exception as e: | |||||
print(e) | |||||
raise e | |||||
finally: | |||||
if not success: | |||||
if cache_path.exists(): | |||||
if cache_path.is_file(): | |||||
os.remove(cache_path) | |||||
else: | |||||
shutil.rmtree(cache_path) | |||||
if delete_temp_dir: | |||||
shutil.rmtree(delete_temp_dir) | |||||
os.close(fd) | |||||
os.remove(temp_filename) | |||||
return get_filepath(cache_path) | |||||
raise HTTPError(f"Fail to download from {url}.") | |||||
def unzip_file(file: Path, to: Path): | def unzip_file(file: Path, to: Path): | ||||
@@ -0,0 +1,30 @@ | |||||
""" | |||||
Loader用于读取数据,并将内容读取到 :class:`~fastNLP.DataSet` 或者 :class:`~fastNLP.io.DataBundle`中。所有的Loader都支持以下的 | |||||
三个方法: __init__(),_load(), loads(). 其中__init__()用于申明读取参数,以及说明该Loader支持的数据格式,读取后Dataset中field | |||||
; _load(path)方法传入一个文件路径读取单个文件,并返回DataSet; load(paths)用于读取文件夹下的文件,并返回DataBundle, load()方法 | |||||
支持以下三种类型的参数 | |||||
Example:: | |||||
(0) 如果传入None,将尝试自动下载数据集并缓存。但不是所有的数据都可以直接下载。 | |||||
(1) 如果传入的是一个文件path,则返回的DataBundle包含一个名为train的DataSet可以通过data_bundle.datasets['train']获取 | |||||
(2) 传入的是一个文件夹目录,将读取的是这个文件夹下文件名中包含'train', 'test', 'dev'的文件,其它文件会被忽略。 | |||||
假设某个目录下的文件为 | |||||
-train.txt | |||||
-dev.txt | |||||
-test.txt | |||||
-other.txt | |||||
Loader().load('/path/to/dir')读取,返回的data_bundle中可以用data_bundle.datasets['train'], data_bundle.datasets['dev'], | |||||
data_bundle.datasets['test']获取对应的DataSet,其中other.txt的内容会被忽略。 | |||||
假设某个目录下的文件为 | |||||
-train.txt | |||||
-dev.txt | |||||
Loader().load('/path/to/dir')读取,返回的data_bundle中可以用data_bundle.datasets['train'], data_bundle.datasets['dev']获取 | |||||
对应的DataSet。 | |||||
(3) 传入一个dict,key为dataset的名称,value是该dataset的文件路径。 | |||||
paths = {'train':'/path/to/train', 'dev': '/path/to/dev', 'test':'/path/to/test'} | |||||
Loader().load(paths) # 返回的data_bundle可以通过以下的方式获取相应的DataSet, data_bundle.datasets['train'], data_bundle.datasets['dev'], | |||||
data_bundle.datasets['test'] | |||||
""" | |||||
@@ -0,0 +1,369 @@ | |||||
from ...core.dataset import DataSet | |||||
from ...core.instance import Instance | |||||
from .loader import Loader | |||||
import warnings | |||||
import os | |||||
import random | |||||
import shutil | |||||
import numpy as np | |||||
class YelpLoader(Loader): | |||||
""" | |||||
别名::class:`fastNLP.io.YelpLoader` :class:`fastNLP.io.loader.YelpLoader` | |||||
原始数据中内容应该为, 每一行为一个sample,第一个逗号之前为target,第一个逗号之后为文本内容。 | |||||
Example:: | |||||
"1","I got 'new' tires from the..." | |||||
"1","Don't waste your time..." | |||||
读取YelpFull, YelpPolarity的数据。可以通过xxx下载并预处理数据。 | |||||
读取的DataSet将具备以下的数据结构 | |||||
.. csv-table:: | |||||
:header: "raw_words", "target" | |||||
"I got 'new' tires from them and... ", "1" | |||||
"Don't waste your time. We had two...", "1" | |||||
"...", "..." | |||||
""" | |||||
def __init__(self): | |||||
super(YelpLoader, self).__init__() | |||||
def _load(self, path: str=None): | |||||
ds = DataSet() | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
for line in f: | |||||
line = line.strip() | |||||
sep_index = line.index(',') | |||||
target = line[:sep_index] | |||||
raw_words = line[sep_index + 1:] | |||||
if target.startswith("\""): | |||||
target = target[1:] | |||||
if target.endswith("\""): | |||||
target = target[:-1] | |||||
if raw_words.endswith("\""): | |||||
raw_words = raw_words[:-1] | |||||
if raw_words.startswith('"'): | |||||
raw_words = raw_words[1:] | |||||
raw_words = raw_words.replace('""', '"') # 替换双引号 | |||||
if raw_words: | |||||
ds.append(Instance(raw_words=raw_words, target=target)) | |||||
return ds | |||||
class YelpFullLoader(YelpLoader): | |||||
def download(self, dev_ratio: float = 0.1, seed: int = 0): | |||||
""" | |||||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | |||||
Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances | |||||
in Neural Information Processing Systems 28 (NIPS 2015) | |||||
根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后在output_dir中有train.csv, test.csv, | |||||
dev.csv三个文件。 | |||||
:param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 | |||||
:param int seed: 划分dev时的随机数种子 | |||||
:return: str, 数据集的目录地址 | |||||
""" | |||||
dataset_name = 'yelp-review-full' | |||||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||||
if os.path.exists(os.path.join(data_dir, 'dev.csv')): # 存在dev的话,check是否需要重新下载 | |||||
re_download = True | |||||
if dev_ratio>0: | |||||
dev_line_count = 0 | |||||
tr_line_count = 0 | |||||
with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f1, \ | |||||
open(os.path.join(data_dir, 'dev.csv'), 'r', encoding='utf-8') as f2: | |||||
for line in f1: | |||||
tr_line_count += 1 | |||||
for line in f2: | |||||
dev_line_count += 1 | |||||
if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005): | |||||
re_download = True | |||||
else: | |||||
re_download = False | |||||
if re_download: | |||||
shutil.rmtree(data_dir) | |||||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||||
if not os.path.exists(os.path.join(data_dir, 'dev.csv')): | |||||
if dev_ratio > 0: | |||||
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | |||||
random.seed(int(seed)) | |||||
try: | |||||
with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f, \ | |||||
open(os.path.join(data_dir, 'middle_file.csv'), 'w', encoding='utf-8') as f1, \ | |||||
open(os.path.join(data_dir, 'dev.csv'), 'w', encoding='utf-8') as f2: | |||||
for line in f: | |||||
if random.random() < dev_ratio: | |||||
f2.write(line) | |||||
else: | |||||
f1.write(line) | |||||
os.remove(os.path.join(data_dir, 'train.csv')) | |||||
os.renames(os.path.join(data_dir, 'middle_file.csv'), os.path.join(data_dir, 'train.csv')) | |||||
finally: | |||||
if os.path.exists(os.path.join(data_dir, 'middle_file.csv')): | |||||
os.remove(os.path.join(data_dir, 'middle_file.csv')) | |||||
return data_dir | |||||
class YelpPolarityLoader(YelpLoader): | |||||
def download(self, dev_ratio: float = 0.1, seed: int = 0): | |||||
""" | |||||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | |||||
Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances | |||||
in Neural Information Processing Systems 28 (NIPS 2015) | |||||
根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后从train中切分0.1作为dev | |||||
:param float dev_ratio: 如果路径中不存在dev.csv, 从train划分多少作为dev的数据. 如果为0,则不划分dev | |||||
:param int seed: 划分dev时的随机数种子 | |||||
:return: str, 数据集的目录地址 | |||||
""" | |||||
dataset_name = 'yelp-review-polarity' | |||||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||||
if os.path.exists(os.path.join(data_dir, 'dev.csv')): # 存在dev的话,check是否符合比例要求 | |||||
re_download = True | |||||
if dev_ratio>0: | |||||
dev_line_count = 0 | |||||
tr_line_count = 0 | |||||
with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f1, \ | |||||
open(os.path.join(data_dir, 'dev.csv'), 'r', encoding='utf-8') as f2: | |||||
for line in f1: | |||||
tr_line_count += 1 | |||||
for line in f2: | |||||
dev_line_count += 1 | |||||
if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005): | |||||
re_download = True | |||||
else: | |||||
re_download = False | |||||
if re_download: | |||||
shutil.rmtree(data_dir) | |||||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||||
if not os.path.exists(os.path.join(data_dir, 'dev.csv')): | |||||
if dev_ratio > 0: | |||||
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | |||||
random.seed(int(seed)) | |||||
try: | |||||
with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f, \ | |||||
open(os.path.join(data_dir, 'middle_file.csv'), 'w', encoding='utf-8') as f1, \ | |||||
open(os.path.join(data_dir, 'dev.csv'), 'w', encoding='utf-8') as f2: | |||||
for line in f: | |||||
if random.random() < dev_ratio: | |||||
f2.write(line) | |||||
else: | |||||
f1.write(line) | |||||
os.remove(os.path.join(data_dir, 'train.csv')) | |||||
os.renames(os.path.join(data_dir, 'middle_file.csv'), os.path.join(data_dir, 'train.csv')) | |||||
finally: | |||||
if os.path.exists(os.path.join(data_dir, 'middle_file.csv')): | |||||
os.remove(os.path.join(data_dir, 'middle_file.csv')) | |||||
return data_dir | |||||
class IMDBLoader(Loader): | |||||
""" | |||||
别名::class:`fastNLP.io.IMDBLoader` :class:`fastNLP.io.loader.IMDBLoader` | |||||
IMDBLoader读取后的数据将具有以下两列内容: raw_words: str, 需要分类的文本; target: str, 文本的标签 | |||||
DataSet具备以下的结构: | |||||
.. csv-table:: | |||||
:header: "raw_words", "target" | |||||
"Bromwell High is a cartoon ... ", "pos" | |||||
"Story of a man who has ...", "neg" | |||||
"...", "..." | |||||
""" | |||||
def __init__(self): | |||||
super(IMDBLoader, self).__init__() | |||||
def _load(self, path: str): | |||||
dataset = DataSet() | |||||
with open(path, 'r', encoding="utf-8") as f: | |||||
for line in f: | |||||
line = line.strip() | |||||
if not line: | |||||
continue | |||||
parts = line.split('\t') | |||||
target = parts[0] | |||||
words = parts[1] | |||||
if words: | |||||
dataset.append(Instance(raw_words=words, target=target)) | |||||
if len(dataset) == 0: | |||||
raise RuntimeError(f"{path} has no valid data.") | |||||
return dataset | |||||
def download(self, dev_ratio: float = 0.1, seed: int = 0): | |||||
""" | |||||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | |||||
http://www.aclweb.org/anthology/P11-1015 | |||||
根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后从train中切分0.1作为dev | |||||
:param float dev_ratio: 如果路径中没有dev.txt。从train划分多少作为dev的数据. 如果为0,则不划分dev | |||||
:param int seed: 划分dev时的随机数种子 | |||||
:return: str, 数据集的目录地址 | |||||
""" | |||||
dataset_name = 'aclImdb' | |||||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||||
if os.path.exists(os.path.join(data_dir, 'dev.txt')): # 存在dev的话,check是否符合比例要求 | |||||
re_download = True | |||||
if dev_ratio>0: | |||||
dev_line_count = 0 | |||||
tr_line_count = 0 | |||||
with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f1, \ | |||||
open(os.path.join(data_dir, 'dev.txt'), 'r', encoding='utf-8') as f2: | |||||
for line in f1: | |||||
tr_line_count += 1 | |||||
for line in f2: | |||||
dev_line_count += 1 | |||||
if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005): | |||||
re_download = True | |||||
else: | |||||
re_download = False | |||||
if re_download: | |||||
shutil.rmtree(data_dir) | |||||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||||
if not os.path.exists(os.path.join(data_dir, 'dev.csv')): | |||||
if dev_ratio > 0: | |||||
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | |||||
random.seed(int(seed)) | |||||
try: | |||||
with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f, \ | |||||
open(os.path.join(data_dir, 'middle_file.txt'), 'w', encoding='utf-8') as f1, \ | |||||
open(os.path.join(data_dir, 'dev.txt'), 'w', encoding='utf-8') as f2: | |||||
for line in f: | |||||
if random.random() < dev_ratio: | |||||
f2.write(line) | |||||
else: | |||||
f1.write(line) | |||||
os.remove(os.path.join(data_dir, 'train.txt')) | |||||
os.renames(os.path.join(data_dir, 'middle_file.txt'), os.path.join(data_dir, 'train.txt')) | |||||
finally: | |||||
if os.path.exists(os.path.join(data_dir, 'middle_file.txt')): | |||||
os.remove(os.path.join(data_dir, 'middle_file.txt')) | |||||
return data_dir | |||||
class SSTLoader(Loader): | |||||
""" | |||||
别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.loader.SSTLoader` | |||||
读取之后的DataSet具有以下的结构 | |||||
.. csv-table:: 下面是使用SSTLoader读取的DataSet所具备的field | |||||
:header: "raw_words" | |||||
"(3 (2 It) (4 (4 (2 's) (4 (3 (2 a)..." | |||||
"(4 (4 (2 Offers) (3 (3 (2 that) (3 (3 rare)..." | |||||
"..." | |||||
raw_words列是str。 | |||||
""" | |||||
def __init__(self): | |||||
super().__init__() | |||||
def _load(self, path: str): | |||||
""" | |||||
从path读取SST文件 | |||||
:param str path: 文件路径 | |||||
:return: DataSet | |||||
""" | |||||
ds = DataSet() | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
for line in f: | |||||
line = line.strip() | |||||
if line: | |||||
ds.append(Instance(raw_words=line)) | |||||
return ds | |||||
def download(self): | |||||
""" | |||||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | |||||
https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf | |||||
:return: str, 数据集的目录地址 | |||||
""" | |||||
output_dir = self._get_dataset_path(dataset_name='sst') | |||||
return output_dir | |||||
class SST2Loader(Loader): | |||||
""" | |||||
数据SST2的Loader | |||||
读取之后DataSet将如下所示 | |||||
.. csv-table:: 下面是使用SSTLoader读取的DataSet所具备的field | |||||
:header: "raw_words", "target" | |||||
"it 's a charming and often affecting...", "1" | |||||
"unflinchingly bleak and...", "0" | |||||
"..." | |||||
test的DataSet没有target列。 | |||||
""" | |||||
def __init__(self): | |||||
super().__init__() | |||||
def _load(self, path: str): | |||||
""" | |||||
从path读取SST2文件 | |||||
:param str path: 数据路径 | |||||
:return: DataSet | |||||
""" | |||||
ds = DataSet() | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
f.readline() # 跳过header | |||||
if 'test' in os.path.split(path)[1]: | |||||
warnings.warn("SST2's test file has no target.") | |||||
for line in f: | |||||
line = line.strip() | |||||
if line: | |||||
sep_index = line.index('\t') | |||||
raw_words = line[sep_index + 1:] | |||||
if raw_words: | |||||
ds.append(Instance(raw_words=raw_words)) | |||||
else: | |||||
for line in f: | |||||
line = line.strip() | |||||
if line: | |||||
raw_words = line[:-2] | |||||
target = line[-1] | |||||
if raw_words: | |||||
ds.append(Instance(raw_words=raw_words, target=target)) | |||||
return ds | |||||
def download(self): | |||||
""" | |||||
自动下载数据集,如果你使用了该数据集,请引用以下的文章 | |||||
https://nlp.stanford.edu/pubs/SocherBauerManningNg_ACL2013.pdf | |||||
:return: | |||||
""" | |||||
output_dir = self._get_dataset_path(dataset_name='sst-2') | |||||
return output_dir |
@@ -0,0 +1,264 @@ | |||||
from typing import Dict, Union | |||||
from .loader import Loader | |||||
from ... import DataSet | |||||
from ..file_reader import _read_conll | |||||
from ... import Instance | |||||
from .. import DataBundle | |||||
from ..utils import check_loader_paths | |||||
from ... import Const | |||||
class ConllLoader(Loader): | |||||
""" | |||||
别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.data_loader.ConllLoader` | |||||
ConllLoader支持读取的数据格式: 以空行隔开两个sample,除了分割行,每一行用空格或者制表符隔开不同的元素。如下例所示: | |||||
Example:: | |||||
# 文件中的内容 | |||||
Nadim NNP B-NP B-PER | |||||
Ladki NNP I-NP I-PER | |||||
AL-AIN NNP B-NP B-LOC | |||||
United NNP B-NP B-LOC | |||||
Arab NNP I-NP I-LOC | |||||
Emirates NNPS I-NP I-LOC | |||||
1996-12-06 CD I-NP O | |||||
... | |||||
# 如果用以下的参数读取,返回的DataSet将包含raw_words和pos两个field, 这两个field的值分别取自于第0列与第1列 | |||||
dataset = ConllLoader(headers=['raw_words', 'pos'], indexes=[0, 1])._load('/path/to/train.conll') | |||||
# 如果用以下的参数读取,返回的DataSet将包含raw_words和ner两个field, 这两个field的值分别取自于第0列与第2列 | |||||
dataset = ConllLoader(headers=['raw_words', 'ner'], indexes=[0, 3])._load('/path/to/train.conll') | |||||
# 如果用以下的参数读取,返回的DataSet将包含raw_words, pos和ner三个field | |||||
dataset = ConllLoader(headers=['raw_words', 'pos', 'ner'], indexes=[0, 1, 3])._load('/path/to/train.conll') | |||||
ConllLoader返回的DataSet的field由传入的headers确定。 | |||||
数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。 | |||||
:param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 | |||||
:param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | |||||
:param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` | |||||
""" | |||||
def __init__(self, headers, indexes=None, dropna=True): | |||||
super(ConllLoader, self).__init__() | |||||
if not isinstance(headers, (list, tuple)): | |||||
raise TypeError( | |||||
'invalid headers: {}, should be list of strings'.format(headers)) | |||||
self.headers = headers | |||||
self.dropna = dropna | |||||
if indexes is None: | |||||
self.indexes = list(range(len(self.headers))) | |||||
else: | |||||
if len(indexes) != len(headers): | |||||
raise ValueError | |||||
self.indexes = indexes | |||||
def _load(self, path): | |||||
""" | |||||
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 | |||||
:param str path: 文件的路径 | |||||
:return: DataSet | |||||
""" | |||||
ds = DataSet() | |||||
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): | |||||
ins = {h: data[i] for i, h in enumerate(self.headers)} | |||||
ds.append(Instance(**ins)) | |||||
return ds | |||||
class Conll2003Loader(ConllLoader): | |||||
""" | |||||
用于读取conll2003任务的数据。数据的内容应该类似与以下的内容, 第一列为raw_words, 第二列为pos, 第三列为chunking,第四列为ner。 | |||||
Example:: | |||||
Nadim NNP B-NP B-PER | |||||
Ladki NNP I-NP I-PER | |||||
AL-AIN NNP B-NP B-LOC | |||||
United NNP B-NP B-LOC | |||||
Arab NNP I-NP I-LOC | |||||
Emirates NNPS I-NP I-LOC | |||||
1996-12-06 CD I-NP O | |||||
... | |||||
返回的DataSet的内容为 | |||||
.. csv-table:: 下面是Conll2003Loader加载后数据具备的结构。 | |||||
:header: "raw_words", "pos", "chunk", "ner" | |||||
"[Nadim, Ladki]", "[NNP, NNP]", "[B-NP, I-NP]", "[B-PER, I-PER]" | |||||
"[AL-AIN, United, Arab, ...]", "[NNP, NNP, NNP, ...]", "[B-NP, B-NP, I-NP, ...]", "[B-LOC, B-LOC, I-LOC, ...]" | |||||
"[...]", "[...]", "[...]", "[...]" | |||||
""" | |||||
def __init__(self): | |||||
headers = [ | |||||
'raw_words', 'pos', 'chunk', 'ner', | |||||
] | |||||
super(Conll2003Loader, self).__init__(headers=headers) | |||||
def _load(self, path): | |||||
""" | |||||
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 | |||||
:param str path: 文件的路径 | |||||
:return: DataSet | |||||
""" | |||||
ds = DataSet() | |||||
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): | |||||
doc_start = False | |||||
for i, h in enumerate(self.headers): | |||||
field = data[i] | |||||
if str(field[0]).startswith('-DOCSTART-'): | |||||
doc_start = True | |||||
break | |||||
if doc_start: | |||||
continue | |||||
ins = {h: data[i] for i, h in enumerate(self.headers)} | |||||
ds.append(Instance(**ins)) | |||||
return ds | |||||
def download(self, output_dir=None): | |||||
raise RuntimeError("conll2003 cannot be downloaded automatically.") | |||||
class Conll2003NERLoader(ConllLoader): | |||||
""" | |||||
用于读取conll2003任务的NER数据。 | |||||
Example:: | |||||
Nadim NNP B-NP B-PER | |||||
Ladki NNP I-NP I-PER | |||||
AL-AIN NNP B-NP B-LOC | |||||
United NNP B-NP B-LOC | |||||
Arab NNP I-NP I-LOC | |||||
Emirates NNPS I-NP I-LOC | |||||
1996-12-06 CD I-NP O | |||||
... | |||||
返回的DataSet的内容为 | |||||
.. csv-table:: 下面是Conll2003Loader加载后数据具备的结构, target是BIO2编码 | |||||
:header: "raw_words", "target" | |||||
"[Nadim, Ladki]", "[B-PER, I-PER]" | |||||
"[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]" | |||||
"[...]", "[...]" | |||||
""" | |||||
def __init__(self): | |||||
headers = [ | |||||
'raw_words', 'target', | |||||
] | |||||
super().__init__(headers=headers, indexes=[0, 3]) | |||||
def _load(self, path): | |||||
""" | |||||
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 | |||||
:param str path: 文件的路径 | |||||
:return: DataSet | |||||
""" | |||||
ds = DataSet() | |||||
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): | |||||
doc_start = False | |||||
for i, h in enumerate(self.headers): | |||||
field = data[i] | |||||
if str(field[0]).startswith('-DOCSTART-'): | |||||
doc_start = True | |||||
break | |||||
if doc_start: | |||||
continue | |||||
ins = {h: data[i] for i, h in enumerate(self.headers)} | |||||
ds.append(Instance(**ins)) | |||||
return ds | |||||
def download(self): | |||||
raise RuntimeError("conll2003 cannot be downloaded automatically.") | |||||
class OntoNotesNERLoader(ConllLoader): | |||||
""" | |||||
用以读取OntoNotes的NER数据,同时也是Conll2012的NER任务数据。将OntoNote数据处理为conll格式的过程可以参考 | |||||
https://github.com/yhcc/OntoNotes-5.0-NER。OntoNoteNERLoader将取第4列和第11列的内容。 | |||||
返回的DataSet的内容为 | |||||
.. csv-table:: 下面是使用OntoNoteNERLoader读取的DataSet所具备的结构, target列是BIO编码 | |||||
:header: "raw_words", "target" | |||||
"[Nadim, Ladki]", "[B-PER, I-PER]" | |||||
"[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]" | |||||
"[...]", "[...]" | |||||
""" | |||||
def __init__(self): | |||||
super().__init__(headers=[Const.RAW_WORD, Const.TARGET], indexes=[3, 10]) | |||||
def _load(self, path:str): | |||||
dataset = super()._load(path) | |||||
def convert_to_bio(tags): | |||||
bio_tags = [] | |||||
flag = None | |||||
for tag in tags: | |||||
label = tag.strip("()*") | |||||
if '(' in tag: | |||||
bio_label = 'B-' + label | |||||
flag = label | |||||
elif flag: | |||||
bio_label = 'I-' + flag | |||||
else: | |||||
bio_label = 'O' | |||||
if ')' in tag: | |||||
flag = None | |||||
bio_tags.append(bio_label) | |||||
return bio_tags | |||||
def convert_word(words): | |||||
converted_words = [] | |||||
for word in words: | |||||
word = word.replace('/.', '.') # 有些结尾的.是/.形式的 | |||||
if not word.startswith('-'): | |||||
converted_words.append(word) | |||||
continue | |||||
# 以下是由于这些符号被转义了,再转回来 | |||||
tfrs = {'-LRB-':'(', | |||||
'-RRB-': ')', | |||||
'-LSB-': '[', | |||||
'-RSB-': ']', | |||||
'-LCB-': '{', | |||||
'-RCB-': '}' | |||||
} | |||||
if word in tfrs: | |||||
converted_words.append(tfrs[word]) | |||||
else: | |||||
converted_words.append(word) | |||||
return converted_words | |||||
dataset.apply_field(convert_word, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD) | |||||
dataset.apply_field(convert_to_bio, field_name=Const.TARGET, new_field_name=Const.TARGET) | |||||
return dataset | |||||
def download(self): | |||||
raise RuntimeError("Ontonotes cannot be downloaded automatically, you can refer " | |||||
"https://github.com/yhcc/OntoNotes-5.0-NER to download and preprocess.") | |||||
class CTBLoader(Loader): | |||||
def __init__(self): | |||||
super().__init__() | |||||
def _load(self, path:str): | |||||
pass |
@@ -0,0 +1,32 @@ | |||||
from ...core.dataset import DataSet | |||||
from ...core.instance import Instance | |||||
from ..file_reader import _read_csv | |||||
from .loader import Loader | |||||
class CSVLoader(Loader): | |||||
""" | |||||
别名::class:`fastNLP.io.CSVLoader` :class:`fastNLP.io.dataset_loader.CSVLoader` | |||||
读取CSV格式的数据集, 返回 ``DataSet`` 。 | |||||
:param List[str] headers: CSV文件的文件头.定义每一列的属性名称,即返回的DataSet中`field`的名称 | |||||
若为 ``None`` ,则将读入文件的第一行视作 ``headers`` . Default: ``None`` | |||||
:param str sep: CSV文件中列与列之间的分隔符. Default: "," | |||||
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . | |||||
Default: ``False`` | |||||
""" | |||||
def __init__(self, headers=None, sep=",", dropna=False): | |||||
super().__init__() | |||||
self.headers = headers | |||||
self.sep = sep | |||||
self.dropna = dropna | |||||
def _load(self, path): | |||||
ds = DataSet() | |||||
for idx, data in _read_csv(path, headers=self.headers, | |||||
sep=self.sep, dropna=self.dropna): | |||||
ds.append(Instance(**data)) | |||||
return ds | |||||
@@ -0,0 +1,41 @@ | |||||
from .loader import Loader | |||||
from ...core import DataSet, Instance | |||||
class CWSLoader(Loader): | |||||
""" | |||||
分词任务数据加载器, | |||||
SigHan2005的数据可以用xxx下载并预处理 | |||||
CWSLoader支持的数据格式为,一行一句话,不同词之间用空格隔开, 例如: | |||||
Example:: | |||||
上海 浦东 开发 与 法制 建设 同步 | |||||
新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 ) | |||||
... | |||||
该Loader读取后的DataSet具有如下的结构 | |||||
.. csv-table:: | |||||
:header: "raw_words" | |||||
"上海 浦东 开发 与 法制 建设 同步" | |||||
"新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )" | |||||
"..." | |||||
""" | |||||
def __init__(self): | |||||
super().__init__() | |||||
def _load(self, path:str): | |||||
ds = DataSet() | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
for line in f: | |||||
line = line.strip() | |||||
if line: | |||||
ds.append(Instance(raw_words=line)) | |||||
return ds | |||||
def download(self, output_dir=None): | |||||
raise RuntimeError("You can refer {} for sighan2005's data downloading.") |
@@ -0,0 +1,40 @@ | |||||
from ...core.dataset import DataSet | |||||
from ...core.instance import Instance | |||||
from ..file_reader import _read_json | |||||
from .loader import Loader | |||||
class JsonLoader(Loader): | |||||
""" | |||||
别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.loader.JsonLoader` | |||||
读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象 | |||||
:param dict fields: 需要读入的json属性名称, 和读入后在DataSet中存储的field_name | |||||
``fields`` 的 `key` 必须是json对象的属性名. ``fields`` 的 `value` 为读入后在DataSet存储的 `field_name` , | |||||
`value` 也可为 ``None`` , 这时读入后的 `field_name` 与json对象对应属性同名 | |||||
``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None`` | |||||
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . | |||||
Default: ``False`` | |||||
""" | |||||
def __init__(self, fields=None, dropna=False): | |||||
super(JsonLoader, self).__init__() | |||||
self.dropna = dropna | |||||
self.fields = None | |||||
self.fields_list = None | |||||
if fields: | |||||
self.fields = {} | |||||
for k, v in fields.items(): | |||||
self.fields[k] = k if v is None else v | |||||
self.fields_list = list(self.fields.keys()) | |||||
def _load(self, path): | |||||
ds = DataSet() | |||||
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): | |||||
if self.fields: | |||||
ins = {self.fields[k]: v for k, v in d.items()} | |||||
else: | |||||
ins = d | |||||
ds.append(Instance(**ins)) | |||||
return ds |
@@ -0,0 +1,75 @@ | |||||
from ... import DataSet | |||||
from .. import DataBundle | |||||
from ..utils import check_loader_paths | |||||
from typing import Union, Dict | |||||
import os | |||||
from ..file_utils import _get_dataset_url, get_default_cache_path, cached_path | |||||
class Loader: | |||||
def __init__(self): | |||||
pass | |||||
def _load(self, path:str) -> DataSet: | |||||
raise NotImplementedError | |||||
def load(self, paths: Union[str, Dict[str, str]]=None) -> DataBundle: | |||||
""" | |||||
从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。 | |||||
读取的field根据ConllLoader初始化时传入的headers决定。 | |||||
:param Union[str, Dict[str, str]] paths: 支持以下的几种输入方式 | |||||
(0) 如果为None,则先查看本地是否有缓存,如果没有则自动下载并缓存。 | |||||
(1) 传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件 | |||||
名包含'train'、 'dev'、 'test'则会报错 | |||||
Example:: | |||||
data_bundle = ConllLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train、 | |||||
# dev、 test等有所变化,可以通过以下的方式取出DataSet | |||||
tr_data = data_bundle.datasets['train'] | |||||
te_data = data_bundle.datasets['test'] # 如果目录下有文件包含test这个字段 | |||||
(2) 传入文件路径 | |||||
Example:: | |||||
data_bundle = ConllLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train' | |||||
tr_data = data_bundle.datasets['train'] # 可以通过以下的方式取出DataSet | |||||
(3) 传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test | |||||
Example:: | |||||
paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"} | |||||
data_bundle = ConllLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test" | |||||
dev_data = data_bundle.datasets['dev'] | |||||
:return: 返回的:class:`~fastNLP.io.DataBundle` | |||||
""" | |||||
if paths is None: | |||||
paths = self.download() | |||||
paths = check_loader_paths(paths) | |||||
datasets = {name: self._load(path) for name, path in paths.items()} | |||||
data_bundle = DataBundle(datasets=datasets) | |||||
return data_bundle | |||||
def download(self): | |||||
raise NotImplementedError(f"{self.__class__} cannot download data automatically.") | |||||
def _get_dataset_path(self, dataset_name): | |||||
""" | |||||
传入dataset的名称,获取读取数据的目录。如果数据不存在,会尝试自动下载并缓存 | |||||
:param str dataset_name: 数据集的名称 | |||||
:return: str, 数据集的目录地址。直接到该目录下读取相应的数据即可。 | |||||
""" | |||||
default_cache_path = get_default_cache_path() | |||||
url = _get_dataset_url(dataset_name) | |||||
output_dir = cached_path(url_or_filename=url, cache_dir=default_cache_path, name='dataset') | |||||
return output_dir | |||||
@@ -0,0 +1,309 @@ | |||||
import warnings | |||||
from .loader import Loader | |||||
from .json import JsonLoader | |||||
from ...core import Const | |||||
from .. import DataBundle | |||||
import os | |||||
from typing import Union, Dict | |||||
from ...core import DataSet | |||||
from ...core import Instance | |||||
__all__ = ['MNLILoader', | |||||
"QuoraLoader", | |||||
"SNLILoader", | |||||
"QNLILoader", | |||||
"RTELoader"] | |||||
class MNLILoader(Loader): | |||||
""" | |||||
读取MNLI任务的数据,读取之后的DataSet中包含以下的内容,words0是sentence1, words1是sentence2, target是gold_label, 测试集中没 | |||||
有target列。 | |||||
.. csv-table:: | |||||
:header: "raw_words1", "raw_words2", "target" | |||||
"The new rights are...", "Everyone really likes..", "neutral" | |||||
"This site includes a...", "The Government Executive...", "contradiction" | |||||
"...", "...","." | |||||
""" | |||||
def __init__(self): | |||||
super().__init__() | |||||
def _load(self, path:str): | |||||
ds = DataSet() | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
f.readline() # 跳过header | |||||
if path.endswith("test.tsv"): | |||||
warnings.warn("RTE's test file has no target.") | |||||
for line in f: | |||||
line = line.strip() | |||||
if line: | |||||
parts = line.split('\t') | |||||
raw_words1 = parts[8] | |||||
raw_words2 = parts[9] | |||||
if raw_words1 and raw_words2: | |||||
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2)) | |||||
else: | |||||
for line in f: | |||||
line = line.strip() | |||||
if line: | |||||
parts = line.split('\t') | |||||
raw_words1 = parts[8] | |||||
raw_words2 = parts[9] | |||||
target = parts[-1] | |||||
if raw_words1 and raw_words2 and target: | |||||
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) | |||||
return ds | |||||
def load(self, paths:str=None): | |||||
""" | |||||
:param str paths: 传入数据所在目录,会在该目录下寻找dev_matched.tsv, dev_mismatched.tsv, test_matched.tsv, | |||||
test_mismatched.tsv, train.tsv文件夹 | |||||
:return: DataBundle | |||||
""" | |||||
if paths: | |||||
paths = os.path.abspath(os.path.expanduser(paths)) | |||||
else: | |||||
paths = self.download() | |||||
if not os.path.isdir(paths): | |||||
raise NotADirectoryError(f"{paths} is not a valid directory.") | |||||
files = {'dev_matched':"dev_matched.tsv", | |||||
"dev_mismatched":"dev_mismatched.tsv", | |||||
"test_matched":"test_matched.tsv", | |||||
"test_mismatched":"test_mismatched.tsv", | |||||
"train":'train.tsv'} | |||||
datasets = {} | |||||
for name, filename in files.items(): | |||||
filepath = os.path.join(paths, filename) | |||||
if not os.path.isfile(filepath): | |||||
if 'test' not in name: | |||||
raise FileNotFoundError(f"{name} not found in directory {filepath}.") | |||||
datasets[name] = self._load(filepath) | |||||
data_bundle = DataBundle(datasets=datasets) | |||||
return data_bundle | |||||
def download(self): | |||||
""" | |||||
如果你使用了这个数据,请引用 | |||||
https://www.nyu.edu/projects/bowman/multinli/paper.pdf | |||||
:return: | |||||
""" | |||||
output_dir = self._get_dataset_path('mnli') | |||||
return output_dir | |||||
class SNLILoader(JsonLoader): | |||||
""" | |||||
读取之后的DataSet中的field情况为 | |||||
.. csv-table:: 下面是使用SNLILoader加载的DataSet所具备的field | |||||
:header: "raw_words1", "raw_words2", "target" | |||||
"The new rights are...", "Everyone really likes..", "neutral" | |||||
"This site includes a...", "The Government Executive...", "entailment" | |||||
"...", "...", "." | |||||
""" | |||||
def __init__(self): | |||||
super().__init__(fields={ | |||||
'sentence1': Const.RAW_WORDS(0), | |||||
'sentence2': Const.RAW_WORDS(1), | |||||
'gold_label': Const.TARGET, | |||||
}) | |||||
def load(self, paths: Union[str, Dict[str, str]]=None) -> DataBundle: | |||||
""" | |||||
从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。 | |||||
读取的field根据ConllLoader初始化时传入的headers决定。 | |||||
:param str paths: 传入一个目录, 将在该目录下寻找snli_1.0_train.jsonl, snli_1.0_dev.jsonl | |||||
和snli_1.0_test.jsonl三个文件。 | |||||
:return: 返回的:class:`~fastNLP.io.DataBundle` | |||||
""" | |||||
_paths = {} | |||||
if paths is None: | |||||
paths = self.download() | |||||
if paths: | |||||
if os.path.isdir(paths): | |||||
if not os.path.isfile(os.path.join(paths, 'snli_1.0_train.jsonl')): | |||||
raise FileNotFoundError(f"snli_1.0_train.jsonl is not found in {paths}") | |||||
_paths['train'] = os.path.join(paths, 'snli_1.0_train.jsonl') | |||||
for filename in ['snli_1.0_dev.jsonl', 'snli_1.0_test.jsonl']: | |||||
filepath = os.path.join(paths, filename) | |||||
_paths[filename.split('_')[-1].split('.')[0]] = filepath | |||||
paths = _paths | |||||
else: | |||||
raise NotADirectoryError(f"{paths} is not a valid directory.") | |||||
datasets = {name: self._load(path) for name, path in paths.items()} | |||||
data_bundle = DataBundle(datasets=datasets) | |||||
return data_bundle | |||||
def download(self): | |||||
""" | |||||
如果您的文章使用了这份数据,请引用 | |||||
http://nlp.stanford.edu/pubs/snli_paper.pdf | |||||
:return: str | |||||
""" | |||||
return self._get_dataset_path('snli') | |||||
class QNLILoader(JsonLoader): | |||||
""" | |||||
QNLI数据集的Loader, | |||||
加载的DataSet将具备以下的field, raw_words1是question, raw_words2是sentence, target是label | |||||
.. csv-table:: | |||||
:header: "raw_words1", "raw_words2", "target" | |||||
"What came into force after the new...", "As of that day...", "entailment" | |||||
"What is the first major...", "The most important tributaries", "not_entailment" | |||||
"...","." | |||||
test数据集没有target列 | |||||
""" | |||||
def __init__(self): | |||||
super().__init__() | |||||
def _load(self, path): | |||||
ds = DataSet() | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
f.readline() # 跳过header | |||||
if path.endswith("test.tsv"): | |||||
warnings.warn("QNLI's test file has no target.") | |||||
for line in f: | |||||
line = line.strip() | |||||
if line: | |||||
parts = line.split('\t') | |||||
raw_words1 = parts[1] | |||||
raw_words2 = parts[2] | |||||
if raw_words1 and raw_words2: | |||||
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2)) | |||||
else: | |||||
for line in f: | |||||
line = line.strip() | |||||
if line: | |||||
parts = line.split('\t') | |||||
raw_words1 = parts[1] | |||||
raw_words2 = parts[2] | |||||
target = parts[-1] | |||||
if raw_words1 and raw_words2 and target: | |||||
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) | |||||
return ds | |||||
def download(self): | |||||
""" | |||||
如果您的实验使用到了该数据,请引用 | |||||
TODO 补充 | |||||
:return: | |||||
""" | |||||
return self._get_dataset_path('qnli') | |||||
class RTELoader(Loader): | |||||
""" | |||||
RTE数据的loader | |||||
加载的DataSet将具备以下的field, raw_words1是sentence0,raw_words2是sentence1, target是label | |||||
.. csv-table:: | |||||
:header: "raw_words1", "raw_words2", "target" | |||||
"Dana Reeve, the widow of the actor...", "Christopher Reeve had an...", "not_entailment" | |||||
"Yet, we now are discovering that...", "Bacteria is winning...", "entailment" | |||||
"...","." | |||||
test数据集没有target列 | |||||
""" | |||||
def __init__(self): | |||||
super().__init__() | |||||
def _load(self, path:str): | |||||
ds = DataSet() | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
f.readline() # 跳过header | |||||
if path.endswith("test.tsv"): | |||||
warnings.warn("RTE's test file has no target.") | |||||
for line in f: | |||||
line = line.strip() | |||||
if line: | |||||
parts = line.split('\t') | |||||
raw_words1 = parts[1] | |||||
raw_words2 = parts[2] | |||||
if raw_words1 and raw_words2: | |||||
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2)) | |||||
else: | |||||
for line in f: | |||||
line = line.strip() | |||||
if line: | |||||
parts = line.split('\t') | |||||
raw_words1 = parts[1] | |||||
raw_words2 = parts[2] | |||||
target = parts[-1] | |||||
if raw_words1 and raw_words2 and target: | |||||
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) | |||||
return ds | |||||
def download(self): | |||||
return self._get_dataset_path('rte') | |||||
class QuoraLoader(Loader): | |||||
""" | |||||
Quora matching任务的数据集Loader | |||||
支持读取的文件中的内容,应该有以下的形式, 以制表符分隔,且前三列的内容必须是:第一列是label,第二列和第三列是句子 | |||||
Example:: | |||||
1 How do I get funding for my web based startup idea ? How do I get seed funding pre product ? 327970 | |||||
1 How can I stop my depression ? What can I do to stop being depressed ? 339556 | |||||
... | |||||
加载的DataSet将具备以下的field | |||||
.. csv-table:: | |||||
:header: "raw_words1", "raw_words2", "target" | |||||
"What should I do to avoid...", "1" | |||||
"How do I not sleep in a boring class...", "0" | |||||
"...","." | |||||
""" | |||||
def __init__(self): | |||||
super().__init__() | |||||
def _load(self, path:str): | |||||
ds = DataSet() | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
for line in f: | |||||
line = line.strip() | |||||
if line: | |||||
parts = line.split('\t') | |||||
raw_words1 = parts[1] | |||||
raw_words2 = parts[2] | |||||
target = parts[0] | |||||
if raw_words1 and raw_words2 and target: | |||||
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) | |||||
return ds | |||||
def download(self): | |||||
raise RuntimeError("Quora cannot be downloaded automatically.") |
@@ -0,0 +1,8 @@ | |||||
""" | |||||
Pipe用于处理数据,所有的Pipe都包含一个process(DataBundle)方法,传入一个DataBundle对象, 在传入DataBundle上进行原位修改,并将其返回; | |||||
process_from_file(paths)传入的文件路径,返回一个DataBundle。process(DataBundle)或者process_from_file(paths)的返回DataBundle | |||||
中的DataSet一般都包含原文与转换为index的输入,以及转换为index的target;除了DataSet之外,还会包含将field转为index时所建立的词表。 | |||||
""" |
@@ -0,0 +1,444 @@ | |||||
from nltk import Tree | |||||
from ..base_loader import DataBundle | |||||
from ...core.vocabulary import Vocabulary | |||||
from ...core.const import Const | |||||
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader | |||||
from ...core import DataSet, Instance | |||||
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance | |||||
from .pipe import Pipe | |||||
import re | |||||
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') | |||||
from ...core import cache_results | |||||
class _CLSPipe(Pipe): | |||||
""" | |||||
分类问题的基类,负责对classification的数据进行tokenize操作。默认是对raw_words列操作,然后生成words列 | |||||
""" | |||||
def __init__(self, tokenizer:str='spacy', lang='en'): | |||||
self.tokenizer = get_tokenizer(tokenizer, lang=lang) | |||||
def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): | |||||
""" | |||||
将DataBundle中的数据进行tokenize | |||||
:param DataBundle data_bundle: | |||||
:param str field_name: | |||||
:param str new_field_name: | |||||
:return: 传入的DataBundle对象 | |||||
""" | |||||
new_field_name = new_field_name or field_name | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) | |||||
return data_bundle | |||||
def _granularize(self, data_bundle, tag_map): | |||||
""" | |||||
该函数对data_bundle中'target'列中的内容进行转换。 | |||||
:param data_bundle: | |||||
:param dict tag_map: 将target列中的tag做以下的映射,比如{"0":0, "1":0, "3":1, "4":1}, 则会删除target为"2"的instance, | |||||
且将"1"认为是第0类。 | |||||
:return: 传入的data_bundle | |||||
""" | |||||
for name in list(data_bundle.datasets.keys()): | |||||
dataset = data_bundle.get_dataset(name) | |||||
dataset.apply_field(lambda target:tag_map.get(target, -100), field_name=Const.TARGET, | |||||
new_field_name=Const.TARGET) | |||||
dataset.drop(lambda ins:ins[Const.TARGET] == -100) | |||||
data_bundle.set_dataset(dataset, name) | |||||
return data_bundle | |||||
def _clean_str(words): | |||||
""" | |||||
heavily borrowed from github | |||||
https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb | |||||
:param sentence: is a str | |||||
:return: | |||||
""" | |||||
words_collection = [] | |||||
for word in words: | |||||
if word in ['-lrb-', '-rrb-', '<sssss>', '-r', '-l', 'b-']: | |||||
continue | |||||
tt = nonalpnum.split(word) | |||||
t = ''.join(tt) | |||||
if t != '': | |||||
words_collection.append(t) | |||||
return words_collection | |||||
class YelpFullPipe(_CLSPipe): | |||||
""" | |||||
处理YelpFull的数据, 处理之后DataSet中的内容如下 | |||||
.. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field | |||||
:header: "raw_words", "words", "target", "seq_len" | |||||
"It 's a ...", "[4, 2, 10, ...]", 0, 10 | |||||
"Offers that ...", "[20, 40, ...]", 1, 21 | |||||
"...", "[...]", ., . | |||||
:param bool lower: 是否对输入进行小写化。 | |||||
:param int granularity: 支持2, 3, 5。若为2, 则认为是2分类问题,将1、2归为1类,4、5归为一类,丢掉2;若为3, 则有3分类问题,将 | |||||
1、2归为1类,3归为1类,4、5归为1类;若为5, 则有5分类问题。 | |||||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | |||||
""" | |||||
def __init__(self, lower:bool=False, granularity=5, tokenizer:str='spacy'): | |||||
super().__init__(tokenizer=tokenizer, lang='en') | |||||
self.lower = lower | |||||
assert granularity in (2, 3, 5), "granularity can only be 2,3,5." | |||||
self.granularity = granularity | |||||
if granularity==2: | |||||
self.tag_map = {"1": 0, "2": 0, "4": 1, "5": 1} | |||||
elif granularity==3: | |||||
self.tag_map = {"1": 0, "2": 0, "3":1, "4": 2, "5": 2} | |||||
else: | |||||
self.tag_map = {"1": 0, "2": 1, "3": 2, "4": 3, "5": 4} | |||||
def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): | |||||
""" | |||||
将DataBundle中的数据进行tokenize | |||||
:param DataBundle data_bundle: | |||||
:param str field_name: | |||||
:param str new_field_name: | |||||
:return: 传入的DataBundle对象 | |||||
""" | |||||
new_field_name = new_field_name or field_name | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) | |||||
dataset.apply_field(_clean_str, field_name=field_name, new_field_name=new_field_name) | |||||
return data_bundle | |||||
def process(self, data_bundle): | |||||
""" | |||||
传入的DataSet应该具备如下的结构 | |||||
.. csv-table:: | |||||
:header: "raw_words", "target" | |||||
"I got 'new' tires from them and... ", "1" | |||||
"Don't waste your time. We had two...", "1" | |||||
"...", "..." | |||||
:param data_bundle: | |||||
:return: | |||||
""" | |||||
# 复制一列words | |||||
data_bundle = _add_words_field(data_bundle, lower=self.lower) | |||||
# 进行tokenize | |||||
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) | |||||
# 根据granularity设置tag | |||||
data_bundle = self._granularize(data_bundle, tag_map=self.tag_map) | |||||
# 删除空行 | |||||
data_bundle = _drop_empty_instance(data_bundle, field_name=Const.INPUT) | |||||
# index | |||||
data_bundle = _indexize(data_bundle=data_bundle) | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset.add_seq_len(Const.INPUT) | |||||
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) | |||||
data_bundle.set_target(Const.TARGET) | |||||
return data_bundle | |||||
def process_from_file(self, paths=None): | |||||
""" | |||||
:param paths: | |||||
:return: DataBundle | |||||
""" | |||||
data_bundle = YelpFullLoader().load(paths) | |||||
return self.process(data_bundle=data_bundle) | |||||
class YelpPolarityPipe(_CLSPipe): | |||||
""" | |||||
处理YelpPolarity的数据, 处理之后DataSet中的内容如下 | |||||
.. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field | |||||
:header: "raw_words", "words", "target", "seq_len" | |||||
"It 's a ...", "[4, 2, 10, ...]", 0, 10 | |||||
"Offers that ...", "[20, 40, ...]", 1, 21 | |||||
"...", "[...]", ., . | |||||
:param bool lower: 是否对输入进行小写化。 | |||||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | |||||
""" | |||||
def __init__(self, lower:bool=False, tokenizer:str='spacy'): | |||||
super().__init__(tokenizer=tokenizer, lang='en') | |||||
self.lower = lower | |||||
def process(self, data_bundle): | |||||
# 复制一列words | |||||
data_bundle = _add_words_field(data_bundle, lower=self.lower) | |||||
# 进行tokenize | |||||
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) | |||||
# index | |||||
data_bundle = _indexize(data_bundle=data_bundle) | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset.add_seq_len(Const.INPUT) | |||||
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) | |||||
data_bundle.set_target(Const.TARGET) | |||||
return data_bundle | |||||
def process_from_file(self, paths=None): | |||||
""" | |||||
:param str paths: | |||||
:return: DataBundle | |||||
""" | |||||
data_bundle = YelpPolarityLoader().load(paths) | |||||
return self.process(data_bundle=data_bundle) | |||||
class SSTPipe(_CLSPipe): | |||||
""" | |||||
别名::class:`fastNLP.io.SSTPipe` :class:`fastNLP.io.pipe.SSTPipe` | |||||
经过该Pipe之后,DataSet中具备的field如下所示 | |||||
.. csv-table:: 下面是使用SSTPipe处理后的DataSet所具备的field | |||||
:header: "raw_words", "words", "target", "seq_len" | |||||
"It 's a ...", "[4, 2, 10, ...]", 0, 16 | |||||
"Offers that ...", "[20, 40, ...]", 1, 18 | |||||
"...", "[...]", ., . | |||||
:param bool subtree: 是否将train, test, dev数据展开为子树,扩充数据量。 Default: ``False`` | |||||
:param bool train_subtree: 是否将train集通过子树扩展数据。 | |||||
:param bool lower: 是否对输入进行小写化。 | |||||
:param int granularity: 支持2, 3, 5。若为2, 则认为是2分类问题,将0、1归为1类,3、4归为一类,丢掉2;若为3, 则有3分类问题,将 | |||||
0、1归为1类,2归为1类,3、4归为1类;若为5, 则有5分类问题。 | |||||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | |||||
""" | |||||
def __init__(self, subtree=False, train_subtree=True, lower=False, granularity=5, tokenizer='spacy'): | |||||
super().__init__(tokenizer=tokenizer, lang='en') | |||||
self.subtree = subtree | |||||
self.train_tree = train_subtree | |||||
self.lower = lower | |||||
assert granularity in (2, 3, 5), "granularity can only be 2,3,5." | |||||
self.granularity = granularity | |||||
if granularity==2: | |||||
self.tag_map = {"0": 0, "1": 0, "3": 1, "4": 1} | |||||
elif granularity==3: | |||||
self.tag_map = {"0": 0, "1": 0, "2":1, "3": 2, "4": 2} | |||||
else: | |||||
self.tag_map = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4} | |||||
def process(self, data_bundle:DataBundle): | |||||
""" | |||||
对DataBundle中的数据进行预处理。输入的DataSet应该至少拥有raw_words这一列,且内容类似与 | |||||
.. csv-table:: | |||||
:header: "raw_words" | |||||
"(3 (2 It) (4 (4 (2 's) (4 (3 (2 a)..." | |||||
"(4 (4 (2 Offers) (3 (3 (2 that) (3 (3 rare)..." | |||||
"..." | |||||
:param DataBundle data_bundle: 需要处理的DataBundle对象 | |||||
:return: | |||||
""" | |||||
# 先取出subtree | |||||
for name in list(data_bundle.datasets.keys()): | |||||
dataset = data_bundle.get_dataset(name) | |||||
ds = DataSet() | |||||
use_subtree = self.subtree or (name == 'train' and self.train_tree) | |||||
for ins in dataset: | |||||
raw_words = ins['raw_words'] | |||||
tree = Tree.fromstring(raw_words) | |||||
if use_subtree: | |||||
for t in tree.subtrees(): | |||||
raw_words = " ".join(t.leaves()) | |||||
instance = Instance(raw_words=raw_words, target=t.label()) | |||||
ds.append(instance) | |||||
else: | |||||
instance = Instance(raw_words=' '.join(tree.leaves()), target=tree.label()) | |||||
ds.append(instance) | |||||
data_bundle.set_dataset(ds, name) | |||||
_add_words_field(data_bundle, lower=self.lower) | |||||
# 进行tokenize | |||||
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) | |||||
# 根据granularity设置tag | |||||
data_bundle = self._granularize(data_bundle, tag_map=self.tag_map) | |||||
# index | |||||
data_bundle = _indexize(data_bundle=data_bundle) | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset.add_seq_len(Const.INPUT) | |||||
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) | |||||
data_bundle.set_target(Const.TARGET) | |||||
return data_bundle | |||||
def process_from_file(self, paths=None): | |||||
data_bundle = SSTLoader().load(paths) | |||||
return self.process(data_bundle=data_bundle) | |||||
class SST2Pipe(_CLSPipe): | |||||
""" | |||||
加载SST2的数据, 处理完成之后DataSet将拥有以下的field | |||||
.. csv-table:: | |||||
:header: "raw_words", "words", "target", "seq_len" | |||||
"it 's a charming and... ", "[3, 4, 5, 6, 7,...]", 1, 43 | |||||
"unflinchingly bleak and...", "[10, 11, 7,...]", 1, 21 | |||||
"...", "...", ., . | |||||
:param bool lower: 是否对输入进行小写化。 | |||||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | |||||
""" | |||||
def __init__(self, lower=False, tokenizer='spacy'): | |||||
super().__init__(tokenizer=tokenizer, lang='en') | |||||
self.lower = lower | |||||
def process(self, data_bundle:DataBundle): | |||||
""" | |||||
可以处理的DataSet应该具备如下的结构 | |||||
.. csv-table:: | |||||
:header: "raw_words", "target" | |||||
"it 's a charming and... ", 1 | |||||
"unflinchingly bleak and...", 1 | |||||
"...", "..." | |||||
:param data_bundle: | |||||
:return: | |||||
""" | |||||
_add_words_field(data_bundle, self.lower) | |||||
data_bundle = self._tokenize(data_bundle=data_bundle) | |||||
src_vocab = Vocabulary() | |||||
src_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.INPUT, | |||||
no_create_entry_dataset=[dataset for name,dataset in data_bundle.datasets.items() if | |||||
name != 'train']) | |||||
src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT) | |||||
tgt_vocab = Vocabulary(unknown=None, padding=None) | |||||
tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) | |||||
datasets = [] | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
if dataset.has_field(Const.TARGET): | |||||
datasets.append(dataset) | |||||
tgt_vocab.index_dataset(*datasets, field_name=Const.TARGET) | |||||
data_bundle.set_vocab(src_vocab, Const.INPUT) | |||||
data_bundle.set_vocab(tgt_vocab, Const.TARGET) | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset.add_seq_len(Const.INPUT) | |||||
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) | |||||
data_bundle.set_target(Const.TARGET) | |||||
return data_bundle | |||||
def process_from_file(self, paths=None): | |||||
""" | |||||
:param str paths: 如果为None,则自动下载并缓存到fastNLP的缓存地址。 | |||||
:return: DataBundle | |||||
""" | |||||
data_bundle = SST2Loader().load(paths) | |||||
return self.process(data_bundle) | |||||
class IMDBPipe(_CLSPipe): | |||||
""" | |||||
经过本Pipe处理后DataSet将如下 | |||||
.. csv-table:: 输出DataSet的field | |||||
:header: "raw_words", "words", "target", "seq_len" | |||||
"Bromwell High is a cartoon ... ", "[3, 5, 6, 9, ...]", 0, 20 | |||||
"Story of a man who has ...", "[20, 43, 9, 10, ...]", 1, 31 | |||||
"...", "[...]", ., . | |||||
其中raw_words为str类型,是原文; words是转换为index的输入; target是转换为index的目标值; | |||||
words列被设置为input; target列被设置为target。 | |||||
:param bool lower: 是否将words列的数据小写。 | |||||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | |||||
""" | |||||
def __init__(self, lower:bool=False, tokenizer:str='spacy'): | |||||
super().__init__(tokenizer=tokenizer, lang='en') | |||||
self.lower = lower | |||||
def process(self, data_bundle:DataBundle): | |||||
""" | |||||
期待的DataBunlde中输入的DataSet应该类似于如下,有两个field,raw_words和target,且均为str类型 | |||||
.. csv-table:: 输入DataSet的field | |||||
:header: "raw_words", "target" | |||||
"Bromwell High is a cartoon ... ", "pos" | |||||
"Story of a man who has ...", "neg" | |||||
"...", "..." | |||||
:param DataBunlde data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和target两个field,且raw_words列应该为str, | |||||
target列应该为str。 | |||||
:return:DataBundle | |||||
""" | |||||
# 替换<br /> | |||||
def replace_br(raw_words): | |||||
raw_words = raw_words.replace("<br />", ' ') | |||||
return raw_words | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset.apply_field(replace_br, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD) | |||||
_add_words_field(data_bundle, lower=self.lower) | |||||
self._tokenize(data_bundle, field_name=Const.INPUT, new_field_name=Const.INPUT) | |||||
_indexize(data_bundle) | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset.add_seq_len(Const.INPUT) | |||||
dataset.set_input(Const.INPUT, Const.INPUT_LEN) | |||||
dataset.set_target(Const.TARGET) | |||||
return data_bundle | |||||
def process_from_file(self, paths=None): | |||||
""" | |||||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 | |||||
:return: DataBundle | |||||
""" | |||||
# 读取数据 | |||||
data_bundle = IMDBLoader().load(paths) | |||||
data_bundle = self.process(data_bundle) | |||||
return data_bundle | |||||
@@ -0,0 +1,149 @@ | |||||
from .pipe import Pipe | |||||
from .. import DataBundle | |||||
from .utils import iob2, iob2bioes | |||||
from ... import Const | |||||
from ..loader.conll import Conll2003NERLoader, OntoNotesNERLoader | |||||
from .utils import _indexize, _add_words_field | |||||
class _NERPipe(Pipe): | |||||
""" | |||||
NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表 | |||||
(创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将words,target列根据相应的 | |||||
Vocabulary转换为index。 | |||||
raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | |||||
target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target, seq_len。 | |||||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | |||||
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | |||||
:param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为-100。 | |||||
""" | |||||
def __init__(self, encoding_type:str='bio', lower:bool=False, target_pad_val=0): | |||||
if encoding_type == 'bio': | |||||
self.convert_tag = iob2 | |||||
else: | |||||
self.convert_tag = iob2bioes | |||||
self.lower = lower | |||||
self.target_pad_val = int(target_pad_val) | |||||
def process(self, data_bundle:DataBundle)->DataBundle: | |||||
""" | |||||
支持的DataSet的field为 | |||||
.. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader | |||||
:header: "raw_words", "target" | |||||
"[Nadim, Ladki]", "[B-PER, I-PER]" | |||||
"[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]" | |||||
"[...]", "[...]" | |||||
:param DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]。 | |||||
在传入DataBundle基础上原位修改。 | |||||
:return: DataBundle | |||||
Example:: | |||||
data_bundle = Conll2003Loader().load('/path/to/conll2003/') | |||||
data_bundle = Conll2003NERPipe().process(data_bundle) | |||||
# 获取train | |||||
tr_data = data_bundle.get_dataset('train') | |||||
# 获取target这个field的词表 | |||||
target_vocab = data_bundle.get_vocab('target') | |||||
# 获取words这个field的词表 | |||||
word_vocab = data_bundle.get_vocab('words') | |||||
""" | |||||
# 转换tag | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET) | |||||
_add_words_field(data_bundle, lower=self.lower) | |||||
# index | |||||
_indexize(data_bundle) | |||||
input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] | |||||
target_fields = [Const.TARGET, Const.INPUT_LEN] | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset.set_pad_val(Const.TARGET, self.target_pad_val) | |||||
dataset.add_seq_len(Const.INPUT) | |||||
data_bundle.set_input(*input_fields) | |||||
data_bundle.set_target(*target_fields) | |||||
return data_bundle | |||||
def process_from_file(self, paths) -> DataBundle: | |||||
""" | |||||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.ConllLoader` 的load函数。 | |||||
:return: DataBundle | |||||
""" | |||||
# 读取数据 | |||||
data_bundle = Conll2003NERLoader().load(paths) | |||||
data_bundle = self.process(data_bundle) | |||||
return data_bundle | |||||
class Conll2003NERPipe(_NERPipe): | |||||
""" | |||||
Conll2003的NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表 | |||||
(创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将words,target列根据相应的 | |||||
Vocabulary转换为index。 | |||||
经过该Pipe过后,DataSet中的内容如下所示 | |||||
.. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader | |||||
:header: "raw_words", "words", "target", "seq_len" | |||||
"[Nadim, Ladki]", "[1, 2]", "[1, 2]", 2 | |||||
"[AL-AIN, United, Arab, ...]", "[3, 4, 5,...]", "[3, 4]", 10 | |||||
"[...]", "[...]", "[...]", . | |||||
raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | |||||
target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target。 | |||||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | |||||
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | |||||
:param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为-100。 | |||||
""" | |||||
def process_from_file(self, paths) -> DataBundle: | |||||
""" | |||||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.ConllLoader` 的load函数。 | |||||
:return: DataBundle | |||||
""" | |||||
# 读取数据 | |||||
data_bundle = Conll2003NERLoader().load(paths) | |||||
data_bundle = self.process(data_bundle) | |||||
return data_bundle | |||||
class OntoNotesNERPipe(_NERPipe): | |||||
""" | |||||
处理OntoNotes的NER数据,处理之后DataSet中的field情况为 | |||||
.. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader | |||||
:header: "raw_words", "words", "target", "seq_len" | |||||
"[Nadim, Ladki]", "[1, 2]", "[1, 2]", 2 | |||||
"[AL-AIN, United, Arab, ...]", "[3, 4, 5,...]", "[3, 4]", 6 | |||||
"[...]", "[...]", "[...]", . | |||||
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | |||||
:param bool delete_unused_fields: 是否删除NER任务中用不到的field。 | |||||
:param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为-100。 | |||||
""" | |||||
def process_from_file(self, paths): | |||||
data_bundle = OntoNotesNERLoader().load(paths) | |||||
return self.process(data_bundle) | |||||
@@ -0,0 +1,254 @@ | |||||
import math | |||||
from .pipe import Pipe | |||||
from .utils import get_tokenizer | |||||
from ...core import Const | |||||
from ...core import Vocabulary | |||||
from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader | |||||
class MatchingBertPipe(Pipe): | |||||
""" | |||||
Matching任务的Bert pipe,输出的DataSet将包含以下的field | |||||
.. csv-table:: | |||||
:header: "raw_words1", "raw_words2", "words", "target", "seq_len" | |||||
"The new rights are...", "Everyone really likes..", "[2, 3, 4, 5, ...]", 1, 10 | |||||
"This site includes a...", "The Government Executive...", "[11, 12, 13,...]", 0, 5 | |||||
"...", "...", "[...]", ., . | |||||
words列是将raw_words1(即premise), raw_words2(即hypothesis)使用"[SEP]"链接起来转换为index的。 | |||||
words列被设置为input,target列被设置为target. | |||||
:param bool lower: 是否将word小写化。 | |||||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | |||||
:param int max_concat_sent_length: 如果concat后的句子长度超过了该值,则合并后的句子将被截断到这个长度,截断时同时对premise | |||||
和hypothesis按比例截断。 | |||||
""" | |||||
def __init__(self, lower=False, tokenizer:str='raw', max_concat_sent_length:int=480): | |||||
super().__init__() | |||||
self.lower = bool(lower) | |||||
self.tokenizer = get_tokenizer(tokenizer=tokenizer) | |||||
self.max_concat_sent_length = int(max_concat_sent_length) | |||||
def _tokenize(self, data_bundle, field_names, new_field_names): | |||||
""" | |||||
:param DataBundle data_bundle: DataBundle. | |||||
:param list field_names: List[str], 需要tokenize的field名称 | |||||
:param list new_field_names: List[str], tokenize之后field的名称,与field_names一一对应。 | |||||
:return: 输入的DataBundle对象 | |||||
""" | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
for field_name, new_field_name in zip(field_names, new_field_names): | |||||
dataset.apply_field(lambda words:self.tokenizer(words), field_name=field_name, | |||||
new_field_name=new_field_name) | |||||
return data_bundle | |||||
def process(self, data_bundle): | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset.copy_field(Const.RAW_WORDS(0), Const.INPUTS(0)) | |||||
dataset.copy_field(Const.RAW_WORDS(1), Const.INPUTS(1)) | |||||
if self.lower: | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset[Const.INPUTS(0)].lower() | |||||
dataset[Const.INPUTS(1)].lower() | |||||
data_bundle = self._tokenize(data_bundle, [Const.INPUTS(0), Const.INPUT(1)], | |||||
[Const.INPUTS(0), Const.INPUTS(1)]) | |||||
# concat两个words | |||||
def concat(ins): | |||||
words0 = ins[Const.INPUTS(0)] | |||||
words1 = ins[Const.INPUTS(1)] | |||||
len0 = len(words0) | |||||
len1 = len(words1) | |||||
if len0 + len1 > self.max_concat_sent_length: | |||||
ratio = self.max_concat_sent_length / (len0 + len1) | |||||
len0 = math.floor(ratio * len0) | |||||
len1 = math.floor(ratio * len1) | |||||
words0 = words0[:len0] | |||||
words1 = words1[:len1] | |||||
words = words0 + ['[SEP]'] + words1 | |||||
return words | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset.apply(concat, new_field_name=Const.INPUT) | |||||
dataset.delete_field(Const.INPUTS(0)) | |||||
dataset.delete_field(Const.INPUTS(1)) | |||||
word_vocab = Vocabulary() | |||||
word_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.INPUT, | |||||
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if | |||||
name != 'train']) | |||||
word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT) | |||||
target_vocab = Vocabulary(padding=None, unknown=None) | |||||
target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) | |||||
has_target_datasets = [] | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
if dataset.has_field(Const.TARGET): | |||||
has_target_datasets.append(dataset) | |||||
target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) | |||||
data_bundle.set_vocab(word_vocab, Const.INPUT) | |||||
data_bundle.set_vocab(target_vocab, Const.TARGET) | |||||
input_fields = [Const.INPUT, Const.INPUT_LEN] | |||||
target_fields = [Const.TARGET] | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset.add_seq_len(Const.INPUT) | |||||
dataset.set_input(*input_fields, flag=True) | |||||
dataset.set_target(*target_fields, flag=True) | |||||
return data_bundle | |||||
class RTEBertPipe(MatchingBertPipe): | |||||
def process_from_file(self, paths=None): | |||||
data_bundle = RTELoader().load(paths) | |||||
return self.process(data_bundle) | |||||
class SNLIBertPipe(MatchingBertPipe): | |||||
def process_from_file(self, paths=None): | |||||
data_bundle = SNLILoader().load(paths) | |||||
return self.process(data_bundle) | |||||
class QuoraBertPipe(MatchingBertPipe): | |||||
def process_from_file(self, paths): | |||||
data_bundle = QuoraLoader().load(paths) | |||||
return self.process(data_bundle) | |||||
class QNLIBertPipe(MatchingBertPipe): | |||||
def process_from_file(self, paths=None): | |||||
data_bundle = QNLILoader().load(paths) | |||||
return self.process(data_bundle) | |||||
class MNLIBertPipe(MatchingBertPipe): | |||||
def process_from_file(self, paths=None): | |||||
data_bundle = MNLILoader().load(paths) | |||||
return self.process(data_bundle) | |||||
class MatchingPipe(Pipe): | |||||
""" | |||||
Matching任务的Pipe。输出的DataSet将包含以下的field | |||||
.. csv-table:: | |||||
:header: "raw_words1", "raw_words2", "words1", "words2", "target", "seq_len1", "seq_len2" | |||||
"The new rights are...", "Everyone really likes..", "[2, 3, 4, 5, ...]", "[10, 20, 6]", 1, 10, 13 | |||||
"This site includes a...", "The Government Executive...", "[11, 12, 13,...]", "[2, 7, ...]", 0, 6, 7 | |||||
"...", "...", "[...]", "[...]", ., ., . | |||||
words1是premise,words2是hypothesis。其中words1,words2,seq_len1,seq_len2被设置为input;target被设置为target。 | |||||
:param bool lower: 是否将所有raw_words转为小写。 | |||||
:param str tokenizer: 将原始数据tokenize的方式。支持spacy, raw. spacy是使用spacy切分,raw就是用空格切分。 | |||||
""" | |||||
def __init__(self, lower=False, tokenizer:str='raw'): | |||||
super().__init__() | |||||
self.lower = bool(lower) | |||||
self.tokenizer = get_tokenizer(tokenizer=tokenizer) | |||||
def _tokenize(self, data_bundle, field_names, new_field_names): | |||||
""" | |||||
:param DataBundle data_bundle: DataBundle. | |||||
:param list field_names: List[str], 需要tokenize的field名称 | |||||
:param list new_field_names: List[str], tokenize之后field的名称,与field_names一一对应。 | |||||
:return: 输入的DataBundle对象 | |||||
""" | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
for field_name, new_field_name in zip(field_names, new_field_names): | |||||
dataset.apply_field(lambda words:self.tokenizer(words), field_name=field_name, | |||||
new_field_name=new_field_name) | |||||
return data_bundle | |||||
def process(self, data_bundle): | |||||
""" | |||||
接受的DataBundle中的DataSet应该具有以下的field, target列可以没有 | |||||
.. csv-table:: | |||||
:header: "raw_words1", "raw_words2", "target" | |||||
"The new rights are...", "Everyone really likes..", "entailment" | |||||
"This site includes a...", "The Government Executive...", "not_entailment" | |||||
"...", "..." | |||||
:param data_bundle: | |||||
:return: | |||||
""" | |||||
data_bundle = self._tokenize(data_bundle, [Const.RAW_WORDS(0), Const.RAW_WORDS(1)], | |||||
[Const.INPUTS(0), Const.INPUTS(1)]) | |||||
if self.lower: | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset[Const.INPUTS(0)].lower() | |||||
dataset[Const.INPUTS(1)].lower() | |||||
word_vocab = Vocabulary() | |||||
word_vocab.from_dataset(data_bundle.datasets['train'], field_name=[Const.INPUTS(0), Const.INPUTS(1)], | |||||
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if | |||||
name != 'train']) | |||||
word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=[Const.INPUTS(0), Const.INPUTS(1)]) | |||||
target_vocab = Vocabulary(padding=None, unknown=None) | |||||
target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) | |||||
has_target_datasets = [] | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
if dataset.has_field(Const.TARGET): | |||||
has_target_datasets.append(dataset) | |||||
target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) | |||||
data_bundle.set_vocab(word_vocab, Const.INPUTS(0)) | |||||
data_bundle.set_vocab(target_vocab, Const.TARGET) | |||||
input_fields = [Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LEN(0), Const.INPUT_LEN(1)] | |||||
target_fields = [Const.TARGET] | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset.add_seq_len(Const.INPUTS(0), Const.INPUT_LEN(0)) | |||||
dataset.add_seq_len(Const.INPUTS(1), Const.INPUT_LEN(1)) | |||||
dataset.set_input(*input_fields, flag=True) | |||||
dataset.set_target(*target_fields, flag=True) | |||||
return data_bundle | |||||
class RTEPipe(MatchingPipe): | |||||
def process_from_file(self, paths=None): | |||||
data_bundle = RTELoader().load(paths) | |||||
return self.process(data_bundle) | |||||
class SNLIPipe(MatchingPipe): | |||||
def process_from_file(self, paths=None): | |||||
data_bundle = SNLILoader().load(paths) | |||||
return self.process(data_bundle) | |||||
class QuoraPipe(MatchingPipe): | |||||
def process_from_file(self, paths): | |||||
data_bundle = QuoraLoader().load(paths) | |||||
return self.process(data_bundle) | |||||
class QNLIPipe(MatchingPipe): | |||||
def process_from_file(self, paths=None): | |||||
data_bundle = QNLILoader().load(paths) | |||||
return self.process(data_bundle) | |||||
class MNLIPipe(MatchingPipe): | |||||
def process_from_file(self, paths=None): | |||||
data_bundle = MNLILoader().load(paths) | |||||
return self.process(data_bundle) | |||||
@@ -0,0 +1,9 @@ | |||||
from .. import DataBundle | |||||
class Pipe: | |||||
def process(self, data_bundle:DataBundle)->DataBundle: | |||||
raise NotImplementedError | |||||
def process_from_file(self, paths)->DataBundle: | |||||
raise NotImplementedError |
@@ -0,0 +1,142 @@ | |||||
from typing import List | |||||
from ...core import Vocabulary | |||||
from ...core import Const | |||||
def iob2(tags:List[str])->List[str]: | |||||
""" | |||||
检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两种格式的区别见https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format | |||||
:param tags: 需要转换的tags | |||||
""" | |||||
for i, tag in enumerate(tags): | |||||
if tag == "O": | |||||
continue | |||||
split = tag.split("-") | |||||
if len(split) != 2 or split[0] not in ["I", "B"]: | |||||
raise TypeError("The encoding schema is not a valid IOB type.") | |||||
if split[0] == "B": | |||||
continue | |||||
elif i == 0 or tags[i - 1] == "O": # conversion IOB1 to IOB2 | |||||
tags[i] = "B" + tag[1:] | |||||
elif tags[i - 1][1:] == tag[1:]: | |||||
continue | |||||
else: # conversion IOB1 to IOB2 | |||||
tags[i] = "B" + tag[1:] | |||||
return tags | |||||
def iob2bioes(tags:List[str])->List[str]: | |||||
""" | |||||
将iob的tag转换为bioes编码 | |||||
:param tags: | |||||
:return: | |||||
""" | |||||
new_tags = [] | |||||
for i, tag in enumerate(tags): | |||||
if tag == 'O': | |||||
new_tags.append(tag) | |||||
else: | |||||
split = tag.split('-')[0] | |||||
if split == 'B': | |||||
if i+1!=len(tags) and tags[i+1].split('-')[0] == 'I': | |||||
new_tags.append(tag) | |||||
else: | |||||
new_tags.append(tag.replace('B-', 'S-')) | |||||
elif split == 'I': | |||||
if i + 1<len(tags) and tags[i+1].split('-')[0] == 'I': | |||||
new_tags.append(tag) | |||||
else: | |||||
new_tags.append(tag.replace('I-', 'E-')) | |||||
else: | |||||
raise TypeError("Invalid IOB format.") | |||||
return new_tags | |||||
def get_tokenizer(tokenizer:str, lang='en'): | |||||
""" | |||||
:param str tokenizer: 获取tokenzier方法 | |||||
:param str lang: 语言,当前仅支持en | |||||
:return: 返回tokenize函数 | |||||
""" | |||||
if tokenizer == 'spacy': | |||||
import spacy | |||||
spacy.prefer_gpu() | |||||
if lang!='en': | |||||
raise RuntimeError("Spacy only supports en right right.") | |||||
en = spacy.load(lang) | |||||
tokenizer = lambda x: [w.text for w in en.tokenizer(x)] | |||||
elif tokenizer == 'raw': | |||||
tokenizer = _raw_split | |||||
else: | |||||
raise RuntimeError("Only support `spacy`, `raw` tokenizer.") | |||||
return tokenizer | |||||
def _raw_split(sent): | |||||
return sent.split() | |||||
def _indexize(data_bundle): | |||||
""" | |||||
在dataset中的"words"列建立词表,"target"列建立词表,并把词表加入到data_bundle中。 | |||||
:param data_bundle: | |||||
:return: | |||||
""" | |||||
src_vocab = Vocabulary() | |||||
src_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.INPUT, | |||||
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if | |||||
name != 'train']) | |||||
src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT) | |||||
tgt_vocab = Vocabulary(unknown=None, padding=None) | |||||
tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) | |||||
tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.TARGET) | |||||
data_bundle.set_vocab(src_vocab, Const.INPUT) | |||||
data_bundle.set_vocab(tgt_vocab, Const.TARGET) | |||||
return data_bundle | |||||
def _add_words_field(data_bundle, lower=False): | |||||
""" | |||||
给data_bundle中的dataset中复制一列words. 并根据lower参数判断是否需要小写化 | |||||
:param data_bundle: | |||||
:param bool lower:是否要小写化 | |||||
:return: 传入的DataBundle | |||||
""" | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset.copy_field(field_name=Const.RAW_WORD, new_field_name=Const.INPUT) | |||||
if lower: | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset[Const.INPUT].lower() | |||||
return data_bundle | |||||
def _drop_empty_instance(data_bundle, field_name): | |||||
""" | |||||
删除data_bundle的DataSet中存在的某个field为空的情况 | |||||
:param data_bundle: DataBundle | |||||
:param str field_name: 对哪个field进行检查,如果为None,则任意field为空都会删掉 | |||||
:return: 传入的DataBundle | |||||
""" | |||||
def empty_instance(ins): | |||||
if field_name: | |||||
field_value = ins[field_name] | |||||
if field_value in ((), {}, [], ''): | |||||
return True | |||||
return False | |||||
for _, field_value in ins.items(): | |||||
if field_value in ((), {}, [], ''): | |||||
return True | |||||
return False | |||||
for name, dataset in data_bundle.datasets.items(): | |||||
dataset.drop(empty_instance) | |||||
return data_bundle | |||||
@@ -1,9 +1,10 @@ | |||||
import os | import os | ||||
from typing import Union, Dict | from typing import Union, Dict | ||||
from pathlib import Path | |||||
def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||||
def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||||
""" | """ | ||||
检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果 | 检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果 | ||||
{ | { | ||||
@@ -11,13 +12,14 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||||
'test': 'xxx' # 可能有,也可能没有 | 'test': 'xxx' # 可能有,也可能没有 | ||||
... | ... | ||||
} | } | ||||
如果paths为不合法的,将直接进行raise相应的错误 | |||||
如果paths为不合法的,将直接进行raise相应的错误. 如果paths内不包含train也会报错。 | |||||
:param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名 | |||||
:param str paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名 | |||||
中包含train这个字段), test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 | 中包含train这个字段), test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 | ||||
:return: | :return: | ||||
""" | """ | ||||
if isinstance(paths, str): | |||||
if isinstance(paths, (str, Path)): | |||||
paths = os.path.abspath(os.path.expanduser(paths)) | |||||
if os.path.isfile(paths): | if os.path.isfile(paths): | ||||
return {'train': paths} | return {'train': paths} | ||||
elif os.path.isdir(paths): | elif os.path.isdir(paths): | ||||
@@ -37,6 +39,8 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||||
path_pair = ('test', filename) | path_pair = ('test', filename) | ||||
if path_pair: | if path_pair: | ||||
files[path_pair[0]] = os.path.join(paths, path_pair[1]) | files[path_pair[0]] = os.path.join(paths, path_pair[1]) | ||||
if 'train' not in files: | |||||
raise KeyError(f"There is no train file in {paths}.") | |||||
return files | return files | ||||
else: | else: | ||||
raise FileNotFoundError(f"{paths} is not a valid file path.") | raise FileNotFoundError(f"{paths} is not a valid file path.") | ||||
@@ -47,8 +51,10 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||||
raise KeyError("You have to include `train` in your dict.") | raise KeyError("You have to include `train` in your dict.") | ||||
for key, value in paths.items(): | for key, value in paths.items(): | ||||
if isinstance(key, str) and isinstance(value, str): | if isinstance(key, str) and isinstance(value, str): | ||||
value = os.path.abspath(os.path.expanduser(value)) | |||||
if not os.path.isfile(value): | if not os.path.isfile(value): | ||||
raise TypeError(f"{value} is not a valid file.") | raise TypeError(f"{value} is not a valid file.") | ||||
paths[key] = value | |||||
else: | else: | ||||
raise TypeError("All keys and values in paths should be str.") | raise TypeError("All keys and values in paths should be str.") | ||||
return paths | return paths | ||||
@@ -0,0 +1,21 @@ | |||||
import unittest | |||||
from fastNLP import Vocabulary | |||||
from fastNLP.embeddings import ElmoEmbedding | |||||
import torch | |||||
import os | |||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||||
class TestDownload(unittest.TestCase): | |||||
def test_download_small(self): | |||||
# import os | |||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||||
elmo_embed = ElmoEmbedding(vocab, model_dir_or_name='en-small') | |||||
words = torch.LongTensor([[0, 1, 2]]) | |||||
print(elmo_embed(words).size()) | |||||
# 首先保证所有权重可以加载;上传权重;验证可以下载 | |||||
@@ -0,0 +1,19 @@ | |||||
import unittest | |||||
from fastNLP.io.loader.classification import YelpFullLoader | |||||
from fastNLP.io.loader.classification import YelpPolarityLoader | |||||
from fastNLP.io.loader.classification import IMDBLoader | |||||
from fastNLP.io.loader.classification import SST2Loader | |||||
from fastNLP.io.loader.classification import SSTLoader | |||||
import os | |||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||||
class TestDownload(unittest.TestCase): | |||||
def test_download(self): | |||||
for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader]: | |||||
loader().download() | |||||
def test_load(self): | |||||
for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader]: | |||||
data_bundle = loader().load() | |||||
print(data_bundle) |
@@ -0,0 +1,22 @@ | |||||
import unittest | |||||
from fastNLP.io.loader.matching import RTELoader | |||||
from fastNLP.io.loader.matching import QNLILoader | |||||
from fastNLP.io.loader.matching import SNLILoader | |||||
from fastNLP.io.loader.matching import QuoraLoader | |||||
from fastNLP.io.loader.matching import MNLILoader | |||||
import os | |||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||||
class TestDownload(unittest.TestCase): | |||||
def test_download(self): | |||||
for loader in [RTELoader, QNLILoader, SNLILoader, MNLILoader]: | |||||
loader().download() | |||||
with self.assertRaises(Exception): | |||||
QuoraLoader().load() | |||||
def test_load(self): | |||||
for loader in [RTELoader, QNLILoader, SNLILoader, MNLILoader]: | |||||
data_bundle = loader().load() | |||||
print(data_bundle) | |||||
@@ -0,0 +1,13 @@ | |||||
import unittest | |||||
import os | |||||
from fastNLP.io.pipe.classification import SSTPipe, SST2Pipe, IMDBPipe, YelpFullPipe, YelpPolarityPipe | |||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||||
class TestPipe(unittest.TestCase): | |||||
def test_process_from_file(self): | |||||
for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]: | |||||
with self.subTest(pipe=pipe): | |||||
print(pipe) | |||||
data_bundle = pipe(tokenizer='raw').process_from_file() | |||||
print(data_bundle) |
@@ -0,0 +1,26 @@ | |||||
import unittest | |||||
import os | |||||
from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe | |||||
from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe | |||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||||
class TestPipe(unittest.TestCase): | |||||
def test_process_from_file(self): | |||||
for pipe in [SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe]: | |||||
with self.subTest(pipe=pipe): | |||||
print(pipe) | |||||
data_bundle = pipe(tokenizer='raw').process_from_file() | |||||
print(data_bundle) | |||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||||
class TestBertPipe(unittest.TestCase): | |||||
def test_process_from_file(self): | |||||
for pipe in [SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe]: | |||||
with self.subTest(pipe=pipe): | |||||
print(pipe) | |||||
data_bundle = pipe(tokenizer='raw').process_from_file() | |||||
print(data_bundle) |