| @@ -55,7 +55,8 @@ class OfaASRPreprocessor(OfaBasePreprocessor): | |||||
| def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | ||||
| speed = random.choice([0.9, 1.0, 1.1]) | speed = random.choice([0.9, 1.0, 1.1]) | ||||
| wav, sr = librosa.load(data[self.column_map['wav']], 16000, mono=True) | |||||
| audio_bytes = self.get_audio_bytes(data[self.column_map['wav']]) | |||||
| wav, sr = librosa.load(audio_bytes, 16000, mono=True) | |||||
| fbank = self.prepare_fbank( | fbank = self.prepare_fbank( | ||||
| torch.tensor([wav], dtype=torch.float32), | torch.tensor([wav], dtype=torch.float32), | ||||
| sr, | sr, | ||||
| @@ -91,7 +92,8 @@ class OfaASRPreprocessor(OfaBasePreprocessor): | |||||
| def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | ||||
| speed = 1.0 | speed = 1.0 | ||||
| wav, sr = librosa.load(data[self.column_map['wav']], 16000, mono=True) | |||||
| audio_bytes = self.get_audio_bytes(data[self.column_map['wav']]) | |||||
| wav, sr = librosa.load(audio_bytes, 16000, mono=True) | |||||
| fbank = self.prepare_fbank( | fbank = self.prepare_fbank( | ||||
| torch.tensor([wav], dtype=torch.float32), | torch.tensor([wav], dtype=torch.float32), | ||||
| sr, | sr, | ||||
| @@ -1,4 +1,5 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import io | |||||
| import re | import re | ||||
| import string | import string | ||||
| from os import path as osp | from os import path as osp | ||||
| @@ -9,6 +10,7 @@ import torch | |||||
| import torchaudio | import torchaudio | ||||
| from PIL import Image | from PIL import Image | ||||
| from modelscope.fileio import File | |||||
| from modelscope.models.multi_modal.ofa import OFATokenizer, OFATokenizerZH | from modelscope.models.multi_modal.ofa import OFATokenizer, OFATokenizerZH | ||||
| from modelscope.preprocessors.image import load_image | from modelscope.preprocessors.image import load_image | ||||
| from modelscope.utils.trie import Trie | from modelscope.utils.trie import Trie | ||||
| @@ -170,6 +172,16 @@ class OfaBasePreprocessor: | |||||
| else load_image(path_or_url_or_pil) | else load_image(path_or_url_or_pil) | ||||
| return image | return image | ||||
| def get_audio_bytes(self, path_or_url): | |||||
| if isinstance(path_or_url, bytes): | |||||
| audio_bytes = io.BytesIO(path_or_url) | |||||
| elif isinstance(path_or_url, str): | |||||
| file_bytes = File.read(path_or_url) | |||||
| audio_bytes = io.BytesIO(file_bytes) | |||||
| else: | |||||
| raise TypeError(f'Unsupported input type: {type(path_or_url)}.') | |||||
| return audio_bytes | |||||
| def prepare_fbank(self, | def prepare_fbank(self, | ||||
| waveform, | waveform, | ||||
| sample_rate, | sample_rate, | ||||
| @@ -275,7 +275,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_run_with_asr_with_name(self): | def test_run_with_asr_with_name(self): | ||||
| model = 'damo/ofa_asr_pretrain_base_zh' | |||||
| model = 'damo/ofa_mmspeech_pretrain_base_zh' | |||||
| ofa_pipe = pipeline(Tasks.auto_speech_recognition, model=model) | ofa_pipe = pipeline(Tasks.auto_speech_recognition, model=model) | ||||
| example = {'wav': 'data/test/audios/asr_example_ofa.wav'} | example = {'wav': 'data/test/audios/asr_example_ofa.wav'} | ||||
| result = ofa_pipe(example) | result = ofa_pipe(example) | ||||