|
- from kamal import vision, engine, utils, amalgamation, metrics, callbacks
- from kamal.vision import sync_transforms as sT
- import pdb
- import oneflow
- import argparse
-
- parser = argparse.ArgumentParser()
- parser.add_argument('--car_ckpt', default="./ckpt/car_res50_model")
- parser.add_argument('--dog_ckpt', default="./ckpt/dog_res50_model")
- parser.add_argument('--aircraft_ckpt', default="./ckpt/aircraft_res50_model")
- parser.add_argument('--lr', type=float, default=1e-3)
- parser.add_argument('--batch_size', type=int, default=32)
- args = parser.parse_args()
-
-
- def main():
- # 数据集
- car_train_dst = vision.datasets.StanfordCars('./DataSets/StanfordCars/', split='train')
- car_val_dst = vision.datasets.StanfordCars('./DataSets/StanfordCars/', split='test')
- dog_train_dst = vision.datasets.StanfordDogs('./DataSets/StanfordDogs/', split='train')
- dog_val_dst = vision.datasets.StanfordDogs('./DataSets/StanfordDogs/', split='test')
- aircraft_train_dst = vision.datasets.FGVCAircraft('./DataSets/FGVCAircraft/', split='trainval')
- aircraft_val_dst = vision.datasets.FGVCAircraft('./DataSets/FGVCAircraft/', split='test')
-
- # 教师/学生
- car_teacher = vision.models.classification.resnet50(num_classes=196, pretrained=False)
- dog_teacher = vision.models.classification.resnet50(num_classes=120, pretrained=False)
- aircraft_teacher = vision.models.classification.resnet50(num_classes=102, pretrained=False)
- student = vision.models.classification.resnet50(num_classes=196+120+102, pretrained=False)
-
- # 权重参数
- cars_parameters = oneflow.load(args.car_ckpt)
- dogs_parameters = oneflow.load(args.dog_ckpt)
- aircraft_parameters = oneflow.load(args.aircraft_ckpt)
-
- car_teacher.load_state_dict(cars_parameters)
- dog_teacher.load_state_dict(dogs_parameters)
- aircraft_teacher.load_state_dict(aircraft_parameters)
-
- train_transform = sT.Compose( [
- sT.RandomResizedCrop(224),
- sT.RandomHorizontalFlip(),
- sT.ToTensor(),
- sT.Normalize( mean=[0.485, 0.456, 0.406],
- std=[0.229, 0.224, 0.225] )
- ] )
- val_transform = sT.Compose( [
- sT.Resize(256),
- sT.CenterCrop( 224 ),
- sT.ToTensor(),
- sT.Normalize( mean=[0.485, 0.456, 0.406],
- std=[0.229, 0.224, 0.225] )
- ] )
-
- car_train_dst.transform = dog_train_dst.transform = aircraft_train_dst.transform = train_transform
- car_val_dst.transform = dog_val_dst.transform = aircraft_val_dst.transform = val_transform
-
- car_metric = metrics.MetricCompose(metric_dict={'car_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, :196], t))})
- dog_metric = metrics.MetricCompose(metric_dict={'dog_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, 196:196+120], t))})
- aircraft_metric = metrics.MetricCompose(metric_dict={'aircraft_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, 196+120:196+120+102], t))})
-
-
- train_dst = oneflow.utils.data.ConcatDataset( [car_train_dst, dog_train_dst, aircraft_train_dst] )
- # pdb.set_trace()
- train_loader = oneflow.utils.data.DataLoader( train_dst, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True )
- car_loader = oneflow.utils.data.DataLoader( car_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=True )
- dog_loader = oneflow.utils.data.DataLoader( dog_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=True)
- aircraft_loader = oneflow.utils.data.DataLoader( aircraft_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=True )
-
- car_evaluator = engine.evaluator.BasicEvaluator( car_loader, car_metric )
- dog_evaluator = engine.evaluator.BasicEvaluator( dog_loader, dog_metric )
- aircraft_evaluator = engine.evaluator.BasicEvaluator( aircraft_loader, aircraft_metric )
-
- # pdb.set_trace()
- TOTAL_ITERS=len(train_loader) * 100
- device = oneflow.device( 'cuda' if oneflow.cuda.is_available() else 'cpu' )
- optim = oneflow.optim.Adam( student.parameters(), lr=args.lr, weight_decay=1e-4)
- sched = oneflow.optim.lr_scheduler.CosineAnnealingLR( optim, T_max=TOTAL_ITERS )
- trainer = amalgamation.LayerWiseAmalgamator(
- logger=utils.logger.get_logger('layerwise-ka'),
- # tb_writer=SummaryWriter( log_dir='run/layerwise_ka-%s'%( time.asctime().replace( ' ', '_' ) ) )
- )
-
- trainer.add_callback(
- engine.DefaultEvents.AFTER_STEP(every=10),
- callbacks=callbacks.MetricsLogging(keys=('total_loss', 'loss_kd', 'loss_amal', 'loss_recons', 'lr')))
- trainer.add_callback(
- engine.DefaultEvents.AFTER_EPOCH,
- callbacks=[
- callbacks.EvalAndCkpt(model=student, evaluator=car_evaluator, metric_name='car_acc', ckpt_prefix='thl_car'),
- callbacks.EvalAndCkpt(model=student, evaluator=dog_evaluator, metric_name='dog_acc', ckpt_prefix='thl_dog'),
- callbacks.EvalAndCkpt(model=student, evaluator=aircraft_evaluator, metric_name='aircraft_acc', ckpt_prefix='thl_aircraft'),
- ] )
- trainer.add_callback(
- engine.DefaultEvents.AFTER_STEP,
- callbacks=callbacks.LRSchedulerCallback(schedulers=[sched]))
-
- layer_groups = []
- layer_channels = []
- for stu_block, car_block, dog_block, aircraft_block in zip( student.modules(), car_teacher.modules(), dog_teacher.modules(), aircraft_teacher.modules() ):
- if isinstance( stu_block, oneflow.nn.Conv2d ):
- layer_groups.append( [ stu_block, car_block, dog_block, aircraft_block ] )
- layer_channels.append( [ stu_block.out_channels, car_block.out_channels, dog_block.out_channels, aircraft_block.out_channels ] )
- trainer.setup( student=student,
- teachers=[car_teacher, dog_teacher, aircraft_teacher],
- layer_groups=layer_groups,
- layer_channels=layer_channels,
- dataloader=train_loader,
- optimizer=optim,
- device=device,
- weights=[1., 1., 1.] )
-
- trainer.run(start_iter=0, max_iter=TOTAL_ITERS)
-
- if __name__=='__main__':
- main()
|