Browse Source

update serving

tags/v0.4.0
之江天枢 3 years ago
parent
commit
857a57e030
21 changed files with 589 additions and 304 deletions
  1. +22
    -19
      tianshu_serving/README.md
  2. +15
    -17
      tianshu_serving/batch_server.py
  3. +6
    -5
      tianshu_serving/config.py
  4. +0
    -0
      tianshu_serving/customize/__init__.py
  5. +84
    -0
      tianshu_serving/customize/common_inference_service.py
  6. +8
    -8
      tianshu_serving/grpc_client.py
  7. +6
    -4
      tianshu_serving/grpc_server.py
  8. +10
    -5
      tianshu_serving/http_server.py
  9. +0
    -0
      tianshu_serving/logs/serving.log.2020-12-18
  10. +1
    -0
      tianshu_serving/logs/serving.log.2021-01-14
  11. +1
    -0
      tianshu_serving/logs/serving.log.2021-02-03
  12. +5
    -5
      tianshu_serving/proto/inference.proto
  13. +141
    -132
      tianshu_serving/proto/inference_pb2.py
  14. +24
    -24
      tianshu_serving/proto/inference_pb2_grpc.py
  15. +87
    -0
      tianshu_serving/service/common_inference_service.py
  16. +32
    -19
      tianshu_serving/service/inference_service_manager.py
  17. +7
    -10
      tianshu_serving/service/oneflow_inference_service.py
  18. +16
    -18
      tianshu_serving/service/pytorch_inference_service.py
  19. +9
    -11
      tianshu_serving/service/tensorflow_inference_service.py
  20. +27
    -27
      tianshu_serving/utils/file_utils.py
  21. +88
    -0
      tianshu_serving/utils/find_class_in_file.py

+ 22
- 19
tianshu_serving/README.md View File

@@ -5,15 +5,11 @@
+ **支持oneflow、tensorflow、pytorch三种框架模型部署** </br> + **支持oneflow、tensorflow、pytorch三种框架模型部署** </br>


1、通过如下命令启动http在线推理服务 1、通过如下命令启动http在线推理服务
```
python http_server.py --platform='框架名称' --model_path='模型地址'
```
```
python http_server.py --platform='框架名称' --model_path='模型地址'
```
通过访问localhost:5000/docs进入swagger页面,调用localhost:5000/inference进行图片上传得道推理结果,结果如下所示: 通过访问localhost:5000/docs进入swagger页面,调用localhost:5000/inference进行图片上传得道推理结果,结果如下所示:
```
```
{ {
"image_name": "哈士奇.jpg", "image_name": "哈士奇.jpg",
"predictions": [ "predictions": [
@@ -39,28 +35,35 @@
} }
] ]
} }
```
```
2、同理通过如下命令启动grpc在线推理服务 2、同理通过如下命令启动grpc在线推理服务
```
python grpc_server.py --platform='框架名称' --model_path='模型地址'
```
```
python grpc_server.py --platform='框架名称' --model_path='模型地址'
```
再启动grpc_client.py进行上传图片推理得道结果,或者根据ip端口自行编写grpc客户端 再启动grpc_client.py进行上传图片推理得道结果,或者根据ip端口自行编写grpc客户端
3、支持多模型部署,可以自行配置config文件夹下的model_config_file.json进行多模型配置,启动http或grpc时输入不同的模型名称即可,或者自行修改inference接口入参来达到启动单一服务多模型推理的功能
+ **支持多模型部署** </br>

用户可以自行配置config文件夹下的model_config_file.json进行多模型配置,启动http或grpc时输入不同的模型名称即可,或者自行修改inference接口入参来达到启动单一服务多模型推理的功能
+ **支持分布式模型部署推理** </br> + **支持分布式模型部署推理** </br>


需要推理大量图片时需要分布式推理功能,执行如下命令: 需要推理大量图片时需要分布式推理功能,执行如下命令:
```
python batch_server.py --platform='框架名称' --model_path='模型地址' --input_path='批量图片地址' --output_path='输出JSON文件地址'
```
```
python batch_server.py --platform='框架名称' --model_path='模型地址' --input_path='批量图片地址' --output_path='输出JSON文件地址'
```
输入的所有图片保存在input文件夹下,输入json文件保存在output_path文件夹,json名称与图片名称对应 输入的所有图片保存在input文件夹下,输入json文件保存在output_path文件夹,json名称与图片名称对应

+ **支持使用自定义推理脚本** </br>
用户需要使用自定义推理脚本时,可根据common_inference_service.py脚本中注释的规则,自定义推理脚本,并替换原有的common_inference_service.py脚本。此外,在启动命令中添加use_script参数,即可在推理时使用自定义的推理脚本。命令如下所示:
```
python grpc_server.py --platform='框架名称' --model_path='模型地址' --user_script=True
```
+ **代码还包含了各种参数配置,日志文件输出、是否启用TLS等** </br> + **代码还包含了各种参数配置,日志文件输出、是否启用TLS等** </br>

+ 15
- 17
tianshu_serving/batch_server.py View File

@@ -27,49 +27,47 @@ log = Logger().logger
def get_host_ip(): def get_host_ip():
""" """
查询本机ip地址 查询本机ip地址
:return:
return
""" """
global s
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
ip = s.getsockname()[0]
finally:
s.close()
hostname = socket.gethostname()
ip = socket.gethostbyname(hostname)
return ip return ip




def read_directory(images_path):
def read_directory(data_path):
""" """
读取文件夹并进行拆分文件 读取文件夹并进行拆分文件
:return: :return:
""" """
files = os.listdir(images_path)
files = os.listdir(data_path)
num_files = len(files) num_files = len(files)
index_list = list(range(num_files)) index_list = list(range(num_files))
images = list()
data_list = list()
for index in index_list: for index in index_list:
# 是否开启分布式 # 是否开启分布式
if args.enable_distributed: if args.enable_distributed:
ip = get_host_ip() ip = get_host_ip()
log.info("NODE_IPS:{}", os.getenv('NODE_IPS'))
ip_list = os.getenv('NODE_IPS').split(",") ip_list = os.getenv('NODE_IPS').split(",")
num_ips = len(ip_list) num_ips = len(ip_list)
ip_index = ip_list.index(ip) ip_index = ip_list.index(ip)
if ip_index == index % num_ips: if ip_index == index % num_ips:
filename = files[index] filename = files[index]
image = {"image_name": filename, "image_path": images_path + filename}
images.append(image)
data = {"data_name": filename, "data_path": data_path + filename}
data_list.append(data)
else: else:
filename = files[index] filename = files[index]
image = {"image_name": filename, "image_path": images_path + filename}
images.append(image)
return images
data = {"data_name": filename, "data_path": data_path + filename}
data_list.append(data)
return data_list




def main(): def main():
images = read_directory(args.input_path)
inference_service.inference_and_save_json(args.model_name, args.output_path, images)
data_list = read_directory(args.input_path)
inference_service.inference_and_save_json(args.model_name, args.output_path, data_list)
if args.enable_distributed: if args.enable_distributed:
ip = get_host_ip() ip = get_host_ip()
log.info("NODE_IPS:{}", os.getenv('NODE_IPS'))
ip_list = os.getenv('NODE_IPS').split(",") ip_list = os.getenv('NODE_IPS').split(",")
# 主节点必须等待从节点推理完成 # 主节点必须等待从节点推理完成
if ip == ip_list[0]: if ip == ip_list[0]:


+ 6
- 5
tianshu_serving/config.py View File

@@ -58,12 +58,12 @@ def get_parser(parser=None):
parser.add_argument("--job_name", type=str, default="inference", help="oneflow job name") parser.add_argument("--job_name", type=str, default="inference", help="oneflow job name")
parser.add_argument("--prepare_mode", type=str, default="tfhub", parser.add_argument("--prepare_mode", type=str, default="tfhub",
help="tensorflow prepare mode(tfhub、caffe、tf、torch)") help="tensorflow prepare mode(tfhub、caffe、tf、torch)")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="is use gpu")
parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="whether to use gpu")
parser.add_argument('--channel_last', type=str2bool, nargs='?', const=False, parser.add_argument('--channel_last', type=str2bool, nargs='?', const=False,
help='Whether to use use channel last mode(nhwc)') help='Whether to use use channel last mode(nhwc)')
parser.add_argument("--model_path", type=str, default="/usr/local/model/pytorch_models/resnet50/",
parser.add_argument("--model_path", type=str, default="/usr/local/work/models/pytorch_models/resnet50/",
help="model load directory if need") help="model load directory if need")
parser.add_argument("--image_path", type=str, default='/usr/local/data/fish.jpg', help="image path")
parser.add_argument("--data_path", type=str, default='/usr/local/work/dog.jpg', help="input data path")
parser.add_argument("--reshape_size", type=int_list, default='[224]', parser.add_argument("--reshape_size", type=int_list, default='[224]',
help="The reshape size of the image(eg. 224)") help="The reshape size of the image(eg. 224)")
parser.add_argument("--num_classes", type=int, default=1000, help="num of pic classes") parser.add_argument("--num_classes", type=int, default=1000, help="num of pic classes")
@@ -78,8 +78,9 @@ def get_parser(parser=None):
parser.add_argument("--model_config_file", type=str, default="", help="The file of the model config(eg. '')") parser.add_argument("--model_config_file", type=str, default="", help="The file of the model config(eg. '')")
parser.add_argument("--enable_distributed", type=ast.literal_eval, default=False, help="If enable use distributed " parser.add_argument("--enable_distributed", type=ast.literal_eval, default=False, help="If enable use distributed "
"environment") "environment")
parser.add_argument("--input_path", type=str, default="/usr/local/data/images/", help="images path")
parser.add_argument("--output_path", type=str, default="/usr/local/output_path/", help="json path")
parser.add_argument("--input_path", type=str, default="/usr/local/input/", help="input batch data path")
parser.add_argument("--output_path", type=str, default="/usr/local/output/", help="output json path")
parser.add_argument("--use_script", type=ast.literal_eval, default=False, help="whether to use custom inference script")


return parser return parser




+ 0
- 0
tianshu_serving/customize/__init__.py View File


+ 84
- 0
tianshu_serving/customize/common_inference_service.py View File

@@ -0,0 +1,84 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
import io
import torch
import torch.nn.functional as functional
from PIL import Image
from torchvision import transforms
from imagenet1000_clsidx_to_labels import clsidx_2_labels
from logger import Logger

log = Logger().logger

#只能定义一个class
class CommonInferenceService:
# __init__初始化方法中接收args参数(其中模型路径参数为args.model_path,是否使用gpu参数为args.use_gpu),并加载模型(方法用户可自定义)
def __init__(self, args):
self.args = args
self.model = self.load_model()


def load_data(self, data_path):
image = open(data_path, 'rb').read()
image = Image.open(io.BytesIO(image))
if image.mode != 'RGB':
image = image.convert("RGB")
image = transforms.Resize((self.args.reshape_size[0], self.args.reshape_size[1]))(image)
image = transforms.ToTensor()(image)
image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)
image = image[None]
if self.args.use_gpu:
image = image.cuda()
return image

def load_model(self):
if os.path.isfile(self.args.model_path):
self.checkpoint = torch.load(self.args.model_path)
else:
for file in os.listdir(self.args.model_path):
self.checkpoint = torch.load(self.args.model_path + file)
model = self.checkpoint["model"]
model.load_state_dict(self.checkpoint['state_dict'])
for parameter in model.parameters():
parameter.requires_grad = False
if self.args.use_gpu:
model.cuda()
model.eval()
return model

# inference方法名称固定
def inference(self, data):
result = {"data_name": data['data_name']}
data = self.load_data(data['data_path'])
preds = functional.softmax(self.model(data), dim=1)
predictions = torch.topk(preds.data, k=5, dim=1)
result['predictions'] = list()
for prob, label in zip(predictions[0][0], predictions[1][0]):
predictions = {"label": clsidx_2_labels[int(label)], "probability": "{:.3f}".format(float(prob))}
result['predictions'].append(predictions)
return result
if __name__=="__main__":
import argparse
parser = argparse.ArgumentParser(description='tianshu serving')
parser.add_argument('--model_path', type=str, default='./res4serving.pth', help="model path")
parser.add_argument('--use_gpu', type=bool, default=True, help="use gpu or not")
parser.add_argument('--reshape_size', type=list, default=[224,224], help="use gpu or not")
args = parser.parse_args()
server = CommonInferenceService(args)

image_path = "./cat.jpg"
image = {"data_name": "cat.jpg", "data_path": image_path}
re = server.inference(image)
print(re)

+ 8
- 8
tianshu_serving/grpc_client.py View File

@@ -21,8 +21,8 @@ log = Logger().logger
parser = configs.get_parser() parser = configs.get_parser()
args = parser.parse_args() args = parser.parse_args()


_HOST = 'kohj2s.serving.dubhe.ai'
_PORT = '31365'
_HOST = '10.5.24.134'
_PORT = '8500'
MAX_MESSAGE_LENGTH = 1024 * 1024 * 1024 # 可根据具体需求设置,此处设为1G MAX_MESSAGE_LENGTH = 1024 * 1024 * 1024 # 可根据具体需求设置,此处设为1G




@@ -41,12 +41,12 @@ def run():
('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH), ], ) # 创建连接 ('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH), ], ) # 创建连接
client = inference_pb2_grpc.InferenceServiceStub(channel=channel) # 创建客户端 client = inference_pb2_grpc.InferenceServiceStub(channel=channel) # 创建客户端
data_request = inference_pb2.DataRequest() data_request = inference_pb2.DataRequest()
Image = data_request.images.add()
Image.image_file = str(base64.b64encode(open("F:\\Files\\pic\\哈士奇.jpg", "rb").read()), encoding='utf-8')
Image.image_name = "哈士奇.jpg"
Image = data_request.images.add()
Image.image_file = str(base64.b64encode(open("F:\\Files\\pic\\fish.jpg", "rb").read()), encoding='utf-8')
Image.image_name = "fish.jpg"
data1 = data_request.data_list.add()
data1.data_file = str(base64.b64encode(open("/usr/local/input/dog.jpg", "rb").read()), encoding='utf-8')
data1.data_name = "dog.jpg"
data2 = data_request.data_list.add()
data2.data_file = str(base64.b64encode(open("/usr/local/input/6.jpg", "rb").read()), encoding='utf-8')
data2.data_name = "6.jpg"
response = client.inference(data_request) response = client.inference(data_request)
log.info(response.json_result.encode('utf-8').decode('unicode_escape')) log.info(response.json_result.encode('utf-8').decode('unicode_escape'))




+ 6
- 4
tianshu_serving/grpc_server.py View File

@@ -46,19 +46,21 @@ class InferenceService(inference_pb2_grpc.InferenceServiceServicer):
调用grpc方法进行推理 调用grpc方法进行推理
""" """
def inference(self, request, context): def inference(self, request, context):
image_files = request.images
data_list = request.data_list
log.info("===============> grpc inference start <===============") log.info("===============> grpc inference start <===============")
try: try:
images = file_utils.upload_image_by_base64(image_files) # 上传图片到本地
data_list_b64 = file_utils.upload_image_by_base64(data_list) # 上传图片到本地
except Exception as e: except Exception as e:
log.error("upload data failed", e)
return inference_pb2.DataResponse(json_result=json.dumps( return inference_pb2.DataResponse(json_result=json.dumps(
response_convert(Response(success=False, data=str(e), error="upload image fail"))))
response_convert(Response(success=False, data=str(e), error="upload data failed"))))
try: try:
result = inference_service.inference(args.model_name, images)
result = inference_service.inference(args.model_name, data_list_b64)
log.info("===============> grpc inference success <===============") log.info("===============> grpc inference success <===============")
return inference_pb2.DataResponse(json_result=json.dumps( return inference_pb2.DataResponse(json_result=json.dumps(
response_convert(Response(success=True, data=result)))) response_convert(Response(success=True, data=result))))
except Exception as e: except Exception as e:
log.error("inference fail", e)
return inference_pb2.DataResponse(json_result=json.dumps( return inference_pb2.DataResponse(json_result=json.dumps(
response_convert(Response(success=False, data=str(e), error="inference fail")))) response_convert(Response(success=False, data=str(e), error="inference fail"))))




+ 10
- 5
tianshu_serving/http_server.py View File

@@ -49,7 +49,7 @@ async def inference(images_path: List[str] = None):
threading.Thread(target=file_utils.download_image(images_path)) # 开启异步线程下载图片到本地 threading.Thread(target=file_utils.download_image(images_path)) # 开启异步线程下载图片到本地
images = list() images = list()
for image in images_path: for image in images_path:
data = {"image_name": image.split("/")[-1], "image_path": image}
data = {"data_name": image.split("/")[-1], "data_path": image}
images.append(data) images.append(data)
try: try:
data = inference_service.inference(args.model_name, images) data = inference_service.inference(args.model_name, images)
@@ -59,17 +59,22 @@ async def inference(images_path: List[str] = None):




@app.post("/inference") @app.post("/inference")
async def inference(image_files: List[UploadFile] = File(...)):
async def inference(files: List[UploadFile] = File(...)):
"""
上传本地文件推理
"""
log.info("===============> http inference start <===============") log.info("===============> http inference start <===============")
try: try:
images = file_utils.upload_image(image_files) # 上传图片到本地
data_list = file_utils.upload_data(files) # 上传图片到本地
except Exception as e: except Exception as e:
return Response(success=False, data=str(e), error="upload image fail")
log.error("upload data failed", e)
return Response(success=False, data=str(e), error="upload data failed")
try: try:
result = inference_service.inference(args.model_name, images)
result = inference_service.inference(args.model_name, data_list)
log.info("===============> http inference success <===============") log.info("===============> http inference success <===============")
return Response(success=True, data=result) return Response(success=True, data=result)
except Exception as e: except Exception as e:
log.error("inference fail", e)
return Response(success=False, data=str(e), error="inference fail") return Response(success=False, data=str(e), error="inference fail")






+ 0
- 0
tianshu_serving/logs/serving.log.2020-12-18 View File


+ 1
- 0
tianshu_serving/logs/serving.log.2021-01-14 View File

@@ -0,0 +1 @@
2021-01-14 19:55:07,045 - E:/Python Project/TS_Serving/grpc_client.py[line:51] - INFO: {"success": true, "data": [{"image_name": "哈士奇.jpg", "predictions": [{"label": "Eskimo dog, husky", "probability": "0.665"}, {"label": "Siberian husky", "probability": "0.294"}, {"label": "dogsled, dog sled, dog sleigh", "probability": "0.024"}, {"label": "malamute, malemute, Alaskan malamute", "probability": "0.015"}, {"label": "timber wolf, grey wolf, gray wolf, Canis lupus", "probability": "0.000"}]}, {"image_name": "fish.jpg", "predictions": [{"label": "goldfish, Carassius auratus", "probability": "1.000"}, {"label": "tench, Tinca tinca", "probability": "0.000"}, {"label": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch", "probability": "0.000"}, {"label": "axolotl, mud puppy, Ambystoma mexicanum", "probability": "0.000"}, {"label": "barracouta, snoek", "probability": "0.000"}]}], "error": ""}

+ 1
- 0
tianshu_serving/logs/serving.log.2021-02-03 View File

@@ -0,0 +1 @@
2021-02-03 15:05:00,069 - E:/Python Project/TS_Serving/grpc_client.py[line:51] - INFO: {"success": true, "data": [{"image_name": "哈士奇.jpg", "predictions": [{"label": "Eskimo dog, husky", "probability": "0.679"}, {"label": "Siberian husky", "probability": "0.213"}, {"label": "dogsled, dog sled, dog sleigh", "probability": "0.021"}, {"label": "malamute, malemute, Alaskan malamute", "probability": "0.006"}, {"label": "white wolf, Arctic wolf, Canis lupus tundrarum", "probability": "0.001"}]}, {"image_name": "fish.jpg", "predictions": [{"label": "goldfish, Carassius auratus", "probability": "0.871"}, {"label": "tench, Tinca tinca", "probability": "0.004"}, {"label": "puffer, pufferfish, blowfish, globefish", "probability": "0.002"}, {"label": "rock beauty, Holocanthus tricolor", "probability": "0.002"}, {"label": "barracouta, snoek", "probability": "0.001"}]}], "error": ""}

+ 5
- 5
tianshu_serving/proto/inference.proto View File

@@ -1,16 +1,16 @@
syntax = 'proto3'; syntax = 'proto3';
service InferenceService { service InferenceService {
rpc inference(DataRequest) returns (DataResponse) {} rpc inference(DataRequest) returns (DataResponse) {}
} }


message DataRequest{ message DataRequest{
repeated Image images = 1;
repeated Data data_list = 1;
} }


message Image {
string image_file = 1;
string image_name = 2;
message Data {
string data_file = 1;
string data_name = 2;
} }


message DataResponse{ message DataResponse{


+ 141
- 132
tianshu_serving/proto/inference_pb2.py View File

@@ -6,168 +6,177 @@ from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message from google.protobuf import message as _message
from google.protobuf import reflection as _reflection from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database from google.protobuf import symbol_database as _symbol_database

# @@protoc_insertion_point(imports) # @@protoc_insertion_point(imports)


_sym_db = _symbol_database.Default() _sym_db = _symbol_database.Default()





DESCRIPTOR = _descriptor.FileDescriptor( DESCRIPTOR = _descriptor.FileDescriptor(
name='inference.proto',
package='',
syntax='proto3',
serialized_options=None,
create_key=_descriptor._internal_create_key,
serialized_pb=b'\n\x0finference.proto\"%\n\x0b\x44\x61taRequest\x12\x16\n\x06images\x18\x01 \x03(\x0b\x32\x06.Image\"/\n\x05Image\x12\x12\n\nimage_file\x18\x01 \x01(\t\x12\x12\n\nimage_name\x18\x02 \x01(\t\"#\n\x0c\x44\x61taResponse\x12\x13\n\x0bjson_result\x18\x01 \x01(\t2>\n\x10InferenceService\x12*\n\tinference\x12\x0c.DataRequest\x1a\r.DataResponse\"\x00\x62\x06proto3'
name='inference.proto',
package='',
syntax='proto3',
serialized_options=None,
create_key=_descriptor._internal_create_key,
serialized_pb=b'\n\x0finference.proto\"\'\n\x0b\x44\x61taRequest\x12\x18\n\tdata_list\x18\x01 \x03(\x0b\x32\x05.Data\",\n\x04\x44\x61ta\x12\x11\n\tdata_file\x18\x01 \x01(\t\x12\x11\n\tdata_name\x18\x02 \x01(\t\"#\n\x0c\x44\x61taResponse\x12\x13\n\x0bjson_result\x18\x01 \x01(\t2>\n\x10InferenceService\x12*\n\tinference\x12\x0c.DataRequest\x1a\r.DataResponse\"\x00\x62\x06proto3'
) )





_DATAREQUEST = _descriptor.Descriptor( _DATAREQUEST = _descriptor.Descriptor(
name='DataRequest',
full_name='DataRequest',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='images', full_name='DataRequest.images', index=0,
number=1, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=19,
serialized_end=56,
name='DataRequest',
full_name='DataRequest',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='data_list', full_name='DataRequest.data_list', index=0,
number=1, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=19,
serialized_end=58,
) )


_IMAGE = _descriptor.Descriptor(
name='Image',
full_name='Image',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='image_file', full_name='Image.image_file', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='image_name', full_name='Image.image_name', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=58,
serialized_end=105,

_DATA = _descriptor.Descriptor(
name='Data',
full_name='Data',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='data_file', full_name='Data.data_file', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='data_name', full_name='Data.data_name', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=60,
serialized_end=104,
) )



_DATARESPONSE = _descriptor.Descriptor( _DATARESPONSE = _descriptor.Descriptor(
name='DataResponse',
full_name='DataResponse',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='json_result', full_name='DataResponse.json_result', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=107,
serialized_end=142,
name='DataResponse',
full_name='DataResponse',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='json_result', full_name='DataResponse.json_result', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=106,
serialized_end=141,
) )


_DATAREQUEST.fields_by_name['images'].message_type = _IMAGE
_DATAREQUEST.fields_by_name['data_list'].message_type = _DATA
DESCRIPTOR.message_types_by_name['DataRequest'] = _DATAREQUEST DESCRIPTOR.message_types_by_name['DataRequest'] = _DATAREQUEST
DESCRIPTOR.message_types_by_name['Image'] = _IMAGE
DESCRIPTOR.message_types_by_name['Data'] = _DATA
DESCRIPTOR.message_types_by_name['DataResponse'] = _DATARESPONSE DESCRIPTOR.message_types_by_name['DataResponse'] = _DATARESPONSE
_sym_db.RegisterFileDescriptor(DESCRIPTOR) _sym_db.RegisterFileDescriptor(DESCRIPTOR)


DataRequest = _reflection.GeneratedProtocolMessageType('DataRequest', (_message.Message,), { DataRequest = _reflection.GeneratedProtocolMessageType('DataRequest', (_message.Message,), {
'DESCRIPTOR': _DATAREQUEST,
'__module__': 'inference_pb2'
# @@protoc_insertion_point(class_scope:DataRequest)
})
'DESCRIPTOR' : _DATAREQUEST,
'__module__' : 'inference_pb2'
# @@protoc_insertion_point(class_scope:DataRequest)
})
_sym_db.RegisterMessage(DataRequest) _sym_db.RegisterMessage(DataRequest)


Image = _reflection.GeneratedProtocolMessageType('Image', (_message.Message,), {
'DESCRIPTOR': _IMAGE,
'__module__': 'inference_pb2'
# @@protoc_insertion_point(class_scope:Image)
})
_sym_db.RegisterMessage(Image)
Data = _reflection.GeneratedProtocolMessageType('Data', (_message.Message,), {
'DESCRIPTOR' : _DATA,
'__module__' : 'inference_pb2'
# @@protoc_insertion_point(class_scope:Data)
})
_sym_db.RegisterMessage(Data)


DataResponse = _reflection.GeneratedProtocolMessageType('DataResponse', (_message.Message,), { DataResponse = _reflection.GeneratedProtocolMessageType('DataResponse', (_message.Message,), {
'DESCRIPTOR': _DATARESPONSE,
'__module__': 'inference_pb2'
# @@protoc_insertion_point(class_scope:DataResponse)
})
'DESCRIPTOR' : _DATARESPONSE,
'__module__' : 'inference_pb2'
# @@protoc_insertion_point(class_scope:DataResponse)
})
_sym_db.RegisterMessage(DataResponse) _sym_db.RegisterMessage(DataResponse)




_INFERENCESERVICE = _descriptor.ServiceDescriptor( _INFERENCESERVICE = _descriptor.ServiceDescriptor(
name='InferenceService',
full_name='InferenceService',
file=DESCRIPTOR,
name='InferenceService',
full_name='InferenceService',
file=DESCRIPTOR,
index=0,
serialized_options=None,
create_key=_descriptor._internal_create_key,
serialized_start=143,
serialized_end=205,
methods=[
_descriptor.MethodDescriptor(
name='inference',
full_name='InferenceService.inference',
index=0, index=0,
containing_service=None,
input_type=_DATAREQUEST,
output_type=_DATARESPONSE,
serialized_options=None, serialized_options=None,
create_key=_descriptor._internal_create_key, create_key=_descriptor._internal_create_key,
serialized_start=144,
serialized_end=206,
methods=[
_descriptor.MethodDescriptor(
name='inference',
full_name='InferenceService.inference',
index=0,
containing_service=None,
input_type=_DATAREQUEST,
output_type=_DATARESPONSE,
serialized_options=None,
create_key=_descriptor._internal_create_key,
),
])
),
])
_sym_db.RegisterServiceDescriptor(_INFERENCESERVICE) _sym_db.RegisterServiceDescriptor(_INFERENCESERVICE)


DESCRIPTOR.services_by_name['InferenceService'] = _INFERENCESERVICE DESCRIPTOR.services_by_name['InferenceService'] = _INFERENCESERVICE


+ 24
- 24
tianshu_serving/proto/inference_pb2_grpc.py View File

@@ -15,10 +15,10 @@ class InferenceServiceStub(object):
channel: A grpc.Channel. channel: A grpc.Channel.
""" """
self.inference = channel.unary_unary( self.inference = channel.unary_unary(
'/InferenceService/inference',
request_serializer=inference__pb2.DataRequest.SerializeToString,
response_deserializer=inference__pb2.DataResponse.FromString,
)
'/InferenceService/inference',
request_serializer=inference__pb2.DataRequest.SerializeToString,
response_deserializer=inference__pb2.DataResponse.FromString,
)




class InferenceServiceServicer(object): class InferenceServiceServicer(object):
@@ -33,34 +33,34 @@ class InferenceServiceServicer(object):


def add_InferenceServiceServicer_to_server(servicer, server): def add_InferenceServiceServicer_to_server(servicer, server):
rpc_method_handlers = { rpc_method_handlers = {
'inference': grpc.unary_unary_rpc_method_handler(
servicer.inference,
request_deserializer=inference__pb2.DataRequest.FromString,
response_serializer=inference__pb2.DataResponse.SerializeToString,
),
'inference': grpc.unary_unary_rpc_method_handler(
servicer.inference,
request_deserializer=inference__pb2.DataRequest.FromString,
response_serializer=inference__pb2.DataResponse.SerializeToString,
),
} }
generic_handler = grpc.method_handlers_generic_handler( generic_handler = grpc.method_handlers_generic_handler(
'InferenceService', rpc_method_handlers)
'InferenceService', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,)) server.add_generic_rpc_handlers((generic_handler,))




# This class is part of an EXPERIMENTAL API.
# This class is part of an EXPERIMENTAL API.
class InferenceService(object): class InferenceService(object):
"""Missing associated documentation comment in .proto file.""" """Missing associated documentation comment in .proto file."""


@staticmethod @staticmethod
def inference(request, def inference(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/InferenceService/inference', return grpc.experimental.unary_unary(request, target, '/InferenceService/inference',
inference__pb2.DataRequest.SerializeToString,
inference__pb2.DataResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
inference__pb2.DataRequest.SerializeToString,
inference__pb2.DataResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

+ 87
- 0
tianshu_serving/service/common_inference_service.py View File

@@ -0,0 +1,87 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
import io
import torch
import torch.nn.functional as functional
from PIL import Image
from torchvision import transforms
from imagenet1000_clsidx_to_labels import clsidx_2_labels
from logger import Logger

log = Logger().logger

# 只能定义一个class
class CommonInferenceService:
# 请在__init__初始化方法中接收args参数,并加载模型(其中模型路径参数为args.model_path,是否使用gpu参数为args.use_gpu,模型加载方法用户可自定义)
def __init__(self, args):
self.args = args
self.model = self.load_model()


def load_data(self, data_path):
image = open(data_path, 'rb').read()
image = Image.open(io.BytesIO(image))
if image.mode != 'RGB':
image = image.convert("RGB")
image = transforms.Resize((self.args.reshape_size[0], self.args.reshape_size[1]))(image)
image = transforms.ToTensor()(image)
image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)
image = image[None]
if self.args.use_gpu:
image = image.cuda()
return image

def load_model(self):
if os.path.isfile(self.args.model_path):
self.checkpoint = torch.load(self.args.model_path)
else:
for file in os.listdir(self.args.model_path):
self.checkpoint = torch.load(self.args.model_path + file)
model = self.checkpoint["model"]
model.load_state_dict(self.checkpoint['state_dict'])
for parameter in model.parameters():
parameter.requires_grad = False
if self.args.use_gpu:
model.cuda()
model.eval()
return model

# inference方法名称固定
def inference(self, data):
result = {"data_name": data['data_name']}
log.info("===============> start load " + data['data_name'] + " <===============")
data = self.load_data(data['data_path'])
preds = functional.softmax(self.model(data), dim=1)
predictions = torch.topk(preds.data, k=5, dim=1)
result['predictions'] = list()
for prob, label in zip(predictions[0][0], predictions[1][0]):
predictions = {"label": clsidx_2_labels[int(label)], "probability": "{:.3f}".format(float(prob))}
result['predictions'].append(predictions)
return result

# 非必须,可用于本地调试
if __name__=="__main__":
import argparse
parser = argparse.ArgumentParser(description='dubhe serving')
parser.add_argument('--model_path', type=str, default='./res4serving.pth', help="model path")
parser.add_argument('--use_gpu', type=bool, default=True, help="use gpu or not")
parser.add_argument('--reshape_size', type=list, default=[224,224], help="use gpu or not")
args = parser.parse_args()
server = CommonInferenceService(args)

image_path = "./cat.jpg"
image = {"data_name": "cat.jpg", "data_path": image_path}
re = server.inference(image)
print(re)

+ 32
- 19
tianshu_serving/service/inference_service_manager.py View File

@@ -15,8 +15,10 @@ import time
from service.oneflow_inference_service import OneFlowInferenceService from service.oneflow_inference_service import OneFlowInferenceService
from service.tensorflow_inference_service import TensorflowInferenceService from service.tensorflow_inference_service import TensorflowInferenceService
from service.pytorch_inference_service import PytorchInferenceService from service.pytorch_inference_service import PytorchInferenceService
import service.common_inference_service as common_inference_service
from logger import Logger from logger import Logger
from utils import file_utils from utils import file_utils
from utils.find_class_in_file import FindClassInFile


log = Logger().logger log = Logger().logger


@@ -36,47 +38,58 @@ class InferenceServiceManager:
for model_config in model_config_list: for model_config in model_config_list:
model_name = model_config["model_name"] model_name = model_config["model_name"]
model_path = model_config["model_path"] model_path = model_config["model_path"]
self.args.model_name = model_name
self.args.model_path = model_path
model_platform = model_config.get("platform") model_platform = model_config.get("platform")


if model_platform == "oneflow": if model_platform == "oneflow":
self.inference_service = OneFlowInferenceService(model_name, model_path)
self.inference_service = OneFlowInferenceService(self.args)
elif model_platform == "tensorflow" or model_platform == "keras": elif model_platform == "tensorflow" or model_platform == "keras":
self.inference_service = TensorflowInferenceService(model_name, model_path)
self.inference_service = TensorflowInferenceService(self.args)
elif model_platform == "pytorch": elif model_platform == "pytorch":
self.inference_service = PytorchInferenceService(model_name, model_path)
self.inference_service = PytorchInferenceService(self.args)


self.model_name_service_map[model_name] = self.inference_service self.model_name_service_map[model_name] = self.inference_service
else: else:
# Read from command-line parameter # Read from command-line parameter
if self.args.platform == "oneflow":
self.inference_service = OneFlowInferenceService(self.args.model_name, self.args.model_path)
elif self.args.platform == "tensorflow" or self.args.platform == "keras":
self.inference_service = TensorflowInferenceService(self.args.model_name, self.args.model_path)
elif self.args.platform == "pytorch":
self.inference_service = PytorchInferenceService(self.args.model_name, self.args.model_path)
if self.args.use_script:
# 使用自定义推理脚本
find_class_in_file = FindClassInFile()
cls = find_class_in_file.find(common_inference_service)
self.inference_service = cls[1](self.args)

else :
# 使用默认推理脚本
if self.args.platform == "oneflow":
self.inference_service = OneFlowInferenceService(self.args)
elif self.args.platform == "tensorflow" or self.args.platform == "keras":
self.inference_service = TensorflowInferenceService(self.args)
elif self.args.platform == "pytorch":
self.inference_service = PytorchInferenceService(self.args)



self.model_name_service_map[self.args.model_name] = self.inference_service self.model_name_service_map[self.args.model_name] = self.inference_service


def inference(self, model_name, images):
def inference(self, model_name, data_list):
""" """
在线服务推理方法 在线服务推理方法
""" """
inferenceService = self.model_name_service_map[model_name] inferenceService = self.model_name_service_map[model_name]
result = list() result = list()
for image in images:
data = inferenceService.inference(image)
if len(images) == 1:
return data
for data in data_list:
output = inferenceService.inference(data)
if len(data_list) == 1:
return output
else: else:
result.append(data)
result.append(output)
return result return result


def inference_and_save_json(self, model_name, json_path, images):
def inference_and_save_json(self, model_name, json_path, data_list):
""" """
批量服务推理方法 批量服务推理方法
""" """
inferenceService = self.model_name_service_map[model_name] inferenceService = self.model_name_service_map[model_name]
for image in images:
data = inferenceService.inference(image)
file_utils.writer_json_file(json_path, image['image_name'], data)
for data in data_list:
result = inferenceService.inference(data)
file_utils.writer_json_file(json_path, data['data_name'], result)
time.sleep(1) time.sleep(1)

+ 7
- 10
tianshu_serving/service/oneflow_inference_service.py View File

@@ -20,12 +20,8 @@ import google.protobuf.text_format as text_format
import os import os
from imagenet1000_clsidx_to_labels import clsidx_2_labels from imagenet1000_clsidx_to_labels import clsidx_2_labels
from logger import Logger from logger import Logger
import config as configs
from service.abstract_inference_service import AbstractInferenceService from service.abstract_inference_service import AbstractInferenceService


parser = configs.get_parser()
args = parser.parse_args()

log = Logger().logger log = Logger().logger




@@ -33,10 +29,11 @@ class OneFlowInferenceService(AbstractInferenceService):
""" """
oneflow 框架推理service oneflow 框架推理service
""" """
def __init__(self, model_name, model_path):
def __init__(self, args):
super().__init__() super().__init__()
self.model_name = model_name
self.model_path = model_path
self.args = args
self.model_name = args.model_name
self.model_path = args.model_path
flow.clear_default_session() flow.clear_default_session()
self.infer_session = flow.SimpleSession() self.infer_session = flow.SimpleSession()
self.load_model() self.load_model()
@@ -91,9 +88,9 @@ class OneFlowInferenceService(AbstractInferenceService):
return saved_model_proto return saved_model_proto


def inference(self, image): def inference(self, image):
data = {"image_name": image['image_name']}
log.info("===============> start load " + image['image_name'] + " <===============")
images = self.load_image(image['image_path'])
data = {"data_name": image['data_name']}
log.info("===============> start load " + image['data_name'] + " <===============")
images = self.load_image(image['data_path'])


predictions = self.infer_session.run('inference', image=images) predictions = self.infer_session.run('inference', image=images)




+ 16
- 18
tianshu_serving/service/pytorch_inference_service.py View File

@@ -16,15 +16,12 @@ import torch
import torch.nn.functional as functional import torch.nn.functional as functional
from PIL import Image from PIL import Image
from torchvision import transforms from torchvision import transforms
import config
import requests import requests
from imagenet1000_clsidx_to_labels import clsidx_2_labels from imagenet1000_clsidx_to_labels import clsidx_2_labels
from io import BytesIO from io import BytesIO
from logger import Logger from logger import Logger
from service.abstract_inference_service import AbstractInferenceService from service.abstract_inference_service import AbstractInferenceService


parser = config.get_parser()
args = parser.parse_args()
log = Logger().logger log = Logger().logger




@@ -33,10 +30,11 @@ class PytorchInferenceService(AbstractInferenceService):
pytorch 框架推理service pytorch 框架推理service
""" """


def __init__(self, model_name, model_path):
def __init__(self, args):
super().__init__() super().__init__()
self.model_name = model_name
self.model_path = model_path
self.args = args
self.model_name = args.model_name
self.model_path = args.model_path
self.model = self.load_model() self.model = self.load_model()
self.checkpoint = None self.checkpoint = None


@@ -52,38 +50,38 @@ class PytorchInferenceService(AbstractInferenceService):
image = Image.open(io.BytesIO(image)) image = Image.open(io.BytesIO(image))
if image.mode != 'RGB': if image.mode != 'RGB':
image = image.convert("RGB") image = image.convert("RGB")
image = transforms.Resize((args.reshape_size[0], args.reshape_size[1]))(image)
image = transforms.Resize((self.args.reshape_size[0], self.args.reshape_size[1]))(image)
image = transforms.ToTensor()(image) image = transforms.ToTensor()(image)
image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image) image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image)
image = image[None] image = image[None]
if args.use_gpu:
if self.args.use_gpu:
image = image.cuda() image = image.cuda()
log.info("===============> load image success <===============") log.info("===============> load image success <===============")
return image return image


def load_model(self): def load_model(self):
log.info("===============> start load pytorch model :" + args.model_path + " <===============")
if os.path.isfile(args.model_path):
log.info("===============> start load pytorch model :" + self.args.model_path + " <===============")
if os.path.isfile(self.args.model_path):
self.checkpoint = torch.load(self.model_path) self.checkpoint = torch.load(self.model_path)
else: else:
for file in os.listdir(args.model_path):
for file in os.listdir(self.args.model_path):
self.checkpoint = torch.load(self.model_path + file) self.checkpoint = torch.load(self.model_path + file)
model = self.checkpoint[args.model_structure]
model = self.checkpoint[self.args.model_structure]
model.load_state_dict(self.checkpoint['state_dict']) model.load_state_dict(self.checkpoint['state_dict'])
for parameter in model.parameters(): for parameter in model.parameters():
parameter.requires_grad = False parameter.requires_grad = False
if args.use_gpu:
if self.args.use_gpu:
model.cuda() model.cuda()
model.eval() model.eval()
log.info("===============> load pytorch model success <===============") log.info("===============> load pytorch model success <===============")
return model return model


def inference(self, image): def inference(self, image):
data = {"image_name": image['image_name']}
log.info("===============> start load " + image['image_name'] + " <===============")
image = self.load_image(image['image_path'])
predis = functional.softmax(self.model(image), dim=1)
results = torch.topk(predis.data, k=5, dim=1)
data = {"data_name": image['data_name']}
log.info("===============> start load " + image['data_name'] + " <===============")
image = self.load_image(image['data_path'])
preds = functional.softmax(self.model(image), dim=1)
results = torch.topk(preds.data, k=5, dim=1)
data['predictions'] = list() data['predictions'] = list()
for prob, label in zip(results[0][0], results[1][0]): for prob, label in zip(results[0][0], results[1][0]):
result = {"label": clsidx_2_labels[int(label)], "probability": "{:.3f}".format(float(prob))} result = {"label": clsidx_2_labels[int(label)], "probability": "{:.3f}".format(float(prob))}


+ 9
- 11
tianshu_serving/service/tensorflow_inference_service.py View File

@@ -13,7 +13,6 @@ limitations under the License.
import tensorflow as tf import tensorflow as tf
import requests import requests
import numpy as np import numpy as np
import config as configs
from imagenet1000_clsidx_to_labels import clsidx_2_labels from imagenet1000_clsidx_to_labels import clsidx_2_labels
from service.abstract_inference_service import AbstractInferenceService from service.abstract_inference_service import AbstractInferenceService
from utils.imagenet_preprocessing_utils import preprocess_input from utils.imagenet_preprocessing_utils import preprocess_input
@@ -21,8 +20,6 @@ from logger import Logger
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO


parser = configs.get_parser()
args = parser.parse_args()
log = Logger().logger log = Logger().logger




@@ -30,11 +27,12 @@ class TensorflowInferenceService(AbstractInferenceService):
""" """
tensorflow 框架推理service tensorflow 框架推理service
""" """
def __init__(self, model_name, model_path):
def __init__(self, args):
super().__init__() super().__init__()
self.session = tf.compat.v1.Session(graph=tf.Graph()) self.session = tf.compat.v1.Session(graph=tf.Graph())
self.model_name = model_name
self.model_path = model_path
self.args = args
self.model_name = args.model_name
self.model_path = args.model_path
self.signature_input_keys = [] self.signature_input_keys = []
self.signature_input_tensor_names = [] self.signature_input_tensor_names = []
self.signature_output_keys = [] self.signature_output_keys = []
@@ -69,11 +67,11 @@ class TensorflowInferenceService(AbstractInferenceService):
self.session, [tf.compat.v1.saved_model.tag_constants.SERVING], self.model_path) self.session, [tf.compat.v1.saved_model.tag_constants.SERVING], self.model_path)


# 加载模型之前先校验用户传入signature name # 加载模型之前先校验用户传入signature name
if args.signature_name not in meta_graph.signature_def:
if self.args.signature_name not in meta_graph.signature_def:
log.error("==============> Invalid signature name <==================") log.error("==============> Invalid signature name <==================")


# 从signature中获取meta graph中输入和输出的节点信息 # 从signature中获取meta graph中输入和输出的节点信息
signature = meta_graph.signature_def[args.signature_name]
signature = meta_graph.signature_def[self.args.signature_name]
input_keys, input_tensor_names = get_tensors(signature.inputs) input_keys, input_tensor_names = get_tensors(signature.inputs)
output_keys, output_tensor_names = get_tensors(signature.outputs) output_keys, output_tensor_names = get_tensors(signature.outputs)


@@ -87,14 +85,14 @@ class TensorflowInferenceService(AbstractInferenceService):
log.info("===============> load tensorflow model success <===============") log.info("===============> load tensorflow model success <===============")


def inference(self, image): def inference(self, image):
data = {"image_name": image['image_name']}
data = {"data_name": image['data_name']}
# 获得用户输入的图片 # 获得用户输入的图片
log.info("===============> start load " + image['image_name'] + " <===============")
log.info("===============> start load " + image['data_name'] + " <===============")
# 推理所需的输入,目前的分类预置模型都只有一个输入 # 推理所需的输入,目前的分类预置模型都只有一个输入
input_dict = {} input_dict = {}
input_keys = self.signature_input_keys input_keys = self.signature_input_keys
input_data = {} input_data = {}
im = preprocess_input(self.load_image(image['image_path']), mode=args.prepare_mode)
im = preprocess_input(self.load_image(image['data_path']), mode=self.args.prepare_mode)
if len(list(im.shape)) == 3: if len(list(im.shape)) == 3:
input_data[input_keys[0]] = np.expand_dims(im, axis=0) input_data[input_keys[0]] = np.expand_dims(im, axis=0)




+ 27
- 27
tianshu_serving/utils/file_utils.py View File

@@ -43,52 +43,52 @@ def download_image(images_path):
save_image_dir + str(int(round(time.time() * MAX_TIME_LENGTH))) + "." + image_path.split("/")[-1].split(".")[-1]) save_image_dir + str(int(round(time.time() * MAX_TIME_LENGTH))) + "." + image_path.split("/")[-1].split(".")[-1])




def upload_image(image_files):
def upload_data(files):
""" """
前端上传图片保存到本地 前端上传图片保存到本地
""" """
save_image_dir = "/usr/local/images/"
if not os.path.exists(save_image_dir):
os.mkdir(save_image_dir)
images = list()
for image_file in image_files:
save_data_dir = "/usr/local/data/"
if not os.path.exists(save_data_dir):
os.mkdir(save_data_dir)
data_list = list()
for file in files:
try: try:
suffix = Path(image_file.filename).suffix
with NamedTemporaryFile(delete=False, suffix=suffix, dir=save_image_dir) as tmp:
shutil.copyfileobj(image_file.file, tmp)
suffix = Path(file.filename).suffix
with NamedTemporaryFile(delete=False, suffix=suffix, dir=save_data_dir) as tmp:
shutil.copyfileobj(file.file, tmp)
tmp_file_name = Path(tmp.name).name tmp_file_name = Path(tmp.name).name
file = {"image_name": image_file.filename, "image_path": save_image_dir + tmp_file_name}
images.append(file)
data = {"data_name": file.filename, "data_path": save_data_dir + tmp_file_name}
data_list.append(data)
finally: finally:
image_file.file.close()
return images
file.file.close
return data_list




def upload_image_by_base64(image_files):
def upload_image_by_base64(data_list):
""" """
base64图片信息保存到本地 base64图片信息保存到本地
""" """
save_image_dir = "/usr/local/images/"
if not os.path.exists(save_image_dir):
os.mkdir(save_image_dir)
images = list()
for img_file in image_files:
file_path = save_image_dir + str(int(round(time.time() * MAX_TIME_LENGTH))) + "." + img_file.image_name.split(".")[-1]
img_data = base64.b64decode(img_file.image_file)
save_data_dir = "/usr/local/data/"
if not os.path.exists(save_data_dir):
os.mkdir(save_data_dir)
data_list_b64 = list()
for data in data_list:
file_path = save_data_dir + str(int(round(time.time() * MAX_TIME_LENGTH))) + "." + data.data_name.split(".")[-1]
file_b64 = base64.b64decode(data.data_file)
file = open(file_path, 'wb') file = open(file_path, 'wb')
file.write(img_data)
file.write(file_b64)
file.close() file.close()
image = {"image_name": img_file.image_name, "image_path": file_path}
images.append(image)
return images
data_b64 = {"data_name": data.data_name, "data_path": file_path}
data_list_b64.append(data_b64)
return data_list_b64




def writer_json_file(json_path, image_name, data):
def writer_json_file(json_path, data_name, data):
""" """
保存为json文件 保存为json文件
""" """
if not os.path.exists(json_path): if not os.path.exists(json_path):
os.mkdir(json_path) os.mkdir(json_path)
filename = json_path + image_name + '.json'
filename = json_path + data_name + '.json'
with open(filename, 'w', encoding='utf-8') as file_obj: with open(filename, 'w', encoding='utf-8') as file_obj:
file_obj.write(json.dumps(data, ensure_ascii=False)) file_obj.write(json.dumps(data, ensure_ascii=False))

+ 88
- 0
tianshu_serving/utils/find_class_in_file.py View File

@@ -0,0 +1,88 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import re
import inspect
# from ac_opf.models import ac_model


class FindClassInFile:
"""
find class in the given module

### note ###
There is only one class in the given module.
If there are more than one classes, all of them will be omit except the first
############


method: find
args:
module: object-> the given file or module
encoding: string-> "utf8" by default
output: tuple(string-> name of the class, object-> class)


usage:

# >>> import module
#
# >>> find_class_in_file = FindClassInFile()
# >>> cls = find_class_in_file.find(module)
#
# >>> cls_instance = cls[1](args)

"""

def __init__(self):
pass

def _open_file(self, path, encoding="utf8"):
with open(path, "r", encoding=encoding) as f:
data = f.readlines()

for line in data:
yield line

def find(self, module, encoding="utf8"):
path = module.__file__
lines = self._open_file(path=path, encoding=encoding)

cls = ""
for line in lines:
if "class " in line:
cls = re.findall("class (.*?)[:(]", line)[0]
if cls:
break

return self._valid(module, cls)

def _valid(self, module, cls):
members = inspect.getmembers(module)
cand = [(i, j) for i, j in members if inspect.isclass(j) and (not inspect.isabstract(j)) and (i == cls)]
if not cand:
print("class not found in {}".format(module))
return cand[0]


if __name__ == "__main__":
find_class_in_file = FindClassInFile()
# cls = find_class_in_file.find(ac_model)

# print(cls)








Loading…
Cancel
Save