Browse Source

[to #43887377]fix: sdk api concurrent call snapshort download file will conflict

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9672696

    * [to #43887377]fix: sdk api concurrent call snapshort download file  will conflict
master
mulin.lyh 3 years ago
parent
commit
578f82e501
7 changed files with 39 additions and 41 deletions
  1. +6
    -3
      modelscope/hub/file_download.py
  2. +31
    -24
      modelscope/hub/snapshot_download.py
  3. +0
    -3
      tests/hub/test_hub_operation.py
  4. +0
    -3
      tests/hub/test_hub_private_files.py
  5. +0
    -3
      tests/hub/test_hub_private_repository.py
  6. +0
    -3
      tests/hub/test_hub_repository.py
  7. +2
    -2
      tests/hub/test_utils.py

+ 6
- 3
modelscope/hub/file_download.py View File

@@ -79,6 +79,8 @@ def model_file_download(
cache_dir = get_cache_dir()
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
temporary_cache_dir = os.path.join(cache_dir, 'temp')
os.makedirs(temporary_cache_dir, exist_ok=True)

group_or_owner, name = model_id_to_group_owner_name(model_id)

@@ -152,12 +154,13 @@ def model_file_download(
temp_file_name = next(tempfile._get_candidate_names())
http_get_file(
url_to_download,
cache_dir,
temporary_cache_dir,
temp_file_name,
headers=headers,
cookies=None if cookies is None else cookies.get_dict())
return cache.put_file(file_to_download_info,
os.path.join(cache_dir, temp_file_name))
return cache.put_file(
file_to_download_info,
os.path.join(temporary_cache_dir, temp_file_name))


def http_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str:


+ 31
- 24
modelscope/hub/snapshot_download.py View File

@@ -1,4 +1,5 @@
import os
import tempfile
from pathlib import Path
from typing import Dict, Optional, Union

@@ -58,6 +59,8 @@ def snapshot_download(model_id: str,
cache_dir = get_cache_dir()
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
temporary_cache_dir = os.path.join(cache_dir, 'temp')
os.makedirs(temporary_cache_dir, exist_ok=True)

group_or_owner, name = model_id_to_group_owner_name(model_id)

@@ -98,31 +101,35 @@ def snapshot_download(model_id: str,
headers=snapshot_header,
)

for model_file in model_files:
if model_file['Type'] == 'tree':
continue
# check model_file is exist in cache, if exist, skip download, otherwise download
if cache.exists(model_file):
file_name = os.path.basename(model_file['Name'])
logger.info(
f'File {file_name} already in cache, skip downloading!')
continue
with tempfile.TemporaryDirectory(
dir=temporary_cache_dir) as temp_cache_dir:
for model_file in model_files:
if model_file['Type'] == 'tree':
continue
# check model_file is exist in cache, if exist, skip download, otherwise download
if cache.exists(model_file):
file_name = os.path.basename(model_file['Name'])
logger.info(
f'File {file_name} already in cache, skip downloading!'
)
continue

# get download url
url = get_file_download_url(
model_id=model_id,
file_path=model_file['Path'],
revision=revision)
# get download url
url = get_file_download_url(
model_id=model_id,
file_path=model_file['Path'],
revision=revision)

# First download to /tmp
http_get_file(
url=url,
local_dir=cache_dir,
file_name=model_file['Name'],
headers=headers,
cookies=cookies)
# put file to cache
cache.put_file(model_file,
os.path.join(cache_dir, model_file['Name']))
# First download to /tmp
http_get_file(
url=url,
local_dir=temp_cache_dir,
file_name=model_file['Name'],
headers=headers,
cookies=cookies)
# put file to cache
cache.put_file(
model_file, os.path.join(temp_cache_dir,
model_file['Name']))

return os.path.join(cache.get_root_location())

+ 0
- 3
tests/hub/test_hub_operation.py View File

@@ -21,9 +21,6 @@ DEFAULT_GIT_PATH = 'git'
download_model_file_name = 'test.bin'


@unittest.skip(
"Access token is always change, we can't login with same access token, so skip!"
)
class HubOperationTest(unittest.TestCase):

def setUp(self):


+ 0
- 3
tests/hub/test_hub_private_files.py View File

@@ -18,9 +18,6 @@ from .test_utils import (TEST_ACCESS_TOKEN1, TEST_ACCESS_TOKEN2,
delete_credential)


@unittest.skip(
"Access token is always change, we can't login with same access token, so skip!"
)
class HubPrivateFileDownloadTest(unittest.TestCase):

def setUp(self):


+ 0
- 3
tests/hub/test_hub_private_repository.py View File

@@ -15,9 +15,6 @@ from .test_utils import (TEST_ACCESS_TOKEN1, TEST_ACCESS_TOKEN2,
DEFAULT_GIT_PATH = 'git'


@unittest.skip(
"Access token is always change, we can't login with same access token, so skip!"
)
class HubPrivateRepositoryTest(unittest.TestCase):

def setUp(self):


+ 0
- 3
tests/hub/test_hub_repository.py View File

@@ -24,9 +24,6 @@ logger.setLevel('DEBUG')
DEFAULT_GIT_PATH = 'git'


@unittest.skip(
"Access token is always change, we can't login with same access token, so skip!"
)
class HubRepositoryTest(unittest.TestCase):

def setUp(self):


+ 2
- 2
tests/hub/test_utils.py View File

@@ -6,8 +6,8 @@ from os.path import expanduser
from modelscope.hub.constants import DEFAULT_CREDENTIALS_PATH

# for user citest and sdkdev
TEST_ACCESS_TOKEN1 = 'OVAzNU9aZ2FYbXFhdGNzZll6VHRtalQ0T1BpZTNGeWVhMkxSSGpTSzU0dkM5WE5ObDFKdFRQWGc2U2ZIdjdPdg=='
TEST_ACCESS_TOKEN2 = 'aXRocHhGeG0rNXRWQWhBSnJpTTZUQ0RDbUlkcUJRS1dQR2lNb0xIa0JjRDBrT1JKYklZV05DVzROTTdtamxWcg=='
TEST_ACCESS_TOKEN1 = 'RGZZdkh2Z3BlMFU1VktjUkdIcUJtdjdqdnhQUEQrUVROdVBjclAzUGVycHFhU1BFZFBIaGtUOHB1eHQ2OTV3dQ=='
TEST_ACCESS_TOKEN2 = 'dFpadllseTZQbHlyK0E4amQxVC84a2RtZHdkUVhmMUl3M1VXZXU4dS9GZlRuVmFUTW5yQm8yTENYWEw2SVh0Uw=='

TEST_MODEL_CHINESE_NAME = '内部测试模型'
TEST_MODEL_ORG = 'citest'


Loading…
Cancel
Save