Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10078262master
@@ -8,6 +8,8 @@ from modelscope.models import Model | |||
from modelscope.pipelines.base import Pipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import WavToLists | |||
from modelscope.utils.audio.audio_utils import (extract_pcm_from_wav, | |||
load_bytes_from_url) | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.logger import get_logger | |||
@@ -40,6 +42,13 @@ class KeyWordSpottingKwsbpPipeline(Pipeline): | |||
if self.preprocessor is None: | |||
self.preprocessor = WavToLists() | |||
if isinstance(audio_in, str): | |||
# load pcm data from url if audio_in is url str | |||
audio_in = load_bytes_from_url(audio_in) | |||
elif isinstance(audio_in, bytes): | |||
# load pcm data from wav data if audio_in is wave format | |||
audio_in = extract_pcm_from_wav(audio_in) | |||
output = self.preprocessor.forward(self.model.forward(), audio_in) | |||
output = self.forward(output) | |||
rst = self.postprocess(output) | |||
@@ -42,23 +42,28 @@ def extract_pcm_from_wav(wav: bytes) -> bytes: | |||
if len(data) > 44: | |||
frame_len = 44 | |||
file_len = len(data) | |||
header_fields = {} | |||
header_fields['ChunkID'] = str(data[0:4], 'UTF-8') | |||
header_fields['Format'] = str(data[8:12], 'UTF-8') | |||
header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8') | |||
if header_fields['ChunkID'] == 'RIFF' and header_fields[ | |||
'Format'] == 'WAVE' and header_fields['Subchunk1ID'] == 'fmt ': | |||
header_fields['SubChunk1Size'] = struct.unpack('<I', | |||
data[16:20])[0] | |||
try: | |||
header_fields = {} | |||
header_fields['ChunkID'] = str(data[0:4], 'UTF-8') | |||
header_fields['Format'] = str(data[8:12], 'UTF-8') | |||
header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8') | |||
if header_fields['ChunkID'] == 'RIFF' and header_fields[ | |||
'Format'] == 'WAVE' and header_fields[ | |||
'Subchunk1ID'] == 'fmt ': | |||
header_fields['SubChunk1Size'] = struct.unpack( | |||
'<I', data[16:20])[0] | |||
if header_fields['SubChunk1Size'] == 16: | |||
frame_len = 44 | |||
elif header_fields['SubChunk1Size'] == 18: | |||
frame_len = 46 | |||
else: | |||
return data | |||
if header_fields['SubChunk1Size'] == 16: | |||
frame_len = 44 | |||
elif header_fields['SubChunk1Size'] == 18: | |||
frame_len = 46 | |||
else: | |||
return data | |||
data = wav[frame_len:file_len] | |||
data = wav[frame_len:file_len] | |||
except Exception: | |||
# no treatment | |||
pass | |||
return data | |||
@@ -18,6 +18,7 @@ logger = get_logger() | |||
POS_WAV_FILE = 'data/test/audios/kws_xiaoyunxiaoyun.wav' | |||
BOFANGYINYUE_WAV_FILE = 'data/test/audios/kws_bofangyinyue.wav' | |||
URL_FILE = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/20200707_xiaoyun.wav' | |||
POS_TESTSETS_FILE = 'pos_testsets.tar.gz' | |||
POS_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testsets.tar.gz' | |||
@@ -76,6 +77,22 @@ class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck): | |||
}] | |||
} | |||
}, | |||
'test_run_with_url': { | |||
'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'], | |||
'checking_value': '小云小云', | |||
'example': { | |||
'wav_count': | |||
1, | |||
'kws_type': | |||
'pcm', | |||
'kws_list': [{ | |||
'keyword': '小云小云', | |||
'offset': 0.69, | |||
'length': 1.67, | |||
'confidence': 0.996023 | |||
}] | |||
} | |||
}, | |||
'test_run_with_pos_testsets': { | |||
'checking_item': ['recall'], | |||
'example': { | |||
@@ -237,6 +254,12 @@ class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck): | |||
self.check_result('test_run_with_wav_by_customized_keywords', | |||
kws_result) | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_run_with_url(self): | |||
kws_result = self.run_pipeline( | |||
model_id=self.model_id, audio_in=URL_FILE) | |||
self.check_result('test_run_with_url', kws_result) | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_run_with_pos_testsets(self): | |||
wav_file_path = download_and_untar( | |||