Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9781849 * support EasyCVmaster
@@ -2,7 +2,6 @@ | |||||
"framework": "pytorch", | "framework": "pytorch", | ||||
"task": "image_classification", | "task": "image_classification", | ||||
"work_dir": "./work_dir", | |||||
"model": { | "model": { | ||||
"type": "classification", | "type": "classification", | ||||
@@ -119,6 +118,7 @@ | |||||
}, | }, | ||||
"train": { | "train": { | ||||
"work_dir": "./work_dir", | |||||
"dataloader": { | "dataloader": { | ||||
"batch_size_per_gpu": 2, | "batch_size_per_gpu": 2, | ||||
"workers_per_gpu": 1 | "workers_per_gpu": 1 | ||||
@@ -0,0 +1,3 @@ | |||||
version https://git-lfs.github.com/spec/v1 | |||||
oid sha256:af6fa61274e497ecc170de5adc4b8e7ac89eba2bc22a6aa119b08ec7adbe9459 | |||||
size 146140 |
@@ -1,5 +1,5 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
import json | |||||
import jsonplus | |||||
import numpy as np | import numpy as np | ||||
from .base import FormatHandler | from .base import FormatHandler | ||||
@@ -22,14 +22,14 @@ def set_default(obj): | |||||
class JsonHandler(FormatHandler): | class JsonHandler(FormatHandler): | ||||
"""Use jsonplus, serialization of Python types to JSON that "just works".""" | |||||
def load(self, file): | def load(self, file): | ||||
return json.load(file) | |||||
return jsonplus.loads(file.read()) | |||||
def dump(self, obj, file, **kwargs): | def dump(self, obj, file, **kwargs): | ||||
kwargs.setdefault('default', set_default) | |||||
json.dump(obj, file, **kwargs) | |||||
file.write(self.dumps(obj, **kwargs)) | |||||
def dumps(self, obj, **kwargs): | def dumps(self, obj, **kwargs): | ||||
kwargs.setdefault('default', set_default) | kwargs.setdefault('default', set_default) | ||||
return json.dumps(obj, **kwargs) | |||||
return jsonplus.dumps(obj, **kwargs) |
@@ -26,6 +26,10 @@ class Models(object): | |||||
swinL_semantic_segmentation = 'swinL-semantic-segmentation' | swinL_semantic_segmentation = 'swinL-semantic-segmentation' | ||||
vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' | vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' | ||||
# EasyCV models | |||||
yolox = 'YOLOX' | |||||
segformer = 'Segformer' | |||||
# nlp models | # nlp models | ||||
bert = 'bert' | bert = 'bert' | ||||
palm = 'palm-v2' | palm = 'palm-v2' | ||||
@@ -92,6 +96,8 @@ class Pipelines(object): | |||||
body_2d_keypoints = 'hrnetv2w32_body-2d-keypoints_image' | body_2d_keypoints = 'hrnetv2w32_body-2d-keypoints_image' | ||||
human_detection = 'resnet18-human-detection' | human_detection = 'resnet18-human-detection' | ||||
object_detection = 'vit-object-detection' | object_detection = 'vit-object-detection' | ||||
easycv_detection = 'easycv-detection' | |||||
easycv_segmentation = 'easycv-segmentation' | |||||
salient_detection = 'u2net-salient-detection' | salient_detection = 'u2net-salient-detection' | ||||
image_classification = 'image-classification' | image_classification = 'image-classification' | ||||
face_detection = 'resnet-face-detection-scrfd10gkps' | face_detection = 'resnet-face-detection-scrfd10gkps' | ||||
@@ -171,6 +177,7 @@ class Trainers(object): | |||||
""" | """ | ||||
default = 'trainer' | default = 'trainer' | ||||
easycv = 'easycv' | |||||
# multi-modal trainers | # multi-modal trainers | ||||
clip_multi_modal_embedding = 'clip-multi-modal-embedding' | clip_multi_modal_embedding = 'clip-multi-modal-embedding' | ||||
@@ -307,3 +314,12 @@ class LR_Schedulers(object): | |||||
LinearWarmup = 'LinearWarmup' | LinearWarmup = 'LinearWarmup' | ||||
ConstantWarmup = 'ConstantWarmup' | ConstantWarmup = 'ConstantWarmup' | ||||
ExponentialWarmup = 'ExponentialWarmup' | ExponentialWarmup = 'ExponentialWarmup' | ||||
class Datasets(object): | |||||
""" Names for different datasets. | |||||
""" | |||||
ClsDataset = 'ClsDataset' | |||||
SegDataset = 'SegDataset' | |||||
DetDataset = 'DetDataset' | |||||
DetImagesMixDataset = 'DetImagesMixDataset' |
@@ -1,4 +1,5 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
from typing import Dict, Mapping, Union | |||||
from modelscope.metainfo import Metrics | from modelscope.metainfo import Metrics | ||||
from modelscope.utils.config import ConfigDict | from modelscope.utils.config import ConfigDict | ||||
@@ -35,16 +36,19 @@ task_default_metrics = { | |||||
} | } | ||||
def build_metric(metric_name: str, | |||||
def build_metric(metric_cfg: Union[str, Dict], | |||||
field: str = default_group, | field: str = default_group, | ||||
default_args: dict = None): | default_args: dict = None): | ||||
""" Build metric given metric_name and field. | """ Build metric given metric_name and field. | ||||
Args: | Args: | ||||
metric_name (:obj:`str`): The metric name. | |||||
metric_name (str | dict): The metric name or metric config dict. | |||||
field (str, optional): The field of this metric, default value: 'default' for all fields. | field (str, optional): The field of this metric, default value: 'default' for all fields. | ||||
default_args (dict, optional): Default initialization arguments. | default_args (dict, optional): Default initialization arguments. | ||||
""" | """ | ||||
cfg = ConfigDict({'type': metric_name}) | |||||
if isinstance(metric_cfg, Mapping): | |||||
assert 'type' in metric_cfg | |||||
else: | |||||
metric_cfg = ConfigDict({'type': metric_cfg}) | |||||
return build_from_cfg( | return build_from_cfg( | ||||
cfg, METRICS, group_key=field, default_args=default_args) | |||||
metric_cfg, METRICS, group_key=field, default_args=default_args) |
@@ -0,0 +1,25 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from easycv.models.base import BaseModel | |||||
from easycv.utils.ms_utils import EasyCVMeta | |||||
from modelscope.models.base import TorchModel | |||||
class EasyCVBaseModel(BaseModel, TorchModel): | |||||
"""Base model for EasyCV.""" | |||||
def __init__(self, model_dir=None, args=(), kwargs={}): | |||||
kwargs.pop(EasyCVMeta.ARCH, None) # pop useless keys | |||||
BaseModel.__init__(self) | |||||
TorchModel.__init__(self, model_dir=model_dir) | |||||
def forward(self, img, mode='train', **kwargs): | |||||
if self.training: | |||||
losses = self.forward_train(img, **kwargs) | |||||
loss, log_vars = self._parse_losses(losses) | |||||
return dict(loss=loss, log_vars=log_vars) | |||||
else: | |||||
return self.forward_test(img, **kwargs) | |||||
def __call__(self, *args, **kwargs): | |||||
return self.forward(*args, **kwargs) |
@@ -5,10 +5,12 @@ from modelscope.utils.import_utils import LazyImportModule | |||||
if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
from .semantic_seg_model import SemanticSegmentation | from .semantic_seg_model import SemanticSegmentation | ||||
from .segformer import Segformer | |||||
else: | else: | ||||
_import_structure = { | _import_structure = { | ||||
'semantic_seg_model': ['SemanticSegmentation'], | 'semantic_seg_model': ['SemanticSegmentation'], | ||||
'segformer': ['Segformer'] | |||||
} | } | ||||
import sys | import sys | ||||
@@ -0,0 +1,16 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from easycv.models.segmentation import EncoderDecoder | |||||
from modelscope.metainfo import Models | |||||
from modelscope.models.builder import MODELS | |||||
from modelscope.models.cv.easycv_base import EasyCVBaseModel | |||||
from modelscope.utils.constant import Tasks | |||||
@MODELS.register_module( | |||||
group_key=Tasks.image_segmentation, module_name=Models.segformer) | |||||
class Segformer(EasyCVBaseModel, EncoderDecoder): | |||||
def __init__(self, model_dir=None, *args, **kwargs): | |||||
EasyCVBaseModel.__init__(self, model_dir, args, kwargs) | |||||
EncoderDecoder.__init__(self, *args, **kwargs) |
@@ -5,10 +5,12 @@ from modelscope.utils.import_utils import LazyImportModule | |||||
if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
from .mmdet_model import DetectionModel | from .mmdet_model import DetectionModel | ||||
from .yolox_pai import YOLOX | |||||
else: | else: | ||||
_import_structure = { | _import_structure = { | ||||
'mmdet_model': ['DetectionModel'], | 'mmdet_model': ['DetectionModel'], | ||||
'yolox_pai': ['YOLOX'] | |||||
} | } | ||||
import sys | import sys | ||||
@@ -0,0 +1,16 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from easycv.models.detection.detectors import YOLOX as _YOLOX | |||||
from modelscope.metainfo import Models | |||||
from modelscope.models.builder import MODELS | |||||
from modelscope.models.cv.easycv_base import EasyCVBaseModel | |||||
from modelscope.utils.constant import Tasks | |||||
@MODELS.register_module( | |||||
group_key=Tasks.image_object_detection, module_name=Models.yolox) | |||||
class YOLOX(EasyCVBaseModel, _YOLOX): | |||||
def __init__(self, model_dir=None, *args, **kwargs): | |||||
EasyCVBaseModel.__init__(self, model_dir, args, kwargs) | |||||
_YOLOX.__init__(self, *args, **kwargs) |
@@ -1 +1,3 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from . import cv | |||||
from .ms_dataset import MsDataset | from .ms_dataset import MsDataset |
@@ -0,0 +1,3 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from . import (image_classification, image_semantic_segmentation, | |||||
object_detection) |
@@ -0,0 +1,20 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from typing import TYPE_CHECKING | |||||
from modelscope.utils.import_utils import LazyImportModule | |||||
if TYPE_CHECKING: | |||||
from .classification_dataset import ClsDataset | |||||
else: | |||||
_import_structure = {'classification_dataset': ['ClsDataset']} | |||||
import sys | |||||
sys.modules[__name__] = LazyImportModule( | |||||
__name__, | |||||
globals()['__file__'], | |||||
_import_structure, | |||||
module_spec=__spec__, | |||||
extra_objects={}, | |||||
) |
@@ -0,0 +1,19 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from easycv.datasets.classification import ClsDataset as _ClsDataset | |||||
from modelscope.metainfo import Datasets | |||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS | |||||
from modelscope.utils.constant import Tasks | |||||
@TASK_DATASETS.register_module( | |||||
group_key=Tasks.image_classification, module_name=Datasets.ClsDataset) | |||||
class ClsDataset(_ClsDataset): | |||||
"""EasyCV dataset for classification. | |||||
For more details, please refer to : | |||||
https://github.com/alibaba/EasyCV/blob/master/easycv/datasets/classification/raw.py . | |||||
Args: | |||||
data_source: Data source config to parse input data. | |||||
pipeline: Sequence of transform object or config dict to be composed. | |||||
""" |
@@ -0,0 +1,20 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from typing import TYPE_CHECKING | |||||
from modelscope.utils.import_utils import LazyImportModule | |||||
if TYPE_CHECKING: | |||||
from .segmentation_dataset import SegDataset | |||||
else: | |||||
_import_structure = {'easycv_segmentation': ['SegDataset']} | |||||
import sys | |||||
sys.modules[__name__] = LazyImportModule( | |||||
__name__, | |||||
globals()['__file__'], | |||||
_import_structure, | |||||
module_spec=__spec__, | |||||
extra_objects={}, | |||||
) |
@@ -0,0 +1,21 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from easycv.datasets.segmentation import SegDataset as _SegDataset | |||||
from modelscope.metainfo import Datasets | |||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS | |||||
from modelscope.utils.constant import Tasks | |||||
@TASK_DATASETS.register_module( | |||||
group_key=Tasks.image_segmentation, module_name=Datasets.SegDataset) | |||||
class SegDataset(_SegDataset): | |||||
"""EasyCV dataset for Sementic segmentation. | |||||
For more details, please refer to : | |||||
https://github.com/alibaba/EasyCV/blob/master/easycv/datasets/segmentation/raw.py . | |||||
Args: | |||||
data_source: Data source config to parse input data. | |||||
pipeline: Sequence of transform object or config dict to be composed. | |||||
ignore_index (int): Label index to be ignored. | |||||
profiling: If set True, will print transform time. | |||||
""" |
@@ -0,0 +1,22 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from typing import TYPE_CHECKING | |||||
from modelscope.utils.import_utils import LazyImportModule | |||||
if TYPE_CHECKING: | |||||
from .easycv_detection import DetDataset, DetImagesMixDataset | |||||
else: | |||||
_import_structure = { | |||||
'easycv_detection': ['DetDataset', 'DetImagesMixDataset'] | |||||
} | |||||
import sys | |||||
sys.modules[__name__] = LazyImportModule( | |||||
__name__, | |||||
globals()['__file__'], | |||||
_import_structure, | |||||
module_spec=__spec__, | |||||
extra_objects={}, | |||||
) |
@@ -0,0 +1,49 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from easycv.datasets.detection import DetDataset as _DetDataset | |||||
from easycv.datasets.detection import \ | |||||
DetImagesMixDataset as _DetImagesMixDataset | |||||
from modelscope.metainfo import Datasets | |||||
from modelscope.msdatasets.task_datasets import TASK_DATASETS | |||||
from modelscope.utils.constant import Tasks | |||||
@TASK_DATASETS.register_module( | |||||
group_key=Tasks.image_object_detection, module_name=Datasets.DetDataset) | |||||
class DetDataset(_DetDataset): | |||||
"""EasyCV dataset for object detection. | |||||
For more details, please refer to https://github.com/alibaba/EasyCV/blob/master/easycv/datasets/detection/raw.py . | |||||
Args: | |||||
data_source: Data source config to parse input data. | |||||
pipeline: Transform config list | |||||
profiling: If set True, will print pipeline time | |||||
classes: A list of class names, used in evaluation for result and groundtruth visualization | |||||
""" | |||||
@TASK_DATASETS.register_module( | |||||
group_key=Tasks.image_object_detection, | |||||
module_name=Datasets.DetImagesMixDataset) | |||||
class DetImagesMixDataset(_DetImagesMixDataset): | |||||
"""EasyCV dataset for object detection, a wrapper of multiple images mixed dataset. | |||||
Suitable for training on multiple images mixed data augmentation like | |||||
mosaic and mixup. For the augmentation pipeline of mixed image data, | |||||
the `get_indexes` method needs to be provided to obtain the image | |||||
indexes, and you can set `skip_flags` to change the pipeline running | |||||
process. At the same time, we provide the `dynamic_scale` parameter | |||||
to dynamically change the output image size. | |||||
output boxes format: cx, cy, w, h | |||||
For more details, please refer to https://github.com/alibaba/EasyCV/blob/master/easycv/datasets/detection/mix.py . | |||||
Args: | |||||
data_source (:obj:`DetSourceCoco`): Data source config to parse input data. | |||||
pipeline (Sequence[dict]): Sequence of transform object or | |||||
config dict to be composed. | |||||
dynamic_scale (tuple[int], optional): The image scale can be changed | |||||
dynamically. Default to None. | |||||
skip_type_keys (list[str], optional): Sequence of type string to | |||||
be skip pipeline. Default to None. | |||||
label_padding: out labeling padding [N, 120, 5] | |||||
""" |
@@ -240,9 +240,9 @@ class Pipeline(ABC): | |||||
raise ValueError(f'Unsupported data type {type(data)}') | raise ValueError(f'Unsupported data type {type(data)}') | ||||
def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]: | def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]: | ||||
preprocess_params = kwargs.get('preprocess_params') | |||||
forward_params = kwargs.get('forward_params') | |||||
postprocess_params = kwargs.get('postprocess_params') | |||||
preprocess_params = kwargs.get('preprocess_params', {}) | |||||
forward_params = kwargs.get('forward_params', {}) | |||||
postprocess_params = kwargs.get('postprocess_params', {}) | |||||
out = self.preprocess(input, **preprocess_params) | out = self.preprocess(input, **preprocess_params) | ||||
with device_placement(self.framework, self.device_name): | with device_placement(self.framework, self.device_name): | ||||
@@ -39,7 +39,7 @@ if TYPE_CHECKING: | |||||
from .tinynas_classification_pipeline import TinynasClassificationPipeline | from .tinynas_classification_pipeline import TinynasClassificationPipeline | ||||
from .video_category_pipeline import VideoCategoryPipeline | from .video_category_pipeline import VideoCategoryPipeline | ||||
from .virtual_try_on_pipeline import VirtualTryonPipeline | from .virtual_try_on_pipeline import VirtualTryonPipeline | ||||
from .easycv_pipelines import EasyCVDetectionPipeline, EasyCVSegmentationPipeline | |||||
else: | else: | ||||
_import_structure = { | _import_structure = { | ||||
'action_recognition_pipeline': ['ActionRecognitionPipeline'], | 'action_recognition_pipeline': ['ActionRecognitionPipeline'], | ||||
@@ -84,6 +84,8 @@ else: | |||||
'tinynas_classification_pipeline': ['TinynasClassificationPipeline'], | 'tinynas_classification_pipeline': ['TinynasClassificationPipeline'], | ||||
'video_category_pipeline': ['VideoCategoryPipeline'], | 'video_category_pipeline': ['VideoCategoryPipeline'], | ||||
'virtual_try_on_pipeline': ['VirtualTryonPipeline'], | 'virtual_try_on_pipeline': ['VirtualTryonPipeline'], | ||||
'easycv_pipeline': | |||||
['EasyCVDetectionPipeline', 'EasyCVSegmentationPipeline'] | |||||
} | } | ||||
import sys | import sys | ||||
@@ -0,0 +1,23 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from typing import TYPE_CHECKING | |||||
from modelscope.utils.import_utils import LazyImportModule | |||||
if TYPE_CHECKING: | |||||
from .detection_pipeline import EasyCVDetectionPipeline | |||||
from .segmentation_pipeline import EasyCVSegmentationPipeline | |||||
else: | |||||
_import_structure = { | |||||
'detection_pipeline': ['EasyCVDetectionPipeline'], | |||||
'segmentation_pipeline': ['EasyCVSegmentationPipeline'] | |||||
} | |||||
import sys | |||||
sys.modules[__name__] = LazyImportModule( | |||||
__name__, | |||||
globals()['__file__'], | |||||
_import_structure, | |||||
module_spec=__spec__, | |||||
extra_objects={}, | |||||
) |
@@ -0,0 +1,95 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import glob | |||||
import os | |||||
import os.path as osp | |||||
from typing import Any | |||||
from easycv.utils.ms_utils import EasyCVMeta | |||||
from modelscope.hub.snapshot_download import snapshot_download | |||||
from modelscope.pipelines.util import is_official_hub_path | |||||
from modelscope.utils.config import Config | |||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile | |||||
class EasyCVPipeline(object): | |||||
"""Base pipeline for EasyCV. | |||||
Loading configuration file of modelscope style by default, | |||||
but it is actually use the predictor api of easycv to predict. | |||||
So here we do some adaptation work for configuration and predict api. | |||||
""" | |||||
def __init__(self, model: str, model_file_pattern='*.pt', *args, **kwargs): | |||||
""" | |||||
model (str): model id on modelscope hub or local model path. | |||||
model_file_pattern (str): model file pattern. | |||||
""" | |||||
self.model_file_pattern = model_file_pattern | |||||
assert isinstance(model, str) | |||||
if osp.exists(model): | |||||
model_dir = model | |||||
else: | |||||
assert is_official_hub_path( | |||||
model), 'Only support local model path and official hub path!' | |||||
model_dir = snapshot_download( | |||||
model_id=model, revision=DEFAULT_MODEL_REVISION) | |||||
assert osp.isdir(model_dir) | |||||
model_files = glob.glob( | |||||
os.path.join(model_dir, self.model_file_pattern)) | |||||
assert len( | |||||
model_files | |||||
) == 1, f'Need one model file, but find {len(model_files)}: {model_files}' | |||||
model_path = model_files[0] | |||||
self.model_path = model_path | |||||
# get configuration file from source model dir | |||||
self.config_file = os.path.join(model_dir, ModelFile.CONFIGURATION) | |||||
assert os.path.exists( | |||||
self.config_file | |||||
), f'Not find "{ModelFile.CONFIGURATION}" in model directory!' | |||||
self.cfg = Config.from_file(self.config_file) | |||||
self.predict_op = self._build_predict_op() | |||||
def _build_predict_op(self): | |||||
"""Build EasyCV predictor.""" | |||||
from easycv.predictors.builder import build_predictor | |||||
easycv_config = self._to_easycv_config() | |||||
pipeline_op = build_predictor(self.cfg.pipeline.predictor_config, { | |||||
'model_path': self.model_path, | |||||
'config_file': easycv_config | |||||
}) | |||||
return pipeline_op | |||||
def _to_easycv_config(self): | |||||
"""Adapt to EasyCV predictor.""" | |||||
# TODO: refine config compatibility problems | |||||
easycv_arch = self.cfg.model.pop(EasyCVMeta.ARCH, None) | |||||
model_cfg = self.cfg.model | |||||
# Revert to the configuration of easycv | |||||
if easycv_arch is not None: | |||||
model_cfg.update(easycv_arch) | |||||
easycv_config = Config(dict(model=model_cfg)) | |||||
reserved_keys = [] | |||||
if hasattr(self.cfg, EasyCVMeta.META): | |||||
easycv_meta_cfg = getattr(self.cfg, EasyCVMeta.META) | |||||
reserved_keys = easycv_meta_cfg.get(EasyCVMeta.RESERVED_KEYS, []) | |||||
for key in reserved_keys: | |||||
easycv_config.merge_from_dict({key: getattr(self.cfg, key)}) | |||||
if 'test_pipeline' not in reserved_keys: | |||||
easycv_config.merge_from_dict( | |||||
{'test_pipeline': self.cfg.dataset.val.get('pipeline', [])}) | |||||
return easycv_config | |||||
def __call__(self, inputs) -> Any: | |||||
# TODO: support image url | |||||
return self.predict_op(inputs) |
@@ -0,0 +1,23 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from modelscope.metainfo import Pipelines | |||||
from modelscope.pipelines.builder import PIPELINES | |||||
from modelscope.utils.constant import Tasks | |||||
from .base import EasyCVPipeline | |||||
@PIPELINES.register_module( | |||||
Tasks.image_object_detection, module_name=Pipelines.easycv_detection) | |||||
class EasyCVDetectionPipeline(EasyCVPipeline): | |||||
"""Pipeline for easycv detection task.""" | |||||
def __init__(self, model: str, model_file_pattern='*.pt', *args, **kwargs): | |||||
""" | |||||
model (str): model id on modelscope hub or local model path. | |||||
model_file_pattern (str): model file pattern. | |||||
""" | |||||
super(EasyCVDetectionPipeline, self).__init__( | |||||
model=model, | |||||
model_file_pattern=model_file_pattern, | |||||
*args, | |||||
**kwargs) |
@@ -0,0 +1,23 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from modelscope.metainfo import Pipelines | |||||
from modelscope.pipelines.builder import PIPELINES | |||||
from modelscope.utils.constant import Tasks | |||||
from .base import EasyCVPipeline | |||||
@PIPELINES.register_module( | |||||
Tasks.image_segmentation, module_name=Pipelines.easycv_segmentation) | |||||
class EasyCVSegmentationPipeline(EasyCVPipeline): | |||||
"""Pipeline for easycv segmentation task.""" | |||||
def __init__(self, model: str, model_file_pattern='*.pt', *args, **kwargs): | |||||
""" | |||||
model (str): model id on modelscope hub or local model path. | |||||
model_file_pattern (str): model file pattern. | |||||
""" | |||||
super(EasyCVSegmentationPipeline, self).__init__( | |||||
model=model, | |||||
model_file_pattern=model_file_pattern, | |||||
*args, | |||||
**kwargs) |
@@ -0,0 +1,175 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from functools import partial | |||||
from typing import Callable, Optional, Tuple, Union | |||||
import torch | |||||
from torch import nn | |||||
from torch.utils.data import Dataset | |||||
from modelscope.metainfo import Trainers | |||||
from modelscope.models.base import TorchModel | |||||
from modelscope.msdatasets import MsDataset | |||||
from modelscope.preprocessors import Preprocessor | |||||
from modelscope.trainers import EpochBasedTrainer | |||||
from modelscope.trainers.base import TRAINERS | |||||
from modelscope.trainers.easycv.utils import register_util | |||||
from modelscope.trainers.hooks import HOOKS | |||||
from modelscope.trainers.parallel.builder import build_parallel | |||||
from modelscope.trainers.parallel.utils import is_parallel | |||||
from modelscope.utils.config import Config | |||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION | |||||
from modelscope.utils.import_utils import LazyImportModule | |||||
from modelscope.utils.registry import default_group | |||||
@TRAINERS.register_module(module_name=Trainers.easycv) | |||||
class EasyCVEpochBasedTrainer(EpochBasedTrainer): | |||||
"""Epoch based Trainer for EasyCV. | |||||
Args: | |||||
task: Task name. | |||||
cfg_file(str): The config file of EasyCV. | |||||
model (:obj:`torch.nn.Module` or :obj:`TorchModel` or `str`): The model to be run, or a valid model dir | |||||
or a model id. If model is None, build_model method will be called. | |||||
train_dataset (`MsDataset` or `torch.utils.data.Dataset`, *optional*): | |||||
The dataset to use for training. | |||||
Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a | |||||
distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a | |||||
`torch.Generator` for the randomization that must be identical on all processes (and the Trainer will | |||||
manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally | |||||
sets the seed of the RNGs used. | |||||
eval_dataset (`MsDataset` or `torch.utils.data.Dataset`, *optional*): The dataset to use for evaluation. | |||||
preprocessor (:obj:`Preprocessor`, *optional*): The optional preprocessor. | |||||
NOTE: If the preprocessor has been called before the dataset fed into this trainer by user's custom code, | |||||
this parameter should be None, meanwhile remove the 'preprocessor' key from the cfg_file. | |||||
Else the preprocessor will be instantiated from the cfg_file or assigned from this parameter and | |||||
this preprocessing action will be executed every time the dataset's __getitem__ is called. | |||||
optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]`, *optional*): A tuple | |||||
containing the optimizer and the scheduler to use. | |||||
max_epochs: (int, optional): Total training epochs. | |||||
""" | |||||
def __init__( | |||||
self, | |||||
task: str, | |||||
cfg_file: Optional[str] = None, | |||||
model: Optional[Union[TorchModel, nn.Module, str]] = None, | |||||
arg_parse_fn: Optional[Callable] = None, | |||||
train_dataset: Optional[Union[MsDataset, Dataset]] = None, | |||||
eval_dataset: Optional[Union[MsDataset, Dataset]] = None, | |||||
preprocessor: Optional[Preprocessor] = None, | |||||
optimizers: Tuple[torch.optim.Optimizer, | |||||
torch.optim.lr_scheduler._LRScheduler] = (None, | |||||
None), | |||||
model_revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||||
**kwargs): | |||||
self.task = task | |||||
register_util.register_parallel() | |||||
register_util.register_part_mmcv_hooks_to_ms() | |||||
super(EasyCVEpochBasedTrainer, self).__init__( | |||||
model=model, | |||||
cfg_file=cfg_file, | |||||
arg_parse_fn=arg_parse_fn, | |||||
preprocessor=preprocessor, | |||||
optimizers=optimizers, | |||||
model_revision=model_revision, | |||||
train_dataset=train_dataset, | |||||
eval_dataset=eval_dataset, | |||||
**kwargs) | |||||
# reset data_collator | |||||
from mmcv.parallel import collate | |||||
self.train_data_collator = partial( | |||||
collate, | |||||
samples_per_gpu=self.cfg.train.dataloader.batch_size_per_gpu) | |||||
self.eval_data_collator = partial( | |||||
collate, | |||||
samples_per_gpu=self.cfg.evaluation.dataloader.batch_size_per_gpu) | |||||
# Register easycv hooks dynamicly. If the hook already exists in modelscope, | |||||
# the hook in modelscope will be used, otherwise register easycv hook into ms. | |||||
# We must manually trigger lazy import to detect whether the hook is in modelscope. | |||||
# TODO: use ast index to detect whether the hook is in modelscope | |||||
for h_i in self.cfg.train.get('hooks', []): | |||||
sig = ('HOOKS', default_group, h_i['type']) | |||||
LazyImportModule.import_module(sig) | |||||
if h_i['type'] not in HOOKS._modules[default_group]: | |||||
if h_i['type'] in [ | |||||
'TensorboardLoggerHookV2', 'WandbLoggerHookV2' | |||||
]: | |||||
raise ValueError( | |||||
'Not support hook %s now, we will support it in the future!' | |||||
% h_i['type']) | |||||
register_util.register_hook_to_ms(h_i['type'], self.logger) | |||||
# reset parallel | |||||
if not self._dist: | |||||
assert not is_parallel( | |||||
self.model | |||||
), 'Not support model wrapped by custom parallel if not in distributed mode!' | |||||
dp_cfg = dict( | |||||
type='MMDataParallel', | |||||
module=self.model, | |||||
device_ids=[torch.cuda.current_device()]) | |||||
self.model = build_parallel(dp_cfg) | |||||
def create_optimizer_and_scheduler(self): | |||||
""" Create optimizer and lr scheduler | |||||
""" | |||||
optimizer, lr_scheduler = self.optimizers | |||||
if optimizer is None: | |||||
optimizer_cfg = self.cfg.train.get('optimizer', None) | |||||
else: | |||||
optimizer_cfg = None | |||||
optim_options = {} | |||||
if optimizer_cfg is not None: | |||||
optim_options = optimizer_cfg.pop('options', {}) | |||||
from easycv.apis.train import build_optimizer | |||||
optimizer = build_optimizer(self.model, optimizer_cfg) | |||||
if lr_scheduler is None: | |||||
lr_scheduler_cfg = self.cfg.train.get('lr_scheduler', None) | |||||
else: | |||||
lr_scheduler_cfg = None | |||||
lr_options = {} | |||||
# Adapt to mmcv lr scheduler hook. | |||||
# Please refer to: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/lr_updater.py | |||||
if lr_scheduler_cfg is not None: | |||||
assert optimizer is not None | |||||
lr_options = lr_scheduler_cfg.pop('options', {}) | |||||
assert 'policy' in lr_scheduler_cfg | |||||
policy_type = lr_scheduler_cfg.pop('policy') | |||||
if policy_type == policy_type.lower(): | |||||
policy_type = policy_type.title() | |||||
hook_type = policy_type + 'LrUpdaterHook' | |||||
lr_scheduler_cfg['type'] = hook_type | |||||
self.cfg.train.lr_scheduler_hook = lr_scheduler_cfg | |||||
self.optimizer = optimizer | |||||
self.lr_scheduler = lr_scheduler | |||||
return self.optimizer, self.lr_scheduler, optim_options, lr_options | |||||
def to_parallel(self, model) -> Union[nn.Module, TorchModel]: | |||||
if self.cfg.get('parallel', None) is not None: | |||||
self.cfg.parallel.update( | |||||
dict(module=model, device_ids=[torch.cuda.current_device()])) | |||||
return build_parallel(self.cfg.parallel) | |||||
dp_cfg = dict( | |||||
type='MMDistributedDataParallel', | |||||
module=model, | |||||
device_ids=[torch.cuda.current_device()]) | |||||
return build_parallel(dp_cfg) | |||||
def rebuild_config(self, cfg: Config): | |||||
cfg.task = self.task | |||||
return cfg |
@@ -0,0 +1,21 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from typing import TYPE_CHECKING | |||||
from modelscope.utils.import_utils import LazyImportModule | |||||
if TYPE_CHECKING: | |||||
from .hooks import AddLrLogHook | |||||
from .metric import EasyCVMetric | |||||
else: | |||||
_import_structure = {'hooks': ['AddLrLogHook'], 'metric': ['EasyCVMetric']} | |||||
import sys | |||||
sys.modules[__name__] = LazyImportModule( | |||||
__name__, | |||||
globals()['__file__'], | |||||
_import_structure, | |||||
module_spec=__spec__, | |||||
extra_objects={}, | |||||
) |
@@ -0,0 +1,29 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
from modelscope.trainers.hooks import HOOKS, Priority | |||||
from modelscope.trainers.hooks.lr_scheduler_hook import LrSchedulerHook | |||||
from modelscope.utils.constant import LogKeys | |||||
@HOOKS.register_module(module_name='AddLrLogHook') | |||||
class AddLrLogHook(LrSchedulerHook): | |||||
"""For EasyCV to adapt to ModelScope, the lr log of EasyCV is added in the trainer, | |||||
but the trainer of ModelScope does not and it is added in the lr scheduler hook. | |||||
But The lr scheduler hook used by EasyCV is the hook of mmcv, and there is no lr log. | |||||
It will be deleted in the future. | |||||
""" | |||||
PRIORITY = Priority.NORMAL | |||||
def __init__(self): | |||||
pass | |||||
def before_run(self, trainer): | |||||
pass | |||||
def before_train_iter(self, trainer): | |||||
trainer.log_buffer.output[LogKeys.LR] = self._get_log_lr(trainer) | |||||
def before_train_epoch(self, trainer): | |||||
trainer.log_buffer.output[LogKeys.LR] = self._get_log_lr(trainer) | |||||
def after_train_epoch(self, trainer): | |||||
pass |
@@ -0,0 +1,52 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import itertools | |||||
from typing import Dict | |||||
import numpy as np | |||||
import torch | |||||
from modelscope.metrics.base import Metric | |||||
from modelscope.metrics.builder import METRICS | |||||
@METRICS.register_module(module_name='EasyCVMetric') | |||||
class EasyCVMetric(Metric): | |||||
"""Adapt to ModelScope Metric for EasyCV evaluator. | |||||
""" | |||||
def __init__(self, trainer=None, evaluators=None, *args, **kwargs): | |||||
from easycv.core.evaluation.builder import build_evaluator | |||||
self.trainer = trainer | |||||
self.evaluators = build_evaluator(evaluators) | |||||
self.preds = [] | |||||
self.grountruths = [] | |||||
def add(self, outputs: Dict, inputs: Dict): | |||||
self.preds.append(outputs) | |||||
del inputs | |||||
def evaluate(self): | |||||
results = {} | |||||
for _, batch in enumerate(self.preds): | |||||
for k, v in batch.items(): | |||||
if k not in results: | |||||
results[k] = [] | |||||
results[k].append(v) | |||||
for k, v in results.items(): | |||||
if len(v) == 0: | |||||
raise ValueError(f'empty result for {k}') | |||||
if isinstance(v[0], torch.Tensor): | |||||
results[k] = torch.cat(v, 0) | |||||
elif isinstance(v[0], (list, np.ndarray)): | |||||
results[k] = list(itertools.chain.from_iterable(v)) | |||||
else: | |||||
raise ValueError( | |||||
f'value of batch prediction dict should only be tensor or list, {k} type is {v[0]}' | |||||
) | |||||
metric_values = self.trainer.eval_dataset.evaluate( | |||||
results, self.evaluators) | |||||
return metric_values |
@@ -0,0 +1,59 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import inspect | |||||
import logging | |||||
from modelscope.trainers.hooks import HOOKS | |||||
from modelscope.trainers.parallel.builder import PARALLEL | |||||
def register_parallel(): | |||||
from mmcv.parallel import MMDistributedDataParallel, MMDataParallel | |||||
PARALLEL.register_module( | |||||
module_name='MMDistributedDataParallel', | |||||
module_cls=MMDistributedDataParallel) | |||||
PARALLEL.register_module( | |||||
module_name='MMDataParallel', module_cls=MMDataParallel) | |||||
def register_hook_to_ms(hook_name, logger=None): | |||||
"""Register EasyCV hook to ModelScope.""" | |||||
from easycv.hooks import HOOKS as _EV_HOOKS | |||||
if hook_name not in _EV_HOOKS._module_dict: | |||||
raise ValueError( | |||||
f'Not found hook "{hook_name}" in EasyCV hook registries!') | |||||
obj = _EV_HOOKS._module_dict[hook_name] | |||||
HOOKS.register_module(module_name=hook_name, module_cls=obj) | |||||
log_str = f'Register hook "{hook_name}" to modelscope hooks.' | |||||
logger.info(log_str) if logger is not None else logging.info(log_str) | |||||
def register_part_mmcv_hooks_to_ms(): | |||||
"""Register required mmcv hooks to ModelScope. | |||||
Currently we only registered all lr scheduler hooks in EasyCV and mmcv. | |||||
Please refer to: | |||||
EasyCV: https://github.com/alibaba/EasyCV/blob/master/easycv/hooks/lr_update_hook.py | |||||
mmcv: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/lr_updater.py | |||||
""" | |||||
from mmcv.runner.hooks import lr_updater | |||||
from mmcv.runner.hooks import HOOKS as _MMCV_HOOKS | |||||
from easycv.hooks import StepFixCosineAnnealingLrUpdaterHook, YOLOXLrUpdaterHook | |||||
from easycv.hooks.logger import PreLoggerHook | |||||
mmcv_hooks_in_easycv = [('StepFixCosineAnnealingLrUpdaterHook', | |||||
StepFixCosineAnnealingLrUpdaterHook), | |||||
('YOLOXLrUpdaterHook', YOLOXLrUpdaterHook), | |||||
('PreLoggerHook', PreLoggerHook)] | |||||
members = inspect.getmembers(lr_updater) | |||||
members.extend(mmcv_hooks_in_easycv) | |||||
for name, obj in members: | |||||
if name in _MMCV_HOOKS._module_dict: | |||||
HOOKS.register_module( | |||||
module_name=name, | |||||
module_cls=obj, | |||||
) |
@@ -81,12 +81,19 @@ class CheckpointHook(Hook): | |||||
if self.is_last_epoch(trainer) and self.by_epoch: | if self.is_last_epoch(trainer) and self.by_epoch: | ||||
output_dir = os.path.join(self.save_dir, | output_dir = os.path.join(self.save_dir, | ||||
ModelFile.TRAIN_OUTPUT_DIR) | ModelFile.TRAIN_OUTPUT_DIR) | ||||
trainer.model.save_pretrained( | |||||
output_dir, | |||||
ModelFile.TORCH_MODEL_BIN_FILE, | |||||
save_function=save_checkpoint, | |||||
config=trainer.cfg.to_dict()) | |||||
from modelscope.trainers.parallel.utils import is_parallel | |||||
if is_parallel(trainer.model): | |||||
model = trainer.model.module | |||||
else: | |||||
model = trainer.model | |||||
if hasattr(model, 'save_pretrained'): | |||||
model.save_pretrained( | |||||
output_dir, | |||||
ModelFile.TORCH_MODEL_BIN_FILE, | |||||
save_function=save_checkpoint, | |||||
config=trainer.cfg.to_dict()) | |||||
def after_train_iter(self, trainer): | def after_train_iter(self, trainer): | ||||
if self.by_epoch: | if self.by_epoch: | ||||
@@ -60,6 +60,18 @@ class LoggerHook(Hook): | |||||
else: | else: | ||||
return False | return False | ||||
def fetch_tensor(self, trainer, n=0): | |||||
"""Fetch latest n values or all values, process tensor type, convert to numpy for dump logs.""" | |||||
assert n >= 0 | |||||
for key in trainer.log_buffer.val_history: | |||||
values = trainer.log_buffer.val_history[key][-n:] | |||||
for i, v in enumerate(values): | |||||
if isinstance(v, torch.Tensor): | |||||
values[i] = v.clone().detach().cpu().numpy() | |||||
trainer.log_buffer.val_history[key][-n:] = values | |||||
def get_epoch(self, trainer): | def get_epoch(self, trainer): | ||||
if trainer.mode in [ModeKeys.TRAIN, ModeKeys.EVAL]: | if trainer.mode in [ModeKeys.TRAIN, ModeKeys.EVAL]: | ||||
epoch = trainer.epoch + 1 | epoch = trainer.epoch + 1 | ||||
@@ -88,11 +100,14 @@ class LoggerHook(Hook): | |||||
def after_train_iter(self, trainer): | def after_train_iter(self, trainer): | ||||
if self.by_epoch and self.every_n_inner_iters(trainer, self.interval): | if self.by_epoch and self.every_n_inner_iters(trainer, self.interval): | ||||
self.fetch_tensor(trainer, self.interval) | |||||
trainer.log_buffer.average(self.interval) | trainer.log_buffer.average(self.interval) | ||||
elif not self.by_epoch and self.every_n_iters(trainer, self.interval): | elif not self.by_epoch and self.every_n_iters(trainer, self.interval): | ||||
self.fetch_tensor(trainer, self.interval) | |||||
trainer.log_buffer.average(self.interval) | trainer.log_buffer.average(self.interval) | ||||
elif self.end_of_epoch(trainer) and not self.ignore_last: | elif self.end_of_epoch(trainer) and not self.ignore_last: | ||||
# not precise but more stable | # not precise but more stable | ||||
self.fetch_tensor(trainer, self.interval) | |||||
trainer.log_buffer.average(self.interval) | trainer.log_buffer.average(self.interval) | ||||
if trainer.log_buffer.ready: | if trainer.log_buffer.ready: | ||||
@@ -107,6 +122,7 @@ class LoggerHook(Hook): | |||||
trainer.log_buffer.clear_output() | trainer.log_buffer.clear_output() | ||||
def after_val_epoch(self, trainer): | def after_val_epoch(self, trainer): | ||||
self.fetch_tensor(trainer) | |||||
trainer.log_buffer.average() | trainer.log_buffer.average() | ||||
self.log(trainer) | self.log(trainer) | ||||
if self.reset_flag: | if self.reset_flag: | ||||
@@ -26,7 +26,6 @@ from modelscope.msdatasets.task_datasets.torch_base_dataset import \ | |||||
TorchTaskDataset | TorchTaskDataset | ||||
from modelscope.preprocessors.base import Preprocessor | from modelscope.preprocessors.base import Preprocessor | ||||
from modelscope.preprocessors.builder import build_preprocessor | from modelscope.preprocessors.builder import build_preprocessor | ||||
from modelscope.preprocessors.common import Compose | |||||
from modelscope.trainers.hooks.builder import HOOKS | from modelscope.trainers.hooks.builder import HOOKS | ||||
from modelscope.trainers.hooks.priority import Priority, get_priority | from modelscope.trainers.hooks.priority import Priority, get_priority | ||||
from modelscope.trainers.lrscheduler.builder import build_lr_scheduler | from modelscope.trainers.lrscheduler.builder import build_lr_scheduler | ||||
@@ -83,7 +82,8 @@ class EpochBasedTrainer(BaseTrainer): | |||||
model: Optional[Union[TorchModel, nn.Module, str]] = None, | model: Optional[Union[TorchModel, nn.Module, str]] = None, | ||||
cfg_file: Optional[str] = None, | cfg_file: Optional[str] = None, | ||||
arg_parse_fn: Optional[Callable] = None, | arg_parse_fn: Optional[Callable] = None, | ||||
data_collator: Optional[Callable] = None, | |||||
data_collator: Optional[Union[Callable, Dict[str, | |||||
Callable]]] = None, | |||||
train_dataset: Optional[Union[MsDataset, Dataset]] = None, | train_dataset: Optional[Union[MsDataset, Dataset]] = None, | ||||
eval_dataset: Optional[Union[MsDataset, Dataset]] = None, | eval_dataset: Optional[Union[MsDataset, Dataset]] = None, | ||||
preprocessor: Optional[Union[Preprocessor, | preprocessor: Optional[Union[Preprocessor, | ||||
@@ -104,21 +104,24 @@ class EpochBasedTrainer(BaseTrainer): | |||||
if cfg_file is None: | if cfg_file is None: | ||||
cfg_file = os.path.join(self.model_dir, | cfg_file = os.path.join(self.model_dir, | ||||
ModelFile.CONFIGURATION) | ModelFile.CONFIGURATION) | ||||
self.model = self.build_model() | |||||
else: | else: | ||||
assert cfg_file is not None, 'Config file should not be None if model is an nn.Module class' | |||||
assert isinstance( | |||||
model, | |||||
(TorchModel, nn.Module | |||||
)), 'model should be either str, TorchMode or nn.Module.' | |||||
assert cfg_file is not None, 'Config file should not be None if model is not from pretrained!' | |||||
self.model_dir = os.path.dirname(cfg_file) | self.model_dir = os.path.dirname(cfg_file) | ||||
self.model = model | |||||
super().__init__(cfg_file, arg_parse_fn) | super().__init__(cfg_file, arg_parse_fn) | ||||
# add default config | # add default config | ||||
self.cfg.merge_from_dict(self._get_default_config(), force=False) | self.cfg.merge_from_dict(self._get_default_config(), force=False) | ||||
self.cfg = self.rebuild_config(self.cfg) | self.cfg = self.rebuild_config(self.cfg) | ||||
if 'cfg_options' in kwargs: | |||||
self.cfg.merge_from_dict(kwargs['cfg_options']) | |||||
if isinstance(model, (TorchModel, nn.Module)): | |||||
self.model = model | |||||
else: | |||||
self.model = self.build_model() | |||||
if 'work_dir' in kwargs: | if 'work_dir' in kwargs: | ||||
self.work_dir = kwargs['work_dir'] | self.work_dir = kwargs['work_dir'] | ||||
else: | else: | ||||
@@ -162,7 +165,24 @@ class EpochBasedTrainer(BaseTrainer): | |||||
mode=ModeKeys.EVAL, | mode=ModeKeys.EVAL, | ||||
preprocessor=self.eval_preprocessor) | preprocessor=self.eval_preprocessor) | ||||
self.data_collator = data_collator if data_collator is not None else default_collate | |||||
self.train_data_collator, self.eval_default_collate = None, None | |||||
if isinstance(data_collator, Mapping): | |||||
if not (ConfigKeys.train in data_collator | |||||
or ConfigKeys.val in data_collator): | |||||
raise ValueError( | |||||
f'data_collator must split with `{ConfigKeys.train}` and `{ConfigKeys.val}` keys!' | |||||
) | |||||
if ConfigKeys.train in data_collator: | |||||
assert isinstance(data_collator[ConfigKeys.train], Callable) | |||||
self.train_data_collator = data_collator[ConfigKeys.train] | |||||
if ConfigKeys.val in data_collator: | |||||
assert isinstance(data_collator[ConfigKeys.val], Callable) | |||||
self.eval_data_collator = data_collator[ConfigKeys.val] | |||||
else: | |||||
collate_fn = default_collate if data_collator is None else data_collator | |||||
self.train_data_collator = collate_fn | |||||
self.eval_data_collator = collate_fn | |||||
self.metrics = self.get_metrics() | self.metrics = self.get_metrics() | ||||
self._metric_values = None | self._metric_values = None | ||||
self.optimizers = optimizers | self.optimizers = optimizers | ||||
@@ -364,7 +384,7 @@ class EpochBasedTrainer(BaseTrainer): | |||||
return train_preprocessor, eval_preprocessor | return train_preprocessor, eval_preprocessor | ||||
def get_metrics(self) -> List[str]: | |||||
def get_metrics(self) -> List[Union[str, Dict]]: | |||||
"""Get the metric class types. | """Get the metric class types. | ||||
The first choice will be the metrics configured in the config file, if not found, the default metrics will be | The first choice will be the metrics configured in the config file, if not found, the default metrics will be | ||||
@@ -384,7 +404,7 @@ class EpochBasedTrainer(BaseTrainer): | |||||
f'Metrics are needed in evaluation, please try to either ' | f'Metrics are needed in evaluation, please try to either ' | ||||
f'add metrics in configuration.json or add the default metric for {self.cfg.task}.' | f'add metrics in configuration.json or add the default metric for {self.cfg.task}.' | ||||
) | ) | ||||
if isinstance(metrics, str): | |||||
if isinstance(metrics, (str, Mapping)): | |||||
metrics = [metrics] | metrics = [metrics] | ||||
return metrics | return metrics | ||||
@@ -399,6 +419,7 @@ class EpochBasedTrainer(BaseTrainer): | |||||
self.train_dataset, | self.train_dataset, | ||||
dist=self._dist, | dist=self._dist, | ||||
seed=self._seed, | seed=self._seed, | ||||
collate_fn=self.train_data_collator, | |||||
**self.cfg.train.get('dataloader', {})) | **self.cfg.train.get('dataloader', {})) | ||||
self.data_loader = self.train_dataloader | self.data_loader = self.train_dataloader | ||||
@@ -418,6 +439,7 @@ class EpochBasedTrainer(BaseTrainer): | |||||
self.eval_dataset, | self.eval_dataset, | ||||
dist=self._dist, | dist=self._dist, | ||||
seed=self._seed, | seed=self._seed, | ||||
collate_fn=self.eval_data_collator, | |||||
**self.cfg.evaluation.get('dataloader', {})) | **self.cfg.evaluation.get('dataloader', {})) | ||||
self.data_loader = self.eval_dataloader | self.data_loader = self.eval_dataloader | ||||
metric_classes = [build_metric(metric) for metric in self.metrics] | metric_classes = [build_metric(metric) for metric in self.metrics] | ||||
@@ -440,7 +462,7 @@ class EpochBasedTrainer(BaseTrainer): | |||||
override this method in a subclass. | override this method in a subclass. | ||||
""" | """ | ||||
model = Model.from_pretrained(self.model_dir) | |||||
model = Model.from_pretrained(self.model_dir, cfg_dict=self.cfg) | |||||
if not isinstance(model, nn.Module) and hasattr(model, 'model'): | if not isinstance(model, nn.Module) and hasattr(model, 'model'): | ||||
return model.model | return model.model | ||||
elif isinstance(model, nn.Module): | elif isinstance(model, nn.Module): | ||||
@@ -552,6 +574,7 @@ class EpochBasedTrainer(BaseTrainer): | |||||
self.train_dataset, | self.train_dataset, | ||||
dist=self._dist, | dist=self._dist, | ||||
seed=self._seed, | seed=self._seed, | ||||
collate_fn=self.train_data_collator, | |||||
**self.cfg.train.get('dataloader', {})) | **self.cfg.train.get('dataloader', {})) | ||||
return data_loader | return data_loader | ||||
@@ -569,9 +592,9 @@ class EpochBasedTrainer(BaseTrainer): | |||||
mode=ModeKeys.EVAL, | mode=ModeKeys.EVAL, | ||||
preprocessor=self.eval_preprocessor) | preprocessor=self.eval_preprocessor) | ||||
batch_size = self.cfg.evaluation.batch_size | |||||
workers = self.cfg.evaluation.workers | |||||
shuffle = self.cfg.evaluation.get('shuffle', False) | |||||
batch_size = self.cfg.evaluation.dataloader.batch_size_per_gpu | |||||
workers = self.cfg.evaluation.dataloader.workers_per_gpu | |||||
shuffle = self.cfg.evaluation.dataloader.get('shuffle', False) | |||||
data_loader = self._build_dataloader_with_dataset( | data_loader = self._build_dataloader_with_dataset( | ||||
self.eval_dataset, | self.eval_dataset, | ||||
batch_size_per_gpu=batch_size, | batch_size_per_gpu=batch_size, | ||||
@@ -580,25 +603,31 @@ class EpochBasedTrainer(BaseTrainer): | |||||
dist=self._dist, | dist=self._dist, | ||||
seed=self._seed, | seed=self._seed, | ||||
persistent_workers=True, | persistent_workers=True, | ||||
collate_fn=self.eval_data_collator, | |||||
) | ) | ||||
return data_loader | return data_loader | ||||
def build_dataset(self, data_cfg, mode, preprocessor=None): | def build_dataset(self, data_cfg, mode, preprocessor=None): | ||||
""" Build torch dataset object using data config | """ Build torch dataset object using data config | ||||
""" | """ | ||||
dataset = MsDataset.load( | |||||
dataset_name=data_cfg.name, | |||||
split=data_cfg.split, | |||||
subset_name=data_cfg.subset_name if hasattr( | |||||
data_cfg, 'subset_name') else None, | |||||
hub=data_cfg.hub if hasattr(data_cfg, 'hub') else Hubs.modelscope, | |||||
**data_cfg, | |||||
) | |||||
cfg = ConfigDict(type=self.cfg.model.type, mode=mode) | |||||
torch_dataset = dataset.to_torch_dataset( | |||||
task_data_config=cfg, | |||||
task_name=self.cfg.task, | |||||
preprocessors=self.preprocessor) | |||||
# TODO: support MsDataset load for cv | |||||
if hasattr(data_cfg, 'name'): | |||||
dataset = MsDataset.load( | |||||
dataset_name=data_cfg.name, | |||||
split=data_cfg.split, | |||||
subset_name=data_cfg.subset_name if hasattr( | |||||
data_cfg, 'subset_name') else None, | |||||
hub=data_cfg.hub | |||||
if hasattr(data_cfg, 'hub') else Hubs.modelscope, | |||||
**data_cfg, | |||||
) | |||||
cfg = ConfigDict(type=self.cfg.model.type, mode=mode) | |||||
torch_dataset = dataset.to_torch_dataset( | |||||
task_data_config=cfg, | |||||
task_name=self.cfg.task, | |||||
preprocessors=self.preprocessor) | |||||
else: | |||||
torch_dataset = build_task_dataset(data_cfg, self.cfg.task) | |||||
dataset = self.to_task_dataset(torch_dataset, mode) | dataset = self.to_task_dataset(torch_dataset, mode) | ||||
return dataset | return dataset | ||||
@@ -746,7 +775,6 @@ class EpochBasedTrainer(BaseTrainer): | |||||
sampler=sampler, | sampler=sampler, | ||||
num_workers=num_workers, | num_workers=num_workers, | ||||
batch_sampler=batch_sampler, | batch_sampler=batch_sampler, | ||||
collate_fn=self.data_collator, | |||||
pin_memory=kwargs.pop('pin_memory', False), | pin_memory=kwargs.pop('pin_memory', False), | ||||
worker_init_fn=init_fn, | worker_init_fn=init_fn, | ||||
**kwargs) | **kwargs) | ||||
@@ -820,12 +848,14 @@ class EpochBasedTrainer(BaseTrainer): | |||||
Args: | Args: | ||||
hook (:obj:`Hook`): The hook to be registered. | hook (:obj:`Hook`): The hook to be registered. | ||||
""" | """ | ||||
assert isinstance(hook, Hook) | |||||
# insert the hook to a sorted list | # insert the hook to a sorted list | ||||
inserted = False | inserted = False | ||||
for i in range(len(self._hooks) - 1, -1, -1): | for i in range(len(self._hooks) - 1, -1, -1): | ||||
if get_priority(hook.PRIORITY) > get_priority( | |||||
self._hooks[i].PRIORITY): | |||||
p = hook.PRIORITY if hasattr(hook, 'PRIORITY') else Priority.NORMAL | |||||
p_i = self._hooks[i].PRIORITY if hasattr( | |||||
self._hooks[i], 'PRIORITY') else Priority.NORMAL | |||||
if get_priority(p) > get_priority(p_i): | |||||
self._hooks.insert(i + 1, hook) | self._hooks.insert(i + 1, hook) | ||||
inserted = True | inserted = True | ||||
break | break | ||||
@@ -15,9 +15,9 @@ import json | |||||
from modelscope import __version__ | from modelscope import __version__ | ||||
from modelscope.fileio.file import LocalStorage | from modelscope.fileio.file import LocalStorage | ||||
from modelscope.metainfo import (Heads, Hooks, LR_Schedulers, Metrics, Models, | |||||
Optimizers, Pipelines, Preprocessors, | |||||
TaskModels, Trainers) | |||||
from modelscope.metainfo import (Datasets, Heads, Hooks, LR_Schedulers, | |||||
Metrics, Models, Optimizers, Pipelines, | |||||
Preprocessors, TaskModels, Trainers) | |||||
from modelscope.utils.constant import Fields, Tasks | from modelscope.utils.constant import Fields, Tasks | ||||
from modelscope.utils.file_utils import get_default_cache_dir | from modelscope.utils.file_utils import get_default_cache_dir | ||||
from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
@@ -32,8 +32,7 @@ MODELSCOPE_PATH = p.resolve().parents[1] | |||||
REGISTER_MODULE = 'register_module' | REGISTER_MODULE = 'register_module' | ||||
IGNORED_PACKAGES = ['modelscope', '.'] | IGNORED_PACKAGES = ['modelscope', '.'] | ||||
SCAN_SUB_FOLDERS = [ | SCAN_SUB_FOLDERS = [ | ||||
'models', 'metrics', 'pipelines', 'preprocessors', | |||||
'msdatasets/task_datasets', 'trainers' | |||||
'models', 'metrics', 'pipelines', 'preprocessors', 'trainers', 'msdatasets' | |||||
] | ] | ||||
INDEXER_FILE = 'ast_indexer' | INDEXER_FILE = 'ast_indexer' | ||||
DECORATOR_KEY = 'decorators' | DECORATOR_KEY = 'decorators' | ||||
@@ -14,7 +14,7 @@ mmcls>=0.21.0 | |||||
mmdet>=2.25.0 | mmdet>=2.25.0 | ||||
networkx>=2.5 | networkx>=2.5 | ||||
onnxruntime>=1.10 | onnxruntime>=1.10 | ||||
pai-easycv>=0.5 | |||||
pai-easycv>=0.6.0 | |||||
pandas | pandas | ||||
psutil | psutil | ||||
regex | regex | ||||
@@ -4,6 +4,7 @@ easydict | |||||
einops | einops | ||||
filelock>=3.3.0 | filelock>=3.3.0 | ||||
gast>=0.2.2 | gast>=0.2.2 | ||||
jsonplus | |||||
numpy | numpy | ||||
opencv-python | opencv-python | ||||
oss2 | oss2 | ||||
@@ -0,0 +1,35 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import unittest | |||||
import numpy as np | |||||
from PIL import Image | |||||
from modelscope.metainfo import Pipelines | |||||
from modelscope.pipelines import pipeline | |||||
from modelscope.utils.constant import Tasks | |||||
from modelscope.utils.test_utils import test_level | |||||
class EasyCVSegmentationPipelineTest(unittest.TestCase): | |||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
def test_segformer_b0(self): | |||||
img_path = 'data/test/images/image_segmentation.jpg' | |||||
model_id = 'EasyCV/EasyCV-Segformer-b0' | |||||
img = np.asarray(Image.open(img_path)) | |||||
object_detect = pipeline(task=Tasks.image_segmentation, model=model_id) | |||||
outputs = object_detect(img_path) | |||||
self.assertEqual(len(outputs), 1) | |||||
results = outputs[0] | |||||
self.assertListEqual( | |||||
list(img.shape)[:2], list(results['seg_pred'][0].shape)) | |||||
self.assertListEqual(results['seg_pred'][0][1, :10].tolist(), | |||||
[161 for i in range(10)]) | |||||
self.assertListEqual(results['seg_pred'][0][-1, -10:].tolist(), | |||||
[133 for i in range(10)]) | |||||
if __name__ == '__main__': | |||||
unittest.main() |
@@ -0,0 +1,244 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import glob | |||||
import os | |||||
import shutil | |||||
import tempfile | |||||
import unittest | |||||
import json | |||||
import requests | |||||
import torch | |||||
from modelscope.metainfo import Models, Pipelines, Trainers | |||||
from modelscope.trainers import build_trainer | |||||
from modelscope.utils.config import Config | |||||
from modelscope.utils.constant import LogKeys, ModeKeys, Tasks | |||||
from modelscope.utils.logger import get_logger | |||||
from modelscope.utils.test_utils import DistributedTestCase, test_level | |||||
from modelscope.utils.torch_utils import is_master | |||||
def _download_data(url, save_dir): | |||||
r = requests.get(url, verify=True) | |||||
if not os.path.exists(save_dir): | |||||
os.makedirs(save_dir) | |||||
zip_name = os.path.split(url)[-1] | |||||
save_path = os.path.join(save_dir, zip_name) | |||||
with open(save_path, 'wb') as f: | |||||
f.write(r.content) | |||||
unpack_dir = os.path.join(save_dir, os.path.splitext(zip_name)[0]) | |||||
shutil.unpack_archive(save_path, unpack_dir) | |||||
def train_func(work_dir, dist=False, log_config=3, imgs_per_gpu=4): | |||||
import easycv | |||||
config_path = os.path.join( | |||||
os.path.dirname(easycv.__file__), | |||||
'configs/detection/yolox/yolox_s_8xb16_300e_coco.py') | |||||
data_dir = os.path.join(work_dir, 'small_coco_test') | |||||
url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/datasets/small_coco.zip' | |||||
if is_master(): | |||||
_download_data(url, data_dir) | |||||
import time | |||||
time.sleep(1) | |||||
cfg = Config.from_file(config_path) | |||||
cfg.work_dir = work_dir | |||||
cfg.total_epochs = 2 | |||||
cfg.checkpoint_config.interval = 1 | |||||
cfg.eval_config.interval = 1 | |||||
cfg.log_config = dict( | |||||
interval=log_config, | |||||
hooks=[ | |||||
dict(type='TextLoggerHook'), | |||||
dict(type='TensorboardLoggerHook') | |||||
]) | |||||
cfg.data.train.data_source.ann_file = os.path.join( | |||||
data_dir, 'small_coco/small_coco/instances_train2017_20.json') | |||||
cfg.data.train.data_source.img_prefix = os.path.join( | |||||
data_dir, 'small_coco/small_coco/train2017') | |||||
cfg.data.val.data_source.ann_file = os.path.join( | |||||
data_dir, 'small_coco/small_coco/instances_val2017_20.json') | |||||
cfg.data.val.data_source.img_prefix = os.path.join( | |||||
data_dir, 'small_coco/small_coco/val2017') | |||||
cfg.data.imgs_per_gpu = imgs_per_gpu | |||||
cfg.data.workers_per_gpu = 2 | |||||
cfg.data.val.imgs_per_gpu = 2 | |||||
ms_cfg_file = os.path.join(work_dir, 'ms_yolox_s_8xb16_300e_coco.json') | |||||
from easycv.utils.ms_utils import to_ms_config | |||||
if is_master(): | |||||
to_ms_config( | |||||
cfg, | |||||
dump=True, | |||||
task=Tasks.image_object_detection, | |||||
ms_model_name=Models.yolox, | |||||
pipeline_name=Pipelines.easycv_detection, | |||||
save_path=ms_cfg_file) | |||||
trainer_name = Trainers.easycv | |||||
kwargs = dict( | |||||
task=Tasks.image_object_detection, | |||||
cfg_file=ms_cfg_file, | |||||
launcher='pytorch' if dist else None) | |||||
trainer = build_trainer(trainer_name, kwargs) | |||||
trainer.train() | |||||
@unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest') | |||||
class EasyCVTrainerTestSingleGpu(unittest.TestCase): | |||||
def setUp(self): | |||||
self.logger = get_logger() | |||||
self.logger.info(('Testing %s.%s' % | |||||
(type(self).__name__, self._testMethodName))) | |||||
self.tmp_dir = tempfile.TemporaryDirectory().name | |||||
if not os.path.exists(self.tmp_dir): | |||||
os.makedirs(self.tmp_dir) | |||||
def tearDown(self): | |||||
super().tearDown() | |||||
shutil.rmtree(self.tmp_dir, ignore_errors=True) | |||||
@unittest.skipIf( | |||||
True, 'The test cases are all run in the master process, ' | |||||
'cause registry conflicts, and it should run in the subprocess.') | |||||
def test_single_gpu(self): | |||||
# TODO: run in subprocess | |||||
train_func(self.tmp_dir) | |||||
results_files = os.listdir(self.tmp_dir) | |||||
json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) | |||||
self.assertEqual(len(json_files), 1) | |||||
with open(json_files[0], 'r') as f: | |||||
lines = [i.strip() for i in f.readlines()] | |||||
self.assertDictContainsSubset( | |||||
{ | |||||
LogKeys.MODE: ModeKeys.TRAIN, | |||||
LogKeys.EPOCH: 1, | |||||
LogKeys.ITER: 3, | |||||
LogKeys.LR: 0.00013 | |||||
}, json.loads(lines[0])) | |||||
self.assertDictContainsSubset( | |||||
{ | |||||
LogKeys.MODE: ModeKeys.EVAL, | |||||
LogKeys.EPOCH: 1, | |||||
LogKeys.ITER: 10 | |||||
}, json.loads(lines[1])) | |||||
self.assertDictContainsSubset( | |||||
{ | |||||
LogKeys.MODE: ModeKeys.TRAIN, | |||||
LogKeys.EPOCH: 2, | |||||
LogKeys.ITER: 3, | |||||
LogKeys.LR: 0.00157 | |||||
}, json.loads(lines[2])) | |||||
self.assertDictContainsSubset( | |||||
{ | |||||
LogKeys.MODE: ModeKeys.EVAL, | |||||
LogKeys.EPOCH: 2, | |||||
LogKeys.ITER: 10 | |||||
}, json.loads(lines[3])) | |||||
self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) | |||||
self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | |||||
for i in [0, 2]: | |||||
self.assertIn(LogKeys.DATA_LOAD_TIME, lines[i]) | |||||
self.assertIn(LogKeys.ITER_TIME, lines[i]) | |||||
self.assertIn(LogKeys.MEMORY, lines[i]) | |||||
self.assertIn('total_loss', lines[i]) | |||||
for i in [1, 3]: | |||||
self.assertIn( | |||||
'CocoDetectionEvaluator_DetectionBoxes_Precision/mAP', | |||||
lines[i]) | |||||
self.assertIn('DetectionBoxes_Precision/mAP', lines[i]) | |||||
self.assertIn('DetectionBoxes_Precision/mAP@.50IOU', lines[i]) | |||||
self.assertIn('DetectionBoxes_Precision/mAP@.75IOU', lines[i]) | |||||
self.assertIn('DetectionBoxes_Precision/mAP (small)', lines[i]) | |||||
@unittest.skipIf(not torch.cuda.is_available() | |||||
or torch.cuda.device_count() <= 1, 'distributed unittest') | |||||
class EasyCVTrainerTestMultiGpus(DistributedTestCase): | |||||
def setUp(self): | |||||
self.logger = get_logger() | |||||
self.logger.info(('Testing %s.%s' % | |||||
(type(self).__name__, self._testMethodName))) | |||||
self.tmp_dir = tempfile.TemporaryDirectory().name | |||||
if not os.path.exists(self.tmp_dir): | |||||
os.makedirs(self.tmp_dir) | |||||
def tearDown(self): | |||||
super().tearDown() | |||||
shutil.rmtree(self.tmp_dir, ignore_errors=True) | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
def test_multi_gpus(self): | |||||
self.start( | |||||
train_func, | |||||
num_gpus=2, | |||||
work_dir=self.tmp_dir, | |||||
dist=True, | |||||
log_config=2, | |||||
imgs_per_gpu=5) | |||||
results_files = os.listdir(self.tmp_dir) | |||||
json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) | |||||
self.assertEqual(len(json_files), 1) | |||||
with open(json_files[0], 'r') as f: | |||||
lines = [i.strip() for i in f.readlines()] | |||||
self.assertDictContainsSubset( | |||||
{ | |||||
LogKeys.MODE: ModeKeys.TRAIN, | |||||
LogKeys.EPOCH: 1, | |||||
LogKeys.ITER: 2, | |||||
LogKeys.LR: 0.0002 | |||||
}, json.loads(lines[0])) | |||||
self.assertDictContainsSubset( | |||||
{ | |||||
LogKeys.MODE: ModeKeys.EVAL, | |||||
LogKeys.EPOCH: 1, | |||||
LogKeys.ITER: 5 | |||||
}, json.loads(lines[1])) | |||||
self.assertDictContainsSubset( | |||||
{ | |||||
LogKeys.MODE: ModeKeys.TRAIN, | |||||
LogKeys.EPOCH: 2, | |||||
LogKeys.ITER: 2, | |||||
LogKeys.LR: 0.0018 | |||||
}, json.loads(lines[2])) | |||||
self.assertDictContainsSubset( | |||||
{ | |||||
LogKeys.MODE: ModeKeys.EVAL, | |||||
LogKeys.EPOCH: 2, | |||||
LogKeys.ITER: 5 | |||||
}, json.loads(lines[3])) | |||||
self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) | |||||
self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | |||||
for i in [0, 2]: | |||||
self.assertIn(LogKeys.DATA_LOAD_TIME, lines[i]) | |||||
self.assertIn(LogKeys.ITER_TIME, lines[i]) | |||||
self.assertIn(LogKeys.MEMORY, lines[i]) | |||||
self.assertIn('total_loss', lines[i]) | |||||
for i in [1, 3]: | |||||
self.assertIn( | |||||
'CocoDetectionEvaluator_DetectionBoxes_Precision/mAP', | |||||
lines[i]) | |||||
self.assertIn('DetectionBoxes_Precision/mAP', lines[i]) | |||||
self.assertIn('DetectionBoxes_Precision/mAP@.50IOU', lines[i]) | |||||
self.assertIn('DetectionBoxes_Precision/mAP@.75IOU', lines[i]) | |||||
self.assertIn('DetectionBoxes_Precision/mAP (small)', lines[i]) | |||||
if __name__ == '__main__': | |||||
unittest.main() |
@@ -0,0 +1,99 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||||
import glob | |||||
import os | |||||
import shutil | |||||
import tempfile | |||||
import unittest | |||||
import requests | |||||
import torch | |||||
from modelscope.metainfo import Trainers | |||||
from modelscope.trainers import build_trainer | |||||
from modelscope.utils.constant import LogKeys, Tasks | |||||
from modelscope.utils.logger import get_logger | |||||
from modelscope.utils.test_utils import test_level | |||||
from modelscope.utils.torch_utils import is_master | |||||
def _download_data(url, save_dir): | |||||
r = requests.get(url, verify=True) | |||||
if not os.path.exists(save_dir): | |||||
os.makedirs(save_dir) | |||||
zip_name = os.path.split(url)[-1] | |||||
save_path = os.path.join(save_dir, zip_name) | |||||
with open(save_path, 'wb') as f: | |||||
f.write(r.content) | |||||
unpack_dir = os.path.join(save_dir, os.path.splitext(zip_name)[0]) | |||||
shutil.unpack_archive(save_path, unpack_dir) | |||||
@unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest') | |||||
class EasyCVTrainerTestSegformer(unittest.TestCase): | |||||
def setUp(self): | |||||
self.logger = get_logger() | |||||
self.logger.info(('Testing %s.%s' % | |||||
(type(self).__name__, self._testMethodName))) | |||||
self.tmp_dir = tempfile.TemporaryDirectory().name | |||||
if not os.path.exists(self.tmp_dir): | |||||
os.makedirs(self.tmp_dir) | |||||
def tearDown(self): | |||||
super().tearDown() | |||||
shutil.rmtree(self.tmp_dir, ignore_errors=True) | |||||
def _train(self): | |||||
from modelscope.trainers.easycv.trainer import EasyCVEpochBasedTrainer | |||||
url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/datasets/small_coco_stuff164k.zip' | |||||
data_dir = os.path.join(self.tmp_dir, 'data') | |||||
if is_master(): | |||||
_download_data(url, data_dir) | |||||
# adapt to ditributed mode | |||||
from easycv.utils.test_util import pseudo_dist_init | |||||
pseudo_dist_init() | |||||
root_path = os.path.join(data_dir, 'small_coco_stuff164k') | |||||
cfg_options = { | |||||
'train.max_epochs': | |||||
2, | |||||
'dataset.train.data_source.img_root': | |||||
os.path.join(root_path, 'train2017'), | |||||
'dataset.train.data_source.label_root': | |||||
os.path.join(root_path, 'annotations/train2017'), | |||||
'dataset.train.data_source.split': | |||||
os.path.join(root_path, 'train.txt'), | |||||
'dataset.val.data_source.img_root': | |||||
os.path.join(root_path, 'val2017'), | |||||
'dataset.val.data_source.label_root': | |||||
os.path.join(root_path, 'annotations/val2017'), | |||||
'dataset.val.data_source.split': | |||||
os.path.join(root_path, 'val.txt'), | |||||
} | |||||
trainer_name = Trainers.easycv | |||||
kwargs = dict( | |||||
task=Tasks.image_segmentation, | |||||
model='EasyCV/EasyCV-Segformer-b0', | |||||
work_dir=self.tmp_dir, | |||||
cfg_options=cfg_options) | |||||
trainer = build_trainer(trainer_name, kwargs) | |||||
trainer.train() | |||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
def test_single_gpu_segformer(self): | |||||
self._train() | |||||
results_files = os.listdir(self.tmp_dir) | |||||
json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) | |||||
self.assertEqual(len(json_files), 1) | |||||
self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) | |||||
self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | |||||
if __name__ == '__main__': | |||||
unittest.main() |
@@ -4,6 +4,8 @@ import copy | |||||
import tempfile | import tempfile | ||||
import unittest | import unittest | ||||
import json | |||||
from modelscope.utils.config import Config, check_config | from modelscope.utils.config import Config, check_config | ||||
obj = {'a': 1, 'b': {'c': [1, 2, 3], 'd': 'dd'}} | obj = {'a': 1, 'b': {'c': [1, 2, 3], 'd': 'dd'}} | ||||
@@ -43,7 +45,8 @@ class ConfigTest(unittest.TestCase): | |||||
self.assertEqual(pretty_text, cfg.dump()) | self.assertEqual(pretty_text, cfg.dump()) | ||||
cfg.dump(ofile.name) | cfg.dump(ofile.name) | ||||
with open(ofile.name, 'r') as infile: | with open(ofile.name, 'r') as infile: | ||||
self.assertEqual(json_str, infile.read()) | |||||
self.assertDictEqual( | |||||
json.loads(json_str), json.loads(infile.read())) | |||||
with tempfile.NamedTemporaryFile(suffix='.yaml') as ofile: | with tempfile.NamedTemporaryFile(suffix='.yaml') as ofile: | ||||
cfg.dump(ofile.name) | cfg.dump(ofile.name) | ||||