Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9781849 * support EasyCVmaster
@@ -2,7 +2,6 @@ | |||
"framework": "pytorch", | |||
"task": "image_classification", | |||
"work_dir": "./work_dir", | |||
"model": { | |||
"type": "classification", | |||
@@ -119,6 +118,7 @@ | |||
}, | |||
"train": { | |||
"work_dir": "./work_dir", | |||
"dataloader": { | |||
"batch_size_per_gpu": 2, | |||
"workers_per_gpu": 1 | |||
@@ -0,0 +1,3 @@ | |||
version https://git-lfs.github.com/spec/v1 | |||
oid sha256:af6fa61274e497ecc170de5adc4b8e7ac89eba2bc22a6aa119b08ec7adbe9459 | |||
size 146140 |
@@ -1,5 +1,5 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import json | |||
import jsonplus | |||
import numpy as np | |||
from .base import FormatHandler | |||
@@ -22,14 +22,14 @@ def set_default(obj): | |||
class JsonHandler(FormatHandler): | |||
"""Use jsonplus, serialization of Python types to JSON that "just works".""" | |||
def load(self, file): | |||
return json.load(file) | |||
return jsonplus.loads(file.read()) | |||
def dump(self, obj, file, **kwargs): | |||
kwargs.setdefault('default', set_default) | |||
json.dump(obj, file, **kwargs) | |||
file.write(self.dumps(obj, **kwargs)) | |||
def dumps(self, obj, **kwargs): | |||
kwargs.setdefault('default', set_default) | |||
return json.dumps(obj, **kwargs) | |||
return jsonplus.dumps(obj, **kwargs) |
@@ -26,6 +26,10 @@ class Models(object): | |||
swinL_semantic_segmentation = 'swinL-semantic-segmentation' | |||
vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' | |||
# EasyCV models | |||
yolox = 'YOLOX' | |||
segformer = 'Segformer' | |||
# nlp models | |||
bert = 'bert' | |||
palm = 'palm-v2' | |||
@@ -92,6 +96,8 @@ class Pipelines(object): | |||
body_2d_keypoints = 'hrnetv2w32_body-2d-keypoints_image' | |||
human_detection = 'resnet18-human-detection' | |||
object_detection = 'vit-object-detection' | |||
easycv_detection = 'easycv-detection' | |||
easycv_segmentation = 'easycv-segmentation' | |||
salient_detection = 'u2net-salient-detection' | |||
image_classification = 'image-classification' | |||
face_detection = 'resnet-face-detection-scrfd10gkps' | |||
@@ -171,6 +177,7 @@ class Trainers(object): | |||
""" | |||
default = 'trainer' | |||
easycv = 'easycv' | |||
# multi-modal trainers | |||
clip_multi_modal_embedding = 'clip-multi-modal-embedding' | |||
@@ -307,3 +314,12 @@ class LR_Schedulers(object): | |||
LinearWarmup = 'LinearWarmup' | |||
ConstantWarmup = 'ConstantWarmup' | |||
ExponentialWarmup = 'ExponentialWarmup' | |||
class Datasets(object): | |||
""" Names for different datasets. | |||
""" | |||
ClsDataset = 'ClsDataset' | |||
SegDataset = 'SegDataset' | |||
DetDataset = 'DetDataset' | |||
DetImagesMixDataset = 'DetImagesMixDataset' |
@@ -1,4 +1,5 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Dict, Mapping, Union | |||
from modelscope.metainfo import Metrics | |||
from modelscope.utils.config import ConfigDict | |||
@@ -35,16 +36,19 @@ task_default_metrics = { | |||
} | |||
def build_metric(metric_name: str, | |||
def build_metric(metric_cfg: Union[str, Dict], | |||
field: str = default_group, | |||
default_args: dict = None): | |||
""" Build metric given metric_name and field. | |||
Args: | |||
metric_name (:obj:`str`): The metric name. | |||
metric_name (str | dict): The metric name or metric config dict. | |||
field (str, optional): The field of this metric, default value: 'default' for all fields. | |||
default_args (dict, optional): Default initialization arguments. | |||
""" | |||
cfg = ConfigDict({'type': metric_name}) | |||
if isinstance(metric_cfg, Mapping): | |||
assert 'type' in metric_cfg | |||
else: | |||
metric_cfg = ConfigDict({'type': metric_cfg}) | |||
return build_from_cfg( | |||
cfg, METRICS, group_key=field, default_args=default_args) | |||
metric_cfg, METRICS, group_key=field, default_args=default_args) |
@@ -0,0 +1,25 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from easycv.models.base import BaseModel | |||
from easycv.utils.ms_utils import EasyCVMeta | |||
from modelscope.models.base import TorchModel | |||
class EasyCVBaseModel(BaseModel, TorchModel): | |||
"""Base model for EasyCV.""" | |||
def __init__(self, model_dir=None, args=(), kwargs={}): | |||
kwargs.pop(EasyCVMeta.ARCH, None) # pop useless keys | |||
BaseModel.__init__(self) | |||
TorchModel.__init__(self, model_dir=model_dir) | |||
def forward(self, img, mode='train', **kwargs): | |||
if self.training: | |||
losses = self.forward_train(img, **kwargs) | |||
loss, log_vars = self._parse_losses(losses) | |||
return dict(loss=loss, log_vars=log_vars) | |||
else: | |||
return self.forward_test(img, **kwargs) | |||
def __call__(self, *args, **kwargs): | |||
return self.forward(*args, **kwargs) |
@@ -5,10 +5,12 @@ from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .semantic_seg_model import SemanticSegmentation | |||
from .segformer import Segformer | |||
else: | |||
_import_structure = { | |||
'semantic_seg_model': ['SemanticSegmentation'], | |||
'segformer': ['Segformer'] | |||
} | |||
import sys | |||
@@ -0,0 +1,16 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from easycv.models.segmentation import EncoderDecoder | |||
from modelscope.metainfo import Models | |||
from modelscope.models.builder import MODELS | |||
from modelscope.models.cv.easycv_base import EasyCVBaseModel | |||
from modelscope.utils.constant import Tasks | |||
@MODELS.register_module( | |||
group_key=Tasks.image_segmentation, module_name=Models.segformer) | |||
class Segformer(EasyCVBaseModel, EncoderDecoder): | |||
def __init__(self, model_dir=None, *args, **kwargs): | |||
EasyCVBaseModel.__init__(self, model_dir, args, kwargs) | |||
EncoderDecoder.__init__(self, *args, **kwargs) |
@@ -5,10 +5,12 @@ from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .mmdet_model import DetectionModel | |||
from .yolox_pai import YOLOX | |||
else: | |||
_import_structure = { | |||
'mmdet_model': ['DetectionModel'], | |||
'yolox_pai': ['YOLOX'] | |||
} | |||
import sys | |||
@@ -0,0 +1,16 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from easycv.models.detection.detectors import YOLOX as _YOLOX | |||
from modelscope.metainfo import Models | |||
from modelscope.models.builder import MODELS | |||
from modelscope.models.cv.easycv_base import EasyCVBaseModel | |||
from modelscope.utils.constant import Tasks | |||
@MODELS.register_module( | |||
group_key=Tasks.image_object_detection, module_name=Models.yolox) | |||
class YOLOX(EasyCVBaseModel, _YOLOX): | |||
def __init__(self, model_dir=None, *args, **kwargs): | |||
EasyCVBaseModel.__init__(self, model_dir, args, kwargs) | |||
_YOLOX.__init__(self, *args, **kwargs) |
@@ -1 +1,3 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from . import cv | |||
from .ms_dataset import MsDataset |
@@ -0,0 +1,3 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from . import (image_classification, image_semantic_segmentation, | |||
object_detection) |
@@ -0,0 +1,20 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .classification_dataset import ClsDataset | |||
else: | |||
_import_structure = {'classification_dataset': ['ClsDataset']} | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
__name__, | |||
globals()['__file__'], | |||
_import_structure, | |||
module_spec=__spec__, | |||
extra_objects={}, | |||
) |
@@ -0,0 +1,19 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from easycv.datasets.classification import ClsDataset as _ClsDataset | |||
from modelscope.metainfo import Datasets | |||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS | |||
from modelscope.utils.constant import Tasks | |||
@TASK_DATASETS.register_module( | |||
group_key=Tasks.image_classification, module_name=Datasets.ClsDataset) | |||
class ClsDataset(_ClsDataset): | |||
"""EasyCV dataset for classification. | |||
For more details, please refer to : | |||
https://github.com/alibaba/EasyCV/blob/master/easycv/datasets/classification/raw.py . | |||
Args: | |||
data_source: Data source config to parse input data. | |||
pipeline: Sequence of transform object or config dict to be composed. | |||
""" |
@@ -0,0 +1,20 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .segmentation_dataset import SegDataset | |||
else: | |||
_import_structure = {'easycv_segmentation': ['SegDataset']} | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
__name__, | |||
globals()['__file__'], | |||
_import_structure, | |||
module_spec=__spec__, | |||
extra_objects={}, | |||
) |
@@ -0,0 +1,21 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from easycv.datasets.segmentation import SegDataset as _SegDataset | |||
from modelscope.metainfo import Datasets | |||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS | |||
from modelscope.utils.constant import Tasks | |||
@TASK_DATASETS.register_module( | |||
group_key=Tasks.image_segmentation, module_name=Datasets.SegDataset) | |||
class SegDataset(_SegDataset): | |||
"""EasyCV dataset for Sementic segmentation. | |||
For more details, please refer to : | |||
https://github.com/alibaba/EasyCV/blob/master/easycv/datasets/segmentation/raw.py . | |||
Args: | |||
data_source: Data source config to parse input data. | |||
pipeline: Sequence of transform object or config dict to be composed. | |||
ignore_index (int): Label index to be ignored. | |||
profiling: If set True, will print transform time. | |||
""" |
@@ -0,0 +1,22 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .easycv_detection import DetDataset, DetImagesMixDataset | |||
else: | |||
_import_structure = { | |||
'easycv_detection': ['DetDataset', 'DetImagesMixDataset'] | |||
} | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
__name__, | |||
globals()['__file__'], | |||
_import_structure, | |||
module_spec=__spec__, | |||
extra_objects={}, | |||
) |
@@ -0,0 +1,49 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from easycv.datasets.detection import DetDataset as _DetDataset | |||
from easycv.datasets.detection import \ | |||
DetImagesMixDataset as _DetImagesMixDataset | |||
from modelscope.metainfo import Datasets | |||
from modelscope.msdatasets.task_datasets import TASK_DATASETS | |||
from modelscope.utils.constant import Tasks | |||
@TASK_DATASETS.register_module( | |||
group_key=Tasks.image_object_detection, module_name=Datasets.DetDataset) | |||
class DetDataset(_DetDataset): | |||
"""EasyCV dataset for object detection. | |||
For more details, please refer to https://github.com/alibaba/EasyCV/blob/master/easycv/datasets/detection/raw.py . | |||
Args: | |||
data_source: Data source config to parse input data. | |||
pipeline: Transform config list | |||
profiling: If set True, will print pipeline time | |||
classes: A list of class names, used in evaluation for result and groundtruth visualization | |||
""" | |||
@TASK_DATASETS.register_module( | |||
group_key=Tasks.image_object_detection, | |||
module_name=Datasets.DetImagesMixDataset) | |||
class DetImagesMixDataset(_DetImagesMixDataset): | |||
"""EasyCV dataset for object detection, a wrapper of multiple images mixed dataset. | |||
Suitable for training on multiple images mixed data augmentation like | |||
mosaic and mixup. For the augmentation pipeline of mixed image data, | |||
the `get_indexes` method needs to be provided to obtain the image | |||
indexes, and you can set `skip_flags` to change the pipeline running | |||
process. At the same time, we provide the `dynamic_scale` parameter | |||
to dynamically change the output image size. | |||
output boxes format: cx, cy, w, h | |||
For more details, please refer to https://github.com/alibaba/EasyCV/blob/master/easycv/datasets/detection/mix.py . | |||
Args: | |||
data_source (:obj:`DetSourceCoco`): Data source config to parse input data. | |||
pipeline (Sequence[dict]): Sequence of transform object or | |||
config dict to be composed. | |||
dynamic_scale (tuple[int], optional): The image scale can be changed | |||
dynamically. Default to None. | |||
skip_type_keys (list[str], optional): Sequence of type string to | |||
be skip pipeline. Default to None. | |||
label_padding: out labeling padding [N, 120, 5] | |||
""" |
@@ -240,9 +240,9 @@ class Pipeline(ABC): | |||
raise ValueError(f'Unsupported data type {type(data)}') | |||
def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]: | |||
preprocess_params = kwargs.get('preprocess_params') | |||
forward_params = kwargs.get('forward_params') | |||
postprocess_params = kwargs.get('postprocess_params') | |||
preprocess_params = kwargs.get('preprocess_params', {}) | |||
forward_params = kwargs.get('forward_params', {}) | |||
postprocess_params = kwargs.get('postprocess_params', {}) | |||
out = self.preprocess(input, **preprocess_params) | |||
with device_placement(self.framework, self.device_name): | |||
@@ -39,7 +39,7 @@ if TYPE_CHECKING: | |||
from .tinynas_classification_pipeline import TinynasClassificationPipeline | |||
from .video_category_pipeline import VideoCategoryPipeline | |||
from .virtual_try_on_pipeline import VirtualTryonPipeline | |||
from .easycv_pipelines import EasyCVDetectionPipeline, EasyCVSegmentationPipeline | |||
else: | |||
_import_structure = { | |||
'action_recognition_pipeline': ['ActionRecognitionPipeline'], | |||
@@ -84,6 +84,8 @@ else: | |||
'tinynas_classification_pipeline': ['TinynasClassificationPipeline'], | |||
'video_category_pipeline': ['VideoCategoryPipeline'], | |||
'virtual_try_on_pipeline': ['VirtualTryonPipeline'], | |||
'easycv_pipeline': | |||
['EasyCVDetectionPipeline', 'EasyCVSegmentationPipeline'] | |||
} | |||
import sys | |||
@@ -0,0 +1,23 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .detection_pipeline import EasyCVDetectionPipeline | |||
from .segmentation_pipeline import EasyCVSegmentationPipeline | |||
else: | |||
_import_structure = { | |||
'detection_pipeline': ['EasyCVDetectionPipeline'], | |||
'segmentation_pipeline': ['EasyCVSegmentationPipeline'] | |||
} | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
__name__, | |||
globals()['__file__'], | |||
_import_structure, | |||
module_spec=__spec__, | |||
extra_objects={}, | |||
) |
@@ -0,0 +1,95 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import glob | |||
import os | |||
import os.path as osp | |||
from typing import Any | |||
from easycv.utils.ms_utils import EasyCVMeta | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.pipelines.util import is_official_hub_path | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile | |||
class EasyCVPipeline(object): | |||
"""Base pipeline for EasyCV. | |||
Loading configuration file of modelscope style by default, | |||
but it is actually use the predictor api of easycv to predict. | |||
So here we do some adaptation work for configuration and predict api. | |||
""" | |||
def __init__(self, model: str, model_file_pattern='*.pt', *args, **kwargs): | |||
""" | |||
model (str): model id on modelscope hub or local model path. | |||
model_file_pattern (str): model file pattern. | |||
""" | |||
self.model_file_pattern = model_file_pattern | |||
assert isinstance(model, str) | |||
if osp.exists(model): | |||
model_dir = model | |||
else: | |||
assert is_official_hub_path( | |||
model), 'Only support local model path and official hub path!' | |||
model_dir = snapshot_download( | |||
model_id=model, revision=DEFAULT_MODEL_REVISION) | |||
assert osp.isdir(model_dir) | |||
model_files = glob.glob( | |||
os.path.join(model_dir, self.model_file_pattern)) | |||
assert len( | |||
model_files | |||
) == 1, f'Need one model file, but find {len(model_files)}: {model_files}' | |||
model_path = model_files[0] | |||
self.model_path = model_path | |||
# get configuration file from source model dir | |||
self.config_file = os.path.join(model_dir, ModelFile.CONFIGURATION) | |||
assert os.path.exists( | |||
self.config_file | |||
), f'Not find "{ModelFile.CONFIGURATION}" in model directory!' | |||
self.cfg = Config.from_file(self.config_file) | |||
self.predict_op = self._build_predict_op() | |||
def _build_predict_op(self): | |||
"""Build EasyCV predictor.""" | |||
from easycv.predictors.builder import build_predictor | |||
easycv_config = self._to_easycv_config() | |||
pipeline_op = build_predictor(self.cfg.pipeline.predictor_config, { | |||
'model_path': self.model_path, | |||
'config_file': easycv_config | |||
}) | |||
return pipeline_op | |||
def _to_easycv_config(self): | |||
"""Adapt to EasyCV predictor.""" | |||
# TODO: refine config compatibility problems | |||
easycv_arch = self.cfg.model.pop(EasyCVMeta.ARCH, None) | |||
model_cfg = self.cfg.model | |||
# Revert to the configuration of easycv | |||
if easycv_arch is not None: | |||
model_cfg.update(easycv_arch) | |||
easycv_config = Config(dict(model=model_cfg)) | |||
reserved_keys = [] | |||
if hasattr(self.cfg, EasyCVMeta.META): | |||
easycv_meta_cfg = getattr(self.cfg, EasyCVMeta.META) | |||
reserved_keys = easycv_meta_cfg.get(EasyCVMeta.RESERVED_KEYS, []) | |||
for key in reserved_keys: | |||
easycv_config.merge_from_dict({key: getattr(self.cfg, key)}) | |||
if 'test_pipeline' not in reserved_keys: | |||
easycv_config.merge_from_dict( | |||
{'test_pipeline': self.cfg.dataset.val.get('pipeline', [])}) | |||
return easycv_config | |||
def __call__(self, inputs) -> Any: | |||
# TODO: support image url | |||
return self.predict_op(inputs) |
@@ -0,0 +1,23 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.utils.constant import Tasks | |||
from .base import EasyCVPipeline | |||
@PIPELINES.register_module( | |||
Tasks.image_object_detection, module_name=Pipelines.easycv_detection) | |||
class EasyCVDetectionPipeline(EasyCVPipeline): | |||
"""Pipeline for easycv detection task.""" | |||
def __init__(self, model: str, model_file_pattern='*.pt', *args, **kwargs): | |||
""" | |||
model (str): model id on modelscope hub or local model path. | |||
model_file_pattern (str): model file pattern. | |||
""" | |||
super(EasyCVDetectionPipeline, self).__init__( | |||
model=model, | |||
model_file_pattern=model_file_pattern, | |||
*args, | |||
**kwargs) |
@@ -0,0 +1,23 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.utils.constant import Tasks | |||
from .base import EasyCVPipeline | |||
@PIPELINES.register_module( | |||
Tasks.image_segmentation, module_name=Pipelines.easycv_segmentation) | |||
class EasyCVSegmentationPipeline(EasyCVPipeline): | |||
"""Pipeline for easycv segmentation task.""" | |||
def __init__(self, model: str, model_file_pattern='*.pt', *args, **kwargs): | |||
""" | |||
model (str): model id on modelscope hub or local model path. | |||
model_file_pattern (str): model file pattern. | |||
""" | |||
super(EasyCVSegmentationPipeline, self).__init__( | |||
model=model, | |||
model_file_pattern=model_file_pattern, | |||
*args, | |||
**kwargs) |
@@ -0,0 +1,175 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from functools import partial | |||
from typing import Callable, Optional, Tuple, Union | |||
import torch | |||
from torch import nn | |||
from torch.utils.data import Dataset | |||
from modelscope.metainfo import Trainers | |||
from modelscope.models.base import TorchModel | |||
from modelscope.msdatasets import MsDataset | |||
from modelscope.preprocessors import Preprocessor | |||
from modelscope.trainers import EpochBasedTrainer | |||
from modelscope.trainers.base import TRAINERS | |||
from modelscope.trainers.easycv.utils import register_util | |||
from modelscope.trainers.hooks import HOOKS | |||
from modelscope.trainers.parallel.builder import build_parallel | |||
from modelscope.trainers.parallel.utils import is_parallel | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION | |||
from modelscope.utils.import_utils import LazyImportModule | |||
from modelscope.utils.registry import default_group | |||
@TRAINERS.register_module(module_name=Trainers.easycv) | |||
class EasyCVEpochBasedTrainer(EpochBasedTrainer): | |||
"""Epoch based Trainer for EasyCV. | |||
Args: | |||
task: Task name. | |||
cfg_file(str): The config file of EasyCV. | |||
model (:obj:`torch.nn.Module` or :obj:`TorchModel` or `str`): The model to be run, or a valid model dir | |||
or a model id. If model is None, build_model method will be called. | |||
train_dataset (`MsDataset` or `torch.utils.data.Dataset`, *optional*): | |||
The dataset to use for training. | |||
Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a | |||
distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a | |||
`torch.Generator` for the randomization that must be identical on all processes (and the Trainer will | |||
manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally | |||
sets the seed of the RNGs used. | |||
eval_dataset (`MsDataset` or `torch.utils.data.Dataset`, *optional*): The dataset to use for evaluation. | |||
preprocessor (:obj:`Preprocessor`, *optional*): The optional preprocessor. | |||
NOTE: If the preprocessor has been called before the dataset fed into this trainer by user's custom code, | |||
this parameter should be None, meanwhile remove the 'preprocessor' key from the cfg_file. | |||
Else the preprocessor will be instantiated from the cfg_file or assigned from this parameter and | |||
this preprocessing action will be executed every time the dataset's __getitem__ is called. | |||
optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]`, *optional*): A tuple | |||
containing the optimizer and the scheduler to use. | |||
max_epochs: (int, optional): Total training epochs. | |||
""" | |||
def __init__( | |||
self, | |||
task: str, | |||
cfg_file: Optional[str] = None, | |||
model: Optional[Union[TorchModel, nn.Module, str]] = None, | |||
arg_parse_fn: Optional[Callable] = None, | |||
train_dataset: Optional[Union[MsDataset, Dataset]] = None, | |||
eval_dataset: Optional[Union[MsDataset, Dataset]] = None, | |||
preprocessor: Optional[Preprocessor] = None, | |||
optimizers: Tuple[torch.optim.Optimizer, | |||
torch.optim.lr_scheduler._LRScheduler] = (None, | |||
None), | |||
model_revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||
**kwargs): | |||
self.task = task | |||
register_util.register_parallel() | |||
register_util.register_part_mmcv_hooks_to_ms() | |||
super(EasyCVEpochBasedTrainer, self).__init__( | |||
model=model, | |||
cfg_file=cfg_file, | |||
arg_parse_fn=arg_parse_fn, | |||
preprocessor=preprocessor, | |||
optimizers=optimizers, | |||
model_revision=model_revision, | |||
train_dataset=train_dataset, | |||
eval_dataset=eval_dataset, | |||
**kwargs) | |||
# reset data_collator | |||
from mmcv.parallel import collate | |||
self.train_data_collator = partial( | |||
collate, | |||
samples_per_gpu=self.cfg.train.dataloader.batch_size_per_gpu) | |||
self.eval_data_collator = partial( | |||
collate, | |||
samples_per_gpu=self.cfg.evaluation.dataloader.batch_size_per_gpu) | |||
# Register easycv hooks dynamicly. If the hook already exists in modelscope, | |||
# the hook in modelscope will be used, otherwise register easycv hook into ms. | |||
# We must manually trigger lazy import to detect whether the hook is in modelscope. | |||
# TODO: use ast index to detect whether the hook is in modelscope | |||
for h_i in self.cfg.train.get('hooks', []): | |||
sig = ('HOOKS', default_group, h_i['type']) | |||
LazyImportModule.import_module(sig) | |||
if h_i['type'] not in HOOKS._modules[default_group]: | |||
if h_i['type'] in [ | |||
'TensorboardLoggerHookV2', 'WandbLoggerHookV2' | |||
]: | |||
raise ValueError( | |||
'Not support hook %s now, we will support it in the future!' | |||
% h_i['type']) | |||
register_util.register_hook_to_ms(h_i['type'], self.logger) | |||
# reset parallel | |||
if not self._dist: | |||
assert not is_parallel( | |||
self.model | |||
), 'Not support model wrapped by custom parallel if not in distributed mode!' | |||
dp_cfg = dict( | |||
type='MMDataParallel', | |||
module=self.model, | |||
device_ids=[torch.cuda.current_device()]) | |||
self.model = build_parallel(dp_cfg) | |||
def create_optimizer_and_scheduler(self): | |||
""" Create optimizer and lr scheduler | |||
""" | |||
optimizer, lr_scheduler = self.optimizers | |||
if optimizer is None: | |||
optimizer_cfg = self.cfg.train.get('optimizer', None) | |||
else: | |||
optimizer_cfg = None | |||
optim_options = {} | |||
if optimizer_cfg is not None: | |||
optim_options = optimizer_cfg.pop('options', {}) | |||
from easycv.apis.train import build_optimizer | |||
optimizer = build_optimizer(self.model, optimizer_cfg) | |||
if lr_scheduler is None: | |||
lr_scheduler_cfg = self.cfg.train.get('lr_scheduler', None) | |||
else: | |||
lr_scheduler_cfg = None | |||
lr_options = {} | |||
# Adapt to mmcv lr scheduler hook. | |||
# Please refer to: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/lr_updater.py | |||
if lr_scheduler_cfg is not None: | |||
assert optimizer is not None | |||
lr_options = lr_scheduler_cfg.pop('options', {}) | |||
assert 'policy' in lr_scheduler_cfg | |||
policy_type = lr_scheduler_cfg.pop('policy') | |||
if policy_type == policy_type.lower(): | |||
policy_type = policy_type.title() | |||
hook_type = policy_type + 'LrUpdaterHook' | |||
lr_scheduler_cfg['type'] = hook_type | |||
self.cfg.train.lr_scheduler_hook = lr_scheduler_cfg | |||
self.optimizer = optimizer | |||
self.lr_scheduler = lr_scheduler | |||
return self.optimizer, self.lr_scheduler, optim_options, lr_options | |||
def to_parallel(self, model) -> Union[nn.Module, TorchModel]: | |||
if self.cfg.get('parallel', None) is not None: | |||
self.cfg.parallel.update( | |||
dict(module=model, device_ids=[torch.cuda.current_device()])) | |||
return build_parallel(self.cfg.parallel) | |||
dp_cfg = dict( | |||
type='MMDistributedDataParallel', | |||
module=model, | |||
device_ids=[torch.cuda.current_device()]) | |||
return build_parallel(dp_cfg) | |||
def rebuild_config(self, cfg: Config): | |||
cfg.task = self.task | |||
return cfg |
@@ -0,0 +1,21 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .hooks import AddLrLogHook | |||
from .metric import EasyCVMetric | |||
else: | |||
_import_structure = {'hooks': ['AddLrLogHook'], 'metric': ['EasyCVMetric']} | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
__name__, | |||
globals()['__file__'], | |||
_import_structure, | |||
module_spec=__spec__, | |||
extra_objects={}, | |||
) |
@@ -0,0 +1,29 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from modelscope.trainers.hooks import HOOKS, Priority | |||
from modelscope.trainers.hooks.lr_scheduler_hook import LrSchedulerHook | |||
from modelscope.utils.constant import LogKeys | |||
@HOOKS.register_module(module_name='AddLrLogHook') | |||
class AddLrLogHook(LrSchedulerHook): | |||
"""For EasyCV to adapt to ModelScope, the lr log of EasyCV is added in the trainer, | |||
but the trainer of ModelScope does not and it is added in the lr scheduler hook. | |||
But The lr scheduler hook used by EasyCV is the hook of mmcv, and there is no lr log. | |||
It will be deleted in the future. | |||
""" | |||
PRIORITY = Priority.NORMAL | |||
def __init__(self): | |||
pass | |||
def before_run(self, trainer): | |||
pass | |||
def before_train_iter(self, trainer): | |||
trainer.log_buffer.output[LogKeys.LR] = self._get_log_lr(trainer) | |||
def before_train_epoch(self, trainer): | |||
trainer.log_buffer.output[LogKeys.LR] = self._get_log_lr(trainer) | |||
def after_train_epoch(self, trainer): | |||
pass |
@@ -0,0 +1,52 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import itertools | |||
from typing import Dict | |||
import numpy as np | |||
import torch | |||
from modelscope.metrics.base import Metric | |||
from modelscope.metrics.builder import METRICS | |||
@METRICS.register_module(module_name='EasyCVMetric') | |||
class EasyCVMetric(Metric): | |||
"""Adapt to ModelScope Metric for EasyCV evaluator. | |||
""" | |||
def __init__(self, trainer=None, evaluators=None, *args, **kwargs): | |||
from easycv.core.evaluation.builder import build_evaluator | |||
self.trainer = trainer | |||
self.evaluators = build_evaluator(evaluators) | |||
self.preds = [] | |||
self.grountruths = [] | |||
def add(self, outputs: Dict, inputs: Dict): | |||
self.preds.append(outputs) | |||
del inputs | |||
def evaluate(self): | |||
results = {} | |||
for _, batch in enumerate(self.preds): | |||
for k, v in batch.items(): | |||
if k not in results: | |||
results[k] = [] | |||
results[k].append(v) | |||
for k, v in results.items(): | |||
if len(v) == 0: | |||
raise ValueError(f'empty result for {k}') | |||
if isinstance(v[0], torch.Tensor): | |||
results[k] = torch.cat(v, 0) | |||
elif isinstance(v[0], (list, np.ndarray)): | |||
results[k] = list(itertools.chain.from_iterable(v)) | |||
else: | |||
raise ValueError( | |||
f'value of batch prediction dict should only be tensor or list, {k} type is {v[0]}' | |||
) | |||
metric_values = self.trainer.eval_dataset.evaluate( | |||
results, self.evaluators) | |||
return metric_values |
@@ -0,0 +1,59 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import inspect | |||
import logging | |||
from modelscope.trainers.hooks import HOOKS | |||
from modelscope.trainers.parallel.builder import PARALLEL | |||
def register_parallel(): | |||
from mmcv.parallel import MMDistributedDataParallel, MMDataParallel | |||
PARALLEL.register_module( | |||
module_name='MMDistributedDataParallel', | |||
module_cls=MMDistributedDataParallel) | |||
PARALLEL.register_module( | |||
module_name='MMDataParallel', module_cls=MMDataParallel) | |||
def register_hook_to_ms(hook_name, logger=None): | |||
"""Register EasyCV hook to ModelScope.""" | |||
from easycv.hooks import HOOKS as _EV_HOOKS | |||
if hook_name not in _EV_HOOKS._module_dict: | |||
raise ValueError( | |||
f'Not found hook "{hook_name}" in EasyCV hook registries!') | |||
obj = _EV_HOOKS._module_dict[hook_name] | |||
HOOKS.register_module(module_name=hook_name, module_cls=obj) | |||
log_str = f'Register hook "{hook_name}" to modelscope hooks.' | |||
logger.info(log_str) if logger is not None else logging.info(log_str) | |||
def register_part_mmcv_hooks_to_ms(): | |||
"""Register required mmcv hooks to ModelScope. | |||
Currently we only registered all lr scheduler hooks in EasyCV and mmcv. | |||
Please refer to: | |||
EasyCV: https://github.com/alibaba/EasyCV/blob/master/easycv/hooks/lr_update_hook.py | |||
mmcv: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/lr_updater.py | |||
""" | |||
from mmcv.runner.hooks import lr_updater | |||
from mmcv.runner.hooks import HOOKS as _MMCV_HOOKS | |||
from easycv.hooks import StepFixCosineAnnealingLrUpdaterHook, YOLOXLrUpdaterHook | |||
from easycv.hooks.logger import PreLoggerHook | |||
mmcv_hooks_in_easycv = [('StepFixCosineAnnealingLrUpdaterHook', | |||
StepFixCosineAnnealingLrUpdaterHook), | |||
('YOLOXLrUpdaterHook', YOLOXLrUpdaterHook), | |||
('PreLoggerHook', PreLoggerHook)] | |||
members = inspect.getmembers(lr_updater) | |||
members.extend(mmcv_hooks_in_easycv) | |||
for name, obj in members: | |||
if name in _MMCV_HOOKS._module_dict: | |||
HOOKS.register_module( | |||
module_name=name, | |||
module_cls=obj, | |||
) |
@@ -81,12 +81,19 @@ class CheckpointHook(Hook): | |||
if self.is_last_epoch(trainer) and self.by_epoch: | |||
output_dir = os.path.join(self.save_dir, | |||
ModelFile.TRAIN_OUTPUT_DIR) | |||
trainer.model.save_pretrained( | |||
output_dir, | |||
ModelFile.TORCH_MODEL_BIN_FILE, | |||
save_function=save_checkpoint, | |||
config=trainer.cfg.to_dict()) | |||
from modelscope.trainers.parallel.utils import is_parallel | |||
if is_parallel(trainer.model): | |||
model = trainer.model.module | |||
else: | |||
model = trainer.model | |||
if hasattr(model, 'save_pretrained'): | |||
model.save_pretrained( | |||
output_dir, | |||
ModelFile.TORCH_MODEL_BIN_FILE, | |||
save_function=save_checkpoint, | |||
config=trainer.cfg.to_dict()) | |||
def after_train_iter(self, trainer): | |||
if self.by_epoch: | |||
@@ -60,6 +60,18 @@ class LoggerHook(Hook): | |||
else: | |||
return False | |||
def fetch_tensor(self, trainer, n=0): | |||
"""Fetch latest n values or all values, process tensor type, convert to numpy for dump logs.""" | |||
assert n >= 0 | |||
for key in trainer.log_buffer.val_history: | |||
values = trainer.log_buffer.val_history[key][-n:] | |||
for i, v in enumerate(values): | |||
if isinstance(v, torch.Tensor): | |||
values[i] = v.clone().detach().cpu().numpy() | |||
trainer.log_buffer.val_history[key][-n:] = values | |||
def get_epoch(self, trainer): | |||
if trainer.mode in [ModeKeys.TRAIN, ModeKeys.EVAL]: | |||
epoch = trainer.epoch + 1 | |||
@@ -88,11 +100,14 @@ class LoggerHook(Hook): | |||
def after_train_iter(self, trainer): | |||
if self.by_epoch and self.every_n_inner_iters(trainer, self.interval): | |||
self.fetch_tensor(trainer, self.interval) | |||
trainer.log_buffer.average(self.interval) | |||
elif not self.by_epoch and self.every_n_iters(trainer, self.interval): | |||
self.fetch_tensor(trainer, self.interval) | |||
trainer.log_buffer.average(self.interval) | |||
elif self.end_of_epoch(trainer) and not self.ignore_last: | |||
# not precise but more stable | |||
self.fetch_tensor(trainer, self.interval) | |||
trainer.log_buffer.average(self.interval) | |||
if trainer.log_buffer.ready: | |||
@@ -107,6 +122,7 @@ class LoggerHook(Hook): | |||
trainer.log_buffer.clear_output() | |||
def after_val_epoch(self, trainer): | |||
self.fetch_tensor(trainer) | |||
trainer.log_buffer.average() | |||
self.log(trainer) | |||
if self.reset_flag: | |||
@@ -26,7 +26,6 @@ from modelscope.msdatasets.task_datasets.torch_base_dataset import \ | |||
TorchTaskDataset | |||
from modelscope.preprocessors.base import Preprocessor | |||
from modelscope.preprocessors.builder import build_preprocessor | |||
from modelscope.preprocessors.common import Compose | |||
from modelscope.trainers.hooks.builder import HOOKS | |||
from modelscope.trainers.hooks.priority import Priority, get_priority | |||
from modelscope.trainers.lrscheduler.builder import build_lr_scheduler | |||
@@ -83,7 +82,8 @@ class EpochBasedTrainer(BaseTrainer): | |||
model: Optional[Union[TorchModel, nn.Module, str]] = None, | |||
cfg_file: Optional[str] = None, | |||
arg_parse_fn: Optional[Callable] = None, | |||
data_collator: Optional[Callable] = None, | |||
data_collator: Optional[Union[Callable, Dict[str, | |||
Callable]]] = None, | |||
train_dataset: Optional[Union[MsDataset, Dataset]] = None, | |||
eval_dataset: Optional[Union[MsDataset, Dataset]] = None, | |||
preprocessor: Optional[Union[Preprocessor, | |||
@@ -104,21 +104,24 @@ class EpochBasedTrainer(BaseTrainer): | |||
if cfg_file is None: | |||
cfg_file = os.path.join(self.model_dir, | |||
ModelFile.CONFIGURATION) | |||
self.model = self.build_model() | |||
else: | |||
assert cfg_file is not None, 'Config file should not be None if model is an nn.Module class' | |||
assert isinstance( | |||
model, | |||
(TorchModel, nn.Module | |||
)), 'model should be either str, TorchMode or nn.Module.' | |||
assert cfg_file is not None, 'Config file should not be None if model is not from pretrained!' | |||
self.model_dir = os.path.dirname(cfg_file) | |||
self.model = model | |||
super().__init__(cfg_file, arg_parse_fn) | |||
# add default config | |||
self.cfg.merge_from_dict(self._get_default_config(), force=False) | |||
self.cfg = self.rebuild_config(self.cfg) | |||
if 'cfg_options' in kwargs: | |||
self.cfg.merge_from_dict(kwargs['cfg_options']) | |||
if isinstance(model, (TorchModel, nn.Module)): | |||
self.model = model | |||
else: | |||
self.model = self.build_model() | |||
if 'work_dir' in kwargs: | |||
self.work_dir = kwargs['work_dir'] | |||
else: | |||
@@ -162,7 +165,24 @@ class EpochBasedTrainer(BaseTrainer): | |||
mode=ModeKeys.EVAL, | |||
preprocessor=self.eval_preprocessor) | |||
self.data_collator = data_collator if data_collator is not None else default_collate | |||
self.train_data_collator, self.eval_default_collate = None, None | |||
if isinstance(data_collator, Mapping): | |||
if not (ConfigKeys.train in data_collator | |||
or ConfigKeys.val in data_collator): | |||
raise ValueError( | |||
f'data_collator must split with `{ConfigKeys.train}` and `{ConfigKeys.val}` keys!' | |||
) | |||
if ConfigKeys.train in data_collator: | |||
assert isinstance(data_collator[ConfigKeys.train], Callable) | |||
self.train_data_collator = data_collator[ConfigKeys.train] | |||
if ConfigKeys.val in data_collator: | |||
assert isinstance(data_collator[ConfigKeys.val], Callable) | |||
self.eval_data_collator = data_collator[ConfigKeys.val] | |||
else: | |||
collate_fn = default_collate if data_collator is None else data_collator | |||
self.train_data_collator = collate_fn | |||
self.eval_data_collator = collate_fn | |||
self.metrics = self.get_metrics() | |||
self._metric_values = None | |||
self.optimizers = optimizers | |||
@@ -364,7 +384,7 @@ class EpochBasedTrainer(BaseTrainer): | |||
return train_preprocessor, eval_preprocessor | |||
def get_metrics(self) -> List[str]: | |||
def get_metrics(self) -> List[Union[str, Dict]]: | |||
"""Get the metric class types. | |||
The first choice will be the metrics configured in the config file, if not found, the default metrics will be | |||
@@ -384,7 +404,7 @@ class EpochBasedTrainer(BaseTrainer): | |||
f'Metrics are needed in evaluation, please try to either ' | |||
f'add metrics in configuration.json or add the default metric for {self.cfg.task}.' | |||
) | |||
if isinstance(metrics, str): | |||
if isinstance(metrics, (str, Mapping)): | |||
metrics = [metrics] | |||
return metrics | |||
@@ -399,6 +419,7 @@ class EpochBasedTrainer(BaseTrainer): | |||
self.train_dataset, | |||
dist=self._dist, | |||
seed=self._seed, | |||
collate_fn=self.train_data_collator, | |||
**self.cfg.train.get('dataloader', {})) | |||
self.data_loader = self.train_dataloader | |||
@@ -418,6 +439,7 @@ class EpochBasedTrainer(BaseTrainer): | |||
self.eval_dataset, | |||
dist=self._dist, | |||
seed=self._seed, | |||
collate_fn=self.eval_data_collator, | |||
**self.cfg.evaluation.get('dataloader', {})) | |||
self.data_loader = self.eval_dataloader | |||
metric_classes = [build_metric(metric) for metric in self.metrics] | |||
@@ -440,7 +462,7 @@ class EpochBasedTrainer(BaseTrainer): | |||
override this method in a subclass. | |||
""" | |||
model = Model.from_pretrained(self.model_dir) | |||
model = Model.from_pretrained(self.model_dir, cfg_dict=self.cfg) | |||
if not isinstance(model, nn.Module) and hasattr(model, 'model'): | |||
return model.model | |||
elif isinstance(model, nn.Module): | |||
@@ -552,6 +574,7 @@ class EpochBasedTrainer(BaseTrainer): | |||
self.train_dataset, | |||
dist=self._dist, | |||
seed=self._seed, | |||
collate_fn=self.train_data_collator, | |||
**self.cfg.train.get('dataloader', {})) | |||
return data_loader | |||
@@ -569,9 +592,9 @@ class EpochBasedTrainer(BaseTrainer): | |||
mode=ModeKeys.EVAL, | |||
preprocessor=self.eval_preprocessor) | |||
batch_size = self.cfg.evaluation.batch_size | |||
workers = self.cfg.evaluation.workers | |||
shuffle = self.cfg.evaluation.get('shuffle', False) | |||
batch_size = self.cfg.evaluation.dataloader.batch_size_per_gpu | |||
workers = self.cfg.evaluation.dataloader.workers_per_gpu | |||
shuffle = self.cfg.evaluation.dataloader.get('shuffle', False) | |||
data_loader = self._build_dataloader_with_dataset( | |||
self.eval_dataset, | |||
batch_size_per_gpu=batch_size, | |||
@@ -580,25 +603,31 @@ class EpochBasedTrainer(BaseTrainer): | |||
dist=self._dist, | |||
seed=self._seed, | |||
persistent_workers=True, | |||
collate_fn=self.eval_data_collator, | |||
) | |||
return data_loader | |||
def build_dataset(self, data_cfg, mode, preprocessor=None): | |||
""" Build torch dataset object using data config | |||
""" | |||
dataset = MsDataset.load( | |||
dataset_name=data_cfg.name, | |||
split=data_cfg.split, | |||
subset_name=data_cfg.subset_name if hasattr( | |||
data_cfg, 'subset_name') else None, | |||
hub=data_cfg.hub if hasattr(data_cfg, 'hub') else Hubs.modelscope, | |||
**data_cfg, | |||
) | |||
cfg = ConfigDict(type=self.cfg.model.type, mode=mode) | |||
torch_dataset = dataset.to_torch_dataset( | |||
task_data_config=cfg, | |||
task_name=self.cfg.task, | |||
preprocessors=self.preprocessor) | |||
# TODO: support MsDataset load for cv | |||
if hasattr(data_cfg, 'name'): | |||
dataset = MsDataset.load( | |||
dataset_name=data_cfg.name, | |||
split=data_cfg.split, | |||
subset_name=data_cfg.subset_name if hasattr( | |||
data_cfg, 'subset_name') else None, | |||
hub=data_cfg.hub | |||
if hasattr(data_cfg, 'hub') else Hubs.modelscope, | |||
**data_cfg, | |||
) | |||
cfg = ConfigDict(type=self.cfg.model.type, mode=mode) | |||
torch_dataset = dataset.to_torch_dataset( | |||
task_data_config=cfg, | |||
task_name=self.cfg.task, | |||
preprocessors=self.preprocessor) | |||
else: | |||
torch_dataset = build_task_dataset(data_cfg, self.cfg.task) | |||
dataset = self.to_task_dataset(torch_dataset, mode) | |||
return dataset | |||
@@ -746,7 +775,6 @@ class EpochBasedTrainer(BaseTrainer): | |||
sampler=sampler, | |||
num_workers=num_workers, | |||
batch_sampler=batch_sampler, | |||
collate_fn=self.data_collator, | |||
pin_memory=kwargs.pop('pin_memory', False), | |||
worker_init_fn=init_fn, | |||
**kwargs) | |||
@@ -820,12 +848,14 @@ class EpochBasedTrainer(BaseTrainer): | |||
Args: | |||
hook (:obj:`Hook`): The hook to be registered. | |||
""" | |||
assert isinstance(hook, Hook) | |||
# insert the hook to a sorted list | |||
inserted = False | |||
for i in range(len(self._hooks) - 1, -1, -1): | |||
if get_priority(hook.PRIORITY) > get_priority( | |||
self._hooks[i].PRIORITY): | |||
p = hook.PRIORITY if hasattr(hook, 'PRIORITY') else Priority.NORMAL | |||
p_i = self._hooks[i].PRIORITY if hasattr( | |||
self._hooks[i], 'PRIORITY') else Priority.NORMAL | |||
if get_priority(p) > get_priority(p_i): | |||
self._hooks.insert(i + 1, hook) | |||
inserted = True | |||
break | |||
@@ -15,9 +15,9 @@ import json | |||
from modelscope import __version__ | |||
from modelscope.fileio.file import LocalStorage | |||
from modelscope.metainfo import (Heads, Hooks, LR_Schedulers, Metrics, Models, | |||
Optimizers, Pipelines, Preprocessors, | |||
TaskModels, Trainers) | |||
from modelscope.metainfo import (Datasets, Heads, Hooks, LR_Schedulers, | |||
Metrics, Models, Optimizers, Pipelines, | |||
Preprocessors, TaskModels, Trainers) | |||
from modelscope.utils.constant import Fields, Tasks | |||
from modelscope.utils.file_utils import get_default_cache_dir | |||
from modelscope.utils.logger import get_logger | |||
@@ -32,8 +32,7 @@ MODELSCOPE_PATH = p.resolve().parents[1] | |||
REGISTER_MODULE = 'register_module' | |||
IGNORED_PACKAGES = ['modelscope', '.'] | |||
SCAN_SUB_FOLDERS = [ | |||
'models', 'metrics', 'pipelines', 'preprocessors', | |||
'msdatasets/task_datasets', 'trainers' | |||
'models', 'metrics', 'pipelines', 'preprocessors', 'trainers', 'msdatasets' | |||
] | |||
INDEXER_FILE = 'ast_indexer' | |||
DECORATOR_KEY = 'decorators' | |||
@@ -14,7 +14,7 @@ mmcls>=0.21.0 | |||
mmdet>=2.25.0 | |||
networkx>=2.5 | |||
onnxruntime>=1.10 | |||
pai-easycv>=0.5 | |||
pai-easycv>=0.6.0 | |||
pandas | |||
psutil | |||
regex | |||
@@ -4,6 +4,7 @@ easydict | |||
einops | |||
filelock>=3.3.0 | |||
gast>=0.2.2 | |||
jsonplus | |||
numpy | |||
opencv-python | |||
oss2 | |||
@@ -0,0 +1,35 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import unittest | |||
import numpy as np | |||
from PIL import Image | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.pipelines import pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.test_utils import test_level | |||
class EasyCVSegmentationPipelineTest(unittest.TestCase): | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_segformer_b0(self): | |||
img_path = 'data/test/images/image_segmentation.jpg' | |||
model_id = 'EasyCV/EasyCV-Segformer-b0' | |||
img = np.asarray(Image.open(img_path)) | |||
object_detect = pipeline(task=Tasks.image_segmentation, model=model_id) | |||
outputs = object_detect(img_path) | |||
self.assertEqual(len(outputs), 1) | |||
results = outputs[0] | |||
self.assertListEqual( | |||
list(img.shape)[:2], list(results['seg_pred'][0].shape)) | |||
self.assertListEqual(results['seg_pred'][0][1, :10].tolist(), | |||
[161 for i in range(10)]) | |||
self.assertListEqual(results['seg_pred'][0][-1, -10:].tolist(), | |||
[133 for i in range(10)]) | |||
if __name__ == '__main__': | |||
unittest.main() |
@@ -0,0 +1,244 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import glob | |||
import os | |||
import shutil | |||
import tempfile | |||
import unittest | |||
import json | |||
import requests | |||
import torch | |||
from modelscope.metainfo import Models, Pipelines, Trainers | |||
from modelscope.trainers import build_trainer | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import LogKeys, ModeKeys, Tasks | |||
from modelscope.utils.logger import get_logger | |||
from modelscope.utils.test_utils import DistributedTestCase, test_level | |||
from modelscope.utils.torch_utils import is_master | |||
def _download_data(url, save_dir): | |||
r = requests.get(url, verify=True) | |||
if not os.path.exists(save_dir): | |||
os.makedirs(save_dir) | |||
zip_name = os.path.split(url)[-1] | |||
save_path = os.path.join(save_dir, zip_name) | |||
with open(save_path, 'wb') as f: | |||
f.write(r.content) | |||
unpack_dir = os.path.join(save_dir, os.path.splitext(zip_name)[0]) | |||
shutil.unpack_archive(save_path, unpack_dir) | |||
def train_func(work_dir, dist=False, log_config=3, imgs_per_gpu=4): | |||
import easycv | |||
config_path = os.path.join( | |||
os.path.dirname(easycv.__file__), | |||
'configs/detection/yolox/yolox_s_8xb16_300e_coco.py') | |||
data_dir = os.path.join(work_dir, 'small_coco_test') | |||
url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/datasets/small_coco.zip' | |||
if is_master(): | |||
_download_data(url, data_dir) | |||
import time | |||
time.sleep(1) | |||
cfg = Config.from_file(config_path) | |||
cfg.work_dir = work_dir | |||
cfg.total_epochs = 2 | |||
cfg.checkpoint_config.interval = 1 | |||
cfg.eval_config.interval = 1 | |||
cfg.log_config = dict( | |||
interval=log_config, | |||
hooks=[ | |||
dict(type='TextLoggerHook'), | |||
dict(type='TensorboardLoggerHook') | |||
]) | |||
cfg.data.train.data_source.ann_file = os.path.join( | |||
data_dir, 'small_coco/small_coco/instances_train2017_20.json') | |||
cfg.data.train.data_source.img_prefix = os.path.join( | |||
data_dir, 'small_coco/small_coco/train2017') | |||
cfg.data.val.data_source.ann_file = os.path.join( | |||
data_dir, 'small_coco/small_coco/instances_val2017_20.json') | |||
cfg.data.val.data_source.img_prefix = os.path.join( | |||
data_dir, 'small_coco/small_coco/val2017') | |||
cfg.data.imgs_per_gpu = imgs_per_gpu | |||
cfg.data.workers_per_gpu = 2 | |||
cfg.data.val.imgs_per_gpu = 2 | |||
ms_cfg_file = os.path.join(work_dir, 'ms_yolox_s_8xb16_300e_coco.json') | |||
from easycv.utils.ms_utils import to_ms_config | |||
if is_master(): | |||
to_ms_config( | |||
cfg, | |||
dump=True, | |||
task=Tasks.image_object_detection, | |||
ms_model_name=Models.yolox, | |||
pipeline_name=Pipelines.easycv_detection, | |||
save_path=ms_cfg_file) | |||
trainer_name = Trainers.easycv | |||
kwargs = dict( | |||
task=Tasks.image_object_detection, | |||
cfg_file=ms_cfg_file, | |||
launcher='pytorch' if dist else None) | |||
trainer = build_trainer(trainer_name, kwargs) | |||
trainer.train() | |||
@unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest') | |||
class EasyCVTrainerTestSingleGpu(unittest.TestCase): | |||
def setUp(self): | |||
self.logger = get_logger() | |||
self.logger.info(('Testing %s.%s' % | |||
(type(self).__name__, self._testMethodName))) | |||
self.tmp_dir = tempfile.TemporaryDirectory().name | |||
if not os.path.exists(self.tmp_dir): | |||
os.makedirs(self.tmp_dir) | |||
def tearDown(self): | |||
super().tearDown() | |||
shutil.rmtree(self.tmp_dir, ignore_errors=True) | |||
@unittest.skipIf( | |||
True, 'The test cases are all run in the master process, ' | |||
'cause registry conflicts, and it should run in the subprocess.') | |||
def test_single_gpu(self): | |||
# TODO: run in subprocess | |||
train_func(self.tmp_dir) | |||
results_files = os.listdir(self.tmp_dir) | |||
json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) | |||
self.assertEqual(len(json_files), 1) | |||
with open(json_files[0], 'r') as f: | |||
lines = [i.strip() for i in f.readlines()] | |||
self.assertDictContainsSubset( | |||
{ | |||
LogKeys.MODE: ModeKeys.TRAIN, | |||
LogKeys.EPOCH: 1, | |||
LogKeys.ITER: 3, | |||
LogKeys.LR: 0.00013 | |||
}, json.loads(lines[0])) | |||
self.assertDictContainsSubset( | |||
{ | |||
LogKeys.MODE: ModeKeys.EVAL, | |||
LogKeys.EPOCH: 1, | |||
LogKeys.ITER: 10 | |||
}, json.loads(lines[1])) | |||
self.assertDictContainsSubset( | |||
{ | |||
LogKeys.MODE: ModeKeys.TRAIN, | |||
LogKeys.EPOCH: 2, | |||
LogKeys.ITER: 3, | |||
LogKeys.LR: 0.00157 | |||
}, json.loads(lines[2])) | |||
self.assertDictContainsSubset( | |||
{ | |||
LogKeys.MODE: ModeKeys.EVAL, | |||
LogKeys.EPOCH: 2, | |||
LogKeys.ITER: 10 | |||
}, json.loads(lines[3])) | |||
self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) | |||
self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | |||
for i in [0, 2]: | |||
self.assertIn(LogKeys.DATA_LOAD_TIME, lines[i]) | |||
self.assertIn(LogKeys.ITER_TIME, lines[i]) | |||
self.assertIn(LogKeys.MEMORY, lines[i]) | |||
self.assertIn('total_loss', lines[i]) | |||
for i in [1, 3]: | |||
self.assertIn( | |||
'CocoDetectionEvaluator_DetectionBoxes_Precision/mAP', | |||
lines[i]) | |||
self.assertIn('DetectionBoxes_Precision/mAP', lines[i]) | |||
self.assertIn('DetectionBoxes_Precision/mAP@.50IOU', lines[i]) | |||
self.assertIn('DetectionBoxes_Precision/mAP@.75IOU', lines[i]) | |||
self.assertIn('DetectionBoxes_Precision/mAP (small)', lines[i]) | |||
@unittest.skipIf(not torch.cuda.is_available() | |||
or torch.cuda.device_count() <= 1, 'distributed unittest') | |||
class EasyCVTrainerTestMultiGpus(DistributedTestCase): | |||
def setUp(self): | |||
self.logger = get_logger() | |||
self.logger.info(('Testing %s.%s' % | |||
(type(self).__name__, self._testMethodName))) | |||
self.tmp_dir = tempfile.TemporaryDirectory().name | |||
if not os.path.exists(self.tmp_dir): | |||
os.makedirs(self.tmp_dir) | |||
def tearDown(self): | |||
super().tearDown() | |||
shutil.rmtree(self.tmp_dir, ignore_errors=True) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_multi_gpus(self): | |||
self.start( | |||
train_func, | |||
num_gpus=2, | |||
work_dir=self.tmp_dir, | |||
dist=True, | |||
log_config=2, | |||
imgs_per_gpu=5) | |||
results_files = os.listdir(self.tmp_dir) | |||
json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) | |||
self.assertEqual(len(json_files), 1) | |||
with open(json_files[0], 'r') as f: | |||
lines = [i.strip() for i in f.readlines()] | |||
self.assertDictContainsSubset( | |||
{ | |||
LogKeys.MODE: ModeKeys.TRAIN, | |||
LogKeys.EPOCH: 1, | |||
LogKeys.ITER: 2, | |||
LogKeys.LR: 0.0002 | |||
}, json.loads(lines[0])) | |||
self.assertDictContainsSubset( | |||
{ | |||
LogKeys.MODE: ModeKeys.EVAL, | |||
LogKeys.EPOCH: 1, | |||
LogKeys.ITER: 5 | |||
}, json.loads(lines[1])) | |||
self.assertDictContainsSubset( | |||
{ | |||
LogKeys.MODE: ModeKeys.TRAIN, | |||
LogKeys.EPOCH: 2, | |||
LogKeys.ITER: 2, | |||
LogKeys.LR: 0.0018 | |||
}, json.loads(lines[2])) | |||
self.assertDictContainsSubset( | |||
{ | |||
LogKeys.MODE: ModeKeys.EVAL, | |||
LogKeys.EPOCH: 2, | |||
LogKeys.ITER: 5 | |||
}, json.loads(lines[3])) | |||
self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) | |||
self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | |||
for i in [0, 2]: | |||
self.assertIn(LogKeys.DATA_LOAD_TIME, lines[i]) | |||
self.assertIn(LogKeys.ITER_TIME, lines[i]) | |||
self.assertIn(LogKeys.MEMORY, lines[i]) | |||
self.assertIn('total_loss', lines[i]) | |||
for i in [1, 3]: | |||
self.assertIn( | |||
'CocoDetectionEvaluator_DetectionBoxes_Precision/mAP', | |||
lines[i]) | |||
self.assertIn('DetectionBoxes_Precision/mAP', lines[i]) | |||
self.assertIn('DetectionBoxes_Precision/mAP@.50IOU', lines[i]) | |||
self.assertIn('DetectionBoxes_Precision/mAP@.75IOU', lines[i]) | |||
self.assertIn('DetectionBoxes_Precision/mAP (small)', lines[i]) | |||
if __name__ == '__main__': | |||
unittest.main() |
@@ -0,0 +1,99 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import glob | |||
import os | |||
import shutil | |||
import tempfile | |||
import unittest | |||
import requests | |||
import torch | |||
from modelscope.metainfo import Trainers | |||
from modelscope.trainers import build_trainer | |||
from modelscope.utils.constant import LogKeys, Tasks | |||
from modelscope.utils.logger import get_logger | |||
from modelscope.utils.test_utils import test_level | |||
from modelscope.utils.torch_utils import is_master | |||
def _download_data(url, save_dir): | |||
r = requests.get(url, verify=True) | |||
if not os.path.exists(save_dir): | |||
os.makedirs(save_dir) | |||
zip_name = os.path.split(url)[-1] | |||
save_path = os.path.join(save_dir, zip_name) | |||
with open(save_path, 'wb') as f: | |||
f.write(r.content) | |||
unpack_dir = os.path.join(save_dir, os.path.splitext(zip_name)[0]) | |||
shutil.unpack_archive(save_path, unpack_dir) | |||
@unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest') | |||
class EasyCVTrainerTestSegformer(unittest.TestCase): | |||
def setUp(self): | |||
self.logger = get_logger() | |||
self.logger.info(('Testing %s.%s' % | |||
(type(self).__name__, self._testMethodName))) | |||
self.tmp_dir = tempfile.TemporaryDirectory().name | |||
if not os.path.exists(self.tmp_dir): | |||
os.makedirs(self.tmp_dir) | |||
def tearDown(self): | |||
super().tearDown() | |||
shutil.rmtree(self.tmp_dir, ignore_errors=True) | |||
def _train(self): | |||
from modelscope.trainers.easycv.trainer import EasyCVEpochBasedTrainer | |||
url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/datasets/small_coco_stuff164k.zip' | |||
data_dir = os.path.join(self.tmp_dir, 'data') | |||
if is_master(): | |||
_download_data(url, data_dir) | |||
# adapt to ditributed mode | |||
from easycv.utils.test_util import pseudo_dist_init | |||
pseudo_dist_init() | |||
root_path = os.path.join(data_dir, 'small_coco_stuff164k') | |||
cfg_options = { | |||
'train.max_epochs': | |||
2, | |||
'dataset.train.data_source.img_root': | |||
os.path.join(root_path, 'train2017'), | |||
'dataset.train.data_source.label_root': | |||
os.path.join(root_path, 'annotations/train2017'), | |||
'dataset.train.data_source.split': | |||
os.path.join(root_path, 'train.txt'), | |||
'dataset.val.data_source.img_root': | |||
os.path.join(root_path, 'val2017'), | |||
'dataset.val.data_source.label_root': | |||
os.path.join(root_path, 'annotations/val2017'), | |||
'dataset.val.data_source.split': | |||
os.path.join(root_path, 'val.txt'), | |||
} | |||
trainer_name = Trainers.easycv | |||
kwargs = dict( | |||
task=Tasks.image_segmentation, | |||
model='EasyCV/EasyCV-Segformer-b0', | |||
work_dir=self.tmp_dir, | |||
cfg_options=cfg_options) | |||
trainer = build_trainer(trainer_name, kwargs) | |||
trainer.train() | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_single_gpu_segformer(self): | |||
self._train() | |||
results_files = os.listdir(self.tmp_dir) | |||
json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) | |||
self.assertEqual(len(json_files), 1) | |||
self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) | |||
self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | |||
if __name__ == '__main__': | |||
unittest.main() |
@@ -4,6 +4,8 @@ import copy | |||
import tempfile | |||
import unittest | |||
import json | |||
from modelscope.utils.config import Config, check_config | |||
obj = {'a': 1, 'b': {'c': [1, 2, 3], 'd': 'dd'}} | |||
@@ -43,7 +45,8 @@ class ConfigTest(unittest.TestCase): | |||
self.assertEqual(pretty_text, cfg.dump()) | |||
cfg.dump(ofile.name) | |||
with open(ofile.name, 'r') as infile: | |||
self.assertEqual(json_str, infile.read()) | |||
self.assertDictEqual( | |||
json.loads(json_str), json.loads(infile.read())) | |||
with tempfile.NamedTemporaryFile(suffix='.yaml') as ofile: | |||
cfg.dump(ofile.name) | |||