@@ -5,15 +5,11 @@ | |||||
+ **支持oneflow、tensorflow、pytorch三种框架模型部署** </br> | + **支持oneflow、tensorflow、pytorch三种框架模型部署** </br> | ||||
1、通过如下命令启动http在线推理服务 | 1、通过如下命令启动http在线推理服务 | ||||
``` | |||||
python http_server.py --platform='框架名称' --model_path='模型地址' | |||||
``` | |||||
``` | |||||
python http_server.py --platform='框架名称' --model_path='模型地址' | |||||
``` | |||||
通过访问localhost:5000/docs进入swagger页面,调用localhost:5000/inference进行图片上传得道推理结果,结果如下所示: | 通过访问localhost:5000/docs进入swagger页面,调用localhost:5000/inference进行图片上传得道推理结果,结果如下所示: | ||||
``` | |||||
``` | |||||
{ | { | ||||
"image_name": "哈士奇.jpg", | "image_name": "哈士奇.jpg", | ||||
"predictions": [ | "predictions": [ | ||||
@@ -39,28 +35,35 @@ | |||||
} | } | ||||
] | ] | ||||
} | } | ||||
``` | |||||
``` | |||||
2、同理通过如下命令启动grpc在线推理服务 | 2、同理通过如下命令启动grpc在线推理服务 | ||||
``` | |||||
python grpc_server.py --platform='框架名称' --model_path='模型地址' | |||||
``` | |||||
``` | |||||
python grpc_server.py --platform='框架名称' --model_path='模型地址' | |||||
``` | |||||
再启动grpc_client.py进行上传图片推理得道结果,或者根据ip端口自行编写grpc客户端 | 再启动grpc_client.py进行上传图片推理得道结果,或者根据ip端口自行编写grpc客户端 | ||||
3、支持多模型部署,可以自行配置config文件夹下的model_config_file.json进行多模型配置,启动http或grpc时输入不同的模型名称即可,或者自行修改inference接口入参来达到启动单一服务多模型推理的功能 | |||||
+ **支持多模型部署** </br> | |||||
用户可以自行配置config文件夹下的model_config_file.json进行多模型配置,启动http或grpc时输入不同的模型名称即可,或者自行修改inference接口入参来达到启动单一服务多模型推理的功能 | |||||
+ **支持分布式模型部署推理** </br> | + **支持分布式模型部署推理** </br> | ||||
需要推理大量图片时需要分布式推理功能,执行如下命令: | 需要推理大量图片时需要分布式推理功能,执行如下命令: | ||||
``` | |||||
python batch_server.py --platform='框架名称' --model_path='模型地址' --input_path='批量图片地址' --output_path='输出JSON文件地址' | |||||
``` | |||||
``` | |||||
python batch_server.py --platform='框架名称' --model_path='模型地址' --input_path='批量图片地址' --output_path='输出JSON文件地址' | |||||
``` | |||||
输入的所有图片保存在input文件夹下,输入json文件保存在output_path文件夹,json名称与图片名称对应 | 输入的所有图片保存在input文件夹下,输入json文件保存在output_path文件夹,json名称与图片名称对应 | ||||
+ **支持使用自定义推理脚本** </br> | |||||
用户需要使用自定义推理脚本时,可根据common_inference_service.py脚本中注释的规则,自定义推理脚本,并替换原有的common_inference_service.py脚本。此外,在启动命令中添加use_script参数,即可在推理时使用自定义的推理脚本。命令如下所示: | |||||
``` | |||||
python grpc_server.py --platform='框架名称' --model_path='模型地址' --user_script=True | |||||
``` | |||||
+ **代码还包含了各种参数配置,日志文件输出、是否启用TLS等** </br> | + **代码还包含了各种参数配置,日志文件输出、是否启用TLS等** </br> | ||||
@@ -27,49 +27,47 @@ log = Logger().logger | |||||
def get_host_ip(): | def get_host_ip(): | ||||
""" | """ | ||||
查询本机ip地址 | 查询本机ip地址 | ||||
:return: | |||||
return | |||||
""" | """ | ||||
global s | |||||
try: | |||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | |||||
ip = s.getsockname()[0] | |||||
finally: | |||||
s.close() | |||||
hostname = socket.gethostname() | |||||
ip = socket.gethostbyname(hostname) | |||||
return ip | return ip | ||||
def read_directory(images_path): | |||||
def read_directory(data_path): | |||||
""" | """ | ||||
读取文件夹并进行拆分文件 | 读取文件夹并进行拆分文件 | ||||
:return: | :return: | ||||
""" | """ | ||||
files = os.listdir(images_path) | |||||
files = os.listdir(data_path) | |||||
num_files = len(files) | num_files = len(files) | ||||
index_list = list(range(num_files)) | index_list = list(range(num_files)) | ||||
images = list() | |||||
data_list = list() | |||||
for index in index_list: | for index in index_list: | ||||
# 是否开启分布式 | # 是否开启分布式 | ||||
if args.enable_distributed: | if args.enable_distributed: | ||||
ip = get_host_ip() | ip = get_host_ip() | ||||
log.info("NODE_IPS:{}", os.getenv('NODE_IPS')) | |||||
ip_list = os.getenv('NODE_IPS').split(",") | ip_list = os.getenv('NODE_IPS').split(",") | ||||
num_ips = len(ip_list) | num_ips = len(ip_list) | ||||
ip_index = ip_list.index(ip) | ip_index = ip_list.index(ip) | ||||
if ip_index == index % num_ips: | if ip_index == index % num_ips: | ||||
filename = files[index] | filename = files[index] | ||||
image = {"image_name": filename, "image_path": images_path + filename} | |||||
images.append(image) | |||||
data = {"data_name": filename, "data_path": data_path + filename} | |||||
data_list.append(data) | |||||
else: | else: | ||||
filename = files[index] | filename = files[index] | ||||
image = {"image_name": filename, "image_path": images_path + filename} | |||||
images.append(image) | |||||
return images | |||||
data = {"data_name": filename, "data_path": data_path + filename} | |||||
data_list.append(data) | |||||
return data_list | |||||
def main(): | def main(): | ||||
images = read_directory(args.input_path) | |||||
inference_service.inference_and_save_json(args.model_name, args.output_path, images) | |||||
data_list = read_directory(args.input_path) | |||||
inference_service.inference_and_save_json(args.model_name, args.output_path, data_list) | |||||
if args.enable_distributed: | if args.enable_distributed: | ||||
ip = get_host_ip() | ip = get_host_ip() | ||||
log.info("NODE_IPS:{}", os.getenv('NODE_IPS')) | |||||
ip_list = os.getenv('NODE_IPS').split(",") | ip_list = os.getenv('NODE_IPS').split(",") | ||||
# 主节点必须等待从节点推理完成 | # 主节点必须等待从节点推理完成 | ||||
if ip == ip_list[0]: | if ip == ip_list[0]: | ||||
@@ -58,12 +58,12 @@ def get_parser(parser=None): | |||||
parser.add_argument("--job_name", type=str, default="inference", help="oneflow job name") | parser.add_argument("--job_name", type=str, default="inference", help="oneflow job name") | ||||
parser.add_argument("--prepare_mode", type=str, default="tfhub", | parser.add_argument("--prepare_mode", type=str, default="tfhub", | ||||
help="tensorflow prepare mode(tfhub、caffe、tf、torch)") | help="tensorflow prepare mode(tfhub、caffe、tf、torch)") | ||||
parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="is use gpu") | |||||
parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="whether to use gpu") | |||||
parser.add_argument('--channel_last', type=str2bool, nargs='?', const=False, | parser.add_argument('--channel_last', type=str2bool, nargs='?', const=False, | ||||
help='Whether to use use channel last mode(nhwc)') | help='Whether to use use channel last mode(nhwc)') | ||||
parser.add_argument("--model_path", type=str, default="/usr/local/model/pytorch_models/resnet50/", | |||||
parser.add_argument("--model_path", type=str, default="/usr/local/work/models/pytorch_models/resnet50/", | |||||
help="model load directory if need") | help="model load directory if need") | ||||
parser.add_argument("--image_path", type=str, default='/usr/local/data/fish.jpg', help="image path") | |||||
parser.add_argument("--data_path", type=str, default='/usr/local/work/dog.jpg', help="input data path") | |||||
parser.add_argument("--reshape_size", type=int_list, default='[224]', | parser.add_argument("--reshape_size", type=int_list, default='[224]', | ||||
help="The reshape size of the image(eg. 224)") | help="The reshape size of the image(eg. 224)") | ||||
parser.add_argument("--num_classes", type=int, default=1000, help="num of pic classes") | parser.add_argument("--num_classes", type=int, default=1000, help="num of pic classes") | ||||
@@ -78,8 +78,9 @@ def get_parser(parser=None): | |||||
parser.add_argument("--model_config_file", type=str, default="", help="The file of the model config(eg. '')") | parser.add_argument("--model_config_file", type=str, default="", help="The file of the model config(eg. '')") | ||||
parser.add_argument("--enable_distributed", type=ast.literal_eval, default=False, help="If enable use distributed " | parser.add_argument("--enable_distributed", type=ast.literal_eval, default=False, help="If enable use distributed " | ||||
"environment") | "environment") | ||||
parser.add_argument("--input_path", type=str, default="/usr/local/data/images/", help="images path") | |||||
parser.add_argument("--output_path", type=str, default="/usr/local/output_path/", help="json path") | |||||
parser.add_argument("--input_path", type=str, default="/usr/local/input/", help="input batch data path") | |||||
parser.add_argument("--output_path", type=str, default="/usr/local/output/", help="output json path") | |||||
parser.add_argument("--use_script", type=ast.literal_eval, default=False, help="whether to use custom inference script") | |||||
return parser | return parser | ||||
@@ -0,0 +1,84 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
""" | |||||
import os | |||||
import io | |||||
import torch | |||||
import torch.nn.functional as functional | |||||
from PIL import Image | |||||
from torchvision import transforms | |||||
from imagenet1000_clsidx_to_labels import clsidx_2_labels | |||||
from logger import Logger | |||||
log = Logger().logger | |||||
#只能定义一个class | |||||
class CommonInferenceService: | |||||
# __init__初始化方法中接收args参数(其中模型路径参数为args.model_path,是否使用gpu参数为args.use_gpu),并加载模型(方法用户可自定义) | |||||
def __init__(self, args): | |||||
self.args = args | |||||
self.model = self.load_model() | |||||
def load_data(self, data_path): | |||||
image = open(data_path, 'rb').read() | |||||
image = Image.open(io.BytesIO(image)) | |||||
if image.mode != 'RGB': | |||||
image = image.convert("RGB") | |||||
image = transforms.Resize((self.args.reshape_size[0], self.args.reshape_size[1]))(image) | |||||
image = transforms.ToTensor()(image) | |||||
image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image) | |||||
image = image[None] | |||||
if self.args.use_gpu: | |||||
image = image.cuda() | |||||
return image | |||||
def load_model(self): | |||||
if os.path.isfile(self.args.model_path): | |||||
self.checkpoint = torch.load(self.args.model_path) | |||||
else: | |||||
for file in os.listdir(self.args.model_path): | |||||
self.checkpoint = torch.load(self.args.model_path + file) | |||||
model = self.checkpoint["model"] | |||||
model.load_state_dict(self.checkpoint['state_dict']) | |||||
for parameter in model.parameters(): | |||||
parameter.requires_grad = False | |||||
if self.args.use_gpu: | |||||
model.cuda() | |||||
model.eval() | |||||
return model | |||||
# inference方法名称固定 | |||||
def inference(self, data): | |||||
result = {"data_name": data['data_name']} | |||||
data = self.load_data(data['data_path']) | |||||
preds = functional.softmax(self.model(data), dim=1) | |||||
predictions = torch.topk(preds.data, k=5, dim=1) | |||||
result['predictions'] = list() | |||||
for prob, label in zip(predictions[0][0], predictions[1][0]): | |||||
predictions = {"label": clsidx_2_labels[int(label)], "probability": "{:.3f}".format(float(prob))} | |||||
result['predictions'].append(predictions) | |||||
return result | |||||
if __name__=="__main__": | |||||
import argparse | |||||
parser = argparse.ArgumentParser(description='tianshu serving') | |||||
parser.add_argument('--model_path', type=str, default='./res4serving.pth', help="model path") | |||||
parser.add_argument('--use_gpu', type=bool, default=True, help="use gpu or not") | |||||
parser.add_argument('--reshape_size', type=list, default=[224,224], help="use gpu or not") | |||||
args = parser.parse_args() | |||||
server = CommonInferenceService(args) | |||||
image_path = "./cat.jpg" | |||||
image = {"data_name": "cat.jpg", "data_path": image_path} | |||||
re = server.inference(image) | |||||
print(re) |
@@ -21,8 +21,8 @@ log = Logger().logger | |||||
parser = configs.get_parser() | parser = configs.get_parser() | ||||
args = parser.parse_args() | args = parser.parse_args() | ||||
_HOST = 'kohj2s.serving.dubhe.ai' | |||||
_PORT = '31365' | |||||
_HOST = '10.5.24.134' | |||||
_PORT = '8500' | |||||
MAX_MESSAGE_LENGTH = 1024 * 1024 * 1024 # 可根据具体需求设置,此处设为1G | MAX_MESSAGE_LENGTH = 1024 * 1024 * 1024 # 可根据具体需求设置,此处设为1G | ||||
@@ -41,12 +41,12 @@ def run(): | |||||
('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH), ], ) # 创建连接 | ('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH), ], ) # 创建连接 | ||||
client = inference_pb2_grpc.InferenceServiceStub(channel=channel) # 创建客户端 | client = inference_pb2_grpc.InferenceServiceStub(channel=channel) # 创建客户端 | ||||
data_request = inference_pb2.DataRequest() | data_request = inference_pb2.DataRequest() | ||||
Image = data_request.images.add() | |||||
Image.image_file = str(base64.b64encode(open("F:\\Files\\pic\\哈士奇.jpg", "rb").read()), encoding='utf-8') | |||||
Image.image_name = "哈士奇.jpg" | |||||
Image = data_request.images.add() | |||||
Image.image_file = str(base64.b64encode(open("F:\\Files\\pic\\fish.jpg", "rb").read()), encoding='utf-8') | |||||
Image.image_name = "fish.jpg" | |||||
data1 = data_request.data_list.add() | |||||
data1.data_file = str(base64.b64encode(open("/usr/local/input/dog.jpg", "rb").read()), encoding='utf-8') | |||||
data1.data_name = "dog.jpg" | |||||
data2 = data_request.data_list.add() | |||||
data2.data_file = str(base64.b64encode(open("/usr/local/input/6.jpg", "rb").read()), encoding='utf-8') | |||||
data2.data_name = "6.jpg" | |||||
response = client.inference(data_request) | response = client.inference(data_request) | ||||
log.info(response.json_result.encode('utf-8').decode('unicode_escape')) | log.info(response.json_result.encode('utf-8').decode('unicode_escape')) | ||||
@@ -46,19 +46,21 @@ class InferenceService(inference_pb2_grpc.InferenceServiceServicer): | |||||
调用grpc方法进行推理 | 调用grpc方法进行推理 | ||||
""" | """ | ||||
def inference(self, request, context): | def inference(self, request, context): | ||||
image_files = request.images | |||||
data_list = request.data_list | |||||
log.info("===============> grpc inference start <===============") | log.info("===============> grpc inference start <===============") | ||||
try: | try: | ||||
images = file_utils.upload_image_by_base64(image_files) # 上传图片到本地 | |||||
data_list_b64 = file_utils.upload_image_by_base64(data_list) # 上传图片到本地 | |||||
except Exception as e: | except Exception as e: | ||||
log.error("upload data failed", e) | |||||
return inference_pb2.DataResponse(json_result=json.dumps( | return inference_pb2.DataResponse(json_result=json.dumps( | ||||
response_convert(Response(success=False, data=str(e), error="upload image fail")))) | |||||
response_convert(Response(success=False, data=str(e), error="upload data failed")))) | |||||
try: | try: | ||||
result = inference_service.inference(args.model_name, images) | |||||
result = inference_service.inference(args.model_name, data_list_b64) | |||||
log.info("===============> grpc inference success <===============") | log.info("===============> grpc inference success <===============") | ||||
return inference_pb2.DataResponse(json_result=json.dumps( | return inference_pb2.DataResponse(json_result=json.dumps( | ||||
response_convert(Response(success=True, data=result)))) | response_convert(Response(success=True, data=result)))) | ||||
except Exception as e: | except Exception as e: | ||||
log.error("inference fail", e) | |||||
return inference_pb2.DataResponse(json_result=json.dumps( | return inference_pb2.DataResponse(json_result=json.dumps( | ||||
response_convert(Response(success=False, data=str(e), error="inference fail")))) | response_convert(Response(success=False, data=str(e), error="inference fail")))) | ||||
@@ -49,7 +49,7 @@ async def inference(images_path: List[str] = None): | |||||
threading.Thread(target=file_utils.download_image(images_path)) # 开启异步线程下载图片到本地 | threading.Thread(target=file_utils.download_image(images_path)) # 开启异步线程下载图片到本地 | ||||
images = list() | images = list() | ||||
for image in images_path: | for image in images_path: | ||||
data = {"image_name": image.split("/")[-1], "image_path": image} | |||||
data = {"data_name": image.split("/")[-1], "data_path": image} | |||||
images.append(data) | images.append(data) | ||||
try: | try: | ||||
data = inference_service.inference(args.model_name, images) | data = inference_service.inference(args.model_name, images) | ||||
@@ -59,17 +59,22 @@ async def inference(images_path: List[str] = None): | |||||
@app.post("/inference") | @app.post("/inference") | ||||
async def inference(image_files: List[UploadFile] = File(...)): | |||||
async def inference(files: List[UploadFile] = File(...)): | |||||
""" | |||||
上传本地文件推理 | |||||
""" | |||||
log.info("===============> http inference start <===============") | log.info("===============> http inference start <===============") | ||||
try: | try: | ||||
images = file_utils.upload_image(image_files) # 上传图片到本地 | |||||
data_list = file_utils.upload_data(files) # 上传图片到本地 | |||||
except Exception as e: | except Exception as e: | ||||
return Response(success=False, data=str(e), error="upload image fail") | |||||
log.error("upload data failed", e) | |||||
return Response(success=False, data=str(e), error="upload data failed") | |||||
try: | try: | ||||
result = inference_service.inference(args.model_name, images) | |||||
result = inference_service.inference(args.model_name, data_list) | |||||
log.info("===============> http inference success <===============") | log.info("===============> http inference success <===============") | ||||
return Response(success=True, data=result) | return Response(success=True, data=result) | ||||
except Exception as e: | except Exception as e: | ||||
log.error("inference fail", e) | |||||
return Response(success=False, data=str(e), error="inference fail") | return Response(success=False, data=str(e), error="inference fail") | ||||
@@ -0,0 +1 @@ | |||||
2021-01-14 19:55:07,045 - E:/Python Project/TS_Serving/grpc_client.py[line:51] - INFO: {"success": true, "data": [{"image_name": "哈士奇.jpg", "predictions": [{"label": "Eskimo dog, husky", "probability": "0.665"}, {"label": "Siberian husky", "probability": "0.294"}, {"label": "dogsled, dog sled, dog sleigh", "probability": "0.024"}, {"label": "malamute, malemute, Alaskan malamute", "probability": "0.015"}, {"label": "timber wolf, grey wolf, gray wolf, Canis lupus", "probability": "0.000"}]}, {"image_name": "fish.jpg", "predictions": [{"label": "goldfish, Carassius auratus", "probability": "1.000"}, {"label": "tench, Tinca tinca", "probability": "0.000"}, {"label": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch", "probability": "0.000"}, {"label": "axolotl, mud puppy, Ambystoma mexicanum", "probability": "0.000"}, {"label": "barracouta, snoek", "probability": "0.000"}]}], "error": ""} |
@@ -0,0 +1 @@ | |||||
2021-02-03 15:05:00,069 - E:/Python Project/TS_Serving/grpc_client.py[line:51] - INFO: {"success": true, "data": [{"image_name": "哈士奇.jpg", "predictions": [{"label": "Eskimo dog, husky", "probability": "0.679"}, {"label": "Siberian husky", "probability": "0.213"}, {"label": "dogsled, dog sled, dog sleigh", "probability": "0.021"}, {"label": "malamute, malemute, Alaskan malamute", "probability": "0.006"}, {"label": "white wolf, Arctic wolf, Canis lupus tundrarum", "probability": "0.001"}]}, {"image_name": "fish.jpg", "predictions": [{"label": "goldfish, Carassius auratus", "probability": "0.871"}, {"label": "tench, Tinca tinca", "probability": "0.004"}, {"label": "puffer, pufferfish, blowfish, globefish", "probability": "0.002"}, {"label": "rock beauty, Holocanthus tricolor", "probability": "0.002"}, {"label": "barracouta, snoek", "probability": "0.001"}]}], "error": ""} |
@@ -1,16 +1,16 @@ | |||||
syntax = 'proto3'; | syntax = 'proto3'; | ||||
service InferenceService { | service InferenceService { | ||||
rpc inference(DataRequest) returns (DataResponse) {} | rpc inference(DataRequest) returns (DataResponse) {} | ||||
} | } | ||||
message DataRequest{ | message DataRequest{ | ||||
repeated Image images = 1; | |||||
repeated Data data_list = 1; | |||||
} | } | ||||
message Image { | |||||
string image_file = 1; | |||||
string image_name = 2; | |||||
message Data { | |||||
string data_file = 1; | |||||
string data_name = 2; | |||||
} | } | ||||
message DataResponse{ | message DataResponse{ | ||||
@@ -6,168 +6,177 @@ from google.protobuf import descriptor as _descriptor | |||||
from google.protobuf import message as _message | from google.protobuf import message as _message | ||||
from google.protobuf import reflection as _reflection | from google.protobuf import reflection as _reflection | ||||
from google.protobuf import symbol_database as _symbol_database | from google.protobuf import symbol_database as _symbol_database | ||||
# @@protoc_insertion_point(imports) | # @@protoc_insertion_point(imports) | ||||
_sym_db = _symbol_database.Default() | _sym_db = _symbol_database.Default() | ||||
DESCRIPTOR = _descriptor.FileDescriptor( | DESCRIPTOR = _descriptor.FileDescriptor( | ||||
name='inference.proto', | |||||
package='', | |||||
syntax='proto3', | |||||
serialized_options=None, | |||||
create_key=_descriptor._internal_create_key, | |||||
serialized_pb=b'\n\x0finference.proto\"%\n\x0b\x44\x61taRequest\x12\x16\n\x06images\x18\x01 \x03(\x0b\x32\x06.Image\"/\n\x05Image\x12\x12\n\nimage_file\x18\x01 \x01(\t\x12\x12\n\nimage_name\x18\x02 \x01(\t\"#\n\x0c\x44\x61taResponse\x12\x13\n\x0bjson_result\x18\x01 \x01(\t2>\n\x10InferenceService\x12*\n\tinference\x12\x0c.DataRequest\x1a\r.DataResponse\"\x00\x62\x06proto3' | |||||
name='inference.proto', | |||||
package='', | |||||
syntax='proto3', | |||||
serialized_options=None, | |||||
create_key=_descriptor._internal_create_key, | |||||
serialized_pb=b'\n\x0finference.proto\"\'\n\x0b\x44\x61taRequest\x12\x18\n\tdata_list\x18\x01 \x03(\x0b\x32\x05.Data\",\n\x04\x44\x61ta\x12\x11\n\tdata_file\x18\x01 \x01(\t\x12\x11\n\tdata_name\x18\x02 \x01(\t\"#\n\x0c\x44\x61taResponse\x12\x13\n\x0bjson_result\x18\x01 \x01(\t2>\n\x10InferenceService\x12*\n\tinference\x12\x0c.DataRequest\x1a\r.DataResponse\"\x00\x62\x06proto3' | |||||
) | ) | ||||
_DATAREQUEST = _descriptor.Descriptor( | _DATAREQUEST = _descriptor.Descriptor( | ||||
name='DataRequest', | |||||
full_name='DataRequest', | |||||
filename=None, | |||||
file=DESCRIPTOR, | |||||
containing_type=None, | |||||
create_key=_descriptor._internal_create_key, | |||||
fields=[ | |||||
_descriptor.FieldDescriptor( | |||||
name='images', full_name='DataRequest.images', index=0, | |||||
number=1, type=11, cpp_type=10, label=3, | |||||
has_default_value=False, default_value=[], | |||||
message_type=None, enum_type=None, containing_type=None, | |||||
is_extension=False, extension_scope=None, | |||||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), | |||||
], | |||||
extensions=[ | |||||
], | |||||
nested_types=[], | |||||
enum_types=[ | |||||
], | |||||
serialized_options=None, | |||||
is_extendable=False, | |||||
syntax='proto3', | |||||
extension_ranges=[], | |||||
oneofs=[ | |||||
], | |||||
serialized_start=19, | |||||
serialized_end=56, | |||||
name='DataRequest', | |||||
full_name='DataRequest', | |||||
filename=None, | |||||
file=DESCRIPTOR, | |||||
containing_type=None, | |||||
create_key=_descriptor._internal_create_key, | |||||
fields=[ | |||||
_descriptor.FieldDescriptor( | |||||
name='data_list', full_name='DataRequest.data_list', index=0, | |||||
number=1, type=11, cpp_type=10, label=3, | |||||
has_default_value=False, default_value=[], | |||||
message_type=None, enum_type=None, containing_type=None, | |||||
is_extension=False, extension_scope=None, | |||||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), | |||||
], | |||||
extensions=[ | |||||
], | |||||
nested_types=[], | |||||
enum_types=[ | |||||
], | |||||
serialized_options=None, | |||||
is_extendable=False, | |||||
syntax='proto3', | |||||
extension_ranges=[], | |||||
oneofs=[ | |||||
], | |||||
serialized_start=19, | |||||
serialized_end=58, | |||||
) | ) | ||||
_IMAGE = _descriptor.Descriptor( | |||||
name='Image', | |||||
full_name='Image', | |||||
filename=None, | |||||
file=DESCRIPTOR, | |||||
containing_type=None, | |||||
create_key=_descriptor._internal_create_key, | |||||
fields=[ | |||||
_descriptor.FieldDescriptor( | |||||
name='image_file', full_name='Image.image_file', index=0, | |||||
number=1, type=9, cpp_type=9, label=1, | |||||
has_default_value=False, default_value=b"".decode('utf-8'), | |||||
message_type=None, enum_type=None, containing_type=None, | |||||
is_extension=False, extension_scope=None, | |||||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), | |||||
_descriptor.FieldDescriptor( | |||||
name='image_name', full_name='Image.image_name', index=1, | |||||
number=2, type=9, cpp_type=9, label=1, | |||||
has_default_value=False, default_value=b"".decode('utf-8'), | |||||
message_type=None, enum_type=None, containing_type=None, | |||||
is_extension=False, extension_scope=None, | |||||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), | |||||
], | |||||
extensions=[ | |||||
], | |||||
nested_types=[], | |||||
enum_types=[ | |||||
], | |||||
serialized_options=None, | |||||
is_extendable=False, | |||||
syntax='proto3', | |||||
extension_ranges=[], | |||||
oneofs=[ | |||||
], | |||||
serialized_start=58, | |||||
serialized_end=105, | |||||
_DATA = _descriptor.Descriptor( | |||||
name='Data', | |||||
full_name='Data', | |||||
filename=None, | |||||
file=DESCRIPTOR, | |||||
containing_type=None, | |||||
create_key=_descriptor._internal_create_key, | |||||
fields=[ | |||||
_descriptor.FieldDescriptor( | |||||
name='data_file', full_name='Data.data_file', index=0, | |||||
number=1, type=9, cpp_type=9, label=1, | |||||
has_default_value=False, default_value=b"".decode('utf-8'), | |||||
message_type=None, enum_type=None, containing_type=None, | |||||
is_extension=False, extension_scope=None, | |||||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), | |||||
_descriptor.FieldDescriptor( | |||||
name='data_name', full_name='Data.data_name', index=1, | |||||
number=2, type=9, cpp_type=9, label=1, | |||||
has_default_value=False, default_value=b"".decode('utf-8'), | |||||
message_type=None, enum_type=None, containing_type=None, | |||||
is_extension=False, extension_scope=None, | |||||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), | |||||
], | |||||
extensions=[ | |||||
], | |||||
nested_types=[], | |||||
enum_types=[ | |||||
], | |||||
serialized_options=None, | |||||
is_extendable=False, | |||||
syntax='proto3', | |||||
extension_ranges=[], | |||||
oneofs=[ | |||||
], | |||||
serialized_start=60, | |||||
serialized_end=104, | |||||
) | ) | ||||
_DATARESPONSE = _descriptor.Descriptor( | _DATARESPONSE = _descriptor.Descriptor( | ||||
name='DataResponse', | |||||
full_name='DataResponse', | |||||
filename=None, | |||||
file=DESCRIPTOR, | |||||
containing_type=None, | |||||
create_key=_descriptor._internal_create_key, | |||||
fields=[ | |||||
_descriptor.FieldDescriptor( | |||||
name='json_result', full_name='DataResponse.json_result', index=0, | |||||
number=1, type=9, cpp_type=9, label=1, | |||||
has_default_value=False, default_value=b"".decode('utf-8'), | |||||
message_type=None, enum_type=None, containing_type=None, | |||||
is_extension=False, extension_scope=None, | |||||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), | |||||
], | |||||
extensions=[ | |||||
], | |||||
nested_types=[], | |||||
enum_types=[ | |||||
], | |||||
serialized_options=None, | |||||
is_extendable=False, | |||||
syntax='proto3', | |||||
extension_ranges=[], | |||||
oneofs=[ | |||||
], | |||||
serialized_start=107, | |||||
serialized_end=142, | |||||
name='DataResponse', | |||||
full_name='DataResponse', | |||||
filename=None, | |||||
file=DESCRIPTOR, | |||||
containing_type=None, | |||||
create_key=_descriptor._internal_create_key, | |||||
fields=[ | |||||
_descriptor.FieldDescriptor( | |||||
name='json_result', full_name='DataResponse.json_result', index=0, | |||||
number=1, type=9, cpp_type=9, label=1, | |||||
has_default_value=False, default_value=b"".decode('utf-8'), | |||||
message_type=None, enum_type=None, containing_type=None, | |||||
is_extension=False, extension_scope=None, | |||||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), | |||||
], | |||||
extensions=[ | |||||
], | |||||
nested_types=[], | |||||
enum_types=[ | |||||
], | |||||
serialized_options=None, | |||||
is_extendable=False, | |||||
syntax='proto3', | |||||
extension_ranges=[], | |||||
oneofs=[ | |||||
], | |||||
serialized_start=106, | |||||
serialized_end=141, | |||||
) | ) | ||||
_DATAREQUEST.fields_by_name['images'].message_type = _IMAGE | |||||
_DATAREQUEST.fields_by_name['data_list'].message_type = _DATA | |||||
DESCRIPTOR.message_types_by_name['DataRequest'] = _DATAREQUEST | DESCRIPTOR.message_types_by_name['DataRequest'] = _DATAREQUEST | ||||
DESCRIPTOR.message_types_by_name['Image'] = _IMAGE | |||||
DESCRIPTOR.message_types_by_name['Data'] = _DATA | |||||
DESCRIPTOR.message_types_by_name['DataResponse'] = _DATARESPONSE | DESCRIPTOR.message_types_by_name['DataResponse'] = _DATARESPONSE | ||||
_sym_db.RegisterFileDescriptor(DESCRIPTOR) | _sym_db.RegisterFileDescriptor(DESCRIPTOR) | ||||
DataRequest = _reflection.GeneratedProtocolMessageType('DataRequest', (_message.Message,), { | DataRequest = _reflection.GeneratedProtocolMessageType('DataRequest', (_message.Message,), { | ||||
'DESCRIPTOR': _DATAREQUEST, | |||||
'__module__': 'inference_pb2' | |||||
# @@protoc_insertion_point(class_scope:DataRequest) | |||||
}) | |||||
'DESCRIPTOR' : _DATAREQUEST, | |||||
'__module__' : 'inference_pb2' | |||||
# @@protoc_insertion_point(class_scope:DataRequest) | |||||
}) | |||||
_sym_db.RegisterMessage(DataRequest) | _sym_db.RegisterMessage(DataRequest) | ||||
Image = _reflection.GeneratedProtocolMessageType('Image', (_message.Message,), { | |||||
'DESCRIPTOR': _IMAGE, | |||||
'__module__': 'inference_pb2' | |||||
# @@protoc_insertion_point(class_scope:Image) | |||||
}) | |||||
_sym_db.RegisterMessage(Image) | |||||
Data = _reflection.GeneratedProtocolMessageType('Data', (_message.Message,), { | |||||
'DESCRIPTOR' : _DATA, | |||||
'__module__' : 'inference_pb2' | |||||
# @@protoc_insertion_point(class_scope:Data) | |||||
}) | |||||
_sym_db.RegisterMessage(Data) | |||||
DataResponse = _reflection.GeneratedProtocolMessageType('DataResponse', (_message.Message,), { | DataResponse = _reflection.GeneratedProtocolMessageType('DataResponse', (_message.Message,), { | ||||
'DESCRIPTOR': _DATARESPONSE, | |||||
'__module__': 'inference_pb2' | |||||
# @@protoc_insertion_point(class_scope:DataResponse) | |||||
}) | |||||
'DESCRIPTOR' : _DATARESPONSE, | |||||
'__module__' : 'inference_pb2' | |||||
# @@protoc_insertion_point(class_scope:DataResponse) | |||||
}) | |||||
_sym_db.RegisterMessage(DataResponse) | _sym_db.RegisterMessage(DataResponse) | ||||
_INFERENCESERVICE = _descriptor.ServiceDescriptor( | _INFERENCESERVICE = _descriptor.ServiceDescriptor( | ||||
name='InferenceService', | |||||
full_name='InferenceService', | |||||
file=DESCRIPTOR, | |||||
name='InferenceService', | |||||
full_name='InferenceService', | |||||
file=DESCRIPTOR, | |||||
index=0, | |||||
serialized_options=None, | |||||
create_key=_descriptor._internal_create_key, | |||||
serialized_start=143, | |||||
serialized_end=205, | |||||
methods=[ | |||||
_descriptor.MethodDescriptor( | |||||
name='inference', | |||||
full_name='InferenceService.inference', | |||||
index=0, | index=0, | ||||
containing_service=None, | |||||
input_type=_DATAREQUEST, | |||||
output_type=_DATARESPONSE, | |||||
serialized_options=None, | serialized_options=None, | ||||
create_key=_descriptor._internal_create_key, | create_key=_descriptor._internal_create_key, | ||||
serialized_start=144, | |||||
serialized_end=206, | |||||
methods=[ | |||||
_descriptor.MethodDescriptor( | |||||
name='inference', | |||||
full_name='InferenceService.inference', | |||||
index=0, | |||||
containing_service=None, | |||||
input_type=_DATAREQUEST, | |||||
output_type=_DATARESPONSE, | |||||
serialized_options=None, | |||||
create_key=_descriptor._internal_create_key, | |||||
), | |||||
]) | |||||
), | |||||
]) | |||||
_sym_db.RegisterServiceDescriptor(_INFERENCESERVICE) | _sym_db.RegisterServiceDescriptor(_INFERENCESERVICE) | ||||
DESCRIPTOR.services_by_name['InferenceService'] = _INFERENCESERVICE | DESCRIPTOR.services_by_name['InferenceService'] = _INFERENCESERVICE | ||||
@@ -15,10 +15,10 @@ class InferenceServiceStub(object): | |||||
channel: A grpc.Channel. | channel: A grpc.Channel. | ||||
""" | """ | ||||
self.inference = channel.unary_unary( | self.inference = channel.unary_unary( | ||||
'/InferenceService/inference', | |||||
request_serializer=inference__pb2.DataRequest.SerializeToString, | |||||
response_deserializer=inference__pb2.DataResponse.FromString, | |||||
) | |||||
'/InferenceService/inference', | |||||
request_serializer=inference__pb2.DataRequest.SerializeToString, | |||||
response_deserializer=inference__pb2.DataResponse.FromString, | |||||
) | |||||
class InferenceServiceServicer(object): | class InferenceServiceServicer(object): | ||||
@@ -33,34 +33,34 @@ class InferenceServiceServicer(object): | |||||
def add_InferenceServiceServicer_to_server(servicer, server): | def add_InferenceServiceServicer_to_server(servicer, server): | ||||
rpc_method_handlers = { | rpc_method_handlers = { | ||||
'inference': grpc.unary_unary_rpc_method_handler( | |||||
servicer.inference, | |||||
request_deserializer=inference__pb2.DataRequest.FromString, | |||||
response_serializer=inference__pb2.DataResponse.SerializeToString, | |||||
), | |||||
'inference': grpc.unary_unary_rpc_method_handler( | |||||
servicer.inference, | |||||
request_deserializer=inference__pb2.DataRequest.FromString, | |||||
response_serializer=inference__pb2.DataResponse.SerializeToString, | |||||
), | |||||
} | } | ||||
generic_handler = grpc.method_handlers_generic_handler( | generic_handler = grpc.method_handlers_generic_handler( | ||||
'InferenceService', rpc_method_handlers) | |||||
'InferenceService', rpc_method_handlers) | |||||
server.add_generic_rpc_handlers((generic_handler,)) | server.add_generic_rpc_handlers((generic_handler,)) | ||||
# This class is part of an EXPERIMENTAL API. | |||||
# This class is part of an EXPERIMENTAL API. | |||||
class InferenceService(object): | class InferenceService(object): | ||||
"""Missing associated documentation comment in .proto file.""" | """Missing associated documentation comment in .proto file.""" | ||||
@staticmethod | @staticmethod | ||||
def inference(request, | def inference(request, | ||||
target, | |||||
options=(), | |||||
channel_credentials=None, | |||||
call_credentials=None, | |||||
insecure=False, | |||||
compression=None, | |||||
wait_for_ready=None, | |||||
timeout=None, | |||||
metadata=None): | |||||
target, | |||||
options=(), | |||||
channel_credentials=None, | |||||
call_credentials=None, | |||||
insecure=False, | |||||
compression=None, | |||||
wait_for_ready=None, | |||||
timeout=None, | |||||
metadata=None): | |||||
return grpc.experimental.unary_unary(request, target, '/InferenceService/inference', | return grpc.experimental.unary_unary(request, target, '/InferenceService/inference', | ||||
inference__pb2.DataRequest.SerializeToString, | |||||
inference__pb2.DataResponse.FromString, | |||||
options, channel_credentials, | |||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) | |||||
inference__pb2.DataRequest.SerializeToString, | |||||
inference__pb2.DataResponse.FromString, | |||||
options, channel_credentials, | |||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) |
@@ -0,0 +1,87 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
""" | |||||
import os | |||||
import io | |||||
import torch | |||||
import torch.nn.functional as functional | |||||
from PIL import Image | |||||
from torchvision import transforms | |||||
from imagenet1000_clsidx_to_labels import clsidx_2_labels | |||||
from logger import Logger | |||||
log = Logger().logger | |||||
# 只能定义一个class | |||||
class CommonInferenceService: | |||||
# 请在__init__初始化方法中接收args参数,并加载模型(其中模型路径参数为args.model_path,是否使用gpu参数为args.use_gpu,模型加载方法用户可自定义) | |||||
def __init__(self, args): | |||||
self.args = args | |||||
self.model = self.load_model() | |||||
def load_data(self, data_path): | |||||
image = open(data_path, 'rb').read() | |||||
image = Image.open(io.BytesIO(image)) | |||||
if image.mode != 'RGB': | |||||
image = image.convert("RGB") | |||||
image = transforms.Resize((self.args.reshape_size[0], self.args.reshape_size[1]))(image) | |||||
image = transforms.ToTensor()(image) | |||||
image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image) | |||||
image = image[None] | |||||
if self.args.use_gpu: | |||||
image = image.cuda() | |||||
return image | |||||
def load_model(self): | |||||
if os.path.isfile(self.args.model_path): | |||||
self.checkpoint = torch.load(self.args.model_path) | |||||
else: | |||||
for file in os.listdir(self.args.model_path): | |||||
self.checkpoint = torch.load(self.args.model_path + file) | |||||
model = self.checkpoint["model"] | |||||
model.load_state_dict(self.checkpoint['state_dict']) | |||||
for parameter in model.parameters(): | |||||
parameter.requires_grad = False | |||||
if self.args.use_gpu: | |||||
model.cuda() | |||||
model.eval() | |||||
return model | |||||
# inference方法名称固定 | |||||
def inference(self, data): | |||||
result = {"data_name": data['data_name']} | |||||
log.info("===============> start load " + data['data_name'] + " <===============") | |||||
data = self.load_data(data['data_path']) | |||||
preds = functional.softmax(self.model(data), dim=1) | |||||
predictions = torch.topk(preds.data, k=5, dim=1) | |||||
result['predictions'] = list() | |||||
for prob, label in zip(predictions[0][0], predictions[1][0]): | |||||
predictions = {"label": clsidx_2_labels[int(label)], "probability": "{:.3f}".format(float(prob))} | |||||
result['predictions'].append(predictions) | |||||
return result | |||||
# 非必须,可用于本地调试 | |||||
if __name__=="__main__": | |||||
import argparse | |||||
parser = argparse.ArgumentParser(description='dubhe serving') | |||||
parser.add_argument('--model_path', type=str, default='./res4serving.pth', help="model path") | |||||
parser.add_argument('--use_gpu', type=bool, default=True, help="use gpu or not") | |||||
parser.add_argument('--reshape_size', type=list, default=[224,224], help="use gpu or not") | |||||
args = parser.parse_args() | |||||
server = CommonInferenceService(args) | |||||
image_path = "./cat.jpg" | |||||
image = {"data_name": "cat.jpg", "data_path": image_path} | |||||
re = server.inference(image) | |||||
print(re) |
@@ -15,8 +15,10 @@ import time | |||||
from service.oneflow_inference_service import OneFlowInferenceService | from service.oneflow_inference_service import OneFlowInferenceService | ||||
from service.tensorflow_inference_service import TensorflowInferenceService | from service.tensorflow_inference_service import TensorflowInferenceService | ||||
from service.pytorch_inference_service import PytorchInferenceService | from service.pytorch_inference_service import PytorchInferenceService | ||||
import service.common_inference_service as common_inference_service | |||||
from logger import Logger | from logger import Logger | ||||
from utils import file_utils | from utils import file_utils | ||||
from utils.find_class_in_file import FindClassInFile | |||||
log = Logger().logger | log = Logger().logger | ||||
@@ -36,47 +38,58 @@ class InferenceServiceManager: | |||||
for model_config in model_config_list: | for model_config in model_config_list: | ||||
model_name = model_config["model_name"] | model_name = model_config["model_name"] | ||||
model_path = model_config["model_path"] | model_path = model_config["model_path"] | ||||
self.args.model_name = model_name | |||||
self.args.model_path = model_path | |||||
model_platform = model_config.get("platform") | model_platform = model_config.get("platform") | ||||
if model_platform == "oneflow": | if model_platform == "oneflow": | ||||
self.inference_service = OneFlowInferenceService(model_name, model_path) | |||||
self.inference_service = OneFlowInferenceService(self.args) | |||||
elif model_platform == "tensorflow" or model_platform == "keras": | elif model_platform == "tensorflow" or model_platform == "keras": | ||||
self.inference_service = TensorflowInferenceService(model_name, model_path) | |||||
self.inference_service = TensorflowInferenceService(self.args) | |||||
elif model_platform == "pytorch": | elif model_platform == "pytorch": | ||||
self.inference_service = PytorchInferenceService(model_name, model_path) | |||||
self.inference_service = PytorchInferenceService(self.args) | |||||
self.model_name_service_map[model_name] = self.inference_service | self.model_name_service_map[model_name] = self.inference_service | ||||
else: | else: | ||||
# Read from command-line parameter | # Read from command-line parameter | ||||
if self.args.platform == "oneflow": | |||||
self.inference_service = OneFlowInferenceService(self.args.model_name, self.args.model_path) | |||||
elif self.args.platform == "tensorflow" or self.args.platform == "keras": | |||||
self.inference_service = TensorflowInferenceService(self.args.model_name, self.args.model_path) | |||||
elif self.args.platform == "pytorch": | |||||
self.inference_service = PytorchInferenceService(self.args.model_name, self.args.model_path) | |||||
if self.args.use_script: | |||||
# 使用自定义推理脚本 | |||||
find_class_in_file = FindClassInFile() | |||||
cls = find_class_in_file.find(common_inference_service) | |||||
self.inference_service = cls[1](self.args) | |||||
else : | |||||
# 使用默认推理脚本 | |||||
if self.args.platform == "oneflow": | |||||
self.inference_service = OneFlowInferenceService(self.args) | |||||
elif self.args.platform == "tensorflow" or self.args.platform == "keras": | |||||
self.inference_service = TensorflowInferenceService(self.args) | |||||
elif self.args.platform == "pytorch": | |||||
self.inference_service = PytorchInferenceService(self.args) | |||||
self.model_name_service_map[self.args.model_name] = self.inference_service | self.model_name_service_map[self.args.model_name] = self.inference_service | ||||
def inference(self, model_name, images): | |||||
def inference(self, model_name, data_list): | |||||
""" | """ | ||||
在线服务推理方法 | 在线服务推理方法 | ||||
""" | """ | ||||
inferenceService = self.model_name_service_map[model_name] | inferenceService = self.model_name_service_map[model_name] | ||||
result = list() | result = list() | ||||
for image in images: | |||||
data = inferenceService.inference(image) | |||||
if len(images) == 1: | |||||
return data | |||||
for data in data_list: | |||||
output = inferenceService.inference(data) | |||||
if len(data_list) == 1: | |||||
return output | |||||
else: | else: | ||||
result.append(data) | |||||
result.append(output) | |||||
return result | return result | ||||
def inference_and_save_json(self, model_name, json_path, images): | |||||
def inference_and_save_json(self, model_name, json_path, data_list): | |||||
""" | """ | ||||
批量服务推理方法 | 批量服务推理方法 | ||||
""" | """ | ||||
inferenceService = self.model_name_service_map[model_name] | inferenceService = self.model_name_service_map[model_name] | ||||
for image in images: | |||||
data = inferenceService.inference(image) | |||||
file_utils.writer_json_file(json_path, image['image_name'], data) | |||||
for data in data_list: | |||||
result = inferenceService.inference(data) | |||||
file_utils.writer_json_file(json_path, data['data_name'], result) | |||||
time.sleep(1) | time.sleep(1) |
@@ -20,12 +20,8 @@ import google.protobuf.text_format as text_format | |||||
import os | import os | ||||
from imagenet1000_clsidx_to_labels import clsidx_2_labels | from imagenet1000_clsidx_to_labels import clsidx_2_labels | ||||
from logger import Logger | from logger import Logger | ||||
import config as configs | |||||
from service.abstract_inference_service import AbstractInferenceService | from service.abstract_inference_service import AbstractInferenceService | ||||
parser = configs.get_parser() | |||||
args = parser.parse_args() | |||||
log = Logger().logger | log = Logger().logger | ||||
@@ -33,10 +29,11 @@ class OneFlowInferenceService(AbstractInferenceService): | |||||
""" | """ | ||||
oneflow 框架推理service | oneflow 框架推理service | ||||
""" | """ | ||||
def __init__(self, model_name, model_path): | |||||
def __init__(self, args): | |||||
super().__init__() | super().__init__() | ||||
self.model_name = model_name | |||||
self.model_path = model_path | |||||
self.args = args | |||||
self.model_name = args.model_name | |||||
self.model_path = args.model_path | |||||
flow.clear_default_session() | flow.clear_default_session() | ||||
self.infer_session = flow.SimpleSession() | self.infer_session = flow.SimpleSession() | ||||
self.load_model() | self.load_model() | ||||
@@ -91,9 +88,9 @@ class OneFlowInferenceService(AbstractInferenceService): | |||||
return saved_model_proto | return saved_model_proto | ||||
def inference(self, image): | def inference(self, image): | ||||
data = {"image_name": image['image_name']} | |||||
log.info("===============> start load " + image['image_name'] + " <===============") | |||||
images = self.load_image(image['image_path']) | |||||
data = {"data_name": image['data_name']} | |||||
log.info("===============> start load " + image['data_name'] + " <===============") | |||||
images = self.load_image(image['data_path']) | |||||
predictions = self.infer_session.run('inference', image=images) | predictions = self.infer_session.run('inference', image=images) | ||||
@@ -16,15 +16,12 @@ import torch | |||||
import torch.nn.functional as functional | import torch.nn.functional as functional | ||||
from PIL import Image | from PIL import Image | ||||
from torchvision import transforms | from torchvision import transforms | ||||
import config | |||||
import requests | import requests | ||||
from imagenet1000_clsidx_to_labels import clsidx_2_labels | from imagenet1000_clsidx_to_labels import clsidx_2_labels | ||||
from io import BytesIO | from io import BytesIO | ||||
from logger import Logger | from logger import Logger | ||||
from service.abstract_inference_service import AbstractInferenceService | from service.abstract_inference_service import AbstractInferenceService | ||||
parser = config.get_parser() | |||||
args = parser.parse_args() | |||||
log = Logger().logger | log = Logger().logger | ||||
@@ -33,10 +30,11 @@ class PytorchInferenceService(AbstractInferenceService): | |||||
pytorch 框架推理service | pytorch 框架推理service | ||||
""" | """ | ||||
def __init__(self, model_name, model_path): | |||||
def __init__(self, args): | |||||
super().__init__() | super().__init__() | ||||
self.model_name = model_name | |||||
self.model_path = model_path | |||||
self.args = args | |||||
self.model_name = args.model_name | |||||
self.model_path = args.model_path | |||||
self.model = self.load_model() | self.model = self.load_model() | ||||
self.checkpoint = None | self.checkpoint = None | ||||
@@ -52,38 +50,38 @@ class PytorchInferenceService(AbstractInferenceService): | |||||
image = Image.open(io.BytesIO(image)) | image = Image.open(io.BytesIO(image)) | ||||
if image.mode != 'RGB': | if image.mode != 'RGB': | ||||
image = image.convert("RGB") | image = image.convert("RGB") | ||||
image = transforms.Resize((args.reshape_size[0], args.reshape_size[1]))(image) | |||||
image = transforms.Resize((self.args.reshape_size[0], self.args.reshape_size[1]))(image) | |||||
image = transforms.ToTensor()(image) | image = transforms.ToTensor()(image) | ||||
image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image) | image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image) | ||||
image = image[None] | image = image[None] | ||||
if args.use_gpu: | |||||
if self.args.use_gpu: | |||||
image = image.cuda() | image = image.cuda() | ||||
log.info("===============> load image success <===============") | log.info("===============> load image success <===============") | ||||
return image | return image | ||||
def load_model(self): | def load_model(self): | ||||
log.info("===============> start load pytorch model :" + args.model_path + " <===============") | |||||
if os.path.isfile(args.model_path): | |||||
log.info("===============> start load pytorch model :" + self.args.model_path + " <===============") | |||||
if os.path.isfile(self.args.model_path): | |||||
self.checkpoint = torch.load(self.model_path) | self.checkpoint = torch.load(self.model_path) | ||||
else: | else: | ||||
for file in os.listdir(args.model_path): | |||||
for file in os.listdir(self.args.model_path): | |||||
self.checkpoint = torch.load(self.model_path + file) | self.checkpoint = torch.load(self.model_path + file) | ||||
model = self.checkpoint[args.model_structure] | |||||
model = self.checkpoint[self.args.model_structure] | |||||
model.load_state_dict(self.checkpoint['state_dict']) | model.load_state_dict(self.checkpoint['state_dict']) | ||||
for parameter in model.parameters(): | for parameter in model.parameters(): | ||||
parameter.requires_grad = False | parameter.requires_grad = False | ||||
if args.use_gpu: | |||||
if self.args.use_gpu: | |||||
model.cuda() | model.cuda() | ||||
model.eval() | model.eval() | ||||
log.info("===============> load pytorch model success <===============") | log.info("===============> load pytorch model success <===============") | ||||
return model | return model | ||||
def inference(self, image): | def inference(self, image): | ||||
data = {"image_name": image['image_name']} | |||||
log.info("===============> start load " + image['image_name'] + " <===============") | |||||
image = self.load_image(image['image_path']) | |||||
predis = functional.softmax(self.model(image), dim=1) | |||||
results = torch.topk(predis.data, k=5, dim=1) | |||||
data = {"data_name": image['data_name']} | |||||
log.info("===============> start load " + image['data_name'] + " <===============") | |||||
image = self.load_image(image['data_path']) | |||||
preds = functional.softmax(self.model(image), dim=1) | |||||
results = torch.topk(preds.data, k=5, dim=1) | |||||
data['predictions'] = list() | data['predictions'] = list() | ||||
for prob, label in zip(results[0][0], results[1][0]): | for prob, label in zip(results[0][0], results[1][0]): | ||||
result = {"label": clsidx_2_labels[int(label)], "probability": "{:.3f}".format(float(prob))} | result = {"label": clsidx_2_labels[int(label)], "probability": "{:.3f}".format(float(prob))} | ||||
@@ -13,7 +13,6 @@ limitations under the License. | |||||
import tensorflow as tf | import tensorflow as tf | ||||
import requests | import requests | ||||
import numpy as np | import numpy as np | ||||
import config as configs | |||||
from imagenet1000_clsidx_to_labels import clsidx_2_labels | from imagenet1000_clsidx_to_labels import clsidx_2_labels | ||||
from service.abstract_inference_service import AbstractInferenceService | from service.abstract_inference_service import AbstractInferenceService | ||||
from utils.imagenet_preprocessing_utils import preprocess_input | from utils.imagenet_preprocessing_utils import preprocess_input | ||||
@@ -21,8 +20,6 @@ from logger import Logger | |||||
from PIL import Image | from PIL import Image | ||||
from io import BytesIO | from io import BytesIO | ||||
parser = configs.get_parser() | |||||
args = parser.parse_args() | |||||
log = Logger().logger | log = Logger().logger | ||||
@@ -30,11 +27,12 @@ class TensorflowInferenceService(AbstractInferenceService): | |||||
""" | """ | ||||
tensorflow 框架推理service | tensorflow 框架推理service | ||||
""" | """ | ||||
def __init__(self, model_name, model_path): | |||||
def __init__(self, args): | |||||
super().__init__() | super().__init__() | ||||
self.session = tf.compat.v1.Session(graph=tf.Graph()) | self.session = tf.compat.v1.Session(graph=tf.Graph()) | ||||
self.model_name = model_name | |||||
self.model_path = model_path | |||||
self.args = args | |||||
self.model_name = args.model_name | |||||
self.model_path = args.model_path | |||||
self.signature_input_keys = [] | self.signature_input_keys = [] | ||||
self.signature_input_tensor_names = [] | self.signature_input_tensor_names = [] | ||||
self.signature_output_keys = [] | self.signature_output_keys = [] | ||||
@@ -69,11 +67,11 @@ class TensorflowInferenceService(AbstractInferenceService): | |||||
self.session, [tf.compat.v1.saved_model.tag_constants.SERVING], self.model_path) | self.session, [tf.compat.v1.saved_model.tag_constants.SERVING], self.model_path) | ||||
# 加载模型之前先校验用户传入signature name | # 加载模型之前先校验用户传入signature name | ||||
if args.signature_name not in meta_graph.signature_def: | |||||
if self.args.signature_name not in meta_graph.signature_def: | |||||
log.error("==============> Invalid signature name <==================") | log.error("==============> Invalid signature name <==================") | ||||
# 从signature中获取meta graph中输入和输出的节点信息 | # 从signature中获取meta graph中输入和输出的节点信息 | ||||
signature = meta_graph.signature_def[args.signature_name] | |||||
signature = meta_graph.signature_def[self.args.signature_name] | |||||
input_keys, input_tensor_names = get_tensors(signature.inputs) | input_keys, input_tensor_names = get_tensors(signature.inputs) | ||||
output_keys, output_tensor_names = get_tensors(signature.outputs) | output_keys, output_tensor_names = get_tensors(signature.outputs) | ||||
@@ -87,14 +85,14 @@ class TensorflowInferenceService(AbstractInferenceService): | |||||
log.info("===============> load tensorflow model success <===============") | log.info("===============> load tensorflow model success <===============") | ||||
def inference(self, image): | def inference(self, image): | ||||
data = {"image_name": image['image_name']} | |||||
data = {"data_name": image['data_name']} | |||||
# 获得用户输入的图片 | # 获得用户输入的图片 | ||||
log.info("===============> start load " + image['image_name'] + " <===============") | |||||
log.info("===============> start load " + image['data_name'] + " <===============") | |||||
# 推理所需的输入,目前的分类预置模型都只有一个输入 | # 推理所需的输入,目前的分类预置模型都只有一个输入 | ||||
input_dict = {} | input_dict = {} | ||||
input_keys = self.signature_input_keys | input_keys = self.signature_input_keys | ||||
input_data = {} | input_data = {} | ||||
im = preprocess_input(self.load_image(image['image_path']), mode=args.prepare_mode) | |||||
im = preprocess_input(self.load_image(image['data_path']), mode=self.args.prepare_mode) | |||||
if len(list(im.shape)) == 3: | if len(list(im.shape)) == 3: | ||||
input_data[input_keys[0]] = np.expand_dims(im, axis=0) | input_data[input_keys[0]] = np.expand_dims(im, axis=0) | ||||
@@ -43,52 +43,52 @@ def download_image(images_path): | |||||
save_image_dir + str(int(round(time.time() * MAX_TIME_LENGTH))) + "." + image_path.split("/")[-1].split(".")[-1]) | save_image_dir + str(int(round(time.time() * MAX_TIME_LENGTH))) + "." + image_path.split("/")[-1].split(".")[-1]) | ||||
def upload_image(image_files): | |||||
def upload_data(files): | |||||
""" | """ | ||||
前端上传图片保存到本地 | 前端上传图片保存到本地 | ||||
""" | """ | ||||
save_image_dir = "/usr/local/images/" | |||||
if not os.path.exists(save_image_dir): | |||||
os.mkdir(save_image_dir) | |||||
images = list() | |||||
for image_file in image_files: | |||||
save_data_dir = "/usr/local/data/" | |||||
if not os.path.exists(save_data_dir): | |||||
os.mkdir(save_data_dir) | |||||
data_list = list() | |||||
for file in files: | |||||
try: | try: | ||||
suffix = Path(image_file.filename).suffix | |||||
with NamedTemporaryFile(delete=False, suffix=suffix, dir=save_image_dir) as tmp: | |||||
shutil.copyfileobj(image_file.file, tmp) | |||||
suffix = Path(file.filename).suffix | |||||
with NamedTemporaryFile(delete=False, suffix=suffix, dir=save_data_dir) as tmp: | |||||
shutil.copyfileobj(file.file, tmp) | |||||
tmp_file_name = Path(tmp.name).name | tmp_file_name = Path(tmp.name).name | ||||
file = {"image_name": image_file.filename, "image_path": save_image_dir + tmp_file_name} | |||||
images.append(file) | |||||
data = {"data_name": file.filename, "data_path": save_data_dir + tmp_file_name} | |||||
data_list.append(data) | |||||
finally: | finally: | ||||
image_file.file.close() | |||||
return images | |||||
file.file.close | |||||
return data_list | |||||
def upload_image_by_base64(image_files): | |||||
def upload_image_by_base64(data_list): | |||||
""" | """ | ||||
base64图片信息保存到本地 | base64图片信息保存到本地 | ||||
""" | """ | ||||
save_image_dir = "/usr/local/images/" | |||||
if not os.path.exists(save_image_dir): | |||||
os.mkdir(save_image_dir) | |||||
images = list() | |||||
for img_file in image_files: | |||||
file_path = save_image_dir + str(int(round(time.time() * MAX_TIME_LENGTH))) + "." + img_file.image_name.split(".")[-1] | |||||
img_data = base64.b64decode(img_file.image_file) | |||||
save_data_dir = "/usr/local/data/" | |||||
if not os.path.exists(save_data_dir): | |||||
os.mkdir(save_data_dir) | |||||
data_list_b64 = list() | |||||
for data in data_list: | |||||
file_path = save_data_dir + str(int(round(time.time() * MAX_TIME_LENGTH))) + "." + data.data_name.split(".")[-1] | |||||
file_b64 = base64.b64decode(data.data_file) | |||||
file = open(file_path, 'wb') | file = open(file_path, 'wb') | ||||
file.write(img_data) | |||||
file.write(file_b64) | |||||
file.close() | file.close() | ||||
image = {"image_name": img_file.image_name, "image_path": file_path} | |||||
images.append(image) | |||||
return images | |||||
data_b64 = {"data_name": data.data_name, "data_path": file_path} | |||||
data_list_b64.append(data_b64) | |||||
return data_list_b64 | |||||
def writer_json_file(json_path, image_name, data): | |||||
def writer_json_file(json_path, data_name, data): | |||||
""" | """ | ||||
保存为json文件 | 保存为json文件 | ||||
""" | """ | ||||
if not os.path.exists(json_path): | if not os.path.exists(json_path): | ||||
os.mkdir(json_path) | os.mkdir(json_path) | ||||
filename = json_path + image_name + '.json' | |||||
filename = json_path + data_name + '.json' | |||||
with open(filename, 'w', encoding='utf-8') as file_obj: | with open(filename, 'w', encoding='utf-8') as file_obj: | ||||
file_obj.write(json.dumps(data, ensure_ascii=False)) | file_obj.write(json.dumps(data, ensure_ascii=False)) |
@@ -0,0 +1,88 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
""" | |||||
import re | |||||
import inspect | |||||
# from ac_opf.models import ac_model | |||||
class FindClassInFile: | |||||
""" | |||||
find class in the given module | |||||
### note ### | |||||
There is only one class in the given module. | |||||
If there are more than one classes, all of them will be omit except the first | |||||
############ | |||||
method: find | |||||
args: | |||||
module: object-> the given file or module | |||||
encoding: string-> "utf8" by default | |||||
output: tuple(string-> name of the class, object-> class) | |||||
usage: | |||||
# >>> import module | |||||
# | |||||
# >>> find_class_in_file = FindClassInFile() | |||||
# >>> cls = find_class_in_file.find(module) | |||||
# | |||||
# >>> cls_instance = cls[1](args) | |||||
""" | |||||
def __init__(self): | |||||
pass | |||||
def _open_file(self, path, encoding="utf8"): | |||||
with open(path, "r", encoding=encoding) as f: | |||||
data = f.readlines() | |||||
for line in data: | |||||
yield line | |||||
def find(self, module, encoding="utf8"): | |||||
path = module.__file__ | |||||
lines = self._open_file(path=path, encoding=encoding) | |||||
cls = "" | |||||
for line in lines: | |||||
if "class " in line: | |||||
cls = re.findall("class (.*?)[:(]", line)[0] | |||||
if cls: | |||||
break | |||||
return self._valid(module, cls) | |||||
def _valid(self, module, cls): | |||||
members = inspect.getmembers(module) | |||||
cand = [(i, j) for i, j in members if inspect.isclass(j) and (not inspect.isabstract(j)) and (i == cls)] | |||||
if not cand: | |||||
print("class not found in {}".format(module)) | |||||
return cand[0] | |||||
if __name__ == "__main__": | |||||
find_class_in_file = FindClassInFile() | |||||
# cls = find_class_in_file.find(ac_model) | |||||
# print(cls) | |||||