You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

THL.py 6.4 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. from kamal import vision, engine, utils, amalgamation, metrics, callbacks
  2. from kamal.vision import sync_transforms as sT
  3. import pdb
  4. import oneflow
  5. import argparse
  6. parser = argparse.ArgumentParser()
  7. parser.add_argument('--car_ckpt', default="./ckpt/car_res50_model")
  8. parser.add_argument('--dog_ckpt', default="./ckpt/dog_res50_model")
  9. parser.add_argument('--aircraft_ckpt', default="./ckpt/aircraft_res50_model")
  10. parser.add_argument('--lr', type=float, default=1e-3)
  11. parser.add_argument('--batch_size', type=int, default=32)
  12. args = parser.parse_args()
  13. def main():
  14. # 数据集
  15. car_train_dst = vision.datasets.StanfordCars('./DataSets/StanfordCars/', split='train')
  16. car_val_dst = vision.datasets.StanfordCars('./DataSets/StanfordCars/', split='test')
  17. dog_train_dst = vision.datasets.StanfordDogs('./DataSets/StanfordDogs/', split='train')
  18. dog_val_dst = vision.datasets.StanfordDogs('./DataSets/StanfordDogs/', split='test')
  19. aircraft_train_dst = vision.datasets.FGVCAircraft('./DataSets/FGVCAircraft/', split='trainval')
  20. aircraft_val_dst = vision.datasets.FGVCAircraft('./DataSets/FGVCAircraft/', split='test')
  21. # 教师/学生
  22. car_teacher = vision.models.classification.resnet50(num_classes=196, pretrained=False)
  23. dog_teacher = vision.models.classification.resnet50(num_classes=120, pretrained=False)
  24. aircraft_teacher = vision.models.classification.resnet50(num_classes=102, pretrained=False)
  25. student = vision.models.classification.resnet50(num_classes=196+120+102, pretrained=False)
  26. # 权重参数
  27. cars_parameters = oneflow.load(args.car_ckpt)
  28. dogs_parameters = oneflow.load(args.dog_ckpt)
  29. aircraft_parameters = oneflow.load(args.aircraft_ckpt)
  30. car_teacher.load_state_dict(cars_parameters)
  31. dog_teacher.load_state_dict(dogs_parameters)
  32. aircraft_teacher.load_state_dict(aircraft_parameters)
  33. train_transform = sT.Compose( [
  34. sT.RandomResizedCrop(224),
  35. sT.RandomHorizontalFlip(),
  36. sT.ToTensor(),
  37. sT.Normalize( mean=[0.485, 0.456, 0.406],
  38. std=[0.229, 0.224, 0.225] )
  39. ] )
  40. val_transform = sT.Compose( [
  41. sT.Resize(256),
  42. sT.CenterCrop( 224 ),
  43. sT.ToTensor(),
  44. sT.Normalize( mean=[0.485, 0.456, 0.406],
  45. std=[0.229, 0.224, 0.225] )
  46. ] )
  47. car_train_dst.transform = dog_train_dst.transform = aircraft_train_dst.transform = train_transform
  48. car_val_dst.transform = dog_val_dst.transform = aircraft_val_dst.transform = val_transform
  49. car_metric = metrics.MetricCompose(metric_dict={'car_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, :196], t))})
  50. dog_metric = metrics.MetricCompose(metric_dict={'dog_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, 196:196+120], t))})
  51. aircraft_metric = metrics.MetricCompose(metric_dict={'aircraft_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, 196+120:196+120+102], t))})
  52. train_dst = oneflow.utils.data.ConcatDataset( [car_train_dst, dog_train_dst, aircraft_train_dst] )
  53. # pdb.set_trace()
  54. train_loader = oneflow.utils.data.DataLoader( train_dst, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True )
  55. car_loader = oneflow.utils.data.DataLoader( car_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=True )
  56. dog_loader = oneflow.utils.data.DataLoader( dog_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=True)
  57. aircraft_loader = oneflow.utils.data.DataLoader( aircraft_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=True )
  58. car_evaluator = engine.evaluator.BasicEvaluator( car_loader, car_metric )
  59. dog_evaluator = engine.evaluator.BasicEvaluator( dog_loader, dog_metric )
  60. aircraft_evaluator = engine.evaluator.BasicEvaluator( aircraft_loader, aircraft_metric )
  61. # pdb.set_trace()
  62. TOTAL_ITERS=len(train_loader) * 100
  63. device = oneflow.device( 'cuda' if oneflow.cuda.is_available() else 'cpu' )
  64. optim = oneflow.optim.Adam( student.parameters(), lr=args.lr, weight_decay=1e-4)
  65. sched = oneflow.optim.lr_scheduler.CosineAnnealingLR( optim, T_max=TOTAL_ITERS )
  66. trainer = amalgamation.LayerWiseAmalgamator(
  67. logger=utils.logger.get_logger('layerwise-ka'),
  68. # tb_writer=SummaryWriter( log_dir='run/layerwise_ka-%s'%( time.asctime().replace( ' ', '_' ) ) )
  69. )
  70. trainer.add_callback(
  71. engine.DefaultEvents.AFTER_STEP(every=10),
  72. callbacks=callbacks.MetricsLogging(keys=('total_loss', 'loss_kd', 'loss_amal', 'loss_recons', 'lr')))
  73. trainer.add_callback(
  74. engine.DefaultEvents.AFTER_EPOCH,
  75. callbacks=[
  76. callbacks.EvalAndCkpt(model=student, evaluator=car_evaluator, metric_name='car_acc', ckpt_prefix='thl_car'),
  77. callbacks.EvalAndCkpt(model=student, evaluator=dog_evaluator, metric_name='dog_acc', ckpt_prefix='thl_dog'),
  78. callbacks.EvalAndCkpt(model=student, evaluator=aircraft_evaluator, metric_name='aircraft_acc', ckpt_prefix='thl_aircraft'),
  79. ] )
  80. trainer.add_callback(
  81. engine.DefaultEvents.AFTER_STEP,
  82. callbacks=callbacks.LRSchedulerCallback(schedulers=[sched]))
  83. layer_groups = []
  84. layer_channels = []
  85. for stu_block, car_block, dog_block, aircraft_block in zip( student.modules(), car_teacher.modules(), dog_teacher.modules(), aircraft_teacher.modules() ):
  86. if isinstance( stu_block, oneflow.nn.Conv2d ):
  87. layer_groups.append( [ stu_block, car_block, dog_block, aircraft_block ] )
  88. layer_channels.append( [ stu_block.out_channels, car_block.out_channels, dog_block.out_channels, aircraft_block.out_channels ] )
  89. trainer.setup( student=student,
  90. teachers=[car_teacher, dog_teacher, aircraft_teacher],
  91. layer_groups=layer_groups,
  92. layer_channels=layer_channels,
  93. dataloader=train_loader,
  94. optimizer=optim,
  95. device=device,
  96. weights=[1., 1., 1.] )
  97. trainer.run(start_iter=0, max_iter=TOTAL_ITERS)
  98. if __name__=='__main__':
  99. main()

模型炼知是由浙江大学VIPA团队于2019-2020年期间提出,其目的是建立轻量化的知识融合算法和解决深度模型迁移性度量问题。 本仓库包含TTL、THL、TFL三个模型炼知示例算法,用于计算机视觉领域,通过将多个同构或异构教师重组,实现知识融合,获得定制化的、全能型的学生模型,解决所有教师任务,学生模型性能相比于传统训练结果显著提高。因此,模型炼知具有深入研究和实际应用价值。

Contributors (1)