diff --git a/data/test/regression/sbert-base-tnews.bin b/data/test/regression/sbert-base-tnews.bin new file mode 100644 index 00000000..1546860f --- /dev/null +++ b/data/test/regression/sbert-base-tnews.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2bce1341f4b55d536771dad6e2b280458579f46c3216474ceb8a926022ab53d0 +size 151572 diff --git a/data/test/regression/sbert_nli.bin b/data/test/regression/sbert_nli.bin index a5f680bb..68efb778 100644 --- a/data/test/regression/sbert_nli.bin +++ b/data/test/regression/sbert_nli.bin @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:44e3925c15d86d8596baeb6bd1d153d86f57b7489798b2cf988a1248e110fd62 -size 62231 +oid sha256:6af5024a26337a440c7ea2935fce84af558dd982ee97a2f027bb922cc874292b +size 61741 diff --git a/data/test/regression/sbert_sen_sim.bin b/data/test/regression/sbert_sen_sim.bin index a59cbe0b..362f762c 100644 --- a/data/test/regression/sbert_sen_sim.bin +++ b/data/test/regression/sbert_sen_sim.bin @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:1ff17a0272752de4c88d4254b2e881f97f8ef022f03609d03ee1de0ae964368a -size 62235 +oid sha256:bbce084781342ca7274c2e4d02ed5c5de43ba213a3b76328d5994404d6544c41 +size 61745 diff --git a/modelscope/exporters/nlp/sbert_for_sequence_classification_exporter.py b/modelscope/exporters/nlp/sbert_for_sequence_classification_exporter.py index dc1e2b92..52dab4bc 100644 --- a/modelscope/exporters/nlp/sbert_for_sequence_classification_exporter.py +++ b/modelscope/exporters/nlp/sbert_for_sequence_classification_exporter.py @@ -23,12 +23,14 @@ class SbertForSequenceClassificationExporter(TorchModelExporter): def generate_dummy_inputs(self, shape: Tuple = None, + pair: bool = False, **kwargs) -> Dict[str, Any]: """Generate dummy inputs for model exportation to onnx or other formats by tracing. @param shape: A tuple of input shape which should have at most two dimensions. shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor. shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor. + @param pair: Generate sentence pairs or single sentences for dummy inputs. @return: Dummy inputs. """ @@ -55,7 +57,7 @@ class SbertForSequenceClassificationExporter(TorchModelExporter): **sequence_length }) preprocessor: Preprocessor = build_preprocessor(cfg, field_name) - if preprocessor.pair: + if pair: first_sequence = preprocessor.tokenizer.unk_token second_sequence = preprocessor.tokenizer.unk_token else: diff --git a/modelscope/metrics/builder.py b/modelscope/metrics/builder.py index ee4d2840..1c8e16d7 100644 --- a/modelscope/metrics/builder.py +++ b/modelscope/metrics/builder.py @@ -32,6 +32,7 @@ task_default_metrics = { Tasks.sentiment_classification: [Metrics.seq_cls_metric], Tasks.token_classification: [Metrics.token_cls_metric], Tasks.text_generation: [Metrics.text_gen_metric], + Tasks.text_classification: [Metrics.seq_cls_metric], Tasks.image_denoising: [Metrics.image_denoise_metric], Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric], Tasks.image_portrait_enhancement: diff --git a/modelscope/models/nlp/bert/modeling_bert.py b/modelscope/models/nlp/bert/modeling_bert.py index e91a6433..7c1dfcf5 100755 --- a/modelscope/models/nlp/bert/modeling_bert.py +++ b/modelscope/models/nlp/bert/modeling_bert.py @@ -15,7 +15,6 @@ """PyTorch BERT model. """ import math -import os import warnings from dataclasses import dataclass from typing import Optional, Tuple @@ -41,7 +40,6 @@ from transformers.modeling_utils import (PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer) -from modelscope.models.base import TorchModel from modelscope.utils.logger import get_logger from .configuration_bert import BertConfig @@ -50,81 +48,6 @@ logger = get_logger(__name__) _CONFIG_FOR_DOC = 'BertConfig' -def load_tf_weights_in_bert(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - 'Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see ' - 'https://www.tensorflow.org/install/ for installation instructions.' - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f'Converting TensorFlow checkpoint from {tf_path}') - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info(f'Loading TF weight {name} with shape {shape}') - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - name = name.split('/') - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any(n in [ - 'adam_v', 'adam_m', 'AdamWeightDecayOptimizer', - 'AdamWeightDecayOptimizer_1', 'global_step' - ] for n in name): - logger.info(f"Skipping {'/'.join(name)}") - continue - pointer = model - for m_name in name: - if re.fullmatch(r'[A-Za-z]+_\d+', m_name): - scope_names = re.split(r'_(\d+)', m_name) - else: - scope_names = [m_name] - if scope_names[0] == 'kernel' or scope_names[0] == 'gamma': - pointer = getattr(pointer, 'weight') - elif scope_names[0] == 'output_bias' or scope_names[0] == 'beta': - pointer = getattr(pointer, 'bias') - elif scope_names[0] == 'output_weights': - pointer = getattr(pointer, 'weight') - elif scope_names[0] == 'squad': - pointer = getattr(pointer, 'classifier') - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if m_name[-11:] == '_embeddings': - pointer = getattr(pointer, 'weight') - elif m_name == 'kernel': - array = np.transpose(array) - try: - if pointer.shape != array.shape: - raise ValueError( - f'Pointer shape {pointer.shape} and array shape {array.shape} mismatched' - ) - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f'Initialize PyTorch weight {name}') - pointer.data = torch.from_numpy(array) - return model - - class BertEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" @@ -750,7 +673,6 @@ class BertPreTrainedModel(PreTrainedModel): """ config_class = BertConfig - load_tf_weights = load_tf_weights_in_bert base_model_prefix = 'bert' supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r'position_ids'] diff --git a/modelscope/preprocessors/nlp/nlp_base.py b/modelscope/preprocessors/nlp/nlp_base.py index 267dbb8c..bc96f569 100644 --- a/modelscope/preprocessors/nlp/nlp_base.py +++ b/modelscope/preprocessors/nlp/nlp_base.py @@ -2,7 +2,7 @@ import os.path as osp import re -from typing import Any, Dict, Iterable, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import sentencepiece as spm @@ -217,7 +217,7 @@ class NLPTokenizerPreprocessorBase(Preprocessor): return isinstance(label, str) or isinstance(label, int) if labels is not None: - if isinstance(labels, Iterable) and all([label_can_be_mapped(label) for label in labels]) \ + if isinstance(labels, (tuple, list)) and all([label_can_be_mapped(label) for label in labels]) \ and self.label2id is not None: output[OutputKeys.LABELS] = [ self.label2id[str(label)] for label in labels @@ -314,8 +314,7 @@ class SequenceClassificationPreprocessor(NLPTokenizerPreprocessorBase): def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): kwargs['truncation'] = kwargs.get('truncation', True) - kwargs['padding'] = kwargs.get( - 'padding', False if mode == ModeKeys.INFERENCE else 'max_length') + kwargs['padding'] = kwargs.get('padding', 'max_length') kwargs['max_length'] = kwargs.pop('sequence_length', 128) super().__init__(model_dir, mode=mode, **kwargs) diff --git a/modelscope/trainers/hooks/lr_scheduler_hook.py b/modelscope/trainers/hooks/lr_scheduler_hook.py index ca0ec01b..32fb0250 100644 --- a/modelscope/trainers/hooks/lr_scheduler_hook.py +++ b/modelscope/trainers/hooks/lr_scheduler_hook.py @@ -47,7 +47,7 @@ class LrSchedulerHook(Hook): return lr def before_train_iter(self, trainer): - if not self.by_epoch: + if not self.by_epoch and trainer.iter > 0: if self.warmup_lr_scheduler is not None: self.warmup_lr_scheduler.step() else: diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index 9eaff762..61d11aa6 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -651,7 +651,7 @@ class EpochBasedTrainer(BaseTrainer): # TODO: support MsDataset load for cv if hasattr(data_cfg, 'name'): dataset = MsDataset.load( - dataset_name=data_cfg.name, + dataset_name=data_cfg.pop('name'), **data_cfg, ) cfg = ConfigDict(type=self.cfg.model.type, mode=mode) diff --git a/modelscope/utils/regress_test_utils.py b/modelscope/utils/regress_test_utils.py index 47bbadfe..3c1e5c1c 100644 --- a/modelscope/utils/regress_test_utils.py +++ b/modelscope/utils/regress_test_utils.py @@ -65,7 +65,8 @@ class RegressTool: def monitor_module_single_forward(self, module: nn.Module, file_name: str, - compare_fn=None): + compare_fn=None, + **kwargs): """Monitor a pytorch module in a single forward. @param module: A torch module @@ -107,7 +108,7 @@ class RegressTool: baseline = os.path.join(tempfile.gettempdir(), name) self.load(baseline, name) with open(baseline, 'rb') as f: - baseline_json = pickle.load(f) + base = pickle.load(f) class NumpyEncoder(json.JSONEncoder): """Special json encoder for numpy types @@ -122,9 +123,9 @@ class RegressTool: return obj.tolist() return json.JSONEncoder.default(self, obj) - print(f'baseline: {json.dumps(baseline_json, cls=NumpyEncoder)}') + print(f'baseline: {json.dumps(base, cls=NumpyEncoder)}') print(f'latest : {json.dumps(io_json, cls=NumpyEncoder)}') - if not compare_io_and_print(baseline_json, io_json, compare_fn): + if not compare_io_and_print(base, io_json, compare_fn, **kwargs): raise ValueError('Result not match!') @contextlib.contextmanager @@ -136,7 +137,8 @@ class RegressTool: ignore_keys=None, compare_random=True, reset_dropout=True, - lazy_stop_callback=None): + lazy_stop_callback=None, + **kwargs): """Monitor a pytorch module's backward data and cfg data within a step of the optimizer. This is usually useful when you try to change some dangerous code @@ -265,14 +267,15 @@ class RegressTool: baseline_json = pickle.load(f) if level == 'strict' and not compare_io_and_print( - baseline_json['forward'], io_json, compare_fn): + baseline_json['forward'], io_json, compare_fn, **kwargs): raise RuntimeError('Forward not match!') if not compare_backward_and_print( baseline_json['backward'], bw_json, compare_fn=compare_fn, ignore_keys=ignore_keys, - level=level): + level=level, + **kwargs): raise RuntimeError('Backward not match!') cfg_opt1 = { 'optimizer': baseline_json['optimizer'], @@ -286,7 +289,8 @@ class RegressTool: 'cfg': summary['cfg'], 'state': None if not compare_random else summary['state'] } - if not compare_cfg_and_optimizers(cfg_opt1, cfg_opt2, compare_fn): + if not compare_cfg_and_optimizers(cfg_opt1, cfg_opt2, compare_fn, + **kwargs): raise RuntimeError('Cfg or optimizers not match!') @@ -303,7 +307,8 @@ class MsRegressTool(RegressTool): compare_fn=None, ignore_keys=None, compare_random=True, - lazy_stop_callback=None): + lazy_stop_callback=None, + **kwargs): if lazy_stop_callback is None: @@ -319,7 +324,7 @@ class MsRegressTool(RegressTool): trainer.register_hook(EarlyStopHook()) - def _train_loop(trainer, *args, **kwargs): + def _train_loop(trainer, *args_train, **kwargs_train): with self.monitor_module_train( trainer, file_name, @@ -327,9 +332,11 @@ class MsRegressTool(RegressTool): compare_fn=compare_fn, ignore_keys=ignore_keys, compare_random=compare_random, - lazy_stop_callback=lazy_stop_callback): + lazy_stop_callback=lazy_stop_callback, + **kwargs): try: - return trainer.train_loop_origin(*args, **kwargs) + return trainer.train_loop_origin(*args_train, + **kwargs_train) except MsRegressTool.EarlyStopError: pass @@ -530,7 +537,8 @@ def compare_arguments_nested(print_content, ) return False if not all([ - compare_arguments_nested(None, sub_arg1, sub_arg2) + compare_arguments_nested( + None, sub_arg1, sub_arg2, rtol=rtol, atol=atol) for sub_arg1, sub_arg2 in zip(arg1, arg2) ]): if print_content is not None: @@ -551,7 +559,8 @@ def compare_arguments_nested(print_content, print(f'{print_content}, key diff:{set(keys1) - set(keys2)}') return False if not all([ - compare_arguments_nested(None, arg1[key], arg2[key]) + compare_arguments_nested( + None, arg1[key], arg2[key], rtol=rtol, atol=atol) for key in keys1 ]): if print_content is not None: @@ -574,7 +583,7 @@ def compare_arguments_nested(print_content, raise ValueError(f'type not supported: {type1}') -def compare_io_and_print(baseline_json, io_json, compare_fn=None): +def compare_io_and_print(baseline_json, io_json, compare_fn=None, **kwargs): if compare_fn is None: def compare_fn(*args, **kwargs): @@ -602,10 +611,10 @@ def compare_io_and_print(baseline_json, io_json, compare_fn=None): else: match = compare_arguments_nested( f'unmatched module {key} input args', v1input['args'], - v2input['args']) and match + v2input['args'], **kwargs) and match match = compare_arguments_nested( f'unmatched module {key} input kwargs', v1input['kwargs'], - v2input['kwargs']) and match + v2input['kwargs'], **kwargs) and match v1output = numpify_tensor_nested(v1['output']) v2output = numpify_tensor_nested(v2['output']) res = compare_fn(v1output, v2output, key, 'output') @@ -615,8 +624,11 @@ def compare_io_and_print(baseline_json, io_json, compare_fn=None): ) match = match and res else: - match = compare_arguments_nested(f'unmatched module {key} outputs', - v1output, v2output) and match + match = compare_arguments_nested( + f'unmatched module {key} outputs', + arg1=v1output, + arg2=v2output, + **kwargs) and match return match @@ -624,7 +636,8 @@ def compare_backward_and_print(baseline_json, bw_json, level, ignore_keys=None, - compare_fn=None): + compare_fn=None, + **kwargs): if compare_fn is None: def compare_fn(*args, **kwargs): @@ -653,18 +666,26 @@ def compare_backward_and_print(baseline_json, data2, grad2, data_after2 = bw_json[key]['data'], bw_json[key][ 'grad'], bw_json[key]['data_after'] match = compare_arguments_nested( - f'unmatched module {key} tensor data', data1, data2) and match + f'unmatched module {key} tensor data', + arg1=data1, + arg2=data2, + **kwargs) and match if level == 'strict': match = compare_arguments_nested( - f'unmatched module {key} grad data', grad1, - grad2) and match + f'unmatched module {key} grad data', + arg1=grad1, + arg2=grad2, + **kwargs) and match match = compare_arguments_nested( f'unmatched module {key} data after step', data_after1, - data_after2) and match + data_after2, **kwargs) and match return match -def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None): +def compare_cfg_and_optimizers(baseline_json, + cfg_json, + compare_fn=None, + **kwargs): if compare_fn is None: def compare_fn(*args, **kwargs): @@ -686,12 +707,12 @@ def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None): print( f"Optimizer type not equal:{optimizer1['type']} and {optimizer2['type']}" ) - match = compare_arguments_nested('unmatched optimizer defaults', - optimizer1['defaults'], - optimizer2['defaults']) and match - match = compare_arguments_nested('unmatched optimizer state_dict', - optimizer1['state_dict'], - optimizer2['state_dict']) and match + match = compare_arguments_nested( + 'unmatched optimizer defaults', optimizer1['defaults'], + optimizer2['defaults'], **kwargs) and match + match = compare_arguments_nested( + 'unmatched optimizer state_dict', optimizer1['state_dict'], + optimizer2['state_dict'], **kwargs) and match res = compare_fn(lr_scheduler1, lr_scheduler2, None, 'lr_scheduler') if res is not None: @@ -703,16 +724,17 @@ def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None): print( f"Optimizer type not equal:{lr_scheduler1['type']} and {lr_scheduler2['type']}" ) - match = compare_arguments_nested('unmatched lr_scheduler state_dict', - lr_scheduler1['state_dict'], - lr_scheduler2['state_dict']) and match + match = compare_arguments_nested( + 'unmatched lr_scheduler state_dict', lr_scheduler1['state_dict'], + lr_scheduler2['state_dict'], **kwargs) and match res = compare_fn(cfg1, cfg2, None, 'cfg') if res is not None: print(f'cfg compared with user compare_fn with result:{res}\n') match = match and res else: - match = compare_arguments_nested('unmatched cfg', cfg1, cfg2) and match + match = compare_arguments_nested( + 'unmatched cfg', arg1=cfg1, arg2=cfg2, **kwargs) and match res = compare_fn(state1, state2, None, 'state') if res is not None: @@ -721,6 +743,6 @@ def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None): match = match and res else: match = compare_arguments_nested('unmatched random state', state1, - state2) and match + state2, **kwargs) and match return match diff --git a/tests/msdatasets/test_ms_dataset.py b/tests/msdatasets/test_ms_dataset.py index 91a3b5c5..1e537e93 100644 --- a/tests/msdatasets/test_ms_dataset.py +++ b/tests/msdatasets/test_ms_dataset.py @@ -52,7 +52,8 @@ class MsDatasetTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_ms_csv_basic(self): ms_ds_train = MsDataset.load( - 'afqmc_small', namespace='userxiaoming', split='train') + 'clue', subset_name='afqmc', + split='train').to_hf_dataset().select(range(5)) print(next(iter(ms_ds_train))) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') diff --git a/tests/trainers/test_finetune_sequence_classification.py b/tests/trainers/test_finetune_sequence_classification.py index f2adfa22..27db1f18 100644 --- a/tests/trainers/test_finetune_sequence_classification.py +++ b/tests/trainers/test_finetune_sequence_classification.py @@ -16,7 +16,8 @@ from modelscope.trainers.optimizer.child_tuning_adamw_optimizer import \ calculate_fisher from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.data_utils import to_device -from modelscope.utils.regress_test_utils import MsRegressTool +from modelscope.utils.regress_test_utils import (MsRegressTool, + compare_arguments_nested) from modelscope.utils.test_utils import test_level @@ -41,6 +42,33 @@ class TestFinetuneSequenceClassification(unittest.TestCase): def test_trainer_repeatable(self): import torch # noqa + def compare_fn(value1, value2, key, type): + # Ignore the differences between optimizers of two torch versions + if type != 'optimizer': + return None + + match = (value1['type'] == value2['type']) + shared_defaults = set(value1['defaults'].keys()).intersection( + set(value2['defaults'].keys())) + match = all([ + compare_arguments_nested(f'Optimizer defaults {key} not match', + value1['defaults'][key], + value2['defaults'][key]) + for key in shared_defaults + ]) and match + match = (len(value1['state_dict']['param_groups']) == len( + value2['state_dict']['param_groups'])) and match + for group1, group2 in zip(value1['state_dict']['param_groups'], + value2['state_dict']['param_groups']): + shared_keys = set(group1.keys()).intersection( + set(group2.keys())) + match = all([ + compare_arguments_nested( + f'Optimizer param_groups {key} not match', group1[key], + group2[key]) for key in shared_keys + ]) and match + return match + def cfg_modify_fn(cfg): cfg.task = 'nli' cfg['preprocessor'] = {'type': 'nli-tokenizer'} @@ -98,7 +126,8 @@ class TestFinetuneSequenceClassification(unittest.TestCase): name=Trainers.nlp_base_trainer, default_args=kwargs) with self.regress_tool.monitor_ms_train( - trainer, 'sbert-base-tnews', level='strict'): + trainer, 'sbert-base-tnews', level='strict', + compare_fn=compare_fn): trainer.train() def finetune(self, diff --git a/tests/trainers/test_trainer_with_nlp.py b/tests/trainers/test_trainer_with_nlp.py index 6030ada9..8357e778 100644 --- a/tests/trainers/test_trainer_with_nlp.py +++ b/tests/trainers/test_trainer_with_nlp.py @@ -29,7 +29,8 @@ class TestTrainerWithNlp(unittest.TestCase): os.makedirs(self.tmp_dir) self.dataset = MsDataset.load( - 'afqmc_small', namespace='userxiaoming', split='train') + 'clue', subset_name='afqmc', + split='train').to_hf_dataset().select(range(2)) def tearDown(self): shutil.rmtree(self.tmp_dir) @@ -73,7 +74,7 @@ class TestTrainerWithNlp(unittest.TestCase): output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) pipeline_sentence_similarity(output_dir) - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 3, 'skip test in current test level') def test_trainer_with_backbone_head(self): model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' kwargs = dict( @@ -99,6 +100,8 @@ class TestTrainerWithNlp(unittest.TestCase): model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' cfg = read_config(model_id, revision='beta') cfg.train.max_epochs = 20 + cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1} + cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1} cfg.train.work_dir = self.tmp_dir cfg_file = os.path.join(self.tmp_dir, 'config.json') cfg.dump(cfg_file) @@ -120,22 +123,24 @@ class TestTrainerWithNlp(unittest.TestCase): checkpoint_path=os.path.join(self.tmp_dir, 'epoch_10.pth')) self.assertTrue(Metrics.accuracy in eval_results) - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_trainer_with_configured_datasets(self): model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' cfg: Config = read_config(model_id) cfg.train.max_epochs = 20 + cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1} + cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1} cfg.train.work_dir = self.tmp_dir cfg.dataset = { 'train': { - 'name': 'afqmc_small', + 'name': 'clue', + 'subset_name': 'afqmc', 'split': 'train', - 'namespace': 'userxiaoming' }, 'val': { - 'name': 'afqmc_small', + 'name': 'clue', + 'subset_name': 'afqmc', 'split': 'train', - 'namespace': 'userxiaoming' }, } cfg_file = os.path.join(self.tmp_dir, 'config.json') @@ -159,6 +164,11 @@ class TestTrainerWithNlp(unittest.TestCase): model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' cfg: Config = read_config(model_id) cfg.train.max_epochs = 3 + cfg.preprocessor.first_sequence = 'sentence1' + cfg.preprocessor.second_sequence = 'sentence2' + cfg.preprocessor.label = 'label' + cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1} + cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1} cfg.train.work_dir = self.tmp_dir cfg_file = os.path.join(self.tmp_dir, 'config.json') cfg.dump(cfg_file)