shiyi.zxh yichang.zyc 3 years ago
parent
commit
c3a494e46d
8 changed files with 154 additions and 29 deletions
  1. +2
    -0
      modelscope/models/multi_modal/ofa_for_all_tasks.py
  2. +3
    -2
      modelscope/preprocessors/ofa/asr.py
  3. +3
    -0
      modelscope/preprocessors/ofa/base.py
  4. +4
    -0
      modelscope/preprocessors/ofa/utils/collate.py
  5. +9
    -20
      modelscope/trainers/multi_modal/ofa/ofa_trainer.py
  6. +24
    -5
      modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py
  7. +108
    -0
      tests/trainers/test_ofa_mmspeech_trainer.py
  8. +1
    -2
      tests/trainers/test_ofa_trainer.py

+ 2
- 0
modelscope/models/multi_modal/ofa_for_all_tasks.py View File

@@ -41,6 +41,8 @@ __all__ = ['OfaForAllTasks']
class OfaForAllTasks(TorchModel): class OfaForAllTasks(TorchModel):


def __init__(self, model_dir, *args, **kwargs): def __init__(self, model_dir, *args, **kwargs):
if os.path.exists(model_dir):
model_dir = os.path.abspath(model_dir)
super().__init__(model_dir=model_dir, *args, **kwargs) super().__init__(model_dir=model_dir, *args, **kwargs)
self.cfg = Config.from_file( self.cfg = Config.from_file(
osp.join(model_dir, ModelFile.CONFIGURATION)) osp.join(model_dir, ModelFile.CONFIGURATION))


+ 3
- 2
modelscope/preprocessors/ofa/asr.py View File

@@ -80,10 +80,11 @@ class OfaASRPreprocessor(OfaBasePreprocessor):
target = ' '.join(target_token_list[:self.max_tgt_length]) target = ' '.join(target_token_list[:self.max_tgt_length])
sample['target'] = self.tokenize_text(target, add_bos=False) sample['target'] = self.tokenize_text(target, add_bos=False)


phone_item = self.to_phone(target) - 3
phone_item = self.to_phone(target) + 1
phone_mask = torch.tensor([False]) phone_mask = torch.tensor([False])


sample['phone_item'] = phone_item
sample['phone_item'] = phone_item + 3
sample['phone_target'] = phone_item
sample['phone_mask'] = phone_mask sample['phone_mask'] = phone_mask


sample['prev_output_tokens'] = torch.cat( sample['prev_output_tokens'] = torch.cat(


+ 3
- 0
modelscope/preprocessors/ofa/base.py View File

@@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import io import io
import os
import re import re
import string import string
from os import path as osp from os import path as osp
@@ -32,6 +33,8 @@ class OfaBasePreprocessor:
self.cfg = cfg self.cfg = cfg
self.mode = mode self.mode = mode
self.language = self.cfg.model.get('language', 'en') self.language = self.cfg.model.get('language', 'en')
if os.path.exists(model_dir):
model_dir = os.path.abspath(model_dir)
if self.language == 'en': if self.language == 'en':
tokenizer = OFATokenizer.from_pretrained(model_dir) tokenizer = OFATokenizer.from_pretrained(model_dir)
elif self.language in ['zh', 'cn']: elif self.language in ['zh', 'cn']:


+ 4
- 0
modelscope/preprocessors/ofa/utils/collate.py View File

@@ -83,6 +83,10 @@ def collate_fn(samples, pad_idx, eos_idx):
batch['net_input']['phone_items'] = merge('phone_item') batch['net_input']['phone_items'] = merge('phone_item')
batch['net_input']['phone_masks'] = torch.cat( batch['net_input']['phone_masks'] = torch.cat(
[s['phone_mask'] for s in samples]) [s['phone_mask'] for s in samples])
if samples[0].get('phone_target', None) is not None:
batch['phone_target'] = merge('phone_target')
batch['phone_length'] = torch.tensor(
[s['phone_target'].size(0) for s in samples], dtype=torch.long)


return batch return batch




+ 9
- 20
modelscope/trainers/multi_modal/ofa/ofa_trainer.py View File

@@ -2,8 +2,8 @@


import math import math
import os import os
import shutil
from functools import partial from functools import partial
from shutil import ignore_patterns
from typing import Callable, Dict, Optional, Tuple, Union from typing import Callable, Dict, Optional, Tuple, Union


import torch import torch
@@ -23,9 +23,9 @@ from modelscope.trainers.optimizer.builder import build_optimizer
from modelscope.trainers.parallel.utils import is_parallel from modelscope.trainers.parallel.utils import is_parallel
from modelscope.utils.config import Config from modelscope.utils.config import Config
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigKeys, from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigKeys,
Invoke, ModeKeys)
Invoke, ModeKeys, ModelFile)
from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion, from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion,
get_schedule)
get_schedule, recursive_overwrite)




@TRAINERS.register_module(module_name=Trainers.ofa) @TRAINERS.register_module(module_name=Trainers.ofa)
@@ -58,23 +58,12 @@ class OFATrainer(EpochBasedTrainer):
work_dir = cfg.train.work_dir work_dir = cfg.train.work_dir
else: else:
work_dir = kwargs['work_dir'] work_dir = kwargs['work_dir']
tokenizer_files = {
'zh': [
'tokenizer.json', 'tokenizer_config.json', 'vocab.txt',
'config.json', 'ans2label.json'
],
'en': [
'tokenizer.json', 'vocab.json', 'merges.txt', 'config.json',
'ans2label.json'
],
}
for filename in tokenizer_files[cfg.model.get('language', 'en')]:
finetune_file = os.path.join(work_dir, filename)
pretrain_file = os.path.join(model_dir, filename)
if os.path.exists(finetune_file):
continue
if os.path.exists(pretrain_file):
shutil.copy(pretrain_file, finetune_file)

os.makedirs(work_dir, exist_ok=True)
ignore_file_set = set()
ignore_file_set.add(ModelFile.CONFIGURATION)
recursive_overwrite(
model_dir, work_dir, ignore=ignore_patterns(*ignore_file_set))


if preprocessor is None: if preprocessor is None:
preprocessor = { preprocessor = {


+ 24
- 5
modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py View File

@@ -3,6 +3,8 @@
# This source code is licensed under the Apache 2.0 license # This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory. # found in the LICENSE file in the root directory.
import math import math
import os
import shutil


import numpy as np import numpy as np
import torch import torch
@@ -11,6 +13,23 @@ import transformers
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss




def recursive_overwrite(src, dst, ignore=None):
if os.path.isdir(src):
if not os.path.isdir(dst):
os.makedirs(dst)
files = os.listdir(src)
if ignore is not None:
ignored = ignore(src, files)
else:
ignored = set()
for f in files:
if f not in ignored:
recursive_overwrite(
os.path.join(src, f), os.path.join(dst, f), ignore)
else:
shutil.copyfile(src, dst)


def construct_rdrop_sample(x): def construct_rdrop_sample(x):
if isinstance(x, dict): if isinstance(x, dict):
for key in x: for key in x:
@@ -211,17 +230,17 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):
return loss, nll_loss, ntokens return loss, nll_loss, ntokens


def compute_ctc_loss(self, model, output, sample): def compute_ctc_loss(self, model, output, sample):
lprobs = model.get_encoder_normalized_probs(
lprobs = model.model.get_encoder_normalized_probs(
output, log_probs=True).contiguous() # (T, B, C) from the encoder output, log_probs=True).contiguous() # (T, B, C) from the encoder


non_padding_mask = ~output.encoder_padding_mask non_padding_mask = ~output.encoder_padding_mask
input_lengths = non_padding_mask.long().sum(-1) input_lengths = non_padding_mask.long().sum(-1)


target_lengths = sample['ctc_output_lengths']
target_lengths = sample['phone_length']
pad_mask = torch.arange(target_lengths.max()).expand([ pad_mask = torch.arange(target_lengths.max()).expand([
target_lengths.shape[0], -1 target_lengths.shape[0], -1
]).to(target_lengths) < target_lengths.unsqueeze(1) ]).to(target_lengths) < target_lengths.unsqueeze(1)
targets_flat = sample['ctc_outputs'].masked_select(pad_mask)
targets_flat = sample['phone_target'].masked_select(pad_mask)


with torch.backends.cudnn.flags(enabled=False): with torch.backends.cudnn.flags(enabled=False):
loss = F.ctc_loss( loss = F.ctc_loss(
@@ -229,12 +248,12 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):
targets_flat, targets_flat,
input_lengths, input_lengths,
target_lengths, target_lengths,
blank=self.blank_idx,
blank=0,
reduction='sum', reduction='sum',
zero_infinity=True, zero_infinity=True,
) )


return loss
return loss / lprobs.shape[1]




def get_schedule(scheduler): def get_schedule(scheduler):


+ 108
- 0
tests/trainers/test_ofa_mmspeech_trainer.py View File

@@ -0,0 +1,108 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import unittest

import json

from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.utils.constant import DownloadMode, ModelFile
from modelscope.utils.test_utils import test_level


class TestMMSpeechTrainer(unittest.TestCase):

def setUp(self) -> None:
self.finetune_cfg = \
{'framework': 'pytorch',
'task': 'auto-speech-recognition',
'model': {'type': 'ofa',
'beam_search': {'beam_size': 5,
'max_len_b': 128,
'min_len': 1,
'no_repeat_ngram_size': 5,
'constraint_range': '4,21134'},
'seed': 7,
'max_src_length': 256,
'language': 'zh',
'gen_type': 'generation',
'multimodal_type': 'mmspeech'},
'pipeline': {'type': 'ofa-asr'},
'n_frames_per_step': 1,
'dataset': {'column_map': {'wav': 'Audio:FILE', 'text': 'Text:LABEL'}},
'train': {'work_dir': 'work/ckpts/asr_recognition',
# 'launcher': 'pytorch',
'max_epochs': 1,
'use_fp16': True,
'dataloader': {'batch_size_per_gpu': 16, 'workers_per_gpu': 0},
'lr_scheduler': {'name': 'polynomial_decay',
'warmup_proportion': 0.01,
'lr_end': 1e-07},
'lr_scheduler_hook': {'type': 'LrSchedulerHook', 'by_epoch': False},
'optimizer': {'type': 'AdamW', 'lr': 5e-05, 'weight_decay': 0.01},
'optimizer_hook': {'type': 'TorchAMPOptimizerHook',
'cumulative_iters': 1,
'grad_clip': {'max_norm': 1.0, 'norm_type': 2},
'loss_keys': 'loss'},
'criterion': {'name': 'AdjustLabelSmoothedCrossEntropyCriterion',
'constraint_range': '4,21134',
'drop_worst_after': 0,
'drop_worst_ratio': 0.0,
'ignore_eos': False,
'ignore_prefix_size': 0,
'label_smoothing': 0.1,
'reg_alpha': 1.0,
'report_accuracy': False,
'sample_patch_num': 196,
'sentence_avg': True,
'use_rdrop': False,
'ctc_weight': 1.0},
'hooks': [{'type': 'BestCkptSaverHook',
'metric_key': 'accuracy',
'interval': 100},
{'type': 'TextLoggerHook', 'interval': 1},
{'type': 'IterTimerHook'},
{'type': 'EvaluationHook', 'by_epoch': True, 'interval': 1}]},
'evaluation': {'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0},
'metrics': [{'type': 'accuracy'}]},
'preprocessor': []}

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer_std(self):
WORKSPACE = './workspace/ckpts/asr_recognition'
os.makedirs(WORKSPACE, exist_ok=True)
config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION)
with open(config_file, 'w') as writer:
json.dump(self.finetune_cfg, writer)

pretrained_model = 'damo/ofa_mmspeech_pretrain_base_zh'

args = dict(
model=pretrained_model,
work_dir=WORKSPACE,
train_dataset=MsDataset.load(
'aishell1_subset',
subset_name='default',
namespace='modelscope',
split='train',
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS),
eval_dataset=MsDataset.load(
'aishell1_subset',
subset_name='default',
namespace='modelscope',
split='test',
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS),
cfg_file=config_file)
trainer = build_trainer(name=Trainers.ofa, default_args=args)
trainer.train()

self.assertIn(
ModelFile.TORCH_MODEL_BIN_FILE,
os.listdir(os.path.join(WORKSPACE, ModelFile.TRAIN_OUTPUT_DIR)))
shutil.rmtree(WORKSPACE)


if __name__ == '__main__':
unittest.main()

+ 1
- 2
tests/trainers/test_ofa_trainer.py View File

@@ -76,8 +76,7 @@ class TestOfaTrainer(unittest.TestCase):
os.makedirs(WORKSPACE, exist_ok=True) os.makedirs(WORKSPACE, exist_ok=True)
config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION)
with open(config_file, 'w') as writer: with open(config_file, 'w') as writer:
json.dump(self.finetune_cfg, writer)

json.dump(self.finetune_cfg, writer, indent=4)
pretrained_model = 'damo/ofa_ocr-recognition_scene_base_zh' pretrained_model = 'damo/ofa_ocr-recognition_scene_base_zh'


args = dict( args = dict(


Loading…
Cancel
Save