|
- # Copyright (c) Alibaba, Inc. and its affiliates.
-
- import os
- import os.path as osp
- from typing import List, Optional, Union
-
- from requests import HTTPError
-
- from modelscope.hub.constants import Licenses, ModelVisibility
- from modelscope.hub.file_download import model_file_download
- from modelscope.hub.snapshot_download import snapshot_download
- from modelscope.utils.config import Config
- from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields,
- ModelFile)
- from .logger import get_logger
-
- logger = get_logger()
-
-
- def create_model_if_not_exist(
- api,
- model_id: str,
- chinese_name: str,
- visibility: Optional[int] = ModelVisibility.PUBLIC,
- license: Optional[str] = Licenses.APACHE_V2,
- revision: Optional[str] = DEFAULT_MODEL_REVISION):
- exists = True
- try:
- api.get_model(model_id=model_id, revision=revision)
- except HTTPError:
- exists = False
- if exists:
- print(f'model {model_id} already exists, skip creation.')
- return False
- else:
- api.create_model(
- model_id=model_id,
- visibility=visibility,
- license=license,
- chinese_name=chinese_name,
- )
- print(f'model {model_id} successfully created.')
- return True
-
-
- def read_config(model_id_or_path: str,
- revision: Optional[str] = DEFAULT_MODEL_REVISION):
- """ Read config from hub or local path
-
- Args:
- model_id_or_path (str): Model repo name or local directory path.
- revision: revision of the model when getting from the hub
- Return:
- config (:obj:`Config`): config object
- """
- if not os.path.exists(model_id_or_path):
- local_path = model_file_download(
- model_id_or_path, ModelFile.CONFIGURATION, revision=revision)
- elif os.path.isdir(model_id_or_path):
- local_path = os.path.join(model_id_or_path, ModelFile.CONFIGURATION)
- elif os.path.isfile(model_id_or_path):
- local_path = model_id_or_path
-
- return Config.from_file(local_path)
-
-
- def auto_load(model: Union[str, List[str]]):
- if isinstance(model, str):
- if not osp.exists(model):
- model = snapshot_download(model)
- else:
- model = [
- snapshot_download(m) if not osp.exists(m) else m for m in model
- ]
-
- return model
-
-
- def get_model_type(model_dir):
- """Get the model type from the configuration.
-
- This method will try to get the model type from 'model.backbone.type',
- 'model.type' or 'model.model_type' field in the configuration.json file. If
- this file does not exist, the method will try to get the 'model_type' field
- from the config.json.
-
- Args:
- model_dir: The local model dir to use. @return: The model type
- string, returns None if nothing is found.
- """
- try:
- configuration_file = osp.join(model_dir, ModelFile.CONFIGURATION)
- config_file = osp.join(model_dir, 'config.json')
- if osp.isfile(configuration_file):
- cfg = Config.from_file(configuration_file)
- if hasattr(cfg.model, 'backbone'):
- return cfg.model.backbone.type
- elif hasattr(cfg.model,
- 'model_type') and not hasattr(cfg.model, 'type'):
- return cfg.model.model_type
- else:
- return cfg.model.type
- elif osp.isfile(config_file):
- cfg = Config.from_file(config_file)
- return cfg.model_type if hasattr(cfg, 'model_type') else None
- except Exception as e:
- logger.error(f'parse config file failed with error: {e}')
-
-
- def parse_label_mapping(model_dir):
- """Get the label mapping from the model dir.
-
- This method will do:
- 1. Try to read label-id mapping from the label_mapping.json
- 2. Try to read label-id mapping from the configuration.json
- 3. Try to read label-id mapping from the config.json
-
- Args:
- model_dir: The local model dir to use.
-
- Returns:
- The label2id mapping if found.
- """
- import json
- import os
- label2id = None
- label_path = os.path.join(model_dir, ModelFile.LABEL_MAPPING)
- if os.path.exists(label_path):
- with open(label_path, encoding='utf-8') as f:
- label_mapping = json.load(f)
- label2id = {name: idx for name, idx in label_mapping.items()}
-
- if label2id is None:
- config_path = os.path.join(model_dir, ModelFile.CONFIGURATION)
- config = Config.from_file(config_path)
- if hasattr(config, ConfigFields.model) and hasattr(
- config[ConfigFields.model], 'label2id'):
- label2id = config[ConfigFields.model].label2id
- elif hasattr(config, ConfigFields.model) and hasattr(
- config[ConfigFields.model], 'id2label'):
- id2label = config[ConfigFields.model].id2label
- label2id = {label: id for id, label in id2label.items()}
- elif hasattr(config, ConfigFields.preprocessor) and hasattr(
- config[ConfigFields.preprocessor], 'label2id'):
- label2id = config[ConfigFields.preprocessor].label2id
- elif hasattr(config, ConfigFields.preprocessor) and hasattr(
- config[ConfigFields.preprocessor], 'id2label'):
- id2label = config[ConfigFields.preprocessor].id2label
- label2id = {label: id for id, label in id2label.items()}
-
- config_path = os.path.join(model_dir, 'config.json')
- if label2id is None and os.path.exists(config_path):
- config = Config.from_file(config_path)
- if hasattr(config, 'label2id'):
- label2id = config.label2id
- elif hasattr(config, 'id2label'):
- id2label = config.id2label
- label2id = {label: id for id, label in id2label.items()}
- if label2id is not None:
- label2id = {label: int(id) for label, id in label2id.items()}
- return label2id
|