from kamal import vision, engine, utils, amalgamation, metrics, callbacks from kamal.vision import sync_transforms as sT import pdb import oneflow import argparse parser = argparse.ArgumentParser() parser.add_argument('--car_ckpt', default="./ckpt/car_res50_model") parser.add_argument('--dog_ckpt', default="./ckpt/dog_res50_model") parser.add_argument('--aircraft_ckpt', default="./ckpt/aircraft_res50_model") parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--batch_size', type=int, default=32) args = parser.parse_args() def main(): # 数据集 car_train_dst = vision.datasets.StanfordCars('./DataSets/StanfordCars/', split='train') car_val_dst = vision.datasets.StanfordCars('./DataSets/StanfordCars/', split='test') dog_train_dst = vision.datasets.StanfordDogs('./DataSets/StanfordDogs/', split='train') dog_val_dst = vision.datasets.StanfordDogs('./DataSets/StanfordDogs/', split='test') aircraft_train_dst = vision.datasets.FGVCAircraft('./DataSets/FGVCAircraft/', split='trainval') aircraft_val_dst = vision.datasets.FGVCAircraft('./DataSets/FGVCAircraft/', split='test') # 教师/学生 car_teacher = vision.models.classification.resnet50(num_classes=196, pretrained=False) dog_teacher = vision.models.classification.resnet50(num_classes=120, pretrained=False) aircraft_teacher = vision.models.classification.resnet50(num_classes=102, pretrained=False) student = vision.models.classification.resnet50(num_classes=196+120+102, pretrained=False) # 权重参数 cars_parameters = oneflow.load(args.car_ckpt) dogs_parameters = oneflow.load(args.dog_ckpt) aircraft_parameters = oneflow.load(args.aircraft_ckpt) car_teacher.load_state_dict(cars_parameters) dog_teacher.load_state_dict(dogs_parameters) aircraft_teacher.load_state_dict(aircraft_parameters) train_transform = sT.Compose( [ sT.RandomResizedCrop(224), sT.RandomHorizontalFlip(), sT.ToTensor(), sT.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ] ) val_transform = sT.Compose( [ sT.Resize(256), sT.CenterCrop( 224 ), sT.ToTensor(), sT.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ] ) car_train_dst.transform = dog_train_dst.transform = aircraft_train_dst.transform = train_transform car_val_dst.transform = dog_val_dst.transform = aircraft_val_dst.transform = val_transform car_metric = metrics.MetricCompose(metric_dict={'car_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, :196], t))}) dog_metric = metrics.MetricCompose(metric_dict={'dog_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, 196:196+120], t))}) aircraft_metric = metrics.MetricCompose(metric_dict={'aircraft_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, 196+120:196+120+102], t))}) train_dst = oneflow.utils.data.ConcatDataset( [car_train_dst, dog_train_dst, aircraft_train_dst] ) # pdb.set_trace() train_loader = oneflow.utils.data.DataLoader( train_dst, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True ) car_loader = oneflow.utils.data.DataLoader( car_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=True ) dog_loader = oneflow.utils.data.DataLoader( dog_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=True) aircraft_loader = oneflow.utils.data.DataLoader( aircraft_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=True ) car_evaluator = engine.evaluator.BasicEvaluator( car_loader, car_metric ) dog_evaluator = engine.evaluator.BasicEvaluator( dog_loader, dog_metric ) aircraft_evaluator = engine.evaluator.BasicEvaluator( aircraft_loader, aircraft_metric ) # pdb.set_trace() TOTAL_ITERS=len(train_loader) * 100 device = oneflow.device( 'cuda' if oneflow.cuda.is_available() else 'cpu' ) optim = oneflow.optim.Adam( student.parameters(), lr=args.lr, weight_decay=1e-4) sched = oneflow.optim.lr_scheduler.CosineAnnealingLR( optim, T_max=TOTAL_ITERS ) trainer = amalgamation.LayerWiseAmalgamator( logger=utils.logger.get_logger('layerwise-ka'), # tb_writer=SummaryWriter( log_dir='run/layerwise_ka-%s'%( time.asctime().replace( ' ', '_' ) ) ) ) trainer.add_callback( engine.DefaultEvents.AFTER_STEP(every=10), callbacks=callbacks.MetricsLogging(keys=('total_loss', 'loss_kd', 'loss_amal', 'loss_recons', 'lr'))) trainer.add_callback( engine.DefaultEvents.AFTER_EPOCH, callbacks=[ callbacks.EvalAndCkpt(model=student, evaluator=car_evaluator, metric_name='car_acc', ckpt_prefix='thl_car'), callbacks.EvalAndCkpt(model=student, evaluator=dog_evaluator, metric_name='dog_acc', ckpt_prefix='thl_dog'), callbacks.EvalAndCkpt(model=student, evaluator=aircraft_evaluator, metric_name='aircraft_acc', ckpt_prefix='thl_aircraft'), ] ) trainer.add_callback( engine.DefaultEvents.AFTER_STEP, callbacks=callbacks.LRSchedulerCallback(schedulers=[sched])) layer_groups = [] layer_channels = [] for stu_block, car_block, dog_block, aircraft_block in zip( student.modules(), car_teacher.modules(), dog_teacher.modules(), aircraft_teacher.modules() ): if isinstance( stu_block, oneflow.nn.Conv2d ): layer_groups.append( [ stu_block, car_block, dog_block, aircraft_block ] ) layer_channels.append( [ stu_block.out_channels, car_block.out_channels, dog_block.out_channels, aircraft_block.out_channels ] ) trainer.setup( student=student, teachers=[car_teacher, dog_teacher, aircraft_teacher], layer_groups=layer_groups, layer_channels=layer_channels, dataloader=train_loader, optimizer=optim, device=device, weights=[1., 1., 1.] ) trainer.run(start_iter=0, max_iter=TOTAL_ITERS) if __name__=='__main__': main()