@@ -248,15 +248,15 @@ class MsDataset: | |||||
break | break | ||||
target_subset_name, target_dataset_structure = get_target_dataset_structure( | target_subset_name, target_dataset_structure = get_target_dataset_structure( | ||||
dataset_json, subset_name, split) | dataset_json, subset_name, split) | ||||
meta_map, file_map = get_dataset_files(target_dataset_structure, | |||||
dataset_name, namespace, | |||||
version) | |||||
meta_map, file_map, args_map = get_dataset_files( | |||||
target_dataset_structure, dataset_name, namespace, version) | |||||
builder = load_dataset_builder( | builder = load_dataset_builder( | ||||
dataset_name, | dataset_name, | ||||
subset_name, | subset_name, | ||||
namespace, | namespace, | ||||
meta_data_files=meta_map, | meta_data_files=meta_map, | ||||
zip_data_files=file_map, | zip_data_files=file_map, | ||||
args_map=args_map, | |||||
cache_dir=MS_DATASETS_CACHE, | cache_dir=MS_DATASETS_CACHE, | ||||
version=version, | version=version, | ||||
split=list(target_dataset_structure.keys()), | split=list(target_dataset_structure.keys()), | ||||
@@ -60,6 +60,8 @@ class ImageInstanceSegmentationCocoDataset(TorchTaskDataset): | |||||
classes=None, | classes=None, | ||||
seg_prefix=None, | seg_prefix=None, | ||||
folder_name=None, | folder_name=None, | ||||
ann_file=None, | |||||
img_prefix=None, | |||||
test_mode=False, | test_mode=False, | ||||
filter_empty_gt=True, | filter_empty_gt=True, | ||||
**kwargs): | **kwargs): | ||||
@@ -69,11 +71,9 @@ class ImageInstanceSegmentationCocoDataset(TorchTaskDataset): | |||||
self.split = next(iter(split_config.keys())) | self.split = next(iter(split_config.keys())) | ||||
self.preprocessor = preprocessor | self.preprocessor = preprocessor | ||||
self.ann_file = osp.join(self.data_root, | |||||
DATASET_STRUCTURE[self.split]['annotation']) | |||||
self.ann_file = osp.join(self.data_root, ann_file) | |||||
self.img_prefix = osp.join(self.data_root, | |||||
DATASET_STRUCTURE[self.split]['images']) | |||||
self.img_prefix = osp.join(self.data_root, img_prefix) | |||||
self.seg_prefix = seg_prefix | self.seg_prefix = seg_prefix | ||||
self.test_mode = test_mode | self.test_mode = test_mode | ||||
self.filter_empty_gt = filter_empty_gt | self.filter_empty_gt = filter_empty_gt | ||||
@@ -1,6 +1,6 @@ | |||||
import os | import os | ||||
from collections import defaultdict | from collections import defaultdict | ||||
from typing import Mapping, Optional, Sequence, Union | |||||
from typing import Any, Mapping, Optional, Sequence, Union | |||||
from datasets.builder import DatasetBuilder | from datasets.builder import DatasetBuilder | ||||
@@ -92,6 +92,7 @@ def get_dataset_files(subset_split_into: dict, | |||||
""" | """ | ||||
meta_map = defaultdict(dict) | meta_map = defaultdict(dict) | ||||
file_map = defaultdict(dict) | file_map = defaultdict(dict) | ||||
args_map = defaultdict(dict) | |||||
from modelscope.hub.api import HubApi | from modelscope.hub.api import HubApi | ||||
modelscope_api = HubApi() | modelscope_api = HubApi() | ||||
for split, info in subset_split_into.items(): | for split, info in subset_split_into.items(): | ||||
@@ -99,7 +100,8 @@ def get_dataset_files(subset_split_into: dict, | |||||
info.get('meta', ''), dataset_name, namespace, revision) | info.get('meta', ''), dataset_name, namespace, revision) | ||||
if info.get('file'): | if info.get('file'): | ||||
file_map[split] = info['file'] | file_map[split] = info['file'] | ||||
return meta_map, file_map | |||||
args_map[split] = info.get('args') | |||||
return meta_map, file_map, args_map | |||||
def load_dataset_builder(dataset_name: str, subset_name: str, namespace: str, | def load_dataset_builder(dataset_name: str, subset_name: str, namespace: str, | ||||
@@ -107,12 +109,16 @@ def load_dataset_builder(dataset_name: str, subset_name: str, namespace: str, | |||||
Sequence[str]]], | Sequence[str]]], | ||||
zip_data_files: Mapping[str, Union[str, | zip_data_files: Mapping[str, Union[str, | ||||
Sequence[str]]], | Sequence[str]]], | ||||
cache_dir: str, version: Optional[Union[str]], | |||||
split: Sequence[str], | |||||
args_map: Mapping[str, Any], cache_dir: str, | |||||
version: Optional[Union[str]], split: Sequence[str], | |||||
**config_kwargs) -> DatasetBuilder: | **config_kwargs) -> DatasetBuilder: | ||||
sub_dir = os.path.join(version, '_'.join(split)) | sub_dir = os.path.join(version, '_'.join(split)) | ||||
meta_data_file = next(iter(meta_data_files.values())) | meta_data_file = next(iter(meta_data_files.values())) | ||||
if not meta_data_file: | if not meta_data_file: | ||||
args_map = next(iter(args_map.values())) | |||||
if args_map is None: | |||||
args_map = {} | |||||
args_map.update(config_kwargs) | |||||
builder_instance = TaskSpecificDatasetBuilder( | builder_instance = TaskSpecificDatasetBuilder( | ||||
dataset_name=dataset_name, | dataset_name=dataset_name, | ||||
namespace=namespace, | namespace=namespace, | ||||
@@ -121,7 +127,7 @@ def load_dataset_builder(dataset_name: str, subset_name: str, namespace: str, | |||||
meta_data_files=meta_data_files, | meta_data_files=meta_data_files, | ||||
zip_data_files=zip_data_files, | zip_data_files=zip_data_files, | ||||
hash=sub_dir, | hash=sub_dir, | ||||
**config_kwargs) | |||||
**args_map) | |||||
elif meta_data_file.endswith('.csv'): | elif meta_data_file.endswith('.csv'): | ||||
builder_instance = MsCsvDatasetBuilder( | builder_instance = MsCsvDatasetBuilder( | ||||
dataset_name=dataset_name, | dataset_name=dataset_name, | ||||
@@ -36,9 +36,8 @@ class MsDatasetTest(unittest.TestCase): | |||||
ms_ds_train = MsDataset.load( | ms_ds_train = MsDataset.load( | ||||
'pets_small', | 'pets_small', | ||||
namespace=DEFAULT_DATASET_NAMESPACE, | namespace=DEFAULT_DATASET_NAMESPACE, | ||||
split='train', | |||||
classes=('1', '2'), | |||||
folder_name='Pets') | |||||
download_mode=DownloadMode.FORCE_REDOWNLOAD, | |||||
split='train') | |||||
print(ms_ds_train.config_kwargs) | print(ms_ds_train.config_kwargs) | ||||
assert next(iter(ms_ds_train.config_kwargs['split_config'].values())) | assert next(iter(ms_ds_train.config_kwargs['split_config'].values())) | ||||
@@ -20,7 +20,6 @@ class TestFinetuneMPlug(unittest.TestCase): | |||||
self.tmp_dir = tempfile.TemporaryDirectory().name | self.tmp_dir = tempfile.TemporaryDirectory().name | ||||
if not os.path.exists(self.tmp_dir): | if not os.path.exists(self.tmp_dir): | ||||
os.makedirs(self.tmp_dir) | os.makedirs(self.tmp_dir) | ||||
from modelscope.utils.constant import DownloadMode | from modelscope.utils.constant import DownloadMode | ||||
datadict = MsDataset.load( | datadict = MsDataset.load( | ||||
'coco_captions_small_slice', | 'coco_captions_small_slice', | ||||
@@ -15,7 +15,7 @@ from modelscope.msdatasets.task_datasets import \ | |||||
ImageInstanceSegmentationCocoDataset | ImageInstanceSegmentationCocoDataset | ||||
from modelscope.trainers import build_trainer | from modelscope.trainers import build_trainer | ||||
from modelscope.utils.config import Config, ConfigDict | from modelscope.utils.config import Config, ConfigDict | ||||
from modelscope.utils.constant import ModelFile | |||||
from modelscope.utils.constant import DownloadMode, ModelFile | |||||
from modelscope.utils.test_utils import test_level | from modelscope.utils.test_utils import test_level | ||||
@@ -41,38 +41,26 @@ class TestImageInstanceSegmentationTrainer(unittest.TestCase): | |||||
if train_data_cfg is None: | if train_data_cfg is None: | ||||
# use default toy data | # use default toy data | ||||
train_data_cfg = ConfigDict( | train_data_cfg = ConfigDict( | ||||
name='pets_small', | |||||
split='train', | |||||
classes=('Cat', 'Dog'), | |||||
folder_name='Pets', | |||||
test_mode=False) | |||||
name='pets_small', split='train', test_mode=False) | |||||
if val_data_cfg is None: | if val_data_cfg is None: | ||||
val_data_cfg = ConfigDict( | val_data_cfg = ConfigDict( | ||||
name='pets_small', | |||||
split='validation', | |||||
classes=('Cat', 'Dog'), | |||||
folder_name='Pets', | |||||
test_mode=True) | |||||
name='pets_small', split='validation', test_mode=True) | |||||
self.train_dataset = MsDataset.load( | self.train_dataset = MsDataset.load( | ||||
dataset_name=train_data_cfg.name, | dataset_name=train_data_cfg.name, | ||||
split=train_data_cfg.split, | split=train_data_cfg.split, | ||||
classes=train_data_cfg.classes, | |||||
folder_name=train_data_cfg.folder_name, | |||||
test_mode=train_data_cfg.test_mode) | |||||
assert self.train_dataset.config_kwargs[ | |||||
'classes'] == train_data_cfg.classes | |||||
test_mode=train_data_cfg.test_mode, | |||||
download_mode=DownloadMode.FORCE_REDOWNLOAD) | |||||
assert self.train_dataset.config_kwargs['classes'] | |||||
assert next( | assert next( | ||||
iter(self.train_dataset.config_kwargs['split_config'].values())) | iter(self.train_dataset.config_kwargs['split_config'].values())) | ||||
self.eval_dataset = MsDataset.load( | self.eval_dataset = MsDataset.load( | ||||
dataset_name=val_data_cfg.name, | dataset_name=val_data_cfg.name, | ||||
split=val_data_cfg.split, | split=val_data_cfg.split, | ||||
classes=val_data_cfg.classes, | |||||
folder_name=val_data_cfg.folder_name, | |||||
test_mode=val_data_cfg.test_mode) | |||||
assert self.eval_dataset.config_kwargs[ | |||||
'classes'] == val_data_cfg.classes | |||||
test_mode=val_data_cfg.test_mode, | |||||
download_mode=DownloadMode.FORCE_REDOWNLOAD) | |||||
assert self.eval_dataset.config_kwargs['classes'] | |||||
assert next( | assert next( | ||||
iter(self.eval_dataset.config_kwargs['split_config'].values())) | iter(self.eval_dataset.config_kwargs['split_config'].values())) | ||||