From 087e684da53ce89508aac61f6dd0f53c0454c4f3 Mon Sep 17 00:00:00 2001 From: "bin.xue" Date: Wed, 27 Jul 2022 18:45:17 +0800 Subject: [PATCH] [to #42322933] fix ans_pipeline bug and add test --- modelscope/pipelines/audio/ans_pipeline.py | 8 ++++---- tests/pipelines/test_speech_signal_process.py | 15 +++++++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/modelscope/pipelines/audio/ans_pipeline.py b/modelscope/pipelines/audio/ans_pipeline.py index 8289ae50..13d25934 100644 --- a/modelscope/pipelines/audio/ans_pipeline.py +++ b/modelscope/pipelines/audio/ans_pipeline.py @@ -48,13 +48,13 @@ class ANSPipeline(Pipeline): def preprocess(self, inputs: Input) -> Dict[str, Any]: if isinstance(inputs, bytes): - raw_data, fs = sf.read(io.BytesIO(inputs)) + data1, fs = sf.read(io.BytesIO(inputs)) elif isinstance(inputs, str): - raw_data, fs = sf.read(inputs) + data1, fs = sf.read(inputs) else: raise TypeError(f'Unsupported type {type(inputs)}.') - if len(raw_data.shape) > 1: - data1 = raw_data[:, 0] + if len(data1.shape) > 1: + data1 = data1[:, 0] if fs != self.SAMPLE_RATE: data1 = librosa.resample(data1, fs, self.SAMPLE_RATE) data1 = audio_norm(data1) diff --git a/tests/pipelines/test_speech_signal_process.py b/tests/pipelines/test_speech_signal_process.py index 9f00c4ef..f911c0eb 100644 --- a/tests/pipelines/test_speech_signal_process.py +++ b/tests/pipelines/test_speech_signal_process.py @@ -65,6 +65,21 @@ class SpeechSignalProcessTest(unittest.TestCase): ans(NOISE_SPEECH_FILE, output_path=output_path) print(f'Processed audio saved to {output_path}') + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_ans_bytes(self): + # Download audio files + download(NOISE_SPEECH_URL, NOISE_SPEECH_FILE) + model_id = 'damo/speech_frcrn_ans_cirm_16k' + ans = pipeline( + Tasks.speech_signal_process, + model=model_id, + pipeline_name=Pipelines.speech_frcrn_ans_cirm_16k) + output_path = os.path.abspath('output.wav') + with open(NOISE_SPEECH_FILE, 'rb') as f: + data = f.read() + ans(data, output_path=output_path) + print(f'Processed audio saved to {output_path}') + if __name__ == '__main__': unittest.main()