@@ -6,13 +6,14 @@ | |||||
 |  | ||||
[](http://fastnlp.readthedocs.io/?badge=latest) | [](http://fastnlp.readthedocs.io/?badge=latest) | ||||
fastNLP 是一款轻量级的 NLP 处理套件。你既可以使用它快速地完成一个命名实体识别(NER)、中文分词或文本分类任务; 也可以使用他构建许多复杂的网络模型,进行科研。它具有如下的特性: | |||||
fastNLP 是一款轻量级的 NLP 处理套件。你既可以使用它快速地完成一个序列标注([NER](reproduction/seqence_labelling/ner/)、POS-Tagging等)、中文分词、文本分类、[Matching](reproduction/matching/)、指代消解、摘要等任务; 也可以使用它构建许多复杂的网络模型,进行科研。它具有如下的特性: | |||||
- 统一的Tabular式数据容器,让数据预处理过程简洁明了。内置多种数据集的DataSet Loader,省去预处理代码。 | |||||
- 各种方便的NLP工具,例如预处理embedding加载; 中间数据cache等; | |||||
- 详尽的中文文档以供查阅; | |||||
- 统一的Tabular式数据容器,让数据预处理过程简洁明了。内置多种数据集的DataSet Loader,省去预处理代码; | |||||
- 多种训练、测试组件,例如训练器Trainer;测试器Tester;以及各种评测metrics等等; | |||||
- 各种方便的NLP工具,例如预处理embedding加载(包括EMLo和BERT); 中间数据cache等; | |||||
- 详尽的中文[文档](https://fastnlp.readthedocs.io/)、教程以供查阅; | |||||
- 提供诸多高级模块,例如Variational LSTM, Transformer, CRF等; | - 提供诸多高级模块,例如Variational LSTM, Transformer, CRF等; | ||||
- 封装CNNText,Biaffine等模型可供直接使用; | |||||
- 在序列标注、中文分词、文本分类、Matching、指代消解、摘要等任务上封装了各种模型可供直接使用; [详细链接](reproduction/) | |||||
- 便捷且具有扩展性的训练器; 提供多种内置callback函数,方便实验记录、异常捕获等。 | - 便捷且具有扩展性的训练器; 提供多种内置callback函数,方便实验记录、异常捕获等。 | ||||
@@ -20,13 +21,14 @@ fastNLP 是一款轻量级的 NLP 处理套件。你既可以使用它快速地 | |||||
fastNLP 依赖如下包: | fastNLP 依赖如下包: | ||||
+ numpy | |||||
+ torch>=0.4.0 | |||||
+ tqdm | |||||
+ nltk | |||||
+ numpy>=1.14.2 | |||||
+ torch>=1.0.0 | |||||
+ tqdm>=4.28.1 | |||||
+ nltk>=3.4.1 | |||||
+ requests | |||||
其中torch的安装可能与操作系统及 CUDA 的版本相关,请参见 PyTorch 官网 。 | |||||
在依赖包安装完成的情况,您可以在命令行执行如下指令完成安装 | |||||
其中torch的安装可能与操作系统及 CUDA 的版本相关,请参见 [PyTorch 官网](https://pytorch.org/) 。 | |||||
在依赖包安装完成后,您可以在命令行执行如下指令完成安装 | |||||
```shell | ```shell | ||||
pip install fastNLP | pip install fastNLP | ||||
@@ -77,8 +79,8 @@ fastNLP 在 modules 模块中内置了三种模块的诸多组件,可以帮助 | |||||
fastNLP 为不同的 NLP 任务实现了许多完整的模型,它们都经过了训练和测试。 | fastNLP 为不同的 NLP 任务实现了许多完整的模型,它们都经过了训练和测试。 | ||||
你可以在以下两个地方查看相关信息 | 你可以在以下两个地方查看相关信息 | ||||
- [介绍](reproduction/) | |||||
- [源码](fastNLP/models/) | |||||
- [模型介绍](reproduction/) | |||||
- [模型源码](fastNLP/models/) | |||||
## 项目结构 | ## 项目结构 | ||||
@@ -93,7 +95,7 @@ fastNLP的大致工作流程如上图所示,而项目结构如下: | |||||
</tr> | </tr> | ||||
<tr> | <tr> | ||||
<td><b> fastNLP.core </b></td> | <td><b> fastNLP.core </b></td> | ||||
<td> 实现了核心功能,包括数据处理组件、训练器、测速器等 </td> | |||||
<td> 实现了核心功能,包括数据处理组件、训练器、测试器等 </td> | |||||
</tr> | </tr> | ||||
<tr> | <tr> | ||||
<td><b> fastNLP.models </b></td> | <td><b> fastNLP.models </b></td> | ||||
@@ -20,6 +20,7 @@ from collections import defaultdict | |||||
import torch | import torch | ||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from ..core.const import Const | |||||
from .utils import _CheckError | from .utils import _CheckError | ||||
from .utils import _CheckRes | from .utils import _CheckRes | ||||
from .utils import _build_args | from .utils import _build_args | ||||
@@ -28,6 +29,7 @@ from .utils import _check_function_or_method | |||||
from .utils import _get_func_signature | from .utils import _get_func_signature | ||||
from .utils import seq_len_to_mask | from .utils import seq_len_to_mask | ||||
class LossBase(object): | class LossBase(object): | ||||
""" | """ | ||||
所有loss的基类。如果想了解其中的原理,请查看源码。 | 所有loss的基类。如果想了解其中的原理,请查看源码。 | ||||
@@ -95,22 +97,7 @@ class LossBase(object): | |||||
# if func_spect.varargs: | # if func_spect.varargs: | ||||
# raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use " | # raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use " | ||||
# f"positional argument.).") | # f"positional argument.).") | ||||
def _fast_param_map(self, pred_dict, target_dict): | |||||
"""Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | |||||
such as pred_dict has one element, target_dict has one element | |||||
:param pred_dict: | |||||
:param target_dict: | |||||
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. | |||||
""" | |||||
fast_param = {} | |||||
if len(self._param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||||
fast_param['pred'] = list(pred_dict.values())[0] | |||||
fast_param['target'] = list(target_dict.values())[0] | |||||
return fast_param | |||||
return fast_param | |||||
def __call__(self, pred_dict, target_dict, check=False): | def __call__(self, pred_dict, target_dict, check=False): | ||||
""" | """ | ||||
:param dict pred_dict: 模型的forward函数返回的dict | :param dict pred_dict: 模型的forward函数返回的dict | ||||
@@ -118,11 +105,7 @@ class LossBase(object): | |||||
:param Boolean check: 每一次执行映射函数的时候是否检查映射表,默认为不检查 | :param Boolean check: 每一次执行映射函数的时候是否检查映射表,默认为不检查 | ||||
:return: | :return: | ||||
""" | """ | ||||
fast_param = self._fast_param_map(pred_dict, target_dict) | |||||
if fast_param: | |||||
loss = self.get_loss(**fast_param) | |||||
return loss | |||||
if not self._checked: | if not self._checked: | ||||
# 1. check consistence between signature and _param_map | # 1. check consistence between signature and _param_map | ||||
func_spect = inspect.getfullargspec(self.get_loss) | func_spect = inspect.getfullargspec(self.get_loss) | ||||
@@ -212,7 +195,6 @@ class LossFunc(LossBase): | |||||
if not isinstance(key_map, dict): | if not isinstance(key_map, dict): | ||||
raise RuntimeError(f"Loss error: key_map except a {type({})} but got a {type(key_map)}") | raise RuntimeError(f"Loss error: key_map except a {type({})} but got a {type(key_map)}") | ||||
self._init_param_map(key_map, **kwargs) | self._init_param_map(key_map, **kwargs) | ||||
class CrossEntropyLoss(LossBase): | class CrossEntropyLoss(LossBase): | ||||
@@ -226,7 +208,7 @@ class CrossEntropyLoss(LossBase): | |||||
:param seq_len: 句子的长度, 长度之外的token不会计算loss。。 | :param seq_len: 句子的长度, 长度之外的token不会计算loss。。 | ||||
:param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替 | :param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替 | ||||
传入seq_len. | 传入seq_len. | ||||
:param str reduction: 支持'elementwise_mean'和'sum'. | |||||
:param str reduction: 支持'mean','sum'和'none'. | |||||
Example:: | Example:: | ||||
@@ -234,16 +216,16 @@ class CrossEntropyLoss(LossBase): | |||||
""" | """ | ||||
def __init__(self, pred=None, target=None, seq_len=None, padding_idx=-100, reduction='elementwise_mean'): | |||||
def __init__(self, pred=None, target=None, seq_len=None, padding_idx=-100, reduction='mean'): | |||||
super(CrossEntropyLoss, self).__init__() | super(CrossEntropyLoss, self).__init__() | ||||
self._init_param_map(pred=pred, target=target, seq_len=seq_len) | self._init_param_map(pred=pred, target=target, seq_len=seq_len) | ||||
self.padding_idx = padding_idx | self.padding_idx = padding_idx | ||||
assert reduction in ('elementwise_mean', 'sum') | |||||
assert reduction in ('mean', 'sum', 'none') | |||||
self.reduction = reduction | self.reduction = reduction | ||||
def get_loss(self, pred, target, seq_len=None): | def get_loss(self, pred, target, seq_len=None): | ||||
if pred.dim()>2: | |||||
if pred.size(1)!=target.size(1): | |||||
if pred.dim() > 2: | |||||
if pred.size(1) != target.size(1): | |||||
pred = pred.transpose(1, 2) | pred = pred.transpose(1, 2) | ||||
pred = pred.reshape(-1, pred.size(-1)) | pred = pred.reshape(-1, pred.size(-1)) | ||||
target = target.reshape(-1) | target = target.reshape(-1) | ||||
@@ -263,15 +245,18 @@ class L1Loss(LossBase): | |||||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | ||||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` >`target` | :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` >`target` | ||||
:param str reduction: 支持'mean','sum'和'none'. | |||||
""" | """ | ||||
def __init__(self, pred=None, target=None): | |||||
def __init__(self, pred=None, target=None, reduction='mean'): | |||||
super(L1Loss, self).__init__() | super(L1Loss, self).__init__() | ||||
self._init_param_map(pred=pred, target=target) | self._init_param_map(pred=pred, target=target) | ||||
assert reduction in ('mean', 'sum', 'none') | |||||
self.reduction = reduction | |||||
def get_loss(self, pred, target): | def get_loss(self, pred, target): | ||||
return F.l1_loss(input=pred, target=target) | |||||
return F.l1_loss(input=pred, target=target, reduction=self.reduction) | |||||
class BCELoss(LossBase): | class BCELoss(LossBase): | ||||
@@ -282,14 +267,17 @@ class BCELoss(LossBase): | |||||
:param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` | :param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` | ||||
:param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` | :param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` | ||||
:param str reduction: 支持'mean','sum'和'none'. | |||||
""" | """ | ||||
def __init__(self, pred=None, target=None): | |||||
def __init__(self, pred=None, target=None, reduction='mean'): | |||||
super(BCELoss, self).__init__() | super(BCELoss, self).__init__() | ||||
self._init_param_map(pred=pred, target=target) | self._init_param_map(pred=pred, target=target) | ||||
assert reduction in ('mean', 'sum', 'none') | |||||
self.reduction = reduction | |||||
def get_loss(self, pred, target): | def get_loss(self, pred, target): | ||||
return F.binary_cross_entropy(input=pred, target=target) | |||||
return F.binary_cross_entropy(input=pred, target=target, reduction=self.reduction) | |||||
class NLLLoss(LossBase): | class NLLLoss(LossBase): | ||||
@@ -300,14 +288,20 @@ class NLLLoss(LossBase): | |||||
:param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` | :param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred` | ||||
:param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` | :param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target` | ||||
:param ignore_idx: ignore的index,在计算loss时将忽略target中标号为ignore_idx的内容, 可以通过该值代替 | |||||
传入seq_len. | |||||
:param str reduction: 支持'mean','sum'和'none'. | |||||
""" | """ | ||||
def __init__(self, pred=None, target=None): | |||||
def __init__(self, pred=None, target=None, ignore_idx=-100, reduction='mean'): | |||||
super(NLLLoss, self).__init__() | super(NLLLoss, self).__init__() | ||||
self._init_param_map(pred=pred, target=target) | self._init_param_map(pred=pred, target=target) | ||||
assert reduction in ('mean', 'sum', 'none') | |||||
self.reduction = reduction | |||||
self.ignore_idx = ignore_idx | |||||
def get_loss(self, pred, target): | def get_loss(self, pred, target): | ||||
return F.nll_loss(input=pred, target=target) | |||||
return F.nll_loss(input=pred, target=target, ignore_index=self.ignore_idx, reduction=self.reduction) | |||||
class LossInForward(LossBase): | class LossInForward(LossBase): | ||||
@@ -319,7 +313,7 @@ class LossInForward(LossBase): | |||||
:param str loss_key: 在forward函数中loss的键名,默认为loss | :param str loss_key: 在forward函数中loss的键名,默认为loss | ||||
""" | """ | ||||
def __init__(self, loss_key='loss'): | |||||
def __init__(self, loss_key=Const.LOSS): | |||||
super().__init__() | super().__init__() | ||||
if not isinstance(loss_key, str): | if not isinstance(loss_key, str): | ||||
raise TypeError(f"Only str allowed for loss_key, got {type(loss_key)}.") | raise TypeError(f"Only str allowed for loss_key, got {type(loss_key)}.") | ||||
@@ -11,21 +11,35 @@ | |||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
'EmbedLoader', | 'EmbedLoader', | ||||
'DataInfo', | |||||
'DataSetLoader', | 'DataSetLoader', | ||||
'CSVLoader', | 'CSVLoader', | ||||
'JsonLoader', | 'JsonLoader', | ||||
'ConllLoader', | 'ConllLoader', | ||||
'SNLILoader', | |||||
'SSTLoader', | |||||
'PeopleDailyCorpusLoader', | 'PeopleDailyCorpusLoader', | ||||
'Conll2003Loader', | 'Conll2003Loader', | ||||
'ModelLoader', | 'ModelLoader', | ||||
'ModelSaver', | 'ModelSaver', | ||||
'SSTLoader', | |||||
'MatchingLoader', | |||||
'SNLILoader', | |||||
'MNLILoader', | |||||
'QNLILoader', | |||||
'QuoraLoader', | |||||
'RTELoader', | |||||
] | ] | ||||
from .embed_loader import EmbedLoader | from .embed_loader import EmbedLoader | ||||
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, \ | |||||
SNLILoader, SSTLoader, PeopleDailyCorpusLoader, Conll2003Loader | |||||
from .base_loader import DataInfo, DataSetLoader | |||||
from .dataset_loader import CSVLoader, JsonLoader, ConllLoader, \ | |||||
PeopleDailyCorpusLoader, Conll2003Loader | |||||
from .model_io import ModelLoader, ModelSaver | from .model_io import ModelLoader, ModelSaver | ||||
from .data_loader.sst import SSTLoader | |||||
from .data_loader.matching import MatchingLoader, SNLILoader, \ | |||||
MNLILoader, QNLILoader, QuoraLoader, RTELoader |
@@ -10,6 +10,7 @@ from typing import Union, Dict | |||||
import os | import os | ||||
from ..core.dataset import DataSet | from ..core.dataset import DataSet | ||||
class BaseLoader(object): | class BaseLoader(object): | ||||
""" | """ | ||||
各个 Loader 的基类,提供了 API 的参考。 | 各个 Loader 的基类,提供了 API 的参考。 | ||||
@@ -55,8 +56,6 @@ class BaseLoader(object): | |||||
return obj | return obj | ||||
def _download_from_url(url, path): | def _download_from_url(url, path): | ||||
try: | try: | ||||
from tqdm.auto import tqdm | from tqdm.auto import tqdm | ||||
@@ -115,13 +114,11 @@ class DataInfo: | |||||
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。 | 经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。 | ||||
:param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict | :param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict | ||||
:param embeddings: 从名称(字符串)到一系列 embedding 的dict,参考 :class:`~fastNLP.io.EmbedLoader` | |||||
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict | :param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict | ||||
""" | """ | ||||
def __init__(self, vocabs: dict = None, embeddings: dict = None, datasets: dict = None): | |||||
def __init__(self, vocabs: dict = None, datasets: dict = None): | |||||
self.vocabs = vocabs or {} | self.vocabs = vocabs or {} | ||||
self.embeddings = embeddings or {} | |||||
self.datasets = datasets or {} | self.datasets = datasets or {} | ||||
def __repr__(self): | def __repr__(self): | ||||
@@ -133,6 +130,7 @@ class DataInfo: | |||||
_str += '\t{} has {} entries.\n'.format(name, len(vocab)) | _str += '\t{} has {} entries.\n'.format(name, len(vocab)) | ||||
return _str | return _str | ||||
class DataSetLoader: | class DataSetLoader: | ||||
""" | """ | ||||
别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader` | 别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader` | ||||
@@ -213,7 +211,6 @@ class DataSetLoader: | |||||
返回的 :class:`DataInfo` 对象有如下属性: | 返回的 :class:`DataInfo` 对象有如下属性: | ||||
- vocabs: 由从数据集中获取的词表组成的字典,每个词表 | - vocabs: 由从数据集中获取的词表组成的字典,每个词表 | ||||
- embeddings: (可选) 数据集对应的词嵌入 | |||||
- datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const` | - datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const` | ||||
:param paths: 原始数据读取的路径 | :param paths: 原始数据读取的路径 | ||||
@@ -0,0 +1,19 @@ | |||||
""" | |||||
用于读数据集的模块, 具体包括: | |||||
这些模块的使用方法如下: | |||||
""" | |||||
__all__ = [ | |||||
'SSTLoader', | |||||
'MatchingLoader', | |||||
'SNLILoader', | |||||
'MNLILoader', | |||||
'QNLILoader', | |||||
'QuoraLoader', | |||||
'RTELoader', | |||||
] | |||||
from .sst import SSTLoader | |||||
from .matching import MatchingLoader, SNLILoader, \ | |||||
MNLILoader, QNLILoader, QuoraLoader, RTELoader |
@@ -0,0 +1,430 @@ | |||||
import os | |||||
from typing import Union, Dict | |||||
from ...core.const import Const | |||||
from ...core.vocabulary import Vocabulary | |||||
from ..base_loader import DataInfo, DataSetLoader | |||||
from ..dataset_loader import JsonLoader, CSVLoader | |||||
from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||||
from ...modules.encoder._bert import BertTokenizer | |||||
class MatchingLoader(DataSetLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` | |||||
读取Matching任务的数据集 | |||||
:param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名 | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
self.paths = paths | |||||
def _load(self, path): | |||||
""" | |||||
:param str path: 待读取数据集的路径名 | |||||
:return: fastNLP.DataSet ds: 返回一个DataSet对象,里面必须包含3个field:其中两个分别为两个句子 | |||||
的原始字符串文本,第三个为标签 | |||||
""" | |||||
raise NotImplementedError | |||||
def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None, | |||||
to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, | |||||
cut_text: int = None, get_index=True, auto_pad_length: int=None, | |||||
auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True, | |||||
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo: | |||||
""" | |||||
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, | |||||
则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 | |||||
对应的全路径文件名。 | |||||
:param str dataset_name: 如果在paths里传入的是一个数据集的全路径文件名,那么可以用dataset_name来定义 | |||||
这个数据集的名字,如果不定义则默认为train。 | |||||
:param bool to_lower: 是否将文本自动转为小写。默认值为False。 | |||||
:param str seq_len_type: 提供的seq_len类型,支持 ``seq_len`` :提供一个数字作为句子长度; ``mask`` : | |||||
提供一个0/1的mask矩阵作为句子长度; ``bert`` :提供segment_type_id(第一个句子为0,第二个句子为1)和 | |||||
attention mask矩阵(0/1的mask矩阵)。默认值为None,即不提供seq_len | |||||
:param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径 | |||||
:param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。 | |||||
:param bool get_index: 是否需要根据词表将文本转为index | |||||
:param int auto_pad_length: 是否需要将文本自动pad到一定长度(超过这个长度的文本将会被截掉),默认为不会自动pad | |||||
:param str auto_pad_token: 自动pad的内容 | |||||
:param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False | |||||
则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input, | |||||
于此同时其他field不会被设置为input。默认值为True。 | |||||
:param set_target: set_target将控制哪些field可以被设置为target,用法与set_input一致。默认值为True。 | |||||
:param concat: 是否需要将两个句子拼接起来。如果为False则不会拼接。如果为True则会在两个句子之间插入一个<sep>。 | |||||
如果传入一个长度为4的list,则分别表示插在第一句开始前、第一句结束后、第二句开始前、第二句结束后的标识符。如果 | |||||
传入字符串 ``bert`` ,则会采用bert的拼接方式,等价于['[CLS]', '[SEP]', '', '[SEP]']. | |||||
:return: | |||||
""" | |||||
if isinstance(set_input, str): | |||||
set_input = [set_input] | |||||
if isinstance(set_target, str): | |||||
set_target = [set_target] | |||||
if isinstance(set_input, bool): | |||||
auto_set_input = set_input | |||||
else: | |||||
auto_set_input = False | |||||
if isinstance(set_target, bool): | |||||
auto_set_target = set_target | |||||
else: | |||||
auto_set_target = False | |||||
if isinstance(paths, str): | |||||
if os.path.isdir(paths): | |||||
path = {n: os.path.join(paths, self.paths[n]) for n in self.paths.keys()} | |||||
else: | |||||
path = {dataset_name if dataset_name is not None else 'train': paths} | |||||
else: | |||||
path = paths | |||||
data_info = DataInfo() | |||||
for data_name in path.keys(): | |||||
data_info.datasets[data_name] = self._load(path[data_name]) | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
if auto_set_input: | |||||
data_set.set_input(Const.INPUTS(0), Const.INPUTS(1)) | |||||
if auto_set_target: | |||||
if Const.TARGET in data_set.get_field_names(): | |||||
data_set.set_target(Const.TARGET) | |||||
if to_lower: | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0), | |||||
is_input=auto_set_input) | |||||
data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1), | |||||
is_input=auto_set_input) | |||||
if bert_tokenizer is not None: | |||||
if bert_tokenizer.lower() in PRETRAINED_BERT_MODEL_DIR: | |||||
PRETRAIN_URL = _get_base_url('bert') | |||||
model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer] | |||||
model_url = PRETRAIN_URL + model_name | |||||
model_dir = cached_path(model_url) | |||||
# 检查是否存在 | |||||
elif os.path.isdir(bert_tokenizer): | |||||
model_dir = bert_tokenizer | |||||
else: | |||||
raise ValueError(f"Cannot recognize BERT tokenizer from {bert_tokenizer}.") | |||||
words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]') | |||||
with open(os.path.join(model_dir, 'vocab.txt'), 'r') as f: | |||||
lines = f.readlines() | |||||
lines = [line.strip() for line in lines] | |||||
words_vocab.add_word_lst(lines) | |||||
words_vocab.build_vocab() | |||||
tokenizer = BertTokenizer.from_pretrained(model_dir) | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: tokenizer.tokenize(' '.join(x[fields])), new_field_name=fields, | |||||
is_input=auto_set_input) | |||||
if isinstance(concat, bool): | |||||
concat = 'default' if concat else None | |||||
if concat is not None: | |||||
if isinstance(concat, str): | |||||
CONCAT_MAP = {'bert': ['[CLS]', '[SEP]', '', '[SEP]'], | |||||
'default': ['', '<sep>', '', '']} | |||||
if concat.lower() in CONCAT_MAP: | |||||
concat = CONCAT_MAP[concat] | |||||
else: | |||||
concat = 4 * [concat] | |||||
assert len(concat) == 4, \ | |||||
f'Please choose a list with 4 symbols which at the beginning of first sentence ' \ | |||||
f'the end of first sentence, the begin of second sentence, and the end of second' \ | |||||
f'sentence. Your input is {concat}' | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
data_set.apply(lambda x: [concat[0]] + x[Const.INPUTS(0)] + [concat[1]] + [concat[2]] + | |||||
x[Const.INPUTS(1)] + [concat[3]], new_field_name=Const.INPUT) | |||||
data_set.apply(lambda x: [w for w in x[Const.INPUT] if len(w) > 0], new_field_name=Const.INPUT, | |||||
is_input=auto_set_input) | |||||
if seq_len_type is not None: | |||||
if seq_len_type == 'seq_len': # | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: len(x[fields]), | |||||
new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN), | |||||
is_input=auto_set_input) | |||||
elif seq_len_type == 'mask': | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: [1] * len(x[fields]), | |||||
new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN), | |||||
is_input=auto_set_input) | |||||
elif seq_len_type == 'bert': | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
if Const.INPUT not in data_set.get_field_names(): | |||||
raise KeyError(f'Field ``{Const.INPUT}`` not in {data_name} data set: ' | |||||
f'got {data_set.get_field_names()}') | |||||
data_set.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), | |||||
new_field_name=Const.INPUT_LENS(0), is_input=auto_set_input) | |||||
data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), | |||||
new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) | |||||
if auto_pad_length is not None: | |||||
cut_text = min(auto_pad_length, cut_text if cut_text is not None else auto_pad_length) | |||||
if cut_text is not None: | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if (Const.INPUT in fields) or ((Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len')): | |||||
data_set.apply(lambda x: x[fields][: cut_text], new_field_name=fields, | |||||
is_input=auto_set_input) | |||||
data_set_list = [d for n, d in data_info.datasets.items()] | |||||
assert len(data_set_list) > 0, f'There are NO data sets in data info!' | |||||
if bert_tokenizer is None: | |||||
words_vocab = Vocabulary(padding=auto_pad_token) | |||||
words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], | |||||
field_name=[n for n in data_set_list[0].get_field_names() | |||||
if (Const.INPUT in n)], | |||||
no_create_entry_dataset=[d for n, d in data_info.datasets.items() | |||||
if 'train' not in n]) | |||||
target_vocab = Vocabulary(padding=None, unknown=None) | |||||
target_vocab = target_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], | |||||
field_name=Const.TARGET) | |||||
data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab} | |||||
if get_index: | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: [words_vocab.to_index(w) for w in x[fields]], new_field_name=fields, | |||||
is_input=auto_set_input) | |||||
if Const.TARGET in data_set.get_field_names(): | |||||
data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, | |||||
is_input=auto_set_input, is_target=auto_set_target) | |||||
if auto_pad_length is not None: | |||||
if seq_len_type == 'seq_len': | |||||
raise RuntimeError(f'the sequence will be padded with the length {auto_pad_length}, ' | |||||
f'so the seq_len_type cannot be `{seq_len_type}`!') | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] * | |||||
(auto_pad_length - len(x[fields])), new_field_name=fields, | |||||
is_input=auto_set_input) | |||||
elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'): | |||||
data_set.apply(lambda x: x[fields] + [0] * (auto_pad_length - len(x[fields])), | |||||
new_field_name=fields, is_input=auto_set_input) | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
if isinstance(set_input, list): | |||||
data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()]) | |||||
if isinstance(set_target, list): | |||||
data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()]) | |||||
return data_info | |||||
class SNLILoader(MatchingLoader, JsonLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` | |||||
读取SNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
fields = { | |||||
'sentence1_binary_parse': Const.INPUTS(0), | |||||
'sentence2_binary_parse': Const.INPUTS(1), | |||||
'gold_label': Const.TARGET, | |||||
} | |||||
paths = paths if paths is not None else { | |||||
'train': 'snli_1.0_train.jsonl', | |||||
'dev': 'snli_1.0_dev.jsonl', | |||||
'test': 'snli_1.0_test.jsonl'} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
JsonLoader.__init__(self, fields=fields) | |||||
def _load(self, path): | |||||
ds = JsonLoader._load(self, path) | |||||
parentheses_table = str.maketrans({'(': None, ')': None}) | |||||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||||
new_field_name=Const.INPUTS(0)) | |||||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||||
new_field_name=Const.INPUTS(1)) | |||||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||||
return ds | |||||
class RTELoader(MatchingLoader, CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.dataset_loader.RTELoader` | |||||
读取RTE数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
paths = paths if paths is not None else { | |||||
'train': 'train.tsv', | |||||
'dev': 'dev.tsv', | |||||
'test': 'test.tsv' # test set has not label | |||||
} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
self.fields = { | |||||
'sentence1': Const.INPUTS(0), | |||||
'sentence2': Const.INPUTS(1), | |||||
'label': Const.TARGET, | |||||
} | |||||
CSVLoader.__init__(self, sep='\t') | |||||
def _load(self, path): | |||||
ds = CSVLoader._load(self, path) | |||||
for k, v in self.fields.items(): | |||||
if v in ds.get_field_names(): | |||||
ds.rename_field(k, v) | |||||
for fields in ds.get_all_fields(): | |||||
if Const.INPUT in fields: | |||||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||||
return ds | |||||
class QNLILoader(MatchingLoader, CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.dataset_loader.QNLILoader` | |||||
读取QNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
paths = paths if paths is not None else { | |||||
'train': 'train.tsv', | |||||
'dev': 'dev.tsv', | |||||
'test': 'test.tsv' # test set has not label | |||||
} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
self.fields = { | |||||
'question': Const.INPUTS(0), | |||||
'sentence': Const.INPUTS(1), | |||||
'label': Const.TARGET, | |||||
} | |||||
CSVLoader.__init__(self, sep='\t') | |||||
def _load(self, path): | |||||
ds = CSVLoader._load(self, path) | |||||
for k, v in self.fields.items(): | |||||
if v in ds.get_field_names(): | |||||
ds.rename_field(k, v) | |||||
for fields in ds.get_all_fields(): | |||||
if Const.INPUT in fields: | |||||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||||
return ds | |||||
class MNLILoader(MatchingLoader, CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader` | |||||
读取MNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
paths = paths if paths is not None else { | |||||
'train': 'train.tsv', | |||||
'dev_matched': 'dev_matched.tsv', | |||||
'dev_mismatched': 'dev_mismatched.tsv', | |||||
'test_matched': 'test_matched.tsv', | |||||
'test_mismatched': 'test_mismatched.tsv', | |||||
# 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', | |||||
# 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', | |||||
# test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) | |||||
} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
CSVLoader.__init__(self, sep='\t') | |||||
self.fields = { | |||||
'sentence1_binary_parse': Const.INPUTS(0), | |||||
'sentence2_binary_parse': Const.INPUTS(1), | |||||
'gold_label': Const.TARGET, | |||||
} | |||||
def _load(self, path): | |||||
ds = CSVLoader._load(self, path) | |||||
for k, v in self.fields.items(): | |||||
if k in ds.get_field_names(): | |||||
ds.rename_field(k, v) | |||||
if Const.TARGET in ds.get_field_names(): | |||||
if ds[0][Const.TARGET] == 'hidden': | |||||
ds.delete_field(Const.TARGET) | |||||
parentheses_table = str.maketrans({'(': None, ')': None}) | |||||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||||
new_field_name=Const.INPUTS(0)) | |||||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||||
new_field_name=Const.INPUTS(1)) | |||||
if Const.TARGET in ds.get_field_names(): | |||||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||||
return ds | |||||
class QuoraLoader(MatchingLoader, CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.dataset_loader.QuoraLoader` | |||||
读取MNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
paths = paths if paths is not None else { | |||||
'train': 'train.tsv', | |||||
'dev': 'dev.tsv', | |||||
'test': 'test.tsv', | |||||
} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID')) | |||||
def _load(self, path): | |||||
ds = CSVLoader._load(self, path) | |||||
return ds |
@@ -16,8 +16,6 @@ __all__ = [ | |||||
'CSVLoader', | 'CSVLoader', | ||||
'JsonLoader', | 'JsonLoader', | ||||
'ConllLoader', | 'ConllLoader', | ||||
'SNLILoader', | |||||
'SSTLoader', | |||||
'PeopleDailyCorpusLoader', | 'PeopleDailyCorpusLoader', | ||||
'Conll2003Loader', | 'Conll2003Loader', | ||||
] | ] | ||||
@@ -30,7 +28,6 @@ from ..core.dataset import DataSet | |||||
from ..core.instance import Instance | from ..core.instance import Instance | ||||
from .file_reader import _read_csv, _read_json, _read_conll | from .file_reader import _read_csv, _read_json, _read_conll | ||||
from .base_loader import DataSetLoader, DataInfo | from .base_loader import DataSetLoader, DataInfo | ||||
from .data_loader.sst import SSTLoader | |||||
from ..core.const import Const | from ..core.const import Const | ||||
from ..modules.encoder._bert import BertTokenizer | from ..modules.encoder._bert import BertTokenizer | ||||
@@ -111,7 +108,7 @@ class PeopleDailyCorpusLoader(DataSetLoader): | |||||
else: | else: | ||||
instance = Instance(words=sent_words) | instance = Instance(words=sent_words) | ||||
data_set.append(instance) | data_set.append(instance) | ||||
data_set.apply(lambda ins: len(ins["words"]), new_field_name="seq_len") | |||||
data_set.apply(lambda ins: len(ins[Const.INPUT]), new_field_name=Const.INPUT_LEN) | |||||
return data_set | return data_set | ||||
@@ -249,42 +246,6 @@ class JsonLoader(DataSetLoader): | |||||
return ds | return ds | ||||
class SNLILoader(JsonLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` | |||||
读取SNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||||
""" | |||||
def __init__(self): | |||||
fields = { | |||||
'sentence1_parse': Const.INPUTS(0), | |||||
'sentence2_parse': Const.INPUTS(1), | |||||
'gold_label': Const.TARGET, | |||||
} | |||||
super(SNLILoader, self).__init__(fields=fields) | |||||
def _load(self, path): | |||||
ds = super(SNLILoader, self)._load(path) | |||||
def parse_tree(x): | |||||
t = Tree.fromstring(x) | |||||
return t.leaves() | |||||
ds.apply(lambda ins: parse_tree( | |||||
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) | |||||
ds.apply(lambda ins: parse_tree( | |||||
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) | |||||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||||
return ds | |||||
class CSVLoader(DataSetLoader): | class CSVLoader(DataSetLoader): | ||||
""" | """ | ||||
别名::class:`fastNLP.io.CSVLoader` :class:`fastNLP.io.dataset_loader.CSVLoader` | 别名::class:`fastNLP.io.CSVLoader` :class:`fastNLP.io.dataset_loader.CSVLoader` | ||||
@@ -8,35 +8,7 @@ from torch import nn | |||||
from .base_model import BaseModel | from .base_model import BaseModel | ||||
from ..core.const import Const | from ..core.const import Const | ||||
from ..modules.encoder import BertModel | from ..modules.encoder import BertModel | ||||
class BertConfig: | |||||
def __init__( | |||||
self, | |||||
vocab_size=30522, | |||||
hidden_size=768, | |||||
num_hidden_layers=12, | |||||
num_attention_heads=12, | |||||
intermediate_size=3072, | |||||
hidden_act="gelu", | |||||
hidden_dropout_prob=0.1, | |||||
attention_probs_dropout_prob=0.1, | |||||
max_position_embeddings=512, | |||||
type_vocab_size=2, | |||||
initializer_range=0.02 | |||||
): | |||||
self.vocab_size = vocab_size | |||||
self.hidden_size = hidden_size | |||||
self.num_hidden_layers = num_hidden_layers | |||||
self.num_attention_heads = num_attention_heads | |||||
self.intermediate_size = intermediate_size | |||||
self.hidden_act = hidden_act | |||||
self.hidden_dropout_prob = hidden_dropout_prob | |||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob | |||||
self.max_position_embeddings = max_position_embeddings | |||||
self.type_vocab_size = type_vocab_size | |||||
self.initializer_range = initializer_range | |||||
from ..modules.encoder._bert import BertConfig | |||||
class BertForSequenceClassification(BaseModel): | class BertForSequenceClassification(BaseModel): | ||||
@@ -84,11 +56,17 @@ class BertForSequenceClassification(BaseModel): | |||||
self.bert = BertModel.from_pretrained(bert_dir) | self.bert = BertModel.from_pretrained(bert_dir) | ||||
else: | else: | ||||
if config is None: | if config is None: | ||||
config = BertConfig() | |||||
self.bert = BertModel(**config.__dict__) | |||||
config = BertConfig(30522) | |||||
self.bert = BertModel(config) | |||||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | self.dropout = nn.Dropout(config.hidden_dropout_prob) | ||||
self.classifier = nn.Linear(config.hidden_size, num_labels) | self.classifier = nn.Linear(config.hidden_size, num_labels) | ||||
@classmethod | |||||
def from_pretrained(cls, num_labels, pretrained_model_dir): | |||||
config = BertConfig(pretrained_model_dir) | |||||
model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir) | |||||
return model | |||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | ||||
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | ||||
pooled_output = self.dropout(pooled_output) | pooled_output = self.dropout(pooled_output) | ||||
@@ -151,11 +129,17 @@ class BertForMultipleChoice(BaseModel): | |||||
self.bert = BertModel.from_pretrained(bert_dir) | self.bert = BertModel.from_pretrained(bert_dir) | ||||
else: | else: | ||||
if config is None: | if config is None: | ||||
config = BertConfig() | |||||
self.bert = BertModel(**config.__dict__) | |||||
config = BertConfig(30522) | |||||
self.bert = BertModel(config) | |||||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | self.dropout = nn.Dropout(config.hidden_dropout_prob) | ||||
self.classifier = nn.Linear(config.hidden_size, 1) | self.classifier = nn.Linear(config.hidden_size, 1) | ||||
@classmethod | |||||
def from_pretrained(cls, num_choices, pretrained_model_dir): | |||||
config = BertConfig(pretrained_model_dir) | |||||
model = cls(num_choices=num_choices, config=config, bert_dir=pretrained_model_dir) | |||||
return model | |||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | ||||
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) | flat_input_ids = input_ids.view(-1, input_ids.size(-1)) | ||||
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) | flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) | ||||
@@ -224,11 +208,17 @@ class BertForTokenClassification(BaseModel): | |||||
self.bert = BertModel.from_pretrained(bert_dir) | self.bert = BertModel.from_pretrained(bert_dir) | ||||
else: | else: | ||||
if config is None: | if config is None: | ||||
config = BertConfig() | |||||
self.bert = BertModel(**config.__dict__) | |||||
config = BertConfig(30522) | |||||
self.bert = BertModel(config) | |||||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | self.dropout = nn.Dropout(config.hidden_dropout_prob) | ||||
self.classifier = nn.Linear(config.hidden_size, num_labels) | self.classifier = nn.Linear(config.hidden_size, num_labels) | ||||
@classmethod | |||||
def from_pretrained(cls, num_labels, pretrained_model_dir): | |||||
config = BertConfig(pretrained_model_dir) | |||||
model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir) | |||||
return model | |||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | ||||
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | ||||
sequence_output = self.dropout(sequence_output) | sequence_output = self.dropout(sequence_output) | ||||
@@ -302,12 +292,18 @@ class BertForQuestionAnswering(BaseModel): | |||||
self.bert = BertModel.from_pretrained(bert_dir) | self.bert = BertModel.from_pretrained(bert_dir) | ||||
else: | else: | ||||
if config is None: | if config is None: | ||||
config = BertConfig() | |||||
self.bert = BertModel(**config.__dict__) | |||||
config = BertConfig(30522) | |||||
self.bert = BertModel(config) | |||||
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version | ||||
# self.dropout = nn.Dropout(config.hidden_dropout_prob) | # self.dropout = nn.Dropout(config.hidden_dropout_prob) | ||||
self.qa_outputs = nn.Linear(config.hidden_size, 2) | self.qa_outputs = nn.Linear(config.hidden_size, 2) | ||||
@classmethod | |||||
def from_pretrained(cls, pretrained_model_dir): | |||||
config = BertConfig(pretrained_model_dir) | |||||
model = cls(config=config, bert_dir=pretrained_model_dir) | |||||
return model | |||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): | def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): | ||||
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | ||||
logits = self.qa_outputs(sequence_output) | logits = self.qa_outputs(sequence_output) | ||||
@@ -15,7 +15,8 @@ class MLP(nn.Module): | |||||
多层感知器 | 多层感知器 | ||||
:param List[int] size_layer: 一个int的列表,用来定义MLP的层数,列表中的数字为每一层是hidden数目。MLP的层数为 len(size_layer) - 1 | :param List[int] size_layer: 一个int的列表,用来定义MLP的层数,列表中的数字为每一层是hidden数目。MLP的层数为 len(size_layer) - 1 | ||||
:param Union[str,func,List[str]] activation: 一个字符串或者函数的列表,用来定义每一个隐层的激活函数,字符串包括relu,tanh和sigmoid,默认值为relu | |||||
:param Union[str,func,List[str]] activation: 一个字符串或者函数的列表,用来定义每一个隐层的激活函数,字符串包括relu,tanh和 | |||||
sigmoid,默认值为relu | |||||
:param Union[str,func] output_activation: 字符串或者函数,用来定义输出层的激活函数,默认值为None,表示输出层没有激活函数 | :param Union[str,func] output_activation: 字符串或者函数,用来定义输出层的激活函数,默认值为None,表示输出层没有激活函数 | ||||
:param str initial_method: 参数初始化方式 | :param str initial_method: 参数初始化方式 | ||||
:param float dropout: dropout概率,默认值为0 | :param float dropout: dropout概率,默认值为0 | ||||
@@ -26,6 +26,7 @@ import sys | |||||
CONFIG_FILE = 'bert_config.json' | CONFIG_FILE = 'bert_config.json' | ||||
class BertConfig(object): | class BertConfig(object): | ||||
"""Configuration class to store the configuration of a `BertModel`. | """Configuration class to store the configuration of a `BertModel`. | ||||
""" | """ | ||||
@@ -339,13 +340,19 @@ class BertModel(nn.Module): | |||||
如果你想使用预训练好的权重矩阵,请在以下网址下载. | 如果你想使用预训练好的权重矩阵,请在以下网址下载. | ||||
sources:: | sources:: | ||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", | |||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", | |||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", | |||||
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", | |||||
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", | |||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", | |||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", | |||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin", | |||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin", | |||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin", | |||||
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin", | |||||
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin", | |||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin", | |||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin", | |||||
'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-pytorch_model.bin", | |||||
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin", | |||||
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin", | |||||
'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin", | |||||
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin", | |||||
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin" | |||||
用预训练权重矩阵来建立BERT模型:: | 用预训练权重矩阵来建立BERT模型:: | ||||
@@ -562,6 +569,7 @@ class WordpieceTokenizer(object): | |||||
output_tokens.extend(sub_tokens) | output_tokens.extend(sub_tokens) | ||||
return output_tokens | return output_tokens | ||||
def load_vocab(vocab_file): | def load_vocab(vocab_file): | ||||
"""Loads a vocabulary file into a dictionary.""" | """Loads a vocabulary file into a dictionary.""" | ||||
vocab = collections.OrderedDict() | vocab = collections.OrderedDict() | ||||
@@ -692,6 +700,7 @@ class BasicTokenizer(object): | |||||
output.append(char) | output.append(char) | ||||
return "".join(output) | return "".join(output) | ||||
def _is_whitespace(char): | def _is_whitespace(char): | ||||
"""Checks whether `chars` is a whitespace character.""" | """Checks whether `chars` is a whitespace character.""" | ||||
# \t, \n, and \r are technically contorl characters but we treat them | # \t, \n, and \r are technically contorl characters but we treat them | ||||
@@ -3,6 +3,8 @@ | |||||
复现的模型有: | 复现的模型有: | ||||
- [Star-Transformer](Star_transformer/) | - [Star-Transformer](Star_transformer/) | ||||
- [Biaffine](https://github.com/fastnlp/fastNLP/blob/999a14381747068e9e6a7cc370037b320197db00/fastNLP/models/biaffine_parser.py#L239) | |||||
- [CNNText](https://github.com/fastnlp/fastNLP/blob/999a14381747068e9e6a7cc370037b320197db00/fastNLP/models/cnn_text_classification.py#L12) | |||||
- ... | - ... | ||||
# 任务复现 | # 任务复现 | ||||
@@ -11,11 +13,11 @@ | |||||
## Matching (自然语言推理/句子匹配) | ## Matching (自然语言推理/句子匹配) | ||||
- [Matching 任务复现](matching/) | |||||
- [Matching 任务复现](matching) | |||||
## Sequence Labeling (序列标注) | ## Sequence Labeling (序列标注) | ||||
- still in progress | |||||
- [NER](seqence_labelling/ner) | |||||
## Coreference resolution (指代消解) | ## Coreference resolution (指代消解) | ||||
@@ -2,7 +2,8 @@ import torch | |||||
import json | import json | ||||
import os | import os | ||||
from fastNLP import Vocabulary | from fastNLP import Vocabulary | ||||
from fastNLP.io.dataset_loader import ConllLoader, SSTLoader, SNLILoader | |||||
from fastNLP.io.dataset_loader import ConllLoader | |||||
from fastNLP.io.data_loader import SSTLoader, SNLILoader | |||||
from fastNLP.core import Const as C | from fastNLP.core import Const as C | ||||
import numpy as np | import numpy as np | ||||
@@ -16,12 +16,11 @@ class MatchingLoader(DataSetLoader): | |||||
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` | 别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` | ||||
读取Matching任务的数据集 | 读取Matching任务的数据集 | ||||
:param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名 | |||||
""" | """ | ||||
def __init__(self, paths: dict=None): | def __init__(self, paths: dict=None): | ||||
""" | |||||
:param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名 | |||||
""" | |||||
self.paths = paths | self.paths = paths | ||||
def _load(self, path): | def _load(self, path): | ||||
@@ -173,7 +172,7 @@ class MatchingLoader(DataSetLoader): | |||||
new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) | new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) | ||||
if auto_pad_length is not None: | if auto_pad_length is not None: | ||||
cut_text = min(auto_pad_length, cut_text if cut_text is not None else 0) | |||||
cut_text = min(auto_pad_length, cut_text if cut_text is not None else auto_pad_length) | |||||
if cut_text is not None: | if cut_text is not None: | ||||
for data_name, data_set in data_info.datasets.items(): | for data_name, data_set in data_info.datasets.items(): | ||||
@@ -209,15 +208,18 @@ class MatchingLoader(DataSetLoader): | |||||
is_input=auto_set_input, is_target=auto_set_target) | is_input=auto_set_input, is_target=auto_set_target) | ||||
if auto_pad_length is not None: | if auto_pad_length is not None: | ||||
if seq_len_type == 'seq_len': | |||||
raise RuntimeError(f'the sequence will be padded with the length {auto_pad_length}, ' | |||||
f'so the seq_len_type cannot be `{seq_len_type}`!') | |||||
for data_name, data_set in data_info.datasets.items(): | for data_name, data_set in data_info.datasets.items(): | ||||
for fields in data_set.get_field_names(): | for fields in data_set.get_field_names(): | ||||
if Const.INPUT in fields: | if Const.INPUT in fields: | ||||
data_set.apply(lambda x: x[fields] + [words_vocab.padding] * (auto_pad_length - len(x[fields])), | |||||
new_field_name=fields, is_input=auto_set_input) | |||||
elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'): | |||||
data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] * | data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] * | ||||
(auto_pad_length - len(x[fields])), new_field_name=fields, | (auto_pad_length - len(x[fields])), new_field_name=fields, | ||||
is_input=auto_set_input) | is_input=auto_set_input) | ||||
elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'): | |||||
data_set.apply(lambda x: x[fields] + [0] * (auto_pad_length - len(x[fields])), | |||||
new_field_name=fields, is_input=auto_set_input) | |||||
for data_name, data_set in data_info.datasets.items(): | for data_name, data_set in data_info.datasets.items(): | ||||
if isinstance(set_input, list): | if isinstance(set_input, list): | ||||
@@ -284,7 +286,7 @@ class RTELoader(MatchingLoader, CSVLoader): | |||||
paths = paths if paths is not None else { | paths = paths if paths is not None else { | ||||
'train': 'train.tsv', | 'train': 'train.tsv', | ||||
'dev': 'dev.tsv', | 'dev': 'dev.tsv', | ||||
# 'test': 'test.tsv' # test set has not label | |||||
'test': 'test.tsv' # test set has not label | |||||
} | } | ||||
MatchingLoader.__init__(self, paths=paths) | MatchingLoader.__init__(self, paths=paths) | ||||
self.fields = { | self.fields = { | ||||
@@ -298,7 +300,8 @@ class RTELoader(MatchingLoader, CSVLoader): | |||||
ds = CSVLoader._load(self, path) | ds = CSVLoader._load(self, path) | ||||
for k, v in self.fields.items(): | for k, v in self.fields.items(): | ||||
ds.rename_field(k, v) | |||||
if v in ds.get_field_names(): | |||||
ds.rename_field(k, v) | |||||
for fields in ds.get_all_fields(): | for fields in ds.get_all_fields(): | ||||
if Const.INPUT in fields: | if Const.INPUT in fields: | ||||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | ||||
@@ -323,7 +326,7 @@ class QNLILoader(MatchingLoader, CSVLoader): | |||||
paths = paths if paths is not None else { | paths = paths if paths is not None else { | ||||
'train': 'train.tsv', | 'train': 'train.tsv', | ||||
'dev': 'dev.tsv', | 'dev': 'dev.tsv', | ||||
# 'test': 'test.tsv' # test set has not label | |||||
'test': 'test.tsv' # test set has not label | |||||
} | } | ||||
MatchingLoader.__init__(self, paths=paths) | MatchingLoader.__init__(self, paths=paths) | ||||
self.fields = { | self.fields = { | ||||
@@ -337,7 +340,8 @@ class QNLILoader(MatchingLoader, CSVLoader): | |||||
ds = CSVLoader._load(self, path) | ds = CSVLoader._load(self, path) | ||||
for k, v in self.fields.items(): | for k, v in self.fields.items(): | ||||
ds.rename_field(k, v) | |||||
if v in ds.get_field_names(): | |||||
ds.rename_field(k, v) | |||||
for fields in ds.get_all_fields(): | for fields in ds.get_all_fields(): | ||||
if Const.INPUT in fields: | if Const.INPUT in fields: | ||||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | ||||
@@ -349,7 +353,7 @@ class MNLILoader(MatchingLoader, CSVLoader): | |||||
""" | """ | ||||
别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader` | 别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader` | ||||
读取SNLI数据集,读取的DataSet包含fields:: | |||||
读取MNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | words1: list(str),第一句文本, premise | ||||
words2: list(str), 第二句文本, hypothesis | words2: list(str), 第二句文本, hypothesis | ||||
@@ -367,6 +371,7 @@ class MNLILoader(MatchingLoader, CSVLoader): | |||||
'test_mismatched': 'test_mismatched.tsv', | 'test_mismatched': 'test_mismatched.tsv', | ||||
# 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', | # 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', | ||||
# 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', | # 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', | ||||
# test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) | # test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) | ||||
} | } | ||||
MatchingLoader.__init__(self, paths=paths) | MatchingLoader.__init__(self, paths=paths) | ||||
@@ -400,6 +405,17 @@ class MNLILoader(MatchingLoader, CSVLoader): | |||||
class QuoraLoader(MatchingLoader, CSVLoader): | class QuoraLoader(MatchingLoader, CSVLoader): | ||||
""" | |||||
别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.dataset_loader.QuoraLoader` | |||||
读取MNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | def __init__(self, paths: dict=None): | ||||
paths = paths if paths is not None else { | paths = paths if paths is not None else { | ||||
@@ -1,5 +1,5 @@ | |||||
numpy | |||||
torch>=0.4.0 | |||||
tqdm | |||||
nltk | |||||
numpy>=1.14.2 | |||||
torch>=1.0.0 | |||||
tqdm>=4.28.1 | |||||
nltk>=3.4.1 | |||||
requests | requests |
@@ -1,7 +1,7 @@ | |||||
import unittest | import unittest | ||||
import os | import os | ||||
from fastNLP.io import Conll2003Loader, PeopleDailyCorpusLoader, CSVLoader, SNLILoader, JsonLoader | |||||
from fastNLP.io.dataset_loader import SSTLoader | |||||
from fastNLP.io import Conll2003Loader, PeopleDailyCorpusLoader, CSVLoader, JsonLoader | |||||
from fastNLP.io.data_loader import SSTLoader, SNLILoader | |||||
from reproduction.text_classification.data.yelpLoader import yelpLoader | from reproduction.text_classification.data.yelpLoader import yelpLoader | ||||
@@ -61,3 +61,12 @@ class TestDatasetLoader(unittest.TestCase): | |||||
print(info.vocabs) | print(info.vocabs) | ||||
print(info.datasets) | print(info.datasets) | ||||
os.remove(train), os.remove(test) | os.remove(train), os.remove(test) | ||||
def test_import(self): | |||||
import fastNLP | |||||
from fastNLP.io import SNLILoader | |||||
ds = SNLILoader().process('test/data_for_tests/sample_snli.jsonl', to_lower=True, | |||||
get_index=True, seq_len_type='seq_len') | |||||
assert 'train' in ds.datasets | |||||
assert len(ds.datasets) == 1 | |||||
assert len(ds.datasets['train']) == 3 |
@@ -8,8 +8,9 @@ from fastNLP.models.bert import * | |||||
class TestBert(unittest.TestCase): | class TestBert(unittest.TestCase): | ||||
def test_bert_1(self): | def test_bert_1(self): | ||||
from fastNLP.core.const import Const | from fastNLP.core.const import Const | ||||
from fastNLP.modules.encoder._bert import BertConfig | |||||
model = BertForSequenceClassification(2) | |||||
model = BertForSequenceClassification(2, BertConfig(32000)) | |||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | ||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | ||||
@@ -22,8 +23,9 @@ class TestBert(unittest.TestCase): | |||||
def test_bert_2(self): | def test_bert_2(self): | ||||
from fastNLP.core.const import Const | from fastNLP.core.const import Const | ||||
from fastNLP.modules.encoder._bert import BertConfig | |||||
model = BertForMultipleChoice(2) | |||||
model = BertForMultipleChoice(2, BertConfig(32000)) | |||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | ||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | ||||
@@ -36,8 +38,9 @@ class TestBert(unittest.TestCase): | |||||
def test_bert_3(self): | def test_bert_3(self): | ||||
from fastNLP.core.const import Const | from fastNLP.core.const import Const | ||||
from fastNLP.modules.encoder._bert import BertConfig | |||||
model = BertForTokenClassification(7) | |||||
model = BertForTokenClassification(7, BertConfig(32000)) | |||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | ||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | ||||
@@ -50,8 +53,9 @@ class TestBert(unittest.TestCase): | |||||
def test_bert_4(self): | def test_bert_4(self): | ||||
from fastNLP.core.const import Const | from fastNLP.core.const import Const | ||||
from fastNLP.modules.encoder._bert import BertConfig | |||||
model = BertForQuestionAnswering() | |||||
model = BertForQuestionAnswering(BertConfig(32000)) | |||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | ||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | ||||
@@ -8,8 +8,9 @@ from fastNLP.models.bert import BertModel | |||||
class TestBert(unittest.TestCase): | class TestBert(unittest.TestCase): | ||||
def test_bert_1(self): | def test_bert_1(self): | ||||
model = BertModel(vocab_size=32000, hidden_size=768, | |||||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) | |||||
from fastNLP.modules.encoder._bert import BertConfig | |||||
config = BertConfig(32000) | |||||
model = BertModel(config) | |||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | ||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | ||||
@@ -18,4 +19,4 @@ class TestBert(unittest.TestCase): | |||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) | ||||
for layer in all_encoder_layers: | for layer in all_encoder_layers: | ||||
self.assertEqual(tuple(layer.shape), (2, 3, 768)) | self.assertEqual(tuple(layer.shape), (2, 3, 768)) | ||||
self.assertEqual(tuple(pooled_output.shape), (2, 768)) | |||||
self.assertEqual(tuple(pooled_output.shape), (2, 768)) |