@@ -89,13 +89,15 @@ class MatchingBertPipe(Pipe): | |||||
data_bundle.set_vocab(word_vocab, Const.INPUT) | data_bundle.set_vocab(word_vocab, Const.INPUT) | ||||
data_bundle.set_vocab(target_vocab, Const.TARGET) | data_bundle.set_vocab(target_vocab, Const.TARGET) | ||||
input_fields = [Const.INPUT, Const.INPUT_LEN, Const.TARGET] | |||||
input_fields = [Const.INPUT, Const.INPUT_LEN] | |||||
target_fields = [Const.TARGET] | target_fields = [Const.TARGET] | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.add_seq_len(Const.INPUT) | dataset.add_seq_len(Const.INPUT) | ||||
dataset.set_input(*input_fields, flag=True) | dataset.set_input(*input_fields, flag=True) | ||||
dataset.set_target(*target_fields, flag=True) | |||||
for fields in target_fields: | |||||
if dataset.has_field(fields): | |||||
dataset.set_target(fields, flag=True) | |||||
return data_bundle | return data_bundle | ||||
@@ -210,14 +212,16 @@ class MatchingPipe(Pipe): | |||||
data_bundle.set_vocab(word_vocab, Const.INPUTS(0)) | data_bundle.set_vocab(word_vocab, Const.INPUTS(0)) | ||||
data_bundle.set_vocab(target_vocab, Const.TARGET) | data_bundle.set_vocab(target_vocab, Const.TARGET) | ||||
input_fields = [Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1), Const.TARGET] | |||||
input_fields = [Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1)] | |||||
target_fields = [Const.TARGET] | target_fields = [Const.TARGET] | ||||
for name, dataset in data_bundle.datasets.items(): | for name, dataset in data_bundle.datasets.items(): | ||||
dataset.add_seq_len(Const.INPUTS(0), Const.INPUT_LENS(0)) | dataset.add_seq_len(Const.INPUTS(0), Const.INPUT_LENS(0)) | ||||
dataset.add_seq_len(Const.INPUTS(1), Const.INPUT_LENS(1)) | dataset.add_seq_len(Const.INPUTS(1), Const.INPUT_LENS(1)) | ||||
dataset.set_input(*input_fields, flag=True) | dataset.set_input(*input_fields, flag=True) | ||||
dataset.set_target(*target_fields, flag=True) | |||||
for fields in target_fields: | |||||
if dataset.has_field(fields): | |||||
dataset.set_target(fields, flag=True) | |||||
return data_bundle | return data_bundle | ||||
@@ -1,435 +0,0 @@ | |||||
""" | |||||
这个文件的内容已合并到fastNLP.io.data_loader里,这个文件的内容不再更新 | |||||
""" | |||||
import os | |||||
from typing import Union, Dict | |||||
from fastNLP.core.const import Const | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.io.data_bundle import DataBundle, DataSetLoader | |||||
from fastNLP.io.dataset_loader import JsonLoader, CSVLoader | |||||
from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||||
from fastNLP.modules.encoder._bert import BertTokenizer | |||||
class MatchingLoader(DataSetLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` | |||||
读取Matching任务的数据集 | |||||
:param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名 | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
self.paths = paths | |||||
def _load(self, path): | |||||
""" | |||||
:param str path: 待读取数据集的路径名 | |||||
:return: fastNLP.DataSet ds: 返回一个DataSet对象,里面必须包含3个field:其中两个分别为两个句子 | |||||
的原始字符串文本,第三个为标签 | |||||
""" | |||||
raise NotImplementedError | |||||
def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None, | |||||
to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, | |||||
cut_text: int = None, get_index=True, auto_pad_length: int=None, | |||||
auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True, | |||||
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataBundle: | |||||
""" | |||||
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, | |||||
则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 | |||||
对应的全路径文件名。 | |||||
:param str dataset_name: 如果在paths里传入的是一个数据集的全路径文件名,那么可以用dataset_name来定义 | |||||
这个数据集的名字,如果不定义则默认为train。 | |||||
:param bool to_lower: 是否将文本自动转为小写。默认值为False。 | |||||
:param str seq_len_type: 提供的seq_len类型,支持 ``seq_len`` :提供一个数字作为句子长度; ``mask`` : | |||||
提供一个0/1的mask矩阵作为句子长度; ``bert`` :提供segment_type_id(第一个句子为0,第二个句子为1)和 | |||||
attention mask矩阵(0/1的mask矩阵)。默认值为None,即不提供seq_len | |||||
:param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径 | |||||
:param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。 | |||||
:param bool get_index: 是否需要根据词表将文本转为index | |||||
:param int auto_pad_length: 是否需要将文本自动pad到一定长度(超过这个长度的文本将会被截掉),默认为不会自动pad | |||||
:param str auto_pad_token: 自动pad的内容 | |||||
:param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False | |||||
则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input, | |||||
于此同时其他field不会被设置为input。默认值为True。 | |||||
:param set_target: set_target将控制哪些field可以被设置为target,用法与set_input一致。默认值为True。 | |||||
:param concat: 是否需要将两个句子拼接起来。如果为False则不会拼接。如果为True则会在两个句子之间插入一个<sep>。 | |||||
如果传入一个长度为4的list,则分别表示插在第一句开始前、第一句结束后、第二句开始前、第二句结束后的标识符。如果 | |||||
传入字符串 ``bert`` ,则会采用bert的拼接方式,等价于['[CLS]', '[SEP]', '', '[SEP]']. | |||||
:return: | |||||
""" | |||||
if isinstance(set_input, str): | |||||
set_input = [set_input] | |||||
if isinstance(set_target, str): | |||||
set_target = [set_target] | |||||
if isinstance(set_input, bool): | |||||
auto_set_input = set_input | |||||
else: | |||||
auto_set_input = False | |||||
if isinstance(set_target, bool): | |||||
auto_set_target = set_target | |||||
else: | |||||
auto_set_target = False | |||||
if isinstance(paths, str): | |||||
if os.path.isdir(paths): | |||||
path = {n: os.path.join(paths, self.paths[n]) for n in self.paths.keys()} | |||||
else: | |||||
path = {dataset_name if dataset_name is not None else 'train': paths} | |||||
else: | |||||
path = paths | |||||
data_info = DataBundle() | |||||
for data_name in path.keys(): | |||||
data_info.datasets[data_name] = self._load(path[data_name]) | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
if auto_set_input: | |||||
data_set.set_input(Const.INPUTS(0), Const.INPUTS(1)) | |||||
if auto_set_target: | |||||
if Const.TARGET in data_set.get_field_names(): | |||||
data_set.set_target(Const.TARGET) | |||||
if to_lower: | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0), | |||||
is_input=auto_set_input) | |||||
data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1), | |||||
is_input=auto_set_input) | |||||
if bert_tokenizer is not None: | |||||
if bert_tokenizer.lower() in PRETRAINED_BERT_MODEL_DIR: | |||||
PRETRAIN_URL = _get_base_url('bert') | |||||
model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer] | |||||
model_url = PRETRAIN_URL + model_name | |||||
model_dir = cached_path(model_url) | |||||
# 检查是否存在 | |||||
elif os.path.isdir(bert_tokenizer): | |||||
model_dir = bert_tokenizer | |||||
else: | |||||
raise ValueError(f"Cannot recognize BERT tokenizer from {bert_tokenizer}.") | |||||
words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]') | |||||
with open(os.path.join(model_dir, 'vocab.txt'), 'r') as f: | |||||
lines = f.readlines() | |||||
lines = [line.strip() for line in lines] | |||||
words_vocab.add_word_lst(lines) | |||||
words_vocab.build_vocab() | |||||
tokenizer = BertTokenizer.from_pretrained(model_dir) | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: tokenizer.tokenize(' '.join(x[fields])), new_field_name=fields, | |||||
is_input=auto_set_input) | |||||
if isinstance(concat, bool): | |||||
concat = 'default' if concat else None | |||||
if concat is not None: | |||||
if isinstance(concat, str): | |||||
CONCAT_MAP = {'bert': ['[CLS]', '[SEP]', '', '[SEP]'], | |||||
'default': ['', '<sep>', '', '']} | |||||
if concat.lower() in CONCAT_MAP: | |||||
concat = CONCAT_MAP[concat] | |||||
else: | |||||
concat = 4 * [concat] | |||||
assert len(concat) == 4, \ | |||||
f'Please choose a list with 4 symbols which at the beginning of first sentence ' \ | |||||
f'the end of first sentence, the begin of second sentence, and the end of second' \ | |||||
f'sentence. Your input is {concat}' | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
data_set.apply(lambda x: [concat[0]] + x[Const.INPUTS(0)] + [concat[1]] + [concat[2]] + | |||||
x[Const.INPUTS(1)] + [concat[3]], new_field_name=Const.INPUT) | |||||
data_set.apply(lambda x: [w for w in x[Const.INPUT] if len(w) > 0], new_field_name=Const.INPUT, | |||||
is_input=auto_set_input) | |||||
if seq_len_type is not None: | |||||
if seq_len_type == 'seq_len': # | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: len(x[fields]), | |||||
new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN), | |||||
is_input=auto_set_input) | |||||
elif seq_len_type == 'mask': | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: [1] * len(x[fields]), | |||||
new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN), | |||||
is_input=auto_set_input) | |||||
elif seq_len_type == 'bert': | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
if Const.INPUT not in data_set.get_field_names(): | |||||
raise KeyError(f'Field ``{Const.INPUT}`` not in {data_name} data set: ' | |||||
f'got {data_set.get_field_names()}') | |||||
data_set.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), | |||||
new_field_name=Const.INPUT_LENS(0), is_input=auto_set_input) | |||||
data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), | |||||
new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) | |||||
if auto_pad_length is not None: | |||||
cut_text = min(auto_pad_length, cut_text if cut_text is not None else auto_pad_length) | |||||
if cut_text is not None: | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if (Const.INPUT in fields) or ((Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len')): | |||||
data_set.apply(lambda x: x[fields][: cut_text], new_field_name=fields, | |||||
is_input=auto_set_input) | |||||
data_set_list = [d for n, d in data_info.datasets.items()] | |||||
assert len(data_set_list) > 0, f'There are NO data sets in data info!' | |||||
if bert_tokenizer is None: | |||||
words_vocab = Vocabulary(padding=auto_pad_token) | |||||
words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], | |||||
field_name=[n for n in data_set_list[0].get_field_names() | |||||
if (Const.INPUT in n)], | |||||
no_create_entry_dataset=[d for n, d in data_info.datasets.items() | |||||
if 'train' not in n]) | |||||
target_vocab = Vocabulary(padding=None, unknown=None) | |||||
target_vocab = target_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], | |||||
field_name=Const.TARGET) | |||||
data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab} | |||||
if get_index: | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: [words_vocab.to_index(w) for w in x[fields]], new_field_name=fields, | |||||
is_input=auto_set_input) | |||||
if Const.TARGET in data_set.get_field_names(): | |||||
data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, | |||||
is_input=auto_set_input, is_target=auto_set_target) | |||||
if auto_pad_length is not None: | |||||
if seq_len_type == 'seq_len': | |||||
raise RuntimeError(f'the sequence will be padded with the length {auto_pad_length}, ' | |||||
f'so the seq_len_type cannot be `{seq_len_type}`!') | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] * | |||||
(auto_pad_length - len(x[fields])), new_field_name=fields, | |||||
is_input=auto_set_input) | |||||
elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'): | |||||
data_set.apply(lambda x: x[fields] + [0] * (auto_pad_length - len(x[fields])), | |||||
new_field_name=fields, is_input=auto_set_input) | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
if isinstance(set_input, list): | |||||
data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()]) | |||||
if isinstance(set_target, list): | |||||
data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()]) | |||||
return data_info | |||||
class SNLILoader(MatchingLoader, JsonLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` | |||||
读取SNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
fields = { | |||||
'sentence1_binary_parse': Const.INPUTS(0), | |||||
'sentence2_binary_parse': Const.INPUTS(1), | |||||
'gold_label': Const.TARGET, | |||||
} | |||||
paths = paths if paths is not None else { | |||||
'train': 'snli_1.0_train.jsonl', | |||||
'dev': 'snli_1.0_dev.jsonl', | |||||
'test': 'snli_1.0_test.jsonl'} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
JsonLoader.__init__(self, fields=fields) | |||||
def _load(self, path): | |||||
ds = JsonLoader._load(self, path) | |||||
parentheses_table = str.maketrans({'(': None, ')': None}) | |||||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||||
new_field_name=Const.INPUTS(0)) | |||||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||||
new_field_name=Const.INPUTS(1)) | |||||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||||
return ds | |||||
class RTELoader(MatchingLoader, CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.dataset_loader.RTELoader` | |||||
读取RTE数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
paths = paths if paths is not None else { | |||||
'train': 'train.tsv', | |||||
'dev': 'dev.tsv', | |||||
'test': 'test.tsv' # test set has not label | |||||
} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
self.fields = { | |||||
'sentence1': Const.INPUTS(0), | |||||
'sentence2': Const.INPUTS(1), | |||||
'label': Const.TARGET, | |||||
} | |||||
CSVLoader.__init__(self, sep='\t') | |||||
def _load(self, path): | |||||
ds = CSVLoader._load(self, path) | |||||
for k, v in self.fields.items(): | |||||
if v in ds.get_field_names(): | |||||
ds.rename_field(k, v) | |||||
for fields in ds.get_all_fields(): | |||||
if Const.INPUT in fields: | |||||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||||
return ds | |||||
class QNLILoader(MatchingLoader, CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.dataset_loader.QNLILoader` | |||||
读取QNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
paths = paths if paths is not None else { | |||||
'train': 'train.tsv', | |||||
'dev': 'dev.tsv', | |||||
'test': 'test.tsv' # test set has not label | |||||
} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
self.fields = { | |||||
'question': Const.INPUTS(0), | |||||
'sentence': Const.INPUTS(1), | |||||
'label': Const.TARGET, | |||||
} | |||||
CSVLoader.__init__(self, sep='\t') | |||||
def _load(self, path): | |||||
ds = CSVLoader._load(self, path) | |||||
for k, v in self.fields.items(): | |||||
if v in ds.get_field_names(): | |||||
ds.rename_field(k, v) | |||||
for fields in ds.get_all_fields(): | |||||
if Const.INPUT in fields: | |||||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||||
return ds | |||||
class MNLILoader(MatchingLoader, CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader` | |||||
读取MNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
paths = paths if paths is not None else { | |||||
'train': 'train.tsv', | |||||
'dev_matched': 'dev_matched.tsv', | |||||
'dev_mismatched': 'dev_mismatched.tsv', | |||||
'test_matched': 'test_matched.tsv', | |||||
'test_mismatched': 'test_mismatched.tsv', | |||||
# 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', | |||||
# 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', | |||||
# test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) | |||||
} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
CSVLoader.__init__(self, sep='\t') | |||||
self.fields = { | |||||
'sentence1_binary_parse': Const.INPUTS(0), | |||||
'sentence2_binary_parse': Const.INPUTS(1), | |||||
'gold_label': Const.TARGET, | |||||
} | |||||
def _load(self, path): | |||||
ds = CSVLoader._load(self, path) | |||||
for k, v in self.fields.items(): | |||||
if k in ds.get_field_names(): | |||||
ds.rename_field(k, v) | |||||
if Const.TARGET in ds.get_field_names(): | |||||
if ds[0][Const.TARGET] == 'hidden': | |||||
ds.delete_field(Const.TARGET) | |||||
parentheses_table = str.maketrans({'(': None, ')': None}) | |||||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||||
new_field_name=Const.INPUTS(0)) | |||||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||||
new_field_name=Const.INPUTS(1)) | |||||
if Const.TARGET in ds.get_field_names(): | |||||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||||
return ds | |||||
class QuoraLoader(MatchingLoader, CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.dataset_loader.QuoraLoader` | |||||
读取MNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
paths = paths if paths is not None else { | |||||
'train': 'train.tsv', | |||||
'dev': 'dev.tsv', | |||||
'test': 'test.tsv', | |||||
} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID')) | |||||
def _load(self, path): | |||||
ds = CSVLoader._load(self, path) | |||||
return ds |
@@ -2,8 +2,12 @@ import random | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from fastNLP.core import Trainer, Tester, AccuracyMetric, Const, Adam | |||||
from fastNLP.io.data_loader import SNLILoader, RTELoader, MNLILoader, QNLILoader, QuoraLoader | |||||
from fastNLP.core import Trainer, Tester, AccuracyMetric, Const | |||||
from fastNLP.core.callback import WarmupCallback, EvaluateCallback | |||||
from fastNLP.core.optimizer import AdamW | |||||
from fastNLP.embeddings import BertEmbedding | |||||
from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, MNLIBertPipe,\ | |||||
QNLIBertPipe, QuoraBertPipe | |||||
from reproduction.matching.model.bert import BertForNLI | from reproduction.matching.model.bert import BertForNLI | ||||
@@ -12,16 +16,22 @@ from reproduction.matching.model.bert import BertForNLI | |||||
class BERTConfig: | class BERTConfig: | ||||
task = 'snli' | task = 'snli' | ||||
batch_size_per_gpu = 6 | batch_size_per_gpu = 6 | ||||
n_epochs = 6 | n_epochs = 6 | ||||
lr = 2e-5 | lr = 2e-5 | ||||
seq_len_type = 'bert' | |||||
warm_up_rate = 0.1 | |||||
seed = 42 | seed = 42 | ||||
save_path = None # 模型存储的位置,None表示不存储模型。 | |||||
train_dataset_name = 'train' | train_dataset_name = 'train' | ||||
dev_dataset_name = 'dev' | dev_dataset_name = 'dev' | ||||
test_dataset_name = 'test' | test_dataset_name = 'test' | ||||
save_path = None # 模型存储的位置,None表示不存储模型。 | |||||
bert_dir = 'path/to/bert/dir' # 预训练BERT参数文件的文件夹 | |||||
to_lower = True # 忽略大小写 | |||||
tokenizer = 'spacy' # 使用spacy进行分词 | |||||
bert_model_dir_or_name = 'bert-base-uncased' | |||||
arg = BERTConfig() | arg = BERTConfig() | ||||
@@ -37,58 +47,52 @@ if n_gpu > 0: | |||||
# load data set | # load data set | ||||
if arg.task == 'snli': | if arg.task == 'snli': | ||||
data_info = SNLILoader().process( | |||||
paths='path/to/snli/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||||
get_index=True, concat='bert', | |||||
) | |||||
data_bundle = SNLIBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() | |||||
elif arg.task == 'rte': | elif arg.task == 'rte': | ||||
data_info = RTELoader().process( | |||||
paths='path/to/rte/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||||
get_index=True, concat='bert', | |||||
) | |||||
data_bundle = RTEBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() | |||||
elif arg.task == 'qnli': | elif arg.task == 'qnli': | ||||
data_info = QNLILoader().process( | |||||
paths='path/to/qnli/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||||
get_index=True, concat='bert', | |||||
) | |||||
data_bundle = QNLIBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() | |||||
elif arg.task == 'mnli': | elif arg.task == 'mnli': | ||||
data_info = MNLILoader().process( | |||||
paths='path/to/mnli/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||||
get_index=True, concat='bert', | |||||
) | |||||
data_bundle = MNLIBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() | |||||
elif arg.task == 'quora': | elif arg.task == 'quora': | ||||
data_info = QuoraLoader().process( | |||||
paths='path/to/quora/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||||
get_index=True, concat='bert', | |||||
) | |||||
data_bundle = QuoraBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() | |||||
else: | else: | ||||
raise RuntimeError(f'NOT support {arg.task} task yet!') | raise RuntimeError(f'NOT support {arg.task} task yet!') | ||||
print(data_bundle) # print details in data_bundle | |||||
# load embedding | |||||
embed = BertEmbedding(data_bundle.vocabs[Const.INPUT], model_dir_or_name=arg.bert_model_dir_or_name) | |||||
# define model | # define model | ||||
model = BertForNLI(class_num=len(data_info.vocabs[Const.TARGET]), bert_dir=arg.bert_dir) | |||||
model = BertForNLI(embed, class_num=len(data_bundle.vocabs[Const.TARGET])) | |||||
# define optimizer and callback | |||||
optimizer = AdamW(lr=arg.lr, params=model.parameters()) | |||||
callbacks = [WarmupCallback(warmup=arg.warm_up_rate, schedule='linear'), ] | |||||
if arg.task in ['snli']: | |||||
callbacks.append(EvaluateCallback(data=data_bundle.datasets[arg.test_dataset_name])) | |||||
# evaluate test set in every epoch if task is snli. | |||||
# define trainer | # define trainer | ||||
trainer = Trainer(train_data=data_info.datasets[arg.train_dataset_name], model=model, | |||||
optimizer=Adam(lr=arg.lr, model_params=model.parameters()), | |||||
trainer = Trainer(train_data=data_bundle.datasets[arg.train_dataset_name], model=model, | |||||
optimizer=optimizer, | |||||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | ||||
n_epochs=arg.n_epochs, print_every=-1, | n_epochs=arg.n_epochs, print_every=-1, | ||||
dev_data=data_info.datasets[arg.dev_dataset_name], | |||||
dev_data=data_bundle.datasets[arg.dev_dataset_name], | |||||
metrics=AccuracyMetric(), metric_key='acc', | metrics=AccuracyMetric(), metric_key='acc', | ||||
device=[i for i in range(torch.cuda.device_count())], | device=[i for i in range(torch.cuda.device_count())], | ||||
check_code_level=-1, | check_code_level=-1, | ||||
save_path=arg.save_path) | |||||
save_path=arg.save_path, | |||||
callbacks=callbacks) | |||||
# train model | # train model | ||||
trainer.train(load_best_model=True) | trainer.train(load_best_model=True) | ||||
# define tester | # define tester | ||||
tester = Tester( | tester = Tester( | ||||
data=data_info.datasets[arg.test_dataset_name], | |||||
data=data_bundle.datasets[arg.test_dataset_name], | |||||
model=model, | model=model, | ||||
metrics=AccuracyMetric(), | metrics=AccuracyMetric(), | ||||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | ||||
@@ -1,9 +1,9 @@ | |||||
import argparse | import argparse | ||||
import torch | import torch | ||||
from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const | |||||
from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const, CrossEntropyLoss | |||||
from fastNLP.embeddings import StaticEmbedding | from fastNLP.embeddings import StaticEmbedding | ||||
from fastNLP.io.data_loader import QNLILoader, RTELoader, SNLILoader, MNLILoader | |||||
from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, MNLIPipe, QNLIPipe | |||||
from reproduction.matching.model.cntn import CNTNModel | from reproduction.matching.model.cntn import CNTNModel | ||||
@@ -13,14 +13,12 @@ argument.add_argument('--embedding', choices=['glove', 'word2vec'], default='glo | |||||
argument.add_argument('--batch-size-per-gpu', type=int, default=256) | argument.add_argument('--batch-size-per-gpu', type=int, default=256) | ||||
argument.add_argument('--n-epochs', type=int, default=200) | argument.add_argument('--n-epochs', type=int, default=200) | ||||
argument.add_argument('--lr', type=float, default=1e-5) | argument.add_argument('--lr', type=float, default=1e-5) | ||||
argument.add_argument('--seq-len-type', choices=['mask', 'seq_len'], default='mask') | |||||
argument.add_argument('--save-dir', type=str, default=None) | argument.add_argument('--save-dir', type=str, default=None) | ||||
argument.add_argument('--cntn-depth', type=int, default=1) | argument.add_argument('--cntn-depth', type=int, default=1) | ||||
argument.add_argument('--cntn-ns', type=int, default=200) | argument.add_argument('--cntn-ns', type=int, default=200) | ||||
argument.add_argument('--cntn-k-top', type=int, default=10) | argument.add_argument('--cntn-k-top', type=int, default=10) | ||||
argument.add_argument('--cntn-r', type=int, default=5) | argument.add_argument('--cntn-r', type=int, default=5) | ||||
argument.add_argument('--dataset', choices=['qnli', 'rte', 'snli', 'mnli'], default='qnli') | argument.add_argument('--dataset', choices=['qnli', 'rte', 'snli', 'mnli'], default='qnli') | ||||
argument.add_argument('--max-len', type=int, default=50) | |||||
arg = argument.parse_args() | arg = argument.parse_args() | ||||
# dataset dict | # dataset dict | ||||
@@ -45,30 +43,25 @@ else: | |||||
num_labels = 3 | num_labels = 3 | ||||
# load data set | # load data set | ||||
if arg.dataset == 'qnli': | |||||
data_info = QNLILoader().process( | |||||
paths='path/to/qnli/data', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, | |||||
get_index=True, concat=False, auto_pad_length=arg.max_len) | |||||
if arg.dataset == 'snli': | |||||
data_bundle = SNLIPipe(lower=True, tokenizer='raw').process_from_file() | |||||
elif arg.dataset == 'rte': | elif arg.dataset == 'rte': | ||||
data_info = RTELoader().process( | |||||
paths='path/to/rte/data', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, | |||||
get_index=True, concat=False, auto_pad_length=arg.max_len) | |||||
elif arg.dataset == 'snli': | |||||
data_info = SNLILoader().process( | |||||
paths='path/to/snli/data', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, | |||||
get_index=True, concat=False, auto_pad_length=arg.max_len) | |||||
data_bundle = RTEPipe(lower=True, tokenizer='raw').process_from_file() | |||||
elif arg.dataset == 'qnli': | |||||
data_bundle = QNLIPipe(lower=True, tokenizer='raw').process_from_file() | |||||
elif arg.dataset == 'mnli': | elif arg.dataset == 'mnli': | ||||
data_info = MNLILoader().process( | |||||
paths='path/to/mnli/data', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, | |||||
get_index=True, concat=False, auto_pad_length=arg.max_len) | |||||
data_bundle = MNLIPipe(lower=True, tokenizer='raw').process_from_file() | |||||
else: | else: | ||||
raise ValueError(f'now we only support [qnli,rte,snli,mnli] dataset for cntn model!') | |||||
raise RuntimeError(f'NOT support {arg.task} task yet!') | |||||
print(data_bundle) # print details in data_bundle | |||||
# load embedding | # load embedding | ||||
if arg.embedding == 'word2vec': | if arg.embedding == 'word2vec': | ||||
embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], model_dir_or_name='en-word2vec-300', requires_grad=True) | |||||
embedding = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-word2vec-300', | |||||
requires_grad=True) | |||||
elif arg.embedding == 'glove': | elif arg.embedding == 'glove': | ||||
embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], model_dir_or_name='en-glove-840b-300', | |||||
embedding = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-glove-840b-300d', | |||||
requires_grad=True) | requires_grad=True) | ||||
else: | else: | ||||
raise ValueError(f'now we only support word2vec or glove embedding for cntn model!') | raise ValueError(f'now we only support word2vec or glove embedding for cntn model!') | ||||
@@ -79,11 +72,12 @@ model = CNTNModel(embedding, ns=arg.cntn_ns, k_top=arg.cntn_k_top, num_labels=nu | |||||
print(model) | print(model) | ||||
# define trainer | # define trainer | ||||
trainer = Trainer(train_data=data_info.datasets['train'], model=model, | |||||
trainer = Trainer(train_data=data_bundle.datasets['train'], model=model, | |||||
optimizer=Adam(lr=arg.lr, model_params=model.parameters()), | optimizer=Adam(lr=arg.lr, model_params=model.parameters()), | ||||
loss=CrossEntropyLoss(), | |||||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | ||||
n_epochs=arg.n_epochs, print_every=-1, | n_epochs=arg.n_epochs, print_every=-1, | ||||
dev_data=data_info.datasets[dev_dict[arg.dataset]], | |||||
dev_data=data_bundle.datasets[dev_dict[arg.dataset]], | |||||
metrics=AccuracyMetric(), metric_key='acc', | metrics=AccuracyMetric(), metric_key='acc', | ||||
device=[i for i in range(torch.cuda.device_count())], | device=[i for i in range(torch.cuda.device_count())], | ||||
check_code_level=-1) | check_code_level=-1) | ||||
@@ -93,7 +87,7 @@ trainer.train(load_best_model=True) | |||||
# define tester | # define tester | ||||
tester = Tester( | tester = Tester( | ||||
data=data_info.datasets[test_dict[arg.dataset]], | |||||
data=data_bundle.datasets[test_dict[arg.dataset]], | |||||
model=model, | model=model, | ||||
metrics=AccuracyMetric(), | metrics=AccuracyMetric(), | ||||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | ||||
@@ -6,10 +6,11 @@ from torch.optim import Adamax | |||||
from torch.optim.lr_scheduler import StepLR | from torch.optim.lr_scheduler import StepLR | ||||
from fastNLP.core import Trainer, Tester, AccuracyMetric, Const | from fastNLP.core import Trainer, Tester, AccuracyMetric, Const | ||||
from fastNLP.core.callback import GradientClipCallback, LRScheduler | |||||
from fastNLP.embeddings.static_embedding import StaticEmbedding | |||||
from fastNLP.embeddings.elmo_embedding import ElmoEmbedding | |||||
from fastNLP.io.data_loader import SNLILoader, RTELoader, MNLILoader, QNLILoader, QuoraLoader | |||||
from fastNLP.core.callback import GradientClipCallback, LRScheduler, EvaluateCallback | |||||
from fastNLP.core.losses import CrossEntropyLoss | |||||
from fastNLP.embeddings import StaticEmbedding | |||||
from fastNLP.embeddings import ElmoEmbedding | |||||
from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, MNLIPipe, QNLIPipe, QuoraPipe | |||||
from fastNLP.models.snli import ESIM | from fastNLP.models.snli import ESIM | ||||
@@ -17,18 +18,21 @@ from fastNLP.models.snli import ESIM | |||||
class ESIMConfig: | class ESIMConfig: | ||||
task = 'snli' | task = 'snli' | ||||
embedding = 'glove' | embedding = 'glove' | ||||
batch_size_per_gpu = 196 | batch_size_per_gpu = 196 | ||||
n_epochs = 30 | n_epochs = 30 | ||||
lr = 2e-3 | lr = 2e-3 | ||||
seq_len_type = 'seq_len' | |||||
# seq_len表示在process的时候用len(words)来表示长度信息; | |||||
# mask表示用0/1掩码矩阵来表示长度信息; | |||||
seed = 42 | seed = 42 | ||||
save_path = None # 模型存储的位置,None表示不存储模型。 | |||||
train_dataset_name = 'train' | train_dataset_name = 'train' | ||||
dev_dataset_name = 'dev' | dev_dataset_name = 'dev' | ||||
test_dataset_name = 'test' | test_dataset_name = 'test' | ||||
save_path = None # 模型存储的位置,None表示不存储模型。 | |||||
to_lower = True # 忽略大小写 | |||||
tokenizer = 'spacy' # 使用spacy进行分词 | |||||
arg = ESIMConfig() | arg = ESIMConfig() | ||||
@@ -44,43 +48,32 @@ if n_gpu > 0: | |||||
# load data set | # load data set | ||||
if arg.task == 'snli': | if arg.task == 'snli': | ||||
data_info = SNLILoader().process( | |||||
paths='path/to/snli/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||||
get_index=True, concat=False, | |||||
) | |||||
data_bundle = SNLIPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() | |||||
elif arg.task == 'rte': | elif arg.task == 'rte': | ||||
data_info = RTELoader().process( | |||||
paths='path/to/rte/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||||
get_index=True, concat=False, | |||||
) | |||||
data_bundle = RTEPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() | |||||
elif arg.task == 'qnli': | elif arg.task == 'qnli': | ||||
data_info = QNLILoader().process( | |||||
paths='path/to/qnli/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||||
get_index=True, concat=False, | |||||
) | |||||
data_bundle = QNLIPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() | |||||
elif arg.task == 'mnli': | elif arg.task == 'mnli': | ||||
data_info = MNLILoader().process( | |||||
paths='path/to/mnli/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||||
get_index=True, concat=False, | |||||
) | |||||
data_bundle = MNLIPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() | |||||
elif arg.task == 'quora': | elif arg.task == 'quora': | ||||
data_info = QuoraLoader().process( | |||||
paths='path/to/quora/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||||
get_index=True, concat=False, | |||||
) | |||||
data_bundle = QuoraPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() | |||||
else: | else: | ||||
raise RuntimeError(f'NOT support {arg.task} task yet!') | raise RuntimeError(f'NOT support {arg.task} task yet!') | ||||
print(data_bundle) # print details in data_bundle | |||||
# load embedding | # load embedding | ||||
if arg.embedding == 'elmo': | if arg.embedding == 'elmo': | ||||
embedding = ElmoEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True) | |||||
embedding = ElmoEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-medium', | |||||
requires_grad=True) | |||||
elif arg.embedding == 'glove': | elif arg.embedding == 'glove': | ||||
embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True, normalize=False) | |||||
embedding = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-glove-840b-300d', | |||||
requires_grad=True, normalize=False) | |||||
else: | else: | ||||
raise RuntimeError(f'NOT support {arg.embedding} embedding yet!') | raise RuntimeError(f'NOT support {arg.embedding} embedding yet!') | ||||
# define model | # define model | ||||
model = ESIM(embedding, num_labels=len(data_info.vocabs[Const.TARGET])) | |||||
model = ESIM(embedding, num_labels=len(data_bundle.vocabs[Const.TARGET])) | |||||
# define optimizer and callback | # define optimizer and callback | ||||
optimizer = Adamax(lr=arg.lr, params=model.parameters()) | optimizer = Adamax(lr=arg.lr, params=model.parameters()) | ||||
@@ -91,23 +84,29 @@ callbacks = [ | |||||
LRScheduler(scheduler), | LRScheduler(scheduler), | ||||
] | ] | ||||
if arg.task in ['snli']: | |||||
callbacks.append(EvaluateCallback(data=data_bundle.datasets[arg.test_dataset_name])) | |||||
# evaluate test set in every epoch if task is snli. | |||||
# define trainer | # define trainer | ||||
trainer = Trainer(train_data=data_info.datasets[arg.train_dataset_name], model=model, | |||||
trainer = Trainer(train_data=data_bundle.datasets[arg.train_dataset_name], model=model, | |||||
optimizer=optimizer, | optimizer=optimizer, | ||||
loss=CrossEntropyLoss(), | |||||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | ||||
n_epochs=arg.n_epochs, print_every=-1, | n_epochs=arg.n_epochs, print_every=-1, | ||||
dev_data=data_info.datasets[arg.dev_dataset_name], | |||||
dev_data=data_bundle.datasets[arg.dev_dataset_name], | |||||
metrics=AccuracyMetric(), metric_key='acc', | metrics=AccuracyMetric(), metric_key='acc', | ||||
device=[i for i in range(torch.cuda.device_count())], | device=[i for i in range(torch.cuda.device_count())], | ||||
check_code_level=-1, | check_code_level=-1, | ||||
save_path=arg.save_path) | |||||
save_path=arg.save_path, | |||||
callbacks=callbacks) | |||||
# train model | # train model | ||||
trainer.train(load_best_model=True) | trainer.train(load_best_model=True) | ||||
# define tester | # define tester | ||||
tester = Tester( | tester = Tester( | ||||
data=data_info.datasets[arg.test_dataset_name], | |||||
data=data_bundle.datasets[arg.test_dataset_name], | |||||
model=model, | model=model, | ||||
metrics=AccuracyMetric(), | metrics=AccuracyMetric(), | ||||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | ||||
@@ -6,12 +6,11 @@ from torch.optim import Adadelta | |||||
from torch.optim.lr_scheduler import StepLR | from torch.optim.lr_scheduler import StepLR | ||||
from fastNLP import CrossEntropyLoss | from fastNLP import CrossEntropyLoss | ||||
from fastNLP import cache_results | |||||
from fastNLP.core import Trainer, Tester, AccuracyMetric, Const | from fastNLP.core import Trainer, Tester, AccuracyMetric, Const | ||||
from fastNLP.core.callback import LRScheduler, FitlogCallback | |||||
from fastNLP.core.callback import LRScheduler, EvaluateCallback | |||||
from fastNLP.embeddings import StaticEmbedding | from fastNLP.embeddings import StaticEmbedding | ||||
from fastNLP.io.data_loader import MNLILoader, QNLILoader, SNLILoader, RTELoader | |||||
from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, MNLIPipe, QNLIPipe, QuoraPipe | |||||
from reproduction.matching.model.mwan import MwanModel | from reproduction.matching.model.mwan import MwanModel | ||||
import fitlog | import fitlog | ||||
@@ -46,47 +45,25 @@ for k in arg.__dict__: | |||||
# load data set | # load data set | ||||
if arg.task == 'snli': | if arg.task == 'snli': | ||||
@cache_results(f'snli_mwan.pkl') | |||||
def read_snli(): | |||||
data_info = SNLILoader().process( | |||||
paths='path/to/snli/data', to_lower=True, seq_len_type=None, bert_tokenizer=None, | |||||
get_index=True, concat=False, extra_split=['/','%','-'], | |||||
) | |||||
return data_info | |||||
data_info = read_snli() | |||||
data_bundle = SNLIPipe(lower=True, tokenizer='spacy').process_from_file() | |||||
elif arg.task == 'rte': | elif arg.task == 'rte': | ||||
@cache_results(f'rte_mwan.pkl') | |||||
def read_rte(): | |||||
data_info = RTELoader().process( | |||||
paths='path/to/rte/data', to_lower=True, seq_len_type=None, bert_tokenizer=None, | |||||
get_index=True, concat=False, extra_split=['/','%','-'], | |||||
) | |||||
return data_info | |||||
data_info = read_rte() | |||||
data_bundle = RTEPipe(lower=True, tokenizer='spacy').process_from_file() | |||||
elif arg.task == 'qnli': | elif arg.task == 'qnli': | ||||
data_info = QNLILoader().process( | |||||
paths='path/to/qnli/data', to_lower=True, seq_len_type=None, bert_tokenizer=None, | |||||
get_index=True, concat=False , cut_text=512, extra_split=['/','%','-'], | |||||
) | |||||
data_bundle = QNLIPipe(lower=True, tokenizer='spacy').process_from_file() | |||||
elif arg.task == 'mnli': | elif arg.task == 'mnli': | ||||
@cache_results(f'mnli_v0.9_mwan.pkl') | |||||
def read_mnli(): | |||||
data_info = MNLILoader().process( | |||||
paths='path/to/mnli/data', to_lower=True, seq_len_type=None, bert_tokenizer=None, | |||||
get_index=True, concat=False, extra_split=['/','%','-'], | |||||
) | |||||
return data_info | |||||
data_info = read_mnli() | |||||
data_bundle = MNLIPipe(lower=True, tokenizer='spacy').process_from_file() | |||||
elif arg.task == 'quora': | |||||
data_bundle = QuoraPipe(lower=True, tokenizer='spacy').process_from_file() | |||||
else: | else: | ||||
raise RuntimeError(f'NOT support {arg.task} task yet!') | raise RuntimeError(f'NOT support {arg.task} task yet!') | ||||
print(data_info) | |||||
print(len(data_info.vocabs['words'])) | |||||
print(data_bundle) | |||||
print(len(data_bundle.vocabs[Const.INPUTS(0)])) | |||||
model = MwanModel( | model = MwanModel( | ||||
num_class = len(data_info.vocabs[Const.TARGET]), | |||||
EmbLayer = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=False, normalize=False), | |||||
num_class = len(data_bundle.vocabs[Const.TARGET]), | |||||
EmbLayer = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], requires_grad=False, normalize=False), | |||||
ElmoLayer = None, | ElmoLayer = None, | ||||
args_of_imm = { | args_of_imm = { | ||||
"input_size" : 300 , | "input_size" : 300 , | ||||
@@ -105,21 +82,20 @@ callbacks = [ | |||||
] | ] | ||||
if arg.task in ['snli']: | if arg.task in ['snli']: | ||||
callbacks.append(FitlogCallback(data_info.datasets[arg.testset_name], verbose=1)) | |||||
callbacks.append(EvaluateCallback(data=data_bundle.datasets[arg.testset_name])) | |||||
elif arg.task == 'mnli': | elif arg.task == 'mnli': | ||||
callbacks.append(FitlogCallback({'dev_matched': data_info.datasets['dev_matched'], | |||||
'dev_mismatched': data_info.datasets['dev_mismatched']}, | |||||
verbose=1)) | |||||
callbacks.append(EvaluateCallback(data={'dev_matched': data_bundle.datasets['dev_matched'], | |||||
'dev_mismatched': data_bundle.datasets['dev_mismatched']},)) | |||||
trainer = Trainer( | trainer = Trainer( | ||||
train_data = data_info.datasets['train'], | |||||
train_data = data_bundle.datasets['train'], | |||||
model = model, | model = model, | ||||
optimizer = optimizer, | optimizer = optimizer, | ||||
num_workers = 0, | num_workers = 0, | ||||
batch_size = arg.batch_size, | batch_size = arg.batch_size, | ||||
n_epochs = arg.n_epochs, | n_epochs = arg.n_epochs, | ||||
print_every = -1, | print_every = -1, | ||||
dev_data = data_info.datasets[arg.devset_name], | |||||
dev_data = data_bundle.datasets[arg.devset_name], | |||||
metrics = AccuracyMetric(pred = "pred" , target = "target"), | metrics = AccuracyMetric(pred = "pred" , target = "target"), | ||||
metric_key = 'acc', | metric_key = 'acc', | ||||
device = [i for i in range(torch.cuda.device_count())], | device = [i for i in range(torch.cuda.device_count())], | ||||
@@ -130,7 +106,7 @@ trainer = Trainer( | |||||
trainer.train(load_best_model=True) | trainer.train(load_best_model=True) | ||||
tester = Tester( | tester = Tester( | ||||
data=data_info.datasets[arg.testset_name], | |||||
data=data_bundle.datasets[arg.testset_name], | |||||
model=model, | model=model, | ||||
metrics=AccuracyMetric(), | metrics=AccuracyMetric(), | ||||
batch_size=arg.batch_size, | batch_size=arg.batch_size, | ||||
@@ -3,39 +3,28 @@ import torch | |||||
import torch.nn as nn | import torch.nn as nn | ||||
from fastNLP.core.const import Const | from fastNLP.core.const import Const | ||||
from fastNLP.models import BaseModel | |||||
from fastNLP.embeddings.bert import BertModel | |||||
from fastNLP.models.base_model import BaseModel | |||||
from fastNLP.embeddings import BertEmbedding | |||||
class BertForNLI(BaseModel): | class BertForNLI(BaseModel): | ||||
# TODO: still in progress | |||||
def __init__(self, class_num=3, bert_dir=None): | |||||
def __init__(self, bert_embed: BertEmbedding, class_num=3): | |||||
super(BertForNLI, self).__init__() | super(BertForNLI, self).__init__() | ||||
if bert_dir is not None: | |||||
self.bert = BertModel.from_pretrained(bert_dir) | |||||
else: | |||||
self.bert = BertModel() | |||||
hidden_size = self.bert.pooler.dense._parameters['bias'].size(-1) | |||||
self.classifier = nn.Linear(hidden_size, class_num) | |||||
def forward(self, words, seq_len1, seq_len2, target=None): | |||||
self.embed = bert_embed | |||||
self.classifier = nn.Linear(self.embed.embedding_dim, class_num) | |||||
def forward(self, words): | |||||
""" | """ | ||||
:param torch.Tensor words: [batch_size, seq_len] input_ids | :param torch.Tensor words: [batch_size, seq_len] input_ids | ||||
:param torch.Tensor seq_len1: [batch_size, seq_len] token_type_ids | |||||
:param torch.Tensor seq_len2: [batch_size, seq_len] attention_mask | |||||
:param torch.Tensor target: [batch] | |||||
:return: | :return: | ||||
""" | """ | ||||
_, pooled_output = self.bert(words, seq_len1, seq_len2) | |||||
logits = self.classifier(pooled_output) | |||||
hidden = self.embed(words) | |||||
logits = self.classifier(hidden) | |||||
if target is not None: | |||||
loss_func = torch.nn.CrossEntropyLoss() | |||||
loss = loss_func(logits, target) | |||||
return {Const.OUTPUT: logits, Const.LOSS: loss} | |||||
return {Const.OUTPUT: logits} | return {Const.OUTPUT: logits} | ||||
def predict(self, words, seq_len1, seq_len2, target=None): | |||||
return self.forward(words, seq_len1, seq_len2) | |||||
def predict(self, words): | |||||
logits = self.forward(words)[Const.OUTPUT] | |||||
return {Const.OUTPUT: logits.argmax(dim=-1)} | |||||
@@ -3,10 +3,8 @@ import torch.nn as nn | |||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
import numpy as np | import numpy as np | ||||
from torch.nn import CrossEntropyLoss | |||||
from fastNLP.models import BaseModel | |||||
from fastNLP.embeddings.embedding import TokenEmbedding | |||||
from fastNLP.models.base_model import BaseModel | |||||
from fastNLP.embeddings import TokenEmbedding | |||||
from fastNLP.core.const import Const | from fastNLP.core.const import Const | ||||
@@ -83,13 +81,12 @@ class CNTNModel(BaseModel): | |||||
self.weight_V = nn.Linear(2 * ns, r) | self.weight_V = nn.Linear(2 * ns, r) | ||||
self.weight_u = nn.Sequential(nn.Dropout(p=dropout_rate), nn.Linear(r, num_labels)) | self.weight_u = nn.Sequential(nn.Dropout(p=dropout_rate), nn.Linear(r, num_labels)) | ||||
def forward(self, words1, words2, seq_len1, seq_len2, target=None): | |||||
def forward(self, words1, words2, seq_len1, seq_len2): | |||||
""" | """ | ||||
:param words1: [batch, seq_len, emb_size] Question. | :param words1: [batch, seq_len, emb_size] Question. | ||||
:param words2: [batch, seq_len, emb_size] Answer. | :param words2: [batch, seq_len, emb_size] Answer. | ||||
:param seq_len1: [batch] | :param seq_len1: [batch] | ||||
:param seq_len2: [batch] | :param seq_len2: [batch] | ||||
:param target: [batch] Glod labels. | |||||
:return: | :return: | ||||
""" | """ | ||||
in_q = self.embedding(words1) | in_q = self.embedding(words1) | ||||
@@ -109,12 +106,7 @@ class CNTNModel(BaseModel): | |||||
in_a = self.fc_q(in_a.view(in_a.size(0), -1)) | in_a = self.fc_q(in_a.view(in_a.size(0), -1)) | ||||
score = torch.tanh(self.weight_u(self.weight_M(in_q, in_a) + self.weight_V(torch.cat((in_q, in_a), -1)))) | score = torch.tanh(self.weight_u(self.weight_M(in_q, in_a) + self.weight_V(torch.cat((in_q, in_a), -1)))) | ||||
if target is not None: | |||||
loss_fct = CrossEntropyLoss() | |||||
loss = loss_fct(score, target) | |||||
return {Const.LOSS: loss, Const.OUTPUT: score} | |||||
else: | |||||
return {Const.OUTPUT: score} | |||||
return {Const.OUTPUT: score} | |||||
def predict(self, **kwargs): | |||||
return self.forward(**kwargs) | |||||
def predict(self, words1, words2, seq_len1, seq_len2): | |||||
return self.forward(words1, words2, seq_len1, seq_len2) |
@@ -2,10 +2,8 @@ import torch | |||||
import torch.nn as nn | import torch.nn as nn | ||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from torch.nn import CrossEntropyLoss | |||||
from fastNLP.models import BaseModel | |||||
from fastNLP.embeddings.embedding import TokenEmbedding | |||||
from fastNLP.models.base_model import BaseModel | |||||
from fastNLP.embeddings import TokenEmbedding | |||||
from fastNLP.core.const import Const | from fastNLP.core.const import Const | ||||
from fastNLP.core.utils import seq_len_to_mask | from fastNLP.core.utils import seq_len_to_mask | ||||
@@ -42,13 +40,12 @@ class ESIMModel(BaseModel): | |||||
nn.init.xavier_uniform_(self.classifier[1].weight.data) | nn.init.xavier_uniform_(self.classifier[1].weight.data) | ||||
nn.init.xavier_uniform_(self.classifier[4].weight.data) | nn.init.xavier_uniform_(self.classifier[4].weight.data) | ||||
def forward(self, words1, words2, seq_len1, seq_len2, target=None): | |||||
def forward(self, words1, words2, seq_len1, seq_len2): | |||||
""" | """ | ||||
:param words1: [batch, seq_len] | :param words1: [batch, seq_len] | ||||
:param words2: [batch, seq_len] | :param words2: [batch, seq_len] | ||||
:param seq_len1: [batch] | :param seq_len1: [batch] | ||||
:param seq_len2: [batch] | :param seq_len2: [batch] | ||||
:param target: | |||||
:return: | :return: | ||||
""" | """ | ||||
mask1 = seq_len_to_mask(seq_len1, words1.size(1)) | mask1 = seq_len_to_mask(seq_len1, words1.size(1)) | ||||
@@ -82,16 +79,10 @@ class ESIMModel(BaseModel): | |||||
logits = torch.tanh(self.classifier(out)) | logits = torch.tanh(self.classifier(out)) | ||||
# logits = self.classifier(out) | # logits = self.classifier(out) | ||||
if target is not None: | |||||
loss_fct = CrossEntropyLoss() | |||||
loss = loss_fct(logits, target) | |||||
return {Const.LOSS: loss, Const.OUTPUT: logits} | |||||
else: | |||||
return {Const.OUTPUT: logits} | |||||
return {Const.OUTPUT: logits} | |||||
def predict(self, **kwargs): | |||||
pred = self.forward(**kwargs)[Const.OUTPUT].argmax(-1) | |||||
def predict(self, words1, words2, seq_len1, seq_len2): | |||||
pred = self.forward(words1, words2, seq_len1, seq_len2)[Const.OUTPUT].argmax(-1) | |||||
return {Const.OUTPUT: pred} | return {Const.OUTPUT: pred} | ||||
# input [batch_size, len , hidden] | # input [batch_size, len , hidden] | ||||
@@ -1,10 +0,0 @@ | |||||
import unittest | |||||
from ..data import MatchingDataLoader | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
class TestCWSDataLoader(unittest.TestCase): | |||||
def test_case1(self): | |||||
snli_loader = MatchingDataLoader() | |||||
# TODO: still in progress | |||||