|
- # Copyright (c) OpenMMLab. All rights reserved.
- """Tests for async interface."""
-
- import asyncio
- import os
- import sys
-
- import asynctest
- import mmcv
- import torch
-
- from mmdet.apis import async_inference_detector, init_detector
-
- if sys.version_info >= (3, 7):
- from mmdet.utils.contextmanagers import concurrent
-
-
- class AsyncTestCase(asynctest.TestCase):
- use_default_loop = False
- forbid_get_event_loop = True
-
- TEST_TIMEOUT = int(os.getenv('ASYNCIO_TEST_TIMEOUT', '30'))
-
- def _run_test_method(self, method):
- result = method()
- if asyncio.iscoroutine(result):
- self.loop.run_until_complete(
- asyncio.wait_for(result, timeout=self.TEST_TIMEOUT))
-
-
- class MaskRCNNDetector:
-
- def __init__(self,
- model_config,
- checkpoint=None,
- streamqueue_size=3,
- device='cuda:0'):
-
- self.streamqueue_size = streamqueue_size
- self.device = device
- # build the model and load checkpoint
- self.model = init_detector(
- model_config, checkpoint=None, device=self.device)
- self.streamqueue = None
-
- async def init(self):
- self.streamqueue = asyncio.Queue()
- for _ in range(self.streamqueue_size):
- stream = torch.cuda.Stream(device=self.device)
- self.streamqueue.put_nowait(stream)
-
- if sys.version_info >= (3, 7):
-
- async def apredict(self, img):
- if isinstance(img, str):
- img = mmcv.imread(img)
- async with concurrent(self.streamqueue):
- result = await async_inference_detector(self.model, img)
- return result
-
-
- class AsyncInferenceTestCase(AsyncTestCase):
-
- if sys.version_info >= (3, 7):
-
- async def test_simple_inference(self):
- if not torch.cuda.is_available():
- import pytest
-
- pytest.skip('test requires GPU and torch+cuda')
-
- ori_grad_enabled = torch.is_grad_enabled()
- root_dir = os.path.dirname(os.path.dirname(__name__))
- model_config = os.path.join(
- root_dir, 'configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py')
- detector = MaskRCNNDetector(model_config)
- await detector.init()
- img_path = os.path.join(root_dir, 'demo/demo.jpg')
- bboxes, _ = await detector.apredict(img_path)
- self.assertTrue(bboxes)
- # asy inference detector will hack grad_enabled,
- # so restore here to avoid it to influence other tests
- torch.set_grad_enabled(ori_grad_enabled)
|