|
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os
- from abc import ABC, abstractmethod
- from copy import deepcopy
- from typing import Any, Dict, Optional, Sequence
-
- from modelscope.metainfo import Models, Preprocessors
- from modelscope.utils.config import Config, ConfigDict
- from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModeKeys, Tasks
- from modelscope.utils.hub import read_config, snapshot_download
- from modelscope.utils.logger import get_logger
- from .builder import build_preprocessor
-
- logger = get_logger(__name__)
-
- PREPROCESSOR_MAP = {
- # nlp
- # bart
- (Models.bart, Tasks.text_error_correction):
- Preprocessors.text_error_correction,
-
- # bert
- (Models.bert, Tasks.backbone):
- Preprocessors.sen_cls_tokenizer,
- (Models.bert, Tasks.document_segmentation):
- Preprocessors.document_segmentation,
- (Models.bert, Tasks.fill_mask):
- Preprocessors.fill_mask,
- (Models.bert, Tasks.sentence_embedding):
- Preprocessors.sentence_embedding,
- (Models.bert, Tasks.text_classification):
- Preprocessors.sen_cls_tokenizer,
- (Models.bert, Tasks.nli):
- Preprocessors.sen_cls_tokenizer,
- (Models.bert, Tasks.sentiment_classification):
- Preprocessors.sen_cls_tokenizer,
- (Models.bert, Tasks.sentence_similarity):
- Preprocessors.sen_cls_tokenizer,
- (Models.bert, Tasks.zero_shot_classification):
- Preprocessors.sen_cls_tokenizer,
- (Models.bert, Tasks.text_ranking):
- Preprocessors.text_ranking,
- (Models.bert, Tasks.part_of_speech):
- Preprocessors.token_cls_tokenizer,
- (Models.bert, Tasks.token_classification):
- Preprocessors.token_cls_tokenizer,
- (Models.bert, Tasks.word_segmentation):
- Preprocessors.token_cls_tokenizer,
-
- # bloom
- (Models.bloom, Tasks.backbone):
- Preprocessors.text_gen_tokenizer,
-
- # gpt_neo
- # gpt_neo may have different preprocessors, but now only one
- (Models.gpt_neo, Tasks.backbone):
- Preprocessors.sentence_piece,
-
- # gpt3 has different preprocessors by different sizes of models, so they are not listed here.
-
- # palm_v2
- (Models.palm, Tasks.backbone):
- Preprocessors.text_gen_tokenizer,
-
- # T5
- (Models.T5, Tasks.backbone):
- Preprocessors.text2text_gen_preprocessor,
- (Models.T5, Tasks.text2text_generation):
- Preprocessors.text2text_gen_preprocessor,
-
- # deberta_v2
- (Models.deberta_v2, Tasks.backbone):
- Preprocessors.sen_cls_tokenizer,
- (Models.deberta_v2, Tasks.fill_mask):
- Preprocessors.fill_mask,
-
- # ponet
- (Models.ponet, Tasks.fill_mask):
- Preprocessors.fill_mask_ponet,
-
- # structbert
- (Models.structbert, Tasks.backbone):
- Preprocessors.sen_cls_tokenizer,
- (Models.structbert, Tasks.fill_mask):
- Preprocessors.fill_mask,
- (Models.structbert, Tasks.faq_question_answering):
- Preprocessors.faq_question_answering_preprocessor,
- (Models.structbert, Tasks.text_classification):
- Preprocessors.sen_cls_tokenizer,
- (Models.structbert, Tasks.nli):
- Preprocessors.sen_cls_tokenizer,
- (Models.structbert, Tasks.sentiment_classification):
- Preprocessors.sen_cls_tokenizer,
- (Models.structbert, Tasks.sentence_similarity):
- Preprocessors.sen_cls_tokenizer,
- (Models.structbert, Tasks.zero_shot_classification):
- Preprocessors.sen_cls_tokenizer,
- (Models.structbert, Tasks.part_of_speech):
- Preprocessors.token_cls_tokenizer,
- (Models.structbert, Tasks.token_classification):
- Preprocessors.token_cls_tokenizer,
- (Models.structbert, Tasks.word_segmentation):
- Preprocessors.token_cls_tokenizer,
-
- # veco
- (Models.veco, Tasks.backbone):
- Preprocessors.sen_cls_tokenizer,
- (Models.veco, Tasks.fill_mask):
- Preprocessors.fill_mask,
- (Models.veco, Tasks.text_classification):
- Preprocessors.sen_cls_tokenizer,
- (Models.veco, Tasks.nli):
- Preprocessors.sen_cls_tokenizer,
- (Models.veco, Tasks.sentiment_classification):
- Preprocessors.sen_cls_tokenizer,
- (Models.veco, Tasks.sentence_similarity):
- Preprocessors.sen_cls_tokenizer,
-
- # space
- }
-
-
- class Preprocessor(ABC):
-
- def __init__(self, mode=ModeKeys.INFERENCE, *args, **kwargs):
- self._mode = mode
- self.device = int(
- os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else None
- pass
-
- @abstractmethod
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
- pass
-
- @property
- def mode(self):
- return self._mode
-
- @mode.setter
- def mode(self, value):
- self._mode = value
-
- @classmethod
- def from_pretrained(cls,
- model_name_or_path: str,
- revision: Optional[str] = DEFAULT_MODEL_REVISION,
- cfg_dict: Config = None,
- preprocessor_mode=ModeKeys.INFERENCE,
- **kwargs):
- """ Instantiate a model from local directory or remote model repo. Note
- that when loading from remote, the model revision can be specified.
- """
- if not os.path.exists(model_name_or_path):
- model_dir = snapshot_download(
- model_name_or_path, revision=revision)
- else:
- model_dir = model_name_or_path
- if cfg_dict is None:
- cfg = read_config(model_dir)
- else:
- cfg = cfg_dict
- task = cfg.task
- if 'task' in kwargs:
- task = kwargs.pop('task')
- field_name = Tasks.find_field_by_task(task)
- sub_key = 'train' if preprocessor_mode == ModeKeys.TRAIN else 'val'
-
- if not hasattr(cfg, 'preprocessor'):
- logger.error('No preprocessor field found in cfg.')
- preprocessor_cfg = ConfigDict()
- else:
- preprocessor_cfg = cfg.preprocessor
-
- if 'type' not in preprocessor_cfg:
- if sub_key in preprocessor_cfg:
- sub_cfg = getattr(preprocessor_cfg, sub_key)
- else:
- logger.error(
- f'No {sub_key} key and type key found in '
- f'preprocessor domain of configuration.json file.')
- sub_cfg = preprocessor_cfg
- else:
- sub_cfg = preprocessor_cfg
-
- sub_cfg.update({'model_dir': model_dir})
- sub_cfg.update(kwargs)
- if 'type' in sub_cfg:
- if isinstance(sub_cfg, Sequence):
- # TODO: for Sequence, need adapt to `mode` and `mode_dir` args,
- # and add mode for Compose or other plans
- raise NotImplementedError('Not supported yet!')
- sub_cfg = deepcopy(sub_cfg)
-
- preprocessor = build_preprocessor(sub_cfg, field_name)
- else:
- logger.error(
- f'Cannot find available config to build preprocessor at mode {preprocessor_mode}, '
- f'current config: {sub_cfg}. trying to build by task and model information.'
- )
- model_cfg = getattr(cfg, 'model', ConfigDict())
- model_type = model_cfg.type if hasattr(
- model_cfg, 'type') else getattr(model_cfg, 'model_type', None)
- if task is None or model_type is None:
- logger.error(
- f'Find task: {task}, model type: {model_type}. '
- f'Insufficient information to build preprocessor, skip building preprocessor'
- )
- return None
- if (model_type, task) not in PREPROCESSOR_MAP:
- logger.error(
- f'No preprocessor key {(model_type, task)} found in PREPROCESSOR_MAP, '
- f'skip building preprocessor.')
- return None
-
- sub_cfg = ConfigDict({
- 'type': PREPROCESSOR_MAP[(model_type, task)],
- **sub_cfg
- })
- preprocessor = build_preprocessor(sub_cfg, field_name)
- preprocessor.mode = preprocessor_mode
- return preprocessor
|