|
- import torchvision
- import torch
- import argparse
- from torch.autograd import Variable
- import onnx
- print(torch.__version__)
-
- parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
-
- parser.add_argument('--model',
- type=str,
- help='path to training/inference dataset folder'
- )
- parser.add_argument('--n',
- type=int,
- default=256,
- help='batch size for input shape type'
- )
- parser.add_argument('--c',
- type=int,
- default=1,
- help='channel for input shape type'
- )
- parser.add_argument('--h',
- type=int,
- default=28,
- help='height for input shape type'
- )
- parser.add_argument('--w',
- type=int,
- default=28,
- help='width for input shape type'
- )
-
- if __name__ == "__main__":
- args = parser.parse_args()
- print('args:')
- print(args)
-
- model_file = '/dataset/' + args.model
- print(model_file)
- model = torch.load(model_file)
- print(model)
- print(type(model))
- for k, v in model.named_parameters():
- print("k:",k)
- print("v:",v.shape)
-
- suffix = args.model.rindex(".")
- out_file = '/model/' + args.model + ".onnx"
- if suffix!=-1 :
- out_file = '/model/' + args.model[0:suffix] + ".onnx"
- print(out_file)
- input_name = ['input']
- output_name = ['output']
- input = Variable(torch.randn(args.n, args.c, args.h, args.w))
- torch.onnx.export(model, input, out_file, input_names=input_name, output_names=output_name, verbose=True)
-
|