You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

asr_inference_pipeline.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict, List, Sequence, Tuple, Union
  3. import yaml
  4. from modelscope.metainfo import Pipelines
  5. from modelscope.models import Model
  6. from modelscope.outputs import OutputKeys
  7. from modelscope.pipelines.base import Pipeline
  8. from modelscope.pipelines.builder import PIPELINES
  9. from modelscope.preprocessors import WavToScp
  10. from modelscope.utils.audio.audio_utils import (extract_pcm_from_wav,
  11. load_bytes_from_url)
  12. from modelscope.utils.constant import Frameworks, Tasks
  13. from modelscope.utils.logger import get_logger
  14. logger = get_logger()
  15. __all__ = ['AutomaticSpeechRecognitionPipeline']
  16. @PIPELINES.register_module(
  17. Tasks.auto_speech_recognition, module_name=Pipelines.asr_inference)
  18. class AutomaticSpeechRecognitionPipeline(Pipeline):
  19. """ASR Inference Pipeline
  20. """
  21. def __init__(self,
  22. model: Union[Model, str] = None,
  23. preprocessor: WavToScp = None,
  24. **kwargs):
  25. """use `model` and `preprocessor` to create an asr pipeline for prediction
  26. """
  27. super().__init__(model=model, preprocessor=preprocessor, **kwargs)
  28. self.model_cfg = self.model.forward()
  29. def __call__(self,
  30. audio_in: Union[str, bytes],
  31. audio_fs: int = None,
  32. recog_type: str = None,
  33. audio_format: str = None) -> Dict[str, Any]:
  34. from funasr.utils import asr_utils
  35. self.recog_type = recog_type
  36. self.audio_format = audio_format
  37. self.audio_fs = audio_fs
  38. if isinstance(audio_in, str):
  39. # load pcm data from url if audio_in is url str
  40. self.audio_in, checking_audio_fs = load_bytes_from_url(audio_in)
  41. elif isinstance(audio_in, bytes):
  42. # load pcm data from wav data if audio_in is wave format
  43. self.audio_in, checking_audio_fs = extract_pcm_from_wav(audio_in)
  44. else:
  45. self.audio_in = audio_in
  46. # set the sample_rate of audio_in if checking_audio_fs is valid
  47. if checking_audio_fs is not None:
  48. self.audio_fs = checking_audio_fs
  49. if recog_type is None or audio_format is None:
  50. self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking(
  51. audio_in=self.audio_in,
  52. recog_type=recog_type,
  53. audio_format=audio_format)
  54. if hasattr(asr_utils, 'sample_rate_checking'):
  55. checking_audio_fs = asr_utils.sample_rate_checking(
  56. self.audio_in, self.audio_format)
  57. if checking_audio_fs is not None:
  58. self.audio_fs = checking_audio_fs
  59. if self.preprocessor is None:
  60. self.preprocessor = WavToScp()
  61. output = self.preprocessor.forward(self.model_cfg, self.recog_type,
  62. self.audio_format, self.audio_in,
  63. self.audio_fs)
  64. output = self.forward(output)
  65. rst = self.postprocess(output)
  66. return rst
  67. def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  68. """Decoding
  69. """
  70. logger.info(f"Decoding with {inputs['audio_format']} files ...")
  71. data_cmd: Sequence[Tuple[str, str, str]]
  72. if inputs['audio_format'] == 'wav' or inputs['audio_format'] == 'pcm':
  73. data_cmd = ['speech', 'sound']
  74. elif inputs['audio_format'] == 'kaldi_ark':
  75. data_cmd = ['speech', 'kaldi_ark']
  76. elif inputs['audio_format'] == 'tfrecord':
  77. data_cmd = ['speech', 'tfrecord']
  78. if inputs.__contains__('mvn_file'):
  79. data_cmd.append(inputs['mvn_file'])
  80. # generate asr inference command
  81. cmd = {
  82. 'model_type': inputs['model_type'],
  83. 'ngpu': 1, # 0: only CPU, ngpu>=1: gpu number if cuda is available
  84. 'log_level': 'ERROR',
  85. 'audio_in': inputs['audio_lists'],
  86. 'name_and_type': data_cmd,
  87. 'asr_model_file': inputs['am_model_path'],
  88. 'idx_text': '',
  89. 'sampled_ids': 'seq2seq/sampled_ids',
  90. 'sampled_lengths': 'seq2seq/sampled_lengths',
  91. 'lang': 'zh-cn',
  92. 'code_base': inputs['code_base'],
  93. 'mode': inputs['mode'],
  94. 'fs': {
  95. 'audio_fs': inputs['audio_fs'],
  96. 'model_fs': 16000
  97. }
  98. }
  99. if self.framework == Frameworks.torch:
  100. config_file = open(inputs['asr_model_config'], encoding='utf-8')
  101. root = yaml.full_load(config_file)
  102. config_file.close()
  103. frontend_conf = None
  104. if 'frontend_conf' in root:
  105. frontend_conf = root['frontend_conf']
  106. token_num_relax = None
  107. if 'token_num_relax' in root:
  108. token_num_relax = root['token_num_relax']
  109. decoding_ind = None
  110. if 'decoding_ind' in root:
  111. decoding_ind = root['decoding_ind']
  112. decoding_mode = None
  113. if 'decoding_mode' in root:
  114. decoding_mode = root['decoding_mode']
  115. cmd['beam_size'] = root['beam_size']
  116. cmd['penalty'] = root['penalty']
  117. cmd['maxlenratio'] = root['maxlenratio']
  118. cmd['minlenratio'] = root['minlenratio']
  119. cmd['ctc_weight'] = root['ctc_weight']
  120. cmd['lm_weight'] = root['lm_weight']
  121. cmd['asr_train_config'] = inputs['am_model_config']
  122. cmd['lm_file'] = inputs['lm_model_path']
  123. cmd['lm_train_config'] = inputs['lm_model_config']
  124. cmd['batch_size'] = inputs['model_config']['batch_size']
  125. cmd['frontend_conf'] = frontend_conf
  126. if frontend_conf is not None and 'fs' in frontend_conf:
  127. cmd['fs']['model_fs'] = frontend_conf['fs']
  128. cmd['token_num_relax'] = token_num_relax
  129. cmd['decoding_ind'] = decoding_ind
  130. cmd['decoding_mode'] = decoding_mode
  131. elif self.framework == Frameworks.tf:
  132. cmd['fs']['model_fs'] = inputs['model_config']['fs']
  133. cmd['hop_length'] = inputs['model_config']['hop_length']
  134. cmd['feature_dims'] = inputs['model_config']['feature_dims']
  135. cmd['predictions_file'] = 'text'
  136. cmd['mvn_file'] = inputs['am_mvn_file']
  137. cmd['vocab_file'] = inputs['vocab_file']
  138. cmd['lang'] = inputs['model_lang']
  139. if 'idx_text' in inputs:
  140. cmd['idx_text'] = inputs['idx_text']
  141. if 'sampled_ids' in inputs['model_config']:
  142. cmd['sampled_ids'] = inputs['model_config']['sampled_ids']
  143. if 'sampled_lengths' in inputs['model_config']:
  144. cmd['sampled_lengths'] = inputs['model_config'][
  145. 'sampled_lengths']
  146. else:
  147. raise ValueError('model type is mismatching')
  148. inputs['asr_result'] = self.run_inference(cmd)
  149. return inputs
  150. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  151. """process the asr results
  152. """
  153. from funasr.utils import asr_utils
  154. logger.info('Computing the result of ASR ...')
  155. rst = {}
  156. # single wav or pcm task
  157. if inputs['recog_type'] == 'wav':
  158. if 'asr_result' in inputs and len(inputs['asr_result']) > 0:
  159. text = inputs['asr_result'][0]['value']
  160. if len(text) > 0:
  161. rst[OutputKeys.TEXT] = text
  162. # run with datasets, and audio format is waveform or kaldi_ark or tfrecord
  163. elif inputs['recog_type'] != 'wav':
  164. inputs['reference_list'] = self.ref_list_tidy(inputs)
  165. if hasattr(asr_utils, 'set_parameters'):
  166. asr_utils.set_parameters(language=inputs['model_lang'])
  167. inputs['datasets_result'] = asr_utils.compute_wer(
  168. hyp_list=inputs['asr_result'],
  169. ref_list=inputs['reference_list'])
  170. else:
  171. raise ValueError('recog_type and audio_format are mismatching')
  172. if 'datasets_result' in inputs:
  173. rst[OutputKeys.TEXT] = inputs['datasets_result']
  174. return rst
  175. def ref_list_tidy(self, inputs: Dict[str, Any]) -> List[Any]:
  176. ref_list = []
  177. if inputs['audio_format'] == 'tfrecord':
  178. # should assemble idx + txt
  179. with open(inputs['reference_text'], 'r', encoding='utf-8') as r:
  180. text_lines = r.readlines()
  181. with open(inputs['idx_text'], 'r', encoding='utf-8') as i:
  182. idx_lines = i.readlines()
  183. j: int = 0
  184. while j < min(len(text_lines), len(idx_lines)):
  185. idx_str = idx_lines[j].strip()
  186. text_str = text_lines[j].strip().replace(' ', '')
  187. item = {'key': idx_str, 'value': text_str}
  188. ref_list.append(item)
  189. j += 1
  190. else:
  191. # text contain idx + sentence
  192. with open(inputs['reference_text'], 'r', encoding='utf-8') as f:
  193. lines = f.readlines()
  194. for line in lines:
  195. line_item = line.split(None, 1)
  196. if len(line_item) > 1:
  197. item = {
  198. 'key': line_item[0],
  199. 'value': line_item[1].strip('\n')
  200. }
  201. ref_list.append(item)
  202. return ref_list
  203. def run_inference(self, cmd):
  204. asr_result = []
  205. if self.framework == Frameworks.torch and cmd['code_base'] == 'funasr':
  206. from funasr.bin import asr_inference_launch
  207. if hasattr(asr_inference_launch, 'set_parameters'):
  208. asr_inference_launch.set_parameters(sample_rate=cmd['fs'])
  209. asr_inference_launch.set_parameters(language=cmd['lang'])
  210. asr_result = asr_inference_launch.inference_launch(
  211. mode=cmd['mode'],
  212. batch_size=cmd['batch_size'],
  213. maxlenratio=cmd['maxlenratio'],
  214. minlenratio=cmd['minlenratio'],
  215. beam_size=cmd['beam_size'],
  216. ngpu=cmd['ngpu'],
  217. ctc_weight=cmd['ctc_weight'],
  218. lm_weight=cmd['lm_weight'],
  219. penalty=cmd['penalty'],
  220. log_level=cmd['log_level'],
  221. data_path_and_name_and_type=cmd['name_and_type'],
  222. audio_lists=cmd['audio_in'],
  223. asr_train_config=cmd['asr_train_config'],
  224. asr_model_file=cmd['asr_model_file'],
  225. lm_file=cmd['lm_file'],
  226. lm_train_config=cmd['lm_train_config'],
  227. frontend_conf=cmd['frontend_conf'],
  228. token_num_relax=cmd['token_num_relax'],
  229. decoding_ind=cmd['decoding_ind'],
  230. decoding_mode=cmd['decoding_mode'])
  231. elif self.framework == Frameworks.torch:
  232. from easyasr import asr_inference_paraformer_espnet
  233. if hasattr(asr_inference_paraformer_espnet, 'set_parameters'):
  234. asr_inference_paraformer_espnet.set_parameters(
  235. sample_rate=cmd['fs'])
  236. asr_inference_paraformer_espnet.set_parameters(
  237. language=cmd['lang'])
  238. asr_result = asr_inference_paraformer_espnet.asr_inference(
  239. batch_size=cmd['batch_size'],
  240. maxlenratio=cmd['maxlenratio'],
  241. minlenratio=cmd['minlenratio'],
  242. beam_size=cmd['beam_size'],
  243. ngpu=cmd['ngpu'],
  244. ctc_weight=cmd['ctc_weight'],
  245. lm_weight=cmd['lm_weight'],
  246. penalty=cmd['penalty'],
  247. log_level=cmd['log_level'],
  248. name_and_type=cmd['name_and_type'],
  249. audio_lists=cmd['audio_in'],
  250. asr_train_config=cmd['asr_train_config'],
  251. asr_model_file=cmd['asr_model_file'],
  252. frontend_conf=cmd['frontend_conf'])
  253. elif self.framework == Frameworks.tf:
  254. from easyasr import asr_inference_paraformer_tf
  255. if hasattr(asr_inference_paraformer_tf, 'set_parameters'):
  256. asr_inference_paraformer_tf.set_parameters(
  257. language=cmd['lang'])
  258. else:
  259. # in order to support easyasr-0.0.2
  260. cmd['fs'] = cmd['fs']['model_fs']
  261. asr_result = asr_inference_paraformer_tf.asr_inference(
  262. ngpu=cmd['ngpu'],
  263. name_and_type=cmd['name_and_type'],
  264. audio_lists=cmd['audio_in'],
  265. idx_text_file=cmd['idx_text'],
  266. asr_model_file=cmd['asr_model_file'],
  267. vocab_file=cmd['vocab_file'],
  268. am_mvn_file=cmd['mvn_file'],
  269. predictions_file=cmd['predictions_file'],
  270. fs=cmd['fs'],
  271. hop_length=cmd['hop_length'],
  272. feature_dims=cmd['feature_dims'],
  273. sampled_ids=cmd['sampled_ids'],
  274. sampled_lengths=cmd['sampled_lengths'])
  275. return asr_result