Browse Source

[to #42323743] retain local cached model files by default

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8963687
master
yingda.chen 3 years ago
parent
commit
d6868ddffe
5 changed files with 50 additions and 18 deletions
  1. +4
    -2
      maas_lib/models/base.py
  2. +8
    -5
      maas_lib/pipelines/base.py
  3. +11
    -0
      maas_lib/pipelines/util.py
  4. +13
    -6
      tests/pipelines/test_image_matting.py
  5. +14
    -5
      tests/pipelines/test_text_classification.py

+ 4
- 2
maas_lib/models/base.py View File

@@ -8,6 +8,7 @@ from maas_hub.file_download import model_file_download
from maas_hub.snapshot_download import snapshot_download from maas_hub.snapshot_download import snapshot_download


from maas_lib.models.builder import build_model from maas_lib.models.builder import build_model
from maas_lib.pipelines import util
from maas_lib.utils.config import Config from maas_lib.utils.config import Config
from maas_lib.utils.constant import CONFIGFILE from maas_lib.utils.constant import CONFIGFILE


@@ -39,8 +40,9 @@ class Model(ABC):
if osp.exists(model_name_or_path): if osp.exists(model_name_or_path):
local_model_dir = model_name_or_path local_model_dir = model_name_or_path
else: else:

local_model_dir = snapshot_download(model_name_or_path)
cache_path = util.get_model_cache_dir(model_name_or_path)
local_model_dir = cache_path if osp.exists(
cache_path) else snapshot_download(model_name_or_path)
# else: # else:
# raise ValueError( # raise ValueError(
# 'Remote model repo {model_name_or_path} does not exists') # 'Remote model repo {model_name_or_path} does not exists')


+ 8
- 5
maas_lib/pipelines/base.py View File

@@ -2,16 +2,15 @@


import os.path as osp import os.path as osp
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from multiprocessing.sharedctypes import Value
from typing import Any, Dict, Generator, List, Tuple, Union from typing import Any, Dict, Generator, List, Tuple, Union


from ali_maas_datasets import PyDataset from ali_maas_datasets import PyDataset
from maas_hub.snapshot_download import snapshot_download from maas_hub.snapshot_download import snapshot_download


from maas_lib.models import Model from maas_lib.models import Model
from maas_lib.pipelines import util
from maas_lib.preprocessors import Preprocessor from maas_lib.preprocessors import Preprocessor
from maas_lib.utils.config import Config from maas_lib.utils.config import Config
from maas_lib.utils.constant import CONFIGFILE
from .util import is_model_name from .util import is_model_name


Tensor = Union['torch.Tensor', 'tf.Tensor'] Tensor = Union['torch.Tensor', 'tf.Tensor']
@@ -31,7 +30,7 @@ class Pipeline(ABC):
""" Base class for pipeline. """ Base class for pipeline.


If config_file is provided, model and preprocessor will be If config_file is provided, model and preprocessor will be
instantiated from corresponding config. Otherwise model
instantiated from corresponding config. Otherwise, model
and preprocessor will be constructed separately. and preprocessor will be constructed separately.


Args: Args:
@@ -44,7 +43,11 @@ class Pipeline(ABC):


if isinstance(model, str): if isinstance(model, str):
if not osp.exists(model): if not osp.exists(model):
model = snapshot_download(model)
cache_path = util.get_model_cache_dir(model)
if osp.exists(cache_path):
model = cache_path
else:
model = snapshot_download(model)


if is_model_name(model): if is_model_name(model):
self.model = Model.from_pretrained(model) self.model = Model.from_pretrained(model)
@@ -61,7 +64,7 @@ class Pipeline(ABC):


def __call__(self, input: Union[Input, List[Input]], *args, def __call__(self, input: Union[Input, List[Input]], *args,
**post_kwargs) -> Union[Dict[str, Any], Generator]: **post_kwargs) -> Union[Dict[str, Any], Generator]:
# moodel provider should leave it as it is
# model provider should leave it as it is
# maas library developer will handle this function # maas library developer will handle this function


# simple showcase, need to support iterator type for both tensorflow and pytorch # simple showcase, need to support iterator type for both tensorflow and pytorch


+ 11
- 0
maas_lib/pipelines/util.py View File

@@ -1,12 +1,23 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path as osp import os.path as osp


import json import json
from maas_hub.constants import MODEL_ID_SEPARATOR
from maas_hub.file_download import model_file_download from maas_hub.file_download import model_file_download


from maas_lib.utils.constant import CONFIGFILE from maas_lib.utils.constant import CONFIGFILE




# temp solution before the hub-cache is in place
def get_model_cache_dir(model_id: str, branch: str = 'master'):
model_id_expanded = model_id.replace('/',
MODEL_ID_SEPARATOR) + '.' + branch
default_cache_dir = os.path.expanduser(os.path.join('~/.cache', 'maas'))
return os.getenv('MAAS_CACHE',
os.path.join(default_cache_dir, 'hub', model_id_expanded))


def is_model_name(model): def is_model_name(model):
if osp.exists(model): if osp.exists(model):
if osp.exists(osp.join(model, CONFIGFILE)): if osp.exists(osp.join(model, CONFIGFILE)):


+ 13
- 6
tests/pipelines/test_image_matting.py View File

@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path as osp import os.path as osp
import shutil
import tempfile import tempfile
import unittest import unittest


@@ -8,12 +9,20 @@ import cv2
from ali_maas_datasets import PyDataset from ali_maas_datasets import PyDataset


from maas_lib.fileio import File from maas_lib.fileio import File
from maas_lib.pipelines import pipeline
from maas_lib.pipelines import pipeline, util
from maas_lib.utils.constant import Tasks from maas_lib.utils.constant import Tasks




class ImageMattingTest(unittest.TestCase): class ImageMattingTest(unittest.TestCase):


def setUp(self) -> None:
self.model_id = 'damo/image-matting-person'
# switch to False if downloading everytime is not desired
purge_cache = True
if purge_cache:
shutil.rmtree(
util.get_model_cache_dir(self.model_id), ignore_errors=True)

def test_run(self): def test_run(self):
model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \
'.com/data/test/maas/image_matting/matting_person.pb' '.com/data/test/maas/image_matting/matting_person.pb'
@@ -36,16 +45,14 @@ class ImageMattingTest(unittest.TestCase):
# input_location = '/dir/to/images' # input_location = '/dir/to/images'


dataset = PyDataset.load(input_location, target='image') dataset = PyDataset.load(input_location, target='image')
img_matting = pipeline(
Tasks.image_matting, model='damo/image-matting-person')
img_matting = pipeline(Tasks.image_matting, model=self.model_id)
# note that for dataset output, the inference-output is a Generator that can be iterated. # note that for dataset output, the inference-output is a Generator that can be iterated.
result = img_matting(dataset) result = img_matting(dataset)
cv2.imwrite('result.png', next(result)['output_png']) cv2.imwrite('result.png', next(result)['output_png'])
print(f'Output written to {osp.abspath("result.png")}') print(f'Output written to {osp.abspath("result.png")}')


def test_run_modelhub(self): def test_run_modelhub(self):
img_matting = pipeline(
Tasks.image_matting, model='damo/image-matting-person')
img_matting = pipeline(Tasks.image_matting, model=self.model_id)


result = img_matting( result = img_matting(
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png'


+ 14
- 5
tests/pipelines/test_text_classification.py View File

@@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import tempfile
import os
import shutil
import unittest import unittest
import zipfile import zipfile
from pathlib import Path from pathlib import Path
@@ -9,13 +10,21 @@ from ali_maas_datasets import PyDataset
from maas_lib.fileio import File from maas_lib.fileio import File
from maas_lib.models import Model from maas_lib.models import Model
from maas_lib.models.nlp import SequenceClassificationModel from maas_lib.models.nlp import SequenceClassificationModel
from maas_lib.pipelines import SequenceClassificationPipeline, pipeline
from maas_lib.pipelines import SequenceClassificationPipeline, pipeline, util
from maas_lib.preprocessors import SequenceClassificationPreprocessor from maas_lib.preprocessors import SequenceClassificationPreprocessor
from maas_lib.utils.constant import Tasks from maas_lib.utils.constant import Tasks




class SequenceClassificationTest(unittest.TestCase): class SequenceClassificationTest(unittest.TestCase):


def setUp(self) -> None:
self.model_id = 'damo/bert-base-sst2'
# switch to False if downloading everytime is not desired
purge_cache = True
if purge_cache:
shutil.rmtree(
util.get_model_cache_dir(self.model_id), ignore_errors=True)

def predict(self, pipeline_ins: SequenceClassificationPipeline): def predict(self, pipeline_ins: SequenceClassificationPipeline):
from easynlp.appzoo import load_dataset from easynlp.appzoo import load_dataset


@@ -60,7 +69,7 @@ class SequenceClassificationTest(unittest.TestCase):
print(pipeline2('Hello world!')) print(pipeline2('Hello world!'))


def test_run_with_model_from_modelhub(self): def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained('damo/bert-base-sst2')
model = Model.from_pretrained(self.model_id)
preprocessor = SequenceClassificationPreprocessor( preprocessor = SequenceClassificationPreprocessor(
model.model_dir, first_sequence='sentence', second_sequence=None) model.model_dir, first_sequence='sentence', second_sequence=None)
pipeline_ins = pipeline( pipeline_ins = pipeline(
@@ -71,13 +80,13 @@ class SequenceClassificationTest(unittest.TestCase):


def test_run_with_model_name(self): def test_run_with_model_name(self):
text_classification = pipeline( text_classification = pipeline(
task=Tasks.text_classification, model='damo/bert-base-sst2')
task=Tasks.text_classification, model=self.model_id)
result = text_classification( result = text_classification(
PyDataset.load('glue', name='sst2', target='sentence')) PyDataset.load('glue', name='sst2', target='sentence'))
self.printDataset(result) self.printDataset(result)


def test_run_with_dataset(self): def test_run_with_dataset(self):
model = Model.from_pretrained('damo/bert-base-sst2')
model = Model.from_pretrained(self.model_id)
preprocessor = SequenceClassificationPreprocessor( preprocessor = SequenceClassificationPreprocessor(
model.model_dir, first_sequence='sentence', second_sequence=None) model.model_dir, first_sequence='sentence', second_sequence=None)
text_classification = pipeline( text_classification = pipeline(


Loading…
Cancel
Save