Browse Source

[to #42323743] retain local cached model files by default

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8963687
master
yingda.chen 3 years ago
parent
commit
d6868ddffe
5 changed files with 50 additions and 18 deletions
  1. +4
    -2
      maas_lib/models/base.py
  2. +8
    -5
      maas_lib/pipelines/base.py
  3. +11
    -0
      maas_lib/pipelines/util.py
  4. +13
    -6
      tests/pipelines/test_image_matting.py
  5. +14
    -5
      tests/pipelines/test_text_classification.py

+ 4
- 2
maas_lib/models/base.py View File

@@ -8,6 +8,7 @@ from maas_hub.file_download import model_file_download
from maas_hub.snapshot_download import snapshot_download

from maas_lib.models.builder import build_model
from maas_lib.pipelines import util
from maas_lib.utils.config import Config
from maas_lib.utils.constant import CONFIGFILE

@@ -39,8 +40,9 @@ class Model(ABC):
if osp.exists(model_name_or_path):
local_model_dir = model_name_or_path
else:

local_model_dir = snapshot_download(model_name_or_path)
cache_path = util.get_model_cache_dir(model_name_or_path)
local_model_dir = cache_path if osp.exists(
cache_path) else snapshot_download(model_name_or_path)
# else:
# raise ValueError(
# 'Remote model repo {model_name_or_path} does not exists')


+ 8
- 5
maas_lib/pipelines/base.py View File

@@ -2,16 +2,15 @@

import os.path as osp
from abc import ABC, abstractmethod
from multiprocessing.sharedctypes import Value
from typing import Any, Dict, Generator, List, Tuple, Union

from ali_maas_datasets import PyDataset
from maas_hub.snapshot_download import snapshot_download

from maas_lib.models import Model
from maas_lib.pipelines import util
from maas_lib.preprocessors import Preprocessor
from maas_lib.utils.config import Config
from maas_lib.utils.constant import CONFIGFILE
from .util import is_model_name

Tensor = Union['torch.Tensor', 'tf.Tensor']
@@ -31,7 +30,7 @@ class Pipeline(ABC):
""" Base class for pipeline.

If config_file is provided, model and preprocessor will be
instantiated from corresponding config. Otherwise model
instantiated from corresponding config. Otherwise, model
and preprocessor will be constructed separately.

Args:
@@ -44,7 +43,11 @@ class Pipeline(ABC):

if isinstance(model, str):
if not osp.exists(model):
model = snapshot_download(model)
cache_path = util.get_model_cache_dir(model)
if osp.exists(cache_path):
model = cache_path
else:
model = snapshot_download(model)

if is_model_name(model):
self.model = Model.from_pretrained(model)
@@ -61,7 +64,7 @@ class Pipeline(ABC):

def __call__(self, input: Union[Input, List[Input]], *args,
**post_kwargs) -> Union[Dict[str, Any], Generator]:
# moodel provider should leave it as it is
# model provider should leave it as it is
# maas library developer will handle this function

# simple showcase, need to support iterator type for both tensorflow and pytorch


+ 11
- 0
maas_lib/pipelines/util.py View File

@@ -1,12 +1,23 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path as osp

import json
from maas_hub.constants import MODEL_ID_SEPARATOR
from maas_hub.file_download import model_file_download

from maas_lib.utils.constant import CONFIGFILE


# temp solution before the hub-cache is in place
def get_model_cache_dir(model_id: str, branch: str = 'master'):
model_id_expanded = model_id.replace('/',
MODEL_ID_SEPARATOR) + '.' + branch
default_cache_dir = os.path.expanduser(os.path.join('~/.cache', 'maas'))
return os.getenv('MAAS_CACHE',
os.path.join(default_cache_dir, 'hub', model_id_expanded))


def is_model_name(model):
if osp.exists(model):
if osp.exists(osp.join(model, CONFIGFILE)):


+ 13
- 6
tests/pipelines/test_image_matting.py View File

@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path as osp
import shutil
import tempfile
import unittest

@@ -8,12 +9,20 @@ import cv2
from ali_maas_datasets import PyDataset

from maas_lib.fileio import File
from maas_lib.pipelines import pipeline
from maas_lib.pipelines import pipeline, util
from maas_lib.utils.constant import Tasks


class ImageMattingTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/image-matting-person'
# switch to False if downloading everytime is not desired
purge_cache = True
if purge_cache:
shutil.rmtree(
util.get_model_cache_dir(self.model_id), ignore_errors=True)

def test_run(self):
model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \
'.com/data/test/maas/image_matting/matting_person.pb'
@@ -36,16 +45,14 @@ class ImageMattingTest(unittest.TestCase):
# input_location = '/dir/to/images'

dataset = PyDataset.load(input_location, target='image')
img_matting = pipeline(
Tasks.image_matting, model='damo/image-matting-person')
img_matting = pipeline(Tasks.image_matting, model=self.model_id)
# note that for dataset output, the inference-output is a Generator that can be iterated.
result = img_matting(dataset)
cv2.imwrite('result.png', next(result)['output_png'])
print(f'Output written to {osp.abspath("result.png")}')

def test_run_modelhub(self):
img_matting = pipeline(
Tasks.image_matting, model='damo/image-matting-person')
img_matting = pipeline(Tasks.image_matting, model=self.model_id)

result = img_matting(
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png'


+ 14
- 5
tests/pipelines/test_text_classification.py View File

@@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import tempfile
import os
import shutil
import unittest
import zipfile
from pathlib import Path
@@ -9,13 +10,21 @@ from ali_maas_datasets import PyDataset
from maas_lib.fileio import File
from maas_lib.models import Model
from maas_lib.models.nlp import SequenceClassificationModel
from maas_lib.pipelines import SequenceClassificationPipeline, pipeline
from maas_lib.pipelines import SequenceClassificationPipeline, pipeline, util
from maas_lib.preprocessors import SequenceClassificationPreprocessor
from maas_lib.utils.constant import Tasks


class SequenceClassificationTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/bert-base-sst2'
# switch to False if downloading everytime is not desired
purge_cache = True
if purge_cache:
shutil.rmtree(
util.get_model_cache_dir(self.model_id), ignore_errors=True)

def predict(self, pipeline_ins: SequenceClassificationPipeline):
from easynlp.appzoo import load_dataset

@@ -60,7 +69,7 @@ class SequenceClassificationTest(unittest.TestCase):
print(pipeline2('Hello world!'))

def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained('damo/bert-base-sst2')
model = Model.from_pretrained(self.model_id)
preprocessor = SequenceClassificationPreprocessor(
model.model_dir, first_sequence='sentence', second_sequence=None)
pipeline_ins = pipeline(
@@ -71,13 +80,13 @@ class SequenceClassificationTest(unittest.TestCase):

def test_run_with_model_name(self):
text_classification = pipeline(
task=Tasks.text_classification, model='damo/bert-base-sst2')
task=Tasks.text_classification, model=self.model_id)
result = text_classification(
PyDataset.load('glue', name='sst2', target='sentence'))
self.printDataset(result)

def test_run_with_dataset(self):
model = Model.from_pretrained('damo/bert-base-sst2')
model = Model.from_pretrained(self.model_id)
preprocessor = SequenceClassificationPreprocessor(
model.model_dir, first_sequence='sentence', second_sequence=None)
text_classification = pipeline(


Loading…
Cancel
Save