diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 6c8d91fa..31bef3b8 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -452,9 +452,9 @@ class Datasets(object): """ Names for different datasets. """ ClsDataset = 'ClsDataset' - Face2dKeypointsDataset = 'Face2dKeypointsDataset' + Face2dKeypointsDataset = 'FaceKeypointDataset' HandCocoWholeBodyDataset = 'HandCocoWholeBodyDataset' - HumanWholeBodyKeypointDataset = 'HumanWholeBodyKeypointDataset' + HumanWholeBodyKeypointDataset = 'WholeBodyCocoTopDownDataset' SegDataset = 'SegDataset' DetDataset = 'DetDataset' DetImagesMixDataset = 'DetImagesMixDataset' diff --git a/modelscope/msdatasets/cv/easycv_base.py b/modelscope/msdatasets/cv/easycv_base.py index a45827a3..7b6df6e0 100644 --- a/modelscope/msdatasets/cv/easycv_base.py +++ b/modelscope/msdatasets/cv/easycv_base.py @@ -26,11 +26,16 @@ class EasyCVBaseDataset(object): if self.split_config is not None: self._update_data_source(kwargs['data_source']) + def _update_data_root(self, input_dict, data_root): + for k, v in input_dict.items(): + if isinstance(v, str) and self.DATA_ROOT_PATTERN in v: + input_dict.update( + {k: v.replace(self.DATA_ROOT_PATTERN, data_root)}) + elif isinstance(v, dict): + self._update_data_root(v, data_root) + def _update_data_source(self, data_source): data_root = next(iter(self.split_config.values())) data_root = data_root.rstrip(osp.sep) - for k, v in data_source.items(): - if isinstance(v, str) and self.DATA_ROOT_PATTERN in v: - data_source.update( - {k: v.replace(self.DATA_ROOT_PATTERN, data_root)}) + self._update_data_root(data_source, data_root) diff --git a/requirements/cv.txt b/requirements/cv.txt index d23fab3a..f29b296b 100644 --- a/requirements/cv.txt +++ b/requirements/cv.txt @@ -19,7 +19,7 @@ moviepy>=1.0.3 networkx>=2.5 numba onnxruntime>=1.10 -pai-easycv>=0.6.3.7 +pai-easycv>=0.6.3.9 pandas psutil regex diff --git a/tests/trainers/easycv/test_easycv_trainer_face_2d_keypoints.py b/tests/trainers/easycv/test_easycv_trainer_face_2d_keypoints.py new file mode 100644 index 00000000..4dffa998 --- /dev/null +++ b/tests/trainers/easycv/test_easycv_trainer_face_2d_keypoints.py @@ -0,0 +1,71 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import glob +import os +import shutil +import tempfile +import unittest + +import torch + +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.constant import DownloadMode, LogKeys, Tasks +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + + +@unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest') +class EasyCVTrainerTestFace2DKeypoints(unittest.TestCase): + model_id = 'damo/cv_mobilenet_face-2d-keypoints_alignment' + + def setUp(self): + self.logger = get_logger() + self.logger.info(('Testing %s.%s' % + (type(self).__name__, self._testMethodName))) + + def _train(self, tmp_dir): + cfg_options = {'train.max_epochs': 2} + + trainer_name = Trainers.easycv + + train_dataset = MsDataset.load( + dataset_name='face_2d_keypoints_dataset', + namespace='modelscope', + split='train', + download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS) + eval_dataset = MsDataset.load( + dataset_name='face_2d_keypoints_dataset', + namespace='modelscope', + split='train', + download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS) + + kwargs = dict( + model=self.model_id, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + work_dir=tmp_dir, + cfg_options=cfg_options) + + trainer = build_trainer(trainer_name, kwargs) + trainer.train() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_single_gpu(self): + temp_file_dir = tempfile.TemporaryDirectory() + tmp_dir = temp_file_dir.name + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + + self._train(tmp_dir) + + results_files = os.listdir(tmp_dir) + json_files = glob.glob(os.path.join(tmp_dir, '*.log.json')) + self.assertEqual(len(json_files), 1) + self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) + + temp_file_dir.cleanup() + + +if __name__ == '__main__': + unittest.main()