| @@ -5,15 +5,11 @@ | |||
| + **支持oneflow、tensorflow、pytorch三种框架模型部署** </br> | |||
| 1、通过如下命令启动http在线推理服务 | |||
| ``` | |||
| python http_server.py --platform='框架名称' --model_path='模型地址' | |||
| ``` | |||
| ``` | |||
| python http_server.py --platform='框架名称' --model_path='模型地址' | |||
| ``` | |||
| 通过访问localhost:5000/docs进入swagger页面,调用localhost:5000/inference进行图片上传得道推理结果,结果如下所示: | |||
| ``` | |||
| ``` | |||
| { | |||
| "image_name": "哈士奇.jpg", | |||
| "predictions": [ | |||
| @@ -39,28 +35,35 @@ | |||
| } | |||
| ] | |||
| } | |||
| ``` | |||
| ``` | |||
| 2、同理通过如下命令启动grpc在线推理服务 | |||
| ``` | |||
| python grpc_server.py --platform='框架名称' --model_path='模型地址' | |||
| ``` | |||
| ``` | |||
| python grpc_server.py --platform='框架名称' --model_path='模型地址' | |||
| ``` | |||
| 再启动grpc_client.py进行上传图片推理得道结果,或者根据ip端口自行编写grpc客户端 | |||
| 3、支持多模型部署,可以自行配置config文件夹下的model_config_file.json进行多模型配置,启动http或grpc时输入不同的模型名称即可,或者自行修改inference接口入参来达到启动单一服务多模型推理的功能 | |||
| + **支持多模型部署** </br> | |||
| 用户可以自行配置config文件夹下的model_config_file.json进行多模型配置,启动http或grpc时输入不同的模型名称即可,或者自行修改inference接口入参来达到启动单一服务多模型推理的功能 | |||
| + **支持分布式模型部署推理** </br> | |||
| 需要推理大量图片时需要分布式推理功能,执行如下命令: | |||
| ``` | |||
| python batch_server.py --platform='框架名称' --model_path='模型地址' --input_path='批量图片地址' --output_path='输出JSON文件地址' | |||
| ``` | |||
| ``` | |||
| python batch_server.py --platform='框架名称' --model_path='模型地址' --input_path='批量图片地址' --output_path='输出JSON文件地址' | |||
| ``` | |||
| 输入的所有图片保存在input文件夹下,输入json文件保存在output_path文件夹,json名称与图片名称对应 | |||
| + **支持使用自定义推理脚本** </br> | |||
| 用户需要使用自定义推理脚本时,可根据common_inference_service.py脚本中注释的规则,自定义推理脚本,并替换原有的common_inference_service.py脚本。此外,在启动命令中添加use_script参数,即可在推理时使用自定义的推理脚本。命令如下所示: | |||
| ``` | |||
| python grpc_server.py --platform='框架名称' --model_path='模型地址' --user_script=True | |||
| ``` | |||
| + **代码还包含了各种参数配置,日志文件输出、是否启用TLS等** </br> | |||
| @@ -27,49 +27,47 @@ log = Logger().logger | |||
| def get_host_ip(): | |||
| """ | |||
| 查询本机ip地址 | |||
| :return: | |||
| return | |||
| """ | |||
| global s | |||
| try: | |||
| s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | |||
| ip = s.getsockname()[0] | |||
| finally: | |||
| s.close() | |||
| hostname = socket.gethostname() | |||
| ip = socket.gethostbyname(hostname) | |||
| return ip | |||
| def read_directory(images_path): | |||
| def read_directory(data_path): | |||
| """ | |||
| 读取文件夹并进行拆分文件 | |||
| :return: | |||
| """ | |||
| files = os.listdir(images_path) | |||
| files = os.listdir(data_path) | |||
| num_files = len(files) | |||
| index_list = list(range(num_files)) | |||
| images = list() | |||
| data_list = list() | |||
| for index in index_list: | |||
| # 是否开启分布式 | |||
| if args.enable_distributed: | |||
| ip = get_host_ip() | |||
| log.info("NODE_IPS:{}", os.getenv('NODE_IPS')) | |||
| ip_list = os.getenv('NODE_IPS').split(",") | |||
| num_ips = len(ip_list) | |||
| ip_index = ip_list.index(ip) | |||
| if ip_index == index % num_ips: | |||
| filename = files[index] | |||
| image = {"image_name": filename, "image_path": images_path + filename} | |||
| images.append(image) | |||
| data = {"data_name": filename, "data_path": data_path + filename} | |||
| data_list.append(data) | |||
| else: | |||
| filename = files[index] | |||
| image = {"image_name": filename, "image_path": images_path + filename} | |||
| images.append(image) | |||
| return images | |||
| data = {"data_name": filename, "data_path": data_path + filename} | |||
| data_list.append(data) | |||
| return data_list | |||
| def main(): | |||
| images = read_directory(args.input_path) | |||
| inference_service.inference_and_save_json(args.model_name, args.output_path, images) | |||
| data_list = read_directory(args.input_path) | |||
| inference_service.inference_and_save_json(args.model_name, args.output_path, data_list) | |||
| if args.enable_distributed: | |||
| ip = get_host_ip() | |||
| log.info("NODE_IPS:{}", os.getenv('NODE_IPS')) | |||
| ip_list = os.getenv('NODE_IPS').split(",") | |||
| # 主节点必须等待从节点推理完成 | |||
| if ip == ip_list[0]: | |||
| @@ -58,12 +58,12 @@ def get_parser(parser=None): | |||
| parser.add_argument("--job_name", type=str, default="inference", help="oneflow job name") | |||
| parser.add_argument("--prepare_mode", type=str, default="tfhub", | |||
| help="tensorflow prepare mode(tfhub、caffe、tf、torch)") | |||
| parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="is use gpu") | |||
| parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="whether to use gpu") | |||
| parser.add_argument('--channel_last', type=str2bool, nargs='?', const=False, | |||
| help='Whether to use use channel last mode(nhwc)') | |||
| parser.add_argument("--model_path", type=str, default="/usr/local/model/pytorch_models/resnet50/", | |||
| parser.add_argument("--model_path", type=str, default="/usr/local/work/models/pytorch_models/resnet50/", | |||
| help="model load directory if need") | |||
| parser.add_argument("--image_path", type=str, default='/usr/local/data/fish.jpg', help="image path") | |||
| parser.add_argument("--data_path", type=str, default='/usr/local/work/dog.jpg', help="input data path") | |||
| parser.add_argument("--reshape_size", type=int_list, default='[224]', | |||
| help="The reshape size of the image(eg. 224)") | |||
| parser.add_argument("--num_classes", type=int, default=1000, help="num of pic classes") | |||
| @@ -78,8 +78,9 @@ def get_parser(parser=None): | |||
| parser.add_argument("--model_config_file", type=str, default="", help="The file of the model config(eg. '')") | |||
| parser.add_argument("--enable_distributed", type=ast.literal_eval, default=False, help="If enable use distributed " | |||
| "environment") | |||
| parser.add_argument("--input_path", type=str, default="/usr/local/data/images/", help="images path") | |||
| parser.add_argument("--output_path", type=str, default="/usr/local/output_path/", help="json path") | |||
| parser.add_argument("--input_path", type=str, default="/usr/local/input/", help="input batch data path") | |||
| parser.add_argument("--output_path", type=str, default="/usr/local/output/", help="output json path") | |||
| parser.add_argument("--use_script", type=ast.literal_eval, default=False, help="whether to use custom inference script") | |||
| return parser | |||
| @@ -0,0 +1,84 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| """ | |||
| import os | |||
| import io | |||
| import torch | |||
| import torch.nn.functional as functional | |||
| from PIL import Image | |||
| from torchvision import transforms | |||
| from imagenet1000_clsidx_to_labels import clsidx_2_labels | |||
| from logger import Logger | |||
| log = Logger().logger | |||
| #只能定义一个class | |||
| class CommonInferenceService: | |||
| # __init__初始化方法中接收args参数(其中模型路径参数为args.model_path,是否使用gpu参数为args.use_gpu),并加载模型(方法用户可自定义) | |||
| def __init__(self, args): | |||
| self.args = args | |||
| self.model = self.load_model() | |||
| def load_data(self, data_path): | |||
| image = open(data_path, 'rb').read() | |||
| image = Image.open(io.BytesIO(image)) | |||
| if image.mode != 'RGB': | |||
| image = image.convert("RGB") | |||
| image = transforms.Resize((self.args.reshape_size[0], self.args.reshape_size[1]))(image) | |||
| image = transforms.ToTensor()(image) | |||
| image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image) | |||
| image = image[None] | |||
| if self.args.use_gpu: | |||
| image = image.cuda() | |||
| return image | |||
| def load_model(self): | |||
| if os.path.isfile(self.args.model_path): | |||
| self.checkpoint = torch.load(self.args.model_path) | |||
| else: | |||
| for file in os.listdir(self.args.model_path): | |||
| self.checkpoint = torch.load(self.args.model_path + file) | |||
| model = self.checkpoint["model"] | |||
| model.load_state_dict(self.checkpoint['state_dict']) | |||
| for parameter in model.parameters(): | |||
| parameter.requires_grad = False | |||
| if self.args.use_gpu: | |||
| model.cuda() | |||
| model.eval() | |||
| return model | |||
| # inference方法名称固定 | |||
| def inference(self, data): | |||
| result = {"data_name": data['data_name']} | |||
| data = self.load_data(data['data_path']) | |||
| preds = functional.softmax(self.model(data), dim=1) | |||
| predictions = torch.topk(preds.data, k=5, dim=1) | |||
| result['predictions'] = list() | |||
| for prob, label in zip(predictions[0][0], predictions[1][0]): | |||
| predictions = {"label": clsidx_2_labels[int(label)], "probability": "{:.3f}".format(float(prob))} | |||
| result['predictions'].append(predictions) | |||
| return result | |||
| if __name__=="__main__": | |||
| import argparse | |||
| parser = argparse.ArgumentParser(description='tianshu serving') | |||
| parser.add_argument('--model_path', type=str, default='./res4serving.pth', help="model path") | |||
| parser.add_argument('--use_gpu', type=bool, default=True, help="use gpu or not") | |||
| parser.add_argument('--reshape_size', type=list, default=[224,224], help="use gpu or not") | |||
| args = parser.parse_args() | |||
| server = CommonInferenceService(args) | |||
| image_path = "./cat.jpg" | |||
| image = {"data_name": "cat.jpg", "data_path": image_path} | |||
| re = server.inference(image) | |||
| print(re) | |||
| @@ -21,8 +21,8 @@ log = Logger().logger | |||
| parser = configs.get_parser() | |||
| args = parser.parse_args() | |||
| _HOST = 'kohj2s.serving.dubhe.ai' | |||
| _PORT = '31365' | |||
| _HOST = '10.5.24.134' | |||
| _PORT = '8500' | |||
| MAX_MESSAGE_LENGTH = 1024 * 1024 * 1024 # 可根据具体需求设置,此处设为1G | |||
| @@ -41,12 +41,12 @@ def run(): | |||
| ('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH), ], ) # 创建连接 | |||
| client = inference_pb2_grpc.InferenceServiceStub(channel=channel) # 创建客户端 | |||
| data_request = inference_pb2.DataRequest() | |||
| Image = data_request.images.add() | |||
| Image.image_file = str(base64.b64encode(open("F:\\Files\\pic\\哈士奇.jpg", "rb").read()), encoding='utf-8') | |||
| Image.image_name = "哈士奇.jpg" | |||
| Image = data_request.images.add() | |||
| Image.image_file = str(base64.b64encode(open("F:\\Files\\pic\\fish.jpg", "rb").read()), encoding='utf-8') | |||
| Image.image_name = "fish.jpg" | |||
| data1 = data_request.data_list.add() | |||
| data1.data_file = str(base64.b64encode(open("/usr/local/input/dog.jpg", "rb").read()), encoding='utf-8') | |||
| data1.data_name = "dog.jpg" | |||
| data2 = data_request.data_list.add() | |||
| data2.data_file = str(base64.b64encode(open("/usr/local/input/6.jpg", "rb").read()), encoding='utf-8') | |||
| data2.data_name = "6.jpg" | |||
| response = client.inference(data_request) | |||
| log.info(response.json_result.encode('utf-8').decode('unicode_escape')) | |||
| @@ -46,19 +46,21 @@ class InferenceService(inference_pb2_grpc.InferenceServiceServicer): | |||
| 调用grpc方法进行推理 | |||
| """ | |||
| def inference(self, request, context): | |||
| image_files = request.images | |||
| data_list = request.data_list | |||
| log.info("===============> grpc inference start <===============") | |||
| try: | |||
| images = file_utils.upload_image_by_base64(image_files) # 上传图片到本地 | |||
| data_list_b64 = file_utils.upload_image_by_base64(data_list) # 上传图片到本地 | |||
| except Exception as e: | |||
| log.error("upload data failed", e) | |||
| return inference_pb2.DataResponse(json_result=json.dumps( | |||
| response_convert(Response(success=False, data=str(e), error="upload image fail")))) | |||
| response_convert(Response(success=False, data=str(e), error="upload data failed")))) | |||
| try: | |||
| result = inference_service.inference(args.model_name, images) | |||
| result = inference_service.inference(args.model_name, data_list_b64) | |||
| log.info("===============> grpc inference success <===============") | |||
| return inference_pb2.DataResponse(json_result=json.dumps( | |||
| response_convert(Response(success=True, data=result)))) | |||
| except Exception as e: | |||
| log.error("inference fail", e) | |||
| return inference_pb2.DataResponse(json_result=json.dumps( | |||
| response_convert(Response(success=False, data=str(e), error="inference fail")))) | |||
| @@ -49,7 +49,7 @@ async def inference(images_path: List[str] = None): | |||
| threading.Thread(target=file_utils.download_image(images_path)) # 开启异步线程下载图片到本地 | |||
| images = list() | |||
| for image in images_path: | |||
| data = {"image_name": image.split("/")[-1], "image_path": image} | |||
| data = {"data_name": image.split("/")[-1], "data_path": image} | |||
| images.append(data) | |||
| try: | |||
| data = inference_service.inference(args.model_name, images) | |||
| @@ -59,17 +59,22 @@ async def inference(images_path: List[str] = None): | |||
| @app.post("/inference") | |||
| async def inference(image_files: List[UploadFile] = File(...)): | |||
| async def inference(files: List[UploadFile] = File(...)): | |||
| """ | |||
| 上传本地文件推理 | |||
| """ | |||
| log.info("===============> http inference start <===============") | |||
| try: | |||
| images = file_utils.upload_image(image_files) # 上传图片到本地 | |||
| data_list = file_utils.upload_data(files) # 上传图片到本地 | |||
| except Exception as e: | |||
| return Response(success=False, data=str(e), error="upload image fail") | |||
| log.error("upload data failed", e) | |||
| return Response(success=False, data=str(e), error="upload data failed") | |||
| try: | |||
| result = inference_service.inference(args.model_name, images) | |||
| result = inference_service.inference(args.model_name, data_list) | |||
| log.info("===============> http inference success <===============") | |||
| return Response(success=True, data=result) | |||
| except Exception as e: | |||
| log.error("inference fail", e) | |||
| return Response(success=False, data=str(e), error="inference fail") | |||
| @@ -0,0 +1 @@ | |||
| 2021-01-14 19:55:07,045 - E:/Python Project/TS_Serving/grpc_client.py[line:51] - INFO: {"success": true, "data": [{"image_name": "哈士奇.jpg", "predictions": [{"label": "Eskimo dog, husky", "probability": "0.665"}, {"label": "Siberian husky", "probability": "0.294"}, {"label": "dogsled, dog sled, dog sleigh", "probability": "0.024"}, {"label": "malamute, malemute, Alaskan malamute", "probability": "0.015"}, {"label": "timber wolf, grey wolf, gray wolf, Canis lupus", "probability": "0.000"}]}, {"image_name": "fish.jpg", "predictions": [{"label": "goldfish, Carassius auratus", "probability": "1.000"}, {"label": "tench, Tinca tinca", "probability": "0.000"}, {"label": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch", "probability": "0.000"}, {"label": "axolotl, mud puppy, Ambystoma mexicanum", "probability": "0.000"}, {"label": "barracouta, snoek", "probability": "0.000"}]}], "error": ""} | |||
| @@ -0,0 +1 @@ | |||
| 2021-02-03 15:05:00,069 - E:/Python Project/TS_Serving/grpc_client.py[line:51] - INFO: {"success": true, "data": [{"image_name": "哈士奇.jpg", "predictions": [{"label": "Eskimo dog, husky", "probability": "0.679"}, {"label": "Siberian husky", "probability": "0.213"}, {"label": "dogsled, dog sled, dog sleigh", "probability": "0.021"}, {"label": "malamute, malemute, Alaskan malamute", "probability": "0.006"}, {"label": "white wolf, Arctic wolf, Canis lupus tundrarum", "probability": "0.001"}]}, {"image_name": "fish.jpg", "predictions": [{"label": "goldfish, Carassius auratus", "probability": "0.871"}, {"label": "tench, Tinca tinca", "probability": "0.004"}, {"label": "puffer, pufferfish, blowfish, globefish", "probability": "0.002"}, {"label": "rock beauty, Holocanthus tricolor", "probability": "0.002"}, {"label": "barracouta, snoek", "probability": "0.001"}]}], "error": ""} | |||
| @@ -1,16 +1,16 @@ | |||
| syntax = 'proto3'; | |||
| service InferenceService { | |||
| rpc inference(DataRequest) returns (DataResponse) {} | |||
| } | |||
| message DataRequest{ | |||
| repeated Image images = 1; | |||
| repeated Data data_list = 1; | |||
| } | |||
| message Image { | |||
| string image_file = 1; | |||
| string image_name = 2; | |||
| message Data { | |||
| string data_file = 1; | |||
| string data_name = 2; | |||
| } | |||
| message DataResponse{ | |||
| @@ -6,168 +6,177 @@ from google.protobuf import descriptor as _descriptor | |||
| from google.protobuf import message as _message | |||
| from google.protobuf import reflection as _reflection | |||
| from google.protobuf import symbol_database as _symbol_database | |||
| # @@protoc_insertion_point(imports) | |||
| _sym_db = _symbol_database.Default() | |||
| DESCRIPTOR = _descriptor.FileDescriptor( | |||
| name='inference.proto', | |||
| package='', | |||
| syntax='proto3', | |||
| serialized_options=None, | |||
| create_key=_descriptor._internal_create_key, | |||
| serialized_pb=b'\n\x0finference.proto\"%\n\x0b\x44\x61taRequest\x12\x16\n\x06images\x18\x01 \x03(\x0b\x32\x06.Image\"/\n\x05Image\x12\x12\n\nimage_file\x18\x01 \x01(\t\x12\x12\n\nimage_name\x18\x02 \x01(\t\"#\n\x0c\x44\x61taResponse\x12\x13\n\x0bjson_result\x18\x01 \x01(\t2>\n\x10InferenceService\x12*\n\tinference\x12\x0c.DataRequest\x1a\r.DataResponse\"\x00\x62\x06proto3' | |||
| name='inference.proto', | |||
| package='', | |||
| syntax='proto3', | |||
| serialized_options=None, | |||
| create_key=_descriptor._internal_create_key, | |||
| serialized_pb=b'\n\x0finference.proto\"\'\n\x0b\x44\x61taRequest\x12\x18\n\tdata_list\x18\x01 \x03(\x0b\x32\x05.Data\",\n\x04\x44\x61ta\x12\x11\n\tdata_file\x18\x01 \x01(\t\x12\x11\n\tdata_name\x18\x02 \x01(\t\"#\n\x0c\x44\x61taResponse\x12\x13\n\x0bjson_result\x18\x01 \x01(\t2>\n\x10InferenceService\x12*\n\tinference\x12\x0c.DataRequest\x1a\r.DataResponse\"\x00\x62\x06proto3' | |||
| ) | |||
| _DATAREQUEST = _descriptor.Descriptor( | |||
| name='DataRequest', | |||
| full_name='DataRequest', | |||
| filename=None, | |||
| file=DESCRIPTOR, | |||
| containing_type=None, | |||
| create_key=_descriptor._internal_create_key, | |||
| fields=[ | |||
| _descriptor.FieldDescriptor( | |||
| name='images', full_name='DataRequest.images', index=0, | |||
| number=1, type=11, cpp_type=10, label=3, | |||
| has_default_value=False, default_value=[], | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), | |||
| ], | |||
| extensions=[ | |||
| ], | |||
| nested_types=[], | |||
| enum_types=[ | |||
| ], | |||
| serialized_options=None, | |||
| is_extendable=False, | |||
| syntax='proto3', | |||
| extension_ranges=[], | |||
| oneofs=[ | |||
| ], | |||
| serialized_start=19, | |||
| serialized_end=56, | |||
| name='DataRequest', | |||
| full_name='DataRequest', | |||
| filename=None, | |||
| file=DESCRIPTOR, | |||
| containing_type=None, | |||
| create_key=_descriptor._internal_create_key, | |||
| fields=[ | |||
| _descriptor.FieldDescriptor( | |||
| name='data_list', full_name='DataRequest.data_list', index=0, | |||
| number=1, type=11, cpp_type=10, label=3, | |||
| has_default_value=False, default_value=[], | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), | |||
| ], | |||
| extensions=[ | |||
| ], | |||
| nested_types=[], | |||
| enum_types=[ | |||
| ], | |||
| serialized_options=None, | |||
| is_extendable=False, | |||
| syntax='proto3', | |||
| extension_ranges=[], | |||
| oneofs=[ | |||
| ], | |||
| serialized_start=19, | |||
| serialized_end=58, | |||
| ) | |||
| _IMAGE = _descriptor.Descriptor( | |||
| name='Image', | |||
| full_name='Image', | |||
| filename=None, | |||
| file=DESCRIPTOR, | |||
| containing_type=None, | |||
| create_key=_descriptor._internal_create_key, | |||
| fields=[ | |||
| _descriptor.FieldDescriptor( | |||
| name='image_file', full_name='Image.image_file', index=0, | |||
| number=1, type=9, cpp_type=9, label=1, | |||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), | |||
| _descriptor.FieldDescriptor( | |||
| name='image_name', full_name='Image.image_name', index=1, | |||
| number=2, type=9, cpp_type=9, label=1, | |||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), | |||
| ], | |||
| extensions=[ | |||
| ], | |||
| nested_types=[], | |||
| enum_types=[ | |||
| ], | |||
| serialized_options=None, | |||
| is_extendable=False, | |||
| syntax='proto3', | |||
| extension_ranges=[], | |||
| oneofs=[ | |||
| ], | |||
| serialized_start=58, | |||
| serialized_end=105, | |||
| _DATA = _descriptor.Descriptor( | |||
| name='Data', | |||
| full_name='Data', | |||
| filename=None, | |||
| file=DESCRIPTOR, | |||
| containing_type=None, | |||
| create_key=_descriptor._internal_create_key, | |||
| fields=[ | |||
| _descriptor.FieldDescriptor( | |||
| name='data_file', full_name='Data.data_file', index=0, | |||
| number=1, type=9, cpp_type=9, label=1, | |||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), | |||
| _descriptor.FieldDescriptor( | |||
| name='data_name', full_name='Data.data_name', index=1, | |||
| number=2, type=9, cpp_type=9, label=1, | |||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), | |||
| ], | |||
| extensions=[ | |||
| ], | |||
| nested_types=[], | |||
| enum_types=[ | |||
| ], | |||
| serialized_options=None, | |||
| is_extendable=False, | |||
| syntax='proto3', | |||
| extension_ranges=[], | |||
| oneofs=[ | |||
| ], | |||
| serialized_start=60, | |||
| serialized_end=104, | |||
| ) | |||
| _DATARESPONSE = _descriptor.Descriptor( | |||
| name='DataResponse', | |||
| full_name='DataResponse', | |||
| filename=None, | |||
| file=DESCRIPTOR, | |||
| containing_type=None, | |||
| create_key=_descriptor._internal_create_key, | |||
| fields=[ | |||
| _descriptor.FieldDescriptor( | |||
| name='json_result', full_name='DataResponse.json_result', index=0, | |||
| number=1, type=9, cpp_type=9, label=1, | |||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), | |||
| ], | |||
| extensions=[ | |||
| ], | |||
| nested_types=[], | |||
| enum_types=[ | |||
| ], | |||
| serialized_options=None, | |||
| is_extendable=False, | |||
| syntax='proto3', | |||
| extension_ranges=[], | |||
| oneofs=[ | |||
| ], | |||
| serialized_start=107, | |||
| serialized_end=142, | |||
| name='DataResponse', | |||
| full_name='DataResponse', | |||
| filename=None, | |||
| file=DESCRIPTOR, | |||
| containing_type=None, | |||
| create_key=_descriptor._internal_create_key, | |||
| fields=[ | |||
| _descriptor.FieldDescriptor( | |||
| name='json_result', full_name='DataResponse.json_result', index=0, | |||
| number=1, type=9, cpp_type=9, label=1, | |||
| has_default_value=False, default_value=b"".decode('utf-8'), | |||
| message_type=None, enum_type=None, containing_type=None, | |||
| is_extension=False, extension_scope=None, | |||
| serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), | |||
| ], | |||
| extensions=[ | |||
| ], | |||
| nested_types=[], | |||
| enum_types=[ | |||
| ], | |||
| serialized_options=None, | |||
| is_extendable=False, | |||
| syntax='proto3', | |||
| extension_ranges=[], | |||
| oneofs=[ | |||
| ], | |||
| serialized_start=106, | |||
| serialized_end=141, | |||
| ) | |||
| _DATAREQUEST.fields_by_name['images'].message_type = _IMAGE | |||
| _DATAREQUEST.fields_by_name['data_list'].message_type = _DATA | |||
| DESCRIPTOR.message_types_by_name['DataRequest'] = _DATAREQUEST | |||
| DESCRIPTOR.message_types_by_name['Image'] = _IMAGE | |||
| DESCRIPTOR.message_types_by_name['Data'] = _DATA | |||
| DESCRIPTOR.message_types_by_name['DataResponse'] = _DATARESPONSE | |||
| _sym_db.RegisterFileDescriptor(DESCRIPTOR) | |||
| DataRequest = _reflection.GeneratedProtocolMessageType('DataRequest', (_message.Message,), { | |||
| 'DESCRIPTOR': _DATAREQUEST, | |||
| '__module__': 'inference_pb2' | |||
| # @@protoc_insertion_point(class_scope:DataRequest) | |||
| }) | |||
| 'DESCRIPTOR' : _DATAREQUEST, | |||
| '__module__' : 'inference_pb2' | |||
| # @@protoc_insertion_point(class_scope:DataRequest) | |||
| }) | |||
| _sym_db.RegisterMessage(DataRequest) | |||
| Image = _reflection.GeneratedProtocolMessageType('Image', (_message.Message,), { | |||
| 'DESCRIPTOR': _IMAGE, | |||
| '__module__': 'inference_pb2' | |||
| # @@protoc_insertion_point(class_scope:Image) | |||
| }) | |||
| _sym_db.RegisterMessage(Image) | |||
| Data = _reflection.GeneratedProtocolMessageType('Data', (_message.Message,), { | |||
| 'DESCRIPTOR' : _DATA, | |||
| '__module__' : 'inference_pb2' | |||
| # @@protoc_insertion_point(class_scope:Data) | |||
| }) | |||
| _sym_db.RegisterMessage(Data) | |||
| DataResponse = _reflection.GeneratedProtocolMessageType('DataResponse', (_message.Message,), { | |||
| 'DESCRIPTOR': _DATARESPONSE, | |||
| '__module__': 'inference_pb2' | |||
| # @@protoc_insertion_point(class_scope:DataResponse) | |||
| }) | |||
| 'DESCRIPTOR' : _DATARESPONSE, | |||
| '__module__' : 'inference_pb2' | |||
| # @@protoc_insertion_point(class_scope:DataResponse) | |||
| }) | |||
| _sym_db.RegisterMessage(DataResponse) | |||
| _INFERENCESERVICE = _descriptor.ServiceDescriptor( | |||
| name='InferenceService', | |||
| full_name='InferenceService', | |||
| file=DESCRIPTOR, | |||
| name='InferenceService', | |||
| full_name='InferenceService', | |||
| file=DESCRIPTOR, | |||
| index=0, | |||
| serialized_options=None, | |||
| create_key=_descriptor._internal_create_key, | |||
| serialized_start=143, | |||
| serialized_end=205, | |||
| methods=[ | |||
| _descriptor.MethodDescriptor( | |||
| name='inference', | |||
| full_name='InferenceService.inference', | |||
| index=0, | |||
| containing_service=None, | |||
| input_type=_DATAREQUEST, | |||
| output_type=_DATARESPONSE, | |||
| serialized_options=None, | |||
| create_key=_descriptor._internal_create_key, | |||
| serialized_start=144, | |||
| serialized_end=206, | |||
| methods=[ | |||
| _descriptor.MethodDescriptor( | |||
| name='inference', | |||
| full_name='InferenceService.inference', | |||
| index=0, | |||
| containing_service=None, | |||
| input_type=_DATAREQUEST, | |||
| output_type=_DATARESPONSE, | |||
| serialized_options=None, | |||
| create_key=_descriptor._internal_create_key, | |||
| ), | |||
| ]) | |||
| ), | |||
| ]) | |||
| _sym_db.RegisterServiceDescriptor(_INFERENCESERVICE) | |||
| DESCRIPTOR.services_by_name['InferenceService'] = _INFERENCESERVICE | |||
| @@ -15,10 +15,10 @@ class InferenceServiceStub(object): | |||
| channel: A grpc.Channel. | |||
| """ | |||
| self.inference = channel.unary_unary( | |||
| '/InferenceService/inference', | |||
| request_serializer=inference__pb2.DataRequest.SerializeToString, | |||
| response_deserializer=inference__pb2.DataResponse.FromString, | |||
| ) | |||
| '/InferenceService/inference', | |||
| request_serializer=inference__pb2.DataRequest.SerializeToString, | |||
| response_deserializer=inference__pb2.DataResponse.FromString, | |||
| ) | |||
| class InferenceServiceServicer(object): | |||
| @@ -33,34 +33,34 @@ class InferenceServiceServicer(object): | |||
| def add_InferenceServiceServicer_to_server(servicer, server): | |||
| rpc_method_handlers = { | |||
| 'inference': grpc.unary_unary_rpc_method_handler( | |||
| servicer.inference, | |||
| request_deserializer=inference__pb2.DataRequest.FromString, | |||
| response_serializer=inference__pb2.DataResponse.SerializeToString, | |||
| ), | |||
| 'inference': grpc.unary_unary_rpc_method_handler( | |||
| servicer.inference, | |||
| request_deserializer=inference__pb2.DataRequest.FromString, | |||
| response_serializer=inference__pb2.DataResponse.SerializeToString, | |||
| ), | |||
| } | |||
| generic_handler = grpc.method_handlers_generic_handler( | |||
| 'InferenceService', rpc_method_handlers) | |||
| 'InferenceService', rpc_method_handlers) | |||
| server.add_generic_rpc_handlers((generic_handler,)) | |||
| # This class is part of an EXPERIMENTAL API. | |||
| # This class is part of an EXPERIMENTAL API. | |||
| class InferenceService(object): | |||
| """Missing associated documentation comment in .proto file.""" | |||
| @staticmethod | |||
| def inference(request, | |||
| target, | |||
| options=(), | |||
| channel_credentials=None, | |||
| call_credentials=None, | |||
| insecure=False, | |||
| compression=None, | |||
| wait_for_ready=None, | |||
| timeout=None, | |||
| metadata=None): | |||
| target, | |||
| options=(), | |||
| channel_credentials=None, | |||
| call_credentials=None, | |||
| insecure=False, | |||
| compression=None, | |||
| wait_for_ready=None, | |||
| timeout=None, | |||
| metadata=None): | |||
| return grpc.experimental.unary_unary(request, target, '/InferenceService/inference', | |||
| inference__pb2.DataRequest.SerializeToString, | |||
| inference__pb2.DataResponse.FromString, | |||
| options, channel_credentials, | |||
| insecure, call_credentials, compression, wait_for_ready, timeout, metadata) | |||
| inference__pb2.DataRequest.SerializeToString, | |||
| inference__pb2.DataResponse.FromString, | |||
| options, channel_credentials, | |||
| insecure, call_credentials, compression, wait_for_ready, timeout, metadata) | |||
| @@ -0,0 +1,87 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| """ | |||
| import os | |||
| import io | |||
| import torch | |||
| import torch.nn.functional as functional | |||
| from PIL import Image | |||
| from torchvision import transforms | |||
| from imagenet1000_clsidx_to_labels import clsidx_2_labels | |||
| from logger import Logger | |||
| log = Logger().logger | |||
| # 只能定义一个class | |||
| class CommonInferenceService: | |||
| # 请在__init__初始化方法中接收args参数,并加载模型(其中模型路径参数为args.model_path,是否使用gpu参数为args.use_gpu,模型加载方法用户可自定义) | |||
| def __init__(self, args): | |||
| self.args = args | |||
| self.model = self.load_model() | |||
| def load_data(self, data_path): | |||
| image = open(data_path, 'rb').read() | |||
| image = Image.open(io.BytesIO(image)) | |||
| if image.mode != 'RGB': | |||
| image = image.convert("RGB") | |||
| image = transforms.Resize((self.args.reshape_size[0], self.args.reshape_size[1]))(image) | |||
| image = transforms.ToTensor()(image) | |||
| image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image) | |||
| image = image[None] | |||
| if self.args.use_gpu: | |||
| image = image.cuda() | |||
| return image | |||
| def load_model(self): | |||
| if os.path.isfile(self.args.model_path): | |||
| self.checkpoint = torch.load(self.args.model_path) | |||
| else: | |||
| for file in os.listdir(self.args.model_path): | |||
| self.checkpoint = torch.load(self.args.model_path + file) | |||
| model = self.checkpoint["model"] | |||
| model.load_state_dict(self.checkpoint['state_dict']) | |||
| for parameter in model.parameters(): | |||
| parameter.requires_grad = False | |||
| if self.args.use_gpu: | |||
| model.cuda() | |||
| model.eval() | |||
| return model | |||
| # inference方法名称固定 | |||
| def inference(self, data): | |||
| result = {"data_name": data['data_name']} | |||
| log.info("===============> start load " + data['data_name'] + " <===============") | |||
| data = self.load_data(data['data_path']) | |||
| preds = functional.softmax(self.model(data), dim=1) | |||
| predictions = torch.topk(preds.data, k=5, dim=1) | |||
| result['predictions'] = list() | |||
| for prob, label in zip(predictions[0][0], predictions[1][0]): | |||
| predictions = {"label": clsidx_2_labels[int(label)], "probability": "{:.3f}".format(float(prob))} | |||
| result['predictions'].append(predictions) | |||
| return result | |||
| # 非必须,可用于本地调试 | |||
| if __name__=="__main__": | |||
| import argparse | |||
| parser = argparse.ArgumentParser(description='dubhe serving') | |||
| parser.add_argument('--model_path', type=str, default='./res4serving.pth', help="model path") | |||
| parser.add_argument('--use_gpu', type=bool, default=True, help="use gpu or not") | |||
| parser.add_argument('--reshape_size', type=list, default=[224,224], help="use gpu or not") | |||
| args = parser.parse_args() | |||
| server = CommonInferenceService(args) | |||
| image_path = "./cat.jpg" | |||
| image = {"data_name": "cat.jpg", "data_path": image_path} | |||
| re = server.inference(image) | |||
| print(re) | |||
| @@ -15,8 +15,10 @@ import time | |||
| from service.oneflow_inference_service import OneFlowInferenceService | |||
| from service.tensorflow_inference_service import TensorflowInferenceService | |||
| from service.pytorch_inference_service import PytorchInferenceService | |||
| import service.common_inference_service as common_inference_service | |||
| from logger import Logger | |||
| from utils import file_utils | |||
| from utils.find_class_in_file import FindClassInFile | |||
| log = Logger().logger | |||
| @@ -36,47 +38,58 @@ class InferenceServiceManager: | |||
| for model_config in model_config_list: | |||
| model_name = model_config["model_name"] | |||
| model_path = model_config["model_path"] | |||
| self.args.model_name = model_name | |||
| self.args.model_path = model_path | |||
| model_platform = model_config.get("platform") | |||
| if model_platform == "oneflow": | |||
| self.inference_service = OneFlowInferenceService(model_name, model_path) | |||
| self.inference_service = OneFlowInferenceService(self.args) | |||
| elif model_platform == "tensorflow" or model_platform == "keras": | |||
| self.inference_service = TensorflowInferenceService(model_name, model_path) | |||
| self.inference_service = TensorflowInferenceService(self.args) | |||
| elif model_platform == "pytorch": | |||
| self.inference_service = PytorchInferenceService(model_name, model_path) | |||
| self.inference_service = PytorchInferenceService(self.args) | |||
| self.model_name_service_map[model_name] = self.inference_service | |||
| else: | |||
| # Read from command-line parameter | |||
| if self.args.platform == "oneflow": | |||
| self.inference_service = OneFlowInferenceService(self.args.model_name, self.args.model_path) | |||
| elif self.args.platform == "tensorflow" or self.args.platform == "keras": | |||
| self.inference_service = TensorflowInferenceService(self.args.model_name, self.args.model_path) | |||
| elif self.args.platform == "pytorch": | |||
| self.inference_service = PytorchInferenceService(self.args.model_name, self.args.model_path) | |||
| if self.args.use_script: | |||
| # 使用自定义推理脚本 | |||
| find_class_in_file = FindClassInFile() | |||
| cls = find_class_in_file.find(common_inference_service) | |||
| self.inference_service = cls[1](self.args) | |||
| else : | |||
| # 使用默认推理脚本 | |||
| if self.args.platform == "oneflow": | |||
| self.inference_service = OneFlowInferenceService(self.args) | |||
| elif self.args.platform == "tensorflow" or self.args.platform == "keras": | |||
| self.inference_service = TensorflowInferenceService(self.args) | |||
| elif self.args.platform == "pytorch": | |||
| self.inference_service = PytorchInferenceService(self.args) | |||
| self.model_name_service_map[self.args.model_name] = self.inference_service | |||
| def inference(self, model_name, images): | |||
| def inference(self, model_name, data_list): | |||
| """ | |||
| 在线服务推理方法 | |||
| """ | |||
| inferenceService = self.model_name_service_map[model_name] | |||
| result = list() | |||
| for image in images: | |||
| data = inferenceService.inference(image) | |||
| if len(images) == 1: | |||
| return data | |||
| for data in data_list: | |||
| output = inferenceService.inference(data) | |||
| if len(data_list) == 1: | |||
| return output | |||
| else: | |||
| result.append(data) | |||
| result.append(output) | |||
| return result | |||
| def inference_and_save_json(self, model_name, json_path, images): | |||
| def inference_and_save_json(self, model_name, json_path, data_list): | |||
| """ | |||
| 批量服务推理方法 | |||
| """ | |||
| inferenceService = self.model_name_service_map[model_name] | |||
| for image in images: | |||
| data = inferenceService.inference(image) | |||
| file_utils.writer_json_file(json_path, image['image_name'], data) | |||
| for data in data_list: | |||
| result = inferenceService.inference(data) | |||
| file_utils.writer_json_file(json_path, data['data_name'], result) | |||
| time.sleep(1) | |||
| @@ -20,12 +20,8 @@ import google.protobuf.text_format as text_format | |||
| import os | |||
| from imagenet1000_clsidx_to_labels import clsidx_2_labels | |||
| from logger import Logger | |||
| import config as configs | |||
| from service.abstract_inference_service import AbstractInferenceService | |||
| parser = configs.get_parser() | |||
| args = parser.parse_args() | |||
| log = Logger().logger | |||
| @@ -33,10 +29,11 @@ class OneFlowInferenceService(AbstractInferenceService): | |||
| """ | |||
| oneflow 框架推理service | |||
| """ | |||
| def __init__(self, model_name, model_path): | |||
| def __init__(self, args): | |||
| super().__init__() | |||
| self.model_name = model_name | |||
| self.model_path = model_path | |||
| self.args = args | |||
| self.model_name = args.model_name | |||
| self.model_path = args.model_path | |||
| flow.clear_default_session() | |||
| self.infer_session = flow.SimpleSession() | |||
| self.load_model() | |||
| @@ -91,9 +88,9 @@ class OneFlowInferenceService(AbstractInferenceService): | |||
| return saved_model_proto | |||
| def inference(self, image): | |||
| data = {"image_name": image['image_name']} | |||
| log.info("===============> start load " + image['image_name'] + " <===============") | |||
| images = self.load_image(image['image_path']) | |||
| data = {"data_name": image['data_name']} | |||
| log.info("===============> start load " + image['data_name'] + " <===============") | |||
| images = self.load_image(image['data_path']) | |||
| predictions = self.infer_session.run('inference', image=images) | |||
| @@ -16,15 +16,12 @@ import torch | |||
| import torch.nn.functional as functional | |||
| from PIL import Image | |||
| from torchvision import transforms | |||
| import config | |||
| import requests | |||
| from imagenet1000_clsidx_to_labels import clsidx_2_labels | |||
| from io import BytesIO | |||
| from logger import Logger | |||
| from service.abstract_inference_service import AbstractInferenceService | |||
| parser = config.get_parser() | |||
| args = parser.parse_args() | |||
| log = Logger().logger | |||
| @@ -33,10 +30,11 @@ class PytorchInferenceService(AbstractInferenceService): | |||
| pytorch 框架推理service | |||
| """ | |||
| def __init__(self, model_name, model_path): | |||
| def __init__(self, args): | |||
| super().__init__() | |||
| self.model_name = model_name | |||
| self.model_path = model_path | |||
| self.args = args | |||
| self.model_name = args.model_name | |||
| self.model_path = args.model_path | |||
| self.model = self.load_model() | |||
| self.checkpoint = None | |||
| @@ -52,38 +50,38 @@ class PytorchInferenceService(AbstractInferenceService): | |||
| image = Image.open(io.BytesIO(image)) | |||
| if image.mode != 'RGB': | |||
| image = image.convert("RGB") | |||
| image = transforms.Resize((args.reshape_size[0], args.reshape_size[1]))(image) | |||
| image = transforms.Resize((self.args.reshape_size[0], self.args.reshape_size[1]))(image) | |||
| image = transforms.ToTensor()(image) | |||
| image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image) | |||
| image = image[None] | |||
| if args.use_gpu: | |||
| if self.args.use_gpu: | |||
| image = image.cuda() | |||
| log.info("===============> load image success <===============") | |||
| return image | |||
| def load_model(self): | |||
| log.info("===============> start load pytorch model :" + args.model_path + " <===============") | |||
| if os.path.isfile(args.model_path): | |||
| log.info("===============> start load pytorch model :" + self.args.model_path + " <===============") | |||
| if os.path.isfile(self.args.model_path): | |||
| self.checkpoint = torch.load(self.model_path) | |||
| else: | |||
| for file in os.listdir(args.model_path): | |||
| for file in os.listdir(self.args.model_path): | |||
| self.checkpoint = torch.load(self.model_path + file) | |||
| model = self.checkpoint[args.model_structure] | |||
| model = self.checkpoint[self.args.model_structure] | |||
| model.load_state_dict(self.checkpoint['state_dict']) | |||
| for parameter in model.parameters(): | |||
| parameter.requires_grad = False | |||
| if args.use_gpu: | |||
| if self.args.use_gpu: | |||
| model.cuda() | |||
| model.eval() | |||
| log.info("===============> load pytorch model success <===============") | |||
| return model | |||
| def inference(self, image): | |||
| data = {"image_name": image['image_name']} | |||
| log.info("===============> start load " + image['image_name'] + " <===============") | |||
| image = self.load_image(image['image_path']) | |||
| predis = functional.softmax(self.model(image), dim=1) | |||
| results = torch.topk(predis.data, k=5, dim=1) | |||
| data = {"data_name": image['data_name']} | |||
| log.info("===============> start load " + image['data_name'] + " <===============") | |||
| image = self.load_image(image['data_path']) | |||
| preds = functional.softmax(self.model(image), dim=1) | |||
| results = torch.topk(preds.data, k=5, dim=1) | |||
| data['predictions'] = list() | |||
| for prob, label in zip(results[0][0], results[1][0]): | |||
| result = {"label": clsidx_2_labels[int(label)], "probability": "{:.3f}".format(float(prob))} | |||
| @@ -13,7 +13,6 @@ limitations under the License. | |||
| import tensorflow as tf | |||
| import requests | |||
| import numpy as np | |||
| import config as configs | |||
| from imagenet1000_clsidx_to_labels import clsidx_2_labels | |||
| from service.abstract_inference_service import AbstractInferenceService | |||
| from utils.imagenet_preprocessing_utils import preprocess_input | |||
| @@ -21,8 +20,6 @@ from logger import Logger | |||
| from PIL import Image | |||
| from io import BytesIO | |||
| parser = configs.get_parser() | |||
| args = parser.parse_args() | |||
| log = Logger().logger | |||
| @@ -30,11 +27,12 @@ class TensorflowInferenceService(AbstractInferenceService): | |||
| """ | |||
| tensorflow 框架推理service | |||
| """ | |||
| def __init__(self, model_name, model_path): | |||
| def __init__(self, args): | |||
| super().__init__() | |||
| self.session = tf.compat.v1.Session(graph=tf.Graph()) | |||
| self.model_name = model_name | |||
| self.model_path = model_path | |||
| self.args = args | |||
| self.model_name = args.model_name | |||
| self.model_path = args.model_path | |||
| self.signature_input_keys = [] | |||
| self.signature_input_tensor_names = [] | |||
| self.signature_output_keys = [] | |||
| @@ -69,11 +67,11 @@ class TensorflowInferenceService(AbstractInferenceService): | |||
| self.session, [tf.compat.v1.saved_model.tag_constants.SERVING], self.model_path) | |||
| # 加载模型之前先校验用户传入signature name | |||
| if args.signature_name not in meta_graph.signature_def: | |||
| if self.args.signature_name not in meta_graph.signature_def: | |||
| log.error("==============> Invalid signature name <==================") | |||
| # 从signature中获取meta graph中输入和输出的节点信息 | |||
| signature = meta_graph.signature_def[args.signature_name] | |||
| signature = meta_graph.signature_def[self.args.signature_name] | |||
| input_keys, input_tensor_names = get_tensors(signature.inputs) | |||
| output_keys, output_tensor_names = get_tensors(signature.outputs) | |||
| @@ -87,14 +85,14 @@ class TensorflowInferenceService(AbstractInferenceService): | |||
| log.info("===============> load tensorflow model success <===============") | |||
| def inference(self, image): | |||
| data = {"image_name": image['image_name']} | |||
| data = {"data_name": image['data_name']} | |||
| # 获得用户输入的图片 | |||
| log.info("===============> start load " + image['image_name'] + " <===============") | |||
| log.info("===============> start load " + image['data_name'] + " <===============") | |||
| # 推理所需的输入,目前的分类预置模型都只有一个输入 | |||
| input_dict = {} | |||
| input_keys = self.signature_input_keys | |||
| input_data = {} | |||
| im = preprocess_input(self.load_image(image['image_path']), mode=args.prepare_mode) | |||
| im = preprocess_input(self.load_image(image['data_path']), mode=self.args.prepare_mode) | |||
| if len(list(im.shape)) == 3: | |||
| input_data[input_keys[0]] = np.expand_dims(im, axis=0) | |||
| @@ -43,52 +43,52 @@ def download_image(images_path): | |||
| save_image_dir + str(int(round(time.time() * MAX_TIME_LENGTH))) + "." + image_path.split("/")[-1].split(".")[-1]) | |||
| def upload_image(image_files): | |||
| def upload_data(files): | |||
| """ | |||
| 前端上传图片保存到本地 | |||
| """ | |||
| save_image_dir = "/usr/local/images/" | |||
| if not os.path.exists(save_image_dir): | |||
| os.mkdir(save_image_dir) | |||
| images = list() | |||
| for image_file in image_files: | |||
| save_data_dir = "/usr/local/data/" | |||
| if not os.path.exists(save_data_dir): | |||
| os.mkdir(save_data_dir) | |||
| data_list = list() | |||
| for file in files: | |||
| try: | |||
| suffix = Path(image_file.filename).suffix | |||
| with NamedTemporaryFile(delete=False, suffix=suffix, dir=save_image_dir) as tmp: | |||
| shutil.copyfileobj(image_file.file, tmp) | |||
| suffix = Path(file.filename).suffix | |||
| with NamedTemporaryFile(delete=False, suffix=suffix, dir=save_data_dir) as tmp: | |||
| shutil.copyfileobj(file.file, tmp) | |||
| tmp_file_name = Path(tmp.name).name | |||
| file = {"image_name": image_file.filename, "image_path": save_image_dir + tmp_file_name} | |||
| images.append(file) | |||
| data = {"data_name": file.filename, "data_path": save_data_dir + tmp_file_name} | |||
| data_list.append(data) | |||
| finally: | |||
| image_file.file.close() | |||
| return images | |||
| file.file.close | |||
| return data_list | |||
| def upload_image_by_base64(image_files): | |||
| def upload_image_by_base64(data_list): | |||
| """ | |||
| base64图片信息保存到本地 | |||
| """ | |||
| save_image_dir = "/usr/local/images/" | |||
| if not os.path.exists(save_image_dir): | |||
| os.mkdir(save_image_dir) | |||
| images = list() | |||
| for img_file in image_files: | |||
| file_path = save_image_dir + str(int(round(time.time() * MAX_TIME_LENGTH))) + "." + img_file.image_name.split(".")[-1] | |||
| img_data = base64.b64decode(img_file.image_file) | |||
| save_data_dir = "/usr/local/data/" | |||
| if not os.path.exists(save_data_dir): | |||
| os.mkdir(save_data_dir) | |||
| data_list_b64 = list() | |||
| for data in data_list: | |||
| file_path = save_data_dir + str(int(round(time.time() * MAX_TIME_LENGTH))) + "." + data.data_name.split(".")[-1] | |||
| file_b64 = base64.b64decode(data.data_file) | |||
| file = open(file_path, 'wb') | |||
| file.write(img_data) | |||
| file.write(file_b64) | |||
| file.close() | |||
| image = {"image_name": img_file.image_name, "image_path": file_path} | |||
| images.append(image) | |||
| return images | |||
| data_b64 = {"data_name": data.data_name, "data_path": file_path} | |||
| data_list_b64.append(data_b64) | |||
| return data_list_b64 | |||
| def writer_json_file(json_path, image_name, data): | |||
| def writer_json_file(json_path, data_name, data): | |||
| """ | |||
| 保存为json文件 | |||
| """ | |||
| if not os.path.exists(json_path): | |||
| os.mkdir(json_path) | |||
| filename = json_path + image_name + '.json' | |||
| filename = json_path + data_name + '.json' | |||
| with open(filename, 'w', encoding='utf-8') as file_obj: | |||
| file_obj.write(json.dumps(data, ensure_ascii=False)) | |||
| @@ -0,0 +1,88 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| """ | |||
| import re | |||
| import inspect | |||
| # from ac_opf.models import ac_model | |||
| class FindClassInFile: | |||
| """ | |||
| find class in the given module | |||
| ### note ### | |||
| There is only one class in the given module. | |||
| If there are more than one classes, all of them will be omit except the first | |||
| ############ | |||
| method: find | |||
| args: | |||
| module: object-> the given file or module | |||
| encoding: string-> "utf8" by default | |||
| output: tuple(string-> name of the class, object-> class) | |||
| usage: | |||
| # >>> import module | |||
| # | |||
| # >>> find_class_in_file = FindClassInFile() | |||
| # >>> cls = find_class_in_file.find(module) | |||
| # | |||
| # >>> cls_instance = cls[1](args) | |||
| """ | |||
| def __init__(self): | |||
| pass | |||
| def _open_file(self, path, encoding="utf8"): | |||
| with open(path, "r", encoding=encoding) as f: | |||
| data = f.readlines() | |||
| for line in data: | |||
| yield line | |||
| def find(self, module, encoding="utf8"): | |||
| path = module.__file__ | |||
| lines = self._open_file(path=path, encoding=encoding) | |||
| cls = "" | |||
| for line in lines: | |||
| if "class " in line: | |||
| cls = re.findall("class (.*?)[:(]", line)[0] | |||
| if cls: | |||
| break | |||
| return self._valid(module, cls) | |||
| def _valid(self, module, cls): | |||
| members = inspect.getmembers(module) | |||
| cand = [(i, j) for i, j in members if inspect.isclass(j) and (not inspect.isabstract(j)) and (i == cls)] | |||
| if not cand: | |||
| print("class not found in {}".format(module)) | |||
| return cand[0] | |||
| if __name__ == "__main__": | |||
| find_class_in_file = FindClassInFile() | |||
| # cls = find_class_in_file.find(ac_model) | |||
| # print(cls) | |||