@@ -37,7 +37,7 @@ __all__ = [ | |||||
"AccuracyMetric", | "AccuracyMetric", | ||||
"SpanFPreRecMetric", | "SpanFPreRecMetric", | ||||
"SQuADMetric", | |||||
"ExtractiveQAMetric", | |||||
"Optimizer", | "Optimizer", | ||||
"SGD", | "SGD", | ||||
@@ -61,3 +61,4 @@ __version__ = '0.4.0' | |||||
from .core import * | from .core import * | ||||
from . import models | from . import models | ||||
from . import modules | from . import modules | ||||
from .io import data_loader |
@@ -21,7 +21,7 @@ from .dataset import DataSet | |||||
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | ||||
from .instance import Instance | from .instance import Instance | ||||
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | ||||
from .metrics import AccuracyMetric, SpanFPreRecMetric, SQuADMetric | |||||
from .metrics import AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric | |||||
from .optimizer import Optimizer, SGD, Adam | from .optimizer import Optimizer, SGD, Adam | ||||
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler | from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler | ||||
from .tester import Tester | from .tester import Tester | ||||
@@ -6,7 +6,7 @@ __all__ = [ | |||||
"MetricBase", | "MetricBase", | ||||
"AccuracyMetric", | "AccuracyMetric", | ||||
"SpanFPreRecMetric", | "SpanFPreRecMetric", | ||||
"SQuADMetric" | |||||
"ExtractiveQAMetric" | |||||
] | ] | ||||
import inspect | import inspect | ||||
@@ -24,6 +24,7 @@ from .utils import seq_len_to_mask | |||||
from .vocabulary import Vocabulary | from .vocabulary import Vocabulary | ||||
from abc import abstractmethod | from abc import abstractmethod | ||||
class MetricBase(object): | class MetricBase(object): | ||||
""" | """ | ||||
所有metrics的基类,,所有的传入到Trainer, Tester的Metric需要继承自该对象,需要覆盖写入evaluate(), get_metric()方法。 | 所有metrics的基类,,所有的传入到Trainer, Tester的Metric需要继承自该对象,需要覆盖写入evaluate(), get_metric()方法。 | ||||
@@ -735,11 +736,11 @@ def _pred_topk(y_prob, k=1): | |||||
return y_pred_topk, y_prob_topk | return y_pred_topk, y_prob_topk | ||||
class SQuADMetric(MetricBase): | |||||
r""" | |||||
别名::class:`fastNLP.SQuADMetric` :class:`fastNLP.core.metrics.SQuADMetric` | |||||
class ExtractiveQAMetric(MetricBase): | |||||
""" | |||||
别名::class:`fastNLP.ExtractiveQAMetric` :class:`fastNLP.core.metrics.ExtractiveQAMetric` | |||||
SQuAD数据集metric | |||||
抽取式QA(如SQuAD)的metric. | |||||
:param pred1: 参数映射表中 `pred1` 的映射关系,None表示映射关系为 `pred1` -> `pred1` | :param pred1: 参数映射表中 `pred1` 的映射关系,None表示映射关系为 `pred1` -> `pred1` | ||||
:param pred2: 参数映射表中 `pred2` 的映射关系,None表示映射关系为 `pred2` -> `pred2` | :param pred2: 参数映射表中 `pred2` 的映射关系,None表示映射关系为 `pred2` -> `pred2` | ||||
@@ -755,7 +756,7 @@ class SQuADMetric(MetricBase): | |||||
def __init__(self, pred1=None, pred2=None, target1=None, target2=None, | def __init__(self, pred1=None, pred2=None, target1=None, target2=None, | ||||
beta=1, right_open=True, print_predict_stat=False): | beta=1, right_open=True, print_predict_stat=False): | ||||
super(SQuADMetric, self).__init__() | |||||
super(ExtractiveQAMetric, self).__init__() | |||||
self._init_param_map(pred1=pred1, pred2=pred2, target1=target1, target2=target2) | self._init_param_map(pred1=pred1, pred2=pred2, target1=target1, target2=target2) | ||||
@@ -4,16 +4,26 @@ | |||||
这些模块的使用方法如下: | 这些模块的使用方法如下: | ||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
'SSTLoader', | |||||
'IMDBLoader', | |||||
'MatchingLoader', | 'MatchingLoader', | ||||
'SNLILoader', | |||||
'MNLILoader', | 'MNLILoader', | ||||
'MTL16Loader', | |||||
'QNLILoader', | 'QNLILoader', | ||||
'QuoraLoader', | 'QuoraLoader', | ||||
'RTELoader', | 'RTELoader', | ||||
'SSTLoader', | |||||
'SNLILoader', | |||||
'YelpLoader', | |||||
] | ] | ||||
from .imdb import IMDBLoader | |||||
from .matching import MatchingLoader | |||||
from .mnli import MNLILoader | |||||
from .mtl import MTL16Loader | |||||
from .qnli import QNLILoader | |||||
from .quora import QuoraLoader | |||||
from .rte import RTELoader | |||||
from .snli import SNLILoader | |||||
from .sst import SSTLoader | from .sst import SSTLoader | ||||
from .matching import MatchingLoader, SNLILoader, \ | |||||
MNLILoader, QNLILoader, QuoraLoader, RTELoader | |||||
from .yelp import YelpLoader |
@@ -0,0 +1,96 @@ | |||||
from typing import Union, Dict | |||||
from ..embed_loader import EmbeddingOption, EmbedLoader | |||||
from ..base_loader import DataSetLoader, DataInfo | |||||
from ...core.vocabulary import VocabularyOption, Vocabulary | |||||
from ...core.dataset import DataSet | |||||
from ...core.instance import Instance | |||||
from ...core.const import Const | |||||
from ..utils import get_tokenizer | |||||
class IMDBLoader(DataSetLoader): | |||||
""" | |||||
读取IMDB数据集,DataSet包含以下fields: | |||||
words: list(str), 需要分类的文本 | |||||
target: str, 文本的标签 | |||||
""" | |||||
def __init__(self): | |||||
super(IMDBLoader, self).__init__() | |||||
self.tokenizer = get_tokenizer() | |||||
def _load(self, path): | |||||
dataset = DataSet() | |||||
with open(path, 'r', encoding="utf-8") as f: | |||||
for line in f: | |||||
line = line.strip() | |||||
if not line: | |||||
continue | |||||
parts = line.split('\t') | |||||
target = parts[0] | |||||
words = self.tokenizer(parts[1].lower()) | |||||
dataset.append(Instance(words=words, target=target)) | |||||
if len(dataset) == 0: | |||||
raise RuntimeError(f"{path} has no valid data.") | |||||
return dataset | |||||
def process(self, | |||||
paths: Union[str, Dict[str, str]], | |||||
src_vocab_opt: VocabularyOption = None, | |||||
tgt_vocab_opt: VocabularyOption = None, | |||||
char_level_op=False): | |||||
datasets = {} | |||||
info = DataInfo() | |||||
for name, path in paths.items(): | |||||
dataset = self.load(path) | |||||
datasets[name] = dataset | |||||
def wordtochar(words): | |||||
chars = [] | |||||
for word in words: | |||||
word = word.lower() | |||||
for char in word: | |||||
chars.append(char) | |||||
chars.append('') | |||||
chars.pop() | |||||
return chars | |||||
if char_level_op: | |||||
for dataset in datasets.values(): | |||||
dataset.apply_field(wordtochar, field_name="words", new_field_name='chars') | |||||
datasets["train"], datasets["dev"] = datasets["train"].split(0.1, shuffle=False) | |||||
src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | |||||
src_vocab.from_dataset(datasets['train'], field_name='words') | |||||
src_vocab.index_dataset(*datasets.values(), field_name='words') | |||||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||||
if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) | |||||
tgt_vocab.from_dataset(datasets['train'], field_name='target') | |||||
tgt_vocab.index_dataset(*datasets.values(), field_name='target') | |||||
info.vocabs = { | |||||
Const.INPUT: src_vocab, | |||||
Const.TARGET: tgt_vocab | |||||
} | |||||
info.datasets = datasets | |||||
for name, dataset in info.datasets.items(): | |||||
dataset.set_input(Const.INPUT) | |||||
dataset.set_target(Const.TARGET) | |||||
return info | |||||
@@ -5,14 +5,13 @@ from typing import Union, Dict | |||||
from ...core.const import Const | from ...core.const import Const | ||||
from ...core.vocabulary import Vocabulary | from ...core.vocabulary import Vocabulary | ||||
from ..base_loader import DataInfo, DataSetLoader | from ..base_loader import DataInfo, DataSetLoader | ||||
from ..dataset_loader import JsonLoader, CSVLoader | |||||
from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | ||||
from ...modules.encoder._bert import BertTokenizer | from ...modules.encoder._bert import BertTokenizer | ||||
class MatchingLoader(DataSetLoader): | class MatchingLoader(DataSetLoader): | ||||
""" | """ | ||||
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` | |||||
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.data_loader.MatchingLoader` | |||||
读取Matching任务的数据集 | 读取Matching任务的数据集 | ||||
@@ -227,204 +226,3 @@ class MatchingLoader(DataSetLoader): | |||||
data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()]) | data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()]) | ||||
return data_info | return data_info | ||||
class SNLILoader(MatchingLoader, JsonLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` | |||||
读取SNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
fields = { | |||||
'sentence1_binary_parse': Const.INPUTS(0), | |||||
'sentence2_binary_parse': Const.INPUTS(1), | |||||
'gold_label': Const.TARGET, | |||||
} | |||||
paths = paths if paths is not None else { | |||||
'train': 'snli_1.0_train.jsonl', | |||||
'dev': 'snli_1.0_dev.jsonl', | |||||
'test': 'snli_1.0_test.jsonl'} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
JsonLoader.__init__(self, fields=fields) | |||||
def _load(self, path): | |||||
ds = JsonLoader._load(self, path) | |||||
parentheses_table = str.maketrans({'(': None, ')': None}) | |||||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||||
new_field_name=Const.INPUTS(0)) | |||||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||||
new_field_name=Const.INPUTS(1)) | |||||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||||
return ds | |||||
class RTELoader(MatchingLoader, CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.dataset_loader.RTELoader` | |||||
读取RTE数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
paths = paths if paths is not None else { | |||||
'train': 'train.tsv', | |||||
'dev': 'dev.tsv', | |||||
'test': 'test.tsv' # test set has not label | |||||
} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
self.fields = { | |||||
'sentence1': Const.INPUTS(0), | |||||
'sentence2': Const.INPUTS(1), | |||||
'label': Const.TARGET, | |||||
} | |||||
CSVLoader.__init__(self, sep='\t') | |||||
def _load(self, path): | |||||
ds = CSVLoader._load(self, path) | |||||
for k, v in self.fields.items(): | |||||
if v in ds.get_field_names(): | |||||
ds.rename_field(k, v) | |||||
for fields in ds.get_all_fields(): | |||||
if Const.INPUT in fields: | |||||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||||
return ds | |||||
class QNLILoader(MatchingLoader, CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.dataset_loader.QNLILoader` | |||||
读取QNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
paths = paths if paths is not None else { | |||||
'train': 'train.tsv', | |||||
'dev': 'dev.tsv', | |||||
'test': 'test.tsv' # test set has not label | |||||
} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
self.fields = { | |||||
'question': Const.INPUTS(0), | |||||
'sentence': Const.INPUTS(1), | |||||
'label': Const.TARGET, | |||||
} | |||||
CSVLoader.__init__(self, sep='\t') | |||||
def _load(self, path): | |||||
ds = CSVLoader._load(self, path) | |||||
for k, v in self.fields.items(): | |||||
if v in ds.get_field_names(): | |||||
ds.rename_field(k, v) | |||||
for fields in ds.get_all_fields(): | |||||
if Const.INPUT in fields: | |||||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||||
return ds | |||||
class MNLILoader(MatchingLoader, CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader` | |||||
读取MNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
paths = paths if paths is not None else { | |||||
'train': 'train.tsv', | |||||
'dev_matched': 'dev_matched.tsv', | |||||
'dev_mismatched': 'dev_mismatched.tsv', | |||||
'test_matched': 'test_matched.tsv', | |||||
'test_mismatched': 'test_mismatched.tsv', | |||||
# 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', | |||||
# 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', | |||||
# test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) | |||||
} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
CSVLoader.__init__(self, sep='\t') | |||||
self.fields = { | |||||
'sentence1_binary_parse': Const.INPUTS(0), | |||||
'sentence2_binary_parse': Const.INPUTS(1), | |||||
'gold_label': Const.TARGET, | |||||
} | |||||
def _load(self, path): | |||||
ds = CSVLoader._load(self, path) | |||||
for k, v in self.fields.items(): | |||||
if k in ds.get_field_names(): | |||||
ds.rename_field(k, v) | |||||
if Const.TARGET in ds.get_field_names(): | |||||
if ds[0][Const.TARGET] == 'hidden': | |||||
ds.delete_field(Const.TARGET) | |||||
parentheses_table = str.maketrans({'(': None, ')': None}) | |||||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||||
new_field_name=Const.INPUTS(0)) | |||||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||||
new_field_name=Const.INPUTS(1)) | |||||
if Const.TARGET in ds.get_field_names(): | |||||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||||
return ds | |||||
class QuoraLoader(MatchingLoader, CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.dataset_loader.QuoraLoader` | |||||
读取MNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
paths = paths if paths is not None else { | |||||
'train': 'train.tsv', | |||||
'dev': 'dev.tsv', | |||||
'test': 'test.tsv', | |||||
} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID')) | |||||
def _load(self, path): | |||||
ds = CSVLoader._load(self, path) | |||||
return ds |
@@ -0,0 +1,60 @@ | |||||
from ...core import Const | |||||
from .matching import MatchingLoader | |||||
from ..dataset_loader import CSVLoader | |||||
class MNLILoader(MatchingLoader, CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.data_loader.MNLILoader` | |||||
读取MNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
paths = paths if paths is not None else { | |||||
'train': 'train.tsv', | |||||
'dev_matched': 'dev_matched.tsv', | |||||
'dev_mismatched': 'dev_mismatched.tsv', | |||||
'test_matched': 'test_matched.tsv', | |||||
'test_mismatched': 'test_mismatched.tsv', | |||||
# 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', | |||||
# 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', | |||||
# test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) | |||||
} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
CSVLoader.__init__(self, sep='\t') | |||||
self.fields = { | |||||
'sentence1_binary_parse': Const.INPUTS(0), | |||||
'sentence2_binary_parse': Const.INPUTS(1), | |||||
'gold_label': Const.TARGET, | |||||
} | |||||
def _load(self, path): | |||||
ds = CSVLoader._load(self, path) | |||||
for k, v in self.fields.items(): | |||||
if k in ds.get_field_names(): | |||||
ds.rename_field(k, v) | |||||
if Const.TARGET in ds.get_field_names(): | |||||
if ds[0][Const.TARGET] == 'hidden': | |||||
ds.delete_field(Const.TARGET) | |||||
parentheses_table = str.maketrans({'(': None, ')': None}) | |||||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||||
new_field_name=Const.INPUTS(0)) | |||||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||||
new_field_name=Const.INPUTS(1)) | |||||
if Const.TARGET in ds.get_field_names(): | |||||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||||
return ds |
@@ -0,0 +1,65 @@ | |||||
from typing import Union, Dict | |||||
from ..base_loader import DataInfo | |||||
from ..dataset_loader import CSVLoader | |||||
from ...core.vocabulary import Vocabulary, VocabularyOption | |||||
from ...core.const import Const | |||||
from ..utils import check_dataloader_paths | |||||
class MTL16Loader(CSVLoader): | |||||
""" | |||||
读取MTL16数据集,DataSet包含以下fields: | |||||
words: list(str), 需要分类的文本 | |||||
target: str, 文本的标签 | |||||
数据来源:https://pan.baidu.com/s/1c2L6vdA | |||||
""" | |||||
def __init__(self): | |||||
super(MTL16Loader, self).__init__(headers=(Const.TARGET, Const.INPUT), sep='\t') | |||||
def _load(self, path): | |||||
dataset = super(MTL16Loader, self)._load(path) | |||||
dataset.apply(lambda x: x[Const.INPUT].lower().split(), new_field_name=Const.INPUT) | |||||
if len(dataset) == 0: | |||||
raise RuntimeError(f"{path} has no valid data.") | |||||
return dataset | |||||
def process(self, | |||||
paths: Union[str, Dict[str, str]], | |||||
src_vocab_opt: VocabularyOption = None, | |||||
tgt_vocab_opt: VocabularyOption = None,): | |||||
paths = check_dataloader_paths(paths) | |||||
datasets = {} | |||||
info = DataInfo() | |||||
for name, path in paths.items(): | |||||
dataset = self.load(path) | |||||
datasets[name] = dataset | |||||
src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | |||||
src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT) | |||||
src_vocab.index_dataset(*datasets.values(), field_name=Const.INPUT) | |||||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||||
if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) | |||||
tgt_vocab.from_dataset(datasets['train'], field_name=Const.TARGET) | |||||
tgt_vocab.index_dataset(*datasets.values(), field_name=Const.TARGET) | |||||
info.vocabs = { | |||||
Const.INPUT: src_vocab, | |||||
Const.TARGET: tgt_vocab | |||||
} | |||||
info.datasets = datasets | |||||
for name, dataset in info.datasets.items(): | |||||
dataset.set_input(Const.INPUT) | |||||
dataset.set_target(Const.TARGET) | |||||
return info |
@@ -0,0 +1,45 @@ | |||||
from ...core import Const | |||||
from .matching import MatchingLoader | |||||
from ..dataset_loader import CSVLoader | |||||
class QNLILoader(MatchingLoader, CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.data_loader.QNLILoader` | |||||
读取QNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
paths = paths if paths is not None else { | |||||
'train': 'train.tsv', | |||||
'dev': 'dev.tsv', | |||||
'test': 'test.tsv' # test set has not label | |||||
} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
self.fields = { | |||||
'question': Const.INPUTS(0), | |||||
'sentence': Const.INPUTS(1), | |||||
'label': Const.TARGET, | |||||
} | |||||
CSVLoader.__init__(self, sep='\t') | |||||
def _load(self, path): | |||||
ds = CSVLoader._load(self, path) | |||||
for k, v in self.fields.items(): | |||||
if k in ds.get_field_names(): | |||||
ds.rename_field(k, v) | |||||
for fields in ds.get_all_fields(): | |||||
if Const.INPUT in fields: | |||||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||||
return ds |
@@ -0,0 +1,32 @@ | |||||
from ...core import Const | |||||
from .matching import MatchingLoader | |||||
from ..dataset_loader import CSVLoader | |||||
class QuoraLoader(MatchingLoader, CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.data_loader.QuoraLoader` | |||||
读取MNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
paths = paths if paths is not None else { | |||||
'train': 'train.tsv', | |||||
'dev': 'dev.tsv', | |||||
'test': 'test.tsv', | |||||
} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID')) | |||||
def _load(self, path): | |||||
ds = CSVLoader._load(self, path) | |||||
return ds |
@@ -0,0 +1,45 @@ | |||||
from ...core import Const | |||||
from .matching import MatchingLoader | |||||
from ..dataset_loader import CSVLoader | |||||
class RTELoader(MatchingLoader, CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.data_loader.RTELoader` | |||||
读取RTE数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
paths = paths if paths is not None else { | |||||
'train': 'train.tsv', | |||||
'dev': 'dev.tsv', | |||||
'test': 'test.tsv' # test set has not label | |||||
} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
self.fields = { | |||||
'sentence1': Const.INPUTS(0), | |||||
'sentence2': Const.INPUTS(1), | |||||
'label': Const.TARGET, | |||||
} | |||||
CSVLoader.__init__(self, sep='\t') | |||||
def _load(self, path): | |||||
ds = CSVLoader._load(self, path) | |||||
for k, v in self.fields.items(): | |||||
if k in ds.get_field_names(): | |||||
ds.rename_field(k, v) | |||||
for fields in ds.get_all_fields(): | |||||
if Const.INPUT in fields: | |||||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||||
return ds |
@@ -0,0 +1,44 @@ | |||||
from ...core import Const | |||||
from .matching import MatchingLoader | |||||
from ..dataset_loader import JsonLoader | |||||
class SNLILoader(MatchingLoader, JsonLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.data_loader.SNLILoader` | |||||
读取SNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
fields = { | |||||
'sentence1_binary_parse': Const.INPUTS(0), | |||||
'sentence2_binary_parse': Const.INPUTS(1), | |||||
'gold_label': Const.TARGET, | |||||
} | |||||
paths = paths if paths is not None else { | |||||
'train': 'snli_1.0_train.jsonl', | |||||
'dev': 'snli_1.0_dev.jsonl', | |||||
'test': 'snli_1.0_test.jsonl'} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
JsonLoader.__init__(self, fields=fields) | |||||
def _load(self, path): | |||||
ds = JsonLoader._load(self, path) | |||||
parentheses_table = str.maketrans({'(': None, ')': None}) | |||||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||||
new_field_name=Const.INPUTS(0)) | |||||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||||
new_field_name=Const.INPUTS(1)) | |||||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||||
return ds |
@@ -1,19 +1,19 @@ | |||||
from typing import Iterable | |||||
from typing import Union, Dict | |||||
from nltk import Tree | from nltk import Tree | ||||
import spacy | |||||
from ..base_loader import DataInfo, DataSetLoader | from ..base_loader import DataInfo, DataSetLoader | ||||
from ..dataset_loader import CSVLoader | |||||
from ...core.vocabulary import VocabularyOption, Vocabulary | from ...core.vocabulary import VocabularyOption, Vocabulary | ||||
from ...core.dataset import DataSet | from ...core.dataset import DataSet | ||||
from ...core.const import Const | |||||
from ...core.instance import Instance | from ...core.instance import Instance | ||||
from ..utils import check_dataloader_paths, get_tokenizer | from ..utils import check_dataloader_paths, get_tokenizer | ||||
class SSTLoader(DataSetLoader): | class SSTLoader(DataSetLoader): | ||||
URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' | |||||
DATA_DIR = 'sst/' | |||||
""" | """ | ||||
别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader` | |||||
别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.data_loader.SSTLoader` | |||||
读取SST数据集, DataSet包含fields:: | 读取SST数据集, DataSet包含fields:: | ||||
@@ -26,6 +26,9 @@ class SSTLoader(DataSetLoader): | |||||
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` | :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` | ||||
""" | """ | ||||
URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' | |||||
DATA_DIR = 'sst/' | |||||
def __init__(self, subtree=False, fine_grained=False): | def __init__(self, subtree=False, fine_grained=False): | ||||
self.subtree = subtree | self.subtree = subtree | ||||
@@ -98,3 +101,72 @@ class SSTLoader(DataSetLoader): | |||||
return info | return info | ||||
class SST2Loader(CSVLoader): | |||||
""" | |||||
数据来源"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', | |||||
""" | |||||
def __init__(self): | |||||
super(SST2Loader, self).__init__(sep='\t') | |||||
self.tokenizer = get_tokenizer() | |||||
self.field = {'sentence': Const.INPUT, 'label': Const.TARGET} | |||||
def _load(self, path: str) -> DataSet: | |||||
ds = super(SST2Loader, self)._load(path) | |||||
ds.apply(lambda x: self.tokenizer(x[Const.INPUT]), new_field_name=Const.INPUT) | |||||
print("all count:", len(ds)) | |||||
return ds | |||||
def process(self, | |||||
paths: Union[str, Dict[str, str]], | |||||
src_vocab_opt: VocabularyOption = None, | |||||
tgt_vocab_opt: VocabularyOption = None, | |||||
char_level_op=False): | |||||
paths = check_dataloader_paths(paths) | |||||
datasets = {} | |||||
info = DataInfo() | |||||
for name, path in paths.items(): | |||||
dataset = self.load(path) | |||||
datasets[name] = dataset | |||||
def wordtochar(words): | |||||
chars = [] | |||||
for word in words: | |||||
word = word.lower() | |||||
for char in word: | |||||
chars.append(char) | |||||
chars.append('') | |||||
chars.pop() | |||||
return chars | |||||
input_name, target_name = Const.INPUT, Const.TARGET | |||||
info.vocabs={} | |||||
# 就分隔为char形式 | |||||
if char_level_op: | |||||
for dataset in datasets.values(): | |||||
dataset.apply_field(wordtochar, field_name=Const.INPUT, new_field_name=Const.CHAR_INPUT) | |||||
src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | |||||
src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT) | |||||
src_vocab.index_dataset(*datasets.values(), field_name=Const.INPUT) | |||||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||||
if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) | |||||
tgt_vocab.from_dataset(datasets['train'], field_name=Const.TARGET) | |||||
tgt_vocab.index_dataset(*datasets.values(), field_name=Const.TARGET) | |||||
info.vocabs = { | |||||
Const.INPUT: src_vocab, | |||||
Const.TARGET: tgt_vocab | |||||
} | |||||
info.datasets = datasets | |||||
for name, dataset in info.datasets.items(): | |||||
dataset.set_input(Const.INPUT) | |||||
dataset.set_target(Const.TARGET) | |||||
return info | |||||
@@ -0,0 +1,126 @@ | |||||
import csv | |||||
from typing import Iterable | |||||
from ...core.const import Const | |||||
from ...core import DataSet, Instance, Vocabulary | |||||
from ...core.vocabulary import VocabularyOption | |||||
from ..base_loader import DataInfo,DataSetLoader | |||||
from typing import Union, Dict | |||||
from ..utils import check_dataloader_paths, get_tokenizer | |||||
class YelpLoader(DataSetLoader): | |||||
""" | |||||
读取Yelp_full/Yelp_polarity数据集, DataSet包含fields: | |||||
words: list(str), 需要分类的文本 | |||||
target: str, 文本的标签 | |||||
chars:list(str),未index的字符列表 | |||||
数据集:yelp_full/yelp_polarity | |||||
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` | |||||
:param lower: 是否需要自动转小写,默认为False。 | |||||
""" | |||||
def __init__(self, fine_grained=False, lower=False): | |||||
super(YelpLoader, self).__init__() | |||||
tag_v = {'1.0': 'very negative', '2.0': 'negative', '3.0': 'neutral', | |||||
'4.0': 'positive', '5.0': 'very positive'} | |||||
if not fine_grained: | |||||
tag_v['1.0'] = tag_v['2.0'] | |||||
tag_v['5.0'] = tag_v['4.0'] | |||||
self.fine_grained = fine_grained | |||||
self.tag_v = tag_v | |||||
self.lower = lower | |||||
self.tokenizer = get_tokenizer() | |||||
def _load(self, path): | |||||
ds = DataSet() | |||||
csv_reader = csv.reader(open(path, encoding='utf-8')) | |||||
all_count = 0 | |||||
real_count = 0 | |||||
for row in csv_reader: | |||||
all_count += 1 | |||||
if len(row) == 2: | |||||
target = self.tag_v[row[0] + ".0"] | |||||
words = clean_str(row[1], self.tokenizer, self.lower) | |||||
if len(words) != 0: | |||||
ds.append(Instance(words=words, target=target)) | |||||
real_count += 1 | |||||
print("all count:", all_count) | |||||
print("real count:", real_count) | |||||
return ds | |||||
def process(self, paths: Union[str, Dict[str, str]], | |||||
train_ds: Iterable[str] = None, | |||||
src_vocab_op: VocabularyOption = None, | |||||
tgt_vocab_op: VocabularyOption = None, | |||||
char_level_op=False): | |||||
paths = check_dataloader_paths(paths) | |||||
info = DataInfo(datasets=self.load(paths)) | |||||
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) | |||||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||||
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) | |||||
_train_ds = [info.datasets[name] | |||||
for name in train_ds] if train_ds else info.datasets.values() | |||||
def wordtochar(words): | |||||
chars = [] | |||||
for word in words: | |||||
word = word.lower() | |||||
for char in word: | |||||
chars.append(char) | |||||
chars.append('') | |||||
chars.pop() | |||||
return chars | |||||
input_name, target_name = Const.INPUT, Const.TARGET | |||||
info.vocabs = {} | |||||
# 就分隔为char形式 | |||||
if char_level_op: | |||||
for dataset in info.datasets.values(): | |||||
dataset.apply_field(wordtochar, field_name=Const.INPUT, new_field_name=Const.CHAR_INPUT) | |||||
else: | |||||
src_vocab.from_dataset(*_train_ds, field_name=input_name) | |||||
src_vocab.index_dataset(*info.datasets.values(), field_name=input_name, new_field_name=input_name) | |||||
info.vocabs[input_name] = src_vocab | |||||
tgt_vocab.from_dataset(*_train_ds, field_name=target_name) | |||||
tgt_vocab.index_dataset( | |||||
*info.datasets.values(), | |||||
field_name=target_name, new_field_name=target_name) | |||||
info.vocabs[target_name] = tgt_vocab | |||||
info.datasets['train'], info.datasets['dev'] = info.datasets['train'].split(0.1, shuffle=False) | |||||
for name, dataset in info.datasets.items(): | |||||
dataset.set_input(Const.INPUT) | |||||
dataset.set_target(Const.TARGET) | |||||
return info | |||||
def clean_str(sentence, tokenizer, char_lower=False): | |||||
""" | |||||
heavily borrowed from github | |||||
https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb | |||||
:param sentence: is a str | |||||
:return: | |||||
""" | |||||
if char_lower: | |||||
sentence = sentence.lower() | |||||
import re | |||||
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') | |||||
words = tokenizer(sentence) | |||||
words_collection = [] | |||||
for word in words: | |||||
if word in ['-lrb-', '-rrb-', '<sssss>', '-r', '-l', 'b-']: | |||||
continue | |||||
tt = nonalpnum.split(word) | |||||
t = ''.join(tt) | |||||
if t != '': | |||||
words_collection.append(t) | |||||
return words_collection | |||||
@@ -1,14 +0,0 @@ | |||||
__all__ = [ | |||||
"MaxPool", | |||||
"MaxPoolWithMask", | |||||
"AvgPool", | |||||
"MultiHeadAttention", | |||||
] | |||||
from .pooling import MaxPool | |||||
from .pooling import MaxPoolWithMask | |||||
from .pooling import AvgPool | |||||
from .pooling import AvgPoolWithMask | |||||
from .attention import MultiHeadAttention |
@@ -8,9 +8,9 @@ import torch | |||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from torch import nn | from torch import nn | ||||
from ..dropout import TimestepDropout | |||||
from fastNLP.modules.dropout import TimestepDropout | |||||
from ..utils import initial_parameter | |||||
from fastNLP.modules.utils import initial_parameter | |||||
class DotAttention(nn.Module): | class DotAttention(nn.Module): |
@@ -3,7 +3,7 @@ __all__ = [ | |||||
] | ] | ||||
from torch import nn | from torch import nn | ||||
from ..aggregator.attention import MultiHeadAttention | |||||
from fastNLP.modules.encoder.attention import MultiHeadAttention | |||||
from ..dropout import TimestepDropout | from ..dropout import TimestepDropout | ||||
@@ -11,7 +11,7 @@ Coreference resolution是查找文本中指向同一现实实体的所有表达 | |||||
由于版权问题,本文无法提供数据集的下载,请自行下载。 | 由于版权问题,本文无法提供数据集的下载,请自行下载。 | ||||
原始数据集的格式为conll格式,详细介绍参考数据集给出的官方介绍页面。 | 原始数据集的格式为conll格式,详细介绍参考数据集给出的官方介绍页面。 | ||||
代码实现采用了论文作者Lee的预处理方法,具体细节参加[链接](https://github.com/kentonl/e2e-coref/blob/e2e/setup_training.sh)。 | |||||
代码实现采用了论文作者Lee的预处理方法,具体细节参见[链接](https://github.com/kentonl/e2e-coref/blob/e2e/setup_training.sh)。 | |||||
处理之后的数据集为json格式,例子: | 处理之后的数据集为json格式,例子: | ||||
``` | ``` | ||||
{ | { | ||||
@@ -25,12 +25,12 @@ Coreference resolution是查找文本中指向同一现实实体的所有表达 | |||||
### embedding 数据集下载 | ### embedding 数据集下载 | ||||
[turian emdedding](https://lil.cs.washington.edu/coref/turian.50d.txt) | [turian emdedding](https://lil.cs.washington.edu/coref/turian.50d.txt) | ||||
[glove embedding]( https://nlp.stanford.edu/data/glove.840B.300d.zip) | |||||
[glove embedding](https://nlp.stanford.edu/data/glove.840B.300d.zip) | |||||
## 运行 | ## 运行 | ||||
```python | |||||
```shell | |||||
# 训练代码 | # 训练代码 | ||||
CUDA_VISIBLE_DEVICES=0 python train.py | CUDA_VISIBLE_DEVICES=0 python train.py | ||||
# 测试代码 | # 测试代码 | ||||
@@ -39,9 +39,9 @@ CUDA_VISIBLE_DEVICES=0 python valid.py | |||||
## 结果 | ## 结果 | ||||
原论文作者在测试集上取得了67.2%的结果,AllenNLP复现的结果为 [63.0%](https://allennlp.org/models)。 | 原论文作者在测试集上取得了67.2%的结果,AllenNLP复现的结果为 [63.0%](https://allennlp.org/models)。 | ||||
其中allenNLP训练时没有加入speaker信息,没有variational dropout以及只使用了100的antecedents而不是250。 | |||||
其中AllenNLP训练时没有加入speaker信息,没有variational dropout以及只使用了100的antecedents而不是250。 | |||||
在与allenNLP使用同样的超参和配置时,本代码复现取得了63.6%的F1值。 | |||||
在与AllenNLP使用同样的超参和配置时,本代码复现取得了63.6%的F1值。 | |||||
## 问题 | ## 问题 |
@@ -2,7 +2,7 @@ | |||||
这里使用fastNLP复现了几个著名的Matching任务的模型,旨在达到与论文中相符的性能。这几个任务的评价指标均为准确率(%). | 这里使用fastNLP复现了几个著名的Matching任务的模型,旨在达到与论文中相符的性能。这几个任务的评价指标均为准确率(%). | ||||
复现的模型有(按论文发表时间顺序排序): | 复现的模型有(按论文发表时间顺序排序): | ||||
- CNTN:模型代码(still in progress)[](); 训练代码(still in progress)[](). | |||||
- CNTN:[模型代码](model/cntn.py); [训练代码](matching_cntn.py). | |||||
论文链接:[Convolutional Neural Tensor Network Architecture for Community-based Question Answering](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844). | 论文链接:[Convolutional Neural Tensor Network Architecture for Community-based Question Answering](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844). | ||||
- ESIM:[模型代码](model/esim.py); [训练代码](matching_esim.py). | - ESIM:[模型代码](model/esim.py); [训练代码](matching_esim.py). | ||||
论文链接:[Enhanced LSTM for Natural Language Inference](https://arxiv.org/pdf/1609.06038.pdf). | 论文链接:[Enhanced LSTM for Natural Language Inference](https://arxiv.org/pdf/1609.06038.pdf). | ||||
@@ -21,7 +21,7 @@ | |||||
model name | SNLI | MNLI | RTE | QNLI | Quora | model name | SNLI | MNLI | RTE | QNLI | Quora | ||||
:---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | ||||
CNTN [](); [论文](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844) | 74.53 vs - | 60.84/-(dev) vs - | 57.4(dev) vs - | 62.53(dev) vs - | - | | |||||
CNTN [代码](model/cntn.py); [论文](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844) | 77.79 vs - | 63.29/63.16(dev) vs - | 57.04(dev) vs - | 62.38(dev) vs - | - | | |||||
ESIM[代码](model/bert.py); [论文](https://arxiv.org/pdf/1609.06038.pdf) | 88.13(glove) vs 88.0(glove)/88.7(elmo) | 77.78/76.49 vs 72.4/72.1* | 59.21(dev) vs - | 76.97(dev) vs - | - | | ESIM[代码](model/bert.py); [论文](https://arxiv.org/pdf/1609.06038.pdf) | 88.13(glove) vs 88.0(glove)/88.7(elmo) | 77.78/76.49 vs 72.4/72.1* | 59.21(dev) vs - | 76.97(dev) vs - | - | | ||||
DIIN [](); [论文](https://arxiv.org/pdf/1709.04348.pdf) | - vs 88.0 | - vs 78.8/77.8 | - | - | - vs 89.06 | | DIIN [](); [论文](https://arxiv.org/pdf/1709.04348.pdf) | - vs 88.0 | - vs 78.8/77.8 | - | - | - vs 89.06 | | ||||
MwAN [](); [论文](https://www.ijcai.org/proceedings/2018/0613.pdf) | 87.9 vs 88.3 | 77.3/76.7(dev) vs 78.5/77.7 | - | 74.6(dev) vs - | 85.6 vs 89.12 | | MwAN [](); [论文](https://www.ijcai.org/proceedings/2018/0613.pdf) | 87.9 vs 88.3 | 77.3/76.7(dev) vs 78.5/77.7 | - | 74.6(dev) vs - | 85.6 vs 89.12 | | ||||
@@ -44,7 +44,7 @@ Performance on Test set: | |||||
model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | BERT-Large | model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | BERT-Large | ||||
:---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | ||||
__performance__ | - | 88.13 | - | 87.9 | 90.6 | 91.16 | |||||
__performance__ | 77.79 | 88.13 | - | 87.9 | 90.6 | 91.16 | |||||
## MNLI | ## MNLI | ||||
[Link to MNLI main page](https://www.nyu.edu/projects/bowman/multinli/) | [Link to MNLI main page](https://www.nyu.edu/projects/bowman/multinli/) | ||||
@@ -60,7 +60,7 @@ Performance on Test set(matched/mismatched): | |||||
model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | ||||
:---: | :---: | :---: | :---: | :---: | :---: | | :---: | :---: | :---: | :---: | :---: | :---: | | ||||
__performance__ | - | 77.78/76.49 | - | 77.3/76.7(dev) | - | | |||||
__performance__ | 63.29/63.16(dev) | 77.78/76.49 | - | 77.3/76.7(dev) | - | | |||||
## RTE | ## RTE | ||||
@@ -92,7 +92,7 @@ Performance on __Dev__ set: | |||||
model name | CNTN | ESIM | DIIN | MwAN | BERT | model name | CNTN | ESIM | DIIN | MwAN | BERT | ||||
:---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | ||||
__performance__ | - | 76.97 | - | 74.6 | - | |||||
__performance__ | 62.38 | 76.97 | - | 74.6 | - | |||||
## Quora | ## Quora | ||||
@@ -3,3 +3,5 @@ torch>=1.0.0 | |||||
tqdm>=4.28.1 | tqdm>=4.28.1 | ||||
nltk>=3.4.1 | nltk>=3.4.1 | ||||
requests | requests | ||||
spacy | |||||
h5py |