|
|
@@ -9,25 +9,26 @@ import cv2 |
|
|
|
from modelscope.fileio import File |
|
|
|
from modelscope.pipelines import pipeline |
|
|
|
from modelscope.pydatasets import PyDataset |
|
|
|
from modelscope.utils.constant import Tasks |
|
|
|
from modelscope.utils.constant import ModelFile, Tasks |
|
|
|
from modelscope.utils.hub import get_model_cache_dir |
|
|
|
|
|
|
|
|
|
|
|
class ImageMattingTest(unittest.TestCase): |
|
|
|
|
|
|
|
def setUp(self) -> None: |
|
|
|
self.model_id = 'damo/image-matting-person' |
|
|
|
self.model_id = 'damo/cv_unet_image-matting_damo' |
|
|
|
# switch to False if downloading everytime is not desired |
|
|
|
purge_cache = True |
|
|
|
if purge_cache: |
|
|
|
shutil.rmtree( |
|
|
|
get_model_cache_dir(self.model_id), ignore_errors=True) |
|
|
|
|
|
|
|
def test_run(self): |
|
|
|
@unittest.skip('deprecated, download model from model hub instead') |
|
|
|
def test_run_with_direct_file_download(self): |
|
|
|
model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ |
|
|
|
'.com/data/test/maas/image_matting/matting_person.pb' |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
|
|
model_file = osp.join(tmp_dir, 'matting_person.pb') |
|
|
|
model_file = osp.join(tmp_dir, ModelFile.TF_GRAPH_FILE) |
|
|
|
with open(model_file, 'wb') as ofile: |
|
|
|
ofile.write(File.read(model_path)) |
|
|
|
img_matting = pipeline(Tasks.image_matting, model=tmp_dir) |
|
|
|