You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_manifestop.py 4.3 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import mindspore.dataset as ds
  17. import mindspore.dataset.transforms.c_transforms as data_trans
  18. from mindspore import log as logger
  19. DATA_FILE = "../data/dataset/testManifestData/test.manifest"
  20. def test_manifest_dataset_train():
  21. data = ds.ManifestDataset(DATA_FILE, decode=True)
  22. count = 0
  23. cat_count = 0
  24. dog_count = 0
  25. for item in data.create_dict_iterator():
  26. logger.info("item[image] is {}".format(item["image"]))
  27. count = count + 1
  28. if item["label"].size == 1 and item["label"] == 0:
  29. cat_count = cat_count + 1
  30. elif item["label"].size == 1 and item["label"] == 1:
  31. dog_count = dog_count + 1
  32. assert cat_count == 2
  33. assert dog_count == 1
  34. assert count == 4
  35. def test_manifest_dataset_eval():
  36. data = ds.ManifestDataset(DATA_FILE, "eval", decode=True)
  37. count = 0
  38. for item in data.create_dict_iterator():
  39. logger.info("item[image] is {}".format(item["image"]))
  40. count = count + 1
  41. if item["label"] != 0 and item["label"] != 1:
  42. assert 0
  43. assert count == 2
  44. def test_manifest_dataset_class_index():
  45. class_indexing = {"dog": 11}
  46. data = ds.ManifestDataset(DATA_FILE, decode=True, class_indexing=class_indexing)
  47. out_class_indexing = data.get_class_indexing()
  48. assert out_class_indexing == {"dog": 11}
  49. count = 0
  50. for item in data.create_dict_iterator():
  51. logger.info("item[image] is {}".format(item["image"]))
  52. count = count + 1
  53. if item["label"] != 11:
  54. assert 0
  55. assert count == 1
  56. def test_manifest_dataset_get_class_index():
  57. data = ds.ManifestDataset(DATA_FILE, decode=True)
  58. class_indexing = data.get_class_indexing()
  59. assert class_indexing == {'cat': 0, 'dog': 1, 'flower': 2}
  60. data = data.shuffle(4)
  61. class_indexing = data.get_class_indexing()
  62. assert class_indexing == {'cat': 0, 'dog': 1, 'flower': 2}
  63. count = 0
  64. for item in data.create_dict_iterator():
  65. logger.info("item[image] is {}".format(item["image"]))
  66. count = count + 1
  67. assert count == 4
  68. def test_manifest_dataset_multi_label():
  69. data = ds.ManifestDataset(DATA_FILE, decode=True, shuffle=False)
  70. count = 0
  71. expect_label = [1, 0, 0, [0, 2]]
  72. for item in data.create_dict_iterator():
  73. assert item["label"].tolist() == expect_label[count]
  74. logger.info("item[image] is {}".format(item["image"]))
  75. count = count + 1
  76. assert count == 4
  77. def multi_label_hot(x):
  78. result = np.zeros(x.size // x.ndim, dtype=int)
  79. if x.ndim > 1:
  80. for i in range(x.ndim):
  81. result = np.add(result, x[i])
  82. else:
  83. result = np.add(result, x)
  84. return result
  85. def test_manifest_dataset_multi_label_onehot():
  86. data = ds.ManifestDataset(DATA_FILE, decode=True, shuffle=False)
  87. expect_label = [[[0, 1, 0], [1, 0, 0]], [[1, 0, 0], [1, 0, 1]]]
  88. one_hot_encode = data_trans.OneHot(3)
  89. data = data.map(input_columns=["label"], operations=one_hot_encode)
  90. data = data.map(input_columns=["label"], operations=multi_label_hot)
  91. data = data.batch(2)
  92. count = 0
  93. for item in data.create_dict_iterator():
  94. assert item["label"].tolist() == expect_label[count]
  95. logger.info("item[image] is {}".format(item["image"]))
  96. count = count + 1
  97. if __name__ == '__main__':
  98. test_manifest_dataset_train()
  99. test_manifest_dataset_eval()
  100. test_manifest_dataset_class_index()
  101. test_manifest_dataset_get_class_index()
  102. test_manifest_dataset_multi_label()
  103. test_manifest_dataset_multi_label_onehot()