diff --git a/README.md b/README.md
index 9d949482..a5ce3c64 100644
--- a/README.md
+++ b/README.md
@@ -6,13 +6,14 @@

[](http://fastnlp.readthedocs.io/?badge=latest)
-fastNLP 是一款轻量级的 NLP 处理套件。你既可以使用它快速地完成一个命名实体识别(NER)、中文分词或文本分类任务; 也可以使用他构建许多复杂的网络模型,进行科研。它具有如下的特性:
+fastNLP 是一款轻量级的 NLP 处理套件。你既可以使用它快速地完成一个序列标注([NER](reproduction/seqence_labelling/ner/)、POS-Tagging等)、中文分词、文本分类、[Matching](reproduction/matching/)、指代消解、摘要等任务; 也可以使用它构建许多复杂的网络模型,进行科研。它具有如下的特性:
-- 统一的Tabular式数据容器,让数据预处理过程简洁明了。内置多种数据集的DataSet Loader,省去预处理代码。
-- 各种方便的NLP工具,例如预处理embedding加载; 中间数据cache等;
-- 详尽的中文文档以供查阅;
+- 统一的Tabular式数据容器,让数据预处理过程简洁明了。内置多种数据集的DataSet Loader,省去预处理代码;
+- 多种训练、测试组件,例如训练器Trainer;测试器Tester;以及各种评测metrics等等;
+- 各种方便的NLP工具,例如预处理embedding加载(包括EMLo和BERT); 中间数据cache等;
+- 详尽的中文[文档](https://fastnlp.readthedocs.io/)、教程以供查阅;
- 提供诸多高级模块,例如Variational LSTM, Transformer, CRF等;
-- 封装CNNText,Biaffine等模型可供直接使用;
+- 在序列标注、中文分词、文本分类、Matching、指代消解、摘要等任务上封装了各种模型可供直接使用; [详细链接](reproduction/)
- 便捷且具有扩展性的训练器; 提供多种内置callback函数,方便实验记录、异常捕获等。
@@ -20,13 +21,14 @@ fastNLP 是一款轻量级的 NLP 处理套件。你既可以使用它快速地
fastNLP 依赖如下包:
-+ numpy
-+ torch>=0.4.0
-+ tqdm
-+ nltk
++ numpy>=1.14.2
++ torch>=1.0.0
++ tqdm>=4.28.1
++ nltk>=3.4.1
++ requests
-其中torch的安装可能与操作系统及 CUDA 的版本相关,请参见 PyTorch 官网 。
-在依赖包安装完成的情况,您可以在命令行执行如下指令完成安装
+其中torch的安装可能与操作系统及 CUDA 的版本相关,请参见 [PyTorch 官网](https://pytorch.org/) 。
+在依赖包安装完成后,您可以在命令行执行如下指令完成安装
```shell
pip install fastNLP
@@ -77,8 +79,8 @@ fastNLP 在 modules 模块中内置了三种模块的诸多组件,可以帮助
fastNLP 为不同的 NLP 任务实现了许多完整的模型,它们都经过了训练和测试。
你可以在以下两个地方查看相关信息
-- [介绍](reproduction/)
-- [源码](fastNLP/models/)
+- [模型介绍](reproduction/)
+- [模型源码](fastNLP/models/)
## 项目结构
@@ -93,7 +95,7 @@ fastNLP的大致工作流程如上图所示,而项目结构如下:
fastNLP.core |
- 实现了核心功能,包括数据处理组件、训练器、测速器等 |
+ 实现了核心功能,包括数据处理组件、训练器、测试器等 |
fastNLP.models |
diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py
index 46a72802..14aacef0 100644
--- a/fastNLP/core/losses.py
+++ b/fastNLP/core/losses.py
@@ -20,6 +20,7 @@ from collections import defaultdict
import torch
import torch.nn.functional as F
+from ..core.const import Const
from .utils import _CheckError
from .utils import _CheckRes
from .utils import _build_args
@@ -28,6 +29,7 @@ from .utils import _check_function_or_method
from .utils import _get_func_signature
from .utils import seq_len_to_mask
+
class LossBase(object):
"""
所有loss的基类。如果想了解其中的原理,请查看源码。
@@ -95,22 +97,7 @@ class LossBase(object):
# if func_spect.varargs:
# raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use "
# f"positional argument.).")
-
- def _fast_param_map(self, pred_dict, target_dict):
- """Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map.
- such as pred_dict has one element, target_dict has one element
- :param pred_dict:
- :param target_dict:
- :return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping.
- """
- fast_param = {}
- if len(self._param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1:
- fast_param['pred'] = list(pred_dict.values())[0]
- fast_param['target'] = list(target_dict.values())[0]
- return fast_param
- return fast_param
-
def __call__(self, pred_dict, target_dict, check=False):
"""
:param dict pred_dict: 模型的forward函数返回的dict
@@ -118,11 +105,7 @@ class LossBase(object):
:param Boolean check: 每一次执行映射函数的时候是否检查映射表,默认为不检查
:return:
"""
- fast_param = self._fast_param_map(pred_dict, target_dict)
- if fast_param:
- loss = self.get_loss(**fast_param)
- return loss
-
+
if not self._checked:
# 1. check consistence between signature and _param_map
func_spect = inspect.getfullargspec(self.get_loss)
@@ -212,7 +195,6 @@ class LossFunc(LossBase):
if not isinstance(key_map, dict):
raise RuntimeError(f"Loss error: key_map except a {type({})} but got a {type(key_map)}")
self._init_param_map(key_map, **kwargs)
-
class CrossEntropyLoss(LossBase):
@@ -226,7 +208,7 @@ class CrossEntropyLoss(LossBase):
:param seq_len: 句子的长度, 长度之外的token不会计算loss。。
:param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替
传入seq_len.
- :param str reduction: 支持'elementwise_mean'和'sum'.
+ :param str reduction: 支持'mean','sum'和'none'.
Example::
@@ -234,16 +216,16 @@ class CrossEntropyLoss(LossBase):
"""
- def __init__(self, pred=None, target=None, seq_len=None, padding_idx=-100, reduction='elementwise_mean'):
+ def __init__(self, pred=None, target=None, seq_len=None, padding_idx=-100, reduction='mean'):
super(CrossEntropyLoss, self).__init__()
self._init_param_map(pred=pred, target=target, seq_len=seq_len)
self.padding_idx = padding_idx
- assert reduction in ('elementwise_mean', 'sum')
+ assert reduction in ('mean', 'sum', 'none')
self.reduction = reduction
def get_loss(self, pred, target, seq_len=None):
- if pred.dim()>2:
- if pred.size(1)!=target.size(1):
+ if pred.dim() > 2:
+ if pred.size(1) != target.size(1):
pred = pred.transpose(1, 2)
pred = pred.reshape(-1, pred.size(-1))
target = target.reshape(-1)
@@ -263,15 +245,18 @@ class L1Loss(LossBase):
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` >`target`
+ :param str reduction: 支持'mean','sum'和'none'.
"""
- def __init__(self, pred=None, target=None):
+ def __init__(self, pred=None, target=None, reduction='mean'):
super(L1Loss, self).__init__()
self._init_param_map(pred=pred, target=target)
+ assert reduction in ('mean', 'sum', 'none')
+ self.reduction = reduction
def get_loss(self, pred, target):
- return F.l1_loss(input=pred, target=target)
+ return F.l1_loss(input=pred, target=target, reduction=self.reduction)
class BCELoss(LossBase):
@@ -282,14 +267,17 @@ class BCELoss(LossBase):
:param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred`
:param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target`
+ :param str reduction: 支持'mean','sum'和'none'.
"""
- def __init__(self, pred=None, target=None):
+ def __init__(self, pred=None, target=None, reduction='mean'):
super(BCELoss, self).__init__()
self._init_param_map(pred=pred, target=target)
+ assert reduction in ('mean', 'sum', 'none')
+ self.reduction = reduction
def get_loss(self, pred, target):
- return F.binary_cross_entropy(input=pred, target=target)
+ return F.binary_cross_entropy(input=pred, target=target, reduction=self.reduction)
class NLLLoss(LossBase):
@@ -300,14 +288,20 @@ class NLLLoss(LossBase):
:param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred`
:param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target`
+ :param ignore_idx: ignore的index,在计算loss时将忽略target中标号为ignore_idx的内容, 可以通过该值代替
+ 传入seq_len.
+ :param str reduction: 支持'mean','sum'和'none'.
"""
- def __init__(self, pred=None, target=None):
+ def __init__(self, pred=None, target=None, ignore_idx=-100, reduction='mean'):
super(NLLLoss, self).__init__()
self._init_param_map(pred=pred, target=target)
+ assert reduction in ('mean', 'sum', 'none')
+ self.reduction = reduction
+ self.ignore_idx = ignore_idx
def get_loss(self, pred, target):
- return F.nll_loss(input=pred, target=target)
+ return F.nll_loss(input=pred, target=target, ignore_index=self.ignore_idx, reduction=self.reduction)
class LossInForward(LossBase):
@@ -319,7 +313,7 @@ class LossInForward(LossBase):
:param str loss_key: 在forward函数中loss的键名,默认为loss
"""
- def __init__(self, loss_key='loss'):
+ def __init__(self, loss_key=Const.LOSS):
super().__init__()
if not isinstance(loss_key, str):
raise TypeError(f"Only str allowed for loss_key, got {type(loss_key)}.")
diff --git a/fastNLP/io/__init__.py b/fastNLP/io/__init__.py
index 28f466a8..05d75f43 100644
--- a/fastNLP/io/__init__.py
+++ b/fastNLP/io/__init__.py
@@ -11,21 +11,35 @@
"""
__all__ = [
'EmbedLoader',
-
+
+ 'DataInfo',
'DataSetLoader',
+
'CSVLoader',
'JsonLoader',
'ConllLoader',
- 'SNLILoader',
- 'SSTLoader',
'PeopleDailyCorpusLoader',
'Conll2003Loader',
'ModelLoader',
'ModelSaver',
+
+ 'SSTLoader',
+
+ 'MatchingLoader',
+ 'SNLILoader',
+ 'MNLILoader',
+ 'QNLILoader',
+ 'QuoraLoader',
+ 'RTELoader',
]
from .embed_loader import EmbedLoader
-from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, \
- SNLILoader, SSTLoader, PeopleDailyCorpusLoader, Conll2003Loader
+from .base_loader import DataInfo, DataSetLoader
+from .dataset_loader import CSVLoader, JsonLoader, ConllLoader, \
+ PeopleDailyCorpusLoader, Conll2003Loader
from .model_io import ModelLoader, ModelSaver
+
+from .data_loader.sst import SSTLoader
+from .data_loader.matching import MatchingLoader, SNLILoader, \
+ MNLILoader, QNLILoader, QuoraLoader, RTELoader
diff --git a/fastNLP/io/base_loader.py b/fastNLP/io/base_loader.py
index 465fb7e8..8cff1da1 100644
--- a/fastNLP/io/base_loader.py
+++ b/fastNLP/io/base_loader.py
@@ -10,6 +10,7 @@ from typing import Union, Dict
import os
from ..core.dataset import DataSet
+
class BaseLoader(object):
"""
各个 Loader 的基类,提供了 API 的参考。
@@ -55,8 +56,6 @@ class BaseLoader(object):
return obj
-
-
def _download_from_url(url, path):
try:
from tqdm.auto import tqdm
@@ -115,13 +114,11 @@ class DataInfo:
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。
:param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict
- :param embeddings: 从名称(字符串)到一系列 embedding 的dict,参考 :class:`~fastNLP.io.EmbedLoader`
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict
"""
- def __init__(self, vocabs: dict = None, embeddings: dict = None, datasets: dict = None):
+ def __init__(self, vocabs: dict = None, datasets: dict = None):
self.vocabs = vocabs or {}
- self.embeddings = embeddings or {}
self.datasets = datasets or {}
def __repr__(self):
@@ -133,6 +130,7 @@ class DataInfo:
_str += '\t{} has {} entries.\n'.format(name, len(vocab))
return _str
+
class DataSetLoader:
"""
别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader`
@@ -213,7 +211,6 @@ class DataSetLoader:
返回的 :class:`DataInfo` 对象有如下属性:
- vocabs: 由从数据集中获取的词表组成的字典,每个词表
- - embeddings: (可选) 数据集对应的词嵌入
- datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const`
:param paths: 原始数据读取的路径
diff --git a/fastNLP/io/data_loader/__init__.py b/fastNLP/io/data_loader/__init__.py
new file mode 100644
index 00000000..6f4dd973
--- /dev/null
+++ b/fastNLP/io/data_loader/__init__.py
@@ -0,0 +1,19 @@
+"""
+用于读数据集的模块, 具体包括:
+
+这些模块的使用方法如下:
+"""
+__all__ = [
+ 'SSTLoader',
+
+ 'MatchingLoader',
+ 'SNLILoader',
+ 'MNLILoader',
+ 'QNLILoader',
+ 'QuoraLoader',
+ 'RTELoader',
+]
+
+from .sst import SSTLoader
+from .matching import MatchingLoader, SNLILoader, \
+ MNLILoader, QNLILoader, QuoraLoader, RTELoader
diff --git a/fastNLP/io/data_loader/matching.py b/fastNLP/io/data_loader/matching.py
new file mode 100644
index 00000000..3d131bcb
--- /dev/null
+++ b/fastNLP/io/data_loader/matching.py
@@ -0,0 +1,430 @@
+import os
+
+from typing import Union, Dict
+
+from ...core.const import Const
+from ...core.vocabulary import Vocabulary
+from ..base_loader import DataInfo, DataSetLoader
+from ..dataset_loader import JsonLoader, CSVLoader
+from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR
+from ...modules.encoder._bert import BertTokenizer
+
+
+class MatchingLoader(DataSetLoader):
+ """
+ 别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader`
+
+ 读取Matching任务的数据集
+
+ :param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名
+ """
+
+ def __init__(self, paths: dict=None):
+ self.paths = paths
+
+ def _load(self, path):
+ """
+ :param str path: 待读取数据集的路径名
+ :return: fastNLP.DataSet ds: 返回一个DataSet对象,里面必须包含3个field:其中两个分别为两个句子
+ 的原始字符串文本,第三个为标签
+ """
+ raise NotImplementedError
+
+ def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None,
+ to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None,
+ cut_text: int = None, get_index=True, auto_pad_length: int=None,
+ auto_pad_token: str='', set_input: Union[list, str, bool]=True,
+ set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo:
+ """
+ :param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹,
+ 则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和
+ 对应的全路径文件名。
+ :param str dataset_name: 如果在paths里传入的是一个数据集的全路径文件名,那么可以用dataset_name来定义
+ 这个数据集的名字,如果不定义则默认为train。
+ :param bool to_lower: 是否将文本自动转为小写。默认值为False。
+ :param str seq_len_type: 提供的seq_len类型,支持 ``seq_len`` :提供一个数字作为句子长度; ``mask`` :
+ 提供一个0/1的mask矩阵作为句子长度; ``bert`` :提供segment_type_id(第一个句子为0,第二个句子为1)和
+ attention mask矩阵(0/1的mask矩阵)。默认值为None,即不提供seq_len
+ :param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径
+ :param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。
+ :param bool get_index: 是否需要根据词表将文本转为index
+ :param int auto_pad_length: 是否需要将文本自动pad到一定长度(超过这个长度的文本将会被截掉),默认为不会自动pad
+ :param str auto_pad_token: 自动pad的内容
+ :param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False
+ 则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input,
+ 于此同时其他field不会被设置为input。默认值为True。
+ :param set_target: set_target将控制哪些field可以被设置为target,用法与set_input一致。默认值为True。
+ :param concat: 是否需要将两个句子拼接起来。如果为False则不会拼接。如果为True则会在两个句子之间插入一个。
+ 如果传入一个长度为4的list,则分别表示插在第一句开始前、第一句结束后、第二句开始前、第二句结束后的标识符。如果
+ 传入字符串 ``bert`` ,则会采用bert的拼接方式,等价于['[CLS]', '[SEP]', '', '[SEP]'].
+ :return:
+ """
+ if isinstance(set_input, str):
+ set_input = [set_input]
+ if isinstance(set_target, str):
+ set_target = [set_target]
+ if isinstance(set_input, bool):
+ auto_set_input = set_input
+ else:
+ auto_set_input = False
+ if isinstance(set_target, bool):
+ auto_set_target = set_target
+ else:
+ auto_set_target = False
+ if isinstance(paths, str):
+ if os.path.isdir(paths):
+ path = {n: os.path.join(paths, self.paths[n]) for n in self.paths.keys()}
+ else:
+ path = {dataset_name if dataset_name is not None else 'train': paths}
+ else:
+ path = paths
+
+ data_info = DataInfo()
+ for data_name in path.keys():
+ data_info.datasets[data_name] = self._load(path[data_name])
+
+ for data_name, data_set in data_info.datasets.items():
+ if auto_set_input:
+ data_set.set_input(Const.INPUTS(0), Const.INPUTS(1))
+ if auto_set_target:
+ if Const.TARGET in data_set.get_field_names():
+ data_set.set_target(Const.TARGET)
+
+ if to_lower:
+ for data_name, data_set in data_info.datasets.items():
+ data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0),
+ is_input=auto_set_input)
+ data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1),
+ is_input=auto_set_input)
+
+ if bert_tokenizer is not None:
+ if bert_tokenizer.lower() in PRETRAINED_BERT_MODEL_DIR:
+ PRETRAIN_URL = _get_base_url('bert')
+ model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer]
+ model_url = PRETRAIN_URL + model_name
+ model_dir = cached_path(model_url)
+ # 检查是否存在
+ elif os.path.isdir(bert_tokenizer):
+ model_dir = bert_tokenizer
+ else:
+ raise ValueError(f"Cannot recognize BERT tokenizer from {bert_tokenizer}.")
+
+ words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]')
+ with open(os.path.join(model_dir, 'vocab.txt'), 'r') as f:
+ lines = f.readlines()
+ lines = [line.strip() for line in lines]
+ words_vocab.add_word_lst(lines)
+ words_vocab.build_vocab()
+
+ tokenizer = BertTokenizer.from_pretrained(model_dir)
+
+ for data_name, data_set in data_info.datasets.items():
+ for fields in data_set.get_field_names():
+ if Const.INPUT in fields:
+ data_set.apply(lambda x: tokenizer.tokenize(' '.join(x[fields])), new_field_name=fields,
+ is_input=auto_set_input)
+
+ if isinstance(concat, bool):
+ concat = 'default' if concat else None
+ if concat is not None:
+ if isinstance(concat, str):
+ CONCAT_MAP = {'bert': ['[CLS]', '[SEP]', '', '[SEP]'],
+ 'default': ['', '', '', '']}
+ if concat.lower() in CONCAT_MAP:
+ concat = CONCAT_MAP[concat]
+ else:
+ concat = 4 * [concat]
+ assert len(concat) == 4, \
+ f'Please choose a list with 4 symbols which at the beginning of first sentence ' \
+ f'the end of first sentence, the begin of second sentence, and the end of second' \
+ f'sentence. Your input is {concat}'
+
+ for data_name, data_set in data_info.datasets.items():
+ data_set.apply(lambda x: [concat[0]] + x[Const.INPUTS(0)] + [concat[1]] + [concat[2]] +
+ x[Const.INPUTS(1)] + [concat[3]], new_field_name=Const.INPUT)
+ data_set.apply(lambda x: [w for w in x[Const.INPUT] if len(w) > 0], new_field_name=Const.INPUT,
+ is_input=auto_set_input)
+
+ if seq_len_type is not None:
+ if seq_len_type == 'seq_len': #
+ for data_name, data_set in data_info.datasets.items():
+ for fields in data_set.get_field_names():
+ if Const.INPUT in fields:
+ data_set.apply(lambda x: len(x[fields]),
+ new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN),
+ is_input=auto_set_input)
+ elif seq_len_type == 'mask':
+ for data_name, data_set in data_info.datasets.items():
+ for fields in data_set.get_field_names():
+ if Const.INPUT in fields:
+ data_set.apply(lambda x: [1] * len(x[fields]),
+ new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN),
+ is_input=auto_set_input)
+ elif seq_len_type == 'bert':
+ for data_name, data_set in data_info.datasets.items():
+ if Const.INPUT not in data_set.get_field_names():
+ raise KeyError(f'Field ``{Const.INPUT}`` not in {data_name} data set: '
+ f'got {data_set.get_field_names()}')
+ data_set.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1),
+ new_field_name=Const.INPUT_LENS(0), is_input=auto_set_input)
+ data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]),
+ new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input)
+
+ if auto_pad_length is not None:
+ cut_text = min(auto_pad_length, cut_text if cut_text is not None else auto_pad_length)
+
+ if cut_text is not None:
+ for data_name, data_set in data_info.datasets.items():
+ for fields in data_set.get_field_names():
+ if (Const.INPUT in fields) or ((Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len')):
+ data_set.apply(lambda x: x[fields][: cut_text], new_field_name=fields,
+ is_input=auto_set_input)
+
+ data_set_list = [d for n, d in data_info.datasets.items()]
+ assert len(data_set_list) > 0, f'There are NO data sets in data info!'
+
+ if bert_tokenizer is None:
+ words_vocab = Vocabulary(padding=auto_pad_token)
+ words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n],
+ field_name=[n for n in data_set_list[0].get_field_names()
+ if (Const.INPUT in n)],
+ no_create_entry_dataset=[d for n, d in data_info.datasets.items()
+ if 'train' not in n])
+ target_vocab = Vocabulary(padding=None, unknown=None)
+ target_vocab = target_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n],
+ field_name=Const.TARGET)
+ data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab}
+
+ if get_index:
+ for data_name, data_set in data_info.datasets.items():
+ for fields in data_set.get_field_names():
+ if Const.INPUT in fields:
+ data_set.apply(lambda x: [words_vocab.to_index(w) for w in x[fields]], new_field_name=fields,
+ is_input=auto_set_input)
+
+ if Const.TARGET in data_set.get_field_names():
+ data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET,
+ is_input=auto_set_input, is_target=auto_set_target)
+
+ if auto_pad_length is not None:
+ if seq_len_type == 'seq_len':
+ raise RuntimeError(f'the sequence will be padded with the length {auto_pad_length}, '
+ f'so the seq_len_type cannot be `{seq_len_type}`!')
+ for data_name, data_set in data_info.datasets.items():
+ for fields in data_set.get_field_names():
+ if Const.INPUT in fields:
+ data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] *
+ (auto_pad_length - len(x[fields])), new_field_name=fields,
+ is_input=auto_set_input)
+ elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'):
+ data_set.apply(lambda x: x[fields] + [0] * (auto_pad_length - len(x[fields])),
+ new_field_name=fields, is_input=auto_set_input)
+
+ for data_name, data_set in data_info.datasets.items():
+ if isinstance(set_input, list):
+ data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()])
+ if isinstance(set_target, list):
+ data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()])
+
+ return data_info
+
+
+class SNLILoader(MatchingLoader, JsonLoader):
+ """
+ 别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader`
+
+ 读取SNLI数据集,读取的DataSet包含fields::
+
+ words1: list(str),第一句文本, premise
+ words2: list(str), 第二句文本, hypothesis
+ target: str, 真实标签
+
+ 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
+ """
+
+ def __init__(self, paths: dict=None):
+ fields = {
+ 'sentence1_binary_parse': Const.INPUTS(0),
+ 'sentence2_binary_parse': Const.INPUTS(1),
+ 'gold_label': Const.TARGET,
+ }
+ paths = paths if paths is not None else {
+ 'train': 'snli_1.0_train.jsonl',
+ 'dev': 'snli_1.0_dev.jsonl',
+ 'test': 'snli_1.0_test.jsonl'}
+ MatchingLoader.__init__(self, paths=paths)
+ JsonLoader.__init__(self, fields=fields)
+
+ def _load(self, path):
+ ds = JsonLoader._load(self, path)
+
+ parentheses_table = str.maketrans({'(': None, ')': None})
+
+ ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(),
+ new_field_name=Const.INPUTS(0))
+ ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(),
+ new_field_name=Const.INPUTS(1))
+ ds.drop(lambda x: x[Const.TARGET] == '-')
+ return ds
+
+
+class RTELoader(MatchingLoader, CSVLoader):
+ """
+ 别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.dataset_loader.RTELoader`
+
+ 读取RTE数据集,读取的DataSet包含fields::
+
+ words1: list(str),第一句文本, premise
+ words2: list(str), 第二句文本, hypothesis
+ target: str, 真实标签
+
+ 数据来源:
+ """
+
+ def __init__(self, paths: dict=None):
+ paths = paths if paths is not None else {
+ 'train': 'train.tsv',
+ 'dev': 'dev.tsv',
+ 'test': 'test.tsv' # test set has not label
+ }
+ MatchingLoader.__init__(self, paths=paths)
+ self.fields = {
+ 'sentence1': Const.INPUTS(0),
+ 'sentence2': Const.INPUTS(1),
+ 'label': Const.TARGET,
+ }
+ CSVLoader.__init__(self, sep='\t')
+
+ def _load(self, path):
+ ds = CSVLoader._load(self, path)
+
+ for k, v in self.fields.items():
+ if v in ds.get_field_names():
+ ds.rename_field(k, v)
+ for fields in ds.get_all_fields():
+ if Const.INPUT in fields:
+ ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields)
+
+ return ds
+
+
+class QNLILoader(MatchingLoader, CSVLoader):
+ """
+ 别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.dataset_loader.QNLILoader`
+
+ 读取QNLI数据集,读取的DataSet包含fields::
+
+ words1: list(str),第一句文本, premise
+ words2: list(str), 第二句文本, hypothesis
+ target: str, 真实标签
+
+ 数据来源:
+ """
+
+ def __init__(self, paths: dict=None):
+ paths = paths if paths is not None else {
+ 'train': 'train.tsv',
+ 'dev': 'dev.tsv',
+ 'test': 'test.tsv' # test set has not label
+ }
+ MatchingLoader.__init__(self, paths=paths)
+ self.fields = {
+ 'question': Const.INPUTS(0),
+ 'sentence': Const.INPUTS(1),
+ 'label': Const.TARGET,
+ }
+ CSVLoader.__init__(self, sep='\t')
+
+ def _load(self, path):
+ ds = CSVLoader._load(self, path)
+
+ for k, v in self.fields.items():
+ if v in ds.get_field_names():
+ ds.rename_field(k, v)
+ for fields in ds.get_all_fields():
+ if Const.INPUT in fields:
+ ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields)
+
+ return ds
+
+
+class MNLILoader(MatchingLoader, CSVLoader):
+ """
+ 别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader`
+
+ 读取MNLI数据集,读取的DataSet包含fields::
+
+ words1: list(str),第一句文本, premise
+ words2: list(str), 第二句文本, hypothesis
+ target: str, 真实标签
+
+ 数据来源:
+ """
+
+ def __init__(self, paths: dict=None):
+ paths = paths if paths is not None else {
+ 'train': 'train.tsv',
+ 'dev_matched': 'dev_matched.tsv',
+ 'dev_mismatched': 'dev_mismatched.tsv',
+ 'test_matched': 'test_matched.tsv',
+ 'test_mismatched': 'test_mismatched.tsv',
+ # 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt',
+ # 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt',
+
+ # test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle)
+ }
+ MatchingLoader.__init__(self, paths=paths)
+ CSVLoader.__init__(self, sep='\t')
+ self.fields = {
+ 'sentence1_binary_parse': Const.INPUTS(0),
+ 'sentence2_binary_parse': Const.INPUTS(1),
+ 'gold_label': Const.TARGET,
+ }
+
+ def _load(self, path):
+ ds = CSVLoader._load(self, path)
+
+ for k, v in self.fields.items():
+ if k in ds.get_field_names():
+ ds.rename_field(k, v)
+
+ if Const.TARGET in ds.get_field_names():
+ if ds[0][Const.TARGET] == 'hidden':
+ ds.delete_field(Const.TARGET)
+
+ parentheses_table = str.maketrans({'(': None, ')': None})
+
+ ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(),
+ new_field_name=Const.INPUTS(0))
+ ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(),
+ new_field_name=Const.INPUTS(1))
+ if Const.TARGET in ds.get_field_names():
+ ds.drop(lambda x: x[Const.TARGET] == '-')
+ return ds
+
+
+class QuoraLoader(MatchingLoader, CSVLoader):
+ """
+ 别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.dataset_loader.QuoraLoader`
+
+ 读取MNLI数据集,读取的DataSet包含fields::
+
+ words1: list(str),第一句文本, premise
+ words2: list(str), 第二句文本, hypothesis
+ target: str, 真实标签
+
+ 数据来源:
+ """
+
+ def __init__(self, paths: dict=None):
+ paths = paths if paths is not None else {
+ 'train': 'train.tsv',
+ 'dev': 'dev.tsv',
+ 'test': 'test.tsv',
+ }
+ MatchingLoader.__init__(self, paths=paths)
+ CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID'))
+
+ def _load(self, path):
+ ds = CSVLoader._load(self, path)
+ return ds
diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py
index 558fe20e..2881e6e9 100644
--- a/fastNLP/io/dataset_loader.py
+++ b/fastNLP/io/dataset_loader.py
@@ -16,8 +16,6 @@ __all__ = [
'CSVLoader',
'JsonLoader',
'ConllLoader',
- 'SNLILoader',
- 'SSTLoader',
'PeopleDailyCorpusLoader',
'Conll2003Loader',
]
@@ -30,7 +28,6 @@ from ..core.dataset import DataSet
from ..core.instance import Instance
from .file_reader import _read_csv, _read_json, _read_conll
from .base_loader import DataSetLoader, DataInfo
-from .data_loader.sst import SSTLoader
from ..core.const import Const
from ..modules.encoder._bert import BertTokenizer
@@ -111,7 +108,7 @@ class PeopleDailyCorpusLoader(DataSetLoader):
else:
instance = Instance(words=sent_words)
data_set.append(instance)
- data_set.apply(lambda ins: len(ins["words"]), new_field_name="seq_len")
+ data_set.apply(lambda ins: len(ins[Const.INPUT]), new_field_name=Const.INPUT_LEN)
return data_set
@@ -249,42 +246,6 @@ class JsonLoader(DataSetLoader):
return ds
-class SNLILoader(JsonLoader):
- """
- 别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader`
-
- 读取SNLI数据集,读取的DataSet包含fields::
-
- words1: list(str),第一句文本, premise
- words2: list(str), 第二句文本, hypothesis
- target: str, 真实标签
-
- 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
- """
-
- def __init__(self):
- fields = {
- 'sentence1_parse': Const.INPUTS(0),
- 'sentence2_parse': Const.INPUTS(1),
- 'gold_label': Const.TARGET,
- }
- super(SNLILoader, self).__init__(fields=fields)
-
- def _load(self, path):
- ds = super(SNLILoader, self)._load(path)
-
- def parse_tree(x):
- t = Tree.fromstring(x)
- return t.leaves()
-
- ds.apply(lambda ins: parse_tree(
- ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0))
- ds.apply(lambda ins: parse_tree(
- ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1))
- ds.drop(lambda x: x[Const.TARGET] == '-')
- return ds
-
-
class CSVLoader(DataSetLoader):
"""
别名::class:`fastNLP.io.CSVLoader` :class:`fastNLP.io.dataset_loader.CSVLoader`
diff --git a/fastNLP/models/bert.py b/fastNLP/models/bert.py
index 4846c7fa..fb186ce4 100644
--- a/fastNLP/models/bert.py
+++ b/fastNLP/models/bert.py
@@ -8,35 +8,7 @@ from torch import nn
from .base_model import BaseModel
from ..core.const import Const
from ..modules.encoder import BertModel
-
-
-class BertConfig:
-
- def __init__(
- self,
- vocab_size=30522,
- hidden_size=768,
- num_hidden_layers=12,
- num_attention_heads=12,
- intermediate_size=3072,
- hidden_act="gelu",
- hidden_dropout_prob=0.1,
- attention_probs_dropout_prob=0.1,
- max_position_embeddings=512,
- type_vocab_size=2,
- initializer_range=0.02
- ):
- self.vocab_size = vocab_size
- self.hidden_size = hidden_size
- self.num_hidden_layers = num_hidden_layers
- self.num_attention_heads = num_attention_heads
- self.intermediate_size = intermediate_size
- self.hidden_act = hidden_act
- self.hidden_dropout_prob = hidden_dropout_prob
- self.attention_probs_dropout_prob = attention_probs_dropout_prob
- self.max_position_embeddings = max_position_embeddings
- self.type_vocab_size = type_vocab_size
- self.initializer_range = initializer_range
+from ..modules.encoder._bert import BertConfig
class BertForSequenceClassification(BaseModel):
@@ -84,11 +56,17 @@ class BertForSequenceClassification(BaseModel):
self.bert = BertModel.from_pretrained(bert_dir)
else:
if config is None:
- config = BertConfig()
- self.bert = BertModel(**config.__dict__)
+ config = BertConfig(30522)
+ self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
+ @classmethod
+ def from_pretrained(cls, num_labels, pretrained_model_dir):
+ config = BertConfig(pretrained_model_dir)
+ model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir)
+ return model
+
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
pooled_output = self.dropout(pooled_output)
@@ -151,11 +129,17 @@ class BertForMultipleChoice(BaseModel):
self.bert = BertModel.from_pretrained(bert_dir)
else:
if config is None:
- config = BertConfig()
- self.bert = BertModel(**config.__dict__)
+ config = BertConfig(30522)
+ self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)
+ @classmethod
+ def from_pretrained(cls, num_choices, pretrained_model_dir):
+ config = BertConfig(pretrained_model_dir)
+ model = cls(num_choices=num_choices, config=config, bert_dir=pretrained_model_dir)
+ return model
+
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
@@ -224,11 +208,17 @@ class BertForTokenClassification(BaseModel):
self.bert = BertModel.from_pretrained(bert_dir)
else:
if config is None:
- config = BertConfig()
- self.bert = BertModel(**config.__dict__)
+ config = BertConfig(30522)
+ self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
+ @classmethod
+ def from_pretrained(cls, num_labels, pretrained_model_dir):
+ config = BertConfig(pretrained_model_dir)
+ model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir)
+ return model
+
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
sequence_output = self.dropout(sequence_output)
@@ -302,12 +292,18 @@ class BertForQuestionAnswering(BaseModel):
self.bert = BertModel.from_pretrained(bert_dir)
else:
if config is None:
- config = BertConfig()
- self.bert = BertModel(**config.__dict__)
+ config = BertConfig(30522)
+ self.bert = BertModel(config)
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
+ @classmethod
+ def from_pretrained(cls, pretrained_model_dir):
+ config = BertConfig(pretrained_model_dir)
+ model = cls(config=config, bert_dir=pretrained_model_dir)
+ return model
+
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
logits = self.qa_outputs(sequence_output)
diff --git a/fastNLP/modules/decoder/mlp.py b/fastNLP/modules/decoder/mlp.py
index c1579224..418b3a77 100644
--- a/fastNLP/modules/decoder/mlp.py
+++ b/fastNLP/modules/decoder/mlp.py
@@ -15,7 +15,8 @@ class MLP(nn.Module):
多层感知器
:param List[int] size_layer: 一个int的列表,用来定义MLP的层数,列表中的数字为每一层是hidden数目。MLP的层数为 len(size_layer) - 1
- :param Union[str,func,List[str]] activation: 一个字符串或者函数的列表,用来定义每一个隐层的激活函数,字符串包括relu,tanh和sigmoid,默认值为relu
+ :param Union[str,func,List[str]] activation: 一个字符串或者函数的列表,用来定义每一个隐层的激活函数,字符串包括relu,tanh和
+ sigmoid,默认值为relu
:param Union[str,func] output_activation: 字符串或者函数,用来定义输出层的激活函数,默认值为None,表示输出层没有激活函数
:param str initial_method: 参数初始化方式
:param float dropout: dropout概率,默认值为0
diff --git a/fastNLP/modules/encoder/_bert.py b/fastNLP/modules/encoder/_bert.py
index 4669b511..61a5d7d1 100644
--- a/fastNLP/modules/encoder/_bert.py
+++ b/fastNLP/modules/encoder/_bert.py
@@ -26,6 +26,7 @@ import sys
CONFIG_FILE = 'bert_config.json'
+
class BertConfig(object):
"""Configuration class to store the configuration of a `BertModel`.
"""
@@ -339,13 +340,19 @@ class BertModel(nn.Module):
如果你想使用预训练好的权重矩阵,请在以下网址下载.
sources::
- 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
- 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
- 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
- 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
- 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
- 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
- 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
+ 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
+ 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
+ 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
+ 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
+ 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
+ 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
+ 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
+ 'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-pytorch_model.bin",
+ 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
+ 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
+ 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
+ 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
+ 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin"
用预训练权重矩阵来建立BERT模型::
@@ -562,6 +569,7 @@ class WordpieceTokenizer(object):
output_tokens.extend(sub_tokens)
return output_tokens
+
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
@@ -692,6 +700,7 @@ class BasicTokenizer(object):
output.append(char)
return "".join(output)
+
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
diff --git a/reproduction/README.md b/reproduction/README.md
index 92652fb4..b6f61903 100644
--- a/reproduction/README.md
+++ b/reproduction/README.md
@@ -3,6 +3,8 @@
复现的模型有:
- [Star-Transformer](Star_transformer/)
+- [Biaffine](https://github.com/fastnlp/fastNLP/blob/999a14381747068e9e6a7cc370037b320197db00/fastNLP/models/biaffine_parser.py#L239)
+- [CNNText](https://github.com/fastnlp/fastNLP/blob/999a14381747068e9e6a7cc370037b320197db00/fastNLP/models/cnn_text_classification.py#L12)
- ...
# 任务复现
@@ -11,11 +13,11 @@
## Matching (自然语言推理/句子匹配)
-- [Matching 任务复现](matching/)
+- [Matching 任务复现](matching)
## Sequence Labeling (序列标注)
-- still in progress
+- [NER](seqence_labelling/ner)
## Coreference resolution (指代消解)
diff --git a/reproduction/Star_transformer/datasets.py b/reproduction/Star_transformer/datasets.py
index a9257fd4..1532a041 100644
--- a/reproduction/Star_transformer/datasets.py
+++ b/reproduction/Star_transformer/datasets.py
@@ -2,7 +2,8 @@ import torch
import json
import os
from fastNLP import Vocabulary
-from fastNLP.io.dataset_loader import ConllLoader, SSTLoader, SNLILoader
+from fastNLP.io.dataset_loader import ConllLoader
+from fastNLP.io.data_loader import SSTLoader, SNLILoader
from fastNLP.core import Const as C
import numpy as np
diff --git a/reproduction/matching/data/MatchingDataLoader.py b/reproduction/matching/data/MatchingDataLoader.py
index 9d948ec1..7c32899c 100644
--- a/reproduction/matching/data/MatchingDataLoader.py
+++ b/reproduction/matching/data/MatchingDataLoader.py
@@ -16,12 +16,11 @@ class MatchingLoader(DataSetLoader):
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader`
读取Matching任务的数据集
+
+ :param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名
"""
def __init__(self, paths: dict=None):
- """
- :param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名
- """
self.paths = paths
def _load(self, path):
@@ -173,7 +172,7 @@ class MatchingLoader(DataSetLoader):
new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input)
if auto_pad_length is not None:
- cut_text = min(auto_pad_length, cut_text if cut_text is not None else 0)
+ cut_text = min(auto_pad_length, cut_text if cut_text is not None else auto_pad_length)
if cut_text is not None:
for data_name, data_set in data_info.datasets.items():
@@ -209,15 +208,18 @@ class MatchingLoader(DataSetLoader):
is_input=auto_set_input, is_target=auto_set_target)
if auto_pad_length is not None:
+ if seq_len_type == 'seq_len':
+ raise RuntimeError(f'the sequence will be padded with the length {auto_pad_length}, '
+ f'so the seq_len_type cannot be `{seq_len_type}`!')
for data_name, data_set in data_info.datasets.items():
for fields in data_set.get_field_names():
if Const.INPUT in fields:
- data_set.apply(lambda x: x[fields] + [words_vocab.padding] * (auto_pad_length - len(x[fields])),
- new_field_name=fields, is_input=auto_set_input)
- elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'):
data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] *
(auto_pad_length - len(x[fields])), new_field_name=fields,
is_input=auto_set_input)
+ elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'):
+ data_set.apply(lambda x: x[fields] + [0] * (auto_pad_length - len(x[fields])),
+ new_field_name=fields, is_input=auto_set_input)
for data_name, data_set in data_info.datasets.items():
if isinstance(set_input, list):
@@ -284,7 +286,7 @@ class RTELoader(MatchingLoader, CSVLoader):
paths = paths if paths is not None else {
'train': 'train.tsv',
'dev': 'dev.tsv',
- # 'test': 'test.tsv' # test set has not label
+ 'test': 'test.tsv' # test set has not label
}
MatchingLoader.__init__(self, paths=paths)
self.fields = {
@@ -298,7 +300,8 @@ class RTELoader(MatchingLoader, CSVLoader):
ds = CSVLoader._load(self, path)
for k, v in self.fields.items():
- ds.rename_field(k, v)
+ if v in ds.get_field_names():
+ ds.rename_field(k, v)
for fields in ds.get_all_fields():
if Const.INPUT in fields:
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields)
@@ -323,7 +326,7 @@ class QNLILoader(MatchingLoader, CSVLoader):
paths = paths if paths is not None else {
'train': 'train.tsv',
'dev': 'dev.tsv',
- # 'test': 'test.tsv' # test set has not label
+ 'test': 'test.tsv' # test set has not label
}
MatchingLoader.__init__(self, paths=paths)
self.fields = {
@@ -337,7 +340,8 @@ class QNLILoader(MatchingLoader, CSVLoader):
ds = CSVLoader._load(self, path)
for k, v in self.fields.items():
- ds.rename_field(k, v)
+ if v in ds.get_field_names():
+ ds.rename_field(k, v)
for fields in ds.get_all_fields():
if Const.INPUT in fields:
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields)
@@ -349,7 +353,7 @@ class MNLILoader(MatchingLoader, CSVLoader):
"""
别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader`
- 读取SNLI数据集,读取的DataSet包含fields::
+ 读取MNLI数据集,读取的DataSet包含fields::
words1: list(str),第一句文本, premise
words2: list(str), 第二句文本, hypothesis
@@ -367,6 +371,7 @@ class MNLILoader(MatchingLoader, CSVLoader):
'test_mismatched': 'test_mismatched.tsv',
# 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt',
# 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt',
+
# test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle)
}
MatchingLoader.__init__(self, paths=paths)
@@ -400,6 +405,17 @@ class MNLILoader(MatchingLoader, CSVLoader):
class QuoraLoader(MatchingLoader, CSVLoader):
+ """
+ 别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.dataset_loader.QuoraLoader`
+
+ 读取MNLI数据集,读取的DataSet包含fields::
+
+ words1: list(str),第一句文本, premise
+ words2: list(str), 第二句文本, hypothesis
+ target: str, 真实标签
+
+ 数据来源:
+ """
def __init__(self, paths: dict=None):
paths = paths if paths is not None else {
diff --git a/requirements.txt b/requirements.txt
index 7ea8fdac..f8f7a951 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,5 @@
-numpy
-torch>=0.4.0
-tqdm
-nltk
+numpy>=1.14.2
+torch>=1.0.0
+tqdm>=4.28.1
+nltk>=3.4.1
requests
diff --git a/test/io/test_dataset_loader.py b/test/io/test_dataset_loader.py
index 7cff3c12..09ad8c83 100644
--- a/test/io/test_dataset_loader.py
+++ b/test/io/test_dataset_loader.py
@@ -1,7 +1,7 @@
import unittest
import os
-from fastNLP.io import Conll2003Loader, PeopleDailyCorpusLoader, CSVLoader, SNLILoader, JsonLoader
-from fastNLP.io.dataset_loader import SSTLoader
+from fastNLP.io import Conll2003Loader, PeopleDailyCorpusLoader, CSVLoader, JsonLoader
+from fastNLP.io.data_loader import SSTLoader, SNLILoader
from reproduction.text_classification.data.yelpLoader import yelpLoader
@@ -61,3 +61,12 @@ class TestDatasetLoader(unittest.TestCase):
print(info.vocabs)
print(info.datasets)
os.remove(train), os.remove(test)
+
+ def test_import(self):
+ import fastNLP
+ from fastNLP.io import SNLILoader
+ ds = SNLILoader().process('test/data_for_tests/sample_snli.jsonl', to_lower=True,
+ get_index=True, seq_len_type='seq_len')
+ assert 'train' in ds.datasets
+ assert len(ds.datasets) == 1
+ assert len(ds.datasets['train']) == 3
diff --git a/test/models/test_bert.py b/test/models/test_bert.py
index 7177f31b..38a16f9b 100644
--- a/test/models/test_bert.py
+++ b/test/models/test_bert.py
@@ -8,8 +8,9 @@ from fastNLP.models.bert import *
class TestBert(unittest.TestCase):
def test_bert_1(self):
from fastNLP.core.const import Const
+ from fastNLP.modules.encoder._bert import BertConfig
- model = BertForSequenceClassification(2)
+ model = BertForSequenceClassification(2, BertConfig(32000))
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
@@ -22,8 +23,9 @@ class TestBert(unittest.TestCase):
def test_bert_2(self):
from fastNLP.core.const import Const
+ from fastNLP.modules.encoder._bert import BertConfig
- model = BertForMultipleChoice(2)
+ model = BertForMultipleChoice(2, BertConfig(32000))
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
@@ -36,8 +38,9 @@ class TestBert(unittest.TestCase):
def test_bert_3(self):
from fastNLP.core.const import Const
+ from fastNLP.modules.encoder._bert import BertConfig
- model = BertForTokenClassification(7)
+ model = BertForTokenClassification(7, BertConfig(32000))
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
@@ -50,8 +53,9 @@ class TestBert(unittest.TestCase):
def test_bert_4(self):
from fastNLP.core.const import Const
+ from fastNLP.modules.encoder._bert import BertConfig
- model = BertForQuestionAnswering()
+ model = BertForQuestionAnswering(BertConfig(32000))
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
diff --git a/test/modules/encoder/test_bert.py b/test/modules/encoder/test_bert.py
index 78bcf633..2a799478 100644
--- a/test/modules/encoder/test_bert.py
+++ b/test/modules/encoder/test_bert.py
@@ -8,8 +8,9 @@ from fastNLP.models.bert import BertModel
class TestBert(unittest.TestCase):
def test_bert_1(self):
- model = BertModel(vocab_size=32000, hidden_size=768,
- num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
+ from fastNLP.modules.encoder._bert import BertConfig
+ config = BertConfig(32000)
+ model = BertModel(config)
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
@@ -18,4 +19,4 @@ class TestBert(unittest.TestCase):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
for layer in all_encoder_layers:
self.assertEqual(tuple(layer.shape), (2, 3, 768))
- self.assertEqual(tuple(pooled_output.shape), (2, 768))
\ No newline at end of file
+ self.assertEqual(tuple(pooled_output.shape), (2, 768))