@@ -137,6 +137,7 @@ fastNLP的大致工作流程如上图所示,而项目结构如下: | |||||
<b>致谢</b> | <b>致谢</b> | ||||
感谢 [深脑云](http://www.dbcloud.ai/) 提供的模型与数据存储、下载服务。 | 感谢 [深脑云](http://www.dbcloud.ai/) 提供的模型与数据存储、下载服务。 | ||||
<a href="http://www.dbcloud.ai/"> | <a href="http://www.dbcloud.ai/"> | ||||
<img src="http://www.dbcloud.ai/static/images/logo/logo.png"> | <img src="http://www.dbcloud.ai/static/images/logo/logo.png"> | ||||
</a> | </a> | ||||
@@ -37,6 +37,8 @@ PRETRAINED_BERT_MODEL_DIR = { | |||||
'en-base-cased-mrpc': 'bert-base-cased-finetuned-mrpc.zip', | 'en-base-cased-mrpc': 'bert-base-cased-finetuned-mrpc.zip', | ||||
'en-distilbert-base-uncased': 'distilbert-base-uncased.zip', | |||||
'multi-base-cased': 'bert-base-multilingual-cased.zip', | 'multi-base-cased': 'bert-base-multilingual-cased.zip', | ||||
'multi-base-uncased': 'bert-base-multilingual-uncased.zip', | 'multi-base-uncased': 'bert-base-multilingual-uncased.zip', | ||||
@@ -340,7 +340,7 @@ class QuoraLoader(Loader): | |||||
class CNXNLILoader(Loader): | class CNXNLILoader(Loader): | ||||
""" | """ | ||||
别名: | 别名: | ||||
数据集简介:中文句对NLI(本为multi-lingual的数据集,但是这里只取了中文的数据集)。原句子已被MOSES tokenizer处理 | |||||
数据集简介:中文句对NLI(本为multi-lingual的数据集,但是这里只取了中文的数据集)。原句子已被MOSES tokenizer处理,这里我们将其还原并重新按字tokenize | |||||
原始数据为: | 原始数据为: | ||||
train中的数据包括premise,hypo和label三个field | train中的数据包括premise,hypo和label三个field | ||||
dev和test中的数据为csv或json格式,包括十多个field,这里只取与以上三个field中的数据 | dev和test中的数据为csv或json格式,包括十多个field,这里只取与以上三个field中的数据 | ||||
@@ -358,8 +358,6 @@ class CNXNLILoader(Loader): | |||||
super(CNXNLILoader, self).__init__() | super(CNXNLILoader, self).__init__() | ||||
def _load(self, path: str = None): | def _load(self, path: str = None): | ||||
#csv_loader = CSVLoader(sep='\t') | |||||
#ds_all = csv_loader._load(path) | |||||
ds_all = DataSet() | ds_all = DataSet() | ||||
with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
head_name_list = f.readline().strip().split('\t') | head_name_list = f.readline().strip().split('\t') | ||||
@@ -386,17 +384,15 @@ class CNXNLILoader(Loader): | |||||
return ds_zh | return ds_zh | ||||
def _load_train(self, path: str = None): | def _load_train(self, path: str = None): | ||||
#csv_loader = CSVLoader(sep='\t') | |||||
#ds = csv_loader._load(path) | |||||
ds = DataSet() | ds = DataSet() | ||||
with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
next(f) | next(f) | ||||
for line in f: | for line in f: | ||||
raw_instance = line.strip().split('\t') | raw_instance = line.strip().split('\t') | ||||
premise = raw_instance[0] | |||||
hypo = raw_instance[1] | |||||
label = raw_instance[-1] | |||||
premise = "".join(raw_instance[0].split())# 把已经分好词的premise和hypo强制还原为character segmentation | |||||
hypo = "".join(raw_instance[1].split()) | |||||
label = "".join(raw_instance[-1].split()) | |||||
if premise: | if premise: | ||||
ds.append(Instance(premise=premise, hypo=hypo, label=label)) | ds.append(Instance(premise=premise, hypo=hypo, label=label)) | ||||
@@ -466,6 +466,7 @@ class LCQMCBertPipe(MatchingBertPipe): | |||||
data_bundle = LCQMCLoader().load(paths) | data_bundle = LCQMCLoader().load(paths) | ||||
data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) | data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) | ||||
data_bundle = self.process(data_bundle) | data_bundle = self.process(data_bundle) | ||||
data_bundle = TruncateBertPipe(task='cn').process(data_bundle) | |||||
data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) | data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) | ||||
return data_bundle | return data_bundle | ||||
@@ -475,6 +476,7 @@ class BQCorpusBertPipe(MatchingBertPipe): | |||||
data_bundle = BQCorpusLoader().load(paths) | data_bundle = BQCorpusLoader().load(paths) | ||||
data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) | data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) | ||||
data_bundle = self.process(data_bundle) | data_bundle = self.process(data_bundle) | ||||
data_bundle = TruncateBertPipe(task='cn').process(data_bundle) | |||||
data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) | data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) | ||||
return data_bundle | return data_bundle | ||||
@@ -485,5 +487,41 @@ class CNXNLIBertPipe(MatchingBertPipe): | |||||
data_bundle = GranularizePipe(task='XNLI').process(data_bundle) | data_bundle = GranularizePipe(task='XNLI').process(data_bundle) | ||||
data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) | data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) | ||||
data_bundle = self.process(data_bundle) | data_bundle = self.process(data_bundle) | ||||
data_bundle = TruncateBertPipe(task='cn').process(data_bundle) | |||||
data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) | data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) | ||||
return data_bundle | return data_bundle | ||||
class TruncateBertPipe(Pipe): | |||||
def __init__(self, task='cn'): | |||||
super().__init__() | |||||
self.task = task | |||||
def _truncate(self, sentence_index:list, sep_index_vocab): | |||||
# 根据[SEP]在vocab中的index,找到[SEP]在dataset的field['words']中的index | |||||
sep_index_words = sentence_index.index(sep_index_vocab) | |||||
words_before_sep = sentence_index[:sep_index_words] | |||||
words_after_sep = sentence_index[sep_index_words:] # 注意此部分包括了[SEP] | |||||
if self.task == 'cn': | |||||
# 中文任务将Instance['words']中在[SEP]前后的文本分别截至长度不超过250 | |||||
words_before_sep = words_before_sep[:250] | |||||
words_after_sep = words_after_sep[:250] | |||||
elif self.task == 'en': | |||||
# 英文任务将Instance['words']中在[SEP]前后的文本分别截至长度不超过215 | |||||
words_before_sep = words_before_sep[:215] | |||||
words_after_sep = words_after_sep[:215] | |||||
else: | |||||
raise RuntimeError("Only support 'cn' or 'en' task.") | |||||
return words_before_sep + words_after_sep | |||||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||||
for name in data_bundle.datasets.keys(): | |||||
dataset = data_bundle.get_dataset(name) | |||||
sep_index_vocab = data_bundle.get_vocab('words').to_index('[SEP]') | |||||
dataset.apply_field(lambda sent_index: self._truncate(sentence_index=sent_index, sep_index_vocab=sep_index_vocab), field_name='words', new_field_name='words') | |||||
# truncate之后需要更新seq_len | |||||
dataset.add_seq_len(field_name='words') | |||||
return data_bundle | |||||
@@ -16,6 +16,7 @@ import unicodedata | |||||
import torch | import torch | ||||
from torch import nn | from torch import nn | ||||
import numpy as np | |||||
from ..utils import _get_file_name_base_on_postfix | from ..utils import _get_file_name_base_on_postfix | ||||
from ...io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MODEL_DIR | from ...io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MODEL_DIR | ||||
@@ -24,6 +25,24 @@ from ...core import logger | |||||
CONFIG_FILE = 'bert_config.json' | CONFIG_FILE = 'bert_config.json' | ||||
VOCAB_NAME = 'vocab.txt' | VOCAB_NAME = 'vocab.txt' | ||||
BERT_KEY_RENAME_MAP_1 = { | |||||
'gamma': 'weight', | |||||
'beta': 'bias', | |||||
'distilbert.embeddings': 'bert.embeddings', | |||||
'distilbert.transformer': 'bert.encoder', | |||||
} | |||||
BERT_KEY_RENAME_MAP_2 = { | |||||
'q_lin': 'self.query', | |||||
'k_lin': 'self.key', | |||||
'v_lin': 'self.value', | |||||
'out_lin': 'output.dense', | |||||
'sa_layer_norm': 'attention.output.LayerNorm', | |||||
'ffn.lin1': 'intermediate.dense', | |||||
'ffn.lin2': 'output.dense', | |||||
'output_layer_norm': 'output.LayerNorm', | |||||
} | |||||
class BertConfig(object): | class BertConfig(object): | ||||
"""Configuration class to store the configuration of a `BertModel`. | """Configuration class to store the configuration of a `BertModel`. | ||||
@@ -162,6 +181,55 @@ class BertLayerNorm(nn.Module): | |||||
return self.weight * x + self.bias | return self.weight * x + self.bias | ||||
class DistilBertEmbeddings(nn.Module): | |||||
def __init__(self, config): | |||||
super(DistilBertEmbeddings, self).__init__() | |||||
def create_sinusoidal_embeddings(n_pos, dim, out): | |||||
position_enc = np.array([ | |||||
[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] | |||||
for pos in range(n_pos) | |||||
]) | |||||
out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) | |||||
out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) | |||||
out.detach_() | |||||
out.requires_grad = False | |||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) | |||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) | |||||
if config.sinusoidal_pos_embds: | |||||
create_sinusoidal_embeddings(n_pos=config.max_position_embeddings, | |||||
dim=config.hidden_size, | |||||
out=self.position_embeddings.weight) | |||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12) | |||||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||||
def forward(self, input_ids, token_type_ids): | |||||
""" | |||||
Parameters | |||||
---------- | |||||
input_ids: torch.tensor(bs, max_seq_length) | |||||
The token ids to embed. | |||||
token_type_ids: no used. | |||||
Outputs | |||||
------- | |||||
embeddings: torch.tensor(bs, max_seq_length, dim) | |||||
The embedded tokens (plus position embeddings, no token_type embeddings) | |||||
""" | |||||
seq_length = input_ids.size(1) | |||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length) | |||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length) | |||||
word_embeddings = self.word_embeddings(input_ids) # (bs, max_seq_length, dim) | |||||
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim) | |||||
embeddings = word_embeddings + position_embeddings # (bs, max_seq_length, dim) | |||||
embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim) | |||||
embeddings = self.dropout(embeddings) # (bs, max_seq_length, dim) | |||||
return embeddings | |||||
class BertEmbeddings(nn.Module): | class BertEmbeddings(nn.Module): | ||||
"""Construct the embeddings from word, position and token_type embeddings. | """Construct the embeddings from word, position and token_type embeddings. | ||||
""" | """ | ||||
@@ -383,9 +451,22 @@ class BertModel(nn.Module): | |||||
super(BertModel, self).__init__() | super(BertModel, self).__init__() | ||||
self.config = config | self.config = config | ||||
self.hidden_size = self.config.hidden_size | self.hidden_size = self.config.hidden_size | ||||
self.embeddings = BertEmbeddings(config) | |||||
self.model_type = 'bert' | |||||
if hasattr(config, 'sinusoidal_pos_embds'): | |||||
self.model_type = 'distilbert' | |||||
elif 'model_type' in kwargs: | |||||
self.model_type = kwargs['model_type'].lower() | |||||
if self.model_type == 'distilbert': | |||||
self.embeddings = DistilBertEmbeddings(config) | |||||
else: | |||||
self.embeddings = BertEmbeddings(config) | |||||
self.encoder = BertEncoder(config) | self.encoder = BertEncoder(config) | ||||
self.pooler = BertPooler(config) | |||||
if self.model_type != 'distilbert': | |||||
self.pooler = BertPooler(config) | |||||
else: | |||||
logger.info('DistilBert has NOT pooler, will use hidden states of [CLS] token as pooled output.') | |||||
self.apply(self.init_bert_weights) | self.apply(self.init_bert_weights) | ||||
def init_bert_weights(self, module): | def init_bert_weights(self, module): | ||||
@@ -427,7 +508,10 @@ class BertModel(nn.Module): | |||||
extended_attention_mask, | extended_attention_mask, | ||||
output_all_encoded_layers=output_all_encoded_layers) | output_all_encoded_layers=output_all_encoded_layers) | ||||
sequence_output = encoded_layers[-1] | sequence_output = encoded_layers[-1] | ||||
pooled_output = self.pooler(sequence_output) | |||||
if self.model_type != 'distilbert': | |||||
pooled_output = self.pooler(sequence_output) | |||||
else: | |||||
pooled_output = sequence_output[:, 0] | |||||
if not output_all_encoded_layers: | if not output_all_encoded_layers: | ||||
encoded_layers = encoded_layers[-1] | encoded_layers = encoded_layers[-1] | ||||
return encoded_layers, pooled_output | return encoded_layers, pooled_output | ||||
@@ -445,9 +529,7 @@ class BertModel(nn.Module): | |||||
# Load config | # Load config | ||||
config_file = _get_file_name_base_on_postfix(pretrained_model_dir, '.json') | config_file = _get_file_name_base_on_postfix(pretrained_model_dir, '.json') | ||||
config = BertConfig.from_json_file(config_file) | config = BertConfig.from_json_file(config_file) | ||||
# logger.info("Model config {}".format(config)) | |||||
# Instantiate model. | |||||
model = cls(config, *inputs, **kwargs) | |||||
if state_dict is None: | if state_dict is None: | ||||
weights_path = _get_file_name_base_on_postfix(pretrained_model_dir, '.bin') | weights_path = _get_file_name_base_on_postfix(pretrained_model_dir, '.bin') | ||||
state_dict = torch.load(weights_path, map_location='cpu') | state_dict = torch.load(weights_path, map_location='cpu') | ||||
@@ -455,20 +537,40 @@ class BertModel(nn.Module): | |||||
logger.error(f'Cannot load parameters through `state_dict` variable.') | logger.error(f'Cannot load parameters through `state_dict` variable.') | ||||
raise RuntimeError(f'Cannot load parameters through `state_dict` variable.') | raise RuntimeError(f'Cannot load parameters through `state_dict` variable.') | ||||
model_type = 'BERT' | |||||
old_keys = [] | |||||
new_keys = [] | |||||
for key in state_dict.keys(): | |||||
new_key = None | |||||
for key_name in BERT_KEY_RENAME_MAP_1: | |||||
if key_name in key: | |||||
new_key = key.replace(key_name, BERT_KEY_RENAME_MAP_1[key_name]) | |||||
if 'distilbert' in key: | |||||
model_type = 'DistilBert' | |||||
break | |||||
if new_key: | |||||
old_keys.append(key) | |||||
new_keys.append(new_key) | |||||
for old_key, new_key in zip(old_keys, new_keys): | |||||
state_dict[new_key] = state_dict.pop(old_key) | |||||
old_keys = [] | old_keys = [] | ||||
new_keys = [] | new_keys = [] | ||||
for key in state_dict.keys(): | for key in state_dict.keys(): | ||||
new_key = None | new_key = None | ||||
if 'gamma' in key: | |||||
new_key = key.replace('gamma', 'weight') | |||||
if 'beta' in key: | |||||
new_key = key.replace('beta', 'bias') | |||||
for key_name in BERT_KEY_RENAME_MAP_2: | |||||
if key_name in key: | |||||
new_key = key.replace(key_name, BERT_KEY_RENAME_MAP_2[key_name]) | |||||
break | |||||
if new_key: | if new_key: | ||||
old_keys.append(key) | old_keys.append(key) | ||||
new_keys.append(new_key) | new_keys.append(new_key) | ||||
for old_key, new_key in zip(old_keys, new_keys): | for old_key, new_key in zip(old_keys, new_keys): | ||||
state_dict[new_key] = state_dict.pop(old_key) | state_dict[new_key] = state_dict.pop(old_key) | ||||
# Instantiate model. | |||||
model = cls(config, model_type=model_type, *inputs, **kwargs) | |||||
missing_keys = [] | missing_keys = [] | ||||
unexpected_keys = [] | unexpected_keys = [] | ||||
error_msgs = [] | error_msgs = [] | ||||
@@ -494,7 +596,7 @@ class BertModel(nn.Module): | |||||
logger.warning("Weights from pretrained model not used in {}: {}".format( | logger.warning("Weights from pretrained model not used in {}: {}".format( | ||||
model.__class__.__name__, unexpected_keys)) | model.__class__.__name__, unexpected_keys)) | ||||
logger.info(f"Load pre-trained BERT parameters from file {weights_path}.") | |||||
logger.info(f"Load pre-trained {model_type} parameters from file {weights_path}.") | |||||
return model | return model | ||||
@@ -11,9 +11,12 @@ with open('LICENSE', encoding='utf-8') as f: | |||||
with open('requirements.txt', encoding='utf-8') as f: | with open('requirements.txt', encoding='utf-8') as f: | ||||
reqs = f.read() | reqs = f.read() | ||||
pkgs = [p for p in find_packages() if p.startswith('fastNLP')] | |||||
print(pkgs) | |||||
setup( | setup( | ||||
name='FastNLP', | name='FastNLP', | ||||
version='0.4.9', | |||||
version='0.4.10', | |||||
url='https://github.com/fastnlp/fastNLP', | url='https://github.com/fastnlp/fastNLP', | ||||
description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', | description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', | ||||
long_description=readme, | long_description=readme, | ||||
@@ -21,6 +24,6 @@ setup( | |||||
license='Apache License', | license='Apache License', | ||||
author='FudanNLP', | author='FudanNLP', | ||||
python_requires='>=3.6', | python_requires='>=3.6', | ||||
packages=find_packages(), | |||||
packages=pkgs, | |||||
install_requires=reqs.strip().split('\n'), | install_requires=reqs.strip().split('\n'), | ||||
) | ) |
@@ -238,7 +238,7 @@ class TestDataSetIter(unittest.TestCase): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | ||||
for iter in ds: | for iter in ds: | ||||
self.assertEqual(iter.__repr__(), """+--------------+--------+ | self.assertEqual(iter.__repr__(), """+--------------+--------+ | ||||
| x | y | | |||||
| x | y | | |||||
+--------------+--------+ | +--------------+--------+ | ||||
| [1, 2, 3, 4] | [5, 6] | | | [1, 2, 3, 4] | [5, 6] | | ||||
+--------------+--------+""") | +--------------+--------+""") | ||||
@@ -6,3 +6,4 @@ premise hypo label | |||||
一 段 时间 来 看 , 这 一 运动 似乎 要 取得 成功 , 但 政治 事件 , 加 上 帕内尔 在 一个 令 人 愤慨 的 离婚案 中 被 称为 共同 答辩人 , 导致 许多 人 撤回 他们 的 支持 . 帕内尔 在 一个 令 人 愤慨 的 离婚 问题 上 的 法律 问题 使 这 场 运动 受到 了 影响 . entailment | 一 段 时间 来 看 , 这 一 运动 似乎 要 取得 成功 , 但 政治 事件 , 加 上 帕内尔 在 一个 令 人 愤慨 的 离婚案 中 被 称为 共同 答辩人 , 导致 许多 人 撤回 他们 的 支持 . 帕内尔 在 一个 令 人 愤慨 的 离婚 问题 上 的 法律 问题 使 这 场 运动 受到 了 影响 . entailment | ||||
看 在 这里 , 他 说 我们 不 希望 任何 律师 混在 这 一 点 . 他 说 看看 那 张 纸 neutral | 看 在 这里 , 他 说 我们 不 希望 任何 律师 混在 这 一 点 . 他 说 看看 那 张 纸 neutral | ||||
Soderstrom 在 创伤 中心 进行 了 多次 筛选 测试 . 测试 必须 在 创伤 中心 进行 比较 , 否则 就 会 无效 . neutral | Soderstrom 在 创伤 中心 进行 了 多次 筛选 测试 . 测试 必须 在 创伤 中心 进行 比较 , 否则 就 会 无效 . neutral | ||||
嗯 , 这 是 一 种 明显 的 我 的 意思 是 , 他们 甚至 把 它 带 到 现在 呢 , 他们 在 电视 上 做 广告 , 你 知道 如果 你 知道 , 如果 你 知道 这样 做 , 或者 如果 你 需要 这 个呃 , 我们 会 告 你 和 你 你 不用 给 我们 钱 , 但 他们 不 告诉 你 的 是 如果 他们 赢 了 你 给 他们 至少 三分之一 他们 赢 的 东西 , 所以 我 不 知道 它 是呃 , 它 得到 了 现在 做 更 多 的 生意 , 而 不 是呃 实际上 是 在 处理 犯罪 而 不 是 与 呃嗯 他们 的 律师 只 是 为了 钱 , 我 相信 , 我 知道 我 同意 你 , 我 认为 你 是 真实 的 你. 非常 正确 的 是 , 我 认为 他们 应该 有 同等 数量 的 你 知道 也许 他们 可以 有 几 个 , 但 我 认为 大多数 他们 应该 不 是 律师 在 事实 , 这 是 方式 他们 已经 进入 政治 , 这 是 因为 在 法律 上 , 你 知道 的 循环 和 一切 , 但 我 不 知道 我们 是 在 马里兰州 和呃 , 我们 有 同样 的 东西 人满为患 , 和呃 他们 让 他们 出来 我 的 意思 是 只 是 普通 的 监狱 判决 的 事情 , 他们 让. 他们 是 因为 他们 没有 任何 地方 可以 留住 他们 所以 你 可以 知道呃 , 除非 是 一个 重大 的 罪行 , 但呃 , 即使 是 小小的 东西 , 我 的 意思 是 那些 在 美国 失去 的 人 是 受害者 和 谁 可能 是 抢劫 或 毒品 , 或者 其他 什么 , 他们 是 谁 要 支付 , 他们 是 一个 会 受苦 , 另 一个 你 知道 的 人 , 如果 他们 被 逮捕 , 如果 他们 逮捕 他们嗯 , 然后 呢 , 你 知道 的 时间 法律 接管 了 一 半 时间 呃 他们 要么 让 他们 走 , 或者 他们 下 了 一个 句子 , 因为 他们 有 一个 律师 , 你 知道 的 感觉 他们 是 不 是 所有 在 一起 当 他们 做到 了 .它 我 不 知道 我们 怎么 到 这 一 点 , 虽然 . neutral |
@@ -31,7 +31,7 @@ class TestMatchingLoad(unittest.TestCase): | |||||
'MNLI': ('test/data_for_tests/io/MNLI', MNLILoader, (5, 5, 5, 5, 6), True), | 'MNLI': ('test/data_for_tests/io/MNLI', MNLILoader, (5, 5, 5, 5, 6), True), | ||||
'Quora': ('test/data_for_tests/io/Quora', QuoraLoader, (2, 2, 2), False), | 'Quora': ('test/data_for_tests/io/Quora', QuoraLoader, (2, 2, 2), False), | ||||
'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusLoader, (5, 5, 5), False), | 'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusLoader, (5, 5, 5), False), | ||||
'XNLI': ('test/data_for_tests/io/XNLI', CNXNLILoader, (6, 7, 6), False), | |||||
'XNLI': ('test/data_for_tests/io/XNLI', CNXNLILoader, (6, 8, 6), False), | |||||
'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCLoader, (5, 6, 6), False), | 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCLoader, (5, 6, 6), False), | ||||
} | } | ||||
for k, v in data_set_dict.items(): | for k, v in data_set_dict.items(): | ||||
@@ -38,7 +38,7 @@ class TestRunMatchingPipe(unittest.TestCase): | |||||
'QNLI': ('test/data_for_tests/io/QNLI', QNLIPipe, QNLIBertPipe, (5, 5, 5), (372, 2), True), | 'QNLI': ('test/data_for_tests/io/QNLI', QNLIPipe, QNLIBertPipe, (5, 5, 5), (372, 2), True), | ||||
'MNLI': ('test/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True), | 'MNLI': ('test/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True), | ||||
'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False), | 'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False), | ||||
'XNLI': ('test/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 7, 6), (37, 3), False), | |||||
'XNLI': ('test/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 8, 6), (39, 3), False), | |||||
'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (5, 6, 6), (36, 2), False), | 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (5, 6, 6), (36, 2), False), | ||||
} | } | ||||
for k, v in data_set_dict.items(): | for k, v in data_set_dict.items(): | ||||