You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

flow.py 1.3 kB

3 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041
  1. """
  2. pytorch -- oneflow 格式权重转换
  3. """
  4. from kamal import vision
  5. import torch
  6. import argparse
  7. import oneflow as flow
  8. import pdb
  9. parser = argparse.ArgumentParser()
  10. parser.add_argument( '--car_ckpt', required=True )
  11. parser.add_argument( '--dog_ckpt' )
  12. args = parser.parse_args()
  13. cars_parameters = torch.load(args.car_ckpt).state_dict()
  14. # dogs_parameters = torch.load(args.dog_ckpt).state_dict()
  15. cars_para, dogs_para = {}, {}
  16. # pdb.set_trace()
  17. for key, value in cars_parameters.items():
  18. val = value.detach().cpu().numpy()
  19. if not str(key).endswith('num_batches_tracked'):
  20. cars_para[key] = val
  21. # for key, value in dogs_parameters.items():
  22. # val = value.detach().cpu().numpy()
  23. # if not str(key).endswith('num_batches_tracked'):
  24. # dogs_para[key] = val
  25. car_teacher = vision.models.classification.resnet50(num_classes=102, pretrained=False)
  26. # dog_teacher = vision.models.classification.resnet50(num_classes=397, pretrained=False)
  27. car_teacher.load_state_dict(cars_para)
  28. # dog_teacher.load_state_dict(dogs_para)
  29. # torch.save(car_teacher, 'ckpt/aircraft_res50.pth')
  30. # torch.save(dog_teacher, 'checkpoint/sun_res50.pth')
  31. flow.save(car_teacher.state_dict(), "./ckpt/aircraft_res50_model")
  32. # flow.save(dog_teacher.state_dict(), "./checkpoint/sun_res50_model")

模型炼知是由浙江大学VIPA团队于2019-2020年期间提出,其目的是建立轻量化的知识融合算法和解决深度模型迁移性度量问题。 本仓库包含TTL、THL、TFL三个模型炼知示例算法,用于计算机视觉领域,通过将多个同构或异构教师重组,实现知识融合,获得定制化的、全能型的学生模型,解决所有教师任务,学生模型性能相比于传统训练结果显著提高。因此,模型炼知具有深入研究和实际应用价值。

Contributors (1)