|
- """
- Copyright 2020 Tianshu AI Platform. All Rights Reserved.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- """
- # oneflow checkpoint 到savedmodel 转换脚本
- # 转换出来job_name是'inference',也就是 make_infer_func 中定义的同名函数
-
- import oneflow as flow
- import os
- import shutil
-
- os.environ['CUDA_VISIBLE_DEVICES'] = '3'
-
- # 输入参数(需要根据模型修改)
- image_size = (3, 299, 299)
- model_name = "inceptionv3"
- saved_model_path = "inceptionv3"
- batch_size = 8
-
- # 导入模型网络结构
- # from of_models.resnet_model import resnet50
- # from of_models.alexnet_model import alexnet
- # from of_models.vgg_model import vgg16bn
- from models.of_models import inceptionv3
-
- model = inceptionv3
-
- # 模型checkpoint文件路径
- # DEFAULT_CHECKPOINT_DIR = "./checkpoint_of/resnet_v15_of_best_model_val_top1_77318"
- # DEFAULT_CHECKPOINT_DIR = "./checkpoint_of/alexnet_of_best_model_val_top1_54762"
- # DEFAULT_CHECKPOINT_DIR = "./checkpoint_of/vgg16_of_best_model_val_top1_721"
- DEFAULT_CHECKPOINT_DIR = "./checkpoint_of/snapshot_epoch_75"
-
-
- def init_env():
- flow.env.init()
- flow.config.machine_num(1)
- flow.config.cpu_device_num(1)
- flow.config.gpu_device_num(1)
- flow.config.enable_debug_mode(True)
-
-
- def make_infer_func(batch_size, image_size):
- input_lbns = {}
- output_lbns = {}
- image_shape = (batch_size,) + tuple(image_size)
-
- @flow.global_function(type="predict")
- def inference(
- image: flow.typing.Numpy.Placeholder(image_shape, dtype=flow.float32)
- ) -> flow.typing.Numpy:
- input_lbns["image"] = image.logical_blob_name
- output = model(image)
- output = flow.nn.softmax(output)
- output_lbns["output"] = output.logical_blob_name
- return output
-
- return inference, input_lbns, output_lbns
-
-
- def ckpt_to_savedmodel():
- init_env()
- resnet_infer, input_lbns, output_lbns = make_infer_func(
- batch_size, image_size
- )
- # origin resnet inference model
- checkpoint = flow.train.CheckPoint()
- checkpoint.load(DEFAULT_CHECKPOINT_DIR)
-
- # save model
- # 修改路径
- if os.path.exists(saved_model_path) and os.path.isdir(saved_model_path):
- shutil.rmtree(saved_model_path)
-
- model_version = 1
- saved_model_builder = flow.SavedModelBuilderV2(saved_model_path)
- job_builder = (
- saved_model_builder.ModelName(model_name)
- .Version(model_version)
- .Job(resnet_infer)
- )
- for input_name, lbn in input_lbns.items():
- job_builder.Input(input_name, lbn)
- for output_name, lbn in output_lbns.items():
- job_builder.Output(output_name, lbn)
- job_builder.Complete().Save()
-
-
- if __name__ == "__main__":
- ckpt_to_savedmodel()
|