@@ -248,15 +248,15 @@ class MsDataset: | |||
break | |||
target_subset_name, target_dataset_structure = get_target_dataset_structure( | |||
dataset_json, subset_name, split) | |||
meta_map, file_map = get_dataset_files(target_dataset_structure, | |||
dataset_name, namespace, | |||
version) | |||
meta_map, file_map, args_map = get_dataset_files( | |||
target_dataset_structure, dataset_name, namespace, version) | |||
builder = load_dataset_builder( | |||
dataset_name, | |||
subset_name, | |||
namespace, | |||
meta_data_files=meta_map, | |||
zip_data_files=file_map, | |||
args_map=args_map, | |||
cache_dir=MS_DATASETS_CACHE, | |||
version=version, | |||
split=list(target_dataset_structure.keys()), | |||
@@ -60,6 +60,8 @@ class ImageInstanceSegmentationCocoDataset(TorchTaskDataset): | |||
classes=None, | |||
seg_prefix=None, | |||
folder_name=None, | |||
ann_file=None, | |||
img_prefix=None, | |||
test_mode=False, | |||
filter_empty_gt=True, | |||
**kwargs): | |||
@@ -69,11 +71,9 @@ class ImageInstanceSegmentationCocoDataset(TorchTaskDataset): | |||
self.split = next(iter(split_config.keys())) | |||
self.preprocessor = preprocessor | |||
self.ann_file = osp.join(self.data_root, | |||
DATASET_STRUCTURE[self.split]['annotation']) | |||
self.ann_file = osp.join(self.data_root, ann_file) | |||
self.img_prefix = osp.join(self.data_root, | |||
DATASET_STRUCTURE[self.split]['images']) | |||
self.img_prefix = osp.join(self.data_root, img_prefix) | |||
self.seg_prefix = seg_prefix | |||
self.test_mode = test_mode | |||
self.filter_empty_gt = filter_empty_gt | |||
@@ -1,6 +1,6 @@ | |||
import os | |||
from collections import defaultdict | |||
from typing import Mapping, Optional, Sequence, Union | |||
from typing import Any, Mapping, Optional, Sequence, Union | |||
from datasets.builder import DatasetBuilder | |||
@@ -92,6 +92,7 @@ def get_dataset_files(subset_split_into: dict, | |||
""" | |||
meta_map = defaultdict(dict) | |||
file_map = defaultdict(dict) | |||
args_map = defaultdict(dict) | |||
from modelscope.hub.api import HubApi | |||
modelscope_api = HubApi() | |||
for split, info in subset_split_into.items(): | |||
@@ -99,7 +100,8 @@ def get_dataset_files(subset_split_into: dict, | |||
info.get('meta', ''), dataset_name, namespace, revision) | |||
if info.get('file'): | |||
file_map[split] = info['file'] | |||
return meta_map, file_map | |||
args_map[split] = info.get('args') | |||
return meta_map, file_map, args_map | |||
def load_dataset_builder(dataset_name: str, subset_name: str, namespace: str, | |||
@@ -107,12 +109,16 @@ def load_dataset_builder(dataset_name: str, subset_name: str, namespace: str, | |||
Sequence[str]]], | |||
zip_data_files: Mapping[str, Union[str, | |||
Sequence[str]]], | |||
cache_dir: str, version: Optional[Union[str]], | |||
split: Sequence[str], | |||
args_map: Mapping[str, Any], cache_dir: str, | |||
version: Optional[Union[str]], split: Sequence[str], | |||
**config_kwargs) -> DatasetBuilder: | |||
sub_dir = os.path.join(version, '_'.join(split)) | |||
meta_data_file = next(iter(meta_data_files.values())) | |||
if not meta_data_file: | |||
args_map = next(iter(args_map.values())) | |||
if args_map is None: | |||
args_map = {} | |||
args_map.update(config_kwargs) | |||
builder_instance = TaskSpecificDatasetBuilder( | |||
dataset_name=dataset_name, | |||
namespace=namespace, | |||
@@ -121,7 +127,7 @@ def load_dataset_builder(dataset_name: str, subset_name: str, namespace: str, | |||
meta_data_files=meta_data_files, | |||
zip_data_files=zip_data_files, | |||
hash=sub_dir, | |||
**config_kwargs) | |||
**args_map) | |||
elif meta_data_file.endswith('.csv'): | |||
builder_instance = MsCsvDatasetBuilder( | |||
dataset_name=dataset_name, | |||
@@ -36,9 +36,8 @@ class MsDatasetTest(unittest.TestCase): | |||
ms_ds_train = MsDataset.load( | |||
'pets_small', | |||
namespace=DEFAULT_DATASET_NAMESPACE, | |||
split='train', | |||
classes=('1', '2'), | |||
folder_name='Pets') | |||
download_mode=DownloadMode.FORCE_REDOWNLOAD, | |||
split='train') | |||
print(ms_ds_train.config_kwargs) | |||
assert next(iter(ms_ds_train.config_kwargs['split_config'].values())) | |||
@@ -20,7 +20,6 @@ class TestFinetuneMPlug(unittest.TestCase): | |||
self.tmp_dir = tempfile.TemporaryDirectory().name | |||
if not os.path.exists(self.tmp_dir): | |||
os.makedirs(self.tmp_dir) | |||
from modelscope.utils.constant import DownloadMode | |||
datadict = MsDataset.load( | |||
'coco_captions_small_slice', | |||
@@ -15,7 +15,7 @@ from modelscope.msdatasets.task_datasets import \ | |||
ImageInstanceSegmentationCocoDataset | |||
from modelscope.trainers import build_trainer | |||
from modelscope.utils.config import Config, ConfigDict | |||
from modelscope.utils.constant import ModelFile | |||
from modelscope.utils.constant import DownloadMode, ModelFile | |||
from modelscope.utils.test_utils import test_level | |||
@@ -41,38 +41,26 @@ class TestImageInstanceSegmentationTrainer(unittest.TestCase): | |||
if train_data_cfg is None: | |||
# use default toy data | |||
train_data_cfg = ConfigDict( | |||
name='pets_small', | |||
split='train', | |||
classes=('Cat', 'Dog'), | |||
folder_name='Pets', | |||
test_mode=False) | |||
name='pets_small', split='train', test_mode=False) | |||
if val_data_cfg is None: | |||
val_data_cfg = ConfigDict( | |||
name='pets_small', | |||
split='validation', | |||
classes=('Cat', 'Dog'), | |||
folder_name='Pets', | |||
test_mode=True) | |||
name='pets_small', split='validation', test_mode=True) | |||
self.train_dataset = MsDataset.load( | |||
dataset_name=train_data_cfg.name, | |||
split=train_data_cfg.split, | |||
classes=train_data_cfg.classes, | |||
folder_name=train_data_cfg.folder_name, | |||
test_mode=train_data_cfg.test_mode) | |||
assert self.train_dataset.config_kwargs[ | |||
'classes'] == train_data_cfg.classes | |||
test_mode=train_data_cfg.test_mode, | |||
download_mode=DownloadMode.FORCE_REDOWNLOAD) | |||
assert self.train_dataset.config_kwargs['classes'] | |||
assert next( | |||
iter(self.train_dataset.config_kwargs['split_config'].values())) | |||
self.eval_dataset = MsDataset.load( | |||
dataset_name=val_data_cfg.name, | |||
split=val_data_cfg.split, | |||
classes=val_data_cfg.classes, | |||
folder_name=val_data_cfg.folder_name, | |||
test_mode=val_data_cfg.test_mode) | |||
assert self.eval_dataset.config_kwargs[ | |||
'classes'] == val_data_cfg.classes | |||
test_mode=val_data_cfg.test_mode, | |||
download_mode=DownloadMode.FORCE_REDOWNLOAD) | |||
assert self.eval_dataset.config_kwargs['classes'] | |||
assert next( | |||
iter(self.eval_dataset.config_kwargs['split_config'].values())) | |||