Browse Source

ofa asr support url

master^2
yichang.zyc 3 years ago
parent
commit
90034236ab
3 changed files with 17 additions and 3 deletions
  1. +4
    -2
      modelscope/preprocessors/ofa/asr.py
  2. +12
    -0
      modelscope/preprocessors/ofa/base.py
  3. +1
    -1
      tests/pipelines/test_ofa_tasks.py

+ 4
- 2
modelscope/preprocessors/ofa/asr.py View File

@@ -55,7 +55,8 @@ class OfaASRPreprocessor(OfaBasePreprocessor):


def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
speed = random.choice([0.9, 1.0, 1.1]) speed = random.choice([0.9, 1.0, 1.1])
wav, sr = librosa.load(data[self.column_map['wav']], 16000, mono=True)
audio_bytes = self.get_audio_bytes(data[self.column_map['wav']])
wav, sr = librosa.load(audio_bytes, 16000, mono=True)
fbank = self.prepare_fbank( fbank = self.prepare_fbank(
torch.tensor([wav], dtype=torch.float32), torch.tensor([wav], dtype=torch.float32),
sr, sr,
@@ -91,7 +92,8 @@ class OfaASRPreprocessor(OfaBasePreprocessor):


def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
speed = 1.0 speed = 1.0
wav, sr = librosa.load(data[self.column_map['wav']], 16000, mono=True)
audio_bytes = self.get_audio_bytes(data[self.column_map['wav']])
wav, sr = librosa.load(audio_bytes, 16000, mono=True)
fbank = self.prepare_fbank( fbank = self.prepare_fbank(
torch.tensor([wav], dtype=torch.float32), torch.tensor([wav], dtype=torch.float32),
sr, sr,


+ 12
- 0
modelscope/preprocessors/ofa/base.py View File

@@ -1,4 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import io
import re import re
import string import string
from os import path as osp from os import path as osp
@@ -9,6 +10,7 @@ import torch
import torchaudio import torchaudio
from PIL import Image from PIL import Image


from modelscope.fileio import File
from modelscope.models.multi_modal.ofa import OFATokenizer, OFATokenizerZH from modelscope.models.multi_modal.ofa import OFATokenizer, OFATokenizerZH
from modelscope.preprocessors.image import load_image from modelscope.preprocessors.image import load_image
from modelscope.utils.trie import Trie from modelscope.utils.trie import Trie
@@ -170,6 +172,16 @@ class OfaBasePreprocessor:
else load_image(path_or_url_or_pil) else load_image(path_or_url_or_pil)
return image return image


def get_audio_bytes(self, path_or_url):
if isinstance(path_or_url, bytes):
audio_bytes = io.BytesIO(path_or_url)
elif isinstance(path_or_url, str):
file_bytes = File.read(path_or_url)
audio_bytes = io.BytesIO(file_bytes)
else:
raise TypeError(f'Unsupported input type: {type(path_or_url)}.')
return audio_bytes

def prepare_fbank(self, def prepare_fbank(self,
waveform, waveform,
sample_rate, sample_rate,


+ 1
- 1
tests/pipelines/test_ofa_tasks.py View File

@@ -275,7 +275,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck):


@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_asr_with_name(self): def test_run_with_asr_with_name(self):
model = 'damo/ofa_asr_pretrain_base_zh'
model = 'damo/ofa_mmspeech_pretrain_base_zh'
ofa_pipe = pipeline(Tasks.auto_speech_recognition, model=model) ofa_pipe = pipeline(Tasks.auto_speech_recognition, model=model)
example = {'wav': 'data/test/audios/asr_example_ofa.wav'} example = {'wav': 'data/test/audios/asr_example_ofa.wav'}
result = ofa_pipe(example) result = ofa_pipe(example)


Loading…
Cancel
Save