Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9174075 * fix msdatasetmaster
@@ -32,3 +32,18 @@ def raise_on_error(rsp): | |||
return True | |||
else: | |||
raise RequestError(rsp['Message']) | |||
# TODO use raise_on_error instead if modelhub and datahub response have uniform structures, | |||
def datahub_raise_on_error(url, rsp): | |||
"""If response error, raise exception | |||
Args: | |||
rsp (_type_): The server response | |||
""" | |||
if rsp.get('Code') == 200: | |||
return True | |||
else: | |||
raise RequestError( | |||
f"Url = {url}, Status = {rsp.get('status')}, error = {rsp.get('error')}, message = {rsp.get('message')}" | |||
) |
@@ -19,4 +19,4 @@ DOWNLOADED_DATASETS_PATH = Path( | |||
os.getenv('DOWNLOADED_DATASETS_PATH', DEFAULT_DOWNLOADED_DATASETS_PATH)) | |||
MS_HUB_ENDPOINT = os.environ.get('MS_HUB_ENDPOINT', | |||
'http://101.201.119.157:31752') | |||
'http://123.57.189.90:31752') |
@@ -3,7 +3,7 @@ from typing import (Any, Callable, Dict, Iterable, List, Mapping, Optional, | |||
Sequence, Union) | |||
import numpy as np | |||
from datasets import Dataset | |||
from datasets import Dataset, DatasetDict | |||
from datasets import load_dataset as hf_load_dataset | |||
from datasets.config import TF_AVAILABLE, TORCH_AVAILABLE | |||
from datasets.packaged_modules import _PACKAGED_DATASETS_MODULES | |||
@@ -12,7 +12,7 @@ from datasets.utils.file_utils import (is_relative_path, | |||
from modelscope.msdatasets.config import MS_DATASETS_CACHE | |||
from modelscope.msdatasets.utils.ms_api import MsApi | |||
from modelscope.utils.constant import Hubs | |||
from modelscope.utils.constant import DownloadMode, Hubs | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
@@ -34,6 +34,10 @@ class MsDataset: | |||
def __init__(self, hf_ds: Dataset, target: Optional[str] = None): | |||
self._hf_ds = hf_ds | |||
if target is not None and target not in self._hf_ds.features: | |||
raise TypeError( | |||
f'"target" must be a column of the dataset({list(self._hf_ds.features.keys())}, but got {target}' | |||
) | |||
self.target = target | |||
def __iter__(self): | |||
@@ -48,17 +52,23 @@ class MsDataset: | |||
@classmethod | |||
def from_hf_dataset(cls, | |||
hf_ds: Dataset, | |||
hf_ds: Union[Dataset, DatasetDict], | |||
target: str = None) -> Union[dict, 'MsDataset']: | |||
if isinstance(hf_ds, Dataset): | |||
return cls(hf_ds, target) | |||
if len(hf_ds.keys()) == 1: | |||
return cls(next(iter(hf_ds.values())), target) | |||
return {k: cls(v, target) for k, v in hf_ds.items()} | |||
elif isinstance(hf_ds, DatasetDict): | |||
if len(hf_ds.keys()) == 1: | |||
return cls(next(iter(hf_ds.values())), target) | |||
return {k: cls(v, target) for k, v in hf_ds.items()} | |||
else: | |||
raise TypeError( | |||
f'"hf_ds" must be a Dataset or DatasetDict, but got {type(hf_ds)}' | |||
) | |||
@staticmethod | |||
def load( | |||
dataset_name: Union[str, list], | |||
namespace: Optional[str] = None, | |||
target: Optional[str] = None, | |||
version: Optional[str] = None, | |||
hub: Optional[Hubs] = Hubs.modelscope, | |||
@@ -67,23 +77,32 @@ class MsDataset: | |||
data_dir: Optional[str] = None, | |||
data_files: Optional[Union[str, Sequence[str], | |||
Mapping[str, Union[str, | |||
Sequence[str]]]]] = None | |||
Sequence[str]]]]] = None, | |||
download_mode: Optional[DownloadMode] = DownloadMode. | |||
REUSE_DATASET_IF_EXISTS | |||
) -> Union[dict, 'MsDataset']: | |||
"""Load a MsDataset from the ModelScope Hub, Hugging Face Hub, urls, or a local dataset. | |||
Args: | |||
dataset_name (str): Path or name of the dataset. | |||
namespace(str, optional): Namespace of the dataset. It should not be None, if you load a remote dataset | |||
from Hubs.modelscope, | |||
target (str, optional): Name of the column to output. | |||
version (str, optional): Version of the dataset script to load: | |||
subset_name (str, optional): Defining the subset_name of the dataset. | |||
data_dir (str, optional): Defining the data_dir of the dataset configuration. I | |||
data_files (str or Sequence or Mapping, optional): Path(s) to source data file(s). | |||
split (str, optional): Which split of the data to load. | |||
hub (Hubs, optional): When loading from a remote hub, where it is from | |||
hub (Hubs or str, optional): When loading from a remote hub, where it is from. default Hubs.modelscope | |||
download_mode (DownloadMode or str, optional): How to treat existing datasets. default | |||
DownloadMode.REUSE_DATASET_IF_EXISTS | |||
Returns: | |||
MsDataset (obj:`MsDataset`): MsDataset object for a certain dataset. | |||
""" | |||
download_mode = DownloadMode(download_mode | |||
or DownloadMode.REUSE_DATASET_IF_EXISTS) | |||
hub = Hubs(hub or Hubs.modelscope) | |||
if hub == Hubs.huggingface: | |||
dataset = hf_load_dataset( | |||
dataset_name, | |||
@@ -91,21 +110,25 @@ class MsDataset: | |||
revision=version, | |||
split=split, | |||
data_dir=data_dir, | |||
data_files=data_files) | |||
data_files=data_files, | |||
download_mode=download_mode.value) | |||
return MsDataset.from_hf_dataset(dataset, target=target) | |||
else: | |||
elif hub == Hubs.modelscope: | |||
return MsDataset._load_ms_dataset( | |||
dataset_name, | |||
namespace=namespace, | |||
target=target, | |||
subset_name=subset_name, | |||
version=version, | |||
split=split, | |||
data_dir=data_dir, | |||
data_files=data_files) | |||
data_files=data_files, | |||
download_mode=download_mode) | |||
@staticmethod | |||
def _load_ms_dataset( | |||
dataset_name: Union[str, list], | |||
namespace: Optional[str] = None, | |||
target: Optional[str] = None, | |||
version: Optional[str] = None, | |||
subset_name: Optional[str] = None, | |||
@@ -113,17 +136,19 @@ class MsDataset: | |||
data_dir: Optional[str] = None, | |||
data_files: Optional[Union[str, Sequence[str], | |||
Mapping[str, Union[str, | |||
Sequence[str]]]]] = None | |||
Sequence[str]]]]] = None, | |||
download_mode: Optional[DownloadMode] = None | |||
) -> Union[dict, 'MsDataset']: | |||
if isinstance(dataset_name, str): | |||
use_hf = False | |||
if dataset_name in _PACKAGED_DATASETS_MODULES or os.path.isdir(dataset_name) or \ | |||
(os.path.isfile(dataset_name) and dataset_name.endswith('.py')): | |||
use_hf = True | |||
elif is_relative_path(dataset_name): | |||
elif is_relative_path(dataset_name) and dataset_name.count( | |||
'/') == 0: | |||
ms_api = MsApi() | |||
dataset_scripts = ms_api.fetch_dataset_scripts( | |||
dataset_name, version) | |||
dataset_name, namespace, download_mode, version) | |||
if 'py' in dataset_scripts: # dataset copied from hf datasets | |||
dataset_name = dataset_scripts['py'][0] | |||
use_hf = True | |||
@@ -140,7 +165,8 @@ class MsDataset: | |||
split=split, | |||
data_dir=data_dir, | |||
data_files=data_files, | |||
cache_dir=MS_DATASETS_CACHE) | |||
cache_dir=MS_DATASETS_CACHE, | |||
download_mode=download_mode.value) | |||
else: | |||
# TODO load from ms datahub | |||
raise NotImplementedError( | |||
@@ -1,11 +1,14 @@ | |||
import os | |||
import shutil | |||
from collections import defaultdict | |||
from typing import Optional | |||
import requests | |||
from modelscope.hub.errors import NotExistError, datahub_raise_on_error | |||
from modelscope.msdatasets.config import (DOWNLOADED_DATASETS_PATH, | |||
MS_HUB_ENDPOINT) | |||
from modelscope.utils.constant import DownloadMode | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
@@ -27,23 +30,38 @@ class MsApi: | |||
def fetch_dataset_scripts(self, | |||
dataset_name: str, | |||
version: Optional[str] = 'master', | |||
force_download=False): | |||
datahub_url = f'{self.endpoint}/api/v1/datasets?Query={dataset_name}' | |||
r = requests.get(datahub_url) | |||
r.raise_for_status() | |||
dataset_list = r.json()['Data'] | |||
if len(dataset_list) == 0: | |||
return None | |||
dataset_id = dataset_list[0]['Id'] | |||
namespace: str, | |||
download_mode: Optional[DownloadMode], | |||
version: Optional[str] = 'master'): | |||
if namespace is None: | |||
raise ValueError( | |||
f'Dataset from Hubs.modelscope should have a valid "namespace", but get {namespace}' | |||
) | |||
version = version or 'master' | |||
datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={version}' | |||
r = requests.get(datahub_url) | |||
r.raise_for_status() | |||
file_list = r.json()['Data']['Files'] | |||
cache_dir = os.path.join(DOWNLOADED_DATASETS_PATH, dataset_name, | |||
version) | |||
namespace, version) | |||
download_mode = DownloadMode(download_mode | |||
or DownloadMode.REUSE_DATASET_IF_EXISTS) | |||
if download_mode == DownloadMode.FORCE_REDOWNLOAD and os.path.exists( | |||
cache_dir): | |||
shutil.rmtree(cache_dir) | |||
os.makedirs(cache_dir, exist_ok=True) | |||
datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}' | |||
r = requests.get(datahub_url) | |||
resp = r.json() | |||
datahub_raise_on_error(datahub_url, resp) | |||
dataset_id = resp['Data']['Id'] | |||
datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={version}' | |||
r = requests.get(datahub_url) | |||
resp = r.json() | |||
datahub_raise_on_error(datahub_url, resp) | |||
file_list = resp['Data'] | |||
if file_list is None: | |||
raise NotExistError( | |||
f'The modelscope dataset [dataset_name = {dataset_name}, namespace = {namespace}, ' | |||
f'version = {version}] dose not exist') | |||
file_list = file_list['Files'] | |||
local_paths = defaultdict(list) | |||
for file_info in file_list: | |||
file_path = file_info['Path'] | |||
@@ -54,7 +72,7 @@ class MsApi: | |||
r.raise_for_status() | |||
content = r.json()['Data']['Content'] | |||
local_path = os.path.join(cache_dir, file_path) | |||
if os.path.exists(local_path) and not force_download: | |||
if os.path.exists(local_path): | |||
logger.warning( | |||
f"Reusing dataset {dataset_name}'s python file ({local_path})" | |||
) | |||
@@ -1,4 +1,5 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import enum | |||
class Fields(object): | |||
@@ -69,13 +70,20 @@ class InputFields(object): | |||
audio = 'audio' | |||
class Hubs(object): | |||
class Hubs(enum.Enum): | |||
""" Source from which an entity (such as a Dataset or Model) is stored | |||
""" | |||
modelscope = 'modelscope' | |||
huggingface = 'huggingface' | |||
class DownloadMode(enum.Enum): | |||
""" How to treat existing datasets | |||
""" | |||
REUSE_DATASET_IF_EXISTS = 'reuse_dataset_if_exists' | |||
FORCE_REDOWNLOAD = 'force_redownload' | |||
class ModelFile(object): | |||
CONFIGURATION = 'configuration.json' | |||
README = 'README.md' | |||
@@ -32,11 +32,12 @@ class ImgPreprocessor(Preprocessor): | |||
class MsDatasetTest(unittest.TestCase): | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_ds_basic(self): | |||
ms_ds_full = MsDataset.load('squad') | |||
ms_ds_full = MsDataset.load('squad', namespace='damotest') | |||
ms_ds_full_hf = hfdata.load_dataset('squad') | |||
ms_ds_train = MsDataset.load('squad', split='train') | |||
ms_ds_train = MsDataset.load( | |||
'squad', namespace='damotest', split='train') | |||
ms_ds_train_hf = hfdata.load_dataset('squad', split='train') | |||
ms_image_train = MsDataset.from_hf_dataset( | |||
hfdata.load_dataset('beans', split='train')) | |||
@@ -48,7 +49,7 @@ class MsDatasetTest(unittest.TestCase): | |||
print(next(iter(ms_ds_train))) | |||
print(next(iter(ms_image_train))) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
@require_torch | |||
def test_to_torch_dataset_text(self): | |||
model_id = 'damo/bert-base-sst2' | |||
@@ -57,13 +58,14 @@ class MsDatasetTest(unittest.TestCase): | |||
nlp_model.model_dir, | |||
first_sequence='context', | |||
second_sequence=None) | |||
ms_ds_train = MsDataset.load('squad', split='train') | |||
ms_ds_train = MsDataset.load( | |||
'squad', namespace='damotest', split='train') | |||
pt_dataset = ms_ds_train.to_torch_dataset(preprocessors=preprocessor) | |||
import torch | |||
dataloader = torch.utils.data.DataLoader(pt_dataset, batch_size=5) | |||
print(next(iter(dataloader))) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
@require_tf | |||
def test_to_tf_dataset_text(self): | |||
import tensorflow as tf | |||
@@ -74,7 +76,8 @@ class MsDatasetTest(unittest.TestCase): | |||
nlp_model.model_dir, | |||
first_sequence='context', | |||
second_sequence=None) | |||
ms_ds_train = MsDataset.load('squad', split='train') | |||
ms_ds_train = MsDataset.load( | |||
'squad', namespace='damotest', split='train') | |||
tf_dataset = ms_ds_train.to_tf_dataset( | |||
batch_size=5, | |||
shuffle=True, | |||
@@ -85,8 +88,8 @@ class MsDatasetTest(unittest.TestCase): | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
@require_torch | |||
def test_to_torch_dataset_img(self): | |||
ms_image_train = MsDataset.from_hf_dataset( | |||
hfdata.load_dataset('beans', split='train')) | |||
ms_image_train = MsDataset.load( | |||
'beans', namespace='damotest', split='train') | |||
pt_dataset = ms_image_train.to_torch_dataset( | |||
preprocessors=ImgPreprocessor( | |||
image_path='image_file_path', label='labels')) | |||
@@ -99,7 +102,8 @@ class MsDatasetTest(unittest.TestCase): | |||
def test_to_tf_dataset_img(self): | |||
import tensorflow as tf | |||
tf.compat.v1.enable_eager_execution() | |||
ms_image_train = MsDataset.load('beans', split='train') | |||
ms_image_train = MsDataset.load( | |||
'beans', namespace='damotest', split='train') | |||
tf_dataset = ms_image_train.to_tf_dataset( | |||
batch_size=5, | |||
shuffle=True, | |||
@@ -62,7 +62,8 @@ class ImageMattingTest(unittest.TestCase): | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_with_modelscope_dataset(self): | |||
dataset = MsDataset.load('beans', split='train', target='image') | |||
dataset = MsDataset.load( | |||
'beans', namespace='damotest', split='train', target='image') | |||
img_matting = pipeline(Tasks.image_matting, model=self.model_id) | |||
result = img_matting(dataset) | |||
for i in range(10): | |||
@@ -87,12 +87,16 @@ class SequenceClassificationTest(unittest.TestCase): | |||
result = text_classification(dataset) | |||
self.printDataset(result) | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_run_with_modelscope_dataset(self): | |||
text_classification = pipeline(task=Tasks.text_classification) | |||
# loaded from modelscope dataset | |||
dataset = MsDataset.load( | |||
'squad', split='train', target='context', hub=Hubs.modelscope) | |||
'squad', | |||
namespace='damotest', | |||
split='train', | |||
target='context', | |||
hub=Hubs.modelscope) | |||
result = text_classification(dataset) | |||
self.printDataset(result) | |||