From acba1786b0ac263140e2743b51fc7fc7c36d0980 Mon Sep 17 00:00:00 2001 From: "yuze.zyz" Date: Thu, 20 Oct 2022 15:29:34 +0800 Subject: [PATCH] [to #42322933] Fix bug in UT daily 1. Fix bugs in daily test 2. Fix a bug that the updating of lr is before the first time of updating of optimizer TODO this will still cause warnings when GA is above 1 3. Remove the judgement of mode in text-classification's preprocessor to fit the base trainer(Bug) Update some regression bins to fit the preprocessor 4. Update the regression tool to let outer code modify atol and rtol 5. Add the default metric for text-classification task 6. Remove the useless ckpt conversion method in bert to avoid the requirement of tf when loading modeling_bert Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10430764 --- data/test/regression/sbert-base-tnews.bin | 3 + data/test/regression/sbert_nli.bin | 4 +- data/test/regression/sbert_sen_sim.bin | 4 +- ...rt_for_sequence_classification_exporter.py | 4 +- modelscope/metrics/builder.py | 1 + modelscope/models/nlp/bert/modeling_bert.py | 78 --------------- modelscope/preprocessors/nlp/nlp_base.py | 7 +- .../trainers/hooks/lr_scheduler_hook.py | 2 +- modelscope/trainers/trainer.py | 2 +- modelscope/utils/regress_test_utils.py | 94 ++++++++++++------- tests/msdatasets/test_ms_dataset.py | 3 +- .../test_finetune_sequence_classification.py | 33 ++++++- tests/trainers/test_trainer_with_nlp.py | 24 +++-- 13 files changed, 124 insertions(+), 135 deletions(-) create mode 100644 data/test/regression/sbert-base-tnews.bin diff --git a/data/test/regression/sbert-base-tnews.bin b/data/test/regression/sbert-base-tnews.bin new file mode 100644 index 00000000..1546860f --- /dev/null +++ b/data/test/regression/sbert-base-tnews.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2bce1341f4b55d536771dad6e2b280458579f46c3216474ceb8a926022ab53d0 +size 151572 diff --git a/data/test/regression/sbert_nli.bin b/data/test/regression/sbert_nli.bin index a5f680bb..68efb778 100644 --- a/data/test/regression/sbert_nli.bin +++ b/data/test/regression/sbert_nli.bin @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:44e3925c15d86d8596baeb6bd1d153d86f57b7489798b2cf988a1248e110fd62 -size 62231 +oid sha256:6af5024a26337a440c7ea2935fce84af558dd982ee97a2f027bb922cc874292b +size 61741 diff --git a/data/test/regression/sbert_sen_sim.bin b/data/test/regression/sbert_sen_sim.bin index a59cbe0b..362f762c 100644 --- a/data/test/regression/sbert_sen_sim.bin +++ b/data/test/regression/sbert_sen_sim.bin @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:1ff17a0272752de4c88d4254b2e881f97f8ef022f03609d03ee1de0ae964368a -size 62235 +oid sha256:bbce084781342ca7274c2e4d02ed5c5de43ba213a3b76328d5994404d6544c41 +size 61745 diff --git a/modelscope/exporters/nlp/sbert_for_sequence_classification_exporter.py b/modelscope/exporters/nlp/sbert_for_sequence_classification_exporter.py index dc1e2b92..52dab4bc 100644 --- a/modelscope/exporters/nlp/sbert_for_sequence_classification_exporter.py +++ b/modelscope/exporters/nlp/sbert_for_sequence_classification_exporter.py @@ -23,12 +23,14 @@ class SbertForSequenceClassificationExporter(TorchModelExporter): def generate_dummy_inputs(self, shape: Tuple = None, + pair: bool = False, **kwargs) -> Dict[str, Any]: """Generate dummy inputs for model exportation to onnx or other formats by tracing. @param shape: A tuple of input shape which should have at most two dimensions. shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor. shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor. + @param pair: Generate sentence pairs or single sentences for dummy inputs. @return: Dummy inputs. """ @@ -55,7 +57,7 @@ class SbertForSequenceClassificationExporter(TorchModelExporter): **sequence_length }) preprocessor: Preprocessor = build_preprocessor(cfg, field_name) - if preprocessor.pair: + if pair: first_sequence = preprocessor.tokenizer.unk_token second_sequence = preprocessor.tokenizer.unk_token else: diff --git a/modelscope/metrics/builder.py b/modelscope/metrics/builder.py index ee4d2840..1c8e16d7 100644 --- a/modelscope/metrics/builder.py +++ b/modelscope/metrics/builder.py @@ -32,6 +32,7 @@ task_default_metrics = { Tasks.sentiment_classification: [Metrics.seq_cls_metric], Tasks.token_classification: [Metrics.token_cls_metric], Tasks.text_generation: [Metrics.text_gen_metric], + Tasks.text_classification: [Metrics.seq_cls_metric], Tasks.image_denoising: [Metrics.image_denoise_metric], Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric], Tasks.image_portrait_enhancement: diff --git a/modelscope/models/nlp/bert/modeling_bert.py b/modelscope/models/nlp/bert/modeling_bert.py index e91a6433..7c1dfcf5 100755 --- a/modelscope/models/nlp/bert/modeling_bert.py +++ b/modelscope/models/nlp/bert/modeling_bert.py @@ -15,7 +15,6 @@ """PyTorch BERT model. """ import math -import os import warnings from dataclasses import dataclass from typing import Optional, Tuple @@ -41,7 +40,6 @@ from transformers.modeling_utils import (PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer) -from modelscope.models.base import TorchModel from modelscope.utils.logger import get_logger from .configuration_bert import BertConfig @@ -50,81 +48,6 @@ logger = get_logger(__name__) _CONFIG_FOR_DOC = 'BertConfig' -def load_tf_weights_in_bert(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - 'Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see ' - 'https://www.tensorflow.org/install/ for installation instructions.' - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f'Converting TensorFlow checkpoint from {tf_path}') - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info(f'Loading TF weight {name} with shape {shape}') - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - name = name.split('/') - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any(n in [ - 'adam_v', 'adam_m', 'AdamWeightDecayOptimizer', - 'AdamWeightDecayOptimizer_1', 'global_step' - ] for n in name): - logger.info(f"Skipping {'/'.join(name)}") - continue - pointer = model - for m_name in name: - if re.fullmatch(r'[A-Za-z]+_\d+', m_name): - scope_names = re.split(r'_(\d+)', m_name) - else: - scope_names = [m_name] - if scope_names[0] == 'kernel' or scope_names[0] == 'gamma': - pointer = getattr(pointer, 'weight') - elif scope_names[0] == 'output_bias' or scope_names[0] == 'beta': - pointer = getattr(pointer, 'bias') - elif scope_names[0] == 'output_weights': - pointer = getattr(pointer, 'weight') - elif scope_names[0] == 'squad': - pointer = getattr(pointer, 'classifier') - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if m_name[-11:] == '_embeddings': - pointer = getattr(pointer, 'weight') - elif m_name == 'kernel': - array = np.transpose(array) - try: - if pointer.shape != array.shape: - raise ValueError( - f'Pointer shape {pointer.shape} and array shape {array.shape} mismatched' - ) - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f'Initialize PyTorch weight {name}') - pointer.data = torch.from_numpy(array) - return model - - class BertEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" @@ -750,7 +673,6 @@ class BertPreTrainedModel(PreTrainedModel): """ config_class = BertConfig - load_tf_weights = load_tf_weights_in_bert base_model_prefix = 'bert' supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r'position_ids'] diff --git a/modelscope/preprocessors/nlp/nlp_base.py b/modelscope/preprocessors/nlp/nlp_base.py index 267dbb8c..bc96f569 100644 --- a/modelscope/preprocessors/nlp/nlp_base.py +++ b/modelscope/preprocessors/nlp/nlp_base.py @@ -2,7 +2,7 @@ import os.path as osp import re -from typing import Any, Dict, Iterable, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import sentencepiece as spm @@ -217,7 +217,7 @@ class NLPTokenizerPreprocessorBase(Preprocessor): return isinstance(label, str) or isinstance(label, int) if labels is not None: - if isinstance(labels, Iterable) and all([label_can_be_mapped(label) for label in labels]) \ + if isinstance(labels, (tuple, list)) and all([label_can_be_mapped(label) for label in labels]) \ and self.label2id is not None: output[OutputKeys.LABELS] = [ self.label2id[str(label)] for label in labels @@ -314,8 +314,7 @@ class SequenceClassificationPreprocessor(NLPTokenizerPreprocessorBase): def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs): kwargs['truncation'] = kwargs.get('truncation', True) - kwargs['padding'] = kwargs.get( - 'padding', False if mode == ModeKeys.INFERENCE else 'max_length') + kwargs['padding'] = kwargs.get('padding', 'max_length') kwargs['max_length'] = kwargs.pop('sequence_length', 128) super().__init__(model_dir, mode=mode, **kwargs) diff --git a/modelscope/trainers/hooks/lr_scheduler_hook.py b/modelscope/trainers/hooks/lr_scheduler_hook.py index ca0ec01b..32fb0250 100644 --- a/modelscope/trainers/hooks/lr_scheduler_hook.py +++ b/modelscope/trainers/hooks/lr_scheduler_hook.py @@ -47,7 +47,7 @@ class LrSchedulerHook(Hook): return lr def before_train_iter(self, trainer): - if not self.by_epoch: + if not self.by_epoch and trainer.iter > 0: if self.warmup_lr_scheduler is not None: self.warmup_lr_scheduler.step() else: diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index 9eaff762..61d11aa6 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -651,7 +651,7 @@ class EpochBasedTrainer(BaseTrainer): # TODO: support MsDataset load for cv if hasattr(data_cfg, 'name'): dataset = MsDataset.load( - dataset_name=data_cfg.name, + dataset_name=data_cfg.pop('name'), **data_cfg, ) cfg = ConfigDict(type=self.cfg.model.type, mode=mode) diff --git a/modelscope/utils/regress_test_utils.py b/modelscope/utils/regress_test_utils.py index 47bbadfe..3c1e5c1c 100644 --- a/modelscope/utils/regress_test_utils.py +++ b/modelscope/utils/regress_test_utils.py @@ -65,7 +65,8 @@ class RegressTool: def monitor_module_single_forward(self, module: nn.Module, file_name: str, - compare_fn=None): + compare_fn=None, + **kwargs): """Monitor a pytorch module in a single forward. @param module: A torch module @@ -107,7 +108,7 @@ class RegressTool: baseline = os.path.join(tempfile.gettempdir(), name) self.load(baseline, name) with open(baseline, 'rb') as f: - baseline_json = pickle.load(f) + base = pickle.load(f) class NumpyEncoder(json.JSONEncoder): """Special json encoder for numpy types @@ -122,9 +123,9 @@ class RegressTool: return obj.tolist() return json.JSONEncoder.default(self, obj) - print(f'baseline: {json.dumps(baseline_json, cls=NumpyEncoder)}') + print(f'baseline: {json.dumps(base, cls=NumpyEncoder)}') print(f'latest : {json.dumps(io_json, cls=NumpyEncoder)}') - if not compare_io_and_print(baseline_json, io_json, compare_fn): + if not compare_io_and_print(base, io_json, compare_fn, **kwargs): raise ValueError('Result not match!') @contextlib.contextmanager @@ -136,7 +137,8 @@ class RegressTool: ignore_keys=None, compare_random=True, reset_dropout=True, - lazy_stop_callback=None): + lazy_stop_callback=None, + **kwargs): """Monitor a pytorch module's backward data and cfg data within a step of the optimizer. This is usually useful when you try to change some dangerous code @@ -265,14 +267,15 @@ class RegressTool: baseline_json = pickle.load(f) if level == 'strict' and not compare_io_and_print( - baseline_json['forward'], io_json, compare_fn): + baseline_json['forward'], io_json, compare_fn, **kwargs): raise RuntimeError('Forward not match!') if not compare_backward_and_print( baseline_json['backward'], bw_json, compare_fn=compare_fn, ignore_keys=ignore_keys, - level=level): + level=level, + **kwargs): raise RuntimeError('Backward not match!') cfg_opt1 = { 'optimizer': baseline_json['optimizer'], @@ -286,7 +289,8 @@ class RegressTool: 'cfg': summary['cfg'], 'state': None if not compare_random else summary['state'] } - if not compare_cfg_and_optimizers(cfg_opt1, cfg_opt2, compare_fn): + if not compare_cfg_and_optimizers(cfg_opt1, cfg_opt2, compare_fn, + **kwargs): raise RuntimeError('Cfg or optimizers not match!') @@ -303,7 +307,8 @@ class MsRegressTool(RegressTool): compare_fn=None, ignore_keys=None, compare_random=True, - lazy_stop_callback=None): + lazy_stop_callback=None, + **kwargs): if lazy_stop_callback is None: @@ -319,7 +324,7 @@ class MsRegressTool(RegressTool): trainer.register_hook(EarlyStopHook()) - def _train_loop(trainer, *args, **kwargs): + def _train_loop(trainer, *args_train, **kwargs_train): with self.monitor_module_train( trainer, file_name, @@ -327,9 +332,11 @@ class MsRegressTool(RegressTool): compare_fn=compare_fn, ignore_keys=ignore_keys, compare_random=compare_random, - lazy_stop_callback=lazy_stop_callback): + lazy_stop_callback=lazy_stop_callback, + **kwargs): try: - return trainer.train_loop_origin(*args, **kwargs) + return trainer.train_loop_origin(*args_train, + **kwargs_train) except MsRegressTool.EarlyStopError: pass @@ -530,7 +537,8 @@ def compare_arguments_nested(print_content, ) return False if not all([ - compare_arguments_nested(None, sub_arg1, sub_arg2) + compare_arguments_nested( + None, sub_arg1, sub_arg2, rtol=rtol, atol=atol) for sub_arg1, sub_arg2 in zip(arg1, arg2) ]): if print_content is not None: @@ -551,7 +559,8 @@ def compare_arguments_nested(print_content, print(f'{print_content}, key diff:{set(keys1) - set(keys2)}') return False if not all([ - compare_arguments_nested(None, arg1[key], arg2[key]) + compare_arguments_nested( + None, arg1[key], arg2[key], rtol=rtol, atol=atol) for key in keys1 ]): if print_content is not None: @@ -574,7 +583,7 @@ def compare_arguments_nested(print_content, raise ValueError(f'type not supported: {type1}') -def compare_io_and_print(baseline_json, io_json, compare_fn=None): +def compare_io_and_print(baseline_json, io_json, compare_fn=None, **kwargs): if compare_fn is None: def compare_fn(*args, **kwargs): @@ -602,10 +611,10 @@ def compare_io_and_print(baseline_json, io_json, compare_fn=None): else: match = compare_arguments_nested( f'unmatched module {key} input args', v1input['args'], - v2input['args']) and match + v2input['args'], **kwargs) and match match = compare_arguments_nested( f'unmatched module {key} input kwargs', v1input['kwargs'], - v2input['kwargs']) and match + v2input['kwargs'], **kwargs) and match v1output = numpify_tensor_nested(v1['output']) v2output = numpify_tensor_nested(v2['output']) res = compare_fn(v1output, v2output, key, 'output') @@ -615,8 +624,11 @@ def compare_io_and_print(baseline_json, io_json, compare_fn=None): ) match = match and res else: - match = compare_arguments_nested(f'unmatched module {key} outputs', - v1output, v2output) and match + match = compare_arguments_nested( + f'unmatched module {key} outputs', + arg1=v1output, + arg2=v2output, + **kwargs) and match return match @@ -624,7 +636,8 @@ def compare_backward_and_print(baseline_json, bw_json, level, ignore_keys=None, - compare_fn=None): + compare_fn=None, + **kwargs): if compare_fn is None: def compare_fn(*args, **kwargs): @@ -653,18 +666,26 @@ def compare_backward_and_print(baseline_json, data2, grad2, data_after2 = bw_json[key]['data'], bw_json[key][ 'grad'], bw_json[key]['data_after'] match = compare_arguments_nested( - f'unmatched module {key} tensor data', data1, data2) and match + f'unmatched module {key} tensor data', + arg1=data1, + arg2=data2, + **kwargs) and match if level == 'strict': match = compare_arguments_nested( - f'unmatched module {key} grad data', grad1, - grad2) and match + f'unmatched module {key} grad data', + arg1=grad1, + arg2=grad2, + **kwargs) and match match = compare_arguments_nested( f'unmatched module {key} data after step', data_after1, - data_after2) and match + data_after2, **kwargs) and match return match -def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None): +def compare_cfg_and_optimizers(baseline_json, + cfg_json, + compare_fn=None, + **kwargs): if compare_fn is None: def compare_fn(*args, **kwargs): @@ -686,12 +707,12 @@ def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None): print( f"Optimizer type not equal:{optimizer1['type']} and {optimizer2['type']}" ) - match = compare_arguments_nested('unmatched optimizer defaults', - optimizer1['defaults'], - optimizer2['defaults']) and match - match = compare_arguments_nested('unmatched optimizer state_dict', - optimizer1['state_dict'], - optimizer2['state_dict']) and match + match = compare_arguments_nested( + 'unmatched optimizer defaults', optimizer1['defaults'], + optimizer2['defaults'], **kwargs) and match + match = compare_arguments_nested( + 'unmatched optimizer state_dict', optimizer1['state_dict'], + optimizer2['state_dict'], **kwargs) and match res = compare_fn(lr_scheduler1, lr_scheduler2, None, 'lr_scheduler') if res is not None: @@ -703,16 +724,17 @@ def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None): print( f"Optimizer type not equal:{lr_scheduler1['type']} and {lr_scheduler2['type']}" ) - match = compare_arguments_nested('unmatched lr_scheduler state_dict', - lr_scheduler1['state_dict'], - lr_scheduler2['state_dict']) and match + match = compare_arguments_nested( + 'unmatched lr_scheduler state_dict', lr_scheduler1['state_dict'], + lr_scheduler2['state_dict'], **kwargs) and match res = compare_fn(cfg1, cfg2, None, 'cfg') if res is not None: print(f'cfg compared with user compare_fn with result:{res}\n') match = match and res else: - match = compare_arguments_nested('unmatched cfg', cfg1, cfg2) and match + match = compare_arguments_nested( + 'unmatched cfg', arg1=cfg1, arg2=cfg2, **kwargs) and match res = compare_fn(state1, state2, None, 'state') if res is not None: @@ -721,6 +743,6 @@ def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None): match = match and res else: match = compare_arguments_nested('unmatched random state', state1, - state2) and match + state2, **kwargs) and match return match diff --git a/tests/msdatasets/test_ms_dataset.py b/tests/msdatasets/test_ms_dataset.py index 91a3b5c5..1e537e93 100644 --- a/tests/msdatasets/test_ms_dataset.py +++ b/tests/msdatasets/test_ms_dataset.py @@ -52,7 +52,8 @@ class MsDatasetTest(unittest.TestCase): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_ms_csv_basic(self): ms_ds_train = MsDataset.load( - 'afqmc_small', namespace='userxiaoming', split='train') + 'clue', subset_name='afqmc', + split='train').to_hf_dataset().select(range(5)) print(next(iter(ms_ds_train))) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') diff --git a/tests/trainers/test_finetune_sequence_classification.py b/tests/trainers/test_finetune_sequence_classification.py index f2adfa22..27db1f18 100644 --- a/tests/trainers/test_finetune_sequence_classification.py +++ b/tests/trainers/test_finetune_sequence_classification.py @@ -16,7 +16,8 @@ from modelscope.trainers.optimizer.child_tuning_adamw_optimizer import \ calculate_fisher from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.data_utils import to_device -from modelscope.utils.regress_test_utils import MsRegressTool +from modelscope.utils.regress_test_utils import (MsRegressTool, + compare_arguments_nested) from modelscope.utils.test_utils import test_level @@ -41,6 +42,33 @@ class TestFinetuneSequenceClassification(unittest.TestCase): def test_trainer_repeatable(self): import torch # noqa + def compare_fn(value1, value2, key, type): + # Ignore the differences between optimizers of two torch versions + if type != 'optimizer': + return None + + match = (value1['type'] == value2['type']) + shared_defaults = set(value1['defaults'].keys()).intersection( + set(value2['defaults'].keys())) + match = all([ + compare_arguments_nested(f'Optimizer defaults {key} not match', + value1['defaults'][key], + value2['defaults'][key]) + for key in shared_defaults + ]) and match + match = (len(value1['state_dict']['param_groups']) == len( + value2['state_dict']['param_groups'])) and match + for group1, group2 in zip(value1['state_dict']['param_groups'], + value2['state_dict']['param_groups']): + shared_keys = set(group1.keys()).intersection( + set(group2.keys())) + match = all([ + compare_arguments_nested( + f'Optimizer param_groups {key} not match', group1[key], + group2[key]) for key in shared_keys + ]) and match + return match + def cfg_modify_fn(cfg): cfg.task = 'nli' cfg['preprocessor'] = {'type': 'nli-tokenizer'} @@ -98,7 +126,8 @@ class TestFinetuneSequenceClassification(unittest.TestCase): name=Trainers.nlp_base_trainer, default_args=kwargs) with self.regress_tool.monitor_ms_train( - trainer, 'sbert-base-tnews', level='strict'): + trainer, 'sbert-base-tnews', level='strict', + compare_fn=compare_fn): trainer.train() def finetune(self, diff --git a/tests/trainers/test_trainer_with_nlp.py b/tests/trainers/test_trainer_with_nlp.py index 6030ada9..8357e778 100644 --- a/tests/trainers/test_trainer_with_nlp.py +++ b/tests/trainers/test_trainer_with_nlp.py @@ -29,7 +29,8 @@ class TestTrainerWithNlp(unittest.TestCase): os.makedirs(self.tmp_dir) self.dataset = MsDataset.load( - 'afqmc_small', namespace='userxiaoming', split='train') + 'clue', subset_name='afqmc', + split='train').to_hf_dataset().select(range(2)) def tearDown(self): shutil.rmtree(self.tmp_dir) @@ -73,7 +74,7 @@ class TestTrainerWithNlp(unittest.TestCase): output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) pipeline_sentence_similarity(output_dir) - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 3, 'skip test in current test level') def test_trainer_with_backbone_head(self): model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' kwargs = dict( @@ -99,6 +100,8 @@ class TestTrainerWithNlp(unittest.TestCase): model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' cfg = read_config(model_id, revision='beta') cfg.train.max_epochs = 20 + cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1} + cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1} cfg.train.work_dir = self.tmp_dir cfg_file = os.path.join(self.tmp_dir, 'config.json') cfg.dump(cfg_file) @@ -120,22 +123,24 @@ class TestTrainerWithNlp(unittest.TestCase): checkpoint_path=os.path.join(self.tmp_dir, 'epoch_10.pth')) self.assertTrue(Metrics.accuracy in eval_results) - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_trainer_with_configured_datasets(self): model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' cfg: Config = read_config(model_id) cfg.train.max_epochs = 20 + cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1} + cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1} cfg.train.work_dir = self.tmp_dir cfg.dataset = { 'train': { - 'name': 'afqmc_small', + 'name': 'clue', + 'subset_name': 'afqmc', 'split': 'train', - 'namespace': 'userxiaoming' }, 'val': { - 'name': 'afqmc_small', + 'name': 'clue', + 'subset_name': 'afqmc', 'split': 'train', - 'namespace': 'userxiaoming' }, } cfg_file = os.path.join(self.tmp_dir, 'config.json') @@ -159,6 +164,11 @@ class TestTrainerWithNlp(unittest.TestCase): model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' cfg: Config = read_config(model_id) cfg.train.max_epochs = 3 + cfg.preprocessor.first_sequence = 'sentence1' + cfg.preprocessor.second_sequence = 'sentence2' + cfg.preprocessor.label = 'label' + cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1} + cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1} cfg.train.work_dir = self.tmp_dir cfg_file = os.path.join(self.tmp_dir, 'config.json') cfg.dump(cfg_file)