yingda.chen 3 years ago
parent
commit
5db8480c64
2 changed files with 5 additions and 6 deletions
  1. +2
    -2
      modelscope/pipelines/audio/asr/asr_engine/asr_inference_paraformer_espnet.py
  2. +3
    -4
      modelscope/pipelines/audio/asr/asr_inference_pipeline.py

+ 2
- 2
modelscope/pipelines/audio/asr/asr_engine/asr_inference_paraformer_espnet.py View File

@@ -6,7 +6,7 @@ import logging
import sys
import time
from pathlib import Path
from typing import Any, List, Optional, Sequence, Tuple, Union
from typing import Any, Optional, Sequence, Tuple, Union

import numpy as np
import torch
@@ -33,7 +33,7 @@ from espnet.nets.scorer_interface import BatchScorerInterface
from espnet.nets.scorers.ctc import CTCPrefixScorer
from espnet.nets.scorers.length_bonus import LengthBonus
from espnet.utils.cli_utils import get_commandline_args
from typeguard import check_argument_types, check_return_type
from typeguard import check_argument_types

from .espnet.tasks.asr import ASRTaskNAR as ASRTask



+ 3
- 4
modelscope/pipelines/audio/asr/asr_inference_pipeline.py View File

@@ -1,8 +1,7 @@
import io
import os
import shutil
import threading
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Sequence, Tuple, Union

import yaml

@@ -12,7 +11,6 @@ from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import WavToScp
from modelscope.utils.constant import Tasks
from .asr_engine import asr_env_checking, asr_inference_paraformer_espnet
from .asr_engine.common import asr_utils

__all__ = ['AutomaticSpeechRecognitionPipeline']
@@ -30,7 +28,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
**kwargs):
"""use `model` and `preprocessor` to create an asr pipeline for prediction
"""
from .asr_engine import asr_env_checking
assert model is not None, 'asr model should be provided'

model_list: List = []
@@ -199,6 +197,7 @@ class AsrInferenceThread(threading.Thread):

def run(self):
if self._cmd['model_type'] == 'pytorch':
from .asr_engine import asr_inference_paraformer_espnet
asr_inference_paraformer_espnet.asr_inference(
batch_size=self._cmd['batch_size'],
output_dir=self._cmd['output_dir'],


Loading…
Cancel
Save