From 212cf533184cafbebb27808a39a64651c983fb31 Mon Sep 17 00:00:00 2001 From: "yuze.zyz" Date: Thu, 27 Oct 2022 19:49:21 +0800 Subject: [PATCH] [to #42322933] Fix some bugs 1. Add F1 score to sequence classification metric 2. Fix a bug that the evaluate method in trainer does not support a pure pytorch_model.bin 3. Fix a bug in evaluation of veco trainer 4. Add some tips if lr_scheduler in the trainer needs a higher version torch 5. Add some comments Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10532230 --- .../metrics/sequence_classification_metric.py | 9 +++- modelscope/models/base/base_model.py | 22 +++++++++- .../unifold/modules/structure_module.py | 4 +- modelscope/preprocessors/base.py | 44 ++++++++++++++++++- modelscope/trainers/hooks/checkpoint_hook.py | 9 ++-- modelscope/trainers/nlp_trainer.py | 6 ++- modelscope/trainers/trainer.py | 19 +++++++- modelscope/utils/checkpoint.py | 4 +- .../test_text_classification_metrics.py | 32 ++++++++++++++ .../test_finetune_sequence_classification.py | 2 +- tests/trainers/test_trainer_with_nlp.py | 20 ++++++++- 11 files changed, 153 insertions(+), 18 deletions(-) create mode 100644 tests/metrics/test_text_classification_metrics.py diff --git a/modelscope/metrics/sequence_classification_metric.py b/modelscope/metrics/sequence_classification_metric.py index 51a829ef..1fe1c329 100644 --- a/modelscope/metrics/sequence_classification_metric.py +++ b/modelscope/metrics/sequence_classification_metric.py @@ -3,6 +3,7 @@ from typing import Dict import numpy as np +from sklearn.metrics import accuracy_score, f1_score from modelscope.metainfo import Metrics from modelscope.outputs import OutputKeys @@ -41,5 +42,11 @@ class SequenceClassificationMetric(Metric): preds = np.argmax(preds, axis=1) return { MetricKeys.ACCURACY: - (preds == labels).astype(np.float32).mean().item() + accuracy_score(labels, preds), + MetricKeys.F1: + f1_score( + labels, + preds, + average='micro' if any([label > 1 + for label in labels]) else None), } diff --git a/modelscope/models/base/base_model.py b/modelscope/models/base/base_model.py index 1246551e..e01d1f05 100644 --- a/modelscope/models/base/base_model.py +++ b/modelscope/models/base/base_model.py @@ -67,8 +67,28 @@ class Model(ABC): cfg_dict: Config = None, device: str = None, **kwargs): - """ Instantiate a model from local directory or remote model repo. Note + """Instantiate a model from local directory or remote model repo. Note that when loading from remote, the model revision can be specified. + + Args: + model_name_or_path(str): A model dir or a model id to be loaded + revision(str, `optional`): The revision used when the model_name_or_path is + a model id of the remote hub. default `master`. + cfg_dict(Config, `optional`): An optional model config. If provided, it will replace + the config read out of the `model_name_or_path` + device(str, `optional`): The device to load the model. + **kwargs: + task(str, `optional`): The `Tasks` enumeration value to replace the task value + read out of config in the `model_name_or_path`. This is useful when the model to be loaded is not + equal to the model saved. + For example, load a `backbone` into a `text-classification` model. + Other kwargs will be directly fed into the `model` key, to replace the default configs. + Returns: + A model instance. + + Examples: + >>> from modelscope.models import Model + >>> Model.from_pretrained('damo/nlp_structbert_backbone_base_std', task='text-classification') """ prefetched = kwargs.get('model_prefetched') if prefetched is not None: diff --git a/modelscope/models/science/unifold/modules/structure_module.py b/modelscope/models/science/unifold/modules/structure_module.py index 5d4da30b..4872d5c6 100644 --- a/modelscope/models/science/unifold/modules/structure_module.py +++ b/modelscope/models/science/unifold/modules/structure_module.py @@ -288,8 +288,8 @@ class InvariantPointAttention(nn.Module): pt_att *= pt_att pt_att = pt_att.sum(dim=-1) - head_weights = self.softplus(self.head_weights).view( - *((1, ) * len(pt_att.shape[:-2]) + (-1, 1))) + head_weights = self.softplus(self.head_weights).view( # noqa + *((1, ) * len(pt_att.shape[:-2]) + (-1, 1))) # noqa head_weights = head_weights * math.sqrt( 1.0 / (3 * (self.num_qk_points * 9.0 / 2))) pt_att *= head_weights * (-0.5) diff --git a/modelscope/preprocessors/base.py b/modelscope/preprocessors/base.py index db14ba47..be62ebb4 100644 --- a/modelscope/preprocessors/base.py +++ b/modelscope/preprocessors/base.py @@ -147,8 +147,50 @@ class Preprocessor(ABC): cfg_dict: Config = None, preprocessor_mode=ModeKeys.INFERENCE, **kwargs): - """ Instantiate a model from local directory or remote model repo. Note + """Instantiate a preprocessor from local directory or remote model repo. Note that when loading from remote, the model revision can be specified. + + Args: + model_name_or_path(str): A model dir or a model id used to load the preprocessor out. + revision(str, `optional`): The revision used when the model_name_or_path is + a model id of the remote hub. default `master`. + cfg_dict(Config, `optional`): An optional config. If provided, it will replace + the config read out of the `model_name_or_path` + preprocessor_mode(str, `optional`): Specify the working mode of the preprocessor, can be `train`, `eval`, + or `inference`. Default value `inference`. + The preprocessor field in the config may contain two sub preprocessors: + >>> { + >>> "train": { + >>> "type": "some-train-preprocessor" + >>> }, + >>> "val": { + >>> "type": "some-eval-preprocessor" + >>> } + >>> } + In this scenario, the `train` preprocessor will be loaded in the `train` mode, the `val` preprocessor + will be loaded in the `eval` or `inference` mode. The `mode` field in the preprocessor class + will be assigned in all the modes. + Or just one: + >>> { + >>> "type": "some-train-preprocessor" + >>> } + In this scenario, the sole preprocessor will be loaded in all the modes, + and the `mode` field in the preprocessor class will be assigned. + + **kwargs: + task(str, `optional`): The `Tasks` enumeration value to replace the task value + read out of config in the `model_name_or_path`. + This is useful when the preprocessor does not have a `type` field and the task to be used is not + equal to the task of which the model is saved. + Other kwargs will be directly fed into the preprocessor, to replace the default configs. + + Returns: + The preprocessor instance. + + Examples: + >>> from modelscope.preprocessors import Preprocessor + >>> Preprocessor.from_pretrained('damo/nlp_debertav2_fill-mask_chinese-base') + """ if not os.path.exists(model_name_or_path): model_dir = snapshot_download( diff --git a/modelscope/trainers/hooks/checkpoint_hook.py b/modelscope/trainers/hooks/checkpoint_hook.py index 9b86d5b5..47bd84c4 100644 --- a/modelscope/trainers/hooks/checkpoint_hook.py +++ b/modelscope/trainers/hooks/checkpoint_hook.py @@ -101,8 +101,9 @@ class CheckpointHook(Hook): model = trainer.model.module else: model = trainer.model - meta = load_checkpoint(filename, model, trainer.optimizer, - trainer.lr_scheduler) + meta = load_checkpoint(filename, model, + getattr(trainer, 'optimizer', None), + getattr(trainer, 'lr_scheduler', None)) trainer._epoch = meta.get('epoch', trainer._epoch) trainer._iter = meta.get('iter', trainer._iter) trainer._inner_iter = meta.get('inner_iter', trainer._inner_iter) @@ -111,7 +112,7 @@ class CheckpointHook(Hook): # hook: Hook key = f'{hook.__class__}-{i}' if key in meta and hasattr(hook, 'load_state_dict'): - hook.load_state_dict(meta[key]) + hook.load_state_dict(meta.get(key, {})) else: trainer.logger.warn( f'The state_dict of hook {hook.__class__} at index {i} is not found in the checkpoint file.' @@ -123,7 +124,7 @@ class CheckpointHook(Hook): f'The modelscope version of loaded checkpoint does not match the runtime version. ' f'The saved version: {version}, runtime version: {__version__}' ) - trainer.logger.warn( + trainer.logger.info( f'Checkpoint {filename} saving time: {meta.get("time")}') return meta diff --git a/modelscope/trainers/nlp_trainer.py b/modelscope/trainers/nlp_trainer.py index a19e7c7b..a92a3706 100644 --- a/modelscope/trainers/nlp_trainer.py +++ b/modelscope/trainers/nlp_trainer.py @@ -646,7 +646,9 @@ class VecoTrainer(NlpEpochBasedTrainer): break for metric_name in self.metrics: - metric_values[metric_name] = np.average( - [m[metric_name] for m in metric_values.values()]) + all_metrics = [m[metric_name] for m in metric_values.values()] + for key in all_metrics[0].keys(): + metric_values[key] = np.average( + [metric[key] for metric in all_metrics]) return metric_values diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index f660a55a..e1fd7522 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -667,10 +667,25 @@ class EpochBasedTrainer(BaseTrainer): return dataset def build_optimizer(self, cfg: ConfigDict, default_args: dict = None): - return build_optimizer(self.model, cfg=cfg, default_args=default_args) + try: + return build_optimizer( + self.model, cfg=cfg, default_args=default_args) + except KeyError as e: + self.logger.error( + f'Build optimizer error, the optimizer {cfg} is native torch optimizer, ' + f'please check if your torch with version: {torch.__version__} matches the config.' + ) + raise e def build_lr_scheduler(self, cfg: ConfigDict, default_args: dict = None): - return build_lr_scheduler(cfg=cfg, default_args=default_args) + try: + return build_lr_scheduler(cfg=cfg, default_args=default_args) + except KeyError as e: + self.logger.error( + f'Build lr_scheduler error, the lr_scheduler {cfg} is native torch lr_scheduler, ' + f'please check if your torch with version: {torch.__version__} matches the config.' + ) + raise e def create_optimizer_and_scheduler(self): """ Create optimizer and lr scheduler diff --git a/modelscope/utils/checkpoint.py b/modelscope/utils/checkpoint.py index 2a7520f2..5acaa411 100644 --- a/modelscope/utils/checkpoint.py +++ b/modelscope/utils/checkpoint.py @@ -134,9 +134,7 @@ def load_checkpoint(filename, state_dict = checkpoint if 'state_dict' not in checkpoint else checkpoint[ 'state_dict'] model.load_state_dict(state_dict) - - if 'meta' in checkpoint: - return checkpoint.get('meta', {}) + return checkpoint.get('meta', {}) def save_pretrained(model, diff --git a/tests/metrics/test_text_classification_metrics.py b/tests/metrics/test_text_classification_metrics.py new file mode 100644 index 00000000..d0a4cee1 --- /dev/null +++ b/tests/metrics/test_text_classification_metrics.py @@ -0,0 +1,32 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest + +import numpy as np + +from modelscope.metrics.sequence_classification_metric import \ + SequenceClassificationMetric +from modelscope.utils.test_utils import test_level + + +class TestTextClsMetrics(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_value(self): + metric = SequenceClassificationMetric() + outputs = { + 'logits': + np.array([[2.0, 1.0, 0.5], [1.0, 1.5, 1.0], [2.0, 1.0, 3.0], + [2.4, 1.5, 4.0], [2.0, 1.0, 3.0], [2.4, 1.5, 1.7], + [2.0, 1.0, 0.5], [2.4, 1.5, 0.5]]) + } + inputs = {'labels': np.array([0, 1, 2, 2, 0, 1, 2, 2])} + metric.add(outputs, inputs) + ret = metric.evaluate() + self.assertTrue(np.isclose(ret['f1'], 0.5)) + self.assertTrue(np.isclose(ret['accuracy'], 0.5)) + print(ret) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/trainers/test_finetune_sequence_classification.py b/tests/trainers/test_finetune_sequence_classification.py index ae780793..02dd9d2f 100644 --- a/tests/trainers/test_finetune_sequence_classification.py +++ b/tests/trainers/test_finetune_sequence_classification.py @@ -346,7 +346,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): train_datasets = [] from datasets import DownloadConfig dc = DownloadConfig() - dc.local_files_only = True + dc.local_files_only = False for lang in langs: train_datasets.append( load_dataset('xnli', lang, split='train', download_config=dc)) diff --git a/tests/trainers/test_trainer_with_nlp.py b/tests/trainers/test_trainer_with_nlp.py index 8aaa42a3..66aedfd8 100644 --- a/tests/trainers/test_trainer_with_nlp.py +++ b/tests/trainers/test_trainer_with_nlp.py @@ -223,13 +223,31 @@ class TestTrainerWithNlp(unittest.TestCase): trainer, 'trainer_continue_train', level='strict'): trainer.train(os.path.join(self.tmp_dir, 'iter_3.pth')) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_trainer_with_evaluation(self): + tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + + model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny' + cache_path = snapshot_download(model_id) + model = SbertForSequenceClassification.from_pretrained(cache_path) + kwargs = dict( + cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), + model=model, + eval_dataset=self.dataset, + work_dir=self.tmp_dir) + + trainer = build_trainer(default_args=kwargs) + print(trainer.evaluate(cache_path + '/pytorch_model.bin')) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_trainer_with_model_and_args(self): tmp_dir = tempfile.TemporaryDirectory().name if not os.path.exists(tmp_dir): os.makedirs(tmp_dir) - model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' + model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny' cache_path = snapshot_download(model_id) model = SbertForSequenceClassification.from_pretrained(cache_path) kwargs = dict(