@@ -5,15 +5,11 @@ | |||
+ **支持oneflow、tensorflow、pytorch三种框架模型部署** </br> | |||
1、通过如下命令启动http在线推理服务 | |||
``` | |||
python http_server.py --platform='框架名称' --model_path='模型地址' | |||
``` | |||
``` | |||
python http_server.py --platform='框架名称' --model_path='模型地址' | |||
``` | |||
通过访问localhost:5000/docs进入swagger页面,调用localhost:5000/inference进行图片上传得道推理结果,结果如下所示: | |||
``` | |||
``` | |||
{ | |||
"image_name": "哈士奇.jpg", | |||
"predictions": [ | |||
@@ -39,28 +35,35 @@ | |||
} | |||
] | |||
} | |||
``` | |||
``` | |||
2、同理通过如下命令启动grpc在线推理服务 | |||
``` | |||
python grpc_server.py --platform='框架名称' --model_path='模型地址' | |||
``` | |||
``` | |||
python grpc_server.py --platform='框架名称' --model_path='模型地址' | |||
``` | |||
再启动grpc_client.py进行上传图片推理得道结果,或者根据ip端口自行编写grpc客户端 | |||
3、支持多模型部署,可以自行配置config文件夹下的model_config_file.json进行多模型配置,启动http或grpc时输入不同的模型名称即可,或者自行修改inference接口入参来达到启动单一服务多模型推理的功能 | |||
+ **支持多模型部署** </br> | |||
用户可以自行配置config文件夹下的model_config_file.json进行多模型配置,启动http或grpc时输入不同的模型名称即可,或者自行修改inference接口入参来达到启动单一服务多模型推理的功能 | |||
+ **支持分布式模型部署推理** </br> | |||
需要推理大量图片时需要分布式推理功能,执行如下命令: | |||
``` | |||
python batch_server.py --platform='框架名称' --model_path='模型地址' --input_path='批量图片地址' --output_path='输出JSON文件地址' | |||
``` | |||
``` | |||
python batch_server.py --platform='框架名称' --model_path='模型地址' --input_path='批量图片地址' --output_path='输出JSON文件地址' | |||
``` | |||
输入的所有图片保存在input文件夹下,输入json文件保存在output_path文件夹,json名称与图片名称对应 | |||
+ **支持使用自定义推理脚本** </br> | |||
用户需要使用自定义推理脚本时,可根据common_inference_service.py脚本中注释的规则,自定义推理脚本,并替换原有的common_inference_service.py脚本。此外,在启动命令中添加use_script参数,即可在推理时使用自定义的推理脚本。命令如下所示: | |||
``` | |||
python grpc_server.py --platform='框架名称' --model_path='模型地址' --user_script=True | |||
``` | |||
+ **代码还包含了各种参数配置,日志文件输出、是否启用TLS等** </br> | |||
@@ -27,49 +27,47 @@ log = Logger().logger | |||
def get_host_ip(): | |||
""" | |||
查询本机ip地址 | |||
:return: | |||
return | |||
""" | |||
global s | |||
try: | |||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | |||
ip = s.getsockname()[0] | |||
finally: | |||
s.close() | |||
hostname = socket.gethostname() | |||
ip = socket.gethostbyname(hostname) | |||
return ip | |||
def read_directory(images_path): | |||
def read_directory(data_path): | |||
""" | |||
读取文件夹并进行拆分文件 | |||
:return: | |||
""" | |||
files = os.listdir(images_path) | |||
files = os.listdir(data_path) | |||
num_files = len(files) | |||
index_list = list(range(num_files)) | |||
images = list() | |||
data_list = list() | |||
for index in index_list: | |||
# 是否开启分布式 | |||
if args.enable_distributed: | |||
ip = get_host_ip() | |||
log.info("NODE_IPS:{}", os.getenv('NODE_IPS')) | |||
ip_list = os.getenv('NODE_IPS').split(",") | |||
num_ips = len(ip_list) | |||
ip_index = ip_list.index(ip) | |||
if ip_index == index % num_ips: | |||
filename = files[index] | |||
image = {"image_name": filename, "image_path": images_path + filename} | |||
images.append(image) | |||
data = {"data_name": filename, "data_path": data_path + filename} | |||
data_list.append(data) | |||
else: | |||
filename = files[index] | |||
image = {"image_name": filename, "image_path": images_path + filename} | |||
images.append(image) | |||
return images | |||
data = {"data_name": filename, "data_path": data_path + filename} | |||
data_list.append(data) | |||
return data_list | |||
def main(): | |||
images = read_directory(args.input_path) | |||
inference_service.inference_and_save_json(args.model_name, args.output_path, images) | |||
data_list = read_directory(args.input_path) | |||
inference_service.inference_and_save_json(args.model_name, args.output_path, data_list) | |||
if args.enable_distributed: | |||
ip = get_host_ip() | |||
log.info("NODE_IPS:{}", os.getenv('NODE_IPS')) | |||
ip_list = os.getenv('NODE_IPS').split(",") | |||
# 主节点必须等待从节点推理完成 | |||
if ip == ip_list[0]: | |||
@@ -58,12 +58,12 @@ def get_parser(parser=None): | |||
parser.add_argument("--job_name", type=str, default="inference", help="oneflow job name") | |||
parser.add_argument("--prepare_mode", type=str, default="tfhub", | |||
help="tensorflow prepare mode(tfhub、caffe、tf、torch)") | |||
parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="is use gpu") | |||
parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="whether to use gpu") | |||
parser.add_argument('--channel_last', type=str2bool, nargs='?', const=False, | |||
help='Whether to use use channel last mode(nhwc)') | |||
parser.add_argument("--model_path", type=str, default="/usr/local/model/pytorch_models/resnet50/", | |||
parser.add_argument("--model_path", type=str, default="/usr/local/work/models/pytorch_models/resnet50/", | |||
help="model load directory if need") | |||
parser.add_argument("--image_path", type=str, default='/usr/local/data/fish.jpg', help="image path") | |||
parser.add_argument("--data_path", type=str, default='/usr/local/work/dog.jpg', help="input data path") | |||
parser.add_argument("--reshape_size", type=int_list, default='[224]', | |||
help="The reshape size of the image(eg. 224)") | |||
parser.add_argument("--num_classes", type=int, default=1000, help="num of pic classes") | |||
@@ -78,8 +78,9 @@ def get_parser(parser=None): | |||
parser.add_argument("--model_config_file", type=str, default="", help="The file of the model config(eg. '')") | |||
parser.add_argument("--enable_distributed", type=ast.literal_eval, default=False, help="If enable use distributed " | |||
"environment") | |||
parser.add_argument("--input_path", type=str, default="/usr/local/data/images/", help="images path") | |||
parser.add_argument("--output_path", type=str, default="/usr/local/output_path/", help="json path") | |||
parser.add_argument("--input_path", type=str, default="/usr/local/input/", help="input batch data path") | |||
parser.add_argument("--output_path", type=str, default="/usr/local/output/", help="output json path") | |||
parser.add_argument("--use_script", type=ast.literal_eval, default=False, help="whether to use custom inference script") | |||
return parser | |||
@@ -0,0 +1,84 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
""" | |||
import os | |||
import io | |||
import torch | |||
import torch.nn.functional as functional | |||
from PIL import Image | |||
from torchvision import transforms | |||
from imagenet1000_clsidx_to_labels import clsidx_2_labels | |||
from logger import Logger | |||
log = Logger().logger | |||
#只能定义一个class | |||
class CommonInferenceService: | |||
# __init__初始化方法中接收args参数(其中模型路径参数为args.model_path,是否使用gpu参数为args.use_gpu),并加载模型(方法用户可自定义) | |||
def __init__(self, args): | |||
self.args = args | |||
self.model = self.load_model() | |||
def load_data(self, data_path): | |||
image = open(data_path, 'rb').read() | |||
image = Image.open(io.BytesIO(image)) | |||
if image.mode != 'RGB': | |||
image = image.convert("RGB") | |||
image = transforms.Resize((self.args.reshape_size[0], self.args.reshape_size[1]))(image) | |||
image = transforms.ToTensor()(image) | |||
image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image) | |||
image = image[None] | |||
if self.args.use_gpu: | |||
image = image.cuda() | |||
return image | |||
def load_model(self): | |||
if os.path.isfile(self.args.model_path): | |||
self.checkpoint = torch.load(self.args.model_path) | |||
else: | |||
for file in os.listdir(self.args.model_path): | |||
self.checkpoint = torch.load(self.args.model_path + file) | |||
model = self.checkpoint["model"] | |||
model.load_state_dict(self.checkpoint['state_dict']) | |||
for parameter in model.parameters(): | |||
parameter.requires_grad = False | |||
if self.args.use_gpu: | |||
model.cuda() | |||
model.eval() | |||
return model | |||
# inference方法名称固定 | |||
def inference(self, data): | |||
result = {"data_name": data['data_name']} | |||
data = self.load_data(data['data_path']) | |||
preds = functional.softmax(self.model(data), dim=1) | |||
predictions = torch.topk(preds.data, k=5, dim=1) | |||
result['predictions'] = list() | |||
for prob, label in zip(predictions[0][0], predictions[1][0]): | |||
predictions = {"label": clsidx_2_labels[int(label)], "probability": "{:.3f}".format(float(prob))} | |||
result['predictions'].append(predictions) | |||
return result | |||
if __name__=="__main__": | |||
import argparse | |||
parser = argparse.ArgumentParser(description='tianshu serving') | |||
parser.add_argument('--model_path', type=str, default='./res4serving.pth', help="model path") | |||
parser.add_argument('--use_gpu', type=bool, default=True, help="use gpu or not") | |||
parser.add_argument('--reshape_size', type=list, default=[224,224], help="use gpu or not") | |||
args = parser.parse_args() | |||
server = CommonInferenceService(args) | |||
image_path = "./cat.jpg" | |||
image = {"data_name": "cat.jpg", "data_path": image_path} | |||
re = server.inference(image) | |||
print(re) |
@@ -21,8 +21,8 @@ log = Logger().logger | |||
parser = configs.get_parser() | |||
args = parser.parse_args() | |||
_HOST = 'kohj2s.serving.dubhe.ai' | |||
_PORT = '31365' | |||
_HOST = '10.5.24.134' | |||
_PORT = '8500' | |||
MAX_MESSAGE_LENGTH = 1024 * 1024 * 1024 # 可根据具体需求设置,此处设为1G | |||
@@ -41,12 +41,12 @@ def run(): | |||
('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH), ], ) # 创建连接 | |||
client = inference_pb2_grpc.InferenceServiceStub(channel=channel) # 创建客户端 | |||
data_request = inference_pb2.DataRequest() | |||
Image = data_request.images.add() | |||
Image.image_file = str(base64.b64encode(open("F:\\Files\\pic\\哈士奇.jpg", "rb").read()), encoding='utf-8') | |||
Image.image_name = "哈士奇.jpg" | |||
Image = data_request.images.add() | |||
Image.image_file = str(base64.b64encode(open("F:\\Files\\pic\\fish.jpg", "rb").read()), encoding='utf-8') | |||
Image.image_name = "fish.jpg" | |||
data1 = data_request.data_list.add() | |||
data1.data_file = str(base64.b64encode(open("/usr/local/input/dog.jpg", "rb").read()), encoding='utf-8') | |||
data1.data_name = "dog.jpg" | |||
data2 = data_request.data_list.add() | |||
data2.data_file = str(base64.b64encode(open("/usr/local/input/6.jpg", "rb").read()), encoding='utf-8') | |||
data2.data_name = "6.jpg" | |||
response = client.inference(data_request) | |||
log.info(response.json_result.encode('utf-8').decode('unicode_escape')) | |||
@@ -46,19 +46,21 @@ class InferenceService(inference_pb2_grpc.InferenceServiceServicer): | |||
调用grpc方法进行推理 | |||
""" | |||
def inference(self, request, context): | |||
image_files = request.images | |||
data_list = request.data_list | |||
log.info("===============> grpc inference start <===============") | |||
try: | |||
images = file_utils.upload_image_by_base64(image_files) # 上传图片到本地 | |||
data_list_b64 = file_utils.upload_image_by_base64(data_list) # 上传图片到本地 | |||
except Exception as e: | |||
log.error("upload data failed", e) | |||
return inference_pb2.DataResponse(json_result=json.dumps( | |||
response_convert(Response(success=False, data=str(e), error="upload image fail")))) | |||
response_convert(Response(success=False, data=str(e), error="upload data failed")))) | |||
try: | |||
result = inference_service.inference(args.model_name, images) | |||
result = inference_service.inference(args.model_name, data_list_b64) | |||
log.info("===============> grpc inference success <===============") | |||
return inference_pb2.DataResponse(json_result=json.dumps( | |||
response_convert(Response(success=True, data=result)))) | |||
except Exception as e: | |||
log.error("inference fail", e) | |||
return inference_pb2.DataResponse(json_result=json.dumps( | |||
response_convert(Response(success=False, data=str(e), error="inference fail")))) | |||
@@ -49,7 +49,7 @@ async def inference(images_path: List[str] = None): | |||
threading.Thread(target=file_utils.download_image(images_path)) # 开启异步线程下载图片到本地 | |||
images = list() | |||
for image in images_path: | |||
data = {"image_name": image.split("/")[-1], "image_path": image} | |||
data = {"data_name": image.split("/")[-1], "data_path": image} | |||
images.append(data) | |||
try: | |||
data = inference_service.inference(args.model_name, images) | |||
@@ -59,17 +59,22 @@ async def inference(images_path: List[str] = None): | |||
@app.post("/inference") | |||
async def inference(image_files: List[UploadFile] = File(...)): | |||
async def inference(files: List[UploadFile] = File(...)): | |||
""" | |||
上传本地文件推理 | |||
""" | |||
log.info("===============> http inference start <===============") | |||
try: | |||
images = file_utils.upload_image(image_files) # 上传图片到本地 | |||
data_list = file_utils.upload_data(files) # 上传图片到本地 | |||
except Exception as e: | |||
return Response(success=False, data=str(e), error="upload image fail") | |||
log.error("upload data failed", e) | |||
return Response(success=False, data=str(e), error="upload data failed") | |||
try: | |||
result = inference_service.inference(args.model_name, images) | |||
result = inference_service.inference(args.model_name, data_list) | |||
log.info("===============> http inference success <===============") | |||
return Response(success=True, data=result) | |||
except Exception as e: | |||
log.error("inference fail", e) | |||
return Response(success=False, data=str(e), error="inference fail") | |||
@@ -0,0 +1 @@ | |||
2021-01-14 19:55:07,045 - E:/Python Project/TS_Serving/grpc_client.py[line:51] - INFO: {"success": true, "data": [{"image_name": "哈士奇.jpg", "predictions": [{"label": "Eskimo dog, husky", "probability": "0.665"}, {"label": "Siberian husky", "probability": "0.294"}, {"label": "dogsled, dog sled, dog sleigh", "probability": "0.024"}, {"label": "malamute, malemute, Alaskan malamute", "probability": "0.015"}, {"label": "timber wolf, grey wolf, gray wolf, Canis lupus", "probability": "0.000"}]}, {"image_name": "fish.jpg", "predictions": [{"label": "goldfish, Carassius auratus", "probability": "1.000"}, {"label": "tench, Tinca tinca", "probability": "0.000"}, {"label": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch", "probability": "0.000"}, {"label": "axolotl, mud puppy, Ambystoma mexicanum", "probability": "0.000"}, {"label": "barracouta, snoek", "probability": "0.000"}]}], "error": ""} |
@@ -0,0 +1 @@ | |||
2021-02-03 15:05:00,069 - E:/Python Project/TS_Serving/grpc_client.py[line:51] - INFO: {"success": true, "data": [{"image_name": "哈士奇.jpg", "predictions": [{"label": "Eskimo dog, husky", "probability": "0.679"}, {"label": "Siberian husky", "probability": "0.213"}, {"label": "dogsled, dog sled, dog sleigh", "probability": "0.021"}, {"label": "malamute, malemute, Alaskan malamute", "probability": "0.006"}, {"label": "white wolf, Arctic wolf, Canis lupus tundrarum", "probability": "0.001"}]}, {"image_name": "fish.jpg", "predictions": [{"label": "goldfish, Carassius auratus", "probability": "0.871"}, {"label": "tench, Tinca tinca", "probability": "0.004"}, {"label": "puffer, pufferfish, blowfish, globefish", "probability": "0.002"}, {"label": "rock beauty, Holocanthus tricolor", "probability": "0.002"}, {"label": "barracouta, snoek", "probability": "0.001"}]}], "error": ""} |
@@ -1,16 +1,16 @@ | |||
syntax = 'proto3'; | |||
service InferenceService { | |||
rpc inference(DataRequest) returns (DataResponse) {} | |||
} | |||
message DataRequest{ | |||
repeated Image images = 1; | |||
repeated Data data_list = 1; | |||
} | |||
message Image { | |||
string image_file = 1; | |||
string image_name = 2; | |||
message Data { | |||
string data_file = 1; | |||
string data_name = 2; | |||
} | |||
message DataResponse{ | |||
@@ -6,168 +6,177 @@ from google.protobuf import descriptor as _descriptor | |||
from google.protobuf import message as _message | |||
from google.protobuf import reflection as _reflection | |||
from google.protobuf import symbol_database as _symbol_database | |||
# @@protoc_insertion_point(imports) | |||
_sym_db = _symbol_database.Default() | |||
DESCRIPTOR = _descriptor.FileDescriptor( | |||
name='inference.proto', | |||
package='', | |||
syntax='proto3', | |||
serialized_options=None, | |||
create_key=_descriptor._internal_create_key, | |||
serialized_pb=b'\n\x0finference.proto\"%\n\x0b\x44\x61taRequest\x12\x16\n\x06images\x18\x01 \x03(\x0b\x32\x06.Image\"/\n\x05Image\x12\x12\n\nimage_file\x18\x01 \x01(\t\x12\x12\n\nimage_name\x18\x02 \x01(\t\"#\n\x0c\x44\x61taResponse\x12\x13\n\x0bjson_result\x18\x01 \x01(\t2>\n\x10InferenceService\x12*\n\tinference\x12\x0c.DataRequest\x1a\r.DataResponse\"\x00\x62\x06proto3' | |||
name='inference.proto', | |||
package='', | |||
syntax='proto3', | |||
serialized_options=None, | |||
create_key=_descriptor._internal_create_key, | |||
serialized_pb=b'\n\x0finference.proto\"\'\n\x0b\x44\x61taRequest\x12\x18\n\tdata_list\x18\x01 \x03(\x0b\x32\x05.Data\",\n\x04\x44\x61ta\x12\x11\n\tdata_file\x18\x01 \x01(\t\x12\x11\n\tdata_name\x18\x02 \x01(\t\"#\n\x0c\x44\x61taResponse\x12\x13\n\x0bjson_result\x18\x01 \x01(\t2>\n\x10InferenceService\x12*\n\tinference\x12\x0c.DataRequest\x1a\r.DataResponse\"\x00\x62\x06proto3' | |||
) | |||
_DATAREQUEST = _descriptor.Descriptor( | |||
name='DataRequest', | |||
full_name='DataRequest', | |||
filename=None, | |||
file=DESCRIPTOR, | |||
containing_type=None, | |||
create_key=_descriptor._internal_create_key, | |||
fields=[ | |||
_descriptor.FieldDescriptor( | |||
name='images', full_name='DataRequest.images', index=0, | |||
number=1, type=11, cpp_type=10, label=3, | |||
has_default_value=False, default_value=[], | |||
message_type=None, enum_type=None, containing_type=None, | |||
is_extension=False, extension_scope=None, | |||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), | |||
], | |||
extensions=[ | |||
], | |||
nested_types=[], | |||
enum_types=[ | |||
], | |||
serialized_options=None, | |||
is_extendable=False, | |||
syntax='proto3', | |||
extension_ranges=[], | |||
oneofs=[ | |||
], | |||
serialized_start=19, | |||
serialized_end=56, | |||
name='DataRequest', | |||
full_name='DataRequest', | |||
filename=None, | |||
file=DESCRIPTOR, | |||
containing_type=None, | |||
create_key=_descriptor._internal_create_key, | |||
fields=[ | |||
_descriptor.FieldDescriptor( | |||
name='data_list', full_name='DataRequest.data_list', index=0, | |||
number=1, type=11, cpp_type=10, label=3, | |||
has_default_value=False, default_value=[], | |||
message_type=None, enum_type=None, containing_type=None, | |||
is_extension=False, extension_scope=None, | |||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), | |||
], | |||
extensions=[ | |||
], | |||
nested_types=[], | |||
enum_types=[ | |||
], | |||
serialized_options=None, | |||
is_extendable=False, | |||
syntax='proto3', | |||
extension_ranges=[], | |||
oneofs=[ | |||
], | |||
serialized_start=19, | |||
serialized_end=58, | |||
) | |||
_IMAGE = _descriptor.Descriptor( | |||
name='Image', | |||
full_name='Image', | |||
filename=None, | |||
file=DESCRIPTOR, | |||
containing_type=None, | |||
create_key=_descriptor._internal_create_key, | |||
fields=[ | |||
_descriptor.FieldDescriptor( | |||
name='image_file', full_name='Image.image_file', index=0, | |||
number=1, type=9, cpp_type=9, label=1, | |||
has_default_value=False, default_value=b"".decode('utf-8'), | |||
message_type=None, enum_type=None, containing_type=None, | |||
is_extension=False, extension_scope=None, | |||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), | |||
_descriptor.FieldDescriptor( | |||
name='image_name', full_name='Image.image_name', index=1, | |||
number=2, type=9, cpp_type=9, label=1, | |||
has_default_value=False, default_value=b"".decode('utf-8'), | |||
message_type=None, enum_type=None, containing_type=None, | |||
is_extension=False, extension_scope=None, | |||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), | |||
], | |||
extensions=[ | |||
], | |||
nested_types=[], | |||
enum_types=[ | |||
], | |||
serialized_options=None, | |||
is_extendable=False, | |||
syntax='proto3', | |||
extension_ranges=[], | |||
oneofs=[ | |||
], | |||
serialized_start=58, | |||
serialized_end=105, | |||
_DATA = _descriptor.Descriptor( | |||
name='Data', | |||
full_name='Data', | |||
filename=None, | |||
file=DESCRIPTOR, | |||
containing_type=None, | |||
create_key=_descriptor._internal_create_key, | |||
fields=[ | |||
_descriptor.FieldDescriptor( | |||
name='data_file', full_name='Data.data_file', index=0, | |||
number=1, type=9, cpp_type=9, label=1, | |||
has_default_value=False, default_value=b"".decode('utf-8'), | |||
message_type=None, enum_type=None, containing_type=None, | |||
is_extension=False, extension_scope=None, | |||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), | |||
_descriptor.FieldDescriptor( | |||
name='data_name', full_name='Data.data_name', index=1, | |||
number=2, type=9, cpp_type=9, label=1, | |||
has_default_value=False, default_value=b"".decode('utf-8'), | |||
message_type=None, enum_type=None, containing_type=None, | |||
is_extension=False, extension_scope=None, | |||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), | |||
], | |||
extensions=[ | |||
], | |||
nested_types=[], | |||
enum_types=[ | |||
], | |||
serialized_options=None, | |||
is_extendable=False, | |||
syntax='proto3', | |||
extension_ranges=[], | |||
oneofs=[ | |||
], | |||
serialized_start=60, | |||
serialized_end=104, | |||
) | |||
_DATARESPONSE = _descriptor.Descriptor( | |||
name='DataResponse', | |||
full_name='DataResponse', | |||
filename=None, | |||
file=DESCRIPTOR, | |||
containing_type=None, | |||
create_key=_descriptor._internal_create_key, | |||
fields=[ | |||
_descriptor.FieldDescriptor( | |||
name='json_result', full_name='DataResponse.json_result', index=0, | |||
number=1, type=9, cpp_type=9, label=1, | |||
has_default_value=False, default_value=b"".decode('utf-8'), | |||
message_type=None, enum_type=None, containing_type=None, | |||
is_extension=False, extension_scope=None, | |||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), | |||
], | |||
extensions=[ | |||
], | |||
nested_types=[], | |||
enum_types=[ | |||
], | |||
serialized_options=None, | |||
is_extendable=False, | |||
syntax='proto3', | |||
extension_ranges=[], | |||
oneofs=[ | |||
], | |||
serialized_start=107, | |||
serialized_end=142, | |||
name='DataResponse', | |||
full_name='DataResponse', | |||
filename=None, | |||
file=DESCRIPTOR, | |||
containing_type=None, | |||
create_key=_descriptor._internal_create_key, | |||
fields=[ | |||
_descriptor.FieldDescriptor( | |||
name='json_result', full_name='DataResponse.json_result', index=0, | |||
number=1, type=9, cpp_type=9, label=1, | |||
has_default_value=False, default_value=b"".decode('utf-8'), | |||
message_type=None, enum_type=None, containing_type=None, | |||
is_extension=False, extension_scope=None, | |||
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), | |||
], | |||
extensions=[ | |||
], | |||
nested_types=[], | |||
enum_types=[ | |||
], | |||
serialized_options=None, | |||
is_extendable=False, | |||
syntax='proto3', | |||
extension_ranges=[], | |||
oneofs=[ | |||
], | |||
serialized_start=106, | |||
serialized_end=141, | |||
) | |||
_DATAREQUEST.fields_by_name['images'].message_type = _IMAGE | |||
_DATAREQUEST.fields_by_name['data_list'].message_type = _DATA | |||
DESCRIPTOR.message_types_by_name['DataRequest'] = _DATAREQUEST | |||
DESCRIPTOR.message_types_by_name['Image'] = _IMAGE | |||
DESCRIPTOR.message_types_by_name['Data'] = _DATA | |||
DESCRIPTOR.message_types_by_name['DataResponse'] = _DATARESPONSE | |||
_sym_db.RegisterFileDescriptor(DESCRIPTOR) | |||
DataRequest = _reflection.GeneratedProtocolMessageType('DataRequest', (_message.Message,), { | |||
'DESCRIPTOR': _DATAREQUEST, | |||
'__module__': 'inference_pb2' | |||
# @@protoc_insertion_point(class_scope:DataRequest) | |||
}) | |||
'DESCRIPTOR' : _DATAREQUEST, | |||
'__module__' : 'inference_pb2' | |||
# @@protoc_insertion_point(class_scope:DataRequest) | |||
}) | |||
_sym_db.RegisterMessage(DataRequest) | |||
Image = _reflection.GeneratedProtocolMessageType('Image', (_message.Message,), { | |||
'DESCRIPTOR': _IMAGE, | |||
'__module__': 'inference_pb2' | |||
# @@protoc_insertion_point(class_scope:Image) | |||
}) | |||
_sym_db.RegisterMessage(Image) | |||
Data = _reflection.GeneratedProtocolMessageType('Data', (_message.Message,), { | |||
'DESCRIPTOR' : _DATA, | |||
'__module__' : 'inference_pb2' | |||
# @@protoc_insertion_point(class_scope:Data) | |||
}) | |||
_sym_db.RegisterMessage(Data) | |||
DataResponse = _reflection.GeneratedProtocolMessageType('DataResponse', (_message.Message,), { | |||
'DESCRIPTOR': _DATARESPONSE, | |||
'__module__': 'inference_pb2' | |||
# @@protoc_insertion_point(class_scope:DataResponse) | |||
}) | |||
'DESCRIPTOR' : _DATARESPONSE, | |||
'__module__' : 'inference_pb2' | |||
# @@protoc_insertion_point(class_scope:DataResponse) | |||
}) | |||
_sym_db.RegisterMessage(DataResponse) | |||
_INFERENCESERVICE = _descriptor.ServiceDescriptor( | |||
name='InferenceService', | |||
full_name='InferenceService', | |||
file=DESCRIPTOR, | |||
name='InferenceService', | |||
full_name='InferenceService', | |||
file=DESCRIPTOR, | |||
index=0, | |||
serialized_options=None, | |||
create_key=_descriptor._internal_create_key, | |||
serialized_start=143, | |||
serialized_end=205, | |||
methods=[ | |||
_descriptor.MethodDescriptor( | |||
name='inference', | |||
full_name='InferenceService.inference', | |||
index=0, | |||
containing_service=None, | |||
input_type=_DATAREQUEST, | |||
output_type=_DATARESPONSE, | |||
serialized_options=None, | |||
create_key=_descriptor._internal_create_key, | |||
serialized_start=144, | |||
serialized_end=206, | |||
methods=[ | |||
_descriptor.MethodDescriptor( | |||
name='inference', | |||
full_name='InferenceService.inference', | |||
index=0, | |||
containing_service=None, | |||
input_type=_DATAREQUEST, | |||
output_type=_DATARESPONSE, | |||
serialized_options=None, | |||
create_key=_descriptor._internal_create_key, | |||
), | |||
]) | |||
), | |||
]) | |||
_sym_db.RegisterServiceDescriptor(_INFERENCESERVICE) | |||
DESCRIPTOR.services_by_name['InferenceService'] = _INFERENCESERVICE | |||
@@ -15,10 +15,10 @@ class InferenceServiceStub(object): | |||
channel: A grpc.Channel. | |||
""" | |||
self.inference = channel.unary_unary( | |||
'/InferenceService/inference', | |||
request_serializer=inference__pb2.DataRequest.SerializeToString, | |||
response_deserializer=inference__pb2.DataResponse.FromString, | |||
) | |||
'/InferenceService/inference', | |||
request_serializer=inference__pb2.DataRequest.SerializeToString, | |||
response_deserializer=inference__pb2.DataResponse.FromString, | |||
) | |||
class InferenceServiceServicer(object): | |||
@@ -33,34 +33,34 @@ class InferenceServiceServicer(object): | |||
def add_InferenceServiceServicer_to_server(servicer, server): | |||
rpc_method_handlers = { | |||
'inference': grpc.unary_unary_rpc_method_handler( | |||
servicer.inference, | |||
request_deserializer=inference__pb2.DataRequest.FromString, | |||
response_serializer=inference__pb2.DataResponse.SerializeToString, | |||
), | |||
'inference': grpc.unary_unary_rpc_method_handler( | |||
servicer.inference, | |||
request_deserializer=inference__pb2.DataRequest.FromString, | |||
response_serializer=inference__pb2.DataResponse.SerializeToString, | |||
), | |||
} | |||
generic_handler = grpc.method_handlers_generic_handler( | |||
'InferenceService', rpc_method_handlers) | |||
'InferenceService', rpc_method_handlers) | |||
server.add_generic_rpc_handlers((generic_handler,)) | |||
# This class is part of an EXPERIMENTAL API. | |||
# This class is part of an EXPERIMENTAL API. | |||
class InferenceService(object): | |||
"""Missing associated documentation comment in .proto file.""" | |||
@staticmethod | |||
def inference(request, | |||
target, | |||
options=(), | |||
channel_credentials=None, | |||
call_credentials=None, | |||
insecure=False, | |||
compression=None, | |||
wait_for_ready=None, | |||
timeout=None, | |||
metadata=None): | |||
target, | |||
options=(), | |||
channel_credentials=None, | |||
call_credentials=None, | |||
insecure=False, | |||
compression=None, | |||
wait_for_ready=None, | |||
timeout=None, | |||
metadata=None): | |||
return grpc.experimental.unary_unary(request, target, '/InferenceService/inference', | |||
inference__pb2.DataRequest.SerializeToString, | |||
inference__pb2.DataResponse.FromString, | |||
options, channel_credentials, | |||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) | |||
inference__pb2.DataRequest.SerializeToString, | |||
inference__pb2.DataResponse.FromString, | |||
options, channel_credentials, | |||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) |
@@ -0,0 +1,87 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
""" | |||
import os | |||
import io | |||
import torch | |||
import torch.nn.functional as functional | |||
from PIL import Image | |||
from torchvision import transforms | |||
from imagenet1000_clsidx_to_labels import clsidx_2_labels | |||
from logger import Logger | |||
log = Logger().logger | |||
# 只能定义一个class | |||
class CommonInferenceService: | |||
# 请在__init__初始化方法中接收args参数,并加载模型(其中模型路径参数为args.model_path,是否使用gpu参数为args.use_gpu,模型加载方法用户可自定义) | |||
def __init__(self, args): | |||
self.args = args | |||
self.model = self.load_model() | |||
def load_data(self, data_path): | |||
image = open(data_path, 'rb').read() | |||
image = Image.open(io.BytesIO(image)) | |||
if image.mode != 'RGB': | |||
image = image.convert("RGB") | |||
image = transforms.Resize((self.args.reshape_size[0], self.args.reshape_size[1]))(image) | |||
image = transforms.ToTensor()(image) | |||
image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image) | |||
image = image[None] | |||
if self.args.use_gpu: | |||
image = image.cuda() | |||
return image | |||
def load_model(self): | |||
if os.path.isfile(self.args.model_path): | |||
self.checkpoint = torch.load(self.args.model_path) | |||
else: | |||
for file in os.listdir(self.args.model_path): | |||
self.checkpoint = torch.load(self.args.model_path + file) | |||
model = self.checkpoint["model"] | |||
model.load_state_dict(self.checkpoint['state_dict']) | |||
for parameter in model.parameters(): | |||
parameter.requires_grad = False | |||
if self.args.use_gpu: | |||
model.cuda() | |||
model.eval() | |||
return model | |||
# inference方法名称固定 | |||
def inference(self, data): | |||
result = {"data_name": data['data_name']} | |||
log.info("===============> start load " + data['data_name'] + " <===============") | |||
data = self.load_data(data['data_path']) | |||
preds = functional.softmax(self.model(data), dim=1) | |||
predictions = torch.topk(preds.data, k=5, dim=1) | |||
result['predictions'] = list() | |||
for prob, label in zip(predictions[0][0], predictions[1][0]): | |||
predictions = {"label": clsidx_2_labels[int(label)], "probability": "{:.3f}".format(float(prob))} | |||
result['predictions'].append(predictions) | |||
return result | |||
# 非必须,可用于本地调试 | |||
if __name__=="__main__": | |||
import argparse | |||
parser = argparse.ArgumentParser(description='dubhe serving') | |||
parser.add_argument('--model_path', type=str, default='./res4serving.pth', help="model path") | |||
parser.add_argument('--use_gpu', type=bool, default=True, help="use gpu or not") | |||
parser.add_argument('--reshape_size', type=list, default=[224,224], help="use gpu or not") | |||
args = parser.parse_args() | |||
server = CommonInferenceService(args) | |||
image_path = "./cat.jpg" | |||
image = {"data_name": "cat.jpg", "data_path": image_path} | |||
re = server.inference(image) | |||
print(re) |
@@ -15,8 +15,10 @@ import time | |||
from service.oneflow_inference_service import OneFlowInferenceService | |||
from service.tensorflow_inference_service import TensorflowInferenceService | |||
from service.pytorch_inference_service import PytorchInferenceService | |||
import service.common_inference_service as common_inference_service | |||
from logger import Logger | |||
from utils import file_utils | |||
from utils.find_class_in_file import FindClassInFile | |||
log = Logger().logger | |||
@@ -36,47 +38,58 @@ class InferenceServiceManager: | |||
for model_config in model_config_list: | |||
model_name = model_config["model_name"] | |||
model_path = model_config["model_path"] | |||
self.args.model_name = model_name | |||
self.args.model_path = model_path | |||
model_platform = model_config.get("platform") | |||
if model_platform == "oneflow": | |||
self.inference_service = OneFlowInferenceService(model_name, model_path) | |||
self.inference_service = OneFlowInferenceService(self.args) | |||
elif model_platform == "tensorflow" or model_platform == "keras": | |||
self.inference_service = TensorflowInferenceService(model_name, model_path) | |||
self.inference_service = TensorflowInferenceService(self.args) | |||
elif model_platform == "pytorch": | |||
self.inference_service = PytorchInferenceService(model_name, model_path) | |||
self.inference_service = PytorchInferenceService(self.args) | |||
self.model_name_service_map[model_name] = self.inference_service | |||
else: | |||
# Read from command-line parameter | |||
if self.args.platform == "oneflow": | |||
self.inference_service = OneFlowInferenceService(self.args.model_name, self.args.model_path) | |||
elif self.args.platform == "tensorflow" or self.args.platform == "keras": | |||
self.inference_service = TensorflowInferenceService(self.args.model_name, self.args.model_path) | |||
elif self.args.platform == "pytorch": | |||
self.inference_service = PytorchInferenceService(self.args.model_name, self.args.model_path) | |||
if self.args.use_script: | |||
# 使用自定义推理脚本 | |||
find_class_in_file = FindClassInFile() | |||
cls = find_class_in_file.find(common_inference_service) | |||
self.inference_service = cls[1](self.args) | |||
else : | |||
# 使用默认推理脚本 | |||
if self.args.platform == "oneflow": | |||
self.inference_service = OneFlowInferenceService(self.args) | |||
elif self.args.platform == "tensorflow" or self.args.platform == "keras": | |||
self.inference_service = TensorflowInferenceService(self.args) | |||
elif self.args.platform == "pytorch": | |||
self.inference_service = PytorchInferenceService(self.args) | |||
self.model_name_service_map[self.args.model_name] = self.inference_service | |||
def inference(self, model_name, images): | |||
def inference(self, model_name, data_list): | |||
""" | |||
在线服务推理方法 | |||
""" | |||
inferenceService = self.model_name_service_map[model_name] | |||
result = list() | |||
for image in images: | |||
data = inferenceService.inference(image) | |||
if len(images) == 1: | |||
return data | |||
for data in data_list: | |||
output = inferenceService.inference(data) | |||
if len(data_list) == 1: | |||
return output | |||
else: | |||
result.append(data) | |||
result.append(output) | |||
return result | |||
def inference_and_save_json(self, model_name, json_path, images): | |||
def inference_and_save_json(self, model_name, json_path, data_list): | |||
""" | |||
批量服务推理方法 | |||
""" | |||
inferenceService = self.model_name_service_map[model_name] | |||
for image in images: | |||
data = inferenceService.inference(image) | |||
file_utils.writer_json_file(json_path, image['image_name'], data) | |||
for data in data_list: | |||
result = inferenceService.inference(data) | |||
file_utils.writer_json_file(json_path, data['data_name'], result) | |||
time.sleep(1) |
@@ -20,12 +20,8 @@ import google.protobuf.text_format as text_format | |||
import os | |||
from imagenet1000_clsidx_to_labels import clsidx_2_labels | |||
from logger import Logger | |||
import config as configs | |||
from service.abstract_inference_service import AbstractInferenceService | |||
parser = configs.get_parser() | |||
args = parser.parse_args() | |||
log = Logger().logger | |||
@@ -33,10 +29,11 @@ class OneFlowInferenceService(AbstractInferenceService): | |||
""" | |||
oneflow 框架推理service | |||
""" | |||
def __init__(self, model_name, model_path): | |||
def __init__(self, args): | |||
super().__init__() | |||
self.model_name = model_name | |||
self.model_path = model_path | |||
self.args = args | |||
self.model_name = args.model_name | |||
self.model_path = args.model_path | |||
flow.clear_default_session() | |||
self.infer_session = flow.SimpleSession() | |||
self.load_model() | |||
@@ -91,9 +88,9 @@ class OneFlowInferenceService(AbstractInferenceService): | |||
return saved_model_proto | |||
def inference(self, image): | |||
data = {"image_name": image['image_name']} | |||
log.info("===============> start load " + image['image_name'] + " <===============") | |||
images = self.load_image(image['image_path']) | |||
data = {"data_name": image['data_name']} | |||
log.info("===============> start load " + image['data_name'] + " <===============") | |||
images = self.load_image(image['data_path']) | |||
predictions = self.infer_session.run('inference', image=images) | |||
@@ -16,15 +16,12 @@ import torch | |||
import torch.nn.functional as functional | |||
from PIL import Image | |||
from torchvision import transforms | |||
import config | |||
import requests | |||
from imagenet1000_clsidx_to_labels import clsidx_2_labels | |||
from io import BytesIO | |||
from logger import Logger | |||
from service.abstract_inference_service import AbstractInferenceService | |||
parser = config.get_parser() | |||
args = parser.parse_args() | |||
log = Logger().logger | |||
@@ -33,10 +30,11 @@ class PytorchInferenceService(AbstractInferenceService): | |||
pytorch 框架推理service | |||
""" | |||
def __init__(self, model_name, model_path): | |||
def __init__(self, args): | |||
super().__init__() | |||
self.model_name = model_name | |||
self.model_path = model_path | |||
self.args = args | |||
self.model_name = args.model_name | |||
self.model_path = args.model_path | |||
self.model = self.load_model() | |||
self.checkpoint = None | |||
@@ -52,38 +50,38 @@ class PytorchInferenceService(AbstractInferenceService): | |||
image = Image.open(io.BytesIO(image)) | |||
if image.mode != 'RGB': | |||
image = image.convert("RGB") | |||
image = transforms.Resize((args.reshape_size[0], args.reshape_size[1]))(image) | |||
image = transforms.Resize((self.args.reshape_size[0], self.args.reshape_size[1]))(image) | |||
image = transforms.ToTensor()(image) | |||
image = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(image) | |||
image = image[None] | |||
if args.use_gpu: | |||
if self.args.use_gpu: | |||
image = image.cuda() | |||
log.info("===============> load image success <===============") | |||
return image | |||
def load_model(self): | |||
log.info("===============> start load pytorch model :" + args.model_path + " <===============") | |||
if os.path.isfile(args.model_path): | |||
log.info("===============> start load pytorch model :" + self.args.model_path + " <===============") | |||
if os.path.isfile(self.args.model_path): | |||
self.checkpoint = torch.load(self.model_path) | |||
else: | |||
for file in os.listdir(args.model_path): | |||
for file in os.listdir(self.args.model_path): | |||
self.checkpoint = torch.load(self.model_path + file) | |||
model = self.checkpoint[args.model_structure] | |||
model = self.checkpoint[self.args.model_structure] | |||
model.load_state_dict(self.checkpoint['state_dict']) | |||
for parameter in model.parameters(): | |||
parameter.requires_grad = False | |||
if args.use_gpu: | |||
if self.args.use_gpu: | |||
model.cuda() | |||
model.eval() | |||
log.info("===============> load pytorch model success <===============") | |||
return model | |||
def inference(self, image): | |||
data = {"image_name": image['image_name']} | |||
log.info("===============> start load " + image['image_name'] + " <===============") | |||
image = self.load_image(image['image_path']) | |||
predis = functional.softmax(self.model(image), dim=1) | |||
results = torch.topk(predis.data, k=5, dim=1) | |||
data = {"data_name": image['data_name']} | |||
log.info("===============> start load " + image['data_name'] + " <===============") | |||
image = self.load_image(image['data_path']) | |||
preds = functional.softmax(self.model(image), dim=1) | |||
results = torch.topk(preds.data, k=5, dim=1) | |||
data['predictions'] = list() | |||
for prob, label in zip(results[0][0], results[1][0]): | |||
result = {"label": clsidx_2_labels[int(label)], "probability": "{:.3f}".format(float(prob))} | |||
@@ -13,7 +13,6 @@ limitations under the License. | |||
import tensorflow as tf | |||
import requests | |||
import numpy as np | |||
import config as configs | |||
from imagenet1000_clsidx_to_labels import clsidx_2_labels | |||
from service.abstract_inference_service import AbstractInferenceService | |||
from utils.imagenet_preprocessing_utils import preprocess_input | |||
@@ -21,8 +20,6 @@ from logger import Logger | |||
from PIL import Image | |||
from io import BytesIO | |||
parser = configs.get_parser() | |||
args = parser.parse_args() | |||
log = Logger().logger | |||
@@ -30,11 +27,12 @@ class TensorflowInferenceService(AbstractInferenceService): | |||
""" | |||
tensorflow 框架推理service | |||
""" | |||
def __init__(self, model_name, model_path): | |||
def __init__(self, args): | |||
super().__init__() | |||
self.session = tf.compat.v1.Session(graph=tf.Graph()) | |||
self.model_name = model_name | |||
self.model_path = model_path | |||
self.args = args | |||
self.model_name = args.model_name | |||
self.model_path = args.model_path | |||
self.signature_input_keys = [] | |||
self.signature_input_tensor_names = [] | |||
self.signature_output_keys = [] | |||
@@ -69,11 +67,11 @@ class TensorflowInferenceService(AbstractInferenceService): | |||
self.session, [tf.compat.v1.saved_model.tag_constants.SERVING], self.model_path) | |||
# 加载模型之前先校验用户传入signature name | |||
if args.signature_name not in meta_graph.signature_def: | |||
if self.args.signature_name not in meta_graph.signature_def: | |||
log.error("==============> Invalid signature name <==================") | |||
# 从signature中获取meta graph中输入和输出的节点信息 | |||
signature = meta_graph.signature_def[args.signature_name] | |||
signature = meta_graph.signature_def[self.args.signature_name] | |||
input_keys, input_tensor_names = get_tensors(signature.inputs) | |||
output_keys, output_tensor_names = get_tensors(signature.outputs) | |||
@@ -87,14 +85,14 @@ class TensorflowInferenceService(AbstractInferenceService): | |||
log.info("===============> load tensorflow model success <===============") | |||
def inference(self, image): | |||
data = {"image_name": image['image_name']} | |||
data = {"data_name": image['data_name']} | |||
# 获得用户输入的图片 | |||
log.info("===============> start load " + image['image_name'] + " <===============") | |||
log.info("===============> start load " + image['data_name'] + " <===============") | |||
# 推理所需的输入,目前的分类预置模型都只有一个输入 | |||
input_dict = {} | |||
input_keys = self.signature_input_keys | |||
input_data = {} | |||
im = preprocess_input(self.load_image(image['image_path']), mode=args.prepare_mode) | |||
im = preprocess_input(self.load_image(image['data_path']), mode=self.args.prepare_mode) | |||
if len(list(im.shape)) == 3: | |||
input_data[input_keys[0]] = np.expand_dims(im, axis=0) | |||
@@ -43,52 +43,52 @@ def download_image(images_path): | |||
save_image_dir + str(int(round(time.time() * MAX_TIME_LENGTH))) + "." + image_path.split("/")[-1].split(".")[-1]) | |||
def upload_image(image_files): | |||
def upload_data(files): | |||
""" | |||
前端上传图片保存到本地 | |||
""" | |||
save_image_dir = "/usr/local/images/" | |||
if not os.path.exists(save_image_dir): | |||
os.mkdir(save_image_dir) | |||
images = list() | |||
for image_file in image_files: | |||
save_data_dir = "/usr/local/data/" | |||
if not os.path.exists(save_data_dir): | |||
os.mkdir(save_data_dir) | |||
data_list = list() | |||
for file in files: | |||
try: | |||
suffix = Path(image_file.filename).suffix | |||
with NamedTemporaryFile(delete=False, suffix=suffix, dir=save_image_dir) as tmp: | |||
shutil.copyfileobj(image_file.file, tmp) | |||
suffix = Path(file.filename).suffix | |||
with NamedTemporaryFile(delete=False, suffix=suffix, dir=save_data_dir) as tmp: | |||
shutil.copyfileobj(file.file, tmp) | |||
tmp_file_name = Path(tmp.name).name | |||
file = {"image_name": image_file.filename, "image_path": save_image_dir + tmp_file_name} | |||
images.append(file) | |||
data = {"data_name": file.filename, "data_path": save_data_dir + tmp_file_name} | |||
data_list.append(data) | |||
finally: | |||
image_file.file.close() | |||
return images | |||
file.file.close | |||
return data_list | |||
def upload_image_by_base64(image_files): | |||
def upload_image_by_base64(data_list): | |||
""" | |||
base64图片信息保存到本地 | |||
""" | |||
save_image_dir = "/usr/local/images/" | |||
if not os.path.exists(save_image_dir): | |||
os.mkdir(save_image_dir) | |||
images = list() | |||
for img_file in image_files: | |||
file_path = save_image_dir + str(int(round(time.time() * MAX_TIME_LENGTH))) + "." + img_file.image_name.split(".")[-1] | |||
img_data = base64.b64decode(img_file.image_file) | |||
save_data_dir = "/usr/local/data/" | |||
if not os.path.exists(save_data_dir): | |||
os.mkdir(save_data_dir) | |||
data_list_b64 = list() | |||
for data in data_list: | |||
file_path = save_data_dir + str(int(round(time.time() * MAX_TIME_LENGTH))) + "." + data.data_name.split(".")[-1] | |||
file_b64 = base64.b64decode(data.data_file) | |||
file = open(file_path, 'wb') | |||
file.write(img_data) | |||
file.write(file_b64) | |||
file.close() | |||
image = {"image_name": img_file.image_name, "image_path": file_path} | |||
images.append(image) | |||
return images | |||
data_b64 = {"data_name": data.data_name, "data_path": file_path} | |||
data_list_b64.append(data_b64) | |||
return data_list_b64 | |||
def writer_json_file(json_path, image_name, data): | |||
def writer_json_file(json_path, data_name, data): | |||
""" | |||
保存为json文件 | |||
""" | |||
if not os.path.exists(json_path): | |||
os.mkdir(json_path) | |||
filename = json_path + image_name + '.json' | |||
filename = json_path + data_name + '.json' | |||
with open(filename, 'w', encoding='utf-8') as file_obj: | |||
file_obj.write(json.dumps(data, ensure_ascii=False)) |
@@ -0,0 +1,88 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
""" | |||
import re | |||
import inspect | |||
# from ac_opf.models import ac_model | |||
class FindClassInFile: | |||
""" | |||
find class in the given module | |||
### note ### | |||
There is only one class in the given module. | |||
If there are more than one classes, all of them will be omit except the first | |||
############ | |||
method: find | |||
args: | |||
module: object-> the given file or module | |||
encoding: string-> "utf8" by default | |||
output: tuple(string-> name of the class, object-> class) | |||
usage: | |||
# >>> import module | |||
# | |||
# >>> find_class_in_file = FindClassInFile() | |||
# >>> cls = find_class_in_file.find(module) | |||
# | |||
# >>> cls_instance = cls[1](args) | |||
""" | |||
def __init__(self): | |||
pass | |||
def _open_file(self, path, encoding="utf8"): | |||
with open(path, "r", encoding=encoding) as f: | |||
data = f.readlines() | |||
for line in data: | |||
yield line | |||
def find(self, module, encoding="utf8"): | |||
path = module.__file__ | |||
lines = self._open_file(path=path, encoding=encoding) | |||
cls = "" | |||
for line in lines: | |||
if "class " in line: | |||
cls = re.findall("class (.*?)[:(]", line)[0] | |||
if cls: | |||
break | |||
return self._valid(module, cls) | |||
def _valid(self, module, cls): | |||
members = inspect.getmembers(module) | |||
cand = [(i, j) for i, j in members if inspect.isclass(j) and (not inspect.isabstract(j)) and (i == cls)] | |||
if not cand: | |||
print("class not found in {}".format(module)) | |||
return cand[0] | |||
if __name__ == "__main__": | |||
find_class_in_file = FindClassInFile() | |||
# cls = find_class_in_file.find(ac_model) | |||
# print(cls) | |||