@@ -0,0 +1,98 @@ | |||
import os | |||
def find_all(path='../fastNLP'): | |||
head_list = [] | |||
alias_list = [] | |||
for path, dirs, files in os.walk(path): | |||
for file in files: | |||
if file.endswith('.py'): | |||
name = ".".join(path.split('/')[1:]) | |||
if file.split('.')[0] != "__init__": | |||
name = name + '.' + file.split('.')[0] | |||
if len(name.split('.')) < 3 or name.startswith('fastNLP.core'): | |||
heads, alias = find_one(path + '/' + file) | |||
for h in heads: | |||
head_list.append(name + "." + h) | |||
for a in alias: | |||
alias_list.append(a) | |||
heads = {} | |||
for h in head_list: | |||
end = h.split('.')[-1] | |||
file = h[:-len(end) - 1] | |||
if end not in heads: | |||
heads[end] = set() | |||
heads[end].add(file) | |||
alias = {} | |||
for a in alias_list: | |||
for each in a: | |||
end = each.split('.')[-1] | |||
file = each[:-len(end) - 1] | |||
if end not in alias: | |||
alias[end] = set() | |||
alias[end].add(file) | |||
print("IN alias NOT IN heads") | |||
for item in alias: | |||
if item not in heads: | |||
print(item, alias[item]) | |||
elif len(heads[item]) != 2: | |||
print(item, alias[item], heads[item]) | |||
print("\n\nIN heads NOT IN alias") | |||
for item in heads: | |||
if item not in alias: | |||
print(item, heads[item]) | |||
def find_class(path): | |||
with open(path, 'r') as fin: | |||
lines = fin.readlines() | |||
pars = {} | |||
for i, line in enumerate(lines): | |||
if line.strip().startswith('class'): | |||
line = line.strip()[len('class'):-1].strip() | |||
if line[-1] == ')': | |||
line = line[:-1].split('(') | |||
name = line[0].strip() | |||
parents = line[1].split(',') | |||
for i in range(len(parents)): | |||
parents[i] = parents[i].strip() | |||
if len(parents) == 1: | |||
pars[name] = parents[0] | |||
else: | |||
pars[name] = tuple(parents) | |||
return pars | |||
def find_one(path): | |||
head_list = [] | |||
alias = [] | |||
with open(path, 'r') as fin: | |||
lines = fin.readlines() | |||
flag = False | |||
for i, line in enumerate(lines): | |||
if line.strip().startswith('__all__'): | |||
line = line.strip()[len('__all__'):].strip() | |||
if line[-1] == ']': | |||
line = line[1:-1].strip()[1:].strip() | |||
head_list.append(line.strip("\"").strip("\'").strip()) | |||
else: | |||
flag = True | |||
elif line.strip() == ']': | |||
flag = False | |||
elif flag: | |||
line = line.strip()[:-1].strip("\"").strip("\'").strip() | |||
if len(line) == 0 or line[0] == '#': | |||
continue | |||
head_list.append(line) | |||
if line.startswith('def') or line.startswith('class'): | |||
if lines[i + 2].strip().startswith("别名:"): | |||
names = lines[i + 2].strip()[len("别名:"):].split() | |||
names[0] = names[0][len(":class:`"):-1] | |||
names[1] = names[1][len(":class:`"):-1] | |||
alias.append((names[0], names[1])) | |||
return head_list, alias | |||
if __name__ == "__main__": | |||
find_all() # use to check __all__ |
@@ -13,11 +13,11 @@ fastNLP 中最常用的组件可以直接从 fastNLP 包中 import ,他们的 | |||
__all__ = [ | |||
"Instance", | |||
"FieldArray", | |||
"DataSetIter", | |||
"BatchIter", | |||
"TorchLoaderIter", | |||
"Vocabulary", | |||
"DataSet", | |||
"Const", | |||
@@ -51,7 +51,8 @@ __all__ = [ | |||
"LossFunc", | |||
"CrossEntropyLoss", | |||
"L1Loss", "BCELoss", | |||
"L1Loss", | |||
"BCELoss", | |||
"NLLLoss", | |||
"LossInForward", | |||
@@ -7,6 +7,7 @@ torch.FloatTensor。所有的embedding都可以使用 `self.num_embedding` 获 | |||
__all__ = [ | |||
"Embedding", | |||
"TokenEmbedding", | |||
"StaticEmbedding", | |||
"ElmoEmbedding", | |||
"BertEmbedding", | |||
@@ -14,14 +15,14 @@ __all__ = [ | |||
"StackEmbedding", | |||
"LSTMCharEmbedding", | |||
"CNNCharEmbedding", | |||
"get_embeddings" | |||
"get_embeddings", | |||
] | |||
from .embedding import Embedding | |||
from .embedding import Embedding, TokenEmbedding | |||
from .static_embedding import StaticEmbedding | |||
from .elmo_embedding import ElmoEmbedding | |||
from .bert_embedding import BertEmbedding, BertWordPieceEncoder | |||
from .char_embedding import CNNCharEmbedding, LSTMCharEmbedding | |||
from .stack_embedding import StackEmbedding | |||
from .utils import get_embeddings | |||
from .utils import get_embeddings |
@@ -133,10 +133,12 @@ def _get_logger(name=None, level='INFO'): | |||
class FastNLPLogger(logging.Logger): | |||
def add_file(self, path, level): | |||
def add_file(self, path='./log.txt', level='INFO'): | |||
"""add log output file and level""" | |||
_add_file_handler(self, path, level) | |||
def set_stdout(self, stdout, level): | |||
def set_stdout(self, stdout='tqdm', level='INFO'): | |||
"""set stdout format and level""" | |||
_set_stdout_handler(self, stdout, level) | |||
_logger = _init_logger(path=None) | |||
@@ -89,13 +89,15 @@ class MatchingBertPipe(Pipe): | |||
data_bundle.set_vocab(word_vocab, Const.INPUT) | |||
data_bundle.set_vocab(target_vocab, Const.TARGET) | |||
input_fields = [Const.INPUT, Const.INPUT_LEN, Const.TARGET] | |||
input_fields = [Const.INPUT, Const.INPUT_LEN] | |||
target_fields = [Const.TARGET] | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len(Const.INPUT) | |||
dataset.set_input(*input_fields, flag=True) | |||
dataset.set_target(*target_fields, flag=True) | |||
for fields in target_fields: | |||
if dataset.has_field(fields): | |||
dataset.set_target(fields, flag=True) | |||
return data_bundle | |||
@@ -210,14 +212,16 @@ class MatchingPipe(Pipe): | |||
data_bundle.set_vocab(word_vocab, Const.INPUTS(0)) | |||
data_bundle.set_vocab(target_vocab, Const.TARGET) | |||
input_fields = [Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1), Const.TARGET] | |||
input_fields = [Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1)] | |||
target_fields = [Const.TARGET] | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len(Const.INPUTS(0), Const.INPUT_LENS(0)) | |||
dataset.add_seq_len(Const.INPUTS(1), Const.INPUT_LENS(1)) | |||
dataset.set_input(*input_fields, flag=True) | |||
dataset.set_target(*target_fields, flag=True) | |||
for fields in target_fields: | |||
if dataset.has_field(fields): | |||
dataset.set_target(fields, flag=True) | |||
return data_bundle | |||
@@ -1,435 +0,0 @@ | |||
""" | |||
这个文件的内容已合并到fastNLP.io.data_loader里,这个文件的内容不再更新 | |||
""" | |||
import os | |||
from typing import Union, Dict | |||
from fastNLP.core.const import Const | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.io.data_bundle import DataBundle, DataSetLoader | |||
from fastNLP.io.dataset_loader import JsonLoader, CSVLoader | |||
from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||
from fastNLP.modules.encoder._bert import BertTokenizer | |||
class MatchingLoader(DataSetLoader): | |||
""" | |||
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` | |||
读取Matching任务的数据集 | |||
:param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名 | |||
""" | |||
def __init__(self, paths: dict=None): | |||
self.paths = paths | |||
def _load(self, path): | |||
""" | |||
:param str path: 待读取数据集的路径名 | |||
:return: fastNLP.DataSet ds: 返回一个DataSet对象,里面必须包含3个field:其中两个分别为两个句子 | |||
的原始字符串文本,第三个为标签 | |||
""" | |||
raise NotImplementedError | |||
def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None, | |||
to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, | |||
cut_text: int = None, get_index=True, auto_pad_length: int=None, | |||
auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True, | |||
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataBundle: | |||
""" | |||
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, | |||
则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 | |||
对应的全路径文件名。 | |||
:param str dataset_name: 如果在paths里传入的是一个数据集的全路径文件名,那么可以用dataset_name来定义 | |||
这个数据集的名字,如果不定义则默认为train。 | |||
:param bool to_lower: 是否将文本自动转为小写。默认值为False。 | |||
:param str seq_len_type: 提供的seq_len类型,支持 ``seq_len`` :提供一个数字作为句子长度; ``mask`` : | |||
提供一个0/1的mask矩阵作为句子长度; ``bert`` :提供segment_type_id(第一个句子为0,第二个句子为1)和 | |||
attention mask矩阵(0/1的mask矩阵)。默认值为None,即不提供seq_len | |||
:param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径 | |||
:param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。 | |||
:param bool get_index: 是否需要根据词表将文本转为index | |||
:param int auto_pad_length: 是否需要将文本自动pad到一定长度(超过这个长度的文本将会被截掉),默认为不会自动pad | |||
:param str auto_pad_token: 自动pad的内容 | |||
:param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False | |||
则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input, | |||
于此同时其他field不会被设置为input。默认值为True。 | |||
:param set_target: set_target将控制哪些field可以被设置为target,用法与set_input一致。默认值为True。 | |||
:param concat: 是否需要将两个句子拼接起来。如果为False则不会拼接。如果为True则会在两个句子之间插入一个<sep>。 | |||
如果传入一个长度为4的list,则分别表示插在第一句开始前、第一句结束后、第二句开始前、第二句结束后的标识符。如果 | |||
传入字符串 ``bert`` ,则会采用bert的拼接方式,等价于['[CLS]', '[SEP]', '', '[SEP]']. | |||
:return: | |||
""" | |||
if isinstance(set_input, str): | |||
set_input = [set_input] | |||
if isinstance(set_target, str): | |||
set_target = [set_target] | |||
if isinstance(set_input, bool): | |||
auto_set_input = set_input | |||
else: | |||
auto_set_input = False | |||
if isinstance(set_target, bool): | |||
auto_set_target = set_target | |||
else: | |||
auto_set_target = False | |||
if isinstance(paths, str): | |||
if os.path.isdir(paths): | |||
path = {n: os.path.join(paths, self.paths[n]) for n in self.paths.keys()} | |||
else: | |||
path = {dataset_name if dataset_name is not None else 'train': paths} | |||
else: | |||
path = paths | |||
data_info = DataBundle() | |||
for data_name in path.keys(): | |||
data_info.datasets[data_name] = self._load(path[data_name]) | |||
for data_name, data_set in data_info.datasets.items(): | |||
if auto_set_input: | |||
data_set.set_input(Const.INPUTS(0), Const.INPUTS(1)) | |||
if auto_set_target: | |||
if Const.TARGET in data_set.get_field_names(): | |||
data_set.set_target(Const.TARGET) | |||
if to_lower: | |||
for data_name, data_set in data_info.datasets.items(): | |||
data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0), | |||
is_input=auto_set_input) | |||
data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1), | |||
is_input=auto_set_input) | |||
if bert_tokenizer is not None: | |||
if bert_tokenizer.lower() in PRETRAINED_BERT_MODEL_DIR: | |||
PRETRAIN_URL = _get_base_url('bert') | |||
model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer] | |||
model_url = PRETRAIN_URL + model_name | |||
model_dir = cached_path(model_url) | |||
# 检查是否存在 | |||
elif os.path.isdir(bert_tokenizer): | |||
model_dir = bert_tokenizer | |||
else: | |||
raise ValueError(f"Cannot recognize BERT tokenizer from {bert_tokenizer}.") | |||
words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]') | |||
with open(os.path.join(model_dir, 'vocab.txt'), 'r') as f: | |||
lines = f.readlines() | |||
lines = [line.strip() for line in lines] | |||
words_vocab.add_word_lst(lines) | |||
words_vocab.build_vocab() | |||
tokenizer = BertTokenizer.from_pretrained(model_dir) | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if Const.INPUT in fields: | |||
data_set.apply(lambda x: tokenizer.tokenize(' '.join(x[fields])), new_field_name=fields, | |||
is_input=auto_set_input) | |||
if isinstance(concat, bool): | |||
concat = 'default' if concat else None | |||
if concat is not None: | |||
if isinstance(concat, str): | |||
CONCAT_MAP = {'bert': ['[CLS]', '[SEP]', '', '[SEP]'], | |||
'default': ['', '<sep>', '', '']} | |||
if concat.lower() in CONCAT_MAP: | |||
concat = CONCAT_MAP[concat] | |||
else: | |||
concat = 4 * [concat] | |||
assert len(concat) == 4, \ | |||
f'Please choose a list with 4 symbols which at the beginning of first sentence ' \ | |||
f'the end of first sentence, the begin of second sentence, and the end of second' \ | |||
f'sentence. Your input is {concat}' | |||
for data_name, data_set in data_info.datasets.items(): | |||
data_set.apply(lambda x: [concat[0]] + x[Const.INPUTS(0)] + [concat[1]] + [concat[2]] + | |||
x[Const.INPUTS(1)] + [concat[3]], new_field_name=Const.INPUT) | |||
data_set.apply(lambda x: [w for w in x[Const.INPUT] if len(w) > 0], new_field_name=Const.INPUT, | |||
is_input=auto_set_input) | |||
if seq_len_type is not None: | |||
if seq_len_type == 'seq_len': # | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if Const.INPUT in fields: | |||
data_set.apply(lambda x: len(x[fields]), | |||
new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN), | |||
is_input=auto_set_input) | |||
elif seq_len_type == 'mask': | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if Const.INPUT in fields: | |||
data_set.apply(lambda x: [1] * len(x[fields]), | |||
new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN), | |||
is_input=auto_set_input) | |||
elif seq_len_type == 'bert': | |||
for data_name, data_set in data_info.datasets.items(): | |||
if Const.INPUT not in data_set.get_field_names(): | |||
raise KeyError(f'Field ``{Const.INPUT}`` not in {data_name} data set: ' | |||
f'got {data_set.get_field_names()}') | |||
data_set.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), | |||
new_field_name=Const.INPUT_LENS(0), is_input=auto_set_input) | |||
data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), | |||
new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) | |||
if auto_pad_length is not None: | |||
cut_text = min(auto_pad_length, cut_text if cut_text is not None else auto_pad_length) | |||
if cut_text is not None: | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if (Const.INPUT in fields) or ((Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len')): | |||
data_set.apply(lambda x: x[fields][: cut_text], new_field_name=fields, | |||
is_input=auto_set_input) | |||
data_set_list = [d for n, d in data_info.datasets.items()] | |||
assert len(data_set_list) > 0, f'There are NO data sets in data info!' | |||
if bert_tokenizer is None: | |||
words_vocab = Vocabulary(padding=auto_pad_token) | |||
words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], | |||
field_name=[n for n in data_set_list[0].get_field_names() | |||
if (Const.INPUT in n)], | |||
no_create_entry_dataset=[d for n, d in data_info.datasets.items() | |||
if 'train' not in n]) | |||
target_vocab = Vocabulary(padding=None, unknown=None) | |||
target_vocab = target_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], | |||
field_name=Const.TARGET) | |||
data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab} | |||
if get_index: | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if Const.INPUT in fields: | |||
data_set.apply(lambda x: [words_vocab.to_index(w) for w in x[fields]], new_field_name=fields, | |||
is_input=auto_set_input) | |||
if Const.TARGET in data_set.get_field_names(): | |||
data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, | |||
is_input=auto_set_input, is_target=auto_set_target) | |||
if auto_pad_length is not None: | |||
if seq_len_type == 'seq_len': | |||
raise RuntimeError(f'the sequence will be padded with the length {auto_pad_length}, ' | |||
f'so the seq_len_type cannot be `{seq_len_type}`!') | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if Const.INPUT in fields: | |||
data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] * | |||
(auto_pad_length - len(x[fields])), new_field_name=fields, | |||
is_input=auto_set_input) | |||
elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'): | |||
data_set.apply(lambda x: x[fields] + [0] * (auto_pad_length - len(x[fields])), | |||
new_field_name=fields, is_input=auto_set_input) | |||
for data_name, data_set in data_info.datasets.items(): | |||
if isinstance(set_input, list): | |||
data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()]) | |||
if isinstance(set_target, list): | |||
data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()]) | |||
return data_info | |||
class SNLILoader(MatchingLoader, JsonLoader): | |||
""" | |||
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` | |||
读取SNLI数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||
""" | |||
def __init__(self, paths: dict=None): | |||
fields = { | |||
'sentence1_binary_parse': Const.INPUTS(0), | |||
'sentence2_binary_parse': Const.INPUTS(1), | |||
'gold_label': Const.TARGET, | |||
} | |||
paths = paths if paths is not None else { | |||
'train': 'snli_1.0_train.jsonl', | |||
'dev': 'snli_1.0_dev.jsonl', | |||
'test': 'snli_1.0_test.jsonl'} | |||
MatchingLoader.__init__(self, paths=paths) | |||
JsonLoader.__init__(self, fields=fields) | |||
def _load(self, path): | |||
ds = JsonLoader._load(self, path) | |||
parentheses_table = str.maketrans({'(': None, ')': None}) | |||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||
new_field_name=Const.INPUTS(0)) | |||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||
new_field_name=Const.INPUTS(1)) | |||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||
return ds | |||
class RTELoader(MatchingLoader, CSVLoader): | |||
""" | |||
别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.dataset_loader.RTELoader` | |||
读取RTE数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: | |||
""" | |||
def __init__(self, paths: dict=None): | |||
paths = paths if paths is not None else { | |||
'train': 'train.tsv', | |||
'dev': 'dev.tsv', | |||
'test': 'test.tsv' # test set has not label | |||
} | |||
MatchingLoader.__init__(self, paths=paths) | |||
self.fields = { | |||
'sentence1': Const.INPUTS(0), | |||
'sentence2': Const.INPUTS(1), | |||
'label': Const.TARGET, | |||
} | |||
CSVLoader.__init__(self, sep='\t') | |||
def _load(self, path): | |||
ds = CSVLoader._load(self, path) | |||
for k, v in self.fields.items(): | |||
if v in ds.get_field_names(): | |||
ds.rename_field(k, v) | |||
for fields in ds.get_all_fields(): | |||
if Const.INPUT in fields: | |||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||
return ds | |||
class QNLILoader(MatchingLoader, CSVLoader): | |||
""" | |||
别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.dataset_loader.QNLILoader` | |||
读取QNLI数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: | |||
""" | |||
def __init__(self, paths: dict=None): | |||
paths = paths if paths is not None else { | |||
'train': 'train.tsv', | |||
'dev': 'dev.tsv', | |||
'test': 'test.tsv' # test set has not label | |||
} | |||
MatchingLoader.__init__(self, paths=paths) | |||
self.fields = { | |||
'question': Const.INPUTS(0), | |||
'sentence': Const.INPUTS(1), | |||
'label': Const.TARGET, | |||
} | |||
CSVLoader.__init__(self, sep='\t') | |||
def _load(self, path): | |||
ds = CSVLoader._load(self, path) | |||
for k, v in self.fields.items(): | |||
if v in ds.get_field_names(): | |||
ds.rename_field(k, v) | |||
for fields in ds.get_all_fields(): | |||
if Const.INPUT in fields: | |||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||
return ds | |||
class MNLILoader(MatchingLoader, CSVLoader): | |||
""" | |||
别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader` | |||
读取MNLI数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: | |||
""" | |||
def __init__(self, paths: dict=None): | |||
paths = paths if paths is not None else { | |||
'train': 'train.tsv', | |||
'dev_matched': 'dev_matched.tsv', | |||
'dev_mismatched': 'dev_mismatched.tsv', | |||
'test_matched': 'test_matched.tsv', | |||
'test_mismatched': 'test_mismatched.tsv', | |||
# 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', | |||
# 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', | |||
# test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) | |||
} | |||
MatchingLoader.__init__(self, paths=paths) | |||
CSVLoader.__init__(self, sep='\t') | |||
self.fields = { | |||
'sentence1_binary_parse': Const.INPUTS(0), | |||
'sentence2_binary_parse': Const.INPUTS(1), | |||
'gold_label': Const.TARGET, | |||
} | |||
def _load(self, path): | |||
ds = CSVLoader._load(self, path) | |||
for k, v in self.fields.items(): | |||
if k in ds.get_field_names(): | |||
ds.rename_field(k, v) | |||
if Const.TARGET in ds.get_field_names(): | |||
if ds[0][Const.TARGET] == 'hidden': | |||
ds.delete_field(Const.TARGET) | |||
parentheses_table = str.maketrans({'(': None, ')': None}) | |||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||
new_field_name=Const.INPUTS(0)) | |||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||
new_field_name=Const.INPUTS(1)) | |||
if Const.TARGET in ds.get_field_names(): | |||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||
return ds | |||
class QuoraLoader(MatchingLoader, CSVLoader): | |||
""" | |||
别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.dataset_loader.QuoraLoader` | |||
读取MNLI数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: | |||
""" | |||
def __init__(self, paths: dict=None): | |||
paths = paths if paths is not None else { | |||
'train': 'train.tsv', | |||
'dev': 'dev.tsv', | |||
'test': 'test.tsv', | |||
} | |||
MatchingLoader.__init__(self, paths=paths) | |||
CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID')) | |||
def _load(self, path): | |||
ds = CSVLoader._load(self, path) | |||
return ds |
@@ -2,8 +2,12 @@ import random | |||
import numpy as np | |||
import torch | |||
from fastNLP.core import Trainer, Tester, AccuracyMetric, Const, Adam | |||
from fastNLP.io.data_loader import SNLILoader, RTELoader, MNLILoader, QNLILoader, QuoraLoader | |||
from fastNLP.core import Trainer, Tester, AccuracyMetric, Const | |||
from fastNLP.core.callback import WarmupCallback, EvaluateCallback | |||
from fastNLP.core.optimizer import AdamW | |||
from fastNLP.embeddings import BertEmbedding | |||
from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, MNLIBertPipe,\ | |||
QNLIBertPipe, QuoraBertPipe | |||
from reproduction.matching.model.bert import BertForNLI | |||
@@ -12,16 +16,22 @@ from reproduction.matching.model.bert import BertForNLI | |||
class BERTConfig: | |||
task = 'snli' | |||
batch_size_per_gpu = 6 | |||
n_epochs = 6 | |||
lr = 2e-5 | |||
seq_len_type = 'bert' | |||
warm_up_rate = 0.1 | |||
seed = 42 | |||
save_path = None # 模型存储的位置,None表示不存储模型。 | |||
train_dataset_name = 'train' | |||
dev_dataset_name = 'dev' | |||
test_dataset_name = 'test' | |||
save_path = None # 模型存储的位置,None表示不存储模型。 | |||
bert_dir = 'path/to/bert/dir' # 预训练BERT参数文件的文件夹 | |||
to_lower = True # 忽略大小写 | |||
tokenizer = 'spacy' # 使用spacy进行分词 | |||
bert_model_dir_or_name = 'bert-base-uncased' | |||
arg = BERTConfig() | |||
@@ -37,58 +47,52 @@ if n_gpu > 0: | |||
# load data set | |||
if arg.task == 'snli': | |||
data_info = SNLILoader().process( | |||
paths='path/to/snli/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||
get_index=True, concat='bert', | |||
) | |||
data_bundle = SNLIBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() | |||
elif arg.task == 'rte': | |||
data_info = RTELoader().process( | |||
paths='path/to/rte/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||
get_index=True, concat='bert', | |||
) | |||
data_bundle = RTEBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() | |||
elif arg.task == 'qnli': | |||
data_info = QNLILoader().process( | |||
paths='path/to/qnli/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||
get_index=True, concat='bert', | |||
) | |||
data_bundle = QNLIBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() | |||
elif arg.task == 'mnli': | |||
data_info = MNLILoader().process( | |||
paths='path/to/mnli/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||
get_index=True, concat='bert', | |||
) | |||
data_bundle = MNLIBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() | |||
elif arg.task == 'quora': | |||
data_info = QuoraLoader().process( | |||
paths='path/to/quora/data', to_lower=True, seq_len_type=arg.seq_len_type, | |||
bert_tokenizer=arg.bert_dir, cut_text=512, | |||
get_index=True, concat='bert', | |||
) | |||
data_bundle = QuoraBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() | |||
else: | |||
raise RuntimeError(f'NOT support {arg.task} task yet!') | |||
print(data_bundle) # print details in data_bundle | |||
# load embedding | |||
embed = BertEmbedding(data_bundle.vocabs[Const.INPUT], model_dir_or_name=arg.bert_model_dir_or_name) | |||
# define model | |||
model = BertForNLI(class_num=len(data_info.vocabs[Const.TARGET]), bert_dir=arg.bert_dir) | |||
model = BertForNLI(embed, class_num=len(data_bundle.vocabs[Const.TARGET])) | |||
# define optimizer and callback | |||
optimizer = AdamW(lr=arg.lr, params=model.parameters()) | |||
callbacks = [WarmupCallback(warmup=arg.warm_up_rate, schedule='linear'), ] | |||
if arg.task in ['snli']: | |||
callbacks.append(EvaluateCallback(data=data_bundle.datasets[arg.test_dataset_name])) | |||
# evaluate test set in every epoch if task is snli. | |||
# define trainer | |||
trainer = Trainer(train_data=data_info.datasets[arg.train_dataset_name], model=model, | |||
optimizer=Adam(lr=arg.lr, model_params=model.parameters()), | |||
trainer = Trainer(train_data=data_bundle.datasets[arg.train_dataset_name], model=model, | |||
optimizer=optimizer, | |||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | |||
n_epochs=arg.n_epochs, print_every=-1, | |||
dev_data=data_info.datasets[arg.dev_dataset_name], | |||
dev_data=data_bundle.datasets[arg.dev_dataset_name], | |||
metrics=AccuracyMetric(), metric_key='acc', | |||
device=[i for i in range(torch.cuda.device_count())], | |||
check_code_level=-1, | |||
save_path=arg.save_path) | |||
save_path=arg.save_path, | |||
callbacks=callbacks) | |||
# train model | |||
trainer.train(load_best_model=True) | |||
# define tester | |||
tester = Tester( | |||
data=data_info.datasets[arg.test_dataset_name], | |||
data=data_bundle.datasets[arg.test_dataset_name], | |||
model=model, | |||
metrics=AccuracyMetric(), | |||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | |||
@@ -1,9 +1,9 @@ | |||
import argparse | |||
import torch | |||
from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const | |||
from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const, CrossEntropyLoss | |||
from fastNLP.embeddings import StaticEmbedding | |||
from fastNLP.io.data_loader import QNLILoader, RTELoader, SNLILoader, MNLILoader | |||
from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, MNLIPipe, QNLIPipe | |||
from reproduction.matching.model.cntn import CNTNModel | |||
@@ -13,14 +13,12 @@ argument.add_argument('--embedding', choices=['glove', 'word2vec'], default='glo | |||
argument.add_argument('--batch-size-per-gpu', type=int, default=256) | |||
argument.add_argument('--n-epochs', type=int, default=200) | |||
argument.add_argument('--lr', type=float, default=1e-5) | |||
argument.add_argument('--seq-len-type', choices=['mask', 'seq_len'], default='mask') | |||
argument.add_argument('--save-dir', type=str, default=None) | |||
argument.add_argument('--cntn-depth', type=int, default=1) | |||
argument.add_argument('--cntn-ns', type=int, default=200) | |||
argument.add_argument('--cntn-k-top', type=int, default=10) | |||
argument.add_argument('--cntn-r', type=int, default=5) | |||
argument.add_argument('--dataset', choices=['qnli', 'rte', 'snli', 'mnli'], default='qnli') | |||
argument.add_argument('--max-len', type=int, default=50) | |||
arg = argument.parse_args() | |||
# dataset dict | |||
@@ -45,30 +43,25 @@ else: | |||
num_labels = 3 | |||
# load data set | |||
if arg.dataset == 'qnli': | |||
data_info = QNLILoader().process( | |||
paths='path/to/qnli/data', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, | |||
get_index=True, concat=False, auto_pad_length=arg.max_len) | |||
if arg.dataset == 'snli': | |||
data_bundle = SNLIPipe(lower=True, tokenizer='raw').process_from_file() | |||
elif arg.dataset == 'rte': | |||
data_info = RTELoader().process( | |||
paths='path/to/rte/data', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, | |||
get_index=True, concat=False, auto_pad_length=arg.max_len) | |||
elif arg.dataset == 'snli': | |||
data_info = SNLILoader().process( | |||
paths='path/to/snli/data', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, | |||
get_index=True, concat=False, auto_pad_length=arg.max_len) | |||
data_bundle = RTEPipe(lower=True, tokenizer='raw').process_from_file() | |||
elif arg.dataset == 'qnli': | |||
data_bundle = QNLIPipe(lower=True, tokenizer='raw').process_from_file() | |||
elif arg.dataset == 'mnli': | |||
data_info = MNLILoader().process( | |||
paths='path/to/mnli/data', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, | |||
get_index=True, concat=False, auto_pad_length=arg.max_len) | |||
data_bundle = MNLIPipe(lower=True, tokenizer='raw').process_from_file() | |||
else: | |||
raise ValueError(f'now we only support [qnli,rte,snli,mnli] dataset for cntn model!') | |||
raise RuntimeError(f'NOT support {arg.task} task yet!') | |||
print(data_bundle) # print details in data_bundle | |||
# load embedding | |||
if arg.embedding == 'word2vec': | |||
embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], model_dir_or_name='en-word2vec-300', requires_grad=True) | |||
embedding = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-word2vec-300', | |||
requires_grad=True) | |||
elif arg.embedding == 'glove': | |||
embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], model_dir_or_name='en-glove-840b-300', | |||
embedding = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-glove-840b-300d', | |||
requires_grad=True) | |||
else: | |||
raise ValueError(f'now we only support word2vec or glove embedding for cntn model!') | |||
@@ -79,11 +72,12 @@ model = CNTNModel(embedding, ns=arg.cntn_ns, k_top=arg.cntn_k_top, num_labels=nu | |||
print(model) | |||
# define trainer | |||
trainer = Trainer(train_data=data_info.datasets['train'], model=model, | |||
trainer = Trainer(train_data=data_bundle.datasets['train'], model=model, | |||
optimizer=Adam(lr=arg.lr, model_params=model.parameters()), | |||
loss=CrossEntropyLoss(), | |||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | |||
n_epochs=arg.n_epochs, print_every=-1, | |||
dev_data=data_info.datasets[dev_dict[arg.dataset]], | |||
dev_data=data_bundle.datasets[dev_dict[arg.dataset]], | |||
metrics=AccuracyMetric(), metric_key='acc', | |||
device=[i for i in range(torch.cuda.device_count())], | |||
check_code_level=-1) | |||
@@ -93,7 +87,7 @@ trainer.train(load_best_model=True) | |||
# define tester | |||
tester = Tester( | |||
data=data_info.datasets[test_dict[arg.dataset]], | |||
data=data_bundle.datasets[test_dict[arg.dataset]], | |||
model=model, | |||
metrics=AccuracyMetric(), | |||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | |||
@@ -6,10 +6,11 @@ from torch.optim import Adamax | |||
from torch.optim.lr_scheduler import StepLR | |||
from fastNLP.core import Trainer, Tester, AccuracyMetric, Const | |||
from fastNLP.core.callback import GradientClipCallback, LRScheduler | |||
from fastNLP.embeddings.static_embedding import StaticEmbedding | |||
from fastNLP.embeddings.elmo_embedding import ElmoEmbedding | |||
from fastNLP.io.data_loader import SNLILoader, RTELoader, MNLILoader, QNLILoader, QuoraLoader | |||
from fastNLP.core.callback import GradientClipCallback, LRScheduler, EvaluateCallback | |||
from fastNLP.core.losses import CrossEntropyLoss | |||
from fastNLP.embeddings import StaticEmbedding | |||
from fastNLP.embeddings import ElmoEmbedding | |||
from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, MNLIPipe, QNLIPipe, QuoraPipe | |||
from fastNLP.models.snli import ESIM | |||
@@ -17,18 +18,21 @@ from fastNLP.models.snli import ESIM | |||
class ESIMConfig: | |||
task = 'snli' | |||
embedding = 'glove' | |||
batch_size_per_gpu = 196 | |||
n_epochs = 30 | |||
lr = 2e-3 | |||
seq_len_type = 'seq_len' | |||
# seq_len表示在process的时候用len(words)来表示长度信息; | |||
# mask表示用0/1掩码矩阵来表示长度信息; | |||
seed = 42 | |||
save_path = None # 模型存储的位置,None表示不存储模型。 | |||
train_dataset_name = 'train' | |||
dev_dataset_name = 'dev' | |||
test_dataset_name = 'test' | |||
save_path = None # 模型存储的位置,None表示不存储模型。 | |||
to_lower = True # 忽略大小写 | |||
tokenizer = 'spacy' # 使用spacy进行分词 | |||
arg = ESIMConfig() | |||
@@ -44,43 +48,32 @@ if n_gpu > 0: | |||
# load data set | |||
if arg.task == 'snli': | |||
data_info = SNLILoader().process( | |||
paths='path/to/snli/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||
get_index=True, concat=False, | |||
) | |||
data_bundle = SNLIPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() | |||
elif arg.task == 'rte': | |||
data_info = RTELoader().process( | |||
paths='path/to/rte/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||
get_index=True, concat=False, | |||
) | |||
data_bundle = RTEPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() | |||
elif arg.task == 'qnli': | |||
data_info = QNLILoader().process( | |||
paths='path/to/qnli/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||
get_index=True, concat=False, | |||
) | |||
data_bundle = QNLIPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() | |||
elif arg.task == 'mnli': | |||
data_info = MNLILoader().process( | |||
paths='path/to/mnli/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||
get_index=True, concat=False, | |||
) | |||
data_bundle = MNLIPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() | |||
elif arg.task == 'quora': | |||
data_info = QuoraLoader().process( | |||
paths='path/to/quora/data', to_lower=False, seq_len_type=arg.seq_len_type, | |||
get_index=True, concat=False, | |||
) | |||
data_bundle = QuoraPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file() | |||
else: | |||
raise RuntimeError(f'NOT support {arg.task} task yet!') | |||
print(data_bundle) # print details in data_bundle | |||
# load embedding | |||
if arg.embedding == 'elmo': | |||
embedding = ElmoEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True) | |||
embedding = ElmoEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-medium', | |||
requires_grad=True) | |||
elif arg.embedding == 'glove': | |||
embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True, normalize=False) | |||
embedding = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-glove-840b-300d', | |||
requires_grad=True, normalize=False) | |||
else: | |||
raise RuntimeError(f'NOT support {arg.embedding} embedding yet!') | |||
# define model | |||
model = ESIM(embedding, num_labels=len(data_info.vocabs[Const.TARGET])) | |||
model = ESIM(embedding, num_labels=len(data_bundle.vocabs[Const.TARGET])) | |||
# define optimizer and callback | |||
optimizer = Adamax(lr=arg.lr, params=model.parameters()) | |||
@@ -91,23 +84,29 @@ callbacks = [ | |||
LRScheduler(scheduler), | |||
] | |||
if arg.task in ['snli']: | |||
callbacks.append(EvaluateCallback(data=data_bundle.datasets[arg.test_dataset_name])) | |||
# evaluate test set in every epoch if task is snli. | |||
# define trainer | |||
trainer = Trainer(train_data=data_info.datasets[arg.train_dataset_name], model=model, | |||
trainer = Trainer(train_data=data_bundle.datasets[arg.train_dataset_name], model=model, | |||
optimizer=optimizer, | |||
loss=CrossEntropyLoss(), | |||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | |||
n_epochs=arg.n_epochs, print_every=-1, | |||
dev_data=data_info.datasets[arg.dev_dataset_name], | |||
dev_data=data_bundle.datasets[arg.dev_dataset_name], | |||
metrics=AccuracyMetric(), metric_key='acc', | |||
device=[i for i in range(torch.cuda.device_count())], | |||
check_code_level=-1, | |||
save_path=arg.save_path) | |||
save_path=arg.save_path, | |||
callbacks=callbacks) | |||
# train model | |||
trainer.train(load_best_model=True) | |||
# define tester | |||
tester = Tester( | |||
data=data_info.datasets[arg.test_dataset_name], | |||
data=data_bundle.datasets[arg.test_dataset_name], | |||
model=model, | |||
metrics=AccuracyMetric(), | |||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | |||
@@ -6,12 +6,11 @@ from torch.optim import Adadelta | |||
from torch.optim.lr_scheduler import StepLR | |||
from fastNLP import CrossEntropyLoss | |||
from fastNLP import cache_results | |||
from fastNLP.core import Trainer, Tester, AccuracyMetric, Const | |||
from fastNLP.core.callback import LRScheduler, FitlogCallback | |||
from fastNLP.core.callback import LRScheduler, EvaluateCallback | |||
from fastNLP.embeddings import StaticEmbedding | |||
from fastNLP.io.data_loader import MNLILoader, QNLILoader, SNLILoader, RTELoader | |||
from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, MNLIPipe, QNLIPipe, QuoraPipe | |||
from reproduction.matching.model.mwan import MwanModel | |||
import fitlog | |||
@@ -46,47 +45,25 @@ for k in arg.__dict__: | |||
# load data set | |||
if arg.task == 'snli': | |||
@cache_results(f'snli_mwan.pkl') | |||
def read_snli(): | |||
data_info = SNLILoader().process( | |||
paths='path/to/snli/data', to_lower=True, seq_len_type=None, bert_tokenizer=None, | |||
get_index=True, concat=False, extra_split=['/','%','-'], | |||
) | |||
return data_info | |||
data_info = read_snli() | |||
data_bundle = SNLIPipe(lower=True, tokenizer='spacy').process_from_file() | |||
elif arg.task == 'rte': | |||
@cache_results(f'rte_mwan.pkl') | |||
def read_rte(): | |||
data_info = RTELoader().process( | |||
paths='path/to/rte/data', to_lower=True, seq_len_type=None, bert_tokenizer=None, | |||
get_index=True, concat=False, extra_split=['/','%','-'], | |||
) | |||
return data_info | |||
data_info = read_rte() | |||
data_bundle = RTEPipe(lower=True, tokenizer='spacy').process_from_file() | |||
elif arg.task == 'qnli': | |||
data_info = QNLILoader().process( | |||
paths='path/to/qnli/data', to_lower=True, seq_len_type=None, bert_tokenizer=None, | |||
get_index=True, concat=False , cut_text=512, extra_split=['/','%','-'], | |||
) | |||
data_bundle = QNLIPipe(lower=True, tokenizer='spacy').process_from_file() | |||
elif arg.task == 'mnli': | |||
@cache_results(f'mnli_v0.9_mwan.pkl') | |||
def read_mnli(): | |||
data_info = MNLILoader().process( | |||
paths='path/to/mnli/data', to_lower=True, seq_len_type=None, bert_tokenizer=None, | |||
get_index=True, concat=False, extra_split=['/','%','-'], | |||
) | |||
return data_info | |||
data_info = read_mnli() | |||
data_bundle = MNLIPipe(lower=True, tokenizer='spacy').process_from_file() | |||
elif arg.task == 'quora': | |||
data_bundle = QuoraPipe(lower=True, tokenizer='spacy').process_from_file() | |||
else: | |||
raise RuntimeError(f'NOT support {arg.task} task yet!') | |||
print(data_info) | |||
print(len(data_info.vocabs['words'])) | |||
print(data_bundle) | |||
print(len(data_bundle.vocabs[Const.INPUTS(0)])) | |||
model = MwanModel( | |||
num_class = len(data_info.vocabs[Const.TARGET]), | |||
EmbLayer = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=False, normalize=False), | |||
num_class = len(data_bundle.vocabs[Const.TARGET]), | |||
EmbLayer = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], requires_grad=False, normalize=False), | |||
ElmoLayer = None, | |||
args_of_imm = { | |||
"input_size" : 300 , | |||
@@ -105,21 +82,20 @@ callbacks = [ | |||
] | |||
if arg.task in ['snli']: | |||
callbacks.append(FitlogCallback(data_info.datasets[arg.testset_name], verbose=1)) | |||
callbacks.append(EvaluateCallback(data=data_bundle.datasets[arg.testset_name])) | |||
elif arg.task == 'mnli': | |||
callbacks.append(FitlogCallback({'dev_matched': data_info.datasets['dev_matched'], | |||
'dev_mismatched': data_info.datasets['dev_mismatched']}, | |||
verbose=1)) | |||
callbacks.append(EvaluateCallback(data={'dev_matched': data_bundle.datasets['dev_matched'], | |||
'dev_mismatched': data_bundle.datasets['dev_mismatched']},)) | |||
trainer = Trainer( | |||
train_data = data_info.datasets['train'], | |||
train_data = data_bundle.datasets['train'], | |||
model = model, | |||
optimizer = optimizer, | |||
num_workers = 0, | |||
batch_size = arg.batch_size, | |||
n_epochs = arg.n_epochs, | |||
print_every = -1, | |||
dev_data = data_info.datasets[arg.devset_name], | |||
dev_data = data_bundle.datasets[arg.devset_name], | |||
metrics = AccuracyMetric(pred = "pred" , target = "target"), | |||
metric_key = 'acc', | |||
device = [i for i in range(torch.cuda.device_count())], | |||
@@ -130,7 +106,7 @@ trainer = Trainer( | |||
trainer.train(load_best_model=True) | |||
tester = Tester( | |||
data=data_info.datasets[arg.testset_name], | |||
data=data_bundle.datasets[arg.testset_name], | |||
model=model, | |||
metrics=AccuracyMetric(), | |||
batch_size=arg.batch_size, | |||
@@ -3,39 +3,28 @@ import torch | |||
import torch.nn as nn | |||
from fastNLP.core.const import Const | |||
from fastNLP.models import BaseModel | |||
from fastNLP.embeddings.bert import BertModel | |||
from fastNLP.models.base_model import BaseModel | |||
from fastNLP.embeddings import BertEmbedding | |||
class BertForNLI(BaseModel): | |||
# TODO: still in progress | |||
def __init__(self, class_num=3, bert_dir=None): | |||
def __init__(self, bert_embed: BertEmbedding, class_num=3): | |||
super(BertForNLI, self).__init__() | |||
if bert_dir is not None: | |||
self.bert = BertModel.from_pretrained(bert_dir) | |||
else: | |||
self.bert = BertModel() | |||
hidden_size = self.bert.pooler.dense._parameters['bias'].size(-1) | |||
self.classifier = nn.Linear(hidden_size, class_num) | |||
def forward(self, words, seq_len1, seq_len2, target=None): | |||
self.embed = bert_embed | |||
self.classifier = nn.Linear(self.embed.embedding_dim, class_num) | |||
def forward(self, words): | |||
""" | |||
:param torch.Tensor words: [batch_size, seq_len] input_ids | |||
:param torch.Tensor seq_len1: [batch_size, seq_len] token_type_ids | |||
:param torch.Tensor seq_len2: [batch_size, seq_len] attention_mask | |||
:param torch.Tensor target: [batch] | |||
:return: | |||
""" | |||
_, pooled_output = self.bert(words, seq_len1, seq_len2) | |||
logits = self.classifier(pooled_output) | |||
hidden = self.embed(words) | |||
logits = self.classifier(hidden) | |||
if target is not None: | |||
loss_func = torch.nn.CrossEntropyLoss() | |||
loss = loss_func(logits, target) | |||
return {Const.OUTPUT: logits, Const.LOSS: loss} | |||
return {Const.OUTPUT: logits} | |||
def predict(self, words, seq_len1, seq_len2, target=None): | |||
return self.forward(words, seq_len1, seq_len2) | |||
def predict(self, words): | |||
logits = self.forward(words)[Const.OUTPUT] | |||
return {Const.OUTPUT: logits.argmax(dim=-1)} | |||
@@ -3,10 +3,8 @@ import torch.nn as nn | |||
import torch.nn.functional as F | |||
import numpy as np | |||
from torch.nn import CrossEntropyLoss | |||
from fastNLP.models import BaseModel | |||
from fastNLP.embeddings.embedding import TokenEmbedding | |||
from fastNLP.models.base_model import BaseModel | |||
from fastNLP.embeddings import TokenEmbedding | |||
from fastNLP.core.const import Const | |||
@@ -83,13 +81,12 @@ class CNTNModel(BaseModel): | |||
self.weight_V = nn.Linear(2 * ns, r) | |||
self.weight_u = nn.Sequential(nn.Dropout(p=dropout_rate), nn.Linear(r, num_labels)) | |||
def forward(self, words1, words2, seq_len1, seq_len2, target=None): | |||
def forward(self, words1, words2, seq_len1, seq_len2): | |||
""" | |||
:param words1: [batch, seq_len, emb_size] Question. | |||
:param words2: [batch, seq_len, emb_size] Answer. | |||
:param seq_len1: [batch] | |||
:param seq_len2: [batch] | |||
:param target: [batch] Glod labels. | |||
:return: | |||
""" | |||
in_q = self.embedding(words1) | |||
@@ -109,12 +106,7 @@ class CNTNModel(BaseModel): | |||
in_a = self.fc_q(in_a.view(in_a.size(0), -1)) | |||
score = torch.tanh(self.weight_u(self.weight_M(in_q, in_a) + self.weight_V(torch.cat((in_q, in_a), -1)))) | |||
if target is not None: | |||
loss_fct = CrossEntropyLoss() | |||
loss = loss_fct(score, target) | |||
return {Const.LOSS: loss, Const.OUTPUT: score} | |||
else: | |||
return {Const.OUTPUT: score} | |||
return {Const.OUTPUT: score} | |||
def predict(self, **kwargs): | |||
return self.forward(**kwargs) | |||
def predict(self, words1, words2, seq_len1, seq_len2): | |||
return self.forward(words1, words2, seq_len1, seq_len2) |
@@ -2,10 +2,8 @@ import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from torch.nn import CrossEntropyLoss | |||
from fastNLP.models import BaseModel | |||
from fastNLP.embeddings.embedding import TokenEmbedding | |||
from fastNLP.models.base_model import BaseModel | |||
from fastNLP.embeddings import TokenEmbedding | |||
from fastNLP.core.const import Const | |||
from fastNLP.core.utils import seq_len_to_mask | |||
@@ -42,13 +40,12 @@ class ESIMModel(BaseModel): | |||
nn.init.xavier_uniform_(self.classifier[1].weight.data) | |||
nn.init.xavier_uniform_(self.classifier[4].weight.data) | |||
def forward(self, words1, words2, seq_len1, seq_len2, target=None): | |||
def forward(self, words1, words2, seq_len1, seq_len2): | |||
""" | |||
:param words1: [batch, seq_len] | |||
:param words2: [batch, seq_len] | |||
:param seq_len1: [batch] | |||
:param seq_len2: [batch] | |||
:param target: | |||
:return: | |||
""" | |||
mask1 = seq_len_to_mask(seq_len1, words1.size(1)) | |||
@@ -82,16 +79,10 @@ class ESIMModel(BaseModel): | |||
logits = torch.tanh(self.classifier(out)) | |||
# logits = self.classifier(out) | |||
if target is not None: | |||
loss_fct = CrossEntropyLoss() | |||
loss = loss_fct(logits, target) | |||
return {Const.LOSS: loss, Const.OUTPUT: logits} | |||
else: | |||
return {Const.OUTPUT: logits} | |||
return {Const.OUTPUT: logits} | |||
def predict(self, **kwargs): | |||
pred = self.forward(**kwargs)[Const.OUTPUT].argmax(-1) | |||
def predict(self, words1, words2, seq_len1, seq_len2): | |||
pred = self.forward(words1, words2, seq_len1, seq_len2)[Const.OUTPUT].argmax(-1) | |||
return {Const.OUTPUT: pred} | |||
# input [batch_size, len , hidden] | |||
@@ -1,10 +0,0 @@ | |||
import unittest | |||
from ..data import MatchingDataLoader | |||
from fastNLP.core.vocabulary import Vocabulary | |||
class TestCWSDataLoader(unittest.TestCase): | |||
def test_case1(self): | |||
snli_loader = MatchingDataLoader() | |||
# TODO: still in progress | |||