@@ -7,7 +7,7 @@ import PIL | |||||
from modelscope.pipelines.base import Input | from modelscope.pipelines.base import Input | ||||
from modelscope.preprocessors import load_image | from modelscope.preprocessors import load_image | ||||
from modelscope.utils.constant import TF_GRAPH_FILE, Tasks | |||||
from modelscope.utils.constant import Tasks | |||||
from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
from ..base import Pipeline | from ..base import Pipeline | ||||
from ..builder import PIPELINES | from ..builder import PIPELINES | ||||
@@ -24,7 +24,7 @@ class ImageMattingPipeline(Pipeline): | |||||
import tensorflow as tf | import tensorflow as tf | ||||
if tf.__version__ >= '2.0': | if tf.__version__ >= '2.0': | ||||
tf = tf.compat.v1 | tf = tf.compat.v1 | ||||
model_path = osp.join(self.model, TF_GRAPH_FILE) | |||||
model_path = osp.join(self.model, 'matting_person.pb') | |||||
config = tf.ConfigProto(allow_soft_placement=True) | config = tf.ConfigProto(allow_soft_placement=True) | ||||
config.gpu_options.allow_growth = True | config.gpu_options.allow_growth = True | ||||
@@ -75,12 +75,3 @@ class Hubs(object): | |||||
# in order to avoid conflict with huggingface | # in order to avoid conflict with huggingface | ||||
# config file we use maas_config instead | # config file we use maas_config instead | ||||
CONFIGFILE = 'maas_config.json' | CONFIGFILE = 'maas_config.json' | ||||
README_FILE = 'README.md' | |||||
TF_SAVED_MODEL_FILE = 'saved_model.pb' | |||||
TF_GRAPH_FILE = 'tf_graph.pb' | |||||
TF_CHECKPOINT_FOLDER = 'tf_ckpts' | |||||
TF_CHECKPOINT_FILE = 'checkpoint' | |||||
TORCH_MODEL_FILE = 'pytorch_model.bin' | |||||
TENSORFLOW = 'tensorflow' | |||||
PYTORCH = 'pytorch' |
@@ -16,15 +16,14 @@ from modelscope.utils.hub import get_model_cache_dir | |||||
class ImageMattingTest(unittest.TestCase): | class ImageMattingTest(unittest.TestCase): | ||||
def setUp(self) -> None: | def setUp(self) -> None: | ||||
self.model_id = 'damo/cv_unet_image-matting_damo' | |||||
self.model_id = 'damo/image-matting-person' | |||||
# switch to False if downloading everytime is not desired | # switch to False if downloading everytime is not desired | ||||
purge_cache = True | purge_cache = True | ||||
if purge_cache: | if purge_cache: | ||||
shutil.rmtree( | shutil.rmtree( | ||||
get_model_cache_dir(self.model_id), ignore_errors=True) | get_model_cache_dir(self.model_id), ignore_errors=True) | ||||
@unittest.skip('deprecated, download model from model hub instead') | |||||
def test_run_with_direct_file_download(self): | |||||
def test_run(self): | |||||
model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ | model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ | ||||
'.com/data/test/maas/image_matting/matting_person.pb' | '.com/data/test/maas/image_matting/matting_person.pb' | ||||
with tempfile.TemporaryDirectory() as tmp_dir: | with tempfile.TemporaryDirectory() as tmp_dir: | ||||