You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

repository.py 4.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import os
  2. from typing import List, Optional
  3. from modelscope.hub.errors import GitError, InvalidParameter
  4. from modelscope.utils.logger import get_logger
  5. from .api import ModelScopeConfig
  6. from .constants import MODELSCOPE_URL_SCHEME
  7. from .git import GitCommandWrapper
  8. from .utils.utils import get_endpoint
  9. logger = get_logger()
  10. class Repository:
  11. """Representation local model git repository.
  12. """
  13. def __init__(
  14. self,
  15. model_dir: str,
  16. clone_from: str,
  17. revision: Optional[str] = 'master',
  18. auth_token: Optional[str] = None,
  19. git_path: Optional[str] = None,
  20. ):
  21. """
  22. Instantiate a Repository object by cloning the remote ModelScopeHub repo
  23. Args:
  24. model_dir(`str`):
  25. The model root directory.
  26. clone_from:
  27. model id in ModelScope-hub from which git clone
  28. revision(`Optional[str]`):
  29. revision of the model you want to clone from. Can be any of a branch, tag or commit hash
  30. auth_token(`Optional[str]`):
  31. token obtained when calling `HubApi.login()`. Usually you can safely ignore the parameter
  32. as the token is already saved when you login the first time, if None, we will use saved token.
  33. git_path:(`Optional[str]`):
  34. The git command line path, if None, we use 'git'
  35. """
  36. self.model_dir = model_dir
  37. self.model_base_dir = os.path.dirname(model_dir)
  38. self.model_repo_name = os.path.basename(model_dir)
  39. if auth_token:
  40. self.auth_token = auth_token
  41. else:
  42. self.auth_token = ModelScopeConfig.get_token()
  43. git_wrapper = GitCommandWrapper()
  44. if not git_wrapper.is_lfs_installed():
  45. logger.error('git lfs is not installed, please install.')
  46. self.git_wrapper = GitCommandWrapper(git_path)
  47. os.makedirs(self.model_dir, exist_ok=True)
  48. url = self._get_model_id_url(clone_from)
  49. if os.listdir(self.model_dir): # directory not empty.
  50. remote_url = self._get_remote_url()
  51. remote_url = self.git_wrapper.remove_token_from_url(remote_url)
  52. if remote_url and remote_url == url: # need not clone again
  53. return
  54. self.git_wrapper.clone(self.model_base_dir, self.auth_token, url,
  55. self.model_repo_name, revision)
  56. if git_wrapper.is_lfs_installed():
  57. git_wrapper.git_lfs_install(self.model_dir) # init repo lfs
  58. def _get_model_id_url(self, model_id):
  59. url = f'{get_endpoint()}/{model_id}.git'
  60. return url
  61. def _get_remote_url(self):
  62. try:
  63. remote = self.git_wrapper.get_repo_remote_url(self.model_dir)
  64. except GitError:
  65. remote = None
  66. return remote
  67. def push(self,
  68. commit_message: str,
  69. branch: Optional[str] = 'master',
  70. force: bool = False):
  71. """Push local to remote, this method will do.
  72. git add
  73. git commit
  74. git push
  75. Args:
  76. commit_message (str): commit message
  77. revision (Optional[str], optional): which branch to push. Defaults to 'master'.
  78. """
  79. if commit_message is None or not isinstance(commit_message, str):
  80. msg = 'commit_message must be provided!'
  81. raise InvalidParameter(msg)
  82. if not isinstance(force, bool):
  83. raise InvalidParameter('force must be bool')
  84. url = self.git_wrapper.get_repo_remote_url(self.model_dir)
  85. self.git_wrapper.pull(self.model_dir)
  86. self.git_wrapper.add(self.model_dir, all_files=True)
  87. self.git_wrapper.commit(self.model_dir, commit_message)
  88. self.git_wrapper.push(
  89. repo_dir=self.model_dir,
  90. token=self.auth_token,
  91. url=url,
  92. local_branch=branch,
  93. remote_branch=branch)