| @@ -19,7 +19,7 @@ can also create samplers with this module to sample data. | |||||
| """ | """ | ||||
| from .core.configuration import config | from .core.configuration import config | ||||
| from .engine.datasets import StorageDataset, TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, \ | |||||
| from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, \ | |||||
| GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, TextFileDataset, \ | GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, TextFileDataset, \ | ||||
| Schema, Shuffle, zip, RandomDataset | Schema, Shuffle, zip, RandomDataset | ||||
| from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ | from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ | ||||
| @@ -27,7 +27,7 @@ from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, Seque | |||||
| from .engine.serializer_deserializer import serialize, deserialize, show | from .engine.serializer_deserializer import serialize, deserialize, show | ||||
| from .engine.graphdata import GraphData | from .engine.graphdata import GraphData | ||||
| __all__ = ["config", "ImageFolderDatasetV2", "MnistDataset", "StorageDataset", | |||||
| __all__ = ["config", "ImageFolderDatasetV2", "MnistDataset", | |||||
| "MindDataset", "GeneratorDataset", "TFRecordDataset", | "MindDataset", "GeneratorDataset", "TFRecordDataset", | ||||
| "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", | "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", | ||||
| "VOCDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler", | "VOCDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler", | ||||
| @@ -29,7 +29,7 @@ from .samplers import * | |||||
| from ..core.configuration import config, ConfigurationManager | from ..core.configuration import config, ConfigurationManager | ||||
| __all__ = ["config", "ConfigurationManager", "zip", "StorageDataset", | |||||
| __all__ = ["config", "ConfigurationManager", "zip", | |||||
| "ImageFolderDatasetV2", "MnistDataset", | "ImageFolderDatasetV2", "MnistDataset", | ||||
| "MindDataset", "GeneratorDataset", "TFRecordDataset", | "MindDataset", "GeneratorDataset", "TFRecordDataset", | ||||
| "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", | "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", | ||||
| @@ -22,7 +22,6 @@ import glob | |||||
| import json | import json | ||||
| import math | import math | ||||
| import os | import os | ||||
| import random | |||||
| import uuid | import uuid | ||||
| import multiprocessing | import multiprocessing | ||||
| import queue | import queue | ||||
| @@ -40,7 +39,7 @@ from mindspore._c_expression import typing | |||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from . import samplers | from . import samplers | ||||
| from .iterators import DictIterator, TupleIterator | from .iterators import DictIterator, TupleIterator | ||||
| from .validators import check, check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ | |||||
| from .validators import check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, \ | |||||
| check_rename, \ | check_rename, \ | ||||
| check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ | ||||
| check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ | check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ | ||||
| @@ -480,7 +479,7 @@ class Dataset: | |||||
| If input_columns not provided or empty, all columns will be used. | If input_columns not provided or empty, all columns will be used. | ||||
| Args: | Args: | ||||
| predicate(callable): python callable which returns a boolean value. | |||||
| predicate(callable): python callable which returns a boolean value, if False then filter the element. | |||||
| input_columns: (list[str], optional): List of names of the input columns, when | input_columns: (list[str], optional): List of names of the input columns, when | ||||
| default=None, the predicate will be applied on all columns in the dataset. | default=None, the predicate will be applied on all columns in the dataset. | ||||
| num_parallel_workers (int, optional): Number of workers to process the Dataset | num_parallel_workers (int, optional): Number of workers to process the Dataset | ||||
| @@ -899,7 +898,7 @@ class Dataset: | |||||
| def get_distribution(output_dataset): | def get_distribution(output_dataset): | ||||
| dev_id = 0 | dev_id = 0 | ||||
| if isinstance(output_dataset, (StorageDataset, MindDataset)): | |||||
| if isinstance(output_dataset, (MindDataset)): | |||||
| return output_dataset.distribution, dev_id | return output_dataset.distribution, dev_id | ||||
| if isinstance(output_dataset, (Cifar10Dataset, Cifar100Dataset, GeneratorDataset, ImageFolderDatasetV2, | if isinstance(output_dataset, (Cifar10Dataset, Cifar100Dataset, GeneratorDataset, ImageFolderDatasetV2, | ||||
| ManifestDataset, MnistDataset, VOCDataset, CelebADataset)): | ManifestDataset, MnistDataset, VOCDataset, CelebADataset)): | ||||
| @@ -984,57 +983,6 @@ class Dataset: | |||||
| """Create an Iterator over the dataset.""" | """Create an Iterator over the dataset.""" | ||||
| return self.create_tuple_iterator() | return self.create_tuple_iterator() | ||||
| @staticmethod | |||||
| def read_dir(dir_path, schema, columns_list=None, num_parallel_workers=None, | |||||
| deterministic_output=True, prefetch_size=None, shuffle=False, seed=None, distribution=""): | |||||
| """ | |||||
| Append the path of all files in the dir_path to StorageDataset. | |||||
| Args: | |||||
| dir_path (str): Path to the directory that contains the dataset. | |||||
| schema (str): Path to the json schema file. | |||||
| columns_list (list[str], optional): List of columns to be read (default=None). | |||||
| If not provided, read all columns. | |||||
| num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel | |||||
| (default=None). | |||||
| deterministic_output (bool, optional): Whether the result of this dataset can be reproduced | |||||
| or not (default=True). If True, performance might be affected. | |||||
| prefetch_size (int, optional): Prefetch number of records ahead of the | |||||
| user's request (default=None). | |||||
| shuffle (bool, optional): Shuffle the list of files in the directory (default=False). | |||||
| seed (int, optional): Create a random generator with a fixed seed. If set to None, | |||||
| create a random seed (default=None). | |||||
| distribution (str, optional): The path of distribution config file (default=""). | |||||
| Returns: | |||||
| StorageDataset. | |||||
| Raises: | |||||
| ValueError: If dataset folder does not exist. | |||||
| ValueError: If dataset folder permission denied. | |||||
| """ | |||||
| logger.warning("WARN_DEPRECATED: The usage of read_dir is deprecated, please use TFRecordDataset with GLOB.") | |||||
| list_files = [] | |||||
| if not os.path.isdir(dir_path): | |||||
| raise ValueError("The dataset folder does not exist!") | |||||
| if not os.access(dir_path, os.R_OK): | |||||
| raise ValueError("The dataset folder permission denied!") | |||||
| for root, _, files in os.walk(dir_path): | |||||
| for file in files: | |||||
| list_files.append(os.path.join(root, file)) | |||||
| list_files.sort() | |||||
| if shuffle: | |||||
| rand = random.Random(seed) | |||||
| rand.shuffle(list_files) | |||||
| return StorageDataset(list_files, schema, distribution, columns_list, num_parallel_workers, | |||||
| deterministic_output, prefetch_size) | |||||
| @property | @property | ||||
| def input_indexs(self): | def input_indexs(self): | ||||
| return self._input_indexs | return self._input_indexs | ||||
| @@ -1818,7 +1766,7 @@ class FilterDataset(DatasetOp): | |||||
| Args: | Args: | ||||
| input_dataset: Input Dataset to be mapped. | input_dataset: Input Dataset to be mapped. | ||||
| predicate: python callable which returns a boolean value. | |||||
| predicate: python callable which returns a boolean value, if False then filter the element. | |||||
| input_columns: (list[str]): List of names of the input columns, when | input_columns: (list[str]): List of names of the input columns, when | ||||
| default=None, the predicate will be applied all columns in the dataset. | default=None, the predicate will be applied all columns in the dataset. | ||||
| num_parallel_workers (int, optional): Number of workers to process the Dataset | num_parallel_workers (int, optional): Number of workers to process the Dataset | ||||
| @@ -2157,123 +2105,6 @@ class TransferDataset(DatasetOp): | |||||
| self.iterator = TupleIterator(self) | self.iterator = TupleIterator(self) | ||||
| class StorageDataset(SourceDataset): | |||||
| """ | |||||
| A source dataset that reads and parses datasets stored on disk in various formats, including TFData format. | |||||
| Args: | |||||
| dataset_files (list[str]): List of files to be read. | |||||
| schema (str): Path to the json schema file. If numRows(parsed from schema) is not exist, read the full dataset. | |||||
| distribution (str, optional): Path of distribution config file (default=""). | |||||
| columns_list (list[str], optional): List of columns to be read (default=None, read all columns). | |||||
| num_parallel_workers (int, optional): Number of parallel working threads (default=None). | |||||
| deterministic_output (bool, optional): Whether the result of this dataset can be reproduced | |||||
| or not (default=True). If True, performance might be affected. | |||||
| prefetch_size (int, optional): Prefetch number of records ahead of the user's request (default=None). | |||||
| Raises: | |||||
| RuntimeError: If schema file failed to read. | |||||
| RuntimeError: If distribution file path is given but failed to read. | |||||
| """ | |||||
| @check | |||||
| def __init__(self, dataset_files, schema, distribution="", columns_list=None, num_parallel_workers=None, | |||||
| deterministic_output=None, prefetch_size=None): | |||||
| super().__init__(num_parallel_workers) | |||||
| logger.warning("WARN_DEPRECATED: The usage of StorageDataset is deprecated, please use TFRecordDataset.") | |||||
| self.dataset_files = dataset_files | |||||
| try: | |||||
| with open(schema, 'r') as load_f: | |||||
| json.load(load_f) | |||||
| except json.decoder.JSONDecodeError: | |||||
| raise RuntimeError("Json decode error when load schema file") | |||||
| except Exception: | |||||
| raise RuntimeError("Schema file failed to load") | |||||
| if distribution != "": | |||||
| try: | |||||
| with open(distribution, 'r') as load_d: | |||||
| json.load(load_d) | |||||
| except json.decoder.JSONDecodeError: | |||||
| raise RuntimeError("Json decode error when load distribution file") | |||||
| except Exception: | |||||
| raise RuntimeError("Distribution file failed to load") | |||||
| if self.dataset_files is None: | |||||
| schema = None | |||||
| distribution = None | |||||
| self.schema = schema | |||||
| self.distribution = distribution | |||||
| self.columns_list = columns_list | |||||
| self.deterministic_output = deterministic_output | |||||
| self.prefetch_size = prefetch_size | |||||
| def get_args(self): | |||||
| args = super().get_args() | |||||
| args["dataset_files"] = self.dataset_files | |||||
| args["schema"] = self.schema | |||||
| args["distribution"] = self.distribution | |||||
| args["columns_list"] = self.columns_list | |||||
| args["deterministic_output"] = self.deterministic_output | |||||
| args["prefetch_size"] = self.prefetch_size | |||||
| return args | |||||
| def get_dataset_size(self): | |||||
| """ | |||||
| Get the number of batches in an epoch. | |||||
| Return: | |||||
| Number, number of batches. | |||||
| """ | |||||
| if self._dataset_size is None: | |||||
| self._get_pipeline_info() | |||||
| return self._dataset_size | |||||
| # manually set dataset_size as a temporary solution. | |||||
| def set_dataset_size(self, value): | |||||
| logger.warning("WARN_DEPRECATED: This method is deprecated. Please use get_dataset_size directly.") | |||||
| if value >= 0: | |||||
| self._dataset_size = value | |||||
| else: | |||||
| raise ValueError('set dataset_size with negative value {}'.format(value)) | |||||
| def num_classes(self): | |||||
| """ | |||||
| Get the number of classes in dataset. | |||||
| Return: | |||||
| Number, number of classes. | |||||
| Raises: | |||||
| ValueError: If dataset type is invalid. | |||||
| ValueError: If dataset is not Imagenet dataset or manifest dataset. | |||||
| RuntimeError: If schema file is given but failed to load. | |||||
| """ | |||||
| cur_dataset = self | |||||
| while cur_dataset.input: | |||||
| cur_dataset = cur_dataset.input[0] | |||||
| if not hasattr(cur_dataset, "schema"): | |||||
| raise ValueError("Dataset type is invalid") | |||||
| # Only IMAGENET/MANIFEST support numclass | |||||
| try: | |||||
| with open(cur_dataset.schema, 'r') as load_f: | |||||
| load_dict = json.load(load_f) | |||||
| except json.decoder.JSONDecodeError: | |||||
| raise RuntimeError("Json decode error when load schema file") | |||||
| except Exception: | |||||
| raise RuntimeError("Schema file failed to load") | |||||
| if load_dict["datasetType"] != "IMAGENET" and load_dict["datasetType"] != "MANIFEST": | |||||
| raise ValueError("%s dataset does not support num_classes!" % (load_dict["datasetType"])) | |||||
| if self._num_classes is None: | |||||
| self._get_pipeline_info() | |||||
| return self._num_classes | |||||
| def is_shuffled(self): | |||||
| return False | |||||
| def is_sharded(self): | |||||
| return False | |||||
| class RangeDataset(MappableDataset): | class RangeDataset(MappableDataset): | ||||
| """ | """ | ||||
| @@ -168,8 +168,6 @@ class Iterator: | |||||
| op_type = OpName.SKIP | op_type = OpName.SKIP | ||||
| elif isinstance(dataset, de.TakeDataset): | elif isinstance(dataset, de.TakeDataset): | ||||
| op_type = OpName.TAKE | op_type = OpName.TAKE | ||||
| elif isinstance(dataset, de.StorageDataset): | |||||
| op_type = OpName.STORAGE | |||||
| elif isinstance(dataset, de.ImageFolderDatasetV2): | elif isinstance(dataset, de.ImageFolderDatasetV2): | ||||
| op_type = OpName.IMAGEFOLDER | op_type = OpName.IMAGEFOLDER | ||||
| elif isinstance(dataset, de.GeneratorDataset): | elif isinstance(dataset, de.GeneratorDataset): | ||||
| @@ -230,11 +230,7 @@ def create_node(node): | |||||
| pyobj = None | pyobj = None | ||||
| # Find a matching Dataset class and call the constructor with the corresponding args. | # Find a matching Dataset class and call the constructor with the corresponding args. | ||||
| # When a new Dataset class is introduced, another if clause and parsing code needs to be added. | # When a new Dataset class is introduced, another if clause and parsing code needs to be added. | ||||
| if dataset_op == 'StorageDataset': | |||||
| pyobj = pyclass(node['dataset_files'], node['schema'], node.get('distribution'), | |||||
| node.get('columns_list'), node.get('num_parallel_workers')) | |||||
| elif dataset_op == 'ImageFolderDatasetV2': | |||||
| if dataset_op == 'ImageFolderDatasetV2': | |||||
| sampler = construct_sampler(node.get('sampler')) | sampler = construct_sampler(node.get('sampler')) | ||||
| pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'), | pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'), | ||||
| node.get('shuffle'), sampler, node.get('extensions'), | node.get('shuffle'), sampler, node.get('extensions'), | ||||
| @@ -31,7 +31,7 @@ SCHEMA_DIR = "{0}/resnet_all_datasetSchema.json".format(data_path) | |||||
| def test_me_de_train_dataset(): | def test_me_de_train_dataset(): | ||||
| data_list = ["{0}/train-00001-of-01024.data".format(data_path)] | data_list = ["{0}/train-00001-of-01024.data".format(data_path)] | ||||
| data_set = ds.StorageDataset(data_list, schema=SCHEMA_DIR, | |||||
| data_set = ds.TFRecordDataset(data_list, schema=SCHEMA_DIR, | |||||
| columns_list=["image/encoded", "image/class/label"]) | columns_list=["image/encoded", "image/class/label"]) | ||||
| resize_height = 224 | resize_height = 224 | ||||
| @@ -24,11 +24,6 @@ DATA_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data", | |||||
| SCHEMA_DIR = "../data/dataset/test_tf_file_3_images2/datasetSchema.json" | SCHEMA_DIR = "../data/dataset/test_tf_file_3_images2/datasetSchema.json" | ||||
| DISTRIBUTION_ALL_DIR = "../data/dataset/test_tf_file_3_images2/dataDistributionAll.json" | |||||
| DISTRIBUTION_UNIQUE_DIR = "../data/dataset/test_tf_file_3_images2/dataDistributionUnique.json" | |||||
| DISTRIBUTION_RANDOM_DIR = "../data/dataset/test_tf_file_3_images2/dataDistributionRandom.json" | |||||
| DISTRIBUTION_EQUAL_DIR = "../data/dataset/test_tf_file_3_images2/dataDistributionEqualRows.json" | |||||
| def test_tf_file_normal(): | def test_tf_file_normal(): | ||||
| # apply dataset operations | # apply dataset operations | ||||
| @@ -42,61 +37,6 @@ def test_tf_file_normal(): | |||||
| assert num_iter == 12 | assert num_iter == 12 | ||||
| def test_tf_file_distribution_all(): | |||||
| # apply dataset operations | |||||
| data1 = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, DISTRIBUTION_ALL_DIR) | |||||
| data1 = data1.repeat(2) | |||||
| num_iter = 0 | |||||
| for item in data1.create_dict_iterator(): # each data is a dictionary | |||||
| num_iter += 1 | |||||
| logger.info("Number of data in data1: {}".format(num_iter)) | |||||
| assert num_iter == 24 | |||||
| def test_tf_file_distribution_unique(): | |||||
| data1 = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, DISTRIBUTION_UNIQUE_DIR) | |||||
| data1 = data1.repeat(1) | |||||
| num_iter = 0 | |||||
| for item in data1.create_dict_iterator(): # each data is a dictionary | |||||
| num_iter += 1 | |||||
| logger.info("Number of data in data1: {}".format(num_iter)) | |||||
| assert num_iter == 4 | |||||
| def test_tf_file_distribution_random(): | |||||
| data1 = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, DISTRIBUTION_RANDOM_DIR) | |||||
| data1 = data1.repeat(1) | |||||
| num_iter = 0 | |||||
| for item in data1.create_dict_iterator(): # each data is a dictionary | |||||
| num_iter += 1 | |||||
| logger.info("Number of data in data1: {}".format(num_iter)) | |||||
| assert num_iter == 4 | |||||
| def test_tf_file_distribution_equal_rows(): | |||||
| data1 = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, DISTRIBUTION_EQUAL_DIR) | |||||
| data1 = data1.repeat(2) | |||||
| num_iter = 0 | |||||
| for item in data1.create_dict_iterator(): # each data is a dictionary | |||||
| num_iter += 1 | |||||
| assert num_iter == 4 | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| logger.info('=======test normal=======') | logger.info('=======test normal=======') | ||||
| test_tf_file_normal() | test_tf_file_normal() | ||||
| logger.info('=======test all=======') | |||||
| test_tf_file_distribution_all() | |||||
| logger.info('=======test unique=======') | |||||
| test_tf_file_distribution_unique() | |||||
| logger.info('=======test random=======') | |||||
| test_tf_file_distribution_random() | |||||
| logger.info('=======test equal rows=======') | |||||
| test_tf_file_distribution_equal_rows() | |||||
| @@ -1,69 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| import mindspore.dataset as ds | |||||
| from mindspore import log as logger | |||||
| DATA_DIR = "../data/dataset/test_tf_file_3_images/data" | |||||
| SCHEMA = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | |||||
| COLUMNS = ["label"] | |||||
| GENERATE_GOLDEN = False | |||||
| def test_case_0(): | |||||
| logger.info("Test 0 readdir") | |||||
| # apply dataset operations | |||||
| data1 = ds.engine.Dataset.read_dir(DATA_DIR, SCHEMA, columns_list=None, num_parallel_workers=None, | |||||
| deterministic_output=True, prefetch_size=None, shuffle=False, seed=None) | |||||
| i = 0 | |||||
| for item in data1.create_dict_iterator(): # each data is a dictionary | |||||
| logger.info("item[label] is {}".format(item["label"])) | |||||
| i = i + 1 | |||||
| assert (i == 3) | |||||
| def test_case_1(): | |||||
| logger.info("Test 1 readdir") | |||||
| # apply dataset operations | |||||
| data1 = ds.engine.Dataset.read_dir(DATA_DIR, SCHEMA, COLUMNS, num_parallel_workers=None, | |||||
| deterministic_output=True, prefetch_size=None, shuffle=True, seed=None) | |||||
| i = 0 | |||||
| for item in data1.create_dict_iterator(): # each data is a dictionary | |||||
| logger.info("item[label] is {}".format(item["label"])) | |||||
| i = i + 1 | |||||
| assert (i == 3) | |||||
| def test_case_2(): | |||||
| logger.info("Test 2 readdir") | |||||
| # apply dataset operations | |||||
| data1 = ds.engine.Dataset.read_dir(DATA_DIR, SCHEMA, columns_list=None, num_parallel_workers=2, | |||||
| deterministic_output=False, prefetch_size=16, shuffle=True, seed=10) | |||||
| i = 0 | |||||
| for item in data1.create_dict_iterator(): # each data is a dictionary | |||||
| logger.info("item[label] is {}".format(item["label"])) | |||||
| i = i + 1 | |||||
| assert (i == 3) | |||||
| if __name__ == "__main__": | |||||
| test_case_0() | |||||
| test_case_1() | |||||
| test_case_2() | |||||
| @@ -177,7 +177,7 @@ def test_random_crop(): | |||||
| SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" | ||||
| # First dataset | # First dataset | ||||
| data1 = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) | |||||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) | |||||
| decode_op = vision.Decode() | decode_op = vision.Decode() | ||||
| random_crop_op = vision.RandomCrop([512, 512], [200, 200, 200, 200]) | random_crop_op = vision.RandomCrop([512, 512], [200, 200, 200, 200]) | ||||
| data1 = data1.map(input_columns="image", operations=decode_op) | data1 = data1.map(input_columns="image", operations=decode_op) | ||||
| @@ -192,7 +192,7 @@ def test_random_crop(): | |||||
| data1_1 = ds.deserialize(input_dict=ds1_dict) | data1_1 = ds.deserialize(input_dict=ds1_dict) | ||||
| # Second dataset | # Second dataset | ||||
| data2 = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) | |||||
| data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) | |||||
| data2 = data2.map(input_columns="image", operations=decode_op) | data2 = data2.map(input_columns="image", operations=decode_op) | ||||
| for item1, item1_1, item2 in zip(data1.create_dict_iterator(), data1_1.create_dict_iterator(), | for item1, item1_1, item2 in zip(data1.create_dict_iterator(), data1_1.create_dict_iterator(), | ||||
| @@ -1,51 +0,0 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| from util import save_and_check | |||||
| import mindspore.dataset as ds | |||||
| from mindspore import log as logger | |||||
| DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"] | |||||
| SCHEMA_DIR = "../data/dataset/testTFTestAllTypes/datasetSchema.json" | |||||
| COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", | |||||
| "col_sint16", "col_sint32", "col_sint64"] | |||||
| GENERATE_GOLDEN = False | |||||
| def test_case_storage(): | |||||
| """ | |||||
| test StorageDataset | |||||
| """ | |||||
| logger.info("Test Simple StorageDataset") | |||||
| # define parameters | |||||
| parameters = {"params": {}} | |||||
| # apply dataset operations | |||||
| data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) | |||||
| filename = "storage_result.npz" | |||||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | |||||
| def test_case_no_rows(): | |||||
| DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||||
| SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetNoRowsSchema.json" | |||||
| dataset = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) | |||||
| assert dataset.get_dataset_size() == 3 | |||||
| count = 0 | |||||
| for data in dataset.create_tuple_iterator(): | |||||
| count += 1 | |||||
| assert count == 3 | |||||