""" pytorch -- oneflow 格式权重转换 """ from kamal import vision import torch import argparse import oneflow as flow import pdb parser = argparse.ArgumentParser() parser.add_argument( '--car_ckpt', required=True ) parser.add_argument( '--dog_ckpt' ) args = parser.parse_args() cars_parameters = torch.load(args.car_ckpt).state_dict() # dogs_parameters = torch.load(args.dog_ckpt).state_dict() cars_para, dogs_para = {}, {} # pdb.set_trace() for key, value in cars_parameters.items(): val = value.detach().cpu().numpy() if not str(key).endswith('num_batches_tracked'): cars_para[key] = val # for key, value in dogs_parameters.items(): # val = value.detach().cpu().numpy() # if not str(key).endswith('num_batches_tracked'): # dogs_para[key] = val car_teacher = vision.models.classification.resnet50(num_classes=102, pretrained=False) # dog_teacher = vision.models.classification.resnet50(num_classes=397, pretrained=False) car_teacher.load_state_dict(cars_para) # dog_teacher.load_state_dict(dogs_para) # torch.save(car_teacher, 'ckpt/aircraft_res50.pth') # torch.save(dog_teacher, 'checkpoint/sun_res50.pth') flow.save(car_teacher.state_dict(), "./ckpt/aircraft_res50_model") # flow.save(dog_teacher.state_dict(), "./checkpoint/sun_res50_model")