You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_enumerate.py 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test enumerate"""
  16. import numpy as np
  17. import pytest
  18. import mindspore.nn as nn
  19. from mindspore import Tensor
  20. from mindspore import context
  21. context.set_context(mode=context.GRAPH_MODE)
  22. def test_enumerate_list_const():
  23. class Net(nn.Cell):
  24. def __init__(self):
  25. super(Net, self).__init__()
  26. self.value = [11, 22, 33, 44]
  27. def construct(self):
  28. index_sum = 0
  29. value_sum = 0
  30. for i, j in enumerate(self.value):
  31. index_sum += i
  32. value_sum += j
  33. return index_sum, value_sum
  34. net = Net()
  35. assert net() == (6, 110)
  36. def test_enumerate_tuple_const():
  37. class Net(nn.Cell):
  38. def __init__(self):
  39. super(Net, self).__init__()
  40. self.value = (11, 22, 33, 44)
  41. def construct(self):
  42. index_sum = 0
  43. value_sum = 0
  44. for i, j in enumerate(self.value):
  45. index_sum += i
  46. value_sum += j
  47. return index_sum, value_sum
  48. net = Net()
  49. assert net() == (6, 110)
  50. def test_enumerate_list_parameter():
  51. class Net(nn.Cell):
  52. def __init__(self):
  53. super(Net, self).__init__()
  54. def construct(self, x, y, z):
  55. index_sum = 0
  56. value = [x, y, z]
  57. ret = ()
  58. for i, j in enumerate(value):
  59. index_sum += i
  60. ret += (j,)
  61. return index_sum, ret
  62. x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
  63. net = Net()
  64. net(x, x, x)
  65. def test_enumerate_tuple_parameter():
  66. class Net(nn.Cell):
  67. def __init__(self):
  68. super(Net, self).__init__()
  69. def construct(self, x, y, z):
  70. index_sum = 0
  71. value = (x, y, z)
  72. ret = ()
  73. for i, j in enumerate(value):
  74. index_sum += i
  75. ret += (j,)
  76. return index_sum, ret
  77. x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
  78. net = Net()
  79. net(x, x, x)
  80. def test_enumerate_tuple_const_1():
  81. class Net(nn.Cell):
  82. def __init__(self):
  83. super(Net, self).__init__()
  84. self.value = (11, 22, 33, 44)
  85. def construct(self):
  86. index_sum = 0
  87. value_sum = 0
  88. for i in enumerate(self.value):
  89. index_sum += i[0]
  90. value_sum += i[1]
  91. return index_sum, value_sum
  92. net = Net()
  93. assert net() == (6, 110)
  94. def test_enumerate_tuple_parameter_1():
  95. class Net(nn.Cell):
  96. def __init__(self):
  97. super(Net, self).__init__()
  98. def construct(self, x, y, z):
  99. index_sum = 0
  100. value = (x, y, z)
  101. ret = ()
  102. for i in enumerate(value):
  103. index_sum += i[0]
  104. ret += (i[1],)
  105. return index_sum, ret
  106. x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
  107. net = Net()
  108. net(x, x, x)
  109. def test_enumerate_tuple_const_2():
  110. class Net(nn.Cell):
  111. def __init__(self):
  112. super(Net, self).__init__()
  113. self.value = (11, 22, 33, 44)
  114. def construct(self):
  115. index_sum = 0
  116. value_sum = 0
  117. for i in enumerate(self.value, 1):
  118. index_sum += i[0]
  119. value_sum += i[1]
  120. return index_sum, value_sum
  121. net = Net()
  122. assert net() == (10, 110)
  123. def test_enumerate_tuple_parameter_2():
  124. class Net(nn.Cell):
  125. def __init__(self):
  126. super(Net, self).__init__()
  127. def construct(self, x, y, z):
  128. index_sum = 0
  129. value = (x, y, z)
  130. ret = ()
  131. for i in enumerate(value, 2):
  132. index_sum += i[0]
  133. ret += (i[1],)
  134. return index_sum, ret
  135. x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
  136. net = Net()
  137. net(x, x, x)
  138. def test_enumerate_first_input_type_error():
  139. class Net(nn.Cell):
  140. def __init__(self):
  141. super(Net, self).__init__()
  142. def construct(self, x):
  143. return enumerate(x)
  144. x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
  145. net = Net()
  146. with pytest.raises(TypeError) as ex:
  147. net(x)
  148. assert "For 'enumerate', the 'first input'" in str(ex.value)
  149. def test_enumerate_start_type_error():
  150. class Net(nn.Cell):
  151. def __init__(self):
  152. super(Net, self).__init__()
  153. def construct(self, x):
  154. return enumerate(x, start=1.2)
  155. x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
  156. net = Net()
  157. with pytest.raises(TypeError) as ex:
  158. net((x, x))
  159. assert "For 'enumerate', the 'start'" in str(ex.value)