shichen.fsc yingda.chen 3 years ago
parent
commit
84c384cc57
3 changed files with 52 additions and 15 deletions
  1. +9
    -0
      modelscope/pipelines/audio/kws_kwsbp_pipeline.py
  2. +20
    -15
      modelscope/utils/audio/audio_utils.py
  3. +23
    -0
      tests/pipelines/test_key_word_spotting.py

+ 9
- 0
modelscope/pipelines/audio/kws_kwsbp_pipeline.py View File

@@ -8,6 +8,8 @@ from modelscope.models import Model
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import WavToLists
from modelscope.utils.audio.audio_utils import (extract_pcm_from_wav,
load_bytes_from_url)
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger

@@ -40,6 +42,13 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
if self.preprocessor is None:
self.preprocessor = WavToLists()

if isinstance(audio_in, str):
# load pcm data from url if audio_in is url str
audio_in = load_bytes_from_url(audio_in)
elif isinstance(audio_in, bytes):
# load pcm data from wav data if audio_in is wave format
audio_in = extract_pcm_from_wav(audio_in)

output = self.preprocessor.forward(self.model.forward(), audio_in)
output = self.forward(output)
rst = self.postprocess(output)


+ 20
- 15
modelscope/utils/audio/audio_utils.py View File

@@ -42,23 +42,28 @@ def extract_pcm_from_wav(wav: bytes) -> bytes:
if len(data) > 44:
frame_len = 44
file_len = len(data)
header_fields = {}
header_fields['ChunkID'] = str(data[0:4], 'UTF-8')
header_fields['Format'] = str(data[8:12], 'UTF-8')
header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8')
if header_fields['ChunkID'] == 'RIFF' and header_fields[
'Format'] == 'WAVE' and header_fields['Subchunk1ID'] == 'fmt ':
header_fields['SubChunk1Size'] = struct.unpack('<I',
data[16:20])[0]
try:
header_fields = {}
header_fields['ChunkID'] = str(data[0:4], 'UTF-8')
header_fields['Format'] = str(data[8:12], 'UTF-8')
header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8')
if header_fields['ChunkID'] == 'RIFF' and header_fields[
'Format'] == 'WAVE' and header_fields[
'Subchunk1ID'] == 'fmt ':
header_fields['SubChunk1Size'] = struct.unpack(
'<I', data[16:20])[0]

if header_fields['SubChunk1Size'] == 16:
frame_len = 44
elif header_fields['SubChunk1Size'] == 18:
frame_len = 46
else:
return data
if header_fields['SubChunk1Size'] == 16:
frame_len = 44
elif header_fields['SubChunk1Size'] == 18:
frame_len = 46
else:
return data

data = wav[frame_len:file_len]
data = wav[frame_len:file_len]
except Exception:
# no treatment
pass

return data



+ 23
- 0
tests/pipelines/test_key_word_spotting.py View File

@@ -18,6 +18,7 @@ logger = get_logger()

POS_WAV_FILE = 'data/test/audios/kws_xiaoyunxiaoyun.wav'
BOFANGYINYUE_WAV_FILE = 'data/test/audios/kws_bofangyinyue.wav'
URL_FILE = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/20200707_xiaoyun.wav'

POS_TESTSETS_FILE = 'pos_testsets.tar.gz'
POS_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testsets.tar.gz'
@@ -76,6 +77,22 @@ class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck):
}]
}
},
'test_run_with_url': {
'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'],
'checking_value': '小云小云',
'example': {
'wav_count':
1,
'kws_type':
'pcm',
'kws_list': [{
'keyword': '小云小云',
'offset': 0.69,
'length': 1.67,
'confidence': 0.996023
}]
}
},
'test_run_with_pos_testsets': {
'checking_item': ['recall'],
'example': {
@@ -237,6 +254,12 @@ class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck):
self.check_result('test_run_with_wav_by_customized_keywords',
kws_result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_url(self):
kws_result = self.run_pipeline(
model_id=self.model_id, audio_in=URL_FILE)
self.check_result('test_run_with_url', kws_result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_pos_testsets(self):
wav_file_path = download_and_untar(


Loading…
Cancel
Save