|
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os
- import unittest
- from http.client import HTTPMessage, HTTPResponse
- from io import StringIO
- from unittest.mock import Mock, patch
-
- import requests
- from urllib3.exceptions import MaxRetryError
-
- from modelscope.hub.api import HubApi
- from modelscope.hub.file_download import http_get_file
-
-
- class HubOperationTest(unittest.TestCase):
-
- def setUp(self):
- self.api = HubApi()
- self.model_id = 'damo/ofa_text-to-image-synthesis_coco_large_en'
-
- @patch('urllib3.connectionpool.HTTPConnectionPool._get_conn')
- def test_retry_exception(self, getconn_mock):
- getconn_mock.return_value.getresponse.side_effect = [
- Mock(status=500, msg=HTTPMessage()),
- Mock(status=502, msg=HTTPMessage()),
- Mock(status=500, msg=HTTPMessage()),
- ]
- with self.assertRaises(requests.exceptions.RetryError):
- self.api.get_model_files(
- model_id=self.model_id,
- recursive=True,
- )
-
- @patch('urllib3.connectionpool.HTTPConnectionPool._get_conn')
- def test_retry_and_success(self, getconn_mock):
- response_body = '{"Code": 200, "Data": { "Files": [ {"CommitMessage": \
- "update","CommittedDate": 1667548386,"CommitterName": "行嗔","InCheck": false, \
- "IsLFS": false, "Mode": "33188", "Name": "README.md", "Path": "README.md", \
- "Revision": "e45fcc158894f18a7a8cfa3caf8b3dd1a2b26dc9",\
- "Sha256": "8bf99f410ae0a572e5a4a85a3949ad268d49023e5c6ef200c9bd4307f9ed0660", \
- "Size": 6399, "Type": "blob" } ] }, "Message": "success",\
- "RequestId": "8c2a8249-ce50-49f4-85ea-36debf918714","Success": true}'
-
- first = 0
-
- def get_content(p):
- nonlocal first
- if first > 0:
- return None
- else:
- first += 1
- return response_body.encode('utf-8')
-
- rsp = HTTPResponse(getconn_mock)
- rsp.status = 200
- rsp.msg = HTTPMessage()
- rsp.read = get_content
- rsp.chunked = False
- # retry 2 times and success.
- getconn_mock.return_value.getresponse.side_effect = [
- Mock(status=500, msg=HTTPMessage()),
- Mock(
- status=502,
- msg=HTTPMessage(),
- body=response_body,
- read=StringIO(response_body)),
- rsp,
- ]
- model_files = self.api.get_model_files(
- model_id=self.model_id,
- recursive=True,
- )
- assert len(model_files) > 0
-
- @patch('urllib3.connectionpool.HTTPConnectionPool._get_conn')
- def test_retry_broken_continue(self, getconn_mock):
- test_file_name = 'video_inpainting_test.mp4'
- fp = 0
-
- def get_content(content_length):
- nonlocal fp
- with open('data/test/videos/%s' % test_file_name, 'rb') as f:
- f.seek(fp)
- content = f.read(content_length)
- fp += len(content)
- return content
-
- success_rsp = HTTPResponse(getconn_mock)
- success_rsp.status = 200
- success_rsp.msg = HTTPMessage()
- success_rsp.msg.add_header('Content-Length', '2957783')
- success_rsp.read = get_content
- success_rsp.chunked = True
-
- failed_rsp = HTTPResponse(getconn_mock)
- failed_rsp.status = 502
- failed_rsp.msg = HTTPMessage()
- failed_rsp.msg.add_header('Content-Length', '2957783')
- failed_rsp.read = get_content
- failed_rsp.chunked = True
-
- # retry 5 times and success.
- getconn_mock.return_value.getresponse.side_effect = [
- failed_rsp,
- failed_rsp,
- failed_rsp,
- failed_rsp,
- failed_rsp,
- success_rsp,
- ]
- url = 'http://www.modelscope.cn/api/v1/models/%s' % test_file_name
- http_get_file(
- url=url,
- local_dir='./',
- file_name=test_file_name,
- headers={},
- cookies=None)
-
- assert os.path.exists('./%s' % test_file_name)
- os.remove('./%s' % test_file_name)
-
- @patch('urllib3.connectionpool.HTTPConnectionPool._get_conn')
- def test_retry_broken_continue_retry_failed(self, getconn_mock):
- test_file_name = 'video_inpainting_test.mp4'
- fp = 0
-
- def get_content(content_length):
- nonlocal fp
- with open('data/test/videos/%s' % test_file_name, 'rb') as f:
- f.seek(fp)
- content = f.read(content_length)
- fp += len(content)
- return content
-
- failed_rsp = HTTPResponse(getconn_mock)
- failed_rsp.status = 502
- failed_rsp.msg = HTTPMessage()
- failed_rsp.msg.add_header('Content-Length', '2957783')
- failed_rsp.read = get_content
- failed_rsp.chunked = True
-
- # retry 6 times and success.
- getconn_mock.return_value.getresponse.side_effect = [
- failed_rsp,
- failed_rsp,
- failed_rsp,
- failed_rsp,
- failed_rsp,
- failed_rsp,
- ]
- url = 'http://www.modelscope.cn/api/v1/models/%s' % test_file_name
- with self.assertRaises(MaxRetryError):
- http_get_file(
- url=url,
- local_dir='./',
- file_name=test_file_name,
- headers={},
- cookies=None)
-
- assert not os.path.exists('./%s' % test_file_name)
-
-
- if __name__ == '__main__':
- unittest.main()
|