diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 8c9964b8..2df6f2a0 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -402,6 +402,7 @@ class Metrics(object): # accuracy accuracy = 'accuracy' + multi_average_precision = 'mAP' audio_noise_metric = 'audio-noise-metric' # text gen diff --git a/modelscope/metrics/builder.py b/modelscope/metrics/builder.py index b9e402c5..e2fe67f8 100644 --- a/modelscope/metrics/builder.py +++ b/modelscope/metrics/builder.py @@ -24,6 +24,7 @@ class MetricKeys(object): ROUGE_1 = 'rouge-1' ROUGE_L = 'rouge-l' NED = 'ned' # ocr metric + mAP = 'mAP' BatchAcc = 'inbatch_t2i_recall_at_1' diff --git a/modelscope/metrics/map_metric.py b/modelscope/metrics/map_metric.py new file mode 100644 index 00000000..aac76f22 --- /dev/null +++ b/modelscope/metrics/map_metric.py @@ -0,0 +1,67 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Dict + +import numpy as np + +from modelscope.metainfo import Metrics +from modelscope.outputs import OutputKeys +from modelscope.utils.registry import default_group +from .base import Metric +from .builder import METRICS, MetricKeys + + +@METRICS.register_module( + group_key=default_group, module_name=Metrics.multi_average_precision) +class AveragePrecisionMetric(Metric): + """The metric computation class for multi avarage precision classes. + + This metric class calculates multi avarage precision for the whole input batches. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.preds = [] + self.labels = [] + self.thresh = kwargs.get('threshold', 0.5) + + def add(self, outputs: Dict, inputs: Dict): + label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS + ground_truths = inputs[label_name] + eval_results = outputs[label_name] + for key in [ + OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES, + OutputKeys.LABELS, OutputKeys.SCORES + ]: + if key in outputs and outputs[key] is not None: + eval_results = outputs[key] + break + assert type(ground_truths) == type(eval_results) + for truth in ground_truths: + self.labels.append(truth) + for result in eval_results: + if isinstance(truth, str): + self.preds.append(result.strip().replace(' ', '')) + else: + self.preds.append(result) + + def evaluate(self): + assert len(self.preds) == len(self.labels) + scores = self._calculate_ap_score(self.preds, self.labels, self.thresh) + return {MetricKeys.mAP: scores.mean().item()} + + def _calculate_ap_score(self, preds, labels, thresh=0.5): + hyps = np.array(preds) + refs = np.array(labels) + a = np.where(hyps[:, :2] < refs[:, :2], refs[:, :2], hyps[:, :2]) + b = np.where(hyps[:, 2:] < refs[:, 2:], hyps[:, 2:], refs[:, 2:]) + interacts = np.concatenate([a, b], axis=1) + area_predictions = (hyps[:, 2] - hyps[:, 0]) * ( + hyps[:, 3] - hyps[:, 1]) + area_targets = (refs[:, 2] - refs[:, 0]) * (refs[:, 3] - refs[:, 1]) + interacts_w = interacts[:, 2] - interacts[:, 0] + interacts_h = interacts[:, 3] - interacts[:, 1] + area_interacts = interacts_w * interacts_h + ious = area_interacts / ( + area_predictions + area_targets - area_interacts + 1e-6) + return (ious >= thresh) & (interacts_w > 0) & (interacts_h > 0) diff --git a/modelscope/preprocessors/ofa/visual_grounding.py b/modelscope/preprocessors/ofa/visual_grounding.py index d9779fbe..2da79670 100644 --- a/modelscope/preprocessors/ofa/visual_grounding.py +++ b/modelscope/preprocessors/ofa/visual_grounding.py @@ -9,6 +9,7 @@ from torchvision import transforms from modelscope.preprocessors.image import load_image from modelscope.utils.constant import ModeKeys from .base import OfaBasePreprocessor +from .utils import transforms as T class OfaVisualGroundingPreprocessor(OfaBasePreprocessor): @@ -29,13 +30,14 @@ class OfaVisualGroundingPreprocessor(OfaBasePreprocessor): super(OfaVisualGroundingPreprocessor, self).__init__(cfg, model_dir, mode, *args, **kwargs) + self.num_bins = self.cfg.model.get('num_bins', 1000) if self.mode == ModeKeys.TRAIN: # for positioning - self.positioning_transform = transforms.Compose([ - transforms.RandomResize([self.patch_image_size], - max_size=self.patch_image_size), - transforms.ToTensor(), - transforms.Normalize( + self.positioning_transform = T.Compose([ + T.RandomResize([self.patch_image_size], + max_size=self.patch_image_size), + T.ToTensor(), + T.Normalize( mean=self.mean, std=self.std, max_image_size=self.max_image_size) @@ -130,4 +132,10 @@ class OfaVisualGroundingPreprocessor(OfaBasePreprocessor): 'w_resize_ratio': w_resize_ratio, 'h_resize_ratio': h_resize_ratio, } + + if 'region_coord' in self.column_map and self.column_map[ + 'region_coord'] in data: + x0, y0, x1, y1 = data[ + self.column_map['region_coord']].strip().split(',') + sample['label'] = [float(x0), float(y0), float(x1), float(y1)] return sample diff --git a/modelscope/trainers/multi_modal/ofa/ofa_trainer.py b/modelscope/trainers/multi_modal/ofa/ofa_trainer.py index f8028c6c..71494768 100644 --- a/modelscope/trainers/multi_modal/ofa/ofa_trainer.py +++ b/modelscope/trainers/multi_modal/ofa/ofa_trainer.py @@ -34,6 +34,7 @@ class OFATrainer(EpochBasedTrainer): self, model: Optional[Union[TorchModel, nn.Module, str]] = None, cfg_file: Optional[str] = None, + cfg_modify_fn: Optional[Callable] = None, arg_parse_fn: Optional[Callable] = None, data_collator: Optional[Union[Callable, Dict[str, Callable]]] = None, @@ -49,7 +50,8 @@ class OFATrainer(EpochBasedTrainer): **kwargs): model = Model.from_pretrained(model, revision=model_revision) model_dir = model.model_dir - cfg = Config.from_file(cfg_file) + self.cfg_modify_fn = cfg_modify_fn + cfg = self.rebuild_config(Config.from_file(cfg_file)) if 'work_dir' not in kwargs or len(kwargs['work_dir']) == 0: work_dir = cfg.train.work_dir else: @@ -57,10 +59,12 @@ class OFATrainer(EpochBasedTrainer): tokenizer_files = { 'zh': [ 'tokenizer.json', 'tokenizer_config.json', 'vocab.txt', - 'config.json' + 'config.json', 'ans2label.json' + ], + 'en': [ + 'tokenizer.json', 'vocab.json', 'merges.txt', 'config.json', + 'ans2label.json' ], - 'en': - ['tokenizer.json', 'vocab.json', 'merges.txt', 'config.json'], } for filename in tokenizer_files[cfg.model.get('language', 'en')]: finetune_file = os.path.join(work_dir, filename) @@ -127,6 +131,11 @@ class OFATrainer(EpochBasedTrainer): **kwargs, ) + def rebuild_config(self, cfg: Config): + if self.cfg_modify_fn is not None: + cfg = self.cfg_modify_fn(cfg) + return cfg + def train_step(self, model, inputs): model.train() loss, sample_size, logging_output = self.criterion(model, inputs) diff --git a/tests/trainers/test_ofa_trainer.py b/tests/trainers/test_ofa_trainer.py index 85c21881..098416bb 100644 --- a/tests/trainers/test_ofa_trainer.py +++ b/tests/trainers/test_ofa_trainer.py @@ -5,10 +5,10 @@ import unittest import json -from modelscope.metainfo import Trainers from modelscope.msdatasets import MsDataset from modelscope.trainers import build_trainer from modelscope.utils.constant import DownloadMode, ModelFile +from modelscope.utils.hub import read_config from modelscope.utils.test_utils import test_level @@ -73,11 +73,12 @@ class TestOfaTrainer(unittest.TestCase): def test_trainer_std(self): WORKSPACE = './workspace/ckpts/recognition' os.makedirs(WORKSPACE, exist_ok=True) - config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) - with open(config_file, 'w') as writer: - json.dump(self.finetune_cfg, writer) pretrained_model = 'damo/ofa_ocr-recognition_scene_base_zh' + cfg = read_config(pretrained_model) + config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) + cfg.dump(config_file) + args = dict( model=pretrained_model, work_dir=WORKSPACE, @@ -94,7 +95,7 @@ class TestOfaTrainer(unittest.TestCase): split='test[:20]', download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS), cfg_file=config_file) - trainer = build_trainer(name=Trainers.ofa, default_args=args) + trainer = build_trainer(name='ofa', default_args=args) trainer.train() self.assertIn(