You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

builder.py 17 kB

[to #42322933] NLP 1030 Refactor Features: 1. Refactor the directory structure of nlp models. All model files are placed into either the model folder or the task_model folder 2. Refactor all the comments to google style 3. Add detail comments to important tasks and nlp models, to list the description of the model, and its preprocessor&trainer 4. Model Exporting now supports a direct all to TorchModelExporter(no need to derive from it) 5. Refactor model save_pretrained method to support direct running(independent from trainer) 6. Remove the judgement of Model in the pipeline base class, to support outer register models running in our pipelines 7. Nlp trainer now has a NLPTrainingArguments class , user can pass arguments into the dataclass, and use it as a normal cfg_modify_fn, to simplify the operation of modify cfg. 8. Merge the BACKBONES and the MODELS, so user can get a backbone with the Model.from_pretrained call 9. Model.from_pretrained now support a task argument, so user can use a backbone and load it with a specific task class. 10. Support Preprocessor.from_pretrained method 11. Add standard return classes to important nlp tasks, so some of the pipelines and the models are independent now, the return values of the models will always be tensors, and the pipelines will take care of the conversion to numpy and the following stuffs. 12. Split the file of the nlp preprocessors, to make the dir structure more clear. Bugs Fixing: 1. Fix a bug that lr_scheduler can be called earlier than the optimizer's step 2. Fix a bug that the direct call of Pipelines (not from pipeline(xxx)) throws error 3. Fix a bug that the trainer will not call the correct TaskDataset class 4. Fix a bug that the internal loading of dataset will throws error in the trainer class Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10490585
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from typing import List, Optional, Union
  4. from modelscope.hub.snapshot_download import snapshot_download
  5. from modelscope.metainfo import Pipelines
  6. from modelscope.models.base import Model
  7. from modelscope.utils.config import ConfigDict, check_config
  8. from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Tasks
  9. from modelscope.utils.hub import read_config
  10. from modelscope.utils.registry import Registry, build_from_cfg
  11. from .base import Pipeline
  12. from .util import is_official_hub_path
  13. PIPELINES = Registry('pipelines')
  14. DEFAULT_MODEL_FOR_PIPELINE = {
  15. # TaskName: (pipeline_module_name, model_repo)
  16. Tasks.sentence_embedding:
  17. (Pipelines.sentence_embedding,
  18. 'damo/nlp_corom_sentence-embedding_english-base'),
  19. Tasks.text_ranking: (Pipelines.text_ranking,
  20. 'damo/nlp_corom_passage-ranking_english-base'),
  21. Tasks.word_segmentation:
  22. (Pipelines.word_segmentation,
  23. 'damo/nlp_structbert_word-segmentation_chinese-base'),
  24. Tasks.part_of_speech: (Pipelines.part_of_speech,
  25. 'damo/nlp_structbert_part-of-speech_chinese-base'),
  26. Tasks.token_classification:
  27. (Pipelines.part_of_speech,
  28. 'damo/nlp_structbert_part-of-speech_chinese-base'),
  29. Tasks.named_entity_recognition:
  30. (Pipelines.named_entity_recognition,
  31. 'damo/nlp_raner_named-entity-recognition_chinese-base-news'),
  32. Tasks.relation_extraction:
  33. (Pipelines.relation_extraction,
  34. 'damo/nlp_bert_relation-extraction_chinese-base'),
  35. Tasks.information_extraction:
  36. (Pipelines.relation_extraction,
  37. 'damo/nlp_bert_relation-extraction_chinese-base'),
  38. Tasks.sentence_similarity:
  39. (Pipelines.sentence_similarity,
  40. 'damo/nlp_structbert_sentence-similarity_chinese-base'),
  41. Tasks.translation: (Pipelines.csanmt_translation,
  42. 'damo/nlp_csanmt_translation_zh2en'),
  43. Tasks.nli: (Pipelines.nli, 'damo/nlp_structbert_nli_chinese-base'),
  44. Tasks.sentiment_classification:
  45. (Pipelines.sentiment_classification,
  46. 'damo/nlp_structbert_sentiment-classification_chinese-base'
  47. ), # TODO: revise back after passing the pr
  48. Tasks.portrait_matting: (Pipelines.portrait_matting,
  49. 'damo/cv_unet_image-matting'),
  50. Tasks.human_detection: (Pipelines.human_detection,
  51. 'damo/cv_resnet18_human-detection'),
  52. Tasks.image_object_detection: (Pipelines.object_detection,
  53. 'damo/cv_vit_object-detection_coco'),
  54. Tasks.image_denoising: (Pipelines.image_denoise,
  55. 'damo/cv_nafnet_image-denoise_sidd'),
  56. Tasks.text_classification:
  57. (Pipelines.sentiment_classification,
  58. 'damo/nlp_structbert_sentiment-classification_chinese-base'),
  59. Tasks.text_generation: (Pipelines.text_generation,
  60. 'damo/nlp_palm2.0_text-generation_chinese-base'),
  61. Tasks.zero_shot_classification:
  62. (Pipelines.zero_shot_classification,
  63. 'damo/nlp_structbert_zero-shot-classification_chinese-base'),
  64. Tasks.task_oriented_conversation: (Pipelines.dialog_modeling,
  65. 'damo/nlp_space_dialog-modeling'),
  66. Tasks.dialog_state_tracking: (Pipelines.dialog_state_tracking,
  67. 'damo/nlp_space_dialog-state-tracking'),
  68. Tasks.table_question_answering:
  69. (Pipelines.table_question_answering_pipeline,
  70. 'damo/nlp-convai-text2sql-pretrain-cn'),
  71. Tasks.text_error_correction:
  72. (Pipelines.text_error_correction,
  73. 'damo/nlp_bart_text-error-correction_chinese'),
  74. Tasks.image_captioning: (Pipelines.image_captioning,
  75. 'damo/ofa_image-caption_coco_large_en'),
  76. Tasks.image_portrait_stylization:
  77. (Pipelines.person_image_cartoon,
  78. 'damo/cv_unet_person-image-cartoon_compound-models'),
  79. Tasks.ocr_detection: (Pipelines.ocr_detection,
  80. 'damo/cv_resnet18_ocr-detection-line-level_damo'),
  81. Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'),
  82. Tasks.feature_extraction: (Pipelines.feature_extraction,
  83. 'damo/pert_feature-extraction_base-test'),
  84. Tasks.action_recognition: (Pipelines.action_recognition,
  85. 'damo/cv_TAdaConv_action-recognition'),
  86. Tasks.action_detection: (Pipelines.action_detection,
  87. 'damo/cv_ResNetC3D_action-detection_detection2d'),
  88. Tasks.live_category: (Pipelines.live_category,
  89. 'damo/cv_resnet50_live-category'),
  90. Tasks.video_category: (Pipelines.video_category,
  91. 'damo/cv_resnet50_video-category'),
  92. Tasks.multi_modal_embedding: (Pipelines.multi_modal_embedding,
  93. 'damo/multi-modal_clip-vit-base-patch16_zh'),
  94. Tasks.generative_multi_modal_embedding:
  95. (Pipelines.generative_multi_modal_embedding,
  96. 'damo/multi-modal_gemm-vit-large-patch14_generative-multi-modal-embedding'
  97. ),
  98. Tasks.multi_modal_similarity:
  99. (Pipelines.multi_modal_similarity,
  100. 'damo/multi-modal_team-vit-large-patch14_multi-modal-similarity'),
  101. Tasks.visual_question_answering:
  102. (Pipelines.visual_question_answering,
  103. 'damo/mplug_visual-question-answering_coco_large_en'),
  104. Tasks.video_embedding: (Pipelines.cmdssl_video_embedding,
  105. 'damo/cv_r2p1d_video_embedding'),
  106. Tasks.text_to_image_synthesis:
  107. (Pipelines.text_to_image_synthesis,
  108. 'damo/cv_diffusion_text-to-image-synthesis_tiny'),
  109. Tasks.body_2d_keypoints: (Pipelines.body_2d_keypoints,
  110. 'damo/cv_hrnetv2w32_body-2d-keypoints_image'),
  111. Tasks.body_3d_keypoints: (Pipelines.body_3d_keypoints,
  112. 'damo/cv_canonical_body-3d-keypoints_video'),
  113. Tasks.hand_2d_keypoints:
  114. (Pipelines.hand_2d_keypoints,
  115. 'damo/cv_hrnetw18_hand-pose-keypoints_coco-wholebody'),
  116. Tasks.face_detection: (Pipelines.face_detection,
  117. 'damo/cv_resnet_facedetection_scrfd10gkps'),
  118. Tasks.card_detection: (Pipelines.card_detection,
  119. 'damo/cv_resnet_carddetection_scrfd34gkps'),
  120. Tasks.face_detection:
  121. (Pipelines.face_detection,
  122. 'damo/cv_resnet101_face-detection_cvpr22papermogface'),
  123. Tasks.face_recognition: (Pipelines.face_recognition,
  124. 'damo/cv_ir101_facerecognition_cfglint'),
  125. Tasks.facial_expression_recognition:
  126. (Pipelines.facial_expression_recognition,
  127. 'damo/cv_vgg19_facial-expression-recognition_fer'),
  128. Tasks.face_2d_keypoints: (Pipelines.face_2d_keypoints,
  129. 'damo/cv_mobilenet_face-2d-keypoints_alignment'),
  130. Tasks.video_multi_modal_embedding:
  131. (Pipelines.video_multi_modal_embedding,
  132. 'damo/multi_modal_clip_vtretrival_msrvtt_53'),
  133. Tasks.image_color_enhancement:
  134. (Pipelines.image_color_enhance,
  135. 'damo/cv_csrnet_image-color-enhance-models'),
  136. Tasks.virtual_try_on: (Pipelines.virtual_try_on,
  137. 'damo/cv_daflow_virtual-try-on_base'),
  138. Tasks.image_colorization: (Pipelines.image_colorization,
  139. 'damo/cv_unet_image-colorization'),
  140. Tasks.image_segmentation:
  141. (Pipelines.image_instance_segmentation,
  142. 'damo/cv_swin-b_image-instance-segmentation_coco'),
  143. Tasks.image_style_transfer: (Pipelines.image_style_transfer,
  144. 'damo/cv_aams_style-transfer_damo'),
  145. Tasks.face_image_generation: (Pipelines.face_image_generation,
  146. 'damo/cv_gan_face-image-generation'),
  147. Tasks.image_super_resolution: (Pipelines.image_super_resolution,
  148. 'damo/cv_rrdb_image-super-resolution'),
  149. Tasks.image_portrait_enhancement:
  150. (Pipelines.image_portrait_enhancement,
  151. 'damo/cv_gpen_image-portrait-enhancement'),
  152. Tasks.product_retrieval_embedding:
  153. (Pipelines.product_retrieval_embedding,
  154. 'damo/cv_resnet50_product-bag-embedding-models'),
  155. Tasks.image_to_image_generation:
  156. (Pipelines.image_to_image_generation,
  157. 'damo/cv_latent_diffusion_image2image_generate'),
  158. Tasks.image_classification:
  159. (Pipelines.daily_image_classification,
  160. 'damo/cv_vit-base_image-classification_Dailylife-labels'),
  161. Tasks.image_object_detection:
  162. (Pipelines.image_object_detection_auto,
  163. 'damo/cv_yolox_image-object-detection-auto'),
  164. Tasks.ocr_recognition:
  165. (Pipelines.ocr_recognition,
  166. 'damo/cv_convnextTiny_ocr-recognition-general_damo'),
  167. Tasks.skin_retouching: (Pipelines.skin_retouching,
  168. 'damo/cv_unet_skin-retouching'),
  169. Tasks.faq_question_answering:
  170. (Pipelines.faq_question_answering,
  171. 'damo/nlp_structbert_faq-question-answering_chinese-base'),
  172. Tasks.crowd_counting: (Pipelines.crowd_counting,
  173. 'damo/cv_hrnet_crowd-counting_dcanet'),
  174. Tasks.video_single_object_tracking:
  175. (Pipelines.video_single_object_tracking,
  176. 'damo/cv_vitb_video-single-object-tracking_ostrack'),
  177. Tasks.image_reid_person: (Pipelines.image_reid_person,
  178. 'damo/cv_passvitb_image-reid-person_market'),
  179. Tasks.text_driven_segmentation:
  180. (Pipelines.text_driven_segmentation,
  181. 'damo/cv_vitl16_segmentation_text-driven-seg'),
  182. Tasks.movie_scene_segmentation:
  183. (Pipelines.movie_scene_segmentation,
  184. 'damo/cv_resnet50-bert_video-scene-segmentation_movienet'),
  185. Tasks.shop_segmentation: (Pipelines.shop_segmentation,
  186. 'damo/cv_vitb16_segmentation_shop-seg'),
  187. Tasks.image_inpainting: (Pipelines.image_inpainting,
  188. 'damo/cv_fft_inpainting_lama'),
  189. Tasks.video_inpainting: (Pipelines.video_inpainting,
  190. 'damo/cv_video-inpainting'),
  191. Tasks.human_wholebody_keypoint:
  192. (Pipelines.human_wholebody_keypoint,
  193. 'damo/cv_hrnetw48_human-wholebody-keypoint_image'),
  194. Tasks.hand_static: (Pipelines.hand_static,
  195. 'damo/cv_mobileface_hand-static'),
  196. Tasks.face_human_hand_detection:
  197. (Pipelines.face_human_hand_detection,
  198. 'damo/cv_nanodet_face-human-hand-detection'),
  199. Tasks.face_emotion: (Pipelines.face_emotion, 'damo/cv_face-emotion'),
  200. Tasks.product_segmentation: (Pipelines.product_segmentation,
  201. 'damo/cv_F3Net_product-segmentation'),
  202. Tasks.referring_video_object_segmentation:
  203. (Pipelines.referring_video_object_segmentation,
  204. 'damo/cv_swin-t_referring_video-object-segmentation'),
  205. }
  206. def normalize_model_input(model, model_revision):
  207. """ normalize the input model, to ensure that a model str is a valid local path: in other words,
  208. for model represented by a model id, the model shall be downloaded locally
  209. """
  210. if isinstance(model, str) and is_official_hub_path(model, model_revision):
  211. # skip revision download if model is a local directory
  212. if not os.path.exists(model):
  213. # note that if there is already a local copy, snapshot_download will check and skip downloading
  214. model = snapshot_download(model, revision=model_revision)
  215. elif isinstance(model, list) and isinstance(model[0], str):
  216. for idx in range(len(model)):
  217. if is_official_hub_path(
  218. model[idx],
  219. model_revision) and not os.path.exists(model[idx]):
  220. model[idx] = snapshot_download(
  221. model[idx], revision=model_revision)
  222. return model
  223. def build_pipeline(cfg: ConfigDict,
  224. task_name: str = None,
  225. default_args: dict = None):
  226. """ build pipeline given model config dict.
  227. Args:
  228. cfg (:obj:`ConfigDict`): config dict for model object.
  229. task_name (str, optional): task name, refer to
  230. :obj:`Tasks` for more details.
  231. default_args (dict, optional): Default initialization arguments.
  232. """
  233. return build_from_cfg(
  234. cfg, PIPELINES, group_key=task_name, default_args=default_args)
  235. def pipeline(task: str = None,
  236. model: Union[str, List[str], Model, List[Model]] = None,
  237. preprocessor=None,
  238. config_file: str = None,
  239. pipeline_name: str = None,
  240. framework: str = None,
  241. device: str = 'gpu',
  242. model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
  243. **kwargs) -> Pipeline:
  244. """ Factory method to build an obj:`Pipeline`.
  245. Args:
  246. task (str): Task name defining which pipeline will be returned.
  247. model (str or List[str] or obj:`Model` or obj:list[`Model`]): (list of) model name or model object.
  248. preprocessor: preprocessor object.
  249. config_file (str, optional): path to config file.
  250. pipeline_name (str, optional): pipeline class name or alias name.
  251. framework (str, optional): framework type.
  252. model_revision: revision of model(s) if getting from model hub, for multiple models, expecting
  253. all models to have the same revision
  254. device (str, optional): whether to use gpu or cpu is used to do inference.
  255. Return:
  256. pipeline (obj:`Pipeline`): pipeline object for certain task.
  257. Examples:
  258. ```python
  259. >>> # Using default model for a task
  260. >>> p = pipeline('image-classification')
  261. >>> # Using pipeline with a model name
  262. >>> p = pipeline('text-classification', model='damo/distilbert-base-uncased')
  263. >>> # Using pipeline with a model object
  264. >>> resnet = Model.from_pretrained('Resnet')
  265. >>> p = pipeline('image-classification', model=resnet)
  266. >>> # Using pipeline with a list of model names
  267. >>> p = pipeline('audio-kws', model=['damo/audio-tts', 'damo/auto-tts2'])
  268. """
  269. if task is None and pipeline_name is None:
  270. raise ValueError('task or pipeline_name is required')
  271. model = normalize_model_input(model, model_revision)
  272. if pipeline_name is None:
  273. # get default pipeline for this task
  274. if isinstance(model, str) \
  275. or (isinstance(model, list) and isinstance(model[0], str)):
  276. if is_official_hub_path(model, revision=model_revision):
  277. # read config file from hub and parse
  278. cfg = read_config(
  279. model, revision=model_revision) if isinstance(
  280. model, str) else read_config(
  281. model[0], revision=model_revision)
  282. check_config(cfg)
  283. pipeline_name = cfg.pipeline.type
  284. else:
  285. # used for test case, when model is str and is not hub path
  286. pipeline_name = get_pipeline_by_model_name(task, model)
  287. elif model is not None:
  288. # get pipeline info from Model object
  289. first_model = model[0] if isinstance(model, list) else model
  290. if not hasattr(first_model, 'pipeline'):
  291. # model is instantiated by user, we should parse config again
  292. cfg = read_config(first_model.model_dir)
  293. check_config(cfg)
  294. first_model.pipeline = cfg.pipeline
  295. pipeline_name = first_model.pipeline.type
  296. else:
  297. pipeline_name, default_model_repo = get_default_pipeline_info(task)
  298. model = normalize_model_input(default_model_repo, model_revision)
  299. cfg = ConfigDict(type=pipeline_name, model=model)
  300. cfg.device = device
  301. if kwargs:
  302. cfg.update(kwargs)
  303. if preprocessor is not None:
  304. cfg.preprocessor = preprocessor
  305. return build_pipeline(cfg, task_name=task)
  306. def add_default_pipeline_info(task: str,
  307. model_name: str,
  308. modelhub_name: str = None,
  309. overwrite: bool = False):
  310. """ Add default model for a task.
  311. Args:
  312. task (str): task name.
  313. model_name (str): model_name.
  314. modelhub_name (str): name for default modelhub.
  315. overwrite (bool): overwrite default info.
  316. """
  317. if not overwrite:
  318. assert task not in DEFAULT_MODEL_FOR_PIPELINE, \
  319. f'task {task} already has default model.'
  320. DEFAULT_MODEL_FOR_PIPELINE[task] = (model_name, modelhub_name)
  321. def get_default_pipeline_info(task):
  322. """ Get default info for certain task.
  323. Args:
  324. task (str): task name.
  325. Return:
  326. A tuple: first element is pipeline name(model_name), second element
  327. is modelhub name.
  328. """
  329. if task not in DEFAULT_MODEL_FOR_PIPELINE:
  330. # support pipeline which does not register default model
  331. pipeline_name = list(PIPELINES.modules[task].keys())[0]
  332. default_model = None
  333. else:
  334. pipeline_name, default_model = DEFAULT_MODEL_FOR_PIPELINE[task]
  335. return pipeline_name, default_model
  336. def get_pipeline_by_model_name(task: str, model: Union[str, List[str]]):
  337. """ Get pipeline name by task name and model name
  338. Args:
  339. task (str): task name.
  340. model (str| list[str]): model names
  341. """
  342. if isinstance(model, str):
  343. model_key = model
  344. else:
  345. model_key = '_'.join(model)
  346. assert model_key in PIPELINES.modules[task], \
  347. f'pipeline for task {task} model {model_key} not found.'
  348. return model_key