Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8963687master
@@ -8,6 +8,7 @@ from maas_hub.file_download import model_file_download | |||||
from maas_hub.snapshot_download import snapshot_download | from maas_hub.snapshot_download import snapshot_download | ||||
from maas_lib.models.builder import build_model | from maas_lib.models.builder import build_model | ||||
from maas_lib.pipelines import util | |||||
from maas_lib.utils.config import Config | from maas_lib.utils.config import Config | ||||
from maas_lib.utils.constant import CONFIGFILE | from maas_lib.utils.constant import CONFIGFILE | ||||
@@ -39,8 +40,9 @@ class Model(ABC): | |||||
if osp.exists(model_name_or_path): | if osp.exists(model_name_or_path): | ||||
local_model_dir = model_name_or_path | local_model_dir = model_name_or_path | ||||
else: | else: | ||||
local_model_dir = snapshot_download(model_name_or_path) | |||||
cache_path = util.get_model_cache_dir(model_name_or_path) | |||||
local_model_dir = cache_path if osp.exists( | |||||
cache_path) else snapshot_download(model_name_or_path) | |||||
# else: | # else: | ||||
# raise ValueError( | # raise ValueError( | ||||
# 'Remote model repo {model_name_or_path} does not exists') | # 'Remote model repo {model_name_or_path} does not exists') | ||||
@@ -2,16 +2,15 @@ | |||||
import os.path as osp | import os.path as osp | ||||
from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||
from multiprocessing.sharedctypes import Value | |||||
from typing import Any, Dict, Generator, List, Tuple, Union | from typing import Any, Dict, Generator, List, Tuple, Union | ||||
from ali_maas_datasets import PyDataset | from ali_maas_datasets import PyDataset | ||||
from maas_hub.snapshot_download import snapshot_download | from maas_hub.snapshot_download import snapshot_download | ||||
from maas_lib.models import Model | from maas_lib.models import Model | ||||
from maas_lib.pipelines import util | |||||
from maas_lib.preprocessors import Preprocessor | from maas_lib.preprocessors import Preprocessor | ||||
from maas_lib.utils.config import Config | from maas_lib.utils.config import Config | ||||
from maas_lib.utils.constant import CONFIGFILE | |||||
from .util import is_model_name | from .util import is_model_name | ||||
Tensor = Union['torch.Tensor', 'tf.Tensor'] | Tensor = Union['torch.Tensor', 'tf.Tensor'] | ||||
@@ -31,7 +30,7 @@ class Pipeline(ABC): | |||||
""" Base class for pipeline. | """ Base class for pipeline. | ||||
If config_file is provided, model and preprocessor will be | If config_file is provided, model and preprocessor will be | ||||
instantiated from corresponding config. Otherwise model | |||||
instantiated from corresponding config. Otherwise, model | |||||
and preprocessor will be constructed separately. | and preprocessor will be constructed separately. | ||||
Args: | Args: | ||||
@@ -44,7 +43,11 @@ class Pipeline(ABC): | |||||
if isinstance(model, str): | if isinstance(model, str): | ||||
if not osp.exists(model): | if not osp.exists(model): | ||||
model = snapshot_download(model) | |||||
cache_path = util.get_model_cache_dir(model) | |||||
if osp.exists(cache_path): | |||||
model = cache_path | |||||
else: | |||||
model = snapshot_download(model) | |||||
if is_model_name(model): | if is_model_name(model): | ||||
self.model = Model.from_pretrained(model) | self.model = Model.from_pretrained(model) | ||||
@@ -61,7 +64,7 @@ class Pipeline(ABC): | |||||
def __call__(self, input: Union[Input, List[Input]], *args, | def __call__(self, input: Union[Input, List[Input]], *args, | ||||
**post_kwargs) -> Union[Dict[str, Any], Generator]: | **post_kwargs) -> Union[Dict[str, Any], Generator]: | ||||
# moodel provider should leave it as it is | |||||
# model provider should leave it as it is | |||||
# maas library developer will handle this function | # maas library developer will handle this function | ||||
# simple showcase, need to support iterator type for both tensorflow and pytorch | # simple showcase, need to support iterator type for both tensorflow and pytorch | ||||
@@ -1,12 +1,23 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
import os | |||||
import os.path as osp | import os.path as osp | ||||
import json | import json | ||||
from maas_hub.constants import MODEL_ID_SEPARATOR | |||||
from maas_hub.file_download import model_file_download | from maas_hub.file_download import model_file_download | ||||
from maas_lib.utils.constant import CONFIGFILE | from maas_lib.utils.constant import CONFIGFILE | ||||
# temp solution before the hub-cache is in place | |||||
def get_model_cache_dir(model_id: str, branch: str = 'master'): | |||||
model_id_expanded = model_id.replace('/', | |||||
MODEL_ID_SEPARATOR) + '.' + branch | |||||
default_cache_dir = os.path.expanduser(os.path.join('~/.cache', 'maas')) | |||||
return os.getenv('MAAS_CACHE', | |||||
os.path.join(default_cache_dir, 'hub', model_id_expanded)) | |||||
def is_model_name(model): | def is_model_name(model): | ||||
if osp.exists(model): | if osp.exists(model): | ||||
if osp.exists(osp.join(model, CONFIGFILE)): | if osp.exists(osp.join(model, CONFIGFILE)): | ||||
@@ -1,6 +1,7 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
import os | |||||
import os.path as osp | import os.path as osp | ||||
import shutil | |||||
import tempfile | import tempfile | ||||
import unittest | import unittest | ||||
@@ -8,12 +9,20 @@ import cv2 | |||||
from ali_maas_datasets import PyDataset | from ali_maas_datasets import PyDataset | ||||
from maas_lib.fileio import File | from maas_lib.fileio import File | ||||
from maas_lib.pipelines import pipeline | |||||
from maas_lib.pipelines import pipeline, util | |||||
from maas_lib.utils.constant import Tasks | from maas_lib.utils.constant import Tasks | ||||
class ImageMattingTest(unittest.TestCase): | class ImageMattingTest(unittest.TestCase): | ||||
def setUp(self) -> None: | |||||
self.model_id = 'damo/image-matting-person' | |||||
# switch to False if downloading everytime is not desired | |||||
purge_cache = True | |||||
if purge_cache: | |||||
shutil.rmtree( | |||||
util.get_model_cache_dir(self.model_id), ignore_errors=True) | |||||
def test_run(self): | def test_run(self): | ||||
model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ | model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ | ||||
'.com/data/test/maas/image_matting/matting_person.pb' | '.com/data/test/maas/image_matting/matting_person.pb' | ||||
@@ -36,16 +45,14 @@ class ImageMattingTest(unittest.TestCase): | |||||
# input_location = '/dir/to/images' | # input_location = '/dir/to/images' | ||||
dataset = PyDataset.load(input_location, target='image') | dataset = PyDataset.load(input_location, target='image') | ||||
img_matting = pipeline( | |||||
Tasks.image_matting, model='damo/image-matting-person') | |||||
img_matting = pipeline(Tasks.image_matting, model=self.model_id) | |||||
# note that for dataset output, the inference-output is a Generator that can be iterated. | # note that for dataset output, the inference-output is a Generator that can be iterated. | ||||
result = img_matting(dataset) | result = img_matting(dataset) | ||||
cv2.imwrite('result.png', next(result)['output_png']) | cv2.imwrite('result.png', next(result)['output_png']) | ||||
print(f'Output written to {osp.abspath("result.png")}') | print(f'Output written to {osp.abspath("result.png")}') | ||||
def test_run_modelhub(self): | def test_run_modelhub(self): | ||||
img_matting = pipeline( | |||||
Tasks.image_matting, model='damo/image-matting-person') | |||||
img_matting = pipeline(Tasks.image_matting, model=self.model_id) | |||||
result = img_matting( | result = img_matting( | ||||
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' | 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' | ||||
@@ -1,5 +1,6 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
import tempfile | |||||
import os | |||||
import shutil | |||||
import unittest | import unittest | ||||
import zipfile | import zipfile | ||||
from pathlib import Path | from pathlib import Path | ||||
@@ -9,13 +10,21 @@ from ali_maas_datasets import PyDataset | |||||
from maas_lib.fileio import File | from maas_lib.fileio import File | ||||
from maas_lib.models import Model | from maas_lib.models import Model | ||||
from maas_lib.models.nlp import SequenceClassificationModel | from maas_lib.models.nlp import SequenceClassificationModel | ||||
from maas_lib.pipelines import SequenceClassificationPipeline, pipeline | |||||
from maas_lib.pipelines import SequenceClassificationPipeline, pipeline, util | |||||
from maas_lib.preprocessors import SequenceClassificationPreprocessor | from maas_lib.preprocessors import SequenceClassificationPreprocessor | ||||
from maas_lib.utils.constant import Tasks | from maas_lib.utils.constant import Tasks | ||||
class SequenceClassificationTest(unittest.TestCase): | class SequenceClassificationTest(unittest.TestCase): | ||||
def setUp(self) -> None: | |||||
self.model_id = 'damo/bert-base-sst2' | |||||
# switch to False if downloading everytime is not desired | |||||
purge_cache = True | |||||
if purge_cache: | |||||
shutil.rmtree( | |||||
util.get_model_cache_dir(self.model_id), ignore_errors=True) | |||||
def predict(self, pipeline_ins: SequenceClassificationPipeline): | def predict(self, pipeline_ins: SequenceClassificationPipeline): | ||||
from easynlp.appzoo import load_dataset | from easynlp.appzoo import load_dataset | ||||
@@ -60,7 +69,7 @@ class SequenceClassificationTest(unittest.TestCase): | |||||
print(pipeline2('Hello world!')) | print(pipeline2('Hello world!')) | ||||
def test_run_with_model_from_modelhub(self): | def test_run_with_model_from_modelhub(self): | ||||
model = Model.from_pretrained('damo/bert-base-sst2') | |||||
model = Model.from_pretrained(self.model_id) | |||||
preprocessor = SequenceClassificationPreprocessor( | preprocessor = SequenceClassificationPreprocessor( | ||||
model.model_dir, first_sequence='sentence', second_sequence=None) | model.model_dir, first_sequence='sentence', second_sequence=None) | ||||
pipeline_ins = pipeline( | pipeline_ins = pipeline( | ||||
@@ -71,13 +80,13 @@ class SequenceClassificationTest(unittest.TestCase): | |||||
def test_run_with_model_name(self): | def test_run_with_model_name(self): | ||||
text_classification = pipeline( | text_classification = pipeline( | ||||
task=Tasks.text_classification, model='damo/bert-base-sst2') | |||||
task=Tasks.text_classification, model=self.model_id) | |||||
result = text_classification( | result = text_classification( | ||||
PyDataset.load('glue', name='sst2', target='sentence')) | PyDataset.load('glue', name='sst2', target='sentence')) | ||||
self.printDataset(result) | self.printDataset(result) | ||||
def test_run_with_dataset(self): | def test_run_with_dataset(self): | ||||
model = Model.from_pretrained('damo/bert-base-sst2') | |||||
model = Model.from_pretrained(self.model_id) | |||||
preprocessor = SequenceClassificationPreprocessor( | preprocessor = SequenceClassificationPreprocessor( | ||||
model.model_dir, first_sequence='sentence', second_sequence=None) | model.model_dir, first_sequence='sentence', second_sequence=None) | ||||
text_classification = pipeline( | text_classification = pipeline( | ||||