diff --git a/docs/source/fastNLP.core.callback.rst b/docs/source/fastNLP.core.callback.rst index 75b5d0cd..5a508e03 100644 --- a/docs/source/fastNLP.core.callback.rst +++ b/docs/source/fastNLP.core.callback.rst @@ -2,6 +2,6 @@ fastNLP.core.callback ===================== .. automodule:: fastNLP.core.callback - :members: Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, CallbackException, EarlyStopError + :members: Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, CallbackException, EarlyStopError :inherited-members: diff --git a/docs/source/fastNLP.io.loader.rst b/docs/source/fastNLP.io.loader.rst index f6c72be8..c6d0dc55 100644 --- a/docs/source/fastNLP.io.loader.rst +++ b/docs/source/fastNLP.io.loader.rst @@ -2,6 +2,6 @@ fastNLP.io.loader ================= .. automodule:: fastNLP.io.loader - :members: Loader, YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader, ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader, MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader, CSVLoader, JsonLoader, CWSLoader, MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader, CoReferenceLoader + :members: Loader, YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader, ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader, MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader, CSVLoader, JsonLoader, CWSLoader, MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader, CNXNLILoader, BQCorpusLoader, LCQMCLoader, CoReferenceLoader :inherited-members: diff --git a/docs/source/fastNLP.io.pipe.rst b/docs/source/fastNLP.io.pipe.rst index ee389e8c..178d35a9 100644 --- a/docs/source/fastNLP.io.pipe.rst +++ b/docs/source/fastNLP.io.pipe.rst @@ -2,6 +2,6 @@ fastNLP.io.pipe =============== .. automodule:: fastNLP.io.pipe - :members: Pipe, CWSPipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe, Conll2003Pipe, MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe, CoReferencePipe + :members: Pipe, CWSPipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, WeiboSenti100kPipe, Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe, Conll2003Pipe, MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, CNXNLIBertPipe, BQCorpusBertPipe, LCQMCBertPipe, MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe, LCQMCPipe, CNXNLIPipe, BQCorpusPipe, RenamePipe, GranularizePipe, MachingTruncatePipe, CoReferencePipe :inherited-members: diff --git a/docs/source/fastNLP.io.rst b/docs/source/fastNLP.io.rst index 7118039d..54373df4 100644 --- a/docs/source/fastNLP.io.rst +++ b/docs/source/fastNLP.io.rst @@ -2,7 +2,7 @@ fastNLP.io ========== .. automodule:: fastNLP.io - :members: DataBundle, EmbedLoader, Loader, YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader, ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader, MsraNERLoader, WeiboNERLoader, PeopleDailyNERLoader, CSVLoader, JsonLoader, CWSLoader, MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader, Pipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, Conll2003Pipe, Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, PeopleDailyPipe, WeiboNERPipe, CWSPipe, MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe, ModelLoader, ModelSaver + :members: DataBundle, EmbedLoader, Loader, YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader, ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader, MsraNERLoader, WeiboNERLoader, PeopleDailyNERLoader, CSVLoader, JsonLoader, CWSLoader, MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader, CNXNLILoader, BQCorpusLoader, LCQMCLoader, Pipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, WeiboSenti100kPipe, Conll2003Pipe, Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, PeopleDailyPipe, WeiboNERPipe, CWSPipe, MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe, ModelLoader, ModelSaver :inherited-members: 子模块 diff --git a/docs/source/fastNLP.rst b/docs/source/fastNLP.rst index e92807d7..097ad0b2 100644 --- a/docs/source/fastNLP.rst +++ b/docs/source/fastNLP.rst @@ -2,7 +2,7 @@ fastNLP ======= .. automodule:: fastNLP - :members: Instance, FieldArray, DataSetIter, BatchIter, TorchLoaderIter, Vocabulary, DataSet, Const, Trainer, Tester, Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, CallbackException, EarlyStopError, Padder, AutoPadder, EngChar2DPadder, AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric, Optimizer, SGD, Adam, AdamW, Sampler, SequentialSampler, BucketSampler, RandomSampler, LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, cache_results, logger + :members: Instance, FieldArray, DataSetIter, BatchIter, TorchLoaderIter, Vocabulary, DataSet, Const, Trainer, Tester, Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, CallbackException, EarlyStopError, Padder, AutoPadder, EngChar2DPadder, AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric, Optimizer, SGD, Adam, AdamW, Sampler, SequentialSampler, BucketSampler, RandomSampler, LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, cache_results, logger :inherited-members: 子模块 diff --git a/fastNLP/io/__init__.py b/fastNLP/io/__init__.py index 63fde69a..54d2d8b6 100644 --- a/fastNLP/io/__init__.py +++ b/fastNLP/io/__init__.py @@ -47,7 +47,7 @@ __all__ = [ "SNLILoader", "QNLILoader", "RTELoader", - "XNLILoader", + "CNXNLILoader", "BQCorpusLoader", "LCQMCLoader", @@ -70,32 +70,61 @@ __all__ = [ "WeiboNERPipe", "CWSPipe", - + + "Pipe", + + "CWSPipe", + + "YelpFullPipe", + "YelpPolarityPipe", + "SSTPipe", + "SST2Pipe", + "IMDBPipe", + "ChnSentiCorpPipe", + "THUCNewsPipe", + "WeiboSenti100kPipe", + + "Conll2003NERPipe", + "OntoNotesNERPipe", + "MsraNERPipe", + "WeiboNERPipe", + "PeopleDailyPipe", + "Conll2003Pipe", + "MatchingBertPipe", "RTEBertPipe", "SNLIBertPipe", "QuoraBertPipe", "QNLIBertPipe", "MNLIBertPipe", + "CNXNLIBertPipe", + "BQCorpusBertPipe", + "LCQMCBertPipe", "MatchingPipe", "RTEPipe", "SNLIPipe", "QuoraPipe", "QNLIPipe", "MNLIPipe", + "LCQMCPipe", + "CNXNLIPipe", + "BQCorpusPipe", + "RenamePipe", + "GranularizePipe", + "MachingTruncatePipe", 'ModelLoader', 'ModelSaver', ] -from .embed_loader import EmbedLoader -from .data_bundle import DataBundle -from .model_io import ModelLoader, ModelSaver +import sys +from .data_bundle import DataBundle +from .embed_loader import EmbedLoader from .loader import * +from .model_io import ModelLoader, ModelSaver from .pipe import * - -import sys from ..doc_utils import doc_process + doc_process(sys.modules[__name__]) \ No newline at end of file diff --git a/fastNLP/io/loader/__init__.py b/fastNLP/io/loader/__init__.py index 4ad228b0..5fb9fd91 100644 --- a/fastNLP/io/loader/__init__.py +++ b/fastNLP/io/loader/__init__.py @@ -54,7 +54,9 @@ __all__ = [ 'SSTLoader', 'SST2Loader', "ChnSentiCorpLoader", - + "THUCNewsLoader", + "WeiboSenti100kLoader", + 'ConllLoader', 'Conll2003Loader', 'Conll2003NERLoader', @@ -63,26 +65,31 @@ __all__ = [ "MsraNERLoader", "PeopleDailyNERLoader", "WeiboNERLoader", - + 'CSVLoader', 'JsonLoader', - + 'CWSLoader', - + 'MNLILoader', "QuoraLoader", "SNLILoader", "QNLILoader", "RTELoader", - + "CNXNLILoader", + "BQCorpusLoader", + "LCQMCLoader", + "CoReferenceLoader" ] -from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader +from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, \ + ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader +from .conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader +from .coreference import CoReferenceLoader from .csv import CSVLoader from .cws import CWSLoader from .json import JsonLoader from .loader import Loader -from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader -from .conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader -from .coreference import CoReferenceLoader \ No newline at end of file +from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader, CNXNLILoader, BQCorpusLoader, \ + LCQMCLoader diff --git a/fastNLP/io/loader/classification.py b/fastNLP/io/loader/classification.py index 004f3ebd..e0c894a2 100644 --- a/fastNLP/io/loader/classification.py +++ b/fastNLP/io/loader/classification.py @@ -409,6 +409,7 @@ class THUCNewsLoader(Loader): .. csv-table:: :header: "raw_words", "target" + "马晓旭意外受伤让国奥警惕 无奈大雨格外青睐殷家军记者傅亚雨沈阳报道 ... ", "体育" "...", "..." @@ -446,13 +447,18 @@ class WeiboSenti100kLoader(Loader): 别名: 数据集简介:微博sentiment classification,二分类 原始数据内容为: - label text - 0 六一出生的?好讽刺…… //@祭春姬:他爸爸是外星人吧 //@面孔小高:现在的孩子都怎么了 [怒][怒][怒] - 1 听过一场!笑死了昂,一听茄子脱口秀,从此节操是路人![嘻嘻] //@中国梦网官微:@Pencil彭赛 @茄子脱口秀 [圣诞帽][圣诞树][平安果] + + .. .. code-block:: text + + label text + 0 六一出生的?好讽刺…… //@祭春姬:他爸爸是外星人吧 //@面孔小高:现在的孩子都怎么了 [怒][怒][怒] + 1 听过一场!笑死了昂,一听茄子脱口秀,从此节操是路人![嘻嘻] //@中国梦网官微:@Pencil彭赛 @茄子脱口秀 [圣诞帽][圣诞树][平安果] + 读取后的Dataset将具有以下数据结构: .. csv-table:: :header: "raw_chars", "target" + "六一出生的?好讽刺…… //@祭春姬:他爸爸是外星人吧 //@面孔小高:现在的孩子都怎么了 [怒][怒][怒]", "0" "...", "..." diff --git a/fastNLP/io/loader/matching.py b/fastNLP/io/loader/matching.py index 80889507..bf4eec81 100644 --- a/fastNLP/io/loader/matching.py +++ b/fastNLP/io/loader/matching.py @@ -15,14 +15,14 @@ import os import warnings from typing import Union, Dict +from .csv import CSVLoader from .json import JsonLoader from .loader import Loader from .. import DataBundle +from ..utils import check_loader_paths from ...core.const import Const from ...core.dataset import DataSet from ...core.instance import Instance -from .csv import CSVLoader -from ..utils import check_loader_paths class MNLILoader(Loader): @@ -348,8 +348,9 @@ class CNXNLILoader(Loader): .. csv-table:: :header: "raw_chars1", "raw_chars2", "target" + "从概念上看,奶油收入有两个基本方面产品和地理.", "产品和地理是什么使奶油抹霜工作.", "1" - ""...", "...", "..." + "...", "...", "..." """ @@ -412,6 +413,7 @@ class BQCorpusLoader(Loader): .. csv-table:: :header: "raw_chars1", "raw_chars2", "target" + "不是邀请的如何贷款?", "我不是你们邀请的客人可以贷款吗?", "1" "如何满足微粒银行的审核", "建设银行有微粒贷的资格吗", "0" "...", "...", "..." @@ -458,9 +460,10 @@ class LCQMCLoader(Loader): .. csv-table:: :header: "raw_chars1", "raw_chars2", "target" + "喜欢打篮球的男生喜欢什么样的女生?", "爱打篮球的男生喜欢什么样的女生?", "1" "晚上睡觉带着耳机听音乐有什么害处吗?", "妇可以戴耳机听音乐吗?", "0" - ""...", "...", "..." + "...", "...", "..." """ diff --git a/fastNLP/io/pipe/__init__.py b/fastNLP/io/pipe/__init__.py index 212f9e66..e30978be 100644 --- a/fastNLP/io/pipe/__init__.py +++ b/fastNLP/io/pipe/__init__.py @@ -9,9 +9,9 @@ Pipe用于处理通过 Loader 读取的数据,所有的 Pipe 都包含 ``proce """ __all__ = [ "Pipe", - + "CWSPipe", - + "YelpFullPipe", "YelpPolarityPipe", "SSTPipe", @@ -20,35 +20,46 @@ __all__ = [ "ChnSentiCorpPipe", "THUCNewsPipe", "WeiboSenti100kPipe", - + "Conll2003NERPipe", "OntoNotesNERPipe", "MsraNERPipe", "WeiboNERPipe", "PeopleDailyPipe", "Conll2003Pipe", - + "MatchingBertPipe", "RTEBertPipe", "SNLIBertPipe", "QuoraBertPipe", "QNLIBertPipe", "MNLIBertPipe", + "CNXNLIBertPipe", + "BQCorpusBertPipe", + "LCQMCBertPipe", "MatchingPipe", "RTEPipe", "SNLIPipe", "QuoraPipe", "QNLIPipe", "MNLIPipe", - + "LCQMCPipe", + "CNXNLIPipe", + "BQCorpusPipe", + "RenamePipe", + "GranularizePipe", + "MachingTruncatePipe", + "CoReferencePipe" ] -from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, WeiboSenti100kPipe +from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, \ + WeiboSenti100kPipe from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe -from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \ - MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe -from .pipe import Pipe from .conll import Conll2003Pipe -from .cws import CWSPipe from .coreference import CoReferencePipe +from .cws import CWSPipe +from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \ + MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe, CNXNLIBertPipe, CNXNLIPipe, BQCorpusBertPipe, \ + LCQMCPipe, BQCorpusPipe, LCQMCBertPipe, RenamePipe, GranularizePipe, MachingTruncatePipe +from .pipe import Pipe diff --git a/fastNLP/io/pipe/classification.py b/fastNLP/io/pipe/classification.py index 1c44cc23..ab31c9de 100644 --- a/fastNLP/io/pipe/classification.py +++ b/fastNLP/io/pipe/classification.py @@ -21,11 +21,11 @@ from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_insta from ..data_bundle import DataBundle from ..loader.classification import ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader +from ...core._logger import logger from ...core.const import Const from ...core.dataset import DataSet from ...core.instance import Instance from ...core.vocabulary import Vocabulary -from ...core._logger import logger nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') @@ -718,6 +718,7 @@ class THUCNewsPipe(_CLSPipe): .. csv-table:: :header: "raw_words", "target" + "马晓旭意外受伤让国奥警惕 无奈大雨格外青睐殷家军记者傅亚雨沈阳报道 ... ", "体育" "...", "..." @@ -826,6 +827,7 @@ class WeiboSenti100kPipe(_CLSPipe): .. csv-table:: :header: "raw_chars", "target" + "六一出生的?好讽刺…… //@祭春姬:他爸爸是外星人吧 //@面孔小高:现在的孩子都怎么了 [怒][怒][怒]", "0" "...", "..." diff --git a/fastNLP/io/pipe/matching.py b/fastNLP/io/pipe/matching.py index 90cf17df..dac21dca 100644 --- a/fastNLP/io/pipe/matching.py +++ b/fastNLP/io/pipe/matching.py @@ -16,20 +16,24 @@ __all__ = [ "QuoraPipe", "QNLIPipe", "MNLIPipe", + "LCQMCPipe", "CNXNLIPipe", "BQCorpusPipe", - "LCQMCPipe", + "RenamePipe", + "GranularizePipe", + "MachingTruncatePipe", ] import warnings from .pipe import Pipe from .utils import get_tokenizer -from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader, BQCorpusLoader, CNXNLILoader, LCQMCLoader +from ..data_bundle import DataBundle +from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader, BQCorpusLoader, CNXNLILoader, \ + LCQMCLoader +from ...core._logger import logger from ...core.const import Const from ...core.vocabulary import Vocabulary -from ...core._logger import logger -from ..data_bundle import DataBundle class MatchingBertPipe(Pipe): @@ -145,7 +149,7 @@ class MatchingBertPipe(Pipe): f"data set but not in train data set!." warnings.warn(warn_msg) logger.warning(warn_msg) - + has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if dataset.has_field(Const.TARGET)] target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) @@ -294,7 +298,7 @@ class MatchingPipe(Pipe): f"data set but not in train data set!." warnings.warn(warn_msg) logger.warning(warn_msg) - + has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if dataset.has_field(Const.TARGET)] target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) @@ -345,8 +349,9 @@ class MNLIPipe(MatchingPipe): data_bundle = MNLILoader().load(paths) return self.process(data_bundle) + class LCQMCPipe(MatchingPipe): - def process_from_file(self, paths = None): + def process_from_file(self, paths=None): data_bundle = LCQMCLoader().load(paths) data_bundle = RenamePipe().process(data_bundle) data_bundle = self.process(data_bundle) @@ -358,14 +363,14 @@ class CNXNLIPipe(MatchingPipe): def process_from_file(self, paths=None): data_bundle = CNXNLILoader().load(paths) data_bundle = GranularizePipe(task='XNLI').process(data_bundle) - data_bundle = RenamePipe().process(data_bundle) #使中文数据的field + data_bundle = RenamePipe().process(data_bundle) # 使中文数据的field data_bundle = self.process(data_bundle) data_bundle = RenamePipe().process(data_bundle) return data_bundle class BQCorpusPipe(MatchingPipe): - def process_from_file(self, paths = None): + def process_from_file(self, paths=None): data_bundle = BQCorpusLoader().load(paths) data_bundle = RenamePipe().process(data_bundle) data_bundle = self.process(data_bundle) @@ -374,12 +379,12 @@ class BQCorpusPipe(MatchingPipe): class RenamePipe(Pipe): - def __init__(self, task = 'cn-nli'): + def __init__(self, task='cn-nli'): super().__init__() self.task = task - + def process(self, data_bundle: DataBundle): # rename field name for Chinese Matching dataset - if(self.task == 'cn-nli'): + if (self.task == 'cn-nli'): for name, dataset in data_bundle.datasets.items(): if (dataset.has_field(Const.RAW_CHARS(0))): dataset.rename_field(Const.RAW_CHARS(0), Const.RAW_WORDS(0)) # RAW_CHARS->RAW_WORDS @@ -392,12 +397,12 @@ class RenamePipe(Pipe): else: raise RuntimeError( "field name of dataset is not qualified. It should have ether RAW_CHARS or WORDS") - elif(self.task == 'cn-nli-bert'): + elif (self.task == 'cn-nli-bert'): for name, dataset in data_bundle.datasets.items(): if (dataset.has_field(Const.RAW_CHARS(0))): dataset.rename_field(Const.RAW_CHARS(0), Const.RAW_WORDS(0)) # RAW_CHARS->RAW_WORDS dataset.rename_field(Const.RAW_CHARS(1), Const.RAW_WORDS(1)) - elif(dataset.has_field(Const.RAW_WORDS(0))): + elif (dataset.has_field(Const.RAW_WORDS(0))): dataset.rename_field(Const.RAW_WORDS(0), Const.RAW_CHARS(0)) dataset.rename_field(Const.RAW_WORDS(1), Const.RAW_CHARS(1)) dataset.rename_field(Const.INPUT, Const.CHAR_INPUT) @@ -409,15 +414,15 @@ class RenamePipe(Pipe): raise RuntimeError( "Only support task='cn-nli' or 'cn-nli-bert'" ) - + return data_bundle class GranularizePipe(Pipe): - def __init__(self, task = None): + def __init__(self, task=None): super().__init__() self.task = task - + def _granularize(self, data_bundle, tag_map): """ 该函数对data_bundle中'target'列中的内容进行转换。 @@ -434,21 +439,22 @@ class GranularizePipe(Pipe): dataset.drop(lambda ins: ins[Const.TARGET] == -100) data_bundle.set_dataset(dataset, name) return data_bundle - + def process(self, data_bundle: DataBundle): task_tag_dict = { - 'XNLI':{'neutral': 0, 'entailment': 1, 'contradictory': 2, 'contradiction': 2} + 'XNLI': {'neutral': 0, 'entailment': 1, 'contradictory': 2, 'contradiction': 2} } if self.task in task_tag_dict: - data_bundle = self._granularize(data_bundle=data_bundle, tag_map= task_tag_dict[self.task]) + data_bundle = self._granularize(data_bundle=data_bundle, tag_map=task_tag_dict[self.task]) else: raise RuntimeError(f"Only support {task_tag_dict.keys()} task_tag_map.") return data_bundle -class MachingTruncatePipe(Pipe): #truncate sentence for bert, modify seq_len +class MachingTruncatePipe(Pipe): # truncate sentence for bert, modify seq_len def __init__(self): super().__init__() + def process(self, data_bundle: DataBundle): for name, dataset in data_bundle.datasets.items(): pass @@ -456,7 +462,7 @@ class MachingTruncatePipe(Pipe): #truncate sentence for bert, modify seq_len class LCQMCBertPipe(MatchingBertPipe): - def process_from_file(self, paths = None): + def process_from_file(self, paths=None): data_bundle = LCQMCLoader().load(paths) data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) data_bundle = self.process(data_bundle) @@ -465,7 +471,7 @@ class LCQMCBertPipe(MatchingBertPipe): class BQCorpusBertPipe(MatchingBertPipe): - def process_from_file(self, paths = None): + def process_from_file(self, paths=None): data_bundle = BQCorpusLoader().load(paths) data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) data_bundle = self.process(data_bundle) @@ -474,7 +480,7 @@ class BQCorpusBertPipe(MatchingBertPipe): class CNXNLIBertPipe(MatchingBertPipe): - def process_from_file(self, paths = None): + def process_from_file(self, paths=None): data_bundle = CNXNLILoader().load(paths) data_bundle = GranularizePipe(task='XNLI').process(data_bundle) data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle)