@@ -2,6 +2,6 @@ fastNLP.core.callback | |||
===================== | |||
.. automodule:: fastNLP.core.callback | |||
:members: Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, CallbackException, EarlyStopError | |||
:members: Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, CallbackException, EarlyStopError | |||
:inherited-members: | |||
@@ -2,6 +2,6 @@ fastNLP.io.loader | |||
================= | |||
.. automodule:: fastNLP.io.loader | |||
:members: Loader, YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader, ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader, MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader, CSVLoader, JsonLoader, CWSLoader, MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader, CoReferenceLoader | |||
:members: Loader, YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader, ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader, MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader, CSVLoader, JsonLoader, CWSLoader, MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader, CNXNLILoader, BQCorpusLoader, LCQMCLoader, CoReferenceLoader | |||
:inherited-members: | |||
@@ -2,6 +2,6 @@ fastNLP.io.pipe | |||
=============== | |||
.. automodule:: fastNLP.io.pipe | |||
:members: Pipe, CWSPipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe, Conll2003Pipe, MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe, CoReferencePipe | |||
:members: Pipe, CWSPipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, WeiboSenti100kPipe, Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe, Conll2003Pipe, MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, CNXNLIBertPipe, BQCorpusBertPipe, LCQMCBertPipe, MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe, LCQMCPipe, CNXNLIPipe, BQCorpusPipe, RenamePipe, GranularizePipe, MachingTruncatePipe, CoReferencePipe | |||
:inherited-members: | |||
@@ -2,7 +2,7 @@ fastNLP.io | |||
========== | |||
.. automodule:: fastNLP.io | |||
:members: DataBundle, EmbedLoader, Loader, YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader, ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader, MsraNERLoader, WeiboNERLoader, PeopleDailyNERLoader, CSVLoader, JsonLoader, CWSLoader, MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader, Pipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, Conll2003Pipe, Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, PeopleDailyPipe, WeiboNERPipe, CWSPipe, MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe, ModelLoader, ModelSaver | |||
:members: DataBundle, EmbedLoader, Loader, YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader, ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader, MsraNERLoader, WeiboNERLoader, PeopleDailyNERLoader, CSVLoader, JsonLoader, CWSLoader, MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader, CNXNLILoader, BQCorpusLoader, LCQMCLoader, Pipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, WeiboSenti100kPipe, Conll2003Pipe, Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, PeopleDailyPipe, WeiboNERPipe, CWSPipe, MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe, ModelLoader, ModelSaver | |||
:inherited-members: | |||
子模块 | |||
@@ -2,7 +2,7 @@ fastNLP | |||
======= | |||
.. automodule:: fastNLP | |||
:members: Instance, FieldArray, DataSetIter, BatchIter, TorchLoaderIter, Vocabulary, DataSet, Const, Trainer, Tester, Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, CallbackException, EarlyStopError, Padder, AutoPadder, EngChar2DPadder, AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric, Optimizer, SGD, Adam, AdamW, Sampler, SequentialSampler, BucketSampler, RandomSampler, LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, cache_results, logger | |||
:members: Instance, FieldArray, DataSetIter, BatchIter, TorchLoaderIter, Vocabulary, DataSet, Const, Trainer, Tester, Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, CallbackException, EarlyStopError, Padder, AutoPadder, EngChar2DPadder, AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric, Optimizer, SGD, Adam, AdamW, Sampler, SequentialSampler, BucketSampler, RandomSampler, LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, cache_results, logger | |||
:inherited-members: | |||
子模块 | |||
@@ -47,7 +47,7 @@ __all__ = [ | |||
"SNLILoader", | |||
"QNLILoader", | |||
"RTELoader", | |||
"XNLILoader", | |||
"CNXNLILoader", | |||
"BQCorpusLoader", | |||
"LCQMCLoader", | |||
@@ -70,32 +70,61 @@ __all__ = [ | |||
"WeiboNERPipe", | |||
"CWSPipe", | |||
"Pipe", | |||
"CWSPipe", | |||
"YelpFullPipe", | |||
"YelpPolarityPipe", | |||
"SSTPipe", | |||
"SST2Pipe", | |||
"IMDBPipe", | |||
"ChnSentiCorpPipe", | |||
"THUCNewsPipe", | |||
"WeiboSenti100kPipe", | |||
"Conll2003NERPipe", | |||
"OntoNotesNERPipe", | |||
"MsraNERPipe", | |||
"WeiboNERPipe", | |||
"PeopleDailyPipe", | |||
"Conll2003Pipe", | |||
"MatchingBertPipe", | |||
"RTEBertPipe", | |||
"SNLIBertPipe", | |||
"QuoraBertPipe", | |||
"QNLIBertPipe", | |||
"MNLIBertPipe", | |||
"CNXNLIBertPipe", | |||
"BQCorpusBertPipe", | |||
"LCQMCBertPipe", | |||
"MatchingPipe", | |||
"RTEPipe", | |||
"SNLIPipe", | |||
"QuoraPipe", | |||
"QNLIPipe", | |||
"MNLIPipe", | |||
"LCQMCPipe", | |||
"CNXNLIPipe", | |||
"BQCorpusPipe", | |||
"RenamePipe", | |||
"GranularizePipe", | |||
"MachingTruncatePipe", | |||
'ModelLoader', | |||
'ModelSaver', | |||
] | |||
from .embed_loader import EmbedLoader | |||
from .data_bundle import DataBundle | |||
from .model_io import ModelLoader, ModelSaver | |||
import sys | |||
from .data_bundle import DataBundle | |||
from .embed_loader import EmbedLoader | |||
from .loader import * | |||
from .model_io import ModelLoader, ModelSaver | |||
from .pipe import * | |||
import sys | |||
from ..doc_utils import doc_process | |||
doc_process(sys.modules[__name__]) |
@@ -54,7 +54,9 @@ __all__ = [ | |||
'SSTLoader', | |||
'SST2Loader', | |||
"ChnSentiCorpLoader", | |||
"THUCNewsLoader", | |||
"WeiboSenti100kLoader", | |||
'ConllLoader', | |||
'Conll2003Loader', | |||
'Conll2003NERLoader', | |||
@@ -63,26 +65,31 @@ __all__ = [ | |||
"MsraNERLoader", | |||
"PeopleDailyNERLoader", | |||
"WeiboNERLoader", | |||
'CSVLoader', | |||
'JsonLoader', | |||
'CWSLoader', | |||
'MNLILoader', | |||
"QuoraLoader", | |||
"SNLILoader", | |||
"QNLILoader", | |||
"RTELoader", | |||
"CNXNLILoader", | |||
"BQCorpusLoader", | |||
"LCQMCLoader", | |||
"CoReferenceLoader" | |||
] | |||
from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader | |||
from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, \ | |||
ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader | |||
from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader | |||
from .conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader | |||
from .coreference import CoReferenceLoader | |||
from .csv import CSVLoader | |||
from .cws import CWSLoader | |||
from .json import JsonLoader | |||
from .loader import Loader | |||
from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader | |||
from .conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader | |||
from .coreference import CoReferenceLoader | |||
from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader, CNXNLILoader, BQCorpusLoader, \ | |||
LCQMCLoader |
@@ -409,6 +409,7 @@ class THUCNewsLoader(Loader): | |||
.. csv-table:: | |||
:header: "raw_words", "target" | |||
"马晓旭意外受伤让国奥警惕 无奈大雨格外青睐殷家军记者傅亚雨沈阳报道 ... ", "体育" | |||
"...", "..." | |||
@@ -446,13 +447,18 @@ class WeiboSenti100kLoader(Loader): | |||
别名: | |||
数据集简介:微博sentiment classification,二分类 | |||
原始数据内容为: | |||
label text | |||
0 六一出生的?好讽刺…… //@祭春姬:他爸爸是外星人吧 //@面孔小高:现在的孩子都怎么了 [怒][怒][怒] | |||
1 听过一场!笑死了昂,一听茄子脱口秀,从此节操是路人![嘻嘻] //@中国梦网官微:@Pencil彭赛 @茄子脱口秀 [圣诞帽][圣诞树][平安果] | |||
.. .. code-block:: text | |||
label text | |||
0 六一出生的?好讽刺…… //@祭春姬:他爸爸是外星人吧 //@面孔小高:现在的孩子都怎么了 [怒][怒][怒] | |||
1 听过一场!笑死了昂,一听茄子脱口秀,从此节操是路人![嘻嘻] //@中国梦网官微:@Pencil彭赛 @茄子脱口秀 [圣诞帽][圣诞树][平安果] | |||
读取后的Dataset将具有以下数据结构: | |||
.. csv-table:: | |||
:header: "raw_chars", "target" | |||
"六一出生的?好讽刺…… //@祭春姬:他爸爸是外星人吧 //@面孔小高:现在的孩子都怎么了 [怒][怒][怒]", "0" | |||
"...", "..." | |||
@@ -15,14 +15,14 @@ import os | |||
import warnings | |||
from typing import Union, Dict | |||
from .csv import CSVLoader | |||
from .json import JsonLoader | |||
from .loader import Loader | |||
from .. import DataBundle | |||
from ..utils import check_loader_paths | |||
from ...core.const import Const | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
from .csv import CSVLoader | |||
from ..utils import check_loader_paths | |||
class MNLILoader(Loader): | |||
@@ -348,8 +348,9 @@ class CNXNLILoader(Loader): | |||
.. csv-table:: | |||
:header: "raw_chars1", "raw_chars2", "target" | |||
"从概念上看,奶油收入有两个基本方面产品和地理.", "产品和地理是什么使奶油抹霜工作.", "1" | |||
""...", "...", "..." | |||
"...", "...", "..." | |||
""" | |||
@@ -412,6 +413,7 @@ class BQCorpusLoader(Loader): | |||
.. csv-table:: | |||
:header: "raw_chars1", "raw_chars2", "target" | |||
"不是邀请的如何贷款?", "我不是你们邀请的客人可以贷款吗?", "1" | |||
"如何满足微粒银行的审核", "建设银行有微粒贷的资格吗", "0" | |||
"...", "...", "..." | |||
@@ -458,9 +460,10 @@ class LCQMCLoader(Loader): | |||
.. csv-table:: | |||
:header: "raw_chars1", "raw_chars2", "target" | |||
"喜欢打篮球的男生喜欢什么样的女生?", "爱打篮球的男生喜欢什么样的女生?", "1" | |||
"晚上睡觉带着耳机听音乐有什么害处吗?", "妇可以戴耳机听音乐吗?", "0" | |||
""...", "...", "..." | |||
"...", "...", "..." | |||
""" | |||
@@ -9,9 +9,9 @@ Pipe用于处理通过 Loader 读取的数据,所有的 Pipe 都包含 ``proce | |||
""" | |||
__all__ = [ | |||
"Pipe", | |||
"CWSPipe", | |||
"YelpFullPipe", | |||
"YelpPolarityPipe", | |||
"SSTPipe", | |||
@@ -20,35 +20,46 @@ __all__ = [ | |||
"ChnSentiCorpPipe", | |||
"THUCNewsPipe", | |||
"WeiboSenti100kPipe", | |||
"Conll2003NERPipe", | |||
"OntoNotesNERPipe", | |||
"MsraNERPipe", | |||
"WeiboNERPipe", | |||
"PeopleDailyPipe", | |||
"Conll2003Pipe", | |||
"MatchingBertPipe", | |||
"RTEBertPipe", | |||
"SNLIBertPipe", | |||
"QuoraBertPipe", | |||
"QNLIBertPipe", | |||
"MNLIBertPipe", | |||
"CNXNLIBertPipe", | |||
"BQCorpusBertPipe", | |||
"LCQMCBertPipe", | |||
"MatchingPipe", | |||
"RTEPipe", | |||
"SNLIPipe", | |||
"QuoraPipe", | |||
"QNLIPipe", | |||
"MNLIPipe", | |||
"LCQMCPipe", | |||
"CNXNLIPipe", | |||
"BQCorpusPipe", | |||
"RenamePipe", | |||
"GranularizePipe", | |||
"MachingTruncatePipe", | |||
"CoReferencePipe" | |||
] | |||
from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, WeiboSenti100kPipe | |||
from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, \ | |||
WeiboSenti100kPipe | |||
from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe | |||
from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \ | |||
MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe | |||
from .pipe import Pipe | |||
from .conll import Conll2003Pipe | |||
from .cws import CWSPipe | |||
from .coreference import CoReferencePipe | |||
from .cws import CWSPipe | |||
from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \ | |||
MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe, CNXNLIBertPipe, CNXNLIPipe, BQCorpusBertPipe, \ | |||
LCQMCPipe, BQCorpusPipe, LCQMCBertPipe, RenamePipe, GranularizePipe, MachingTruncatePipe | |||
from .pipe import Pipe |
@@ -21,11 +21,11 @@ from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_insta | |||
from ..data_bundle import DataBundle | |||
from ..loader.classification import ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader | |||
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader | |||
from ...core._logger import logger | |||
from ...core.const import Const | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
from ...core.vocabulary import Vocabulary | |||
from ...core._logger import logger | |||
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') | |||
@@ -718,6 +718,7 @@ class THUCNewsPipe(_CLSPipe): | |||
.. csv-table:: | |||
:header: "raw_words", "target" | |||
"马晓旭意外受伤让国奥警惕 无奈大雨格外青睐殷家军记者傅亚雨沈阳报道 ... ", "体育" | |||
"...", "..." | |||
@@ -826,6 +827,7 @@ class WeiboSenti100kPipe(_CLSPipe): | |||
.. csv-table:: | |||
:header: "raw_chars", "target" | |||
"六一出生的?好讽刺…… //@祭春姬:他爸爸是外星人吧 //@面孔小高:现在的孩子都怎么了 [怒][怒][怒]", "0" | |||
"...", "..." | |||
@@ -16,20 +16,24 @@ __all__ = [ | |||
"QuoraPipe", | |||
"QNLIPipe", | |||
"MNLIPipe", | |||
"LCQMCPipe", | |||
"CNXNLIPipe", | |||
"BQCorpusPipe", | |||
"LCQMCPipe", | |||
"RenamePipe", | |||
"GranularizePipe", | |||
"MachingTruncatePipe", | |||
] | |||
import warnings | |||
from .pipe import Pipe | |||
from .utils import get_tokenizer | |||
from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader, BQCorpusLoader, CNXNLILoader, LCQMCLoader | |||
from ..data_bundle import DataBundle | |||
from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader, BQCorpusLoader, CNXNLILoader, \ | |||
LCQMCLoader | |||
from ...core._logger import logger | |||
from ...core.const import Const | |||
from ...core.vocabulary import Vocabulary | |||
from ...core._logger import logger | |||
from ..data_bundle import DataBundle | |||
class MatchingBertPipe(Pipe): | |||
@@ -145,7 +149,7 @@ class MatchingBertPipe(Pipe): | |||
f"data set but not in train data set!." | |||
warnings.warn(warn_msg) | |||
logger.warning(warn_msg) | |||
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if | |||
dataset.has_field(Const.TARGET)] | |||
target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) | |||
@@ -294,7 +298,7 @@ class MatchingPipe(Pipe): | |||
f"data set but not in train data set!." | |||
warnings.warn(warn_msg) | |||
logger.warning(warn_msg) | |||
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if | |||
dataset.has_field(Const.TARGET)] | |||
target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) | |||
@@ -345,8 +349,9 @@ class MNLIPipe(MatchingPipe): | |||
data_bundle = MNLILoader().load(paths) | |||
return self.process(data_bundle) | |||
class LCQMCPipe(MatchingPipe): | |||
def process_from_file(self, paths = None): | |||
def process_from_file(self, paths=None): | |||
data_bundle = LCQMCLoader().load(paths) | |||
data_bundle = RenamePipe().process(data_bundle) | |||
data_bundle = self.process(data_bundle) | |||
@@ -358,14 +363,14 @@ class CNXNLIPipe(MatchingPipe): | |||
def process_from_file(self, paths=None): | |||
data_bundle = CNXNLILoader().load(paths) | |||
data_bundle = GranularizePipe(task='XNLI').process(data_bundle) | |||
data_bundle = RenamePipe().process(data_bundle) #使中文数据的field | |||
data_bundle = RenamePipe().process(data_bundle) # 使中文数据的field | |||
data_bundle = self.process(data_bundle) | |||
data_bundle = RenamePipe().process(data_bundle) | |||
return data_bundle | |||
class BQCorpusPipe(MatchingPipe): | |||
def process_from_file(self, paths = None): | |||
def process_from_file(self, paths=None): | |||
data_bundle = BQCorpusLoader().load(paths) | |||
data_bundle = RenamePipe().process(data_bundle) | |||
data_bundle = self.process(data_bundle) | |||
@@ -374,12 +379,12 @@ class BQCorpusPipe(MatchingPipe): | |||
class RenamePipe(Pipe): | |||
def __init__(self, task = 'cn-nli'): | |||
def __init__(self, task='cn-nli'): | |||
super().__init__() | |||
self.task = task | |||
def process(self, data_bundle: DataBundle): # rename field name for Chinese Matching dataset | |||
if(self.task == 'cn-nli'): | |||
if (self.task == 'cn-nli'): | |||
for name, dataset in data_bundle.datasets.items(): | |||
if (dataset.has_field(Const.RAW_CHARS(0))): | |||
dataset.rename_field(Const.RAW_CHARS(0), Const.RAW_WORDS(0)) # RAW_CHARS->RAW_WORDS | |||
@@ -392,12 +397,12 @@ class RenamePipe(Pipe): | |||
else: | |||
raise RuntimeError( | |||
"field name of dataset is not qualified. It should have ether RAW_CHARS or WORDS") | |||
elif(self.task == 'cn-nli-bert'): | |||
elif (self.task == 'cn-nli-bert'): | |||
for name, dataset in data_bundle.datasets.items(): | |||
if (dataset.has_field(Const.RAW_CHARS(0))): | |||
dataset.rename_field(Const.RAW_CHARS(0), Const.RAW_WORDS(0)) # RAW_CHARS->RAW_WORDS | |||
dataset.rename_field(Const.RAW_CHARS(1), Const.RAW_WORDS(1)) | |||
elif(dataset.has_field(Const.RAW_WORDS(0))): | |||
elif (dataset.has_field(Const.RAW_WORDS(0))): | |||
dataset.rename_field(Const.RAW_WORDS(0), Const.RAW_CHARS(0)) | |||
dataset.rename_field(Const.RAW_WORDS(1), Const.RAW_CHARS(1)) | |||
dataset.rename_field(Const.INPUT, Const.CHAR_INPUT) | |||
@@ -409,15 +414,15 @@ class RenamePipe(Pipe): | |||
raise RuntimeError( | |||
"Only support task='cn-nli' or 'cn-nli-bert'" | |||
) | |||
return data_bundle | |||
class GranularizePipe(Pipe): | |||
def __init__(self, task = None): | |||
def __init__(self, task=None): | |||
super().__init__() | |||
self.task = task | |||
def _granularize(self, data_bundle, tag_map): | |||
""" | |||
该函数对data_bundle中'target'列中的内容进行转换。 | |||
@@ -434,21 +439,22 @@ class GranularizePipe(Pipe): | |||
dataset.drop(lambda ins: ins[Const.TARGET] == -100) | |||
data_bundle.set_dataset(dataset, name) | |||
return data_bundle | |||
def process(self, data_bundle: DataBundle): | |||
task_tag_dict = { | |||
'XNLI':{'neutral': 0, 'entailment': 1, 'contradictory': 2, 'contradiction': 2} | |||
'XNLI': {'neutral': 0, 'entailment': 1, 'contradictory': 2, 'contradiction': 2} | |||
} | |||
if self.task in task_tag_dict: | |||
data_bundle = self._granularize(data_bundle=data_bundle, tag_map= task_tag_dict[self.task]) | |||
data_bundle = self._granularize(data_bundle=data_bundle, tag_map=task_tag_dict[self.task]) | |||
else: | |||
raise RuntimeError(f"Only support {task_tag_dict.keys()} task_tag_map.") | |||
return data_bundle | |||
class MachingTruncatePipe(Pipe): #truncate sentence for bert, modify seq_len | |||
class MachingTruncatePipe(Pipe): # truncate sentence for bert, modify seq_len | |||
def __init__(self): | |||
super().__init__() | |||
def process(self, data_bundle: DataBundle): | |||
for name, dataset in data_bundle.datasets.items(): | |||
pass | |||
@@ -456,7 +462,7 @@ class MachingTruncatePipe(Pipe): #truncate sentence for bert, modify seq_len | |||
class LCQMCBertPipe(MatchingBertPipe): | |||
def process_from_file(self, paths = None): | |||
def process_from_file(self, paths=None): | |||
data_bundle = LCQMCLoader().load(paths) | |||
data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) | |||
data_bundle = self.process(data_bundle) | |||
@@ -465,7 +471,7 @@ class LCQMCBertPipe(MatchingBertPipe): | |||
class BQCorpusBertPipe(MatchingBertPipe): | |||
def process_from_file(self, paths = None): | |||
def process_from_file(self, paths=None): | |||
data_bundle = BQCorpusLoader().load(paths) | |||
data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) | |||
data_bundle = self.process(data_bundle) | |||
@@ -474,7 +480,7 @@ class BQCorpusBertPipe(MatchingBertPipe): | |||
class CNXNLIBertPipe(MatchingBertPipe): | |||
def process_from_file(self, paths = None): | |||
def process_from_file(self, paths=None): | |||
data_bundle = CNXNLILoader().load(paths) | |||
data_bundle = GranularizePipe(task='XNLI').process(data_bundle) | |||
data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) | |||