|
- """
- Copyright 2020 Tianshu AI Platform. All Rights Reserved.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- """
- import json
- import time
- from service.oneflow_inference_service import OneFlowInferenceService
- from service.tensorflow_inference_service import TensorflowInferenceService
- from service.pytorch_inference_service import PytorchInferenceService
- import service.common_inference_service as common_inference_service
- from logger import Logger
- from utils import file_utils
- from utils.find_class_in_file import FindClassInFile
-
- log = Logger().logger
-
-
- class InferenceServiceManager:
-
- def __init__(self, args):
- self.inference_service = None
- self.args = args
- self.model_name_service_map = {}
-
- def init(self):
- if self.args.model_config_file != "":
- with open(self.args.model_config_file) as data_file:
- model_config_file_dict = json.load(data_file)
- model_config_list = model_config_file_dict["model_config_list"]
- for model_config in model_config_list:
- model_name = model_config["model_name"]
- model_path = model_config["model_path"]
- self.args.model_name = model_name
- self.args.model_path = model_path
- model_platform = model_config.get("platform")
-
- if model_platform == "oneflow":
- self.inference_service = OneFlowInferenceService(self.args)
- elif model_platform == "tensorflow" or model_platform == "keras":
- self.inference_service = TensorflowInferenceService(self.args)
- elif model_platform == "pytorch":
- self.inference_service = PytorchInferenceService(self.args)
-
- self.model_name_service_map[model_name] = self.inference_service
- else:
- # Read from command-line parameter
- if self.args.use_script:
- # 使用自定义推理脚本
- find_class_in_file = FindClassInFile()
- cls = find_class_in_file.find(common_inference_service)
- self.inference_service = cls[1](self.args)
-
- else :
- # 使用默认推理脚本
- if self.args.platform == "oneflow":
- self.inference_service = OneFlowInferenceService(self.args)
- elif self.args.platform == "tensorflow" or self.args.platform == "keras":
- self.inference_service = TensorflowInferenceService(self.args)
- elif self.args.platform == "pytorch":
- self.inference_service = PytorchInferenceService(self.args)
-
-
- self.model_name_service_map[self.args.model_name] = self.inference_service
-
- def inference(self, model_name, data_list):
- """
- 在线服务推理方法
- """
- inferenceService = self.model_name_service_map[model_name]
- result = list()
- for data in data_list:
- output = inferenceService.inference(data)
- if len(data_list) == 1:
- return output
- else:
- result.append(output)
- return result
-
- def inference_and_save_json(self, model_name, json_path, data_list):
- """
- 批量服务推理方法
- """
- inferenceService = self.model_name_service_map[model_name]
- for data in data_list:
- result = inferenceService.inference(data)
- file_utils.writer_json_file(json_path, data['data_name'], result)
- time.sleep(1)
|