enable finetune of ofa-mmspeech
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10981972
master^2
| @@ -41,6 +41,8 @@ __all__ = ['OfaForAllTasks'] | |||||
| class OfaForAllTasks(TorchModel): | class OfaForAllTasks(TorchModel): | ||||
| def __init__(self, model_dir, *args, **kwargs): | def __init__(self, model_dir, *args, **kwargs): | ||||
| if os.path.exists(model_dir): | |||||
| model_dir = os.path.abspath(model_dir) | |||||
| super().__init__(model_dir=model_dir, *args, **kwargs) | super().__init__(model_dir=model_dir, *args, **kwargs) | ||||
| self.cfg = Config.from_file( | self.cfg = Config.from_file( | ||||
| osp.join(model_dir, ModelFile.CONFIGURATION)) | osp.join(model_dir, ModelFile.CONFIGURATION)) | ||||
| @@ -80,10 +80,11 @@ class OfaASRPreprocessor(OfaBasePreprocessor): | |||||
| target = ' '.join(target_token_list[:self.max_tgt_length]) | target = ' '.join(target_token_list[:self.max_tgt_length]) | ||||
| sample['target'] = self.tokenize_text(target, add_bos=False) | sample['target'] = self.tokenize_text(target, add_bos=False) | ||||
| phone_item = self.to_phone(target) - 3 | |||||
| phone_item = self.to_phone(target) + 1 | |||||
| phone_mask = torch.tensor([False]) | phone_mask = torch.tensor([False]) | ||||
| sample['phone_item'] = phone_item | |||||
| sample['phone_item'] = phone_item + 3 | |||||
| sample['phone_target'] = phone_item | |||||
| sample['phone_mask'] = phone_mask | sample['phone_mask'] = phone_mask | ||||
| sample['prev_output_tokens'] = torch.cat( | sample['prev_output_tokens'] = torch.cat( | ||||
| @@ -1,5 +1,6 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import io | import io | ||||
| import os | |||||
| import re | import re | ||||
| import string | import string | ||||
| from os import path as osp | from os import path as osp | ||||
| @@ -32,6 +33,8 @@ class OfaBasePreprocessor: | |||||
| self.cfg = cfg | self.cfg = cfg | ||||
| self.mode = mode | self.mode = mode | ||||
| self.language = self.cfg.model.get('language', 'en') | self.language = self.cfg.model.get('language', 'en') | ||||
| if os.path.exists(model_dir): | |||||
| model_dir = os.path.abspath(model_dir) | |||||
| if self.language == 'en': | if self.language == 'en': | ||||
| tokenizer = OFATokenizer.from_pretrained(model_dir) | tokenizer = OFATokenizer.from_pretrained(model_dir) | ||||
| elif self.language in ['zh', 'cn']: | elif self.language in ['zh', 'cn']: | ||||
| @@ -83,6 +83,10 @@ def collate_fn(samples, pad_idx, eos_idx): | |||||
| batch['net_input']['phone_items'] = merge('phone_item') | batch['net_input']['phone_items'] = merge('phone_item') | ||||
| batch['net_input']['phone_masks'] = torch.cat( | batch['net_input']['phone_masks'] = torch.cat( | ||||
| [s['phone_mask'] for s in samples]) | [s['phone_mask'] for s in samples]) | ||||
| if samples[0].get('phone_target', None) is not None: | |||||
| batch['phone_target'] = merge('phone_target') | |||||
| batch['phone_length'] = torch.tensor( | |||||
| [s['phone_target'].size(0) for s in samples], dtype=torch.long) | |||||
| return batch | return batch | ||||
| @@ -2,8 +2,8 @@ | |||||
| import math | import math | ||||
| import os | import os | ||||
| import shutil | |||||
| from functools import partial | from functools import partial | ||||
| from shutil import ignore_patterns | |||||
| from typing import Callable, Dict, Optional, Tuple, Union | from typing import Callable, Dict, Optional, Tuple, Union | ||||
| import torch | import torch | ||||
| @@ -23,9 +23,9 @@ from modelscope.trainers.optimizer.builder import build_optimizer | |||||
| from modelscope.trainers.parallel.utils import is_parallel | from modelscope.trainers.parallel.utils import is_parallel | ||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigKeys, | from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigKeys, | ||||
| Invoke, ModeKeys) | |||||
| Invoke, ModeKeys, ModelFile) | |||||
| from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion, | from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion, | ||||
| get_schedule) | |||||
| get_schedule, recursive_overwrite) | |||||
| @TRAINERS.register_module(module_name=Trainers.ofa) | @TRAINERS.register_module(module_name=Trainers.ofa) | ||||
| @@ -58,23 +58,12 @@ class OFATrainer(EpochBasedTrainer): | |||||
| work_dir = cfg.train.work_dir | work_dir = cfg.train.work_dir | ||||
| else: | else: | ||||
| work_dir = kwargs['work_dir'] | work_dir = kwargs['work_dir'] | ||||
| tokenizer_files = { | |||||
| 'zh': [ | |||||
| 'tokenizer.json', 'tokenizer_config.json', 'vocab.txt', | |||||
| 'config.json', 'ans2label.json' | |||||
| ], | |||||
| 'en': [ | |||||
| 'tokenizer.json', 'vocab.json', 'merges.txt', 'config.json', | |||||
| 'ans2label.json' | |||||
| ], | |||||
| } | |||||
| for filename in tokenizer_files[cfg.model.get('language', 'en')]: | |||||
| finetune_file = os.path.join(work_dir, filename) | |||||
| pretrain_file = os.path.join(model_dir, filename) | |||||
| if os.path.exists(finetune_file): | |||||
| continue | |||||
| if os.path.exists(pretrain_file): | |||||
| shutil.copy(pretrain_file, finetune_file) | |||||
| os.makedirs(work_dir, exist_ok=True) | |||||
| ignore_file_set = set() | |||||
| ignore_file_set.add(ModelFile.CONFIGURATION) | |||||
| recursive_overwrite( | |||||
| model_dir, work_dir, ignore=ignore_patterns(*ignore_file_set)) | |||||
| if preprocessor is None: | if preprocessor is None: | ||||
| preprocessor = { | preprocessor = { | ||||
| @@ -3,6 +3,8 @@ | |||||
| # This source code is licensed under the Apache 2.0 license | # This source code is licensed under the Apache 2.0 license | ||||
| # found in the LICENSE file in the root directory. | # found in the LICENSE file in the root directory. | ||||
| import math | import math | ||||
| import os | |||||
| import shutil | |||||
| import numpy as np | import numpy as np | ||||
| import torch | import torch | ||||
| @@ -11,6 +13,23 @@ import transformers | |||||
| from torch.nn.modules.loss import _Loss | from torch.nn.modules.loss import _Loss | ||||
| def recursive_overwrite(src, dst, ignore=None): | |||||
| if os.path.isdir(src): | |||||
| if not os.path.isdir(dst): | |||||
| os.makedirs(dst) | |||||
| files = os.listdir(src) | |||||
| if ignore is not None: | |||||
| ignored = ignore(src, files) | |||||
| else: | |||||
| ignored = set() | |||||
| for f in files: | |||||
| if f not in ignored: | |||||
| recursive_overwrite( | |||||
| os.path.join(src, f), os.path.join(dst, f), ignore) | |||||
| else: | |||||
| shutil.copyfile(src, dst) | |||||
| def construct_rdrop_sample(x): | def construct_rdrop_sample(x): | ||||
| if isinstance(x, dict): | if isinstance(x, dict): | ||||
| for key in x: | for key in x: | ||||
| @@ -211,17 +230,17 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||||
| return loss, nll_loss, ntokens | return loss, nll_loss, ntokens | ||||
| def compute_ctc_loss(self, model, output, sample): | def compute_ctc_loss(self, model, output, sample): | ||||
| lprobs = model.get_encoder_normalized_probs( | |||||
| lprobs = model.model.get_encoder_normalized_probs( | |||||
| output, log_probs=True).contiguous() # (T, B, C) from the encoder | output, log_probs=True).contiguous() # (T, B, C) from the encoder | ||||
| non_padding_mask = ~output.encoder_padding_mask | non_padding_mask = ~output.encoder_padding_mask | ||||
| input_lengths = non_padding_mask.long().sum(-1) | input_lengths = non_padding_mask.long().sum(-1) | ||||
| target_lengths = sample['ctc_output_lengths'] | |||||
| target_lengths = sample['phone_length'] | |||||
| pad_mask = torch.arange(target_lengths.max()).expand([ | pad_mask = torch.arange(target_lengths.max()).expand([ | ||||
| target_lengths.shape[0], -1 | target_lengths.shape[0], -1 | ||||
| ]).to(target_lengths) < target_lengths.unsqueeze(1) | ]).to(target_lengths) < target_lengths.unsqueeze(1) | ||||
| targets_flat = sample['ctc_outputs'].masked_select(pad_mask) | |||||
| targets_flat = sample['phone_target'].masked_select(pad_mask) | |||||
| with torch.backends.cudnn.flags(enabled=False): | with torch.backends.cudnn.flags(enabled=False): | ||||
| loss = F.ctc_loss( | loss = F.ctc_loss( | ||||
| @@ -229,12 +248,12 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||||
| targets_flat, | targets_flat, | ||||
| input_lengths, | input_lengths, | ||||
| target_lengths, | target_lengths, | ||||
| blank=self.blank_idx, | |||||
| blank=0, | |||||
| reduction='sum', | reduction='sum', | ||||
| zero_infinity=True, | zero_infinity=True, | ||||
| ) | ) | ||||
| return loss | |||||
| return loss / lprobs.shape[1] | |||||
| def get_schedule(scheduler): | def get_schedule(scheduler): | ||||
| @@ -0,0 +1,108 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import shutil | |||||
| import unittest | |||||
| import json | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.msdatasets import MsDataset | |||||
| from modelscope.trainers import build_trainer | |||||
| from modelscope.utils.constant import DownloadMode, ModelFile | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class TestMMSpeechTrainer(unittest.TestCase): | |||||
| def setUp(self) -> None: | |||||
| self.finetune_cfg = \ | |||||
| {'framework': 'pytorch', | |||||
| 'task': 'auto-speech-recognition', | |||||
| 'model': {'type': 'ofa', | |||||
| 'beam_search': {'beam_size': 5, | |||||
| 'max_len_b': 128, | |||||
| 'min_len': 1, | |||||
| 'no_repeat_ngram_size': 5, | |||||
| 'constraint_range': '4,21134'}, | |||||
| 'seed': 7, | |||||
| 'max_src_length': 256, | |||||
| 'language': 'zh', | |||||
| 'gen_type': 'generation', | |||||
| 'multimodal_type': 'mmspeech'}, | |||||
| 'pipeline': {'type': 'ofa-asr'}, | |||||
| 'n_frames_per_step': 1, | |||||
| 'dataset': {'column_map': {'wav': 'Audio:FILE', 'text': 'Text:LABEL'}}, | |||||
| 'train': {'work_dir': 'work/ckpts/asr_recognition', | |||||
| # 'launcher': 'pytorch', | |||||
| 'max_epochs': 1, | |||||
| 'use_fp16': True, | |||||
| 'dataloader': {'batch_size_per_gpu': 16, 'workers_per_gpu': 0}, | |||||
| 'lr_scheduler': {'name': 'polynomial_decay', | |||||
| 'warmup_proportion': 0.01, | |||||
| 'lr_end': 1e-07}, | |||||
| 'lr_scheduler_hook': {'type': 'LrSchedulerHook', 'by_epoch': False}, | |||||
| 'optimizer': {'type': 'AdamW', 'lr': 5e-05, 'weight_decay': 0.01}, | |||||
| 'optimizer_hook': {'type': 'TorchAMPOptimizerHook', | |||||
| 'cumulative_iters': 1, | |||||
| 'grad_clip': {'max_norm': 1.0, 'norm_type': 2}, | |||||
| 'loss_keys': 'loss'}, | |||||
| 'criterion': {'name': 'AdjustLabelSmoothedCrossEntropyCriterion', | |||||
| 'constraint_range': '4,21134', | |||||
| 'drop_worst_after': 0, | |||||
| 'drop_worst_ratio': 0.0, | |||||
| 'ignore_eos': False, | |||||
| 'ignore_prefix_size': 0, | |||||
| 'label_smoothing': 0.1, | |||||
| 'reg_alpha': 1.0, | |||||
| 'report_accuracy': False, | |||||
| 'sample_patch_num': 196, | |||||
| 'sentence_avg': True, | |||||
| 'use_rdrop': False, | |||||
| 'ctc_weight': 1.0}, | |||||
| 'hooks': [{'type': 'BestCkptSaverHook', | |||||
| 'metric_key': 'accuracy', | |||||
| 'interval': 100}, | |||||
| {'type': 'TextLoggerHook', 'interval': 1}, | |||||
| {'type': 'IterTimerHook'}, | |||||
| {'type': 'EvaluationHook', 'by_epoch': True, 'interval': 1}]}, | |||||
| 'evaluation': {'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0}, | |||||
| 'metrics': [{'type': 'accuracy'}]}, | |||||
| 'preprocessor': []} | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_trainer_std(self): | |||||
| WORKSPACE = './workspace/ckpts/asr_recognition' | |||||
| os.makedirs(WORKSPACE, exist_ok=True) | |||||
| config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) | |||||
| with open(config_file, 'w') as writer: | |||||
| json.dump(self.finetune_cfg, writer) | |||||
| pretrained_model = 'damo/ofa_mmspeech_pretrain_base_zh' | |||||
| args = dict( | |||||
| model=pretrained_model, | |||||
| work_dir=WORKSPACE, | |||||
| train_dataset=MsDataset.load( | |||||
| 'aishell1_subset', | |||||
| subset_name='default', | |||||
| namespace='modelscope', | |||||
| split='train', | |||||
| download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS), | |||||
| eval_dataset=MsDataset.load( | |||||
| 'aishell1_subset', | |||||
| subset_name='default', | |||||
| namespace='modelscope', | |||||
| split='test', | |||||
| download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS), | |||||
| cfg_file=config_file) | |||||
| trainer = build_trainer(name=Trainers.ofa, default_args=args) | |||||
| trainer.train() | |||||
| self.assertIn( | |||||
| ModelFile.TORCH_MODEL_BIN_FILE, | |||||
| os.listdir(os.path.join(WORKSPACE, ModelFile.TRAIN_OUTPUT_DIR))) | |||||
| shutil.rmtree(WORKSPACE) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||
| @@ -76,8 +76,7 @@ class TestOfaTrainer(unittest.TestCase): | |||||
| os.makedirs(WORKSPACE, exist_ok=True) | os.makedirs(WORKSPACE, exist_ok=True) | ||||
| config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) | config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) | ||||
| with open(config_file, 'w') as writer: | with open(config_file, 'w') as writer: | ||||
| json.dump(self.finetune_cfg, writer) | |||||
| json.dump(self.finetune_cfg, writer, indent=4) | |||||
| pretrained_model = 'damo/ofa_ocr-recognition_scene_base_zh' | pretrained_model = 'damo/ofa_ocr-recognition_scene_base_zh' | ||||
| args = dict( | args = dict( | ||||