You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_sync_wait.py 8.6 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.dataset as ds
  18. from mindspore import log as logger
  19. def gen():
  20. for i in range(100):
  21. yield (np.array(i),)
  22. class Augment:
  23. def __init__(self, loss):
  24. self.loss = loss
  25. def preprocess(self, input_):
  26. return input_
  27. def update(self, data):
  28. self.loss = data["loss"]
  29. def test_simple_sync_wait():
  30. """
  31. Test simple sync wait: test sync in dataset pipeline
  32. """
  33. logger.info("test_simple_sync_wait")
  34. batch_size = 4
  35. dataset = ds.GeneratorDataset(gen, column_names=["input"])
  36. aug = Augment(0)
  37. dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
  38. dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
  39. dataset = dataset.batch(batch_size)
  40. count = 0
  41. for data in dataset.create_dict_iterator():
  42. assert data["input"][0] == count
  43. count += batch_size
  44. data = {"loss": count}
  45. dataset.sync_update(condition_name="policy", data=data)
  46. def test_simple_shuffle_sync():
  47. """
  48. Test simple shuffle sync: test shuffle before sync
  49. """
  50. logger.info("test_simple_shuffle_sync")
  51. shuffle_size = 4
  52. batch_size = 10
  53. dataset = ds.GeneratorDataset(gen, column_names=["input"])
  54. aug = Augment(0)
  55. dataset = dataset.shuffle(shuffle_size)
  56. dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
  57. dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
  58. dataset = dataset.batch(batch_size)
  59. count = 0
  60. for data in dataset.create_dict_iterator():
  61. count += 1
  62. data = {"loss": count}
  63. dataset.sync_update(condition_name="policy", data=data)
  64. def test_two_sync():
  65. """
  66. Test two sync: dataset pipeline with with two sync_operators
  67. """
  68. logger.info("test_two_sync")
  69. batch_size = 6
  70. dataset = ds.GeneratorDataset(gen, column_names=["input"])
  71. aug = Augment(0)
  72. # notice that with our design, we need to have step_size = shuffle size
  73. dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
  74. dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
  75. dataset = dataset.sync_wait(num_batch=2, condition_name="every 2 batches")
  76. dataset = dataset.batch(batch_size)
  77. count = 0
  78. for data in dataset.create_dict_iterator():
  79. count += 1
  80. data = {"loss": count}
  81. dataset.sync_update(condition_name="every batch", data=data)
  82. if count % 2 == 0:
  83. dataset.sync_update(condition_name="every 2 batches")
  84. def test_sync_epoch():
  85. """
  86. Test sync wait with epochs: test sync with epochs in dataset pipeline
  87. """
  88. logger.info("test_sync_epoch")
  89. batch_size = 30
  90. dataset = ds.GeneratorDataset(gen, column_names=["input"])
  91. aug = Augment(0)
  92. dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
  93. dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
  94. dataset = dataset.batch(batch_size, drop_remainder=True)
  95. for _ in range(3):
  96. aug.update({"loss": 0})
  97. count = 0
  98. for data in dataset.create_dict_iterator():
  99. assert data["input"][0] == count
  100. count += batch_size
  101. data = {"loss": count}
  102. dataset.sync_update(condition_name="policy", data=data)
  103. def test_multiple_iterators():
  104. """
  105. Test sync wait with multiple iterators: will start multiple
  106. """
  107. logger.info("test_sync_epoch")
  108. batch_size = 30
  109. dataset = ds.GeneratorDataset(gen, column_names=["input"])
  110. aug = Augment(0)
  111. dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
  112. dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
  113. dataset = dataset.batch(batch_size, drop_remainder=True)
  114. # 2nd dataset
  115. dataset2 = ds.GeneratorDataset(gen, column_names=["input"])
  116. aug = Augment(0)
  117. dataset2 = dataset2.sync_wait(condition_name="policy", callback=aug.update)
  118. dataset2 = dataset2.map(input_columns=["input"], operations=[aug.preprocess])
  119. dataset2 = dataset2.batch(batch_size, drop_remainder=True)
  120. for item1, item2 in zip(dataset.create_dict_iterator(), dataset2.create_dict_iterator()):
  121. assert item1["input"][0] == item2["input"][0]
  122. data1 = {"loss": item1["input"][0]}
  123. data2 = {"loss": item2["input"][0]}
  124. dataset.sync_update(condition_name="policy", data=data1)
  125. dataset2.sync_update(condition_name="policy", data=data2)
  126. def test_sync_exception_01():
  127. """
  128. Test sync: with shuffle in sync mode
  129. """
  130. logger.info("test_sync_exception_01")
  131. shuffle_size = 4
  132. dataset = ds.GeneratorDataset(gen, column_names=["input"])
  133. aug = Augment(0)
  134. dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
  135. dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
  136. with pytest.raises(RuntimeError) as e:
  137. dataset.shuffle(shuffle_size)
  138. assert "No shuffle after sync operators" in str(e.value)
  139. def test_sync_exception_02():
  140. """
  141. Test sync: with duplicated condition name
  142. """
  143. logger.info("test_sync_exception_02")
  144. dataset = ds.GeneratorDataset(gen, column_names=["input"])
  145. aug = Augment(0)
  146. dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
  147. dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
  148. with pytest.raises(RuntimeError) as e:
  149. dataset.sync_wait(num_batch=2, condition_name="every batch")
  150. assert "Condition name is already in use" in str(e.value)
  151. def test_sync_exception_03():
  152. """
  153. Test sync: with wrong batch size
  154. """
  155. logger.info("test_sync_exception_03")
  156. dataset = ds.GeneratorDataset(gen, column_names=["input"])
  157. aug = Augment(0)
  158. # try to create dataset with batch_size < 0
  159. with pytest.raises(ValueError) as e:
  160. dataset.sync_wait(condition_name="every batch", num_batch=-1, callback=aug.update)
  161. assert "num_batch need to be greater than 0." in str(e.value)
  162. def test_sync_exception_04():
  163. """
  164. Test sync: with negative batch size in update
  165. """
  166. logger.info("test_sync_exception_04")
  167. dataset = ds.GeneratorDataset(gen, column_names=["input"])
  168. aug = Augment(0)
  169. # try to create dataset with batch_size < 0
  170. dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
  171. dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
  172. count = 0
  173. with pytest.raises(RuntimeError) as e:
  174. for _ in dataset.create_dict_iterator():
  175. count += 1
  176. data = {"loss": count}
  177. dataset.sync_update(condition_name="every batch", num_batch=-1, data=data)
  178. assert "Sync_update batch size can only be positive" in str(e.value)
  179. def test_sync_exception_05():
  180. """
  181. Test sync: with wrong batch size in update
  182. """
  183. logger.info("test_sync_exception_05")
  184. dataset = ds.GeneratorDataset(gen, column_names=["input"])
  185. count = 0
  186. aug = Augment(0)
  187. # try to create dataset with batch_size < 0
  188. dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
  189. dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
  190. with pytest.raises(RuntimeError) as e:
  191. for _ in dataset.create_dict_iterator():
  192. dataset.disable_sync()
  193. count += 1
  194. data = {"loss": count}
  195. dataset.disable_sync()
  196. dataset.sync_update(condition_name="every", data=data)
  197. assert "Condition name not found" in str(e.value)
  198. if __name__ == "__main__":
  199. test_simple_sync_wait()
  200. test_simple_shuffle_sync()
  201. test_two_sync()
  202. test_sync_exception_01()
  203. test_sync_exception_02()
  204. test_sync_exception_03()
  205. test_sync_exception_04()
  206. test_sync_exception_05()
  207. test_sync_epoch()
  208. test_multiple_iterators()