You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_serdes_dataset.py 10 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing dataset serialize and deserialize in DE
  17. """
  18. import filecmp
  19. import glob
  20. import json
  21. import os
  22. import numpy as np
  23. import mindspore.dataset as ds
  24. import mindspore.dataset.transforms.c_transforms as c
  25. import mindspore.dataset.transforms.vision.c_transforms as vision
  26. from mindspore import log as logger
  27. from mindspore.dataset.transforms.vision import Inter
  28. from test_minddataset_sampler import add_and_remove_cv_file, get_data, CV_DIR_NAME, CV_FILE_NAME
  29. from util import config_get_set_num_parallel_workers
  30. def test_imagefolder(remove_json_files=True):
  31. """
  32. Test simulating resnet50 dataset pipeline.
  33. """
  34. data_dir = "../data/dataset/testPK/data"
  35. ds.config.set_seed(1)
  36. # define data augmentation parameters
  37. rescale = 1.0 / 255.0
  38. shift = 0.0
  39. resize_height, resize_width = 224, 224
  40. weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 1.1]
  41. # Constructing DE pipeline
  42. sampler = ds.WeightedRandomSampler(weights, 11)
  43. data1 = ds.ImageFolderDatasetV2(data_dir, sampler=sampler)
  44. data1 = data1.repeat(1)
  45. data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True)])
  46. rescale_op = vision.Rescale(rescale, shift)
  47. resize_op = vision.Resize((resize_height, resize_width), Inter.LINEAR)
  48. data1 = data1.map(input_columns=["image"], operations=[rescale_op, resize_op])
  49. data1 = data1.batch(2)
  50. # Serialize the dataset pre-processing pipeline.
  51. # data1 should still work after saving.
  52. ds.serialize(data1, "imagenet_dataset_pipeline.json")
  53. ds1_dict = ds.serialize(data1)
  54. assert validate_jsonfile("imagenet_dataset_pipeline.json") is True
  55. # Print the serialized pipeline to stdout
  56. ds.show(data1)
  57. # Deserialize the serialized json file
  58. data2 = ds.deserialize(json_filepath="imagenet_dataset_pipeline.json")
  59. # Serialize the pipeline we just deserialized.
  60. # The content of the json file should be the same to the previous serialize.
  61. ds.serialize(data2, "imagenet_dataset_pipeline_1.json")
  62. assert validate_jsonfile("imagenet_dataset_pipeline_1.json") is True
  63. assert filecmp.cmp('imagenet_dataset_pipeline.json', 'imagenet_dataset_pipeline_1.json')
  64. # Deserialize the latest json file again
  65. data3 = ds.deserialize(json_filepath="imagenet_dataset_pipeline_1.json")
  66. data4 = ds.deserialize(input_dict=ds1_dict)
  67. num_samples = 0
  68. # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
  69. for item1, item2, item3, item4 in zip(data1.create_dict_iterator(), data2.create_dict_iterator(),
  70. data3.create_dict_iterator(), data4.create_dict_iterator()):
  71. assert np.array_equal(item1['image'], item2['image'])
  72. assert np.array_equal(item1['image'], item3['image'])
  73. assert np.array_equal(item1['label'], item2['label'])
  74. assert np.array_equal(item1['label'], item3['label'])
  75. assert np.array_equal(item3['image'], item4['image'])
  76. assert np.array_equal(item3['label'], item4['label'])
  77. num_samples += 1
  78. logger.info("Number of data in data1: {}".format(num_samples))
  79. assert num_samples == 6
  80. # Remove the generated json file
  81. if remove_json_files:
  82. delete_json_files()
  83. def test_mnist_dataset(remove_json_files=True):
  84. data_dir = "../data/dataset/testMnistData"
  85. ds.config.set_seed(1)
  86. data1 = ds.MnistDataset(data_dir, 100)
  87. one_hot_encode = c.OneHot(10) # num_classes is input argument
  88. data1 = data1.map(input_columns="label", operations=one_hot_encode)
  89. # batch_size is input argument
  90. data1 = data1.batch(batch_size=10, drop_remainder=True)
  91. ds.serialize(data1, "mnist_dataset_pipeline.json")
  92. assert validate_jsonfile("mnist_dataset_pipeline.json") is True
  93. data2 = ds.deserialize(json_filepath="mnist_dataset_pipeline.json")
  94. ds.serialize(data2, "mnist_dataset_pipeline_1.json")
  95. assert validate_jsonfile("mnist_dataset_pipeline_1.json") is True
  96. assert filecmp.cmp('mnist_dataset_pipeline.json', 'mnist_dataset_pipeline_1.json')
  97. data3 = ds.deserialize(json_filepath="mnist_dataset_pipeline_1.json")
  98. num = 0
  99. for data1, data2, data3 in zip(data1.create_dict_iterator(), data2.create_dict_iterator(),
  100. data3.create_dict_iterator()):
  101. assert np.array_equal(data1['image'], data2['image'])
  102. assert np.array_equal(data1['image'], data3['image'])
  103. assert np.array_equal(data1['label'], data2['label'])
  104. assert np.array_equal(data1['label'], data3['label'])
  105. num += 1
  106. logger.info("mnist total num samples is {}".format(str(num)))
  107. assert num == 10
  108. if remove_json_files:
  109. delete_json_files()
  110. def test_zip_dataset(remove_json_files=True):
  111. files = ["../data/dataset/testTFTestAllTypes/test.data"]
  112. schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
  113. ds.config.set_seed(1)
  114. ds0 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.GLOBAL)
  115. data1 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.GLOBAL)
  116. data2 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.FILES)
  117. data2 = data2.shuffle(10000)
  118. data2 = data2.rename(input_columns=["col_sint16", "col_sint32", "col_sint64", "col_float",
  119. "col_1d", "col_2d", "col_3d", "col_binary"],
  120. output_columns=["column_sint16", "column_sint32", "column_sint64", "column_float",
  121. "column_1d", "column_2d", "column_3d", "column_binary"])
  122. data3 = ds.zip((data1, data2))
  123. ds.serialize(data3, "zip_dataset_pipeline.json")
  124. assert validate_jsonfile("zip_dataset_pipeline.json") is True
  125. assert validate_jsonfile("zip_dataset_pipeline_typo.json") is False
  126. data4 = ds.deserialize(json_filepath="zip_dataset_pipeline.json")
  127. ds.serialize(data4, "zip_dataset_pipeline_1.json")
  128. assert validate_jsonfile("zip_dataset_pipeline_1.json") is True
  129. assert filecmp.cmp('zip_dataset_pipeline.json', 'zip_dataset_pipeline_1.json')
  130. rows = 0
  131. for d0, d3, d4 in zip(ds0, data3, data4):
  132. num_cols = len(d0)
  133. offset = 0
  134. for t1 in d0:
  135. assert np.array_equal(t1, d3[offset])
  136. assert np.array_equal(t1, d3[offset + num_cols])
  137. assert np.array_equal(t1, d4[offset])
  138. assert np.array_equal(t1, d4[offset + num_cols])
  139. offset += 1
  140. rows += 1
  141. assert rows == 12
  142. if remove_json_files:
  143. delete_json_files()
  144. def test_random_crop():
  145. logger.info("test_random_crop")
  146. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  147. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  148. original_num_parallel_workers = config_get_set_num_parallel_workers(1)
  149. # First dataset
  150. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
  151. decode_op = vision.Decode()
  152. random_crop_op = vision.RandomCrop([512, 512], [200, 200, 200, 200])
  153. data1 = data1.map(input_columns="image", operations=decode_op)
  154. data1 = data1.map(input_columns="image", operations=random_crop_op)
  155. # Serializing into python dictionary
  156. ds1_dict = ds.serialize(data1)
  157. # Serializing into json object
  158. _ = json.dumps(ds1_dict, indent=2)
  159. # Reconstruct dataset pipeline from its serialized form
  160. data1_1 = ds.deserialize(input_dict=ds1_dict)
  161. # Second dataset
  162. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
  163. data2 = data2.map(input_columns="image", operations=decode_op)
  164. for item1, item1_1, item2 in zip(data1.create_dict_iterator(), data1_1.create_dict_iterator(),
  165. data2.create_dict_iterator()):
  166. assert np.array_equal(item1['image'], item1_1['image'])
  167. _ = item2["image"]
  168. # Restore configuration num_parallel_workers
  169. ds.config.set_num_parallel_workers(original_num_parallel_workers)
  170. def validate_jsonfile(filepath):
  171. try:
  172. file_exist = os.path.exists(filepath)
  173. with open(filepath, 'r') as jfile:
  174. loaded_json = json.load(jfile)
  175. except IOError:
  176. return False
  177. return file_exist and isinstance(loaded_json, dict)
  178. def delete_json_files():
  179. file_list = glob.glob('*.json')
  180. for f in file_list:
  181. try:
  182. os.remove(f)
  183. except IOError:
  184. logger.info("Error while deleting: {}".format(f))
  185. # Test save load minddataset
  186. def test_minddataset(add_and_remove_cv_file):
  187. """tutorial for cv minderdataset."""
  188. columns_list = ["data", "file_name", "label"]
  189. num_readers = 4
  190. indices = [1, 2, 3, 5, 7]
  191. sampler = ds.SubsetRandomSampler(indices)
  192. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  193. sampler=sampler)
  194. # Serializing into python dictionary
  195. ds1_dict = ds.serialize(data_set)
  196. # Serializing into json object
  197. ds1_json = json.dumps(ds1_dict, sort_keys=True)
  198. # Reconstruct dataset pipeline from its serialized form
  199. data_set = ds.deserialize(input_dict=ds1_dict)
  200. ds2_dict = ds.serialize(data_set)
  201. # Serializing into json object
  202. ds2_json = json.dumps(ds2_dict, sort_keys=True)
  203. assert ds1_json == ds2_json
  204. _ = get_data(CV_DIR_NAME)
  205. assert data_set.get_dataset_size() == 5
  206. num_iter = 0
  207. for _ in data_set.create_dict_iterator():
  208. num_iter += 1
  209. assert num_iter == 5
  210. if __name__ == '__main__':
  211. test_imagefolder()
  212. test_zip_dataset()
  213. test_mnist_dataset()
  214. test_random_crop()