diff --git a/modelscope/pipelines/cv/image_matting_pipeline.py b/modelscope/pipelines/cv/image_matting_pipeline.py index 3e962d85..6f3ff5f5 100644 --- a/modelscope/pipelines/cv/image_matting_pipeline.py +++ b/modelscope/pipelines/cv/image_matting_pipeline.py @@ -7,7 +7,7 @@ import PIL from modelscope.pipelines.base import Input from modelscope.preprocessors import load_image -from modelscope.utils.constant import TF_GRAPH_FILE, Tasks +from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger from ..base import Pipeline from ..builder import PIPELINES @@ -24,7 +24,7 @@ class ImageMattingPipeline(Pipeline): import tensorflow as tf if tf.__version__ >= '2.0': tf = tf.compat.v1 - model_path = osp.join(self.model, TF_GRAPH_FILE) + model_path = osp.join(self.model, 'matting_person.pb') config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index c51e2445..0d0f2492 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -75,12 +75,3 @@ class Hubs(object): # in order to avoid conflict with huggingface # config file we use maas_config instead CONFIGFILE = 'maas_config.json' - -README_FILE = 'README.md' -TF_SAVED_MODEL_FILE = 'saved_model.pb' -TF_GRAPH_FILE = 'tf_graph.pb' -TF_CHECKPOINT_FOLDER = 'tf_ckpts' -TF_CHECKPOINT_FILE = 'checkpoint' -TORCH_MODEL_FILE = 'pytorch_model.bin' -TENSORFLOW = 'tensorflow' -PYTORCH = 'pytorch' diff --git a/tests/pipelines/test_image_matting.py b/tests/pipelines/test_image_matting.py index 69195bd1..53006317 100644 --- a/tests/pipelines/test_image_matting.py +++ b/tests/pipelines/test_image_matting.py @@ -16,15 +16,14 @@ from modelscope.utils.hub import get_model_cache_dir class ImageMattingTest(unittest.TestCase): def setUp(self) -> None: - self.model_id = 'damo/cv_unet_image-matting_damo' + self.model_id = 'damo/image-matting-person' # switch to False if downloading everytime is not desired purge_cache = True if purge_cache: shutil.rmtree( get_model_cache_dir(self.model_id), ignore_errors=True) - @unittest.skip('deprecated, download model from model hub instead') - def test_run_with_direct_file_download(self): + def test_run(self): model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ '.com/data/test/maas/image_matting/matting_person.pb' with tempfile.TemporaryDirectory() as tmp_dir: