Browse Source

[tests] add unittest

master^2
pengzhendong 2 years ago
parent
commit
2605824dea
3 changed files with 146 additions and 22 deletions
  1. +8
    -15
      modelscope/models/audio/asr/wenet_automatic_speech_recognition.py
  2. +7
    -7
      modelscope/pipelines/audio/asr_wenet_inference_pipeline.py
  3. +131
    -0
      tests/pipelines/test_wenet_automatic_speech_recognition.py

+ 8
- 15
modelscope/models/audio/asr/wenet_automatic_speech_recognition.py View File

@@ -8,6 +8,7 @@ from modelscope.models.base import Model
from modelscope.models.builder import MODELS from modelscope.models.builder import MODELS
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks


import json
import wenetruntime as wenet import wenetruntime as wenet


__all__ = ['WeNetAutomaticSpeechRecognition'] __all__ = ['WeNetAutomaticSpeechRecognition']
@@ -23,23 +24,15 @@ class WeNetAutomaticSpeechRecognition(Model):


Args: Args:
model_dir (str): the model path. model_dir (str): the model path.
am_model_name (str): the am model name from configuration.json
model_config (Dict[str, Any]): the detail config about model from configuration.json
""" """
super().__init__(model_dir, am_model_name, model_config, *args, super().__init__(model_dir, am_model_name, model_config, *args,
**kwargs) **kwargs)
self.model_cfg = {
# the recognition model dir path
'model_dir': model_dir,
# the recognition model config dict
'model_config': model_config
}
self.decoder = None

def forward(self) -> Dict[str, Any]:
"""preload model and return the info of the model
"""
model_dir = self.model_cfg['model_dir']
self.decoder = wenet.Decoder(model_dir, lang='chs') self.decoder = wenet.Decoder(model_dir, lang='chs')


return self.model_cfg
def forward(self, inputs: Dict[str, Any]) -> str:
if inputs['audio_format'] == 'wav':
rst = self.decoder.decode_wav(inputs['audio'])
else:
rst = self.decoder.decode(inputs['audio'])
text = json.loads(rst)['nbest'][0]['sentence']
return {'text': text}

+ 7
- 7
modelscope/pipelines/audio/asr_wenet_inference_pipeline.py View File

@@ -29,8 +29,6 @@ class WeNetAutomaticSpeechRecognitionPipeline(Pipeline):
"""use `model` and `preprocessor` to create an asr pipeline for prediction """use `model` and `preprocessor` to create an asr pipeline for prediction
""" """
super().__init__(model=model, preprocessor=preprocessor, **kwargs) super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.model_cfg = self.model.forward()
self.decoder = self.model.decoder


def __call__(self, def __call__(self,
audio_in: Union[str, bytes], audio_in: Union[str, bytes],
@@ -68,17 +66,19 @@ class WeNetAutomaticSpeechRecognitionPipeline(Pipeline):
if checking_audio_fs is not None: if checking_audio_fs is not None:
self.audio_fs = checking_audio_fs self.audio_fs = checking_audio_fs


self.model_cfg['audio'] = self.audio_in
self.model_cfg['audio_fs'] = self.audio_fs

output = self.forward(self.model_cfg)
inputs = {
'audio': self.audio_in,
'audio_format': self.audio_format,
'audio_fs': self.audio_fs
}
output = self.forward(inputs)
rst = self.postprocess(output['asr_result']) rst = self.postprocess(output['asr_result'])
return rst return rst


def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]: def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Decoding """Decoding
""" """
inputs['asr_result'] = self.decoder.decode(inputs['audio'])
inputs['asr_result'] = self.model(inputs)
return inputs return inputs


def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:


+ 131
- 0
tests/pipelines/test_wenet_automatic_speech_recognition.py View File

@@ -0,0 +1,131 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import unittest
from typing import Any, Dict, Union

import numpy as np
import soundfile

from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import ColorCodes, Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.logger import get_logger
from modelscope.utils.test_utils import download_and_untar, test_level

logger = get_logger()

WAV_FILE = 'data/test/audios/asr_example.wav'
URL_FILE = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example.wav'


class WeNetAutomaticSpeechRecognitionTest(unittest.TestCase,
DemoCompatibilityCheck):
action_info = {
'test_run_with_pcm': {
'checking_item': OutputKeys.TEXT,
'example': 'wav_example'
},
'test_run_with_url': {
'checking_item': OutputKeys.TEXT,
'example': 'wav_example'
},
'test_run_with_wav': {
'checking_item': OutputKeys.TEXT,
'example': 'wav_example'
},
'wav_example': {
'text': '每一天都要快乐喔'
}
}

def setUp(self) -> None:
self.am_model_id = 'wenet/u2pp_conformer-asr-cn-16k-online'
# this temporary workspace dir will store waveform files
self.workspace = os.path.join(os.getcwd(), '.tmp')
self.task = Tasks.auto_speech_recognition
if not os.path.exists(self.workspace):
os.mkdir(self.workspace)

def tearDown(self) -> None:
# remove workspace dir (.tmp)
shutil.rmtree(self.workspace, ignore_errors=True)

def run_pipeline(self,
model_id: str,
audio_in: Union[str, bytes],
sr: int = None) -> Dict[str, Any]:
inference_16k_pipline = pipeline(
task=Tasks.auto_speech_recognition, model=model_id)
rec_result = inference_16k_pipline(audio_in, audio_fs=sr)
return rec_result

def log_error(self, functions: str, result: Dict[str, Any]) -> None:
logger.error(ColorCodes.MAGENTA + functions + ': FAILED.'
+ ColorCodes.END)
logger.error(
ColorCodes.MAGENTA + functions + ' correct result example:'
+ ColorCodes.YELLOW
+ str(self.action_info[self.action_info[functions]['example']])
+ ColorCodes.END)
raise ValueError('asr result is mismatched')

def check_result(self, functions: str, result: Dict[str, Any]) -> None:
if result.__contains__(self.action_info[functions]['checking_item']):
logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.'
+ ColorCodes.END)
logger.info(
ColorCodes.YELLOW
+ str(result[self.action_info[functions]['checking_item']])
+ ColorCodes.END)
else:
self.log_error(functions, result)

def wav2bytes(self, wav_file):
audio, fs = soundfile.read(wav_file)

# float32 -> int16
audio = np.asarray(audio)
dtype = np.dtype('int16')
i = np.iinfo(dtype)
abs_max = 2**(i.bits - 1)
offset = i.min + abs_max
audio = (audio * abs_max + offset).clip(i.min, i.max).astype(dtype)

# int16(PCM_16) -> byte
audio = audio.tobytes()
return audio, fs

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_pcm(self):
"""run with wav data
"""
logger.info('Run ASR test with wav data (wenet)...')
audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE))
rec_result = self.run_pipeline(
model_id=self.am_model_id, audio_in=audio, sr=sr)
self.check_result('test_run_with_pcm', rec_result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav(self):
"""run with single waveform file
"""
logger.info('Run ASR test with waveform file (wenet)...')
wav_file_path = os.path.join(os.getcwd(), WAV_FILE)
rec_result = self.run_pipeline(
model_id=self.am_model_id, audio_in=wav_file_path)
self.check_result('test_run_with_wav', rec_result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_url(self):
"""run with single url file
"""
logger.info('Run ASR test with url file (wenet)...')
rec_result = self.run_pipeline(
model_id=self.am_model_id, audio_in=URL_FILE)
self.check_result('test_run_with_url', rec_result)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save