Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8963687master
@@ -8,6 +8,7 @@ from maas_hub.file_download import model_file_download | |||
from maas_hub.snapshot_download import snapshot_download | |||
from maas_lib.models.builder import build_model | |||
from maas_lib.pipelines import util | |||
from maas_lib.utils.config import Config | |||
from maas_lib.utils.constant import CONFIGFILE | |||
@@ -39,8 +40,9 @@ class Model(ABC): | |||
if osp.exists(model_name_or_path): | |||
local_model_dir = model_name_or_path | |||
else: | |||
local_model_dir = snapshot_download(model_name_or_path) | |||
cache_path = util.get_model_cache_dir(model_name_or_path) | |||
local_model_dir = cache_path if osp.exists( | |||
cache_path) else snapshot_download(model_name_or_path) | |||
# else: | |||
# raise ValueError( | |||
# 'Remote model repo {model_name_or_path} does not exists') | |||
@@ -2,16 +2,15 @@ | |||
import os.path as osp | |||
from abc import ABC, abstractmethod | |||
from multiprocessing.sharedctypes import Value | |||
from typing import Any, Dict, Generator, List, Tuple, Union | |||
from ali_maas_datasets import PyDataset | |||
from maas_hub.snapshot_download import snapshot_download | |||
from maas_lib.models import Model | |||
from maas_lib.pipelines import util | |||
from maas_lib.preprocessors import Preprocessor | |||
from maas_lib.utils.config import Config | |||
from maas_lib.utils.constant import CONFIGFILE | |||
from .util import is_model_name | |||
Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||
@@ -31,7 +30,7 @@ class Pipeline(ABC): | |||
""" Base class for pipeline. | |||
If config_file is provided, model and preprocessor will be | |||
instantiated from corresponding config. Otherwise model | |||
instantiated from corresponding config. Otherwise, model | |||
and preprocessor will be constructed separately. | |||
Args: | |||
@@ -44,7 +43,11 @@ class Pipeline(ABC): | |||
if isinstance(model, str): | |||
if not osp.exists(model): | |||
model = snapshot_download(model) | |||
cache_path = util.get_model_cache_dir(model) | |||
if osp.exists(cache_path): | |||
model = cache_path | |||
else: | |||
model = snapshot_download(model) | |||
if is_model_name(model): | |||
self.model = Model.from_pretrained(model) | |||
@@ -61,7 +64,7 @@ class Pipeline(ABC): | |||
def __call__(self, input: Union[Input, List[Input]], *args, | |||
**post_kwargs) -> Union[Dict[str, Any], Generator]: | |||
# moodel provider should leave it as it is | |||
# model provider should leave it as it is | |||
# maas library developer will handle this function | |||
# simple showcase, need to support iterator type for both tensorflow and pytorch | |||
@@ -1,12 +1,23 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import os.path as osp | |||
import json | |||
from maas_hub.constants import MODEL_ID_SEPARATOR | |||
from maas_hub.file_download import model_file_download | |||
from maas_lib.utils.constant import CONFIGFILE | |||
# temp solution before the hub-cache is in place | |||
def get_model_cache_dir(model_id: str, branch: str = 'master'): | |||
model_id_expanded = model_id.replace('/', | |||
MODEL_ID_SEPARATOR) + '.' + branch | |||
default_cache_dir = os.path.expanduser(os.path.join('~/.cache', 'maas')) | |||
return os.getenv('MAAS_CACHE', | |||
os.path.join(default_cache_dir, 'hub', model_id_expanded)) | |||
def is_model_name(model): | |||
if osp.exists(model): | |||
if osp.exists(osp.join(model, CONFIGFILE)): | |||
@@ -1,6 +1,7 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import os.path as osp | |||
import shutil | |||
import tempfile | |||
import unittest | |||
@@ -8,12 +9,20 @@ import cv2 | |||
from ali_maas_datasets import PyDataset | |||
from maas_lib.fileio import File | |||
from maas_lib.pipelines import pipeline | |||
from maas_lib.pipelines import pipeline, util | |||
from maas_lib.utils.constant import Tasks | |||
class ImageMattingTest(unittest.TestCase): | |||
def setUp(self) -> None: | |||
self.model_id = 'damo/image-matting-person' | |||
# switch to False if downloading everytime is not desired | |||
purge_cache = True | |||
if purge_cache: | |||
shutil.rmtree( | |||
util.get_model_cache_dir(self.model_id), ignore_errors=True) | |||
def test_run(self): | |||
model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ | |||
'.com/data/test/maas/image_matting/matting_person.pb' | |||
@@ -36,16 +45,14 @@ class ImageMattingTest(unittest.TestCase): | |||
# input_location = '/dir/to/images' | |||
dataset = PyDataset.load(input_location, target='image') | |||
img_matting = pipeline( | |||
Tasks.image_matting, model='damo/image-matting-person') | |||
img_matting = pipeline(Tasks.image_matting, model=self.model_id) | |||
# note that for dataset output, the inference-output is a Generator that can be iterated. | |||
result = img_matting(dataset) | |||
cv2.imwrite('result.png', next(result)['output_png']) | |||
print(f'Output written to {osp.abspath("result.png")}') | |||
def test_run_modelhub(self): | |||
img_matting = pipeline( | |||
Tasks.image_matting, model='damo/image-matting-person') | |||
img_matting = pipeline(Tasks.image_matting, model=self.model_id) | |||
result = img_matting( | |||
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' | |||
@@ -1,5 +1,6 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import tempfile | |||
import os | |||
import shutil | |||
import unittest | |||
import zipfile | |||
from pathlib import Path | |||
@@ -9,13 +10,21 @@ from ali_maas_datasets import PyDataset | |||
from maas_lib.fileio import File | |||
from maas_lib.models import Model | |||
from maas_lib.models.nlp import SequenceClassificationModel | |||
from maas_lib.pipelines import SequenceClassificationPipeline, pipeline | |||
from maas_lib.pipelines import SequenceClassificationPipeline, pipeline, util | |||
from maas_lib.preprocessors import SequenceClassificationPreprocessor | |||
from maas_lib.utils.constant import Tasks | |||
class SequenceClassificationTest(unittest.TestCase): | |||
def setUp(self) -> None: | |||
self.model_id = 'damo/bert-base-sst2' | |||
# switch to False if downloading everytime is not desired | |||
purge_cache = True | |||
if purge_cache: | |||
shutil.rmtree( | |||
util.get_model_cache_dir(self.model_id), ignore_errors=True) | |||
def predict(self, pipeline_ins: SequenceClassificationPipeline): | |||
from easynlp.appzoo import load_dataset | |||
@@ -60,7 +69,7 @@ class SequenceClassificationTest(unittest.TestCase): | |||
print(pipeline2('Hello world!')) | |||
def test_run_with_model_from_modelhub(self): | |||
model = Model.from_pretrained('damo/bert-base-sst2') | |||
model = Model.from_pretrained(self.model_id) | |||
preprocessor = SequenceClassificationPreprocessor( | |||
model.model_dir, first_sequence='sentence', second_sequence=None) | |||
pipeline_ins = pipeline( | |||
@@ -71,13 +80,13 @@ class SequenceClassificationTest(unittest.TestCase): | |||
def test_run_with_model_name(self): | |||
text_classification = pipeline( | |||
task=Tasks.text_classification, model='damo/bert-base-sst2') | |||
task=Tasks.text_classification, model=self.model_id) | |||
result = text_classification( | |||
PyDataset.load('glue', name='sst2', target='sentence')) | |||
self.printDataset(result) | |||
def test_run_with_dataset(self): | |||
model = Model.from_pretrained('damo/bert-base-sst2') | |||
model = Model.from_pretrained(self.model_id) | |||
preprocessor = SequenceClassificationPreprocessor( | |||
model.model_dir, first_sequence='sentence', second_sequence=None) | |||
text_classification = pipeline( | |||