diff --git a/modelscope/preprocessors/ofa/asr.py b/modelscope/preprocessors/ofa/asr.py index d74c2550..f4ae2097 100644 --- a/modelscope/preprocessors/ofa/asr.py +++ b/modelscope/preprocessors/ofa/asr.py @@ -55,7 +55,8 @@ class OfaASRPreprocessor(OfaBasePreprocessor): def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: speed = random.choice([0.9, 1.0, 1.1]) - wav, sr = librosa.load(data[self.column_map['wav']], 16000, mono=True) + audio_bytes = self.get_audio_bytes(data[self.column_map['wav']]) + wav, sr = librosa.load(audio_bytes, 16000, mono=True) fbank = self.prepare_fbank( torch.tensor([wav], dtype=torch.float32), sr, @@ -91,7 +92,8 @@ class OfaASRPreprocessor(OfaBasePreprocessor): def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: speed = 1.0 - wav, sr = librosa.load(data[self.column_map['wav']], 16000, mono=True) + audio_bytes = self.get_audio_bytes(data[self.column_map['wav']]) + wav, sr = librosa.load(audio_bytes, 16000, mono=True) fbank = self.prepare_fbank( torch.tensor([wav], dtype=torch.float32), sr, diff --git a/modelscope/preprocessors/ofa/base.py b/modelscope/preprocessors/ofa/base.py index 8f18fe7a..4faa22fe 100644 --- a/modelscope/preprocessors/ofa/base.py +++ b/modelscope/preprocessors/ofa/base.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import io import re import string from os import path as osp @@ -9,6 +10,7 @@ import torch import torchaudio from PIL import Image +from modelscope.fileio import File from modelscope.models.multi_modal.ofa import OFATokenizer, OFATokenizerZH from modelscope.preprocessors.image import load_image from modelscope.utils.trie import Trie @@ -170,6 +172,16 @@ class OfaBasePreprocessor: else load_image(path_or_url_or_pil) return image + def get_audio_bytes(self, path_or_url): + if isinstance(path_or_url, bytes): + audio_bytes = io.BytesIO(path_or_url) + elif isinstance(path_or_url, str): + file_bytes = File.read(path_or_url) + audio_bytes = io.BytesIO(file_bytes) + else: + raise TypeError(f'Unsupported input type: {type(path_or_url)}.') + return audio_bytes + def prepare_fbank(self, waveform, sample_rate, diff --git a/tests/pipelines/test_ofa_tasks.py b/tests/pipelines/test_ofa_tasks.py index 9e1b47a1..6dec2c57 100644 --- a/tests/pipelines/test_ofa_tasks.py +++ b/tests/pipelines/test_ofa_tasks.py @@ -275,7 +275,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_asr_with_name(self): - model = 'damo/ofa_asr_pretrain_base_zh' + model = 'damo/ofa_mmspeech_pretrain_base_zh' ofa_pipe = pipeline(Tasks.auto_speech_recognition, model=model) example = {'wav': 'data/test/audios/asr_example_ofa.wav'} result = ofa_pipe(example)