Browse Source

[to #42322933]add face 2d keypoints finetune test case

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10421808

    * add face 2d keypoints & human wholebody keypoint finrtune test case
master
shouzhou.bx yingda.chen 2 years ago
parent
commit
01d521dd78
4 changed files with 83 additions and 7 deletions
  1. +2
    -2
      modelscope/metainfo.py
  2. +9
    -4
      modelscope/msdatasets/cv/easycv_base.py
  3. +1
    -1
      requirements/cv.txt
  4. +71
    -0
      tests/trainers/easycv/test_easycv_trainer_face_2d_keypoints.py

+ 2
- 2
modelscope/metainfo.py View File

@@ -452,9 +452,9 @@ class Datasets(object):
""" Names for different datasets.
"""
ClsDataset = 'ClsDataset'
Face2dKeypointsDataset = 'Face2dKeypointsDataset'
Face2dKeypointsDataset = 'FaceKeypointDataset'
HandCocoWholeBodyDataset = 'HandCocoWholeBodyDataset'
HumanWholeBodyKeypointDataset = 'HumanWholeBodyKeypointDataset'
HumanWholeBodyKeypointDataset = 'WholeBodyCocoTopDownDataset'
SegDataset = 'SegDataset'
DetDataset = 'DetDataset'
DetImagesMixDataset = 'DetImagesMixDataset'

+ 9
- 4
modelscope/msdatasets/cv/easycv_base.py View File

@@ -26,11 +26,16 @@ class EasyCVBaseDataset(object):
if self.split_config is not None:
self._update_data_source(kwargs['data_source'])

def _update_data_root(self, input_dict, data_root):
for k, v in input_dict.items():
if isinstance(v, str) and self.DATA_ROOT_PATTERN in v:
input_dict.update(
{k: v.replace(self.DATA_ROOT_PATTERN, data_root)})
elif isinstance(v, dict):
self._update_data_root(v, data_root)

def _update_data_source(self, data_source):
data_root = next(iter(self.split_config.values()))
data_root = data_root.rstrip(osp.sep)

for k, v in data_source.items():
if isinstance(v, str) and self.DATA_ROOT_PATTERN in v:
data_source.update(
{k: v.replace(self.DATA_ROOT_PATTERN, data_root)})
self._update_data_root(data_source, data_root)

+ 1
- 1
requirements/cv.txt View File

@@ -19,7 +19,7 @@ moviepy>=1.0.3
networkx>=2.5
numba
onnxruntime>=1.10
pai-easycv>=0.6.3.7
pai-easycv>=0.6.3.9
pandas
psutil
regex


+ 71
- 0
tests/trainers/easycv/test_easycv_trainer_face_2d_keypoints.py View File

@@ -0,0 +1,71 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import glob
import os
import shutil
import tempfile
import unittest

import torch

from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.utils.constant import DownloadMode, LogKeys, Tasks
from modelscope.utils.logger import get_logger
from modelscope.utils.test_utils import test_level


@unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest')
class EasyCVTrainerTestFace2DKeypoints(unittest.TestCase):
model_id = 'damo/cv_mobilenet_face-2d-keypoints_alignment'

def setUp(self):
self.logger = get_logger()
self.logger.info(('Testing %s.%s' %
(type(self).__name__, self._testMethodName)))

def _train(self, tmp_dir):
cfg_options = {'train.max_epochs': 2}

trainer_name = Trainers.easycv

train_dataset = MsDataset.load(
dataset_name='face_2d_keypoints_dataset',
namespace='modelscope',
split='train',
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)
eval_dataset = MsDataset.load(
dataset_name='face_2d_keypoints_dataset',
namespace='modelscope',
split='train',
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)

kwargs = dict(
model=self.model_id,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
work_dir=tmp_dir,
cfg_options=cfg_options)

trainer = build_trainer(trainer_name, kwargs)
trainer.train()

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer_single_gpu(self):
temp_file_dir = tempfile.TemporaryDirectory()
tmp_dir = temp_file_dir.name
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)

self._train(tmp_dir)

results_files = os.listdir(tmp_dir)
json_files = glob.glob(os.path.join(tmp_dir, '*.log.json'))
self.assertEqual(len(json_files), 1)
self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)

temp_file_dir.cleanup()


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save