You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_serdes_dataset.py 10 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing dataset serialize and deserialize in DE
  17. """
  18. import filecmp
  19. import glob
  20. import json
  21. import numpy as np
  22. import os
  23. import pytest
  24. import mindspore.dataset as ds
  25. import mindspore.dataset.transforms.c_transforms as c
  26. import mindspore.dataset.transforms.vision.c_transforms as vision
  27. from mindspore import log as logger
  28. from mindspore.dataset.transforms.vision import Inter
  29. def test_imagefolder(remove_json_files=True):
  30. """
  31. Test simulating resnet50 dataset pipeline.
  32. """
  33. data_dir = "../data/dataset/testPK/data"
  34. ds.config.set_seed(1)
  35. # define data augmentation parameters
  36. rescale = 1.0 / 255.0
  37. shift = 0.0
  38. resize_height, resize_width = 224, 224
  39. weights = [1.0, 0.1, 0.02, 0.3, 0.4, 0.05, 1.2, 0.13, 0.14, 0.015, 0.16, 1.1]
  40. # Constructing DE pipeline
  41. sampler = ds.WeightedRandomSampler(weights, 11)
  42. data1 = ds.ImageFolderDatasetV2(data_dir, sampler=sampler)
  43. data1 = data1.repeat(1)
  44. data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True)])
  45. rescale_op = vision.Rescale(rescale, shift)
  46. resize_op = vision.Resize((resize_height, resize_width), Inter.LINEAR)
  47. data1 = data1.map(input_columns=["image"], operations=[rescale_op, resize_op])
  48. data1 = data1.batch(2)
  49. # Serialize the dataset pre-processing pipeline.
  50. # data1 should still work after saving.
  51. ds.serialize(data1, "imagenet_dataset_pipeline.json")
  52. ds1_dict = ds.serialize(data1)
  53. assert (validate_jsonfile("imagenet_dataset_pipeline.json") is True)
  54. # Print the serialized pipeline to stdout
  55. ds.show(data1)
  56. # Deserialize the serialized json file
  57. data2 = ds.deserialize(json_filepath="imagenet_dataset_pipeline.json")
  58. # Serialize the pipeline we just deserialized.
  59. # The content of the json file should be the same to the previous serialize.
  60. ds.serialize(data2, "imagenet_dataset_pipeline_1.json")
  61. assert (validate_jsonfile("imagenet_dataset_pipeline_1.json") is True)
  62. assert (filecmp.cmp('imagenet_dataset_pipeline.json', 'imagenet_dataset_pipeline_1.json'))
  63. # Deserialize the latest json file again
  64. data3 = ds.deserialize(json_filepath="imagenet_dataset_pipeline_1.json")
  65. data4 = ds.deserialize(input_dict=ds1_dict)
  66. num_samples = 0
  67. # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2)
  68. for item1, item2, item3, item4 in zip(data1.create_dict_iterator(), data2.create_dict_iterator(),
  69. data3.create_dict_iterator(), data4.create_dict_iterator()):
  70. assert (np.array_equal(item1['image'], item2['image']))
  71. assert (np.array_equal(item1['image'], item3['image']))
  72. assert (np.array_equal(item1['label'], item2['label']))
  73. assert (np.array_equal(item1['label'], item3['label']))
  74. assert (np.array_equal(item3['image'], item4['image']))
  75. assert (np.array_equal(item3['label'], item4['label']))
  76. num_samples += 1
  77. logger.info("Number of data in data1: {}".format(num_samples))
  78. assert (num_samples == 6)
  79. # Remove the generated json file
  80. if remove_json_files:
  81. delete_json_files()
  82. def test_mnist_dataset(remove_json_files=True):
  83. data_dir = "../data/dataset/testMnistData"
  84. ds.config.set_seed(1)
  85. data1 = ds.MnistDataset(data_dir, 100)
  86. one_hot_encode = c.OneHot(10) # num_classes is input argument
  87. data1 = data1.map(input_columns="label", operations=one_hot_encode)
  88. # batch_size is input argument
  89. data1 = data1.batch(batch_size=10, drop_remainder=True)
  90. ds.serialize(data1, "mnist_dataset_pipeline.json")
  91. assert (validate_jsonfile("mnist_dataset_pipeline.json") is True)
  92. data2 = ds.deserialize(json_filepath="mnist_dataset_pipeline.json")
  93. ds.serialize(data2, "mnist_dataset_pipeline_1.json")
  94. assert (validate_jsonfile("mnist_dataset_pipeline_1.json") is True)
  95. assert (filecmp.cmp('mnist_dataset_pipeline.json', 'mnist_dataset_pipeline_1.json'))
  96. data3 = ds.deserialize(json_filepath="mnist_dataset_pipeline_1.json")
  97. num = 0
  98. for data1, data2, data3 in zip(data1.create_dict_iterator(), data2.create_dict_iterator(),
  99. data3.create_dict_iterator()):
  100. assert (np.array_equal(data1['image'], data2['image']))
  101. assert (np.array_equal(data1['image'], data3['image']))
  102. assert (np.array_equal(data1['label'], data2['label']))
  103. assert (np.array_equal(data1['label'], data3['label']))
  104. num += 1
  105. logger.info("mnist total num samples is {}".format(str(num)))
  106. assert (num == 10)
  107. if remove_json_files:
  108. delete_json_files()
  109. def test_zip_dataset(remove_json_files=True):
  110. files = ["../data/dataset/testTFTestAllTypes/test.data"]
  111. schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
  112. ds.config.set_seed(1)
  113. ds0 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.GLOBAL)
  114. data1 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.GLOBAL)
  115. data2 = ds.TFRecordDataset(files, schema=schema_file, shuffle=ds.Shuffle.FILES)
  116. data2 = data2.shuffle(10000)
  117. data2 = data2.rename(input_columns=["col_sint16", "col_sint32", "col_sint64", "col_float",
  118. "col_1d", "col_2d", "col_3d", "col_binary"],
  119. output_columns=["column_sint16", "column_sint32", "column_sint64", "column_float",
  120. "column_1d", "column_2d", "column_3d", "column_binary"])
  121. data3 = ds.zip((data1, data2))
  122. ds.serialize(data3, "zip_dataset_pipeline.json")
  123. assert (validate_jsonfile("zip_dataset_pipeline.json") is True)
  124. assert (validate_jsonfile("zip_dataset_pipeline_typo.json") is False)
  125. data4 = ds.deserialize(json_filepath="zip_dataset_pipeline.json")
  126. ds.serialize(data4, "zip_dataset_pipeline_1.json")
  127. assert (validate_jsonfile("zip_dataset_pipeline_1.json") is True)
  128. assert (filecmp.cmp('zip_dataset_pipeline.json', 'zip_dataset_pipeline_1.json'))
  129. rows = 0
  130. for d0, d3, d4 in zip(ds0, data3, data4):
  131. num_cols = len(d0)
  132. offset = 0
  133. for t1 in d0:
  134. assert np.array_equal(t1, d3[offset])
  135. assert np.array_equal(t1, d3[offset + num_cols])
  136. assert np.array_equal(t1, d4[offset])
  137. assert np.array_equal(t1, d4[offset + num_cols])
  138. offset += 1
  139. rows += 1
  140. assert (rows == 12)
  141. if remove_json_files:
  142. delete_json_files()
  143. def test_random_crop():
  144. logger.info("test_random_crop")
  145. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  146. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  147. # First dataset
  148. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
  149. decode_op = vision.Decode()
  150. random_crop_op = vision.RandomCrop([512, 512], [200, 200, 200, 200])
  151. data1 = data1.map(input_columns="image", operations=decode_op)
  152. data1 = data1.map(input_columns="image", operations=random_crop_op)
  153. # Serializing into python dictionary
  154. ds1_dict = ds.serialize(data1)
  155. # Serializing into json object
  156. ds1_json = json.dumps(ds1_dict, indent=2)
  157. # Reconstruct dataset pipeline from its serialized form
  158. data1_1 = ds.deserialize(input_dict=ds1_dict)
  159. # Second dataset
  160. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
  161. data2 = data2.map(input_columns="image", operations=decode_op)
  162. for item1, item1_1, item2 in zip(data1.create_dict_iterator(), data1_1.create_dict_iterator(),
  163. data2.create_dict_iterator()):
  164. assert (np.array_equal(item1['image'], item1_1['image']))
  165. image2 = item2["image"]
  166. def validate_jsonfile(filepath):
  167. try:
  168. file_exist = os.path.exists(filepath)
  169. with open(filepath, 'r') as jfile:
  170. loaded_json = json.load(jfile)
  171. except IOError:
  172. return False
  173. return file_exist and isinstance(loaded_json, dict)
  174. def delete_json_files():
  175. file_list = glob.glob('*.json')
  176. for f in file_list:
  177. try:
  178. os.remove(f)
  179. except IOError:
  180. logger.info("Error while deleting: {}".format(f))
  181. # Test save load minddataset
  182. from test_minddataset_sampler import add_and_remove_cv_file, get_data, CV_DIR_NAME, CV_FILE_NAME, FILES_NUM, \
  183. FileWriter, Inter
  184. def test_minddataset(add_and_remove_cv_file):
  185. """tutorial for cv minderdataset."""
  186. columns_list = ["data", "file_name", "label"]
  187. num_readers = 4
  188. indices = [1, 2, 3, 5, 7]
  189. sampler = ds.SubsetRandomSampler(indices)
  190. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  191. sampler=sampler)
  192. # Serializing into python dictionary
  193. ds1_dict = ds.serialize(data_set)
  194. # Serializing into json object
  195. ds1_json = json.dumps(ds1_dict, sort_keys=True)
  196. # Reconstruct dataset pipeline from its serialized form
  197. data_set = ds.deserialize(input_dict=ds1_dict)
  198. ds2_dict = ds.serialize(data_set)
  199. # Serializing into json object
  200. ds2_json = json.dumps(ds2_dict, sort_keys=True)
  201. assert ds1_json == ds2_json
  202. data = get_data(CV_DIR_NAME)
  203. assert data_set.get_dataset_size() == 5
  204. num_iter = 0
  205. for item in data_set.create_dict_iterator():
  206. num_iter += 1
  207. assert num_iter == 5
  208. if __name__ == '__main__':
  209. test_imagefolder()
  210. test_zip_dataset()
  211. test_mnist_dataset()
  212. test_random_crop()