You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_hub_retry.py 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import unittest
  4. from http.client import HTTPMessage, HTTPResponse
  5. from io import StringIO
  6. from unittest.mock import Mock, patch
  7. import requests
  8. from urllib3.exceptions import MaxRetryError
  9. from modelscope.hub.api import HubApi
  10. from modelscope.hub.file_download import http_get_file
  11. class HubOperationTest(unittest.TestCase):
  12. def setUp(self):
  13. self.api = HubApi()
  14. self.model_id = 'damo/ofa_text-to-image-synthesis_coco_large_en'
  15. @patch('urllib3.connectionpool.HTTPConnectionPool._get_conn')
  16. def test_retry_exception(self, getconn_mock):
  17. getconn_mock.return_value.getresponse.side_effect = [
  18. Mock(status=500, msg=HTTPMessage()),
  19. Mock(status=502, msg=HTTPMessage()),
  20. Mock(status=500, msg=HTTPMessage()),
  21. ]
  22. with self.assertRaises(requests.exceptions.RetryError):
  23. self.api.get_model_files(
  24. model_id=self.model_id,
  25. recursive=True,
  26. )
  27. @patch('urllib3.connectionpool.HTTPConnectionPool._get_conn')
  28. def test_retry_and_success(self, getconn_mock):
  29. response_body = '{"Code": 200, "Data": { "Files": [ {"CommitMessage": \
  30. "update","CommittedDate": 1667548386,"CommitterName": "行嗔","InCheck": false, \
  31. "IsLFS": false, "Mode": "33188", "Name": "README.md", "Path": "README.md", \
  32. "Revision": "e45fcc158894f18a7a8cfa3caf8b3dd1a2b26dc9",\
  33. "Sha256": "8bf99f410ae0a572e5a4a85a3949ad268d49023e5c6ef200c9bd4307f9ed0660", \
  34. "Size": 6399, "Type": "blob" } ] }, "Message": "success",\
  35. "RequestId": "8c2a8249-ce50-49f4-85ea-36debf918714","Success": true}'
  36. first = 0
  37. def get_content(p):
  38. nonlocal first
  39. if first > 0:
  40. return None
  41. else:
  42. first += 1
  43. return response_body.encode('utf-8')
  44. rsp = HTTPResponse(getconn_mock)
  45. rsp.status = 200
  46. rsp.msg = HTTPMessage()
  47. rsp.read = get_content
  48. rsp.chunked = False
  49. # retry 2 times and success.
  50. getconn_mock.return_value.getresponse.side_effect = [
  51. Mock(status=500, msg=HTTPMessage()),
  52. Mock(
  53. status=502,
  54. msg=HTTPMessage(),
  55. body=response_body,
  56. read=StringIO(response_body)),
  57. rsp,
  58. ]
  59. model_files = self.api.get_model_files(
  60. model_id=self.model_id,
  61. recursive=True,
  62. )
  63. assert len(model_files) > 0
  64. @patch('urllib3.connectionpool.HTTPConnectionPool._get_conn')
  65. def test_retry_broken_continue(self, getconn_mock):
  66. test_file_name = 'video_inpainting_test.mp4'
  67. fp = 0
  68. def get_content(content_length):
  69. nonlocal fp
  70. with open('data/test/videos/%s' % test_file_name, 'rb') as f:
  71. f.seek(fp)
  72. content = f.read(content_length)
  73. fp += len(content)
  74. return content
  75. success_rsp = HTTPResponse(getconn_mock)
  76. success_rsp.status = 200
  77. success_rsp.msg = HTTPMessage()
  78. success_rsp.msg.add_header('Content-Length', '2957783')
  79. success_rsp.read = get_content
  80. success_rsp.chunked = True
  81. failed_rsp = HTTPResponse(getconn_mock)
  82. failed_rsp.status = 502
  83. failed_rsp.msg = HTTPMessage()
  84. failed_rsp.msg.add_header('Content-Length', '2957783')
  85. failed_rsp.read = get_content
  86. failed_rsp.chunked = True
  87. # retry 5 times and success.
  88. getconn_mock.return_value.getresponse.side_effect = [
  89. failed_rsp,
  90. failed_rsp,
  91. failed_rsp,
  92. failed_rsp,
  93. failed_rsp,
  94. success_rsp,
  95. ]
  96. url = 'http://www.modelscope.cn/api/v1/models/%s' % test_file_name
  97. http_get_file(
  98. url=url,
  99. local_dir='./',
  100. file_name=test_file_name,
  101. headers={},
  102. cookies=None)
  103. assert os.path.exists('./%s' % test_file_name)
  104. os.remove('./%s' % test_file_name)
  105. @patch('urllib3.connectionpool.HTTPConnectionPool._get_conn')
  106. def test_retry_broken_continue_retry_failed(self, getconn_mock):
  107. test_file_name = 'video_inpainting_test.mp4'
  108. fp = 0
  109. def get_content(content_length):
  110. nonlocal fp
  111. with open('data/test/videos/%s' % test_file_name, 'rb') as f:
  112. f.seek(fp)
  113. content = f.read(content_length)
  114. fp += len(content)
  115. return content
  116. failed_rsp = HTTPResponse(getconn_mock)
  117. failed_rsp.status = 502
  118. failed_rsp.msg = HTTPMessage()
  119. failed_rsp.msg.add_header('Content-Length', '2957783')
  120. failed_rsp.read = get_content
  121. failed_rsp.chunked = True
  122. # retry 6 times and success.
  123. getconn_mock.return_value.getresponse.side_effect = [
  124. failed_rsp,
  125. failed_rsp,
  126. failed_rsp,
  127. failed_rsp,
  128. failed_rsp,
  129. failed_rsp,
  130. ]
  131. url = 'http://www.modelscope.cn/api/v1/models/%s' % test_file_name
  132. with self.assertRaises(MaxRetryError):
  133. http_get_file(
  134. url=url,
  135. local_dir='./',
  136. file_name=test_file_name,
  137. headers={},
  138. cookies=None)
  139. assert not os.path.exists('./%s' % test_file_name)
  140. if __name__ == '__main__':
  141. unittest.main()