From 857a57e03069b2b2ca807501921fe6786829c6b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=8B=E6=B1=9F=E5=A4=A9=E6=9E=A2?= Date: Wed, 30 Jun 2021 14:26:30 +0800 Subject: [PATCH] update serving --- tianshu_serving/README.md | 41 +-- tianshu_serving/batch_server.py | 32 +- tianshu_serving/config.py | 11 +- tianshu_serving/customize/__init__.py | 0 .../customize/common_inference_service.py | 84 ++++++ tianshu_serving/grpc_client.py | 16 +- tianshu_serving/grpc_server.py | 10 +- tianshu_serving/http_server.py | 15 +- tianshu_serving/logs/serving.log.2020-12-18 | 0 tianshu_serving/logs/serving.log.2021-01-14 | 1 + tianshu_serving/logs/serving.log.2021-02-03 | 1 + tianshu_serving/proto/inference.proto | 10 +- tianshu_serving/proto/inference_pb2.py | 273 +++++++++--------- tianshu_serving/proto/inference_pb2_grpc.py | 48 +-- .../service/common_inference_service.py | 87 ++++++ .../service/inference_service_manager.py | 51 ++-- .../service/oneflow_inference_service.py | 17 +- .../service/pytorch_inference_service.py | 34 +-- .../service/tensorflow_inference_service.py | 20 +- tianshu_serving/utils/file_utils.py | 54 ++-- tianshu_serving/utils/find_class_in_file.py | 88 ++++++ 21 files changed, 589 insertions(+), 304 deletions(-) create mode 100644 tianshu_serving/customize/__init__.py create mode 100644 tianshu_serving/customize/common_inference_service.py create mode 100644 tianshu_serving/logs/serving.log.2020-12-18 create mode 100644 tianshu_serving/logs/serving.log.2021-01-14 create mode 100644 tianshu_serving/logs/serving.log.2021-02-03 create mode 100644 tianshu_serving/service/common_inference_service.py create mode 100644 tianshu_serving/utils/find_class_in_file.py diff --git a/tianshu_serving/README.md b/tianshu_serving/README.md index 109cb2c..24ae9a1 100644 --- a/tianshu_serving/README.md +++ b/tianshu_serving/README.md @@ -5,15 +5,11 @@ + **支持oneflow、tensorflow、pytorch三种框架模型部署**
1、通过如下命令启动http在线推理服务 - - ``` - python http_server.py --platform='框架名称' --model_path='模型地址' - - ``` - + ``` + python http_server.py --platform='框架名称' --model_path='模型地址' + ``` 通过访问localhost:5000/docs进入swagger页面,调用localhost:5000/inference进行图片上传得道推理结果,结果如下所示: - - ``` + ``` { "image_name": "哈士奇.jpg", "predictions": [ @@ -39,28 +35,35 @@ } ] } - - ``` + ``` 2、同理通过如下命令启动grpc在线推理服务 - ``` - python grpc_server.py --platform='框架名称' --model_path='模型地址' - - ``` + ``` + python grpc_server.py --platform='框架名称' --model_path='模型地址' + ``` 再启动grpc_client.py进行上传图片推理得道结果,或者根据ip端口自行编写grpc客户端 - 3、支持多模型部署,可以自行配置config文件夹下的model_config_file.json进行多模型配置,启动http或grpc时输入不同的模型名称即可,或者自行修改inference接口入参来达到启动单一服务多模型推理的功能 ++ **支持多模型部署**
+ + 用户可以自行配置config文件夹下的model_config_file.json进行多模型配置,启动http或grpc时输入不同的模型名称即可,或者自行修改inference接口入参来达到启动单一服务多模型推理的功能 + **支持分布式模型部署推理**
需要推理大量图片时需要分布式推理功能,执行如下命令: - ``` - python batch_server.py --platform='框架名称' --model_path='模型地址' --input_path='批量图片地址' --output_path='输出JSON文件地址' - - ``` + ``` + python batch_server.py --platform='框架名称' --model_path='模型地址' --input_path='批量图片地址' --output_path='输出JSON文件地址' + ``` 输入的所有图片保存在input文件夹下,输入json文件保存在output_path文件夹,json名称与图片名称对应 + ++ **支持使用自定义推理脚本**
+ + 用户需要使用自定义推理脚本时,可根据common_inference_service.py脚本中注释的规则,自定义推理脚本,并替换原有的common_inference_service.py脚本。此外,在启动命令中添加use_script参数,即可在推理时使用自定义的推理脚本。命令如下所示: + ``` + python grpc_server.py --platform='框架名称' --model_path='模型地址' --user_script=True + ``` + + **代码还包含了各种参数配置,日志文件输出、是否启用TLS等**
\ No newline at end of file diff --git a/tianshu_serving/batch_server.py b/tianshu_serving/batch_server.py index 0cb9b15..ea1febc 100644 --- a/tianshu_serving/batch_server.py +++ b/tianshu_serving/batch_server.py @@ -27,49 +27,47 @@ log = Logger().logger def get_host_ip(): """ 查询本机ip地址 - :return: + return """ - global s - try: - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - ip = s.getsockname()[0] - finally: - s.close() + hostname = socket.gethostname() + ip = socket.gethostbyname(hostname) return ip -def read_directory(images_path): +def read_directory(data_path): """ 读取文件夹并进行拆分文件 :return: """ - files = os.listdir(images_path) + files = os.listdir(data_path) num_files = len(files) index_list = list(range(num_files)) - images = list() + data_list = list() for index in index_list: # 是否开启分布式 if args.enable_distributed: ip = get_host_ip() + log.info("NODE_IPS:{}", os.getenv('NODE_IPS')) ip_list = os.getenv('NODE_IPS').split(",") num_ips = len(ip_list) ip_index = ip_list.index(ip) if ip_index == index % num_ips: filename = files[index] - image = {"image_name": filename, "image_path": images_path + filename} - images.append(image) + data = {"data_name": filename, "data_path": data_path + filename} + data_list.append(data) else: filename = files[index] - image = {"image_name": filename, "image_path": images_path + filename} - images.append(image) - return images + data = {"data_name": filename, "data_path": data_path + filename} + data_list.append(data) + return data_list def main(): - images = read_directory(args.input_path) - inference_service.inference_and_save_json(args.model_name, args.output_path, images) + data_list = read_directory(args.input_path) + inference_service.inference_and_save_json(args.model_name, args.output_path, data_list) if args.enable_distributed: ip = get_host_ip() + log.info("NODE_IPS:{}", os.getenv('NODE_IPS')) ip_list = os.getenv('NODE_IPS').split(",") # 主节点必须等待从节点推理完成 if ip == ip_list[0]: diff --git a/tianshu_serving/config.py b/tianshu_serving/config.py index b0b1dfe..fa30728 100644 --- a/tianshu_serving/config.py +++ b/tianshu_serving/config.py @@ -58,12 +58,12 @@ def get_parser(parser=None): parser.add_argument("--job_name", type=str, default="inference", help="oneflow job name") parser.add_argument("--prepare_mode", type=str, default="tfhub", help="tensorflow prepare mode(tfhub、caffe、tf、torch)") - parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="is use gpu") + parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="whether to use gpu") parser.add_argument('--channel_last', type=str2bool, nargs='?', const=False, help='Whether to use use channel last mode(nhwc)') - parser.add_argument("--model_path", type=str, default="/usr/local/model/pytorch_models/resnet50/", + parser.add_argument("--model_path", type=str, default="/usr/local/work/models/pytorch_models/resnet50/", help="model load directory if need") - parser.add_argument("--image_path", type=str, default='/usr/local/data/fish.jpg', help="image path") + parser.add_argument("--data_path", type=str, default='/usr/local/work/dog.jpg', help="input data path") parser.add_argument("--reshape_size", type=int_list, default='[224]', help="The reshape size of the image(eg. 224)") parser.add_argument("--num_classes", type=int, default=1000, help="num of pic classes") @@ -78,8 +78,9 @@ def get_parser(parser=None): parser.add_argument("--model_config_file", type=str, default="", help="The file of the model config(eg. '')") parser.add_argument("--enable_distributed", type=ast.literal_eval, default=False, help="If enable use distributed " "environment") - parser.add_argument("--input_path", type=str, default="/usr/local/data/images/", help="images path") - parser.add_argument("--output_path", type=str, default="/usr/local/output_path/", help="json path") + parser.add_argument("--input_path", type=str, default="/usr/local/input/", help="input batch data path") + parser.add_argument("--output_path", type=str, default="/usr/local/output/", help="output json path") + parser.add_argument("--use_script", type=ast.literal_eval, default=False, help="whether to use custom inference script") return parser diff --git a/tianshu_serving/customize/__init__.py b/tianshu_serving/customize/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tianshu_serving/customize/common_inference_service.py b/tianshu_serving/customize/common_inference_service.py new file mode 100644 index 0000000..1f0b2c7 --- /dev/null +++ b/tianshu_serving/customize/common_inference_service.py @@ -0,0 +1,84 @@ +""" +Copyright 2020 Tianshu AI Platform. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import io +import torch +import torch.nn.functional as functional +from PIL import Image +from torchvision import transforms +from imagenet1000_clsidx_to_labels import clsidx_2_labels +from logger import Logger + +log = Logger().logger + +#只能定义一个class +class CommonInferenceService: + # __init__初始化方法中接收args参数(其中模型路径参数为args.model_path,是否使用gpu参数为args.use_gpu),并加载模型(方法用户可自定义) + def __init__(self, args): + self.args = args + self.model = self.load_model() + + + def load_data(self, data_path): + image = open(data_path, 'rb').read() + image = Image.open(io.BytesIO(image)) + if image.mode != 'RGB': + image = image.convert("RGB") + image = transforms.Resize((self.args.reshape_size[0], self.args.reshape_size[1]))(image) + image = transforms.ToTensor()(image) + image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image) + image = image[None] + if self.args.use_gpu: + image = image.cuda() + return image + + def load_model(self): + if os.path.isfile(self.args.model_path): + self.checkpoint = torch.load(self.args.model_path) + else: + for file in os.listdir(self.args.model_path): + self.checkpoint = torch.load(self.args.model_path + file) + model = self.checkpoint["model"] + model.load_state_dict(self.checkpoint['state_dict']) + for parameter in model.parameters(): + parameter.requires_grad = False + if self.args.use_gpu: + model.cuda() + model.eval() + return model + + # inference方法名称固定 + def inference(self, data): + result = {"data_name": data['data_name']} + data = self.load_data(data['data_path']) + preds = functional.softmax(self.model(data), dim=1) + predictions = torch.topk(preds.data, k=5, dim=1) + result['predictions'] = list() + for prob, label in zip(predictions[0][0], predictions[1][0]): + predictions = {"label": clsidx_2_labels[int(label)], "probability": "{:.3f}".format(float(prob))} + result['predictions'].append(predictions) + return result +if __name__=="__main__": + import argparse + parser = argparse.ArgumentParser(description='tianshu serving') + parser.add_argument('--model_path', type=str, default='./res4serving.pth', help="model path") + parser.add_argument('--use_gpu', type=bool, default=True, help="use gpu or not") + parser.add_argument('--reshape_size', type=list, default=[224,224], help="use gpu or not") + args = parser.parse_args() + server = CommonInferenceService(args) + + image_path = "./cat.jpg" + image = {"data_name": "cat.jpg", "data_path": image_path} + re = server.inference(image) + print(re) \ No newline at end of file diff --git a/tianshu_serving/grpc_client.py b/tianshu_serving/grpc_client.py index aef3279..c2b340a 100644 --- a/tianshu_serving/grpc_client.py +++ b/tianshu_serving/grpc_client.py @@ -21,8 +21,8 @@ log = Logger().logger parser = configs.get_parser() args = parser.parse_args() -_HOST = 'kohj2s.serving.dubhe.ai' -_PORT = '31365' +_HOST = '10.5.24.134' +_PORT = '8500' MAX_MESSAGE_LENGTH = 1024 * 1024 * 1024 # 可根据具体需求设置,此处设为1G @@ -41,12 +41,12 @@ def run(): ('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH), ], ) # 创建连接 client = inference_pb2_grpc.InferenceServiceStub(channel=channel) # 创建客户端 data_request = inference_pb2.DataRequest() - Image = data_request.images.add() - Image.image_file = str(base64.b64encode(open("F:\\Files\\pic\\哈士奇.jpg", "rb").read()), encoding='utf-8') - Image.image_name = "哈士奇.jpg" - Image = data_request.images.add() - Image.image_file = str(base64.b64encode(open("F:\\Files\\pic\\fish.jpg", "rb").read()), encoding='utf-8') - Image.image_name = "fish.jpg" + data1 = data_request.data_list.add() + data1.data_file = str(base64.b64encode(open("/usr/local/input/dog.jpg", "rb").read()), encoding='utf-8') + data1.data_name = "dog.jpg" + data2 = data_request.data_list.add() + data2.data_file = str(base64.b64encode(open("/usr/local/input/6.jpg", "rb").read()), encoding='utf-8') + data2.data_name = "6.jpg" response = client.inference(data_request) log.info(response.json_result.encode('utf-8').decode('unicode_escape')) diff --git a/tianshu_serving/grpc_server.py b/tianshu_serving/grpc_server.py index a736273..fc7480b 100644 --- a/tianshu_serving/grpc_server.py +++ b/tianshu_serving/grpc_server.py @@ -46,19 +46,21 @@ class InferenceService(inference_pb2_grpc.InferenceServiceServicer): 调用grpc方法进行推理 """ def inference(self, request, context): - image_files = request.images + data_list = request.data_list log.info("===============> grpc inference start <===============") try: - images = file_utils.upload_image_by_base64(image_files) # 上传图片到本地 + data_list_b64 = file_utils.upload_image_by_base64(data_list) # 上传图片到本地 except Exception as e: + log.error("upload data failed", e) return inference_pb2.DataResponse(json_result=json.dumps( - response_convert(Response(success=False, data=str(e), error="upload image fail")))) + response_convert(Response(success=False, data=str(e), error="upload data failed")))) try: - result = inference_service.inference(args.model_name, images) + result = inference_service.inference(args.model_name, data_list_b64) log.info("===============> grpc inference success <===============") return inference_pb2.DataResponse(json_result=json.dumps( response_convert(Response(success=True, data=result)))) except Exception as e: + log.error("inference fail", e) return inference_pb2.DataResponse(json_result=json.dumps( response_convert(Response(success=False, data=str(e), error="inference fail")))) diff --git a/tianshu_serving/http_server.py b/tianshu_serving/http_server.py index 907fb26..bb1be9c 100644 --- a/tianshu_serving/http_server.py +++ b/tianshu_serving/http_server.py @@ -49,7 +49,7 @@ async def inference(images_path: List[str] = None): threading.Thread(target=file_utils.download_image(images_path)) # 开启异步线程下载图片到本地 images = list() for image in images_path: - data = {"image_name": image.split("/")[-1], "image_path": image} + data = {"data_name": image.split("/")[-1], "data_path": image} images.append(data) try: data = inference_service.inference(args.model_name, images) @@ -59,17 +59,22 @@ async def inference(images_path: List[str] = None): @app.post("/inference") -async def inference(image_files: List[UploadFile] = File(...)): +async def inference(files: List[UploadFile] = File(...)): + """ + 上传本地文件推理 + """ log.info("===============> http inference start <===============") try: - images = file_utils.upload_image(image_files) # 上传图片到本地 + data_list = file_utils.upload_data(files) # 上传图片到本地 except Exception as e: - return Response(success=False, data=str(e), error="upload image fail") + log.error("upload data failed", e) + return Response(success=False, data=str(e), error="upload data failed") try: - result = inference_service.inference(args.model_name, images) + result = inference_service.inference(args.model_name, data_list) log.info("===============> http inference success <===============") return Response(success=True, data=result) except Exception as e: + log.error("inference fail", e) return Response(success=False, data=str(e), error="inference fail") diff --git a/tianshu_serving/logs/serving.log.2020-12-18 b/tianshu_serving/logs/serving.log.2020-12-18 new file mode 100644 index 0000000..e69de29 diff --git a/tianshu_serving/logs/serving.log.2021-01-14 b/tianshu_serving/logs/serving.log.2021-01-14 new file mode 100644 index 0000000..fcbb5ef --- /dev/null +++ b/tianshu_serving/logs/serving.log.2021-01-14 @@ -0,0 +1 @@ +2021-01-14 19:55:07,045 - E:/Python Project/TS_Serving/grpc_client.py[line:51] - INFO: {"success": true, "data": [{"image_name": "哈士奇.jpg", "predictions": [{"label": "Eskimo dog, husky", "probability": "0.665"}, {"label": "Siberian husky", "probability": "0.294"}, {"label": "dogsled, dog sled, dog sleigh", "probability": "0.024"}, {"label": "malamute, malemute, Alaskan malamute", "probability": "0.015"}, {"label": "timber wolf, grey wolf, gray wolf, Canis lupus", "probability": "0.000"}]}, {"image_name": "fish.jpg", "predictions": [{"label": "goldfish, Carassius auratus", "probability": "1.000"}, {"label": "tench, Tinca tinca", "probability": "0.000"}, {"label": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch", "probability": "0.000"}, {"label": "axolotl, mud puppy, Ambystoma mexicanum", "probability": "0.000"}, {"label": "barracouta, snoek", "probability": "0.000"}]}], "error": ""} diff --git a/tianshu_serving/logs/serving.log.2021-02-03 b/tianshu_serving/logs/serving.log.2021-02-03 new file mode 100644 index 0000000..cae2724 --- /dev/null +++ b/tianshu_serving/logs/serving.log.2021-02-03 @@ -0,0 +1 @@ +2021-02-03 15:05:00,069 - E:/Python Project/TS_Serving/grpc_client.py[line:51] - INFO: {"success": true, "data": [{"image_name": "哈士奇.jpg", "predictions": [{"label": "Eskimo dog, husky", "probability": "0.679"}, {"label": "Siberian husky", "probability": "0.213"}, {"label": "dogsled, dog sled, dog sleigh", "probability": "0.021"}, {"label": "malamute, malemute, Alaskan malamute", "probability": "0.006"}, {"label": "white wolf, Arctic wolf, Canis lupus tundrarum", "probability": "0.001"}]}, {"image_name": "fish.jpg", "predictions": [{"label": "goldfish, Carassius auratus", "probability": "0.871"}, {"label": "tench, Tinca tinca", "probability": "0.004"}, {"label": "puffer, pufferfish, blowfish, globefish", "probability": "0.002"}, {"label": "rock beauty, Holocanthus tricolor", "probability": "0.002"}, {"label": "barracouta, snoek", "probability": "0.001"}]}], "error": ""} diff --git a/tianshu_serving/proto/inference.proto b/tianshu_serving/proto/inference.proto index a824298..2dbb583 100644 --- a/tianshu_serving/proto/inference.proto +++ b/tianshu_serving/proto/inference.proto @@ -1,16 +1,16 @@ syntax = 'proto3'; - + service InferenceService { rpc inference(DataRequest) returns (DataResponse) {} } message DataRequest{ - repeated Image images = 1; + repeated Data data_list = 1; } -message Image { - string image_file = 1; - string image_name = 2; +message Data { + string data_file = 1; + string data_name = 2; } message DataResponse{ diff --git a/tianshu_serving/proto/inference_pb2.py b/tianshu_serving/proto/inference_pb2.py index 542cb07..467bf76 100644 --- a/tianshu_serving/proto/inference_pb2.py +++ b/tianshu_serving/proto/inference_pb2.py @@ -6,168 +6,177 @@ from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database - # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() + + + DESCRIPTOR = _descriptor.FileDescriptor( - name='inference.proto', - package='', - syntax='proto3', - serialized_options=None, - create_key=_descriptor._internal_create_key, - serialized_pb=b'\n\x0finference.proto\"%\n\x0b\x44\x61taRequest\x12\x16\n\x06images\x18\x01 \x03(\x0b\x32\x06.Image\"/\n\x05Image\x12\x12\n\nimage_file\x18\x01 \x01(\t\x12\x12\n\nimage_name\x18\x02 \x01(\t\"#\n\x0c\x44\x61taResponse\x12\x13\n\x0bjson_result\x18\x01 \x01(\t2>\n\x10InferenceService\x12*\n\tinference\x12\x0c.DataRequest\x1a\r.DataResponse\"\x00\x62\x06proto3' + name='inference.proto', + package='', + syntax='proto3', + serialized_options=None, + create_key=_descriptor._internal_create_key, + serialized_pb=b'\n\x0finference.proto\"\'\n\x0b\x44\x61taRequest\x12\x18\n\tdata_list\x18\x01 \x03(\x0b\x32\x05.Data\",\n\x04\x44\x61ta\x12\x11\n\tdata_file\x18\x01 \x01(\t\x12\x11\n\tdata_name\x18\x02 \x01(\t\"#\n\x0c\x44\x61taResponse\x12\x13\n\x0bjson_result\x18\x01 \x01(\t2>\n\x10InferenceService\x12*\n\tinference\x12\x0c.DataRequest\x1a\r.DataResponse\"\x00\x62\x06proto3' ) + + + _DATAREQUEST = _descriptor.Descriptor( - name='DataRequest', - full_name='DataRequest', - filename=None, - file=DESCRIPTOR, - containing_type=None, - create_key=_descriptor._internal_create_key, - fields=[ - _descriptor.FieldDescriptor( - name='images', full_name='DataRequest.images', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=19, - serialized_end=56, + name='DataRequest', + full_name='DataRequest', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='data_list', full_name='DataRequest.data_list', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=19, + serialized_end=58, ) -_IMAGE = _descriptor.Descriptor( - name='Image', - full_name='Image', - filename=None, - file=DESCRIPTOR, - containing_type=None, - create_key=_descriptor._internal_create_key, - fields=[ - _descriptor.FieldDescriptor( - name='image_file', full_name='Image.image_file', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - _descriptor.FieldDescriptor( - name='image_name', full_name='Image.image_name', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=58, - serialized_end=105, + +_DATA = _descriptor.Descriptor( + name='Data', + full_name='Data', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='data_file', full_name='Data.data_file', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='data_name', full_name='Data.data_name', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=60, + serialized_end=104, ) + _DATARESPONSE = _descriptor.Descriptor( - name='DataResponse', - full_name='DataResponse', - filename=None, - file=DESCRIPTOR, - containing_type=None, - create_key=_descriptor._internal_create_key, - fields=[ - _descriptor.FieldDescriptor( - name='json_result', full_name='DataResponse.json_result', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto3', - extension_ranges=[], - oneofs=[ - ], - serialized_start=107, - serialized_end=142, + name='DataResponse', + full_name='DataResponse', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='json_result', full_name='DataResponse.json_result', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=106, + serialized_end=141, ) -_DATAREQUEST.fields_by_name['images'].message_type = _IMAGE +_DATAREQUEST.fields_by_name['data_list'].message_type = _DATA DESCRIPTOR.message_types_by_name['DataRequest'] = _DATAREQUEST -DESCRIPTOR.message_types_by_name['Image'] = _IMAGE +DESCRIPTOR.message_types_by_name['Data'] = _DATA DESCRIPTOR.message_types_by_name['DataResponse'] = _DATARESPONSE _sym_db.RegisterFileDescriptor(DESCRIPTOR) DataRequest = _reflection.GeneratedProtocolMessageType('DataRequest', (_message.Message,), { - 'DESCRIPTOR': _DATAREQUEST, - '__module__': 'inference_pb2' - # @@protoc_insertion_point(class_scope:DataRequest) -}) + 'DESCRIPTOR' : _DATAREQUEST, + '__module__' : 'inference_pb2' + # @@protoc_insertion_point(class_scope:DataRequest) + }) _sym_db.RegisterMessage(DataRequest) -Image = _reflection.GeneratedProtocolMessageType('Image', (_message.Message,), { - 'DESCRIPTOR': _IMAGE, - '__module__': 'inference_pb2' - # @@protoc_insertion_point(class_scope:Image) -}) -_sym_db.RegisterMessage(Image) +Data = _reflection.GeneratedProtocolMessageType('Data', (_message.Message,), { + 'DESCRIPTOR' : _DATA, + '__module__' : 'inference_pb2' + # @@protoc_insertion_point(class_scope:Data) + }) +_sym_db.RegisterMessage(Data) DataResponse = _reflection.GeneratedProtocolMessageType('DataResponse', (_message.Message,), { - 'DESCRIPTOR': _DATARESPONSE, - '__module__': 'inference_pb2' - # @@protoc_insertion_point(class_scope:DataResponse) -}) + 'DESCRIPTOR' : _DATARESPONSE, + '__module__' : 'inference_pb2' + # @@protoc_insertion_point(class_scope:DataResponse) + }) _sym_db.RegisterMessage(DataResponse) + + _INFERENCESERVICE = _descriptor.ServiceDescriptor( - name='InferenceService', - full_name='InferenceService', - file=DESCRIPTOR, + name='InferenceService', + full_name='InferenceService', + file=DESCRIPTOR, + index=0, + serialized_options=None, + create_key=_descriptor._internal_create_key, + serialized_start=143, + serialized_end=205, + methods=[ + _descriptor.MethodDescriptor( + name='inference', + full_name='InferenceService.inference', index=0, + containing_service=None, + input_type=_DATAREQUEST, + output_type=_DATARESPONSE, serialized_options=None, create_key=_descriptor._internal_create_key, - serialized_start=144, - serialized_end=206, - methods=[ - _descriptor.MethodDescriptor( - name='inference', - full_name='InferenceService.inference', - index=0, - containing_service=None, - input_type=_DATAREQUEST, - output_type=_DATARESPONSE, - serialized_options=None, - create_key=_descriptor._internal_create_key, - ), - ]) + ), +]) _sym_db.RegisterServiceDescriptor(_INFERENCESERVICE) DESCRIPTOR.services_by_name['InferenceService'] = _INFERENCESERVICE diff --git a/tianshu_serving/proto/inference_pb2_grpc.py b/tianshu_serving/proto/inference_pb2_grpc.py index 0d829a8..cfaa1fe 100644 --- a/tianshu_serving/proto/inference_pb2_grpc.py +++ b/tianshu_serving/proto/inference_pb2_grpc.py @@ -15,10 +15,10 @@ class InferenceServiceStub(object): channel: A grpc.Channel. """ self.inference = channel.unary_unary( - '/InferenceService/inference', - request_serializer=inference__pb2.DataRequest.SerializeToString, - response_deserializer=inference__pb2.DataResponse.FromString, - ) + '/InferenceService/inference', + request_serializer=inference__pb2.DataRequest.SerializeToString, + response_deserializer=inference__pb2.DataResponse.FromString, + ) class InferenceServiceServicer(object): @@ -33,34 +33,34 @@ class InferenceServiceServicer(object): def add_InferenceServiceServicer_to_server(servicer, server): rpc_method_handlers = { - 'inference': grpc.unary_unary_rpc_method_handler( - servicer.inference, - request_deserializer=inference__pb2.DataRequest.FromString, - response_serializer=inference__pb2.DataResponse.SerializeToString, - ), + 'inference': grpc.unary_unary_rpc_method_handler( + servicer.inference, + request_deserializer=inference__pb2.DataRequest.FromString, + response_serializer=inference__pb2.DataResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( - 'InferenceService', rpc_method_handlers) + 'InferenceService', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) -# This class is part of an EXPERIMENTAL API. + # This class is part of an EXPERIMENTAL API. class InferenceService(object): """Missing associated documentation comment in .proto file.""" @staticmethod def inference(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): return grpc.experimental.unary_unary(request, target, '/InferenceService/inference', - inference__pb2.DataRequest.SerializeToString, - inference__pb2.DataResponse.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + inference__pb2.DataRequest.SerializeToString, + inference__pb2.DataResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/tianshu_serving/service/common_inference_service.py b/tianshu_serving/service/common_inference_service.py new file mode 100644 index 0000000..45ac9dd --- /dev/null +++ b/tianshu_serving/service/common_inference_service.py @@ -0,0 +1,87 @@ +""" +Copyright 2020 Tianshu AI Platform. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import io +import torch +import torch.nn.functional as functional +from PIL import Image +from torchvision import transforms +from imagenet1000_clsidx_to_labels import clsidx_2_labels +from logger import Logger + +log = Logger().logger + +# 只能定义一个class +class CommonInferenceService: + # 请在__init__初始化方法中接收args参数,并加载模型(其中模型路径参数为args.model_path,是否使用gpu参数为args.use_gpu,模型加载方法用户可自定义) + def __init__(self, args): + self.args = args + self.model = self.load_model() + + + def load_data(self, data_path): + image = open(data_path, 'rb').read() + image = Image.open(io.BytesIO(image)) + if image.mode != 'RGB': + image = image.convert("RGB") + image = transforms.Resize((self.args.reshape_size[0], self.args.reshape_size[1]))(image) + image = transforms.ToTensor()(image) + image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image) + image = image[None] + if self.args.use_gpu: + image = image.cuda() + return image + + def load_model(self): + if os.path.isfile(self.args.model_path): + self.checkpoint = torch.load(self.args.model_path) + else: + for file in os.listdir(self.args.model_path): + self.checkpoint = torch.load(self.args.model_path + file) + model = self.checkpoint["model"] + model.load_state_dict(self.checkpoint['state_dict']) + for parameter in model.parameters(): + parameter.requires_grad = False + if self.args.use_gpu: + model.cuda() + model.eval() + return model + + # inference方法名称固定 + def inference(self, data): + result = {"data_name": data['data_name']} + log.info("===============> start load " + data['data_name'] + " <===============") + data = self.load_data(data['data_path']) + preds = functional.softmax(self.model(data), dim=1) + predictions = torch.topk(preds.data, k=5, dim=1) + result['predictions'] = list() + for prob, label in zip(predictions[0][0], predictions[1][0]): + predictions = {"label": clsidx_2_labels[int(label)], "probability": "{:.3f}".format(float(prob))} + result['predictions'].append(predictions) + return result + +# 非必须,可用于本地调试 +if __name__=="__main__": + import argparse + parser = argparse.ArgumentParser(description='dubhe serving') + parser.add_argument('--model_path', type=str, default='./res4serving.pth', help="model path") + parser.add_argument('--use_gpu', type=bool, default=True, help="use gpu or not") + parser.add_argument('--reshape_size', type=list, default=[224,224], help="use gpu or not") + args = parser.parse_args() + server = CommonInferenceService(args) + + image_path = "./cat.jpg" + image = {"data_name": "cat.jpg", "data_path": image_path} + re = server.inference(image) + print(re) \ No newline at end of file diff --git a/tianshu_serving/service/inference_service_manager.py b/tianshu_serving/service/inference_service_manager.py index 765bc93..8c78b48 100644 --- a/tianshu_serving/service/inference_service_manager.py +++ b/tianshu_serving/service/inference_service_manager.py @@ -15,8 +15,10 @@ import time from service.oneflow_inference_service import OneFlowInferenceService from service.tensorflow_inference_service import TensorflowInferenceService from service.pytorch_inference_service import PytorchInferenceService +import service.common_inference_service as common_inference_service from logger import Logger from utils import file_utils +from utils.find_class_in_file import FindClassInFile log = Logger().logger @@ -36,47 +38,58 @@ class InferenceServiceManager: for model_config in model_config_list: model_name = model_config["model_name"] model_path = model_config["model_path"] + self.args.model_name = model_name + self.args.model_path = model_path model_platform = model_config.get("platform") if model_platform == "oneflow": - self.inference_service = OneFlowInferenceService(model_name, model_path) + self.inference_service = OneFlowInferenceService(self.args) elif model_platform == "tensorflow" or model_platform == "keras": - self.inference_service = TensorflowInferenceService(model_name, model_path) + self.inference_service = TensorflowInferenceService(self.args) elif model_platform == "pytorch": - self.inference_service = PytorchInferenceService(model_name, model_path) + self.inference_service = PytorchInferenceService(self.args) self.model_name_service_map[model_name] = self.inference_service else: # Read from command-line parameter - if self.args.platform == "oneflow": - self.inference_service = OneFlowInferenceService(self.args.model_name, self.args.model_path) - elif self.args.platform == "tensorflow" or self.args.platform == "keras": - self.inference_service = TensorflowInferenceService(self.args.model_name, self.args.model_path) - elif self.args.platform == "pytorch": - self.inference_service = PytorchInferenceService(self.args.model_name, self.args.model_path) + if self.args.use_script: + # 使用自定义推理脚本 + find_class_in_file = FindClassInFile() + cls = find_class_in_file.find(common_inference_service) + self.inference_service = cls[1](self.args) + + else : + # 使用默认推理脚本 + if self.args.platform == "oneflow": + self.inference_service = OneFlowInferenceService(self.args) + elif self.args.platform == "tensorflow" or self.args.platform == "keras": + self.inference_service = TensorflowInferenceService(self.args) + elif self.args.platform == "pytorch": + self.inference_service = PytorchInferenceService(self.args) + self.model_name_service_map[self.args.model_name] = self.inference_service - def inference(self, model_name, images): + def inference(self, model_name, data_list): """ 在线服务推理方法 """ inferenceService = self.model_name_service_map[model_name] result = list() - for image in images: - data = inferenceService.inference(image) - if len(images) == 1: - return data + for data in data_list: + output = inferenceService.inference(data) + if len(data_list) == 1: + return output else: - result.append(data) + result.append(output) return result - def inference_and_save_json(self, model_name, json_path, images): + def inference_and_save_json(self, model_name, json_path, data_list): """ 批量服务推理方法 """ inferenceService = self.model_name_service_map[model_name] - for image in images: - data = inferenceService.inference(image) - file_utils.writer_json_file(json_path, image['image_name'], data) + for data in data_list: + result = inferenceService.inference(data) + file_utils.writer_json_file(json_path, data['data_name'], result) time.sleep(1) diff --git a/tianshu_serving/service/oneflow_inference_service.py b/tianshu_serving/service/oneflow_inference_service.py index b47578f..bdad3f4 100644 --- a/tianshu_serving/service/oneflow_inference_service.py +++ b/tianshu_serving/service/oneflow_inference_service.py @@ -20,12 +20,8 @@ import google.protobuf.text_format as text_format import os from imagenet1000_clsidx_to_labels import clsidx_2_labels from logger import Logger -import config as configs from service.abstract_inference_service import AbstractInferenceService -parser = configs.get_parser() -args = parser.parse_args() - log = Logger().logger @@ -33,10 +29,11 @@ class OneFlowInferenceService(AbstractInferenceService): """ oneflow 框架推理service """ - def __init__(self, model_name, model_path): + def __init__(self, args): super().__init__() - self.model_name = model_name - self.model_path = model_path + self.args = args + self.model_name = args.model_name + self.model_path = args.model_path flow.clear_default_session() self.infer_session = flow.SimpleSession() self.load_model() @@ -91,9 +88,9 @@ class OneFlowInferenceService(AbstractInferenceService): return saved_model_proto def inference(self, image): - data = {"image_name": image['image_name']} - log.info("===============> start load " + image['image_name'] + " <===============") - images = self.load_image(image['image_path']) + data = {"data_name": image['data_name']} + log.info("===============> start load " + image['data_name'] + " <===============") + images = self.load_image(image['data_path']) predictions = self.infer_session.run('inference', image=images) diff --git a/tianshu_serving/service/pytorch_inference_service.py b/tianshu_serving/service/pytorch_inference_service.py index 10fd5b0..b91fe0b 100644 --- a/tianshu_serving/service/pytorch_inference_service.py +++ b/tianshu_serving/service/pytorch_inference_service.py @@ -16,15 +16,12 @@ import torch import torch.nn.functional as functional from PIL import Image from torchvision import transforms -import config import requests from imagenet1000_clsidx_to_labels import clsidx_2_labels from io import BytesIO from logger import Logger from service.abstract_inference_service import AbstractInferenceService -parser = config.get_parser() -args = parser.parse_args() log = Logger().logger @@ -33,10 +30,11 @@ class PytorchInferenceService(AbstractInferenceService): pytorch 框架推理service """ - def __init__(self, model_name, model_path): + def __init__(self, args): super().__init__() - self.model_name = model_name - self.model_path = model_path + self.args = args + self.model_name = args.model_name + self.model_path = args.model_path self.model = self.load_model() self.checkpoint = None @@ -52,38 +50,38 @@ class PytorchInferenceService(AbstractInferenceService): image = Image.open(io.BytesIO(image)) if image.mode != 'RGB': image = image.convert("RGB") - image = transforms.Resize((args.reshape_size[0], args.reshape_size[1]))(image) + image = transforms.Resize((self.args.reshape_size[0], self.args.reshape_size[1]))(image) image = transforms.ToTensor()(image) image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image) image = image[None] - if args.use_gpu: + if self.args.use_gpu: image = image.cuda() log.info("===============> load image success <===============") return image def load_model(self): - log.info("===============> start load pytorch model :" + args.model_path + " <===============") - if os.path.isfile(args.model_path): + log.info("===============> start load pytorch model :" + self.args.model_path + " <===============") + if os.path.isfile(self.args.model_path): self.checkpoint = torch.load(self.model_path) else: - for file in os.listdir(args.model_path): + for file in os.listdir(self.args.model_path): self.checkpoint = torch.load(self.model_path + file) - model = self.checkpoint[args.model_structure] + model = self.checkpoint[self.args.model_structure] model.load_state_dict(self.checkpoint['state_dict']) for parameter in model.parameters(): parameter.requires_grad = False - if args.use_gpu: + if self.args.use_gpu: model.cuda() model.eval() log.info("===============> load pytorch model success <===============") return model def inference(self, image): - data = {"image_name": image['image_name']} - log.info("===============> start load " + image['image_name'] + " <===============") - image = self.load_image(image['image_path']) - predis = functional.softmax(self.model(image), dim=1) - results = torch.topk(predis.data, k=5, dim=1) + data = {"data_name": image['data_name']} + log.info("===============> start load " + image['data_name'] + " <===============") + image = self.load_image(image['data_path']) + preds = functional.softmax(self.model(image), dim=1) + results = torch.topk(preds.data, k=5, dim=1) data['predictions'] = list() for prob, label in zip(results[0][0], results[1][0]): result = {"label": clsidx_2_labels[int(label)], "probability": "{:.3f}".format(float(prob))} diff --git a/tianshu_serving/service/tensorflow_inference_service.py b/tianshu_serving/service/tensorflow_inference_service.py index a04a98c..b7703ce 100644 --- a/tianshu_serving/service/tensorflow_inference_service.py +++ b/tianshu_serving/service/tensorflow_inference_service.py @@ -13,7 +13,6 @@ limitations under the License. import tensorflow as tf import requests import numpy as np -import config as configs from imagenet1000_clsidx_to_labels import clsidx_2_labels from service.abstract_inference_service import AbstractInferenceService from utils.imagenet_preprocessing_utils import preprocess_input @@ -21,8 +20,6 @@ from logger import Logger from PIL import Image from io import BytesIO -parser = configs.get_parser() -args = parser.parse_args() log = Logger().logger @@ -30,11 +27,12 @@ class TensorflowInferenceService(AbstractInferenceService): """ tensorflow 框架推理service """ - def __init__(self, model_name, model_path): + def __init__(self, args): super().__init__() self.session = tf.compat.v1.Session(graph=tf.Graph()) - self.model_name = model_name - self.model_path = model_path + self.args = args + self.model_name = args.model_name + self.model_path = args.model_path self.signature_input_keys = [] self.signature_input_tensor_names = [] self.signature_output_keys = [] @@ -69,11 +67,11 @@ class TensorflowInferenceService(AbstractInferenceService): self.session, [tf.compat.v1.saved_model.tag_constants.SERVING], self.model_path) # 加载模型之前先校验用户传入signature name - if args.signature_name not in meta_graph.signature_def: + if self.args.signature_name not in meta_graph.signature_def: log.error("==============> Invalid signature name <==================") # 从signature中获取meta graph中输入和输出的节点信息 - signature = meta_graph.signature_def[args.signature_name] + signature = meta_graph.signature_def[self.args.signature_name] input_keys, input_tensor_names = get_tensors(signature.inputs) output_keys, output_tensor_names = get_tensors(signature.outputs) @@ -87,14 +85,14 @@ class TensorflowInferenceService(AbstractInferenceService): log.info("===============> load tensorflow model success <===============") def inference(self, image): - data = {"image_name": image['image_name']} + data = {"data_name": image['data_name']} # 获得用户输入的图片 - log.info("===============> start load " + image['image_name'] + " <===============") + log.info("===============> start load " + image['data_name'] + " <===============") # 推理所需的输入,目前的分类预置模型都只有一个输入 input_dict = {} input_keys = self.signature_input_keys input_data = {} - im = preprocess_input(self.load_image(image['image_path']), mode=args.prepare_mode) + im = preprocess_input(self.load_image(image['data_path']), mode=self.args.prepare_mode) if len(list(im.shape)) == 3: input_data[input_keys[0]] = np.expand_dims(im, axis=0) diff --git a/tianshu_serving/utils/file_utils.py b/tianshu_serving/utils/file_utils.py index 688ff37..72e3bc9 100644 --- a/tianshu_serving/utils/file_utils.py +++ b/tianshu_serving/utils/file_utils.py @@ -43,52 +43,52 @@ def download_image(images_path): save_image_dir + str(int(round(time.time() * MAX_TIME_LENGTH))) + "." + image_path.split("/")[-1].split(".")[-1]) -def upload_image(image_files): +def upload_data(files): """ 前端上传图片保存到本地 """ - save_image_dir = "/usr/local/images/" - if not os.path.exists(save_image_dir): - os.mkdir(save_image_dir) - images = list() - for image_file in image_files: + save_data_dir = "/usr/local/data/" + if not os.path.exists(save_data_dir): + os.mkdir(save_data_dir) + data_list = list() + for file in files: try: - suffix = Path(image_file.filename).suffix - with NamedTemporaryFile(delete=False, suffix=suffix, dir=save_image_dir) as tmp: - shutil.copyfileobj(image_file.file, tmp) + suffix = Path(file.filename).suffix + with NamedTemporaryFile(delete=False, suffix=suffix, dir=save_data_dir) as tmp: + shutil.copyfileobj(file.file, tmp) tmp_file_name = Path(tmp.name).name - file = {"image_name": image_file.filename, "image_path": save_image_dir + tmp_file_name} - images.append(file) + data = {"data_name": file.filename, "data_path": save_data_dir + tmp_file_name} + data_list.append(data) finally: - image_file.file.close() - return images + file.file.close + return data_list -def upload_image_by_base64(image_files): +def upload_image_by_base64(data_list): """ base64图片信息保存到本地 """ - save_image_dir = "/usr/local/images/" - if not os.path.exists(save_image_dir): - os.mkdir(save_image_dir) - images = list() - for img_file in image_files: - file_path = save_image_dir + str(int(round(time.time() * MAX_TIME_LENGTH))) + "." + img_file.image_name.split(".")[-1] - img_data = base64.b64decode(img_file.image_file) + save_data_dir = "/usr/local/data/" + if not os.path.exists(save_data_dir): + os.mkdir(save_data_dir) + data_list_b64 = list() + for data in data_list: + file_path = save_data_dir + str(int(round(time.time() * MAX_TIME_LENGTH))) + "." + data.data_name.split(".")[-1] + file_b64 = base64.b64decode(data.data_file) file = open(file_path, 'wb') - file.write(img_data) + file.write(file_b64) file.close() - image = {"image_name": img_file.image_name, "image_path": file_path} - images.append(image) - return images + data_b64 = {"data_name": data.data_name, "data_path": file_path} + data_list_b64.append(data_b64) + return data_list_b64 -def writer_json_file(json_path, image_name, data): +def writer_json_file(json_path, data_name, data): """ 保存为json文件 """ if not os.path.exists(json_path): os.mkdir(json_path) - filename = json_path + image_name + '.json' + filename = json_path + data_name + '.json' with open(filename, 'w', encoding='utf-8') as file_obj: file_obj.write(json.dumps(data, ensure_ascii=False)) diff --git a/tianshu_serving/utils/find_class_in_file.py b/tianshu_serving/utils/find_class_in_file.py new file mode 100644 index 0000000..b85ee25 --- /dev/null +++ b/tianshu_serving/utils/find_class_in_file.py @@ -0,0 +1,88 @@ +""" +Copyright 2020 Tianshu AI Platform. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import re +import inspect +# from ac_opf.models import ac_model + + +class FindClassInFile: + """ + find class in the given module + + ### note ### + There is only one class in the given module. + If there are more than one classes, all of them will be omit except the first + ############ + + + method: find + args: + module: object-> the given file or module + encoding: string-> "utf8" by default + output: tuple(string-> name of the class, object-> class) + + + usage: + + # >>> import module + # + # >>> find_class_in_file = FindClassInFile() + # >>> cls = find_class_in_file.find(module) + # + # >>> cls_instance = cls[1](args) + + """ + + def __init__(self): + pass + + def _open_file(self, path, encoding="utf8"): + with open(path, "r", encoding=encoding) as f: + data = f.readlines() + + for line in data: + yield line + + def find(self, module, encoding="utf8"): + path = module.__file__ + lines = self._open_file(path=path, encoding=encoding) + + cls = "" + for line in lines: + if "class " in line: + cls = re.findall("class (.*?)[:(]", line)[0] + if cls: + break + + return self._valid(module, cls) + + def _valid(self, module, cls): + members = inspect.getmembers(module) + cand = [(i, j) for i, j in members if inspect.isclass(j) and (not inspect.isabstract(j)) and (i == cls)] + if not cand: + print("class not found in {}".format(module)) + return cand[0] + + +if __name__ == "__main__": + find_class_in_file = FindClassInFile() + # cls = find_class_in_file.find(ac_model) + + # print(cls) + + + + + + +