|
- """
- pytorch -- oneflow 格式权重转换
- """
- from kamal import vision
- import torch
- import argparse
- import oneflow as flow
- import pdb
-
-
- parser = argparse.ArgumentParser()
- parser.add_argument( '--car_ckpt', required=True )
- parser.add_argument( '--dog_ckpt' )
- args = parser.parse_args()
-
- cars_parameters = torch.load(args.car_ckpt).state_dict()
- # dogs_parameters = torch.load(args.dog_ckpt).state_dict()
- cars_para, dogs_para = {}, {}
- # pdb.set_trace()
- for key, value in cars_parameters.items():
- val = value.detach().cpu().numpy()
- if not str(key).endswith('num_batches_tracked'):
- cars_para[key] = val
-
- # for key, value in dogs_parameters.items():
- # val = value.detach().cpu().numpy()
- # if not str(key).endswith('num_batches_tracked'):
- # dogs_para[key] = val
-
- car_teacher = vision.models.classification.resnet50(num_classes=102, pretrained=False)
- # dog_teacher = vision.models.classification.resnet50(num_classes=397, pretrained=False)
-
-
- car_teacher.load_state_dict(cars_para)
- # dog_teacher.load_state_dict(dogs_para)
-
- # torch.save(car_teacher, 'ckpt/aircraft_res50.pth')
- # torch.save(dog_teacher, 'checkpoint/sun_res50.pth')
-
- flow.save(car_teacher.state_dict(), "./ckpt/aircraft_res50_model")
- # flow.save(dog_teacher.state_dict(), "./checkpoint/sun_res50_model")
|