Browse Source

add five task finetune

master
翎航 2 years ago
parent
commit
0418786cbe
6 changed files with 101 additions and 14 deletions
  1. +1
    -0
      modelscope/metainfo.py
  2. +1
    -0
      modelscope/metrics/builder.py
  3. +67
    -0
      modelscope/metrics/map_metric.py
  4. +13
    -5
      modelscope/preprocessors/ofa/visual_grounding.py
  5. +13
    -4
      modelscope/trainers/multi_modal/ofa/ofa_trainer.py
  6. +6
    -5
      tests/trainers/test_ofa_trainer.py

+ 1
- 0
modelscope/metainfo.py View File

@@ -402,6 +402,7 @@ class Metrics(object):

# accuracy
accuracy = 'accuracy'
multi_average_precision = 'mAP'
audio_noise_metric = 'audio-noise-metric'

# text gen


+ 1
- 0
modelscope/metrics/builder.py View File

@@ -24,6 +24,7 @@ class MetricKeys(object):
ROUGE_1 = 'rouge-1'
ROUGE_L = 'rouge-l'
NED = 'ned' # ocr metric
mAP = 'mAP'
BatchAcc = 'inbatch_t2i_recall_at_1'




+ 67
- 0
modelscope/metrics/map_metric.py View File

@@ -0,0 +1,67 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Dict

import numpy as np

from modelscope.metainfo import Metrics
from modelscope.outputs import OutputKeys
from modelscope.utils.registry import default_group
from .base import Metric
from .builder import METRICS, MetricKeys


@METRICS.register_module(
group_key=default_group, module_name=Metrics.multi_average_precision)
class AveragePrecisionMetric(Metric):
"""The metric computation class for multi avarage precision classes.

This metric class calculates multi avarage precision for the whole input batches.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.preds = []
self.labels = []
self.thresh = kwargs.get('threshold', 0.5)

def add(self, outputs: Dict, inputs: Dict):
label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS
ground_truths = inputs[label_name]
eval_results = outputs[label_name]
for key in [
OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES,
OutputKeys.LABELS, OutputKeys.SCORES
]:
if key in outputs and outputs[key] is not None:
eval_results = outputs[key]
break
assert type(ground_truths) == type(eval_results)
for truth in ground_truths:
self.labels.append(truth)
for result in eval_results:
if isinstance(truth, str):
self.preds.append(result.strip().replace(' ', ''))
else:
self.preds.append(result)

def evaluate(self):
assert len(self.preds) == len(self.labels)
scores = self._calculate_ap_score(self.preds, self.labels, self.thresh)
return {MetricKeys.mAP: scores.mean().item()}

def _calculate_ap_score(self, preds, labels, thresh=0.5):
hyps = np.array(preds)
refs = np.array(labels)
a = np.where(hyps[:, :2] < refs[:, :2], refs[:, :2], hyps[:, :2])
b = np.where(hyps[:, 2:] < refs[:, 2:], hyps[:, 2:], refs[:, 2:])
interacts = np.concatenate([a, b], axis=1)
area_predictions = (hyps[:, 2] - hyps[:, 0]) * (
hyps[:, 3] - hyps[:, 1])
area_targets = (refs[:, 2] - refs[:, 0]) * (refs[:, 3] - refs[:, 1])
interacts_w = interacts[:, 2] - interacts[:, 0]
interacts_h = interacts[:, 3] - interacts[:, 1]
area_interacts = interacts_w * interacts_h
ious = area_interacts / (
area_predictions + area_targets - area_interacts + 1e-6)
return (ious >= thresh) & (interacts_w > 0) & (interacts_h > 0)

+ 13
- 5
modelscope/preprocessors/ofa/visual_grounding.py View File

@@ -9,6 +9,7 @@ from torchvision import transforms
from modelscope.preprocessors.image import load_image
from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor
from .utils import transforms as T


class OfaVisualGroundingPreprocessor(OfaBasePreprocessor):
@@ -29,13 +30,14 @@ class OfaVisualGroundingPreprocessor(OfaBasePreprocessor):
super(OfaVisualGroundingPreprocessor,
self).__init__(cfg, model_dir, mode, *args, **kwargs)

self.num_bins = self.cfg.model.get('num_bins', 1000)
if self.mode == ModeKeys.TRAIN:
# for positioning
self.positioning_transform = transforms.Compose([
transforms.RandomResize([self.patch_image_size],
max_size=self.patch_image_size),
transforms.ToTensor(),
transforms.Normalize(
self.positioning_transform = T.Compose([
T.RandomResize([self.patch_image_size],
max_size=self.patch_image_size),
T.ToTensor(),
T.Normalize(
mean=self.mean,
std=self.std,
max_image_size=self.max_image_size)
@@ -130,4 +132,10 @@ class OfaVisualGroundingPreprocessor(OfaBasePreprocessor):
'w_resize_ratio': w_resize_ratio,
'h_resize_ratio': h_resize_ratio,
}

if 'region_coord' in self.column_map and self.column_map[
'region_coord'] in data:
x0, y0, x1, y1 = data[
self.column_map['region_coord']].strip().split(',')
sample['label'] = [float(x0), float(y0), float(x1), float(y1)]
return sample

+ 13
- 4
modelscope/trainers/multi_modal/ofa/ofa_trainer.py View File

@@ -34,6 +34,7 @@ class OFATrainer(EpochBasedTrainer):
self,
model: Optional[Union[TorchModel, nn.Module, str]] = None,
cfg_file: Optional[str] = None,
cfg_modify_fn: Optional[Callable] = None,
arg_parse_fn: Optional[Callable] = None,
data_collator: Optional[Union[Callable, Dict[str,
Callable]]] = None,
@@ -49,7 +50,8 @@ class OFATrainer(EpochBasedTrainer):
**kwargs):
model = Model.from_pretrained(model, revision=model_revision)
model_dir = model.model_dir
cfg = Config.from_file(cfg_file)
self.cfg_modify_fn = cfg_modify_fn
cfg = self.rebuild_config(Config.from_file(cfg_file))
if 'work_dir' not in kwargs or len(kwargs['work_dir']) == 0:
work_dir = cfg.train.work_dir
else:
@@ -57,10 +59,12 @@ class OFATrainer(EpochBasedTrainer):
tokenizer_files = {
'zh': [
'tokenizer.json', 'tokenizer_config.json', 'vocab.txt',
'config.json'
'config.json', 'ans2label.json'
],
'en': [
'tokenizer.json', 'vocab.json', 'merges.txt', 'config.json',
'ans2label.json'
],
'en':
['tokenizer.json', 'vocab.json', 'merges.txt', 'config.json'],
}
for filename in tokenizer_files[cfg.model.get('language', 'en')]:
finetune_file = os.path.join(work_dir, filename)
@@ -127,6 +131,11 @@ class OFATrainer(EpochBasedTrainer):
**kwargs,
)

def rebuild_config(self, cfg: Config):
if self.cfg_modify_fn is not None:
cfg = self.cfg_modify_fn(cfg)
return cfg

def train_step(self, model, inputs):
model.train()
loss, sample_size, logging_output = self.criterion(model, inputs)


+ 6
- 5
tests/trainers/test_ofa_trainer.py View File

@@ -5,10 +5,10 @@ import unittest

import json

from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.utils.constant import DownloadMode, ModelFile
from modelscope.utils.hub import read_config
from modelscope.utils.test_utils import test_level


@@ -73,11 +73,12 @@ class TestOfaTrainer(unittest.TestCase):
def test_trainer_std(self):
WORKSPACE = './workspace/ckpts/recognition'
os.makedirs(WORKSPACE, exist_ok=True)
config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION)
with open(config_file, 'w') as writer:
json.dump(self.finetune_cfg, writer)

pretrained_model = 'damo/ofa_ocr-recognition_scene_base_zh'
cfg = read_config(pretrained_model)
config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION)
cfg.dump(config_file)

args = dict(
model=pretrained_model,
work_dir=WORKSPACE,
@@ -94,7 +95,7 @@ class TestOfaTrainer(unittest.TestCase):
split='test[:20]',
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS),
cfg_file=config_file)
trainer = build_trainer(name=Trainers.ofa, default_args=args)
trainer = build_trainer(name='ofa', default_args=args)
trainer.train()

self.assertIn(


Loading…
Cancel
Save