@@ -37,7 +37,7 @@ __all__ = [ | |||
"AccuracyMetric", | |||
"SpanFPreRecMetric", | |||
"SQuADMetric", | |||
"ExtractiveQAMetric", | |||
"Optimizer", | |||
"SGD", | |||
@@ -61,3 +61,4 @@ __version__ = '0.4.0' | |||
from .core import * | |||
from . import models | |||
from . import modules | |||
from .io import data_loader |
@@ -21,7 +21,7 @@ from .dataset import DataSet | |||
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | |||
from .instance import Instance | |||
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | |||
from .metrics import AccuracyMetric, SpanFPreRecMetric, SQuADMetric | |||
from .metrics import AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric | |||
from .optimizer import Optimizer, SGD, Adam | |||
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler | |||
from .tester import Tester | |||
@@ -6,7 +6,7 @@ __all__ = [ | |||
"MetricBase", | |||
"AccuracyMetric", | |||
"SpanFPreRecMetric", | |||
"SQuADMetric" | |||
"ExtractiveQAMetric" | |||
] | |||
import inspect | |||
@@ -24,6 +24,7 @@ from .utils import seq_len_to_mask | |||
from .vocabulary import Vocabulary | |||
from abc import abstractmethod | |||
class MetricBase(object): | |||
""" | |||
所有metrics的基类,,所有的传入到Trainer, Tester的Metric需要继承自该对象,需要覆盖写入evaluate(), get_metric()方法。 | |||
@@ -735,11 +736,11 @@ def _pred_topk(y_prob, k=1): | |||
return y_pred_topk, y_prob_topk | |||
class SQuADMetric(MetricBase): | |||
r""" | |||
别名::class:`fastNLP.SQuADMetric` :class:`fastNLP.core.metrics.SQuADMetric` | |||
class ExtractiveQAMetric(MetricBase): | |||
""" | |||
别名::class:`fastNLP.ExtractiveQAMetric` :class:`fastNLP.core.metrics.ExtractiveQAMetric` | |||
SQuAD数据集metric | |||
抽取式QA(如SQuAD)的metric. | |||
:param pred1: 参数映射表中 `pred1` 的映射关系,None表示映射关系为 `pred1` -> `pred1` | |||
:param pred2: 参数映射表中 `pred2` 的映射关系,None表示映射关系为 `pred2` -> `pred2` | |||
@@ -755,7 +756,7 @@ class SQuADMetric(MetricBase): | |||
def __init__(self, pred1=None, pred2=None, target1=None, target2=None, | |||
beta=1, right_open=True, print_predict_stat=False): | |||
super(SQuADMetric, self).__init__() | |||
super(ExtractiveQAMetric, self).__init__() | |||
self._init_param_map(pred1=pred1, pred2=pred2, target1=target1, target2=target2) | |||
@@ -4,16 +4,26 @@ | |||
这些模块的使用方法如下: | |||
""" | |||
__all__ = [ | |||
'SSTLoader', | |||
'IMDBLoader', | |||
'MatchingLoader', | |||
'SNLILoader', | |||
'MNLILoader', | |||
'MTL16Loader', | |||
'QNLILoader', | |||
'QuoraLoader', | |||
'RTELoader', | |||
'SSTLoader', | |||
'SNLILoader', | |||
'YelpLoader', | |||
] | |||
from .imdb import IMDBLoader | |||
from .matching import MatchingLoader | |||
from .mnli import MNLILoader | |||
from .mtl import MTL16Loader | |||
from .qnli import QNLILoader | |||
from .quora import QuoraLoader | |||
from .rte import RTELoader | |||
from .snli import SNLILoader | |||
from .sst import SSTLoader | |||
from .matching import MatchingLoader, SNLILoader, \ | |||
MNLILoader, QNLILoader, QuoraLoader, RTELoader | |||
from .yelp import YelpLoader |
@@ -0,0 +1,96 @@ | |||
from typing import Union, Dict | |||
from ..embed_loader import EmbeddingOption, EmbedLoader | |||
from ..base_loader import DataSetLoader, DataInfo | |||
from ...core.vocabulary import VocabularyOption, Vocabulary | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
from ...core.const import Const | |||
from ..utils import get_tokenizer | |||
class IMDBLoader(DataSetLoader): | |||
""" | |||
读取IMDB数据集,DataSet包含以下fields: | |||
words: list(str), 需要分类的文本 | |||
target: str, 文本的标签 | |||
""" | |||
def __init__(self): | |||
super(IMDBLoader, self).__init__() | |||
self.tokenizer = get_tokenizer() | |||
def _load(self, path): | |||
dataset = DataSet() | |||
with open(path, 'r', encoding="utf-8") as f: | |||
for line in f: | |||
line = line.strip() | |||
if not line: | |||
continue | |||
parts = line.split('\t') | |||
target = parts[0] | |||
words = self.tokenizer(parts[1].lower()) | |||
dataset.append(Instance(words=words, target=target)) | |||
if len(dataset) == 0: | |||
raise RuntimeError(f"{path} has no valid data.") | |||
return dataset | |||
def process(self, | |||
paths: Union[str, Dict[str, str]], | |||
src_vocab_opt: VocabularyOption = None, | |||
tgt_vocab_opt: VocabularyOption = None, | |||
char_level_op=False): | |||
datasets = {} | |||
info = DataInfo() | |||
for name, path in paths.items(): | |||
dataset = self.load(path) | |||
datasets[name] = dataset | |||
def wordtochar(words): | |||
chars = [] | |||
for word in words: | |||
word = word.lower() | |||
for char in word: | |||
chars.append(char) | |||
chars.append('') | |||
chars.pop() | |||
return chars | |||
if char_level_op: | |||
for dataset in datasets.values(): | |||
dataset.apply_field(wordtochar, field_name="words", new_field_name='chars') | |||
datasets["train"], datasets["dev"] = datasets["train"].split(0.1, shuffle=False) | |||
src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | |||
src_vocab.from_dataset(datasets['train'], field_name='words') | |||
src_vocab.index_dataset(*datasets.values(), field_name='words') | |||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||
if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) | |||
tgt_vocab.from_dataset(datasets['train'], field_name='target') | |||
tgt_vocab.index_dataset(*datasets.values(), field_name='target') | |||
info.vocabs = { | |||
Const.INPUT: src_vocab, | |||
Const.TARGET: tgt_vocab | |||
} | |||
info.datasets = datasets | |||
for name, dataset in info.datasets.items(): | |||
dataset.set_input(Const.INPUT) | |||
dataset.set_target(Const.TARGET) | |||
return info | |||
@@ -5,14 +5,13 @@ from typing import Union, Dict | |||
from ...core.const import Const | |||
from ...core.vocabulary import Vocabulary | |||
from ..base_loader import DataInfo, DataSetLoader | |||
from ..dataset_loader import JsonLoader, CSVLoader | |||
from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||
from ...modules.encoder._bert import BertTokenizer | |||
class MatchingLoader(DataSetLoader): | |||
""" | |||
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` | |||
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.data_loader.MatchingLoader` | |||
读取Matching任务的数据集 | |||
@@ -227,204 +226,3 @@ class MatchingLoader(DataSetLoader): | |||
data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()]) | |||
return data_info | |||
class SNLILoader(MatchingLoader, JsonLoader): | |||
""" | |||
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` | |||
读取SNLI数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||
""" | |||
def __init__(self, paths: dict=None): | |||
fields = { | |||
'sentence1_binary_parse': Const.INPUTS(0), | |||
'sentence2_binary_parse': Const.INPUTS(1), | |||
'gold_label': Const.TARGET, | |||
} | |||
paths = paths if paths is not None else { | |||
'train': 'snli_1.0_train.jsonl', | |||
'dev': 'snli_1.0_dev.jsonl', | |||
'test': 'snli_1.0_test.jsonl'} | |||
MatchingLoader.__init__(self, paths=paths) | |||
JsonLoader.__init__(self, fields=fields) | |||
def _load(self, path): | |||
ds = JsonLoader._load(self, path) | |||
parentheses_table = str.maketrans({'(': None, ')': None}) | |||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||
new_field_name=Const.INPUTS(0)) | |||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||
new_field_name=Const.INPUTS(1)) | |||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||
return ds | |||
class RTELoader(MatchingLoader, CSVLoader): | |||
""" | |||
别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.dataset_loader.RTELoader` | |||
读取RTE数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: | |||
""" | |||
def __init__(self, paths: dict=None): | |||
paths = paths if paths is not None else { | |||
'train': 'train.tsv', | |||
'dev': 'dev.tsv', | |||
'test': 'test.tsv' # test set has not label | |||
} | |||
MatchingLoader.__init__(self, paths=paths) | |||
self.fields = { | |||
'sentence1': Const.INPUTS(0), | |||
'sentence2': Const.INPUTS(1), | |||
'label': Const.TARGET, | |||
} | |||
CSVLoader.__init__(self, sep='\t') | |||
def _load(self, path): | |||
ds = CSVLoader._load(self, path) | |||
for k, v in self.fields.items(): | |||
if v in ds.get_field_names(): | |||
ds.rename_field(k, v) | |||
for fields in ds.get_all_fields(): | |||
if Const.INPUT in fields: | |||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||
return ds | |||
class QNLILoader(MatchingLoader, CSVLoader): | |||
""" | |||
别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.dataset_loader.QNLILoader` | |||
读取QNLI数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: | |||
""" | |||
def __init__(self, paths: dict=None): | |||
paths = paths if paths is not None else { | |||
'train': 'train.tsv', | |||
'dev': 'dev.tsv', | |||
'test': 'test.tsv' # test set has not label | |||
} | |||
MatchingLoader.__init__(self, paths=paths) | |||
self.fields = { | |||
'question': Const.INPUTS(0), | |||
'sentence': Const.INPUTS(1), | |||
'label': Const.TARGET, | |||
} | |||
CSVLoader.__init__(self, sep='\t') | |||
def _load(self, path): | |||
ds = CSVLoader._load(self, path) | |||
for k, v in self.fields.items(): | |||
if v in ds.get_field_names(): | |||
ds.rename_field(k, v) | |||
for fields in ds.get_all_fields(): | |||
if Const.INPUT in fields: | |||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||
return ds | |||
class MNLILoader(MatchingLoader, CSVLoader): | |||
""" | |||
别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader` | |||
读取MNLI数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: | |||
""" | |||
def __init__(self, paths: dict=None): | |||
paths = paths if paths is not None else { | |||
'train': 'train.tsv', | |||
'dev_matched': 'dev_matched.tsv', | |||
'dev_mismatched': 'dev_mismatched.tsv', | |||
'test_matched': 'test_matched.tsv', | |||
'test_mismatched': 'test_mismatched.tsv', | |||
# 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', | |||
# 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', | |||
# test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) | |||
} | |||
MatchingLoader.__init__(self, paths=paths) | |||
CSVLoader.__init__(self, sep='\t') | |||
self.fields = { | |||
'sentence1_binary_parse': Const.INPUTS(0), | |||
'sentence2_binary_parse': Const.INPUTS(1), | |||
'gold_label': Const.TARGET, | |||
} | |||
def _load(self, path): | |||
ds = CSVLoader._load(self, path) | |||
for k, v in self.fields.items(): | |||
if k in ds.get_field_names(): | |||
ds.rename_field(k, v) | |||
if Const.TARGET in ds.get_field_names(): | |||
if ds[0][Const.TARGET] == 'hidden': | |||
ds.delete_field(Const.TARGET) | |||
parentheses_table = str.maketrans({'(': None, ')': None}) | |||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||
new_field_name=Const.INPUTS(0)) | |||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||
new_field_name=Const.INPUTS(1)) | |||
if Const.TARGET in ds.get_field_names(): | |||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||
return ds | |||
class QuoraLoader(MatchingLoader, CSVLoader): | |||
""" | |||
别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.dataset_loader.QuoraLoader` | |||
读取MNLI数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: | |||
""" | |||
def __init__(self, paths: dict=None): | |||
paths = paths if paths is not None else { | |||
'train': 'train.tsv', | |||
'dev': 'dev.tsv', | |||
'test': 'test.tsv', | |||
} | |||
MatchingLoader.__init__(self, paths=paths) | |||
CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID')) | |||
def _load(self, path): | |||
ds = CSVLoader._load(self, path) | |||
return ds |
@@ -0,0 +1,60 @@ | |||
from ...core import Const | |||
from .matching import MatchingLoader | |||
from ..dataset_loader import CSVLoader | |||
class MNLILoader(MatchingLoader, CSVLoader): | |||
""" | |||
别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.data_loader.MNLILoader` | |||
读取MNLI数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: | |||
""" | |||
def __init__(self, paths: dict=None): | |||
paths = paths if paths is not None else { | |||
'train': 'train.tsv', | |||
'dev_matched': 'dev_matched.tsv', | |||
'dev_mismatched': 'dev_mismatched.tsv', | |||
'test_matched': 'test_matched.tsv', | |||
'test_mismatched': 'test_mismatched.tsv', | |||
# 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', | |||
# 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', | |||
# test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) | |||
} | |||
MatchingLoader.__init__(self, paths=paths) | |||
CSVLoader.__init__(self, sep='\t') | |||
self.fields = { | |||
'sentence1_binary_parse': Const.INPUTS(0), | |||
'sentence2_binary_parse': Const.INPUTS(1), | |||
'gold_label': Const.TARGET, | |||
} | |||
def _load(self, path): | |||
ds = CSVLoader._load(self, path) | |||
for k, v in self.fields.items(): | |||
if k in ds.get_field_names(): | |||
ds.rename_field(k, v) | |||
if Const.TARGET in ds.get_field_names(): | |||
if ds[0][Const.TARGET] == 'hidden': | |||
ds.delete_field(Const.TARGET) | |||
parentheses_table = str.maketrans({'(': None, ')': None}) | |||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||
new_field_name=Const.INPUTS(0)) | |||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||
new_field_name=Const.INPUTS(1)) | |||
if Const.TARGET in ds.get_field_names(): | |||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||
return ds |
@@ -0,0 +1,65 @@ | |||
from typing import Union, Dict | |||
from ..base_loader import DataInfo | |||
from ..dataset_loader import CSVLoader | |||
from ...core.vocabulary import Vocabulary, VocabularyOption | |||
from ...core.const import Const | |||
from ..utils import check_dataloader_paths | |||
class MTL16Loader(CSVLoader): | |||
""" | |||
读取MTL16数据集,DataSet包含以下fields: | |||
words: list(str), 需要分类的文本 | |||
target: str, 文本的标签 | |||
数据来源:https://pan.baidu.com/s/1c2L6vdA | |||
""" | |||
def __init__(self): | |||
super(MTL16Loader, self).__init__(headers=(Const.TARGET, Const.INPUT), sep='\t') | |||
def _load(self, path): | |||
dataset = super(MTL16Loader, self)._load(path) | |||
dataset.apply(lambda x: x[Const.INPUT].lower().split(), new_field_name=Const.INPUT) | |||
if len(dataset) == 0: | |||
raise RuntimeError(f"{path} has no valid data.") | |||
return dataset | |||
def process(self, | |||
paths: Union[str, Dict[str, str]], | |||
src_vocab_opt: VocabularyOption = None, | |||
tgt_vocab_opt: VocabularyOption = None,): | |||
paths = check_dataloader_paths(paths) | |||
datasets = {} | |||
info = DataInfo() | |||
for name, path in paths.items(): | |||
dataset = self.load(path) | |||
datasets[name] = dataset | |||
src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | |||
src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT) | |||
src_vocab.index_dataset(*datasets.values(), field_name=Const.INPUT) | |||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||
if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) | |||
tgt_vocab.from_dataset(datasets['train'], field_name=Const.TARGET) | |||
tgt_vocab.index_dataset(*datasets.values(), field_name=Const.TARGET) | |||
info.vocabs = { | |||
Const.INPUT: src_vocab, | |||
Const.TARGET: tgt_vocab | |||
} | |||
info.datasets = datasets | |||
for name, dataset in info.datasets.items(): | |||
dataset.set_input(Const.INPUT) | |||
dataset.set_target(Const.TARGET) | |||
return info |
@@ -0,0 +1,45 @@ | |||
from ...core import Const | |||
from .matching import MatchingLoader | |||
from ..dataset_loader import CSVLoader | |||
class QNLILoader(MatchingLoader, CSVLoader): | |||
""" | |||
别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.data_loader.QNLILoader` | |||
读取QNLI数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: | |||
""" | |||
def __init__(self, paths: dict=None): | |||
paths = paths if paths is not None else { | |||
'train': 'train.tsv', | |||
'dev': 'dev.tsv', | |||
'test': 'test.tsv' # test set has not label | |||
} | |||
MatchingLoader.__init__(self, paths=paths) | |||
self.fields = { | |||
'question': Const.INPUTS(0), | |||
'sentence': Const.INPUTS(1), | |||
'label': Const.TARGET, | |||
} | |||
CSVLoader.__init__(self, sep='\t') | |||
def _load(self, path): | |||
ds = CSVLoader._load(self, path) | |||
for k, v in self.fields.items(): | |||
if k in ds.get_field_names(): | |||
ds.rename_field(k, v) | |||
for fields in ds.get_all_fields(): | |||
if Const.INPUT in fields: | |||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||
return ds |
@@ -0,0 +1,32 @@ | |||
from ...core import Const | |||
from .matching import MatchingLoader | |||
from ..dataset_loader import CSVLoader | |||
class QuoraLoader(MatchingLoader, CSVLoader): | |||
""" | |||
别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.data_loader.QuoraLoader` | |||
读取MNLI数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: | |||
""" | |||
def __init__(self, paths: dict=None): | |||
paths = paths if paths is not None else { | |||
'train': 'train.tsv', | |||
'dev': 'dev.tsv', | |||
'test': 'test.tsv', | |||
} | |||
MatchingLoader.__init__(self, paths=paths) | |||
CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID')) | |||
def _load(self, path): | |||
ds = CSVLoader._load(self, path) | |||
return ds |
@@ -0,0 +1,45 @@ | |||
from ...core import Const | |||
from .matching import MatchingLoader | |||
from ..dataset_loader import CSVLoader | |||
class RTELoader(MatchingLoader, CSVLoader): | |||
""" | |||
别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.data_loader.RTELoader` | |||
读取RTE数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: | |||
""" | |||
def __init__(self, paths: dict=None): | |||
paths = paths if paths is not None else { | |||
'train': 'train.tsv', | |||
'dev': 'dev.tsv', | |||
'test': 'test.tsv' # test set has not label | |||
} | |||
MatchingLoader.__init__(self, paths=paths) | |||
self.fields = { | |||
'sentence1': Const.INPUTS(0), | |||
'sentence2': Const.INPUTS(1), | |||
'label': Const.TARGET, | |||
} | |||
CSVLoader.__init__(self, sep='\t') | |||
def _load(self, path): | |||
ds = CSVLoader._load(self, path) | |||
for k, v in self.fields.items(): | |||
if k in ds.get_field_names(): | |||
ds.rename_field(k, v) | |||
for fields in ds.get_all_fields(): | |||
if Const.INPUT in fields: | |||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||
return ds |
@@ -0,0 +1,44 @@ | |||
from ...core import Const | |||
from .matching import MatchingLoader | |||
from ..dataset_loader import JsonLoader | |||
class SNLILoader(MatchingLoader, JsonLoader): | |||
""" | |||
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.data_loader.SNLILoader` | |||
读取SNLI数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||
""" | |||
def __init__(self, paths: dict=None): | |||
fields = { | |||
'sentence1_binary_parse': Const.INPUTS(0), | |||
'sentence2_binary_parse': Const.INPUTS(1), | |||
'gold_label': Const.TARGET, | |||
} | |||
paths = paths if paths is not None else { | |||
'train': 'snli_1.0_train.jsonl', | |||
'dev': 'snli_1.0_dev.jsonl', | |||
'test': 'snli_1.0_test.jsonl'} | |||
MatchingLoader.__init__(self, paths=paths) | |||
JsonLoader.__init__(self, fields=fields) | |||
def _load(self, path): | |||
ds = JsonLoader._load(self, path) | |||
parentheses_table = str.maketrans({'(': None, ')': None}) | |||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||
new_field_name=Const.INPUTS(0)) | |||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||
new_field_name=Const.INPUTS(1)) | |||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||
return ds |
@@ -1,19 +1,19 @@ | |||
from typing import Iterable | |||
from typing import Union, Dict | |||
from nltk import Tree | |||
import spacy | |||
from ..base_loader import DataInfo, DataSetLoader | |||
from ..dataset_loader import CSVLoader | |||
from ...core.vocabulary import VocabularyOption, Vocabulary | |||
from ...core.dataset import DataSet | |||
from ...core.const import Const | |||
from ...core.instance import Instance | |||
from ..utils import check_dataloader_paths, get_tokenizer | |||
class SSTLoader(DataSetLoader): | |||
URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' | |||
DATA_DIR = 'sst/' | |||
""" | |||
别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader` | |||
别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.data_loader.SSTLoader` | |||
读取SST数据集, DataSet包含fields:: | |||
@@ -26,6 +26,9 @@ class SSTLoader(DataSetLoader): | |||
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` | |||
""" | |||
URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' | |||
DATA_DIR = 'sst/' | |||
def __init__(self, subtree=False, fine_grained=False): | |||
self.subtree = subtree | |||
@@ -98,3 +101,72 @@ class SSTLoader(DataSetLoader): | |||
return info | |||
class SST2Loader(CSVLoader): | |||
""" | |||
数据来源"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', | |||
""" | |||
def __init__(self): | |||
super(SST2Loader, self).__init__(sep='\t') | |||
self.tokenizer = get_tokenizer() | |||
self.field = {'sentence': Const.INPUT, 'label': Const.TARGET} | |||
def _load(self, path: str) -> DataSet: | |||
ds = super(SST2Loader, self)._load(path) | |||
ds.apply(lambda x: self.tokenizer(x[Const.INPUT]), new_field_name=Const.INPUT) | |||
print("all count:", len(ds)) | |||
return ds | |||
def process(self, | |||
paths: Union[str, Dict[str, str]], | |||
src_vocab_opt: VocabularyOption = None, | |||
tgt_vocab_opt: VocabularyOption = None, | |||
char_level_op=False): | |||
paths = check_dataloader_paths(paths) | |||
datasets = {} | |||
info = DataInfo() | |||
for name, path in paths.items(): | |||
dataset = self.load(path) | |||
datasets[name] = dataset | |||
def wordtochar(words): | |||
chars = [] | |||
for word in words: | |||
word = word.lower() | |||
for char in word: | |||
chars.append(char) | |||
chars.append('') | |||
chars.pop() | |||
return chars | |||
input_name, target_name = Const.INPUT, Const.TARGET | |||
info.vocabs={} | |||
# 就分隔为char形式 | |||
if char_level_op: | |||
for dataset in datasets.values(): | |||
dataset.apply_field(wordtochar, field_name=Const.INPUT, new_field_name=Const.CHAR_INPUT) | |||
src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | |||
src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT) | |||
src_vocab.index_dataset(*datasets.values(), field_name=Const.INPUT) | |||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||
if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) | |||
tgt_vocab.from_dataset(datasets['train'], field_name=Const.TARGET) | |||
tgt_vocab.index_dataset(*datasets.values(), field_name=Const.TARGET) | |||
info.vocabs = { | |||
Const.INPUT: src_vocab, | |||
Const.TARGET: tgt_vocab | |||
} | |||
info.datasets = datasets | |||
for name, dataset in info.datasets.items(): | |||
dataset.set_input(Const.INPUT) | |||
dataset.set_target(Const.TARGET) | |||
return info | |||
@@ -0,0 +1,126 @@ | |||
import csv | |||
from typing import Iterable | |||
from ...core.const import Const | |||
from ...core import DataSet, Instance, Vocabulary | |||
from ...core.vocabulary import VocabularyOption | |||
from ..base_loader import DataInfo,DataSetLoader | |||
from typing import Union, Dict | |||
from ..utils import check_dataloader_paths, get_tokenizer | |||
class YelpLoader(DataSetLoader): | |||
""" | |||
读取Yelp_full/Yelp_polarity数据集, DataSet包含fields: | |||
words: list(str), 需要分类的文本 | |||
target: str, 文本的标签 | |||
chars:list(str),未index的字符列表 | |||
数据集:yelp_full/yelp_polarity | |||
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` | |||
:param lower: 是否需要自动转小写,默认为False。 | |||
""" | |||
def __init__(self, fine_grained=False, lower=False): | |||
super(YelpLoader, self).__init__() | |||
tag_v = {'1.0': 'very negative', '2.0': 'negative', '3.0': 'neutral', | |||
'4.0': 'positive', '5.0': 'very positive'} | |||
if not fine_grained: | |||
tag_v['1.0'] = tag_v['2.0'] | |||
tag_v['5.0'] = tag_v['4.0'] | |||
self.fine_grained = fine_grained | |||
self.tag_v = tag_v | |||
self.lower = lower | |||
self.tokenizer = get_tokenizer() | |||
def _load(self, path): | |||
ds = DataSet() | |||
csv_reader = csv.reader(open(path, encoding='utf-8')) | |||
all_count = 0 | |||
real_count = 0 | |||
for row in csv_reader: | |||
all_count += 1 | |||
if len(row) == 2: | |||
target = self.tag_v[row[0] + ".0"] | |||
words = clean_str(row[1], self.tokenizer, self.lower) | |||
if len(words) != 0: | |||
ds.append(Instance(words=words, target=target)) | |||
real_count += 1 | |||
print("all count:", all_count) | |||
print("real count:", real_count) | |||
return ds | |||
def process(self, paths: Union[str, Dict[str, str]], | |||
train_ds: Iterable[str] = None, | |||
src_vocab_op: VocabularyOption = None, | |||
tgt_vocab_op: VocabularyOption = None, | |||
char_level_op=False): | |||
paths = check_dataloader_paths(paths) | |||
info = DataInfo(datasets=self.load(paths)) | |||
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) | |||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) | |||
_train_ds = [info.datasets[name] | |||
for name in train_ds] if train_ds else info.datasets.values() | |||
def wordtochar(words): | |||
chars = [] | |||
for word in words: | |||
word = word.lower() | |||
for char in word: | |||
chars.append(char) | |||
chars.append('') | |||
chars.pop() | |||
return chars | |||
input_name, target_name = Const.INPUT, Const.TARGET | |||
info.vocabs = {} | |||
# 就分隔为char形式 | |||
if char_level_op: | |||
for dataset in info.datasets.values(): | |||
dataset.apply_field(wordtochar, field_name=Const.INPUT, new_field_name=Const.CHAR_INPUT) | |||
else: | |||
src_vocab.from_dataset(*_train_ds, field_name=input_name) | |||
src_vocab.index_dataset(*info.datasets.values(), field_name=input_name, new_field_name=input_name) | |||
info.vocabs[input_name] = src_vocab | |||
tgt_vocab.from_dataset(*_train_ds, field_name=target_name) | |||
tgt_vocab.index_dataset( | |||
*info.datasets.values(), | |||
field_name=target_name, new_field_name=target_name) | |||
info.vocabs[target_name] = tgt_vocab | |||
info.datasets['train'], info.datasets['dev'] = info.datasets['train'].split(0.1, shuffle=False) | |||
for name, dataset in info.datasets.items(): | |||
dataset.set_input(Const.INPUT) | |||
dataset.set_target(Const.TARGET) | |||
return info | |||
def clean_str(sentence, tokenizer, char_lower=False): | |||
""" | |||
heavily borrowed from github | |||
https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb | |||
:param sentence: is a str | |||
:return: | |||
""" | |||
if char_lower: | |||
sentence = sentence.lower() | |||
import re | |||
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') | |||
words = tokenizer(sentence) | |||
words_collection = [] | |||
for word in words: | |||
if word in ['-lrb-', '-rrb-', '<sssss>', '-r', '-l', 'b-']: | |||
continue | |||
tt = nonalpnum.split(word) | |||
t = ''.join(tt) | |||
if t != '': | |||
words_collection.append(t) | |||
return words_collection | |||
@@ -1,14 +0,0 @@ | |||
__all__ = [ | |||
"MaxPool", | |||
"MaxPoolWithMask", | |||
"AvgPool", | |||
"MultiHeadAttention", | |||
] | |||
from .pooling import MaxPool | |||
from .pooling import MaxPoolWithMask | |||
from .pooling import AvgPool | |||
from .pooling import AvgPoolWithMask | |||
from .attention import MultiHeadAttention |
@@ -8,9 +8,9 @@ import torch | |||
import torch.nn.functional as F | |||
from torch import nn | |||
from ..dropout import TimestepDropout | |||
from fastNLP.modules.dropout import TimestepDropout | |||
from ..utils import initial_parameter | |||
from fastNLP.modules.utils import initial_parameter | |||
class DotAttention(nn.Module): |
@@ -3,7 +3,7 @@ __all__ = [ | |||
] | |||
from torch import nn | |||
from ..aggregator.attention import MultiHeadAttention | |||
from fastNLP.modules.encoder.attention import MultiHeadAttention | |||
from ..dropout import TimestepDropout | |||
@@ -11,7 +11,7 @@ Coreference resolution是查找文本中指向同一现实实体的所有表达 | |||
由于版权问题,本文无法提供数据集的下载,请自行下载。 | |||
原始数据集的格式为conll格式,详细介绍参考数据集给出的官方介绍页面。 | |||
代码实现采用了论文作者Lee的预处理方法,具体细节参加[链接](https://github.com/kentonl/e2e-coref/blob/e2e/setup_training.sh)。 | |||
代码实现采用了论文作者Lee的预处理方法,具体细节参见[链接](https://github.com/kentonl/e2e-coref/blob/e2e/setup_training.sh)。 | |||
处理之后的数据集为json格式,例子: | |||
``` | |||
{ | |||
@@ -25,12 +25,12 @@ Coreference resolution是查找文本中指向同一现实实体的所有表达 | |||
### embedding 数据集下载 | |||
[turian emdedding](https://lil.cs.washington.edu/coref/turian.50d.txt) | |||
[glove embedding]( https://nlp.stanford.edu/data/glove.840B.300d.zip) | |||
[glove embedding](https://nlp.stanford.edu/data/glove.840B.300d.zip) | |||
## 运行 | |||
```python | |||
```shell | |||
# 训练代码 | |||
CUDA_VISIBLE_DEVICES=0 python train.py | |||
# 测试代码 | |||
@@ -39,9 +39,9 @@ CUDA_VISIBLE_DEVICES=0 python valid.py | |||
## 结果 | |||
原论文作者在测试集上取得了67.2%的结果,AllenNLP复现的结果为 [63.0%](https://allennlp.org/models)。 | |||
其中allenNLP训练时没有加入speaker信息,没有variational dropout以及只使用了100的antecedents而不是250。 | |||
其中AllenNLP训练时没有加入speaker信息,没有variational dropout以及只使用了100的antecedents而不是250。 | |||
在与allenNLP使用同样的超参和配置时,本代码复现取得了63.6%的F1值。 | |||
在与AllenNLP使用同样的超参和配置时,本代码复现取得了63.6%的F1值。 | |||
## 问题 |
@@ -2,7 +2,7 @@ | |||
这里使用fastNLP复现了几个著名的Matching任务的模型,旨在达到与论文中相符的性能。这几个任务的评价指标均为准确率(%). | |||
复现的模型有(按论文发表时间顺序排序): | |||
- CNTN:模型代码(still in progress)[](); 训练代码(still in progress)[](). | |||
- CNTN:[模型代码](model/cntn.py); [训练代码](matching_cntn.py). | |||
论文链接:[Convolutional Neural Tensor Network Architecture for Community-based Question Answering](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844). | |||
- ESIM:[模型代码](model/esim.py); [训练代码](matching_esim.py). | |||
论文链接:[Enhanced LSTM for Natural Language Inference](https://arxiv.org/pdf/1609.06038.pdf). | |||
@@ -21,7 +21,7 @@ | |||
model name | SNLI | MNLI | RTE | QNLI | Quora | |||
:---: | :---: | :---: | :---: | :---: | :---: | |||
CNTN [](); [论文](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844) | 74.53 vs - | 60.84/-(dev) vs - | 57.4(dev) vs - | 62.53(dev) vs - | - | | |||
CNTN [代码](model/cntn.py); [论文](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844) | 77.79 vs - | 63.29/63.16(dev) vs - | 57.04(dev) vs - | 62.38(dev) vs - | - | | |||
ESIM[代码](model/bert.py); [论文](https://arxiv.org/pdf/1609.06038.pdf) | 88.13(glove) vs 88.0(glove)/88.7(elmo) | 77.78/76.49 vs 72.4/72.1* | 59.21(dev) vs - | 76.97(dev) vs - | - | | |||
DIIN [](); [论文](https://arxiv.org/pdf/1709.04348.pdf) | - vs 88.0 | - vs 78.8/77.8 | - | - | - vs 89.06 | | |||
MwAN [](); [论文](https://www.ijcai.org/proceedings/2018/0613.pdf) | 87.9 vs 88.3 | 77.3/76.7(dev) vs 78.5/77.7 | - | 74.6(dev) vs - | 85.6 vs 89.12 | | |||
@@ -44,7 +44,7 @@ Performance on Test set: | |||
model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | BERT-Large | |||
:---: | :---: | :---: | :---: | :---: | :---: | :---: | |||
__performance__ | - | 88.13 | - | 87.9 | 90.6 | 91.16 | |||
__performance__ | 77.79 | 88.13 | - | 87.9 | 90.6 | 91.16 | |||
## MNLI | |||
[Link to MNLI main page](https://www.nyu.edu/projects/bowman/multinli/) | |||
@@ -60,7 +60,7 @@ Performance on Test set(matched/mismatched): | |||
model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | |||
:---: | :---: | :---: | :---: | :---: | :---: | | |||
__performance__ | - | 77.78/76.49 | - | 77.3/76.7(dev) | - | | |||
__performance__ | 63.29/63.16(dev) | 77.78/76.49 | - | 77.3/76.7(dev) | - | | |||
## RTE | |||
@@ -92,7 +92,7 @@ Performance on __Dev__ set: | |||
model name | CNTN | ESIM | DIIN | MwAN | BERT | |||
:---: | :---: | :---: | :---: | :---: | :---: | |||
__performance__ | - | 76.97 | - | 74.6 | - | |||
__performance__ | 62.38 | 76.97 | - | 74.6 | - | |||
## Quora | |||
@@ -3,3 +3,5 @@ torch>=1.0.0 | |||
tqdm>=4.28.1 | |||
nltk>=3.4.1 | |||
requests | |||
spacy | |||
h5py |