@@ -1,5 +1,4 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
import jsonplus | |||||
import numpy as np | import numpy as np | ||||
from .base import FormatHandler | from .base import FormatHandler | ||||
@@ -25,11 +24,13 @@ class JsonHandler(FormatHandler): | |||||
"""Use jsonplus, serialization of Python types to JSON that "just works".""" | """Use jsonplus, serialization of Python types to JSON that "just works".""" | ||||
def load(self, file): | def load(self, file): | ||||
import jsonplus | |||||
return jsonplus.loads(file.read()) | return jsonplus.loads(file.read()) | ||||
def dump(self, obj, file, **kwargs): | def dump(self, obj, file, **kwargs): | ||||
file.write(self.dumps(obj, **kwargs)) | file.write(self.dumps(obj, **kwargs)) | ||||
def dumps(self, obj, **kwargs): | def dumps(self, obj, **kwargs): | ||||
import jsonplus | |||||
kwargs.setdefault('default', set_default) | kwargs.setdefault('default', set_default) | ||||
return jsonplus.dumps(obj, **kwargs) | return jsonplus.dumps(obj, **kwargs) |
@@ -1,5 +1,6 @@ | |||||
import hashlib | import hashlib | ||||
import os | import os | ||||
from typing import Optional | |||||
from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DATA_ENDPOINT, | from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DATA_ENDPOINT, | ||||
DEFAULT_MODELSCOPE_DOMAIN, | DEFAULT_MODELSCOPE_DOMAIN, | ||||
@@ -23,14 +24,16 @@ def model_id_to_group_owner_name(model_id): | |||||
return group_or_owner, name | return group_or_owner, name | ||||
def get_cache_dir(): | |||||
def get_cache_dir(model_id: Optional[str] = None): | |||||
""" | """ | ||||
cache dir precedence: | cache dir precedence: | ||||
function parameter > enviroment > ~/.cache/modelscope/hub | function parameter > enviroment > ~/.cache/modelscope/hub | ||||
""" | """ | ||||
default_cache_dir = get_default_cache_dir() | default_cache_dir = get_default_cache_dir() | ||||
return os.getenv('MODELSCOPE_CACHE', os.path.join(default_cache_dir, | |||||
'hub')) | |||||
base_path = os.getenv('MODELSCOPE_CACHE', | |||||
os.path.join(default_cache_dir, 'hub')) | |||||
return base_path if model_id is None else os.path.join( | |||||
base_path, model_id + '/') | |||||
def get_endpoint(): | def get_endpoint(): | ||||