Browse Source

[to #42670107]pydataset fetch data from datahub

* pydataset fetch data from datahub
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9060856
master
feiwu.yfw yingda.chen 3 years ago
parent
commit
c7238a470b
9 changed files with 580 additions and 103 deletions
  1. +24
    -15
      modelscope/preprocessors/nlp.py
  2. +22
    -0
      modelscope/pydatasets/config.py
  3. +323
    -58
      modelscope/pydatasets/py_dataset.py
  4. +0
    -0
      modelscope/pydatasets/utils/__init__.py
  5. +66
    -0
      modelscope/pydatasets/utils/ms_api.py
  6. +15
    -0
      modelscope/utils/test_utils.py
  7. +11
    -0
      tests/pipelines/test_image_matting.py
  8. +24
    -4
      tests/pipelines/test_text_classification.py
  9. +95
    -26
      tests/pydatasets/test_py_dataset.py

+ 24
- 15
modelscope/preprocessors/nlp.py View File

@@ -53,12 +53,12 @@ class SequenceClassificationPreprocessor(Preprocessor):
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
print(f'this is the tokenzier {self.tokenizer}')

@type_assert(object, (str, tuple))
def __call__(self, data: Union[str, tuple]) -> Dict[str, Any]:
@type_assert(object, (str, tuple, Dict))
def __call__(self, data: Union[str, tuple, Dict]) -> Dict[str, Any]:
"""process the raw input data

Args:
data (str or tuple):
data (str or tuple, Dict):
sentence1 (str): a sentence
Example:
'you are so handsome.'
@@ -70,22 +70,31 @@ class SequenceClassificationPreprocessor(Preprocessor):
sentence2 (str): a sentence
Example:
'you are so beautiful.'
or
{field1: field_value1, field2: field_value2}
field1 (str): field name, default 'first_sequence'
field_value1 (str): a sentence
Example:
'you are so handsome.'

field2 (str): field name, default 'second_sequence'
field_value2 (str): a sentence
Example:
'you are so beautiful.'

Returns:
Dict[str, Any]: the preprocessed data
"""

if not isinstance(data, tuple):
data = (
data,
None,
)

sentence1, sentence2 = data
new_data = {
self.first_sequence: sentence1,
self.second_sequence: sentence2
}
if isinstance(data, str):
new_data = {self.first_sequence: data}
elif isinstance(data, tuple):
sentence1, sentence2 = data
new_data = {
self.first_sequence: sentence1,
self.second_sequence: sentence2
}
else:
new_data = data

# preprocess the data for the model input



+ 22
- 0
modelscope/pydatasets/config.py View File

@@ -0,0 +1,22 @@
import os
from pathlib import Path

# Cache location
DEFAULT_CACHE_HOME = '~/.cache'
CACHE_HOME = os.getenv('CACHE_HOME', DEFAULT_CACHE_HOME)
DEFAULT_MS_CACHE_HOME = os.path.join(CACHE_HOME, 'modelscope/hub')
MS_CACHE_HOME = os.path.expanduser(
os.getenv('MS_CACHE_HOME', DEFAULT_MS_CACHE_HOME))

DEFAULT_MS_DATASETS_CACHE = os.path.join(MS_CACHE_HOME, 'datasets')
MS_DATASETS_CACHE = Path(
os.getenv('MS_DATASETS_CACHE', DEFAULT_MS_DATASETS_CACHE))

DOWNLOADED_DATASETS_DIR = 'downloads'
DEFAULT_DOWNLOADED_DATASETS_PATH = os.path.join(MS_DATASETS_CACHE,
DOWNLOADED_DATASETS_DIR)
DOWNLOADED_DATASETS_PATH = Path(
os.getenv('DOWNLOADED_DATASETS_PATH', DEFAULT_DOWNLOADED_DATASETS_PATH))

MS_HUB_ENDPOINT = os.environ.get('MS_HUB_ENDPOINT',
'http://101.201.119.157:31752')

+ 323
- 58
modelscope/pydatasets/py_dataset.py View File

@@ -1,64 +1,81 @@
from typing import (Any, Callable, Dict, List, Mapping, Optional, Sequence,
Union)
import os
from typing import (Any, Callable, Dict, Iterable, List, Mapping, Optional,
Sequence, Union)

from datasets import Dataset, load_dataset
import numpy as np
from datasets import Dataset
from datasets import load_dataset as hf_load_dataset
from datasets.config import TF_AVAILABLE, TORCH_AVAILABLE
from datasets.packaged_modules import _PACKAGED_DATASETS_MODULES
from datasets.utils.file_utils import (is_relative_path,
relative_to_absolute_path)

from modelscope.pydatasets.config import MS_DATASETS_CACHE
from modelscope.pydatasets.utils.ms_api import MsApi
from modelscope.utils.constant import Hubs
from modelscope.utils.logger import get_logger

logger = get_logger()


def format_list(para) -> List:
if para is None:
para = []
elif isinstance(para, str):
para = [para]
elif len(set(para)) < len(para):
raise ValueError(f'List columns contains duplicates: {para}')
return para


class PyDataset:
_hf_ds = None # holds the underlying HuggingFace Dataset
"""A PyDataset backed by hugging face Dataset."""

def __init__(self, hf_ds: Dataset):
def __init__(self, hf_ds: Dataset, target: Optional[str] = None):
self._hf_ds = hf_ds
self.target = None
self.target = target

def __iter__(self):
if isinstance(self._hf_ds, Dataset):
for item in self._hf_ds:
if self.target is not None:
yield item[self.target]
else:
yield item
else:
for ds in self._hf_ds.values():
for item in ds:
if self.target is not None:
yield item[self.target]
else:
yield item
for item in self._hf_ds:
if self.target is not None:
yield item[self.target]
else:
yield item

def __getitem__(self, key):
return self._hf_ds[key]

@classmethod
def from_hf_dataset(cls,
hf_ds: Dataset,
target: str = None) -> 'PyDataset':
dataset = cls(hf_ds)
dataset.target = target
return dataset
target: str = None) -> Union[dict, 'PyDataset']:
if isinstance(hf_ds, Dataset):
return cls(hf_ds, target)
if len(hf_ds.keys()) == 1:
return cls(next(iter(hf_ds.values())), target)
return {k: cls(v, target) for k, v in hf_ds.items()}

@staticmethod
def load(path: Union[str, list],
target: Optional[str] = None,
version: Optional[str] = None,
name: Optional[str] = None,
split: Optional[str] = None,
data_dir: Optional[str] = None,
data_files: Optional[Union[str, Sequence[str],
Mapping[str,
Union[str,
Sequence[str]]]]] = None,
hub: Optional[Hubs] = None) -> 'PyDataset':
def load(
dataset_name: Union[str, list],
target: Optional[str] = None,
version: Optional[str] = None,
hub: Optional[Hubs] = Hubs.modelscope,
subset_name: Optional[str] = None,
split: Optional[str] = None,
data_dir: Optional[str] = None,
data_files: Optional[Union[str, Sequence[str],
Mapping[str, Union[str,
Sequence[str]]]]] = None
) -> Union[dict, 'PyDataset']:
"""Load a PyDataset from the ModelScope Hub, Hugging Face Hub, urls, or a local dataset.
Args:

path (str): Path or name of the dataset.
dataset_name (str): Path or name of the dataset.
target (str, optional): Name of the column to output.
version (str, optional): Version of the dataset script to load:
name (str, optional): Defining the subset_name of the dataset.
subset_name (str, optional): Defining the subset_name of the dataset.
data_dir (str, optional): Defining the data_dir of the dataset configuration. I
data_files (str or Sequence or Mapping, optional): Path(s) to source data file(s).
split (str, optional): Which split of the data to load.
@@ -67,53 +84,302 @@ class PyDataset:
Returns:
PyDataset (obj:`PyDataset`): PyDataset object for a certain dataset.
"""
if Hubs.modelscope == hub:
# TODO: parse data meta information from modelscope hub
# and possibly download data files to local (and update path)
print('getting data from modelscope hub')
if isinstance(path, str):
dataset = load_dataset(
path,
name=name,
if hub == Hubs.huggingface:
dataset = hf_load_dataset(
dataset_name,
name=subset_name,
revision=version,
split=split,
data_dir=data_dir,
data_files=data_files)
elif isinstance(path, list):
return PyDataset.from_hf_dataset(dataset, target=target)
else:
return PyDataset._load_ms_dataset(
dataset_name,
target=target,
subset_name=subset_name,
version=version,
split=split,
data_dir=data_dir,
data_files=data_files)

@staticmethod
def _load_ms_dataset(
dataset_name: Union[str, list],
target: Optional[str] = None,
version: Optional[str] = None,
subset_name: Optional[str] = None,
split: Optional[str] = None,
data_dir: Optional[str] = None,
data_files: Optional[Union[str, Sequence[str],
Mapping[str, Union[str,
Sequence[str]]]]] = None
) -> Union[dict, 'PyDataset']:
if isinstance(dataset_name, str):
use_hf = False
if dataset_name in _PACKAGED_DATASETS_MODULES or os.path.isdir(dataset_name) or \
(os.path.isfile(dataset_name) and dataset_name.endswith('.py')):
use_hf = True
elif is_relative_path(dataset_name):
ms_api = MsApi()
dataset_scripts = ms_api.fetch_dataset_scripts(
dataset_name, version)
if 'py' in dataset_scripts: # dataset copied from hf datasets
dataset_name = dataset_scripts['py'][0]
use_hf = True
else:
raise FileNotFoundError(
f"Couldn't find a dataset script at {relative_to_absolute_path(dataset_name)} "
f'or any data file in the same directory.')

if use_hf:
dataset = hf_load_dataset(
dataset_name,
name=subset_name,
revision=version,
split=split,
data_dir=data_dir,
data_files=data_files,
cache_dir=MS_DATASETS_CACHE)
else:
# TODO load from ms datahub
raise NotImplementedError(
f'Dataset {dataset_name} load from modelscope datahub to be implemented in '
f'the future')
elif isinstance(dataset_name, list):
if target is None:
target = 'target'
dataset = Dataset.from_dict({target: [p] for p in path})
dataset = Dataset.from_dict({target: dataset_name})
else:
raise TypeError('path must be a str or a list, but got'
f' {type(path)}')
f' {type(dataset_name)}')
return PyDataset.from_hf_dataset(dataset, target=target)

def to_torch_dataset_with_processors(
self,
preprocessors: Union[Callable, List[Callable]],
columns: Union[str, List[str]] = None,
):
preprocessor_list = preprocessors if isinstance(
preprocessors, list) else [preprocessors]

columns = format_list(columns)

columns = [
key for key in self._hf_ds.features.keys() if key in columns
]
sample = next(iter(self._hf_ds))

sample_res = {k: np.array(sample[k]) for k in columns}
for processor in preprocessor_list:
sample_res.update(
{k: np.array(v)
for k, v in processor(sample).items()})

def is_numpy_number(value):
return np.issubdtype(value.dtype, np.integer) or np.issubdtype(
value.dtype, np.floating)

retained_columns = []
for k in sample_res.keys():
if not is_numpy_number(sample_res[k]):
logger.warning(
f'Data of column {k} is non-numeric, will be removed')
continue
retained_columns.append(k)

import torch

class MsIterableDataset(torch.utils.data.IterableDataset):

def __init__(self, dataset: Iterable):
super(MsIterableDataset).__init__()
self.dataset = dataset

def __iter__(self):
for item_dict in self.dataset:
res = {
k: np.array(item_dict[k])
for k in columns if k in retained_columns
}
for preprocessor in preprocessor_list:
res.update({
k: np.array(v)
for k, v in preprocessor(item_dict).items()
if k in retained_columns
})
yield res

return MsIterableDataset(self._hf_ds)

def to_torch_dataset(
self,
columns: Union[str, List[str]] = None,
output_all_columns: bool = False,
preprocessors: Union[Callable, List[Callable]] = None,
**format_kwargs,
):
self._hf_ds.reset_format()
self._hf_ds.set_format(
type='torch',
columns=columns,
output_all_columns=output_all_columns,
format_kwargs=format_kwargs)
return self._hf_ds
"""Create a torch.utils.data.Dataset from the MS Dataset. The torch.utils.data.Dataset can be passed to
torch.utils.data.DataLoader.

Args:
preprocessors (Callable or List[Callable], default None): (list of) Preprocessor object used to process
every sample of the dataset. The output type of processors is dict, and each numeric field of the dict
will be used as a field of torch.utils.data.Dataset.
columns (str or List[str], default None): Dataset column(s) to be loaded (numeric data only). If the
preprocessor is None, the arg columns must have at least one column. If the `preprocessors` is not None,
the output fields of processors will also be added.
format_kwargs: A `dict` of arguments to be passed to the `torch.tensor`.

Returns:
:class:`tf.data.Dataset`

"""
if not TORCH_AVAILABLE:
raise ImportError(
'The function to_torch_dataset requires pytorch to be installed'
)
if preprocessors is not None:
return self.to_torch_dataset_with_processors(preprocessors)
else:
self._hf_ds.reset_format()
self._hf_ds.set_format(
type='torch', columns=columns, format_kwargs=format_kwargs)
return self._hf_ds

def to_tf_dataset_with_processors(
self,
batch_size: int,
shuffle: bool,
preprocessors: Union[Callable, List[Callable]],
drop_remainder: bool = None,
prefetch: bool = True,
label_cols: Union[str, List[str]] = None,
columns: Union[str, List[str]] = None,
):
preprocessor_list = preprocessors if isinstance(
preprocessors, list) else [preprocessors]

label_cols = format_list(label_cols)
columns = format_list(columns)
cols_to_retain = list(set(label_cols + columns))
retained_columns = [
key for key in self._hf_ds.features.keys() if key in cols_to_retain
]
import tensorflow as tf
tf_dataset = tf.data.Dataset.from_tensor_slices(
np.arange(len(self._hf_ds), dtype=np.int64))
if shuffle:
tf_dataset = tf_dataset.shuffle(buffer_size=len(self._hf_ds))

def func(i, return_dict=False):
i = int(i)
res = {k: np.array(self._hf_ds[i][k]) for k in retained_columns}
for preprocessor in preprocessor_list:
# TODO preprocessor output may have the same key
res.update({
k: np.array(v)
for k, v in preprocessor(self._hf_ds[i]).items()
})
if return_dict:
return res
return tuple(list(res.values()))

sample_res = func(0, True)

@tf.function(input_signature=[tf.TensorSpec(None, tf.int64)])
def fetch_function(i):
output = tf.numpy_function(
func,
inp=[i],
Tout=[
tf.dtypes.as_dtype(val.dtype)
for val in sample_res.values()
],
)
return {key: output[i] for i, key in enumerate(sample_res)}

tf_dataset = tf_dataset.map(
fetch_function, num_parallel_calls=tf.data.AUTOTUNE)
if label_cols:

def split_features_and_labels(input_batch):
labels = {
key: tensor
for key, tensor in input_batch.items() if key in label_cols
}
if len(input_batch) == 1:
input_batch = next(iter(input_batch.values()))
if len(labels) == 1:
labels = next(iter(labels.values()))
return input_batch, labels

tf_dataset = tf_dataset.map(split_features_and_labels)

elif len(columns) == 1:
tf_dataset = tf_dataset.map(lambda x: next(iter(x.values())))
if batch_size > 1:
tf_dataset = tf_dataset.batch(
batch_size, drop_remainder=drop_remainder)

if prefetch:
tf_dataset = tf_dataset.prefetch(tf.data.experimental.AUTOTUNE)
return tf_dataset

def to_tf_dataset(
self,
columns: Union[str, List[str]],
batch_size: int,
shuffle: bool,
collate_fn: Callable,
preprocessors: Union[Callable, List[Callable]] = None,
columns: Union[str, List[str]] = None,
collate_fn: Callable = None,
drop_remainder: bool = None,
collate_fn_args: Dict[str, Any] = None,
label_cols: Union[str, List[str]] = None,
dummy_labels: bool = False,
prefetch: bool = True,
):
"""Create a tf.data.Dataset from the MS Dataset. This tf.data.Dataset can be passed to tf methods like
model.fit() or model.predict().

Args:
batch_size (int): Number of samples in a single batch.
shuffle(bool): Shuffle the dataset order.
preprocessors (Callable or List[Callable], default None): (list of) Preprocessor object used to process
every sample of the dataset. The output type of processors is dict, and each field of the dict will be
used as a field of the tf.data. Dataset. If the `preprocessors` is None, the `collate_fn`
shouldn't be None.
columns (str or List[str], default None): Dataset column(s) to be loaded. If the preprocessor is None,
the arg columns must have at least one column. If the `preprocessors` is not None, the output fields of
processors will also be added.
collate_fn(Callable, default None): A callable object used to collect lists of samples into a batch. If
the `preprocessors` is None, the `collate_fn` shouldn't be None.
drop_remainder(bool, default None): Drop the last incomplete batch when loading.
collate_fn_args (Dict, optional): A `dict` of arguments to be passed to the`collate_fn`.
label_cols (str or List[str], defalut None): Dataset column(s) to load as labels.
prefetch (bool, default True): Prefetch data.

Returns:
:class:`tf.data.Dataset`

"""
if not TF_AVAILABLE:
raise ImportError(
'The function to_tf_dataset requires Tensorflow to be installed.'
)
if preprocessors is not None:
return self.to_tf_dataset_with_processors(
batch_size,
shuffle,
preprocessors,
drop_remainder=drop_remainder,
prefetch=prefetch,
label_cols=label_cols,
columns=columns)

if collate_fn is None:
logger.error(
'The `preprocessors` and the `collate_fn` should`t be both None.'
)
return None
self._hf_ds.reset_format()
return self._hf_ds.to_tf_dataset(
columns,
@@ -123,7 +389,6 @@ class PyDataset:
drop_remainder=drop_remainder,
collate_fn_args=collate_fn_args,
label_cols=label_cols,
dummy_labels=dummy_labels,
prefetch=prefetch)

def to_hf_dataset(self) -> Dataset:


+ 0
- 0
modelscope/pydatasets/utils/__init__.py View File


+ 66
- 0
modelscope/pydatasets/utils/ms_api.py View File

@@ -0,0 +1,66 @@
import os
from collections import defaultdict
from typing import Optional

import requests

from modelscope.pydatasets.config import (DOWNLOADED_DATASETS_PATH,
MS_HUB_ENDPOINT)
from modelscope.utils.logger import get_logger

logger = get_logger()


class MsApi:

def __init__(self, endpoint=MS_HUB_ENDPOINT):
self.endpoint = endpoint

def list_datasets(self):
path = f'{self.endpoint}/api/v1/datasets'
headers = None
params = {}
r = requests.get(path, params=params, headers=headers)
r.raise_for_status()
dataset_list = r.json()['Data']
return [x['Name'] for x in dataset_list]

def fetch_dataset_scripts(self,
dataset_name: str,
version: Optional[str] = 'master',
force_download=False):
datahub_url = f'{self.endpoint}/api/v1/datasets?Query={dataset_name}'
r = requests.get(datahub_url)
r.raise_for_status()
dataset_list = r.json()['Data']
if len(dataset_list) == 0:
return None
dataset_id = dataset_list[0]['Id']
version = version or 'master'
datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={version}'
r = requests.get(datahub_url)
r.raise_for_status()
file_list = r.json()['Data']['Files']
cache_dir = os.path.join(DOWNLOADED_DATASETS_PATH, dataset_name,
version)
os.makedirs(cache_dir, exist_ok=True)
local_paths = defaultdict(list)
for file_info in file_list:
file_path = file_info['Path']
if file_path.endswith('.py'):
datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/files?' \
f'Revision={version}&Path={file_path}'
r = requests.get(datahub_url)
r.raise_for_status()
content = r.json()['Data']['Content']
local_path = os.path.join(cache_dir, file_path)
if os.path.exists(local_path) and not force_download:
logger.warning(
f"Reusing dataset {dataset_name}'s python file ({local_path})"
)
local_paths['py'].append(local_path)
continue
with open(local_path, 'w') as f:
f.writelines(content)
local_paths['py'].append(local_path)
return local_paths

+ 15
- 0
modelscope/utils/test_utils.py View File

@@ -2,6 +2,9 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
import unittest

from datasets.config import TF_AVAILABLE, TORCH_AVAILABLE

TEST_LEVEL = 2
TEST_LEVEL_STR = 'TEST_LEVEL'
@@ -15,6 +18,18 @@ def test_level():
return TEST_LEVEL


def require_tf(test_case):
if not TF_AVAILABLE:
test_case = unittest.skip('test requires TensorFlow')(test_case)
return test_case


def require_torch(test_case):
if not TORCH_AVAILABLE:
test_case = unittest.skip('test requires PyTorch')(test_case)
return test_case


def set_test_level(level: int):
global TEST_LEVEL
TEST_LEVEL = level

+ 11
- 0
tests/pipelines/test_image_matting.py View File

@@ -66,6 +66,17 @@ class ImageMattingTest(unittest.TestCase):
cv2.imwrite('result.png', result['output_png'])
print(f'Output written to {osp.abspath("result.png")}')

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_modelscope_dataset(self):
dataset = PyDataset.load('beans', split='train', target='image')
img_matting = pipeline(Tasks.image_matting, model=self.model_id)
result = img_matting(dataset)
for i in range(10):
cv2.imwrite(f'result_{i}.png', next(result)['output_png'])
print(
f'Output written to dir: {osp.dirname(osp.abspath("result_0.png"))}'
)


if __name__ == '__main__':
unittest.main()

+ 24
- 4
tests/pipelines/test_text_classification.py View File

@@ -86,7 +86,11 @@ class SequenceClassificationTest(unittest.TestCase):
task=Tasks.text_classification, model=self.model_id)
result = text_classification(
PyDataset.load(
'glue', name='sst2', target='sentence', hub=Hubs.huggingface))
'glue',
subset_name='sst2',
split='train',
target='sentence',
hub=Hubs.huggingface))
self.printDataset(result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@@ -94,7 +98,11 @@ class SequenceClassificationTest(unittest.TestCase):
text_classification = pipeline(task=Tasks.text_classification)
result = text_classification(
PyDataset.load(
'glue', name='sst2', target='sentence', hub=Hubs.huggingface))
'glue',
subset_name='sst2',
split='train',
target='sentence',
hub=Hubs.huggingface))
self.printDataset(result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@@ -105,9 +113,21 @@ class SequenceClassificationTest(unittest.TestCase):
text_classification = pipeline(
Tasks.text_classification, model=model, preprocessor=preprocessor)
# loaded from huggingface dataset
# TODO: rename parameter as dataset_name and subset_name
dataset = PyDataset.load(
'glue', name='sst2', target='sentence', hub=Hubs.huggingface)
'glue',
subset_name='sst2',
split='train',
target='sentence',
hub=Hubs.huggingface)
result = text_classification(dataset)
self.printDataset(result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_modelscope_dataset(self):
text_classification = pipeline(task=Tasks.text_classification)
# loaded from modelscope dataset
dataset = PyDataset.load(
'squad', split='train', target='context', hub=Hubs.modelscope)
result = text_classification(dataset)
self.printDataset(result)



+ 95
- 26
tests/pydatasets/test_py_dataset.py View File

@@ -2,42 +2,111 @@ import unittest

import datasets as hfdata

from modelscope.models import Model
from modelscope.preprocessors import SequenceClassificationPreprocessor
from modelscope.preprocessors.base import Preprocessor
from modelscope.pydatasets import PyDataset
from modelscope.utils.constant import Hubs
from modelscope.utils.test_utils import require_tf, require_torch, test_level


class PyDatasetTest(unittest.TestCase):
class ImgPreprocessor(Preprocessor):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.path_field = kwargs.pop('image_path', 'image_path')
self.width = kwargs.pop('width', 'width')
self.height = kwargs.pop('height', 'width')

def setUp(self):
# ds1 initialized from in memory json
self.json_data = {
'dummy': [{
'a': i,
'x': i * 10,
'c': i * 100
} for i in range(1, 11)]
def __call__(self, data):
import cv2
image_path = data.get(self.path_field)
if not image_path:
return None
img = cv2.imread(image_path)
return {
'image':
cv2.resize(img,
(data.get(self.height, 128), data.get(self.width, 128)))
}
hfds1 = hfdata.Dataset.from_dict(self.json_data)
self.ds1 = PyDataset.from_hf_dataset(hfds1)

# ds2 initialized from hg hub
hfds2 = hfdata.load_dataset(
'glue', 'mrpc', revision='2.0.0', split='train')
self.ds2 = PyDataset.from_hf_dataset(hfds2)

def tearDown(self):
pass
class PyDatasetTest(unittest.TestCase):

def test_ds_basic(self):
ms_ds_full = PyDataset.load('squad')
ms_ds_full_hf = hfdata.load_dataset('squad')
ms_ds_train = PyDataset.load('squad', split='train')
ms_ds_train_hf = hfdata.load_dataset('squad', split='train')
ms_image_train = PyDataset.from_hf_dataset(
hfdata.load_dataset('beans', split='train'))
self.assertEqual(ms_ds_full['train'][0], ms_ds_full_hf['train'][0])
self.assertEqual(ms_ds_full['validation'][0],
ms_ds_full_hf['validation'][0])
self.assertEqual(ms_ds_train[0], ms_ds_train_hf[0])
print(next(iter(ms_ds_full['train'])))
print(next(iter(ms_ds_train)))
print(next(iter(ms_image_train)))

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@require_torch
def test_to_torch_dataset_text(self):
model_id = 'damo/bert-base-sst2'
nlp_model = Model.from_pretrained(model_id)
preprocessor = SequenceClassificationPreprocessor(
nlp_model.model_dir,
first_sequence='context',
second_sequence=None)
ms_ds_train = PyDataset.load('squad', split='train')
pt_dataset = ms_ds_train.to_torch_dataset(preprocessors=preprocessor)
import torch
dataloader = torch.utils.data.DataLoader(pt_dataset, batch_size=5)
print(next(iter(dataloader)))

def test_to_hf_dataset(self):
hfds = self.ds1.to_hf_dataset()
hfds1 = hfdata.Dataset.from_dict(self.json_data)
self.assertEqual(hfds.data, hfds1.data)
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@require_tf
def test_to_tf_dataset_text(self):
import tensorflow as tf
tf.compat.v1.enable_eager_execution()
model_id = 'damo/bert-base-sst2'
nlp_model = Model.from_pretrained(model_id)
preprocessor = SequenceClassificationPreprocessor(
nlp_model.model_dir,
first_sequence='context',
second_sequence=None)
ms_ds_train = PyDataset.load('squad', split='train')
tf_dataset = ms_ds_train.to_tf_dataset(
batch_size=5,
shuffle=True,
preprocessors=preprocessor,
drop_remainder=True)
print(next(iter(tf_dataset)))

# simple map function
hfds = hfds.map(lambda e: {'new_feature': e['dummy']['a']})
self.assertEqual(len(hfds['new_feature']), 10)
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@require_torch
def test_to_torch_dataset_img(self):
ms_image_train = PyDataset.from_hf_dataset(
hfdata.load_dataset('beans', split='train'))
pt_dataset = ms_image_train.to_torch_dataset(
preprocessors=ImgPreprocessor(
image_path='image_file_path', label='labels'))
import torch
dataloader = torch.utils.data.DataLoader(pt_dataset, batch_size=5)
print(next(iter(dataloader)))

hfds2 = self.ds2.to_hf_dataset()
self.assertTrue(hfds2[0]['sentence1'].startswith('Amrozi'))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@require_tf
def test_to_tf_dataset_img(self):
import tensorflow as tf
tf.compat.v1.enable_eager_execution()
ms_image_train = PyDataset.load('beans', split='train')
tf_dataset = ms_image_train.to_tf_dataset(
batch_size=5,
shuffle=True,
preprocessors=ImgPreprocessor(image_path='image_file_path'),
drop_remainder=True,
label_cols='labels')
print(next(iter(tf_dataset)))


if __name__ == '__main__':


Loading…
Cancel
Save