shiyi.zxh yingda.chen 3 years ago
parent
commit
b386a4ee50
3 changed files with 18 additions and 6 deletions
  1. +9
    -3
      modelscope/preprocessors/ofa/asr.py
  2. +8
    -3
      modelscope/preprocessors/ofa/base.py
  3. +1
    -0
      requirements/multi-modal.txt

+ 9
- 3
modelscope/preprocessors/ofa/asr.py View File

@@ -5,6 +5,7 @@ import random
from pathlib import Path
from typing import Any, Dict

import librosa
import soundfile as sf
import torch
from fairseq.data.audio.feature_transforms import \
@@ -54,9 +55,13 @@ class OfaASRPreprocessor(OfaBasePreprocessor):

def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
speed = random.choice([0.9, 1.0, 1.1])
wav, sr = sf.read(self.column_map['wav'])
wav, sr = librosa.load(data[self.column_map['wav']], 16000, mono=True)
fbank = self.prepare_fbank(
torch.tensor([wav], dtype=torch.float32), sr, speed, is_train=True)
torch.tensor([wav], dtype=torch.float32),
sr,
speed,
target_sample_rate=16000,
is_train=True)
fbank_mask = torch.tensor([True])
sample = {
'fbank': fbank,
@@ -86,11 +91,12 @@ class OfaASRPreprocessor(OfaBasePreprocessor):

def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
speed = 1.0
wav, sr = sf.read(data[self.column_map['wav']])
wav, sr = librosa.load(data[self.column_map['wav']], 16000, mono=True)
fbank = self.prepare_fbank(
torch.tensor([wav], dtype=torch.float32),
sr,
speed,
target_sample_rate=16000,
is_train=False)
fbank_mask = torch.tensor([True])



+ 8
- 3
modelscope/preprocessors/ofa/base.py View File

@@ -170,10 +170,15 @@ class OfaBasePreprocessor:
else load_image(path_or_url_or_pil)
return image

def prepare_fbank(self, waveform, sample_rate, speed, is_train):
waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
def prepare_fbank(self,
waveform,
sample_rate,
speed,
target_sample_rate=16000,
is_train=False):
waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
waveform, sample_rate,
[['speed', str(speed)], ['rate', str(sample_rate)]])
[['speed', str(speed)], ['rate', str(target_sample_rate)]])
_waveform, _ = convert_waveform(
waveform, sample_rate, to_mono=True, normalize_volume=True)
# Kaldi compliance: 16-bit signed integers


+ 1
- 0
requirements/multi-modal.txt View File

@@ -1,4 +1,5 @@
ftfy>=6.0.3
librosa
ofa>=0.0.2
pycocoevalcap>=1.2
pycocotools>=2.0.4


Loading…
Cancel
Save