Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10421808 * add face 2d keypoints & human wholebody keypoint finrtune test casemaster
@@ -452,9 +452,9 @@ class Datasets(object): | |||||
""" Names for different datasets. | """ Names for different datasets. | ||||
""" | """ | ||||
ClsDataset = 'ClsDataset' | ClsDataset = 'ClsDataset' | ||||
Face2dKeypointsDataset = 'Face2dKeypointsDataset' | |||||
Face2dKeypointsDataset = 'FaceKeypointDataset' | |||||
HandCocoWholeBodyDataset = 'HandCocoWholeBodyDataset' | HandCocoWholeBodyDataset = 'HandCocoWholeBodyDataset' | ||||
HumanWholeBodyKeypointDataset = 'HumanWholeBodyKeypointDataset' | |||||
HumanWholeBodyKeypointDataset = 'WholeBodyCocoTopDownDataset' | |||||
SegDataset = 'SegDataset' | SegDataset = 'SegDataset' | ||||
DetDataset = 'DetDataset' | DetDataset = 'DetDataset' | ||||
DetImagesMixDataset = 'DetImagesMixDataset' | DetImagesMixDataset = 'DetImagesMixDataset' |
@@ -26,11 +26,16 @@ class EasyCVBaseDataset(object): | |||||
if self.split_config is not None: | if self.split_config is not None: | ||||
self._update_data_source(kwargs['data_source']) | self._update_data_source(kwargs['data_source']) | ||||
def _update_data_root(self, input_dict, data_root): | |||||
for k, v in input_dict.items(): | |||||
if isinstance(v, str) and self.DATA_ROOT_PATTERN in v: | |||||
input_dict.update( | |||||
{k: v.replace(self.DATA_ROOT_PATTERN, data_root)}) | |||||
elif isinstance(v, dict): | |||||
self._update_data_root(v, data_root) | |||||
def _update_data_source(self, data_source): | def _update_data_source(self, data_source): | ||||
data_root = next(iter(self.split_config.values())) | data_root = next(iter(self.split_config.values())) | ||||
data_root = data_root.rstrip(osp.sep) | data_root = data_root.rstrip(osp.sep) | ||||
for k, v in data_source.items(): | |||||
if isinstance(v, str) and self.DATA_ROOT_PATTERN in v: | |||||
data_source.update( | |||||
{k: v.replace(self.DATA_ROOT_PATTERN, data_root)}) | |||||
self._update_data_root(data_source, data_root) |
@@ -19,7 +19,7 @@ moviepy>=1.0.3 | |||||
networkx>=2.5 | networkx>=2.5 | ||||
numba | numba | ||||
onnxruntime>=1.10 | onnxruntime>=1.10 | ||||
pai-easycv>=0.6.3.7 | |||||
pai-easycv>=0.6.3.9 | |||||
pandas | pandas | ||||
psutil | psutil | ||||
regex | regex | ||||
@@ -0,0 +1,71 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import glob | |||||
import os | |||||
import shutil | |||||
import tempfile | |||||
import unittest | |||||
import torch | |||||
from modelscope.metainfo import Trainers | |||||
from modelscope.msdatasets import MsDataset | |||||
from modelscope.trainers import build_trainer | |||||
from modelscope.utils.constant import DownloadMode, LogKeys, Tasks | |||||
from modelscope.utils.logger import get_logger | |||||
from modelscope.utils.test_utils import test_level | |||||
@unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest') | |||||
class EasyCVTrainerTestFace2DKeypoints(unittest.TestCase): | |||||
model_id = 'damo/cv_mobilenet_face-2d-keypoints_alignment' | |||||
def setUp(self): | |||||
self.logger = get_logger() | |||||
self.logger.info(('Testing %s.%s' % | |||||
(type(self).__name__, self._testMethodName))) | |||||
def _train(self, tmp_dir): | |||||
cfg_options = {'train.max_epochs': 2} | |||||
trainer_name = Trainers.easycv | |||||
train_dataset = MsDataset.load( | |||||
dataset_name='face_2d_keypoints_dataset', | |||||
namespace='modelscope', | |||||
split='train', | |||||
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS) | |||||
eval_dataset = MsDataset.load( | |||||
dataset_name='face_2d_keypoints_dataset', | |||||
namespace='modelscope', | |||||
split='train', | |||||
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS) | |||||
kwargs = dict( | |||||
model=self.model_id, | |||||
train_dataset=train_dataset, | |||||
eval_dataset=eval_dataset, | |||||
work_dir=tmp_dir, | |||||
cfg_options=cfg_options) | |||||
trainer = build_trainer(trainer_name, kwargs) | |||||
trainer.train() | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
def test_trainer_single_gpu(self): | |||||
temp_file_dir = tempfile.TemporaryDirectory() | |||||
tmp_dir = temp_file_dir.name | |||||
if not os.path.exists(tmp_dir): | |||||
os.makedirs(tmp_dir) | |||||
self._train(tmp_dir) | |||||
results_files = os.listdir(tmp_dir) | |||||
json_files = glob.glob(os.path.join(tmp_dir, '*.log.json')) | |||||
self.assertEqual(len(json_files), 1) | |||||
self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | |||||
temp_file_dir.cleanup() | |||||
if __name__ == '__main__': | |||||
unittest.main() |