Browse Source

[to #42322933] fix ans_pipeline bug and add test

master
bin.xue 3 years ago
parent
commit
087e684da5
2 changed files with 19 additions and 4 deletions
  1. +4
    -4
      modelscope/pipelines/audio/ans_pipeline.py
  2. +15
    -0
      tests/pipelines/test_speech_signal_process.py

+ 4
- 4
modelscope/pipelines/audio/ans_pipeline.py View File

@@ -48,13 +48,13 @@ class ANSPipeline(Pipeline):


def preprocess(self, inputs: Input) -> Dict[str, Any]: def preprocess(self, inputs: Input) -> Dict[str, Any]:
if isinstance(inputs, bytes): if isinstance(inputs, bytes):
raw_data, fs = sf.read(io.BytesIO(inputs))
data1, fs = sf.read(io.BytesIO(inputs))
elif isinstance(inputs, str): elif isinstance(inputs, str):
raw_data, fs = sf.read(inputs)
data1, fs = sf.read(inputs)
else: else:
raise TypeError(f'Unsupported type {type(inputs)}.') raise TypeError(f'Unsupported type {type(inputs)}.')
if len(raw_data.shape) > 1:
data1 = raw_data[:, 0]
if len(data1.shape) > 1:
data1 = data1[:, 0]
if fs != self.SAMPLE_RATE: if fs != self.SAMPLE_RATE:
data1 = librosa.resample(data1, fs, self.SAMPLE_RATE) data1 = librosa.resample(data1, fs, self.SAMPLE_RATE)
data1 = audio_norm(data1) data1 = audio_norm(data1)


+ 15
- 0
tests/pipelines/test_speech_signal_process.py View File

@@ -65,6 +65,21 @@ class SpeechSignalProcessTest(unittest.TestCase):
ans(NOISE_SPEECH_FILE, output_path=output_path) ans(NOISE_SPEECH_FILE, output_path=output_path)
print(f'Processed audio saved to {output_path}') print(f'Processed audio saved to {output_path}')


@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_ans_bytes(self):
# Download audio files
download(NOISE_SPEECH_URL, NOISE_SPEECH_FILE)
model_id = 'damo/speech_frcrn_ans_cirm_16k'
ans = pipeline(
Tasks.speech_signal_process,
model=model_id,
pipeline_name=Pipelines.speech_frcrn_ans_cirm_16k)
output_path = os.path.abspath('output.wav')
with open(NOISE_SPEECH_FILE, 'rb') as f:
data = f.read()
ans(data, output_path=output_path)
print(f'Processed audio saved to {output_path}')



if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

Loading…
Cancel
Save