@@ -48,13 +48,13 @@ class ANSPipeline(Pipeline): | |||||
def preprocess(self, inputs: Input) -> Dict[str, Any]: | def preprocess(self, inputs: Input) -> Dict[str, Any]: | ||||
if isinstance(inputs, bytes): | if isinstance(inputs, bytes): | ||||
raw_data, fs = sf.read(io.BytesIO(inputs)) | |||||
data1, fs = sf.read(io.BytesIO(inputs)) | |||||
elif isinstance(inputs, str): | elif isinstance(inputs, str): | ||||
raw_data, fs = sf.read(inputs) | |||||
data1, fs = sf.read(inputs) | |||||
else: | else: | ||||
raise TypeError(f'Unsupported type {type(inputs)}.') | raise TypeError(f'Unsupported type {type(inputs)}.') | ||||
if len(raw_data.shape) > 1: | |||||
data1 = raw_data[:, 0] | |||||
if len(data1.shape) > 1: | |||||
data1 = data1[:, 0] | |||||
if fs != self.SAMPLE_RATE: | if fs != self.SAMPLE_RATE: | ||||
data1 = librosa.resample(data1, fs, self.SAMPLE_RATE) | data1 = librosa.resample(data1, fs, self.SAMPLE_RATE) | ||||
data1 = audio_norm(data1) | data1 = audio_norm(data1) | ||||
@@ -65,6 +65,21 @@ class SpeechSignalProcessTest(unittest.TestCase): | |||||
ans(NOISE_SPEECH_FILE, output_path=output_path) | ans(NOISE_SPEECH_FILE, output_path=output_path) | ||||
print(f'Processed audio saved to {output_path}') | print(f'Processed audio saved to {output_path}') | ||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
def test_ans_bytes(self): | |||||
# Download audio files | |||||
download(NOISE_SPEECH_URL, NOISE_SPEECH_FILE) | |||||
model_id = 'damo/speech_frcrn_ans_cirm_16k' | |||||
ans = pipeline( | |||||
Tasks.speech_signal_process, | |||||
model=model_id, | |||||
pipeline_name=Pipelines.speech_frcrn_ans_cirm_16k) | |||||
output_path = os.path.abspath('output.wav') | |||||
with open(NOISE_SPEECH_FILE, 'rb') as f: | |||||
data = f.read() | |||||
ans(data, output_path=output_path) | |||||
print(f'Processed audio saved to {output_path}') | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
unittest.main() | unittest.main() |