enable finetune of ofa-mmspeech
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10981972
master^2
| @@ -41,6 +41,8 @@ __all__ = ['OfaForAllTasks'] | |||
| class OfaForAllTasks(TorchModel): | |||
| def __init__(self, model_dir, *args, **kwargs): | |||
| if os.path.exists(model_dir): | |||
| model_dir = os.path.abspath(model_dir) | |||
| super().__init__(model_dir=model_dir, *args, **kwargs) | |||
| self.cfg = Config.from_file( | |||
| osp.join(model_dir, ModelFile.CONFIGURATION)) | |||
| @@ -80,10 +80,11 @@ class OfaASRPreprocessor(OfaBasePreprocessor): | |||
| target = ' '.join(target_token_list[:self.max_tgt_length]) | |||
| sample['target'] = self.tokenize_text(target, add_bos=False) | |||
| phone_item = self.to_phone(target) - 3 | |||
| phone_item = self.to_phone(target) + 1 | |||
| phone_mask = torch.tensor([False]) | |||
| sample['phone_item'] = phone_item | |||
| sample['phone_item'] = phone_item + 3 | |||
| sample['phone_target'] = phone_item | |||
| sample['phone_mask'] = phone_mask | |||
| sample['prev_output_tokens'] = torch.cat( | |||
| @@ -1,5 +1,6 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import io | |||
| import os | |||
| import re | |||
| import string | |||
| from os import path as osp | |||
| @@ -32,6 +33,8 @@ class OfaBasePreprocessor: | |||
| self.cfg = cfg | |||
| self.mode = mode | |||
| self.language = self.cfg.model.get('language', 'en') | |||
| if os.path.exists(model_dir): | |||
| model_dir = os.path.abspath(model_dir) | |||
| if self.language == 'en': | |||
| tokenizer = OFATokenizer.from_pretrained(model_dir) | |||
| elif self.language in ['zh', 'cn']: | |||
| @@ -83,6 +83,10 @@ def collate_fn(samples, pad_idx, eos_idx): | |||
| batch['net_input']['phone_items'] = merge('phone_item') | |||
| batch['net_input']['phone_masks'] = torch.cat( | |||
| [s['phone_mask'] for s in samples]) | |||
| if samples[0].get('phone_target', None) is not None: | |||
| batch['phone_target'] = merge('phone_target') | |||
| batch['phone_length'] = torch.tensor( | |||
| [s['phone_target'].size(0) for s in samples], dtype=torch.long) | |||
| return batch | |||
| @@ -2,8 +2,8 @@ | |||
| import math | |||
| import os | |||
| import shutil | |||
| from functools import partial | |||
| from shutil import ignore_patterns | |||
| from typing import Callable, Dict, Optional, Tuple, Union | |||
| import torch | |||
| @@ -23,9 +23,9 @@ from modelscope.trainers.optimizer.builder import build_optimizer | |||
| from modelscope.trainers.parallel.utils import is_parallel | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigKeys, | |||
| Invoke, ModeKeys) | |||
| Invoke, ModeKeys, ModelFile) | |||
| from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion, | |||
| get_schedule) | |||
| get_schedule, recursive_overwrite) | |||
| @TRAINERS.register_module(module_name=Trainers.ofa) | |||
| @@ -58,23 +58,12 @@ class OFATrainer(EpochBasedTrainer): | |||
| work_dir = cfg.train.work_dir | |||
| else: | |||
| work_dir = kwargs['work_dir'] | |||
| tokenizer_files = { | |||
| 'zh': [ | |||
| 'tokenizer.json', 'tokenizer_config.json', 'vocab.txt', | |||
| 'config.json', 'ans2label.json' | |||
| ], | |||
| 'en': [ | |||
| 'tokenizer.json', 'vocab.json', 'merges.txt', 'config.json', | |||
| 'ans2label.json' | |||
| ], | |||
| } | |||
| for filename in tokenizer_files[cfg.model.get('language', 'en')]: | |||
| finetune_file = os.path.join(work_dir, filename) | |||
| pretrain_file = os.path.join(model_dir, filename) | |||
| if os.path.exists(finetune_file): | |||
| continue | |||
| if os.path.exists(pretrain_file): | |||
| shutil.copy(pretrain_file, finetune_file) | |||
| os.makedirs(work_dir, exist_ok=True) | |||
| ignore_file_set = set() | |||
| ignore_file_set.add(ModelFile.CONFIGURATION) | |||
| recursive_overwrite( | |||
| model_dir, work_dir, ignore=ignore_patterns(*ignore_file_set)) | |||
| if preprocessor is None: | |||
| preprocessor = { | |||
| @@ -3,6 +3,8 @@ | |||
| # This source code is licensed under the Apache 2.0 license | |||
| # found in the LICENSE file in the root directory. | |||
| import math | |||
| import os | |||
| import shutil | |||
| import numpy as np | |||
| import torch | |||
| @@ -11,6 +13,23 @@ import transformers | |||
| from torch.nn.modules.loss import _Loss | |||
| def recursive_overwrite(src, dst, ignore=None): | |||
| if os.path.isdir(src): | |||
| if not os.path.isdir(dst): | |||
| os.makedirs(dst) | |||
| files = os.listdir(src) | |||
| if ignore is not None: | |||
| ignored = ignore(src, files) | |||
| else: | |||
| ignored = set() | |||
| for f in files: | |||
| if f not in ignored: | |||
| recursive_overwrite( | |||
| os.path.join(src, f), os.path.join(dst, f), ignore) | |||
| else: | |||
| shutil.copyfile(src, dst) | |||
| def construct_rdrop_sample(x): | |||
| if isinstance(x, dict): | |||
| for key in x: | |||
| @@ -211,17 +230,17 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||
| return loss, nll_loss, ntokens | |||
| def compute_ctc_loss(self, model, output, sample): | |||
| lprobs = model.get_encoder_normalized_probs( | |||
| lprobs = model.model.get_encoder_normalized_probs( | |||
| output, log_probs=True).contiguous() # (T, B, C) from the encoder | |||
| non_padding_mask = ~output.encoder_padding_mask | |||
| input_lengths = non_padding_mask.long().sum(-1) | |||
| target_lengths = sample['ctc_output_lengths'] | |||
| target_lengths = sample['phone_length'] | |||
| pad_mask = torch.arange(target_lengths.max()).expand([ | |||
| target_lengths.shape[0], -1 | |||
| ]).to(target_lengths) < target_lengths.unsqueeze(1) | |||
| targets_flat = sample['ctc_outputs'].masked_select(pad_mask) | |||
| targets_flat = sample['phone_target'].masked_select(pad_mask) | |||
| with torch.backends.cudnn.flags(enabled=False): | |||
| loss = F.ctc_loss( | |||
| @@ -229,12 +248,12 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||
| targets_flat, | |||
| input_lengths, | |||
| target_lengths, | |||
| blank=self.blank_idx, | |||
| blank=0, | |||
| reduction='sum', | |||
| zero_infinity=True, | |||
| ) | |||
| return loss | |||
| return loss / lprobs.shape[1] | |||
| def get_schedule(scheduler): | |||
| @@ -0,0 +1,108 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import shutil | |||
| import unittest | |||
| import json | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.utils.constant import DownloadMode, ModelFile | |||
| from modelscope.utils.test_utils import test_level | |||
| class TestMMSpeechTrainer(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| self.finetune_cfg = \ | |||
| {'framework': 'pytorch', | |||
| 'task': 'auto-speech-recognition', | |||
| 'model': {'type': 'ofa', | |||
| 'beam_search': {'beam_size': 5, | |||
| 'max_len_b': 128, | |||
| 'min_len': 1, | |||
| 'no_repeat_ngram_size': 5, | |||
| 'constraint_range': '4,21134'}, | |||
| 'seed': 7, | |||
| 'max_src_length': 256, | |||
| 'language': 'zh', | |||
| 'gen_type': 'generation', | |||
| 'multimodal_type': 'mmspeech'}, | |||
| 'pipeline': {'type': 'ofa-asr'}, | |||
| 'n_frames_per_step': 1, | |||
| 'dataset': {'column_map': {'wav': 'Audio:FILE', 'text': 'Text:LABEL'}}, | |||
| 'train': {'work_dir': 'work/ckpts/asr_recognition', | |||
| # 'launcher': 'pytorch', | |||
| 'max_epochs': 1, | |||
| 'use_fp16': True, | |||
| 'dataloader': {'batch_size_per_gpu': 16, 'workers_per_gpu': 0}, | |||
| 'lr_scheduler': {'name': 'polynomial_decay', | |||
| 'warmup_proportion': 0.01, | |||
| 'lr_end': 1e-07}, | |||
| 'lr_scheduler_hook': {'type': 'LrSchedulerHook', 'by_epoch': False}, | |||
| 'optimizer': {'type': 'AdamW', 'lr': 5e-05, 'weight_decay': 0.01}, | |||
| 'optimizer_hook': {'type': 'TorchAMPOptimizerHook', | |||
| 'cumulative_iters': 1, | |||
| 'grad_clip': {'max_norm': 1.0, 'norm_type': 2}, | |||
| 'loss_keys': 'loss'}, | |||
| 'criterion': {'name': 'AdjustLabelSmoothedCrossEntropyCriterion', | |||
| 'constraint_range': '4,21134', | |||
| 'drop_worst_after': 0, | |||
| 'drop_worst_ratio': 0.0, | |||
| 'ignore_eos': False, | |||
| 'ignore_prefix_size': 0, | |||
| 'label_smoothing': 0.1, | |||
| 'reg_alpha': 1.0, | |||
| 'report_accuracy': False, | |||
| 'sample_patch_num': 196, | |||
| 'sentence_avg': True, | |||
| 'use_rdrop': False, | |||
| 'ctc_weight': 1.0}, | |||
| 'hooks': [{'type': 'BestCkptSaverHook', | |||
| 'metric_key': 'accuracy', | |||
| 'interval': 100}, | |||
| {'type': 'TextLoggerHook', 'interval': 1}, | |||
| {'type': 'IterTimerHook'}, | |||
| {'type': 'EvaluationHook', 'by_epoch': True, 'interval': 1}]}, | |||
| 'evaluation': {'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0}, | |||
| 'metrics': [{'type': 'accuracy'}]}, | |||
| 'preprocessor': []} | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_trainer_std(self): | |||
| WORKSPACE = './workspace/ckpts/asr_recognition' | |||
| os.makedirs(WORKSPACE, exist_ok=True) | |||
| config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) | |||
| with open(config_file, 'w') as writer: | |||
| json.dump(self.finetune_cfg, writer) | |||
| pretrained_model = 'damo/ofa_mmspeech_pretrain_base_zh' | |||
| args = dict( | |||
| model=pretrained_model, | |||
| work_dir=WORKSPACE, | |||
| train_dataset=MsDataset.load( | |||
| 'aishell1_subset', | |||
| subset_name='default', | |||
| namespace='modelscope', | |||
| split='train', | |||
| download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS), | |||
| eval_dataset=MsDataset.load( | |||
| 'aishell1_subset', | |||
| subset_name='default', | |||
| namespace='modelscope', | |||
| split='test', | |||
| download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS), | |||
| cfg_file=config_file) | |||
| trainer = build_trainer(name=Trainers.ofa, default_args=args) | |||
| trainer.train() | |||
| self.assertIn( | |||
| ModelFile.TORCH_MODEL_BIN_FILE, | |||
| os.listdir(os.path.join(WORKSPACE, ModelFile.TRAIN_OUTPUT_DIR))) | |||
| shutil.rmtree(WORKSPACE) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -76,8 +76,7 @@ class TestOfaTrainer(unittest.TestCase): | |||
| os.makedirs(WORKSPACE, exist_ok=True) | |||
| config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) | |||
| with open(config_file, 'w') as writer: | |||
| json.dump(self.finetune_cfg, writer) | |||
| json.dump(self.finetune_cfg, writer, indent=4) | |||
| pretrained_model = 'damo/ofa_ocr-recognition_scene_base_zh' | |||
| args = dict( | |||