| @@ -0,0 +1,40 @@ | |||
| # 模型炼知(Model Knowledge Amalgamation, KA) | |||
| ## Introduction | |||
| 模型炼知是用于知识融合和模型迁移性度量,而建立的轻量级算法包,详情请见 [KAE](https://github.com/zju-vipa/KamalEngine/tree/master/examples)。 | |||
| ## Table of contents | |||
| * [Introduction](#Introduction) | |||
| * [Algorithms](#Algorithms) | |||
| * [Authors](#Authors) | |||
| ## Algorithms | |||
| #### 0. 数据集 | |||
| * data/StanfordCars: https://ai.stanford.edu/~jkrause/cars/car_dataset.html | |||
| * data/StanfordDogs: http://vision.stanford.edu/aditya86/ImageNetDogs/ | |||
| * data/FGVCAircraft: http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/ | |||
| * data/Flowers102: http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/ | |||
| #### 1. 依赖包 | |||
| * [requirements.txt](requirements.txt) | |||
| #### 2. 算法 | |||
| * TW_Layerwise | |||
| ```bash | |||
| python3 TTL.py --car_ckpt ./ckpt/car_res50_model stanford_dogs --dog_ckpt ./ckpt/dog_res50_model | |||
| ``` | |||
| * TH_Layerwise | |||
| ```bash | |||
| python3 THL.py --car_ckpt ./ckpt/car_res50_model stanford_dogs --dog_ckpt ./ckpt/dog_res50_model --aircraft_ckpt ./ckpt/aircraft_res50_model | |||
| ``` | |||
| * TF_Layerwise | |||
| ```bash | |||
| python3 TFL.py --car_ckpt ./ckpt/car_res50_model stanford_dogs --dog_ckpt ./ckpt/dog_res50_model --aircraft_ckpt ./ckpt/aircraft_res50_model --flower_ckpt ./ckpt/flower_res50_model | |||
| ``` | |||
| * flow | |||
| ```bash | |||
| pytorch weights convert oneflow | |||
| ``` | |||
| #### 3. 权重 | |||
| 文件夹MODEL_ckpt保存的是三个模型重组/炼知算法的权重,oneflow格式,供复现时参考。ckpt文件夹为pytorch权重转为oneflow权重,用于教师模型加载权重,内存占用较大,改为放在下方链接。 | |||
| * 教师模型权重 [链接]:https://pan.baidu.com/s/19pRLZKQvjgEQ7CuD_RiPPw 提取码:FLOW | |||
| ## Authors | |||
| This project is simplified by KAE, which is developed by [VIPA Lab](http://vipazoo.cn) from Zhejiang University and [Zhejiang Lab](http://www.zhejianglab.com/) | |||
| @@ -0,0 +1,124 @@ | |||
| from kamal import vision, engine, utils, amalgamation, metrics, callbacks | |||
| from kamal.vision import sync_transforms as sT | |||
| import pdb | |||
| import oneflow | |||
| import argparse | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument('--car_ckpt', default="./ckpt/car_res50_model") | |||
| parser.add_argument('--dog_ckpt', default="./ckpt/dog_res50_model") | |||
| parser.add_argument('--aircraft_ckpt', default="./ckpt/aircraft_res50_model") | |||
| parser.add_argument('--flower_ckpt', default="./ckpt/flower_res50_model") | |||
| parser.add_argument('--lr', type=float, default=1e-3) | |||
| parser.add_argument('--batch_size', type=int, default=16) | |||
| args = parser.parse_args() | |||
| def main(): | |||
| # 数据集 | |||
| car_train_dst = vision.datasets.StanfordCars('./DataSets/StanfordCars/', split='train') | |||
| car_val_dst = vision.datasets.StanfordCars('./DataSets/StanfordCars/', split='test') | |||
| dog_train_dst = vision.datasets.StanfordDogs('./DataSets/StanfordDogs/', split='train') | |||
| dog_val_dst = vision.datasets.StanfordDogs('./DataSets/StanfordDogs/', split='test') | |||
| aircraft_train_dst = vision.datasets.FGVCAircraft('./DataSets/FGVCAircraft/', split='trainval') | |||
| aircraft_val_dst = vision.datasets.FGVCAircraft('./DataSets/FGVCAircraft/', split='test') | |||
| flower_train_dst = vision.datasets.Flowers102('./DataSets/Flower102/', split='train') | |||
| flower_val_dst = vision.datasets.Flowers102('./DataSets/Flower102/', split='valid') | |||
| # 教师/学生 | |||
| car_teacher = vision.models.classification.resnet50(num_classes=196, pretrained=False) | |||
| dog_teacher = vision.models.classification.resnet50(num_classes=120, pretrained=False) | |||
| aircraft_teacher = vision.models.classification.resnet50(num_classes=102, pretrained=False) | |||
| flower_teacher = vision.models.classification.resnet50(num_classes=102, pretrained=False) | |||
| student = vision.models.classification.resnet50(num_classes=196+120+102+102, pretrained=False) | |||
| # 权重参数 | |||
| cars_parameters = oneflow.load(args.car_ckpt) | |||
| dogs_parameters = oneflow.load(args.dog_ckpt) | |||
| aircraft_parameters = oneflow.load(args.aircraft_ckpt) | |||
| flowers_parameters = oneflow.load(args.flower_ckpt) | |||
| car_teacher.load_state_dict(cars_parameters) | |||
| dog_teacher.load_state_dict(dogs_parameters) | |||
| aircraft_teacher.load_state_dict(aircraft_parameters) | |||
| flower_teacher.load_state_dict(flowers_parameters) | |||
| train_transform = sT.Compose( [ | |||
| sT.RandomResizedCrop(224), | |||
| sT.RandomHorizontalFlip(), | |||
| sT.ToTensor(), | |||
| sT.Normalize( mean=[0.485, 0.456, 0.406], | |||
| std=[0.229, 0.224, 0.225] ) | |||
| ] ) | |||
| val_transform = sT.Compose( [ | |||
| sT.Resize(256), | |||
| sT.CenterCrop( 224 ), | |||
| sT.ToTensor(), | |||
| sT.Normalize( mean=[0.485, 0.456, 0.406], | |||
| std=[0.229, 0.224, 0.225] ) | |||
| ] ) | |||
| car_train_dst.transform = dog_train_dst.transform = aircraft_train_dst.transform = flower_train_dst.transform = train_transform | |||
| car_val_dst.transform = dog_val_dst.transform = aircraft_val_dst.transform = flower_val_dst.transform = val_transform | |||
| car_metric = metrics.MetricCompose(metric_dict={'car_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, :196], t))}) | |||
| dog_metric = metrics.MetricCompose(metric_dict={'dog_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, 196:196+120], t))}) | |||
| aircraft_metric = metrics.MetricCompose(metric_dict={'aircraft_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, 196+120:196+120+102], t))}) | |||
| flower_metric = metrics.MetricCompose(metric_dict={'flower_acc': metrics.Accuracy(attach_to=lambda o, t: ( o[:, 196+120+102:196+120+102+102], t ) ) } ) | |||
| train_dst = oneflow.utils.data.ConcatDataset([car_train_dst, dog_train_dst, aircraft_train_dst, flower_train_dst]) | |||
| train_loader = oneflow.utils.data.DataLoader(train_dst, batch_size=args.batch_size, shuffle=True, num_workers=0) | |||
| car_loader = oneflow.utils.data.DataLoader(car_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0) | |||
| dog_loader = oneflow.utils.data.DataLoader(dog_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0) | |||
| aircraft_loader = oneflow.utils.data.DataLoader(aircraft_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=True) | |||
| flower_loader = oneflow.utils.data.DataLoader(flower_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0) | |||
| car_evaluator = engine.evaluator.BasicEvaluator(car_loader, car_metric) | |||
| dog_evaluator = engine.evaluator.BasicEvaluator(dog_loader, dog_metric) | |||
| aircraft_evaluator = engine.evaluator.BasicEvaluator(aircraft_loader, aircraft_metric) | |||
| flower_evaluator = engine.evaluator.BasicEvaluator(flower_loader, flower_metric) | |||
| TOTAL_ITERS=len(train_loader) * 100 | |||
| device = oneflow.device( 'cuda' if oneflow.cuda.is_available() else 'cpu' ) | |||
| optim = oneflow.optim.Adam( student.parameters(), lr=args.lr, weight_decay=1e-4) | |||
| sched = oneflow.optim.lr_scheduler.CosineAnnealingLR( optim, T_max=TOTAL_ITERS ) | |||
| trainer = amalgamation.LayerWiseAmalgamator( | |||
| logger=utils.logger.get_logger('layerwise-ka'), | |||
| # tb_writer=SummaryWriter( log_dir='run/layerwise_ka-%s'%( time.asctime().replace( ' ', '_' ) ) ) | |||
| ) | |||
| trainer.add_callback( | |||
| engine.DefaultEvents.AFTER_STEP(every=10), | |||
| callbacks=callbacks.MetricsLogging(keys=('total_loss', 'loss_kd', 'loss_amal', 'loss_recons', 'lr'))) | |||
| trainer.add_callback( | |||
| engine.DefaultEvents.AFTER_EPOCH, | |||
| callbacks=[ | |||
| callbacks.EvalAndCkpt(model=student, evaluator=car_evaluator, metric_name='car_acc', ckpt_prefix='tfl_car'), | |||
| callbacks.EvalAndCkpt(model=student, evaluator=dog_evaluator, metric_name='dog_acc', ckpt_prefix='tfl_dog'), | |||
| callbacks.EvalAndCkpt(model=student, evaluator=aircraft_evaluator, metric_name='aircraft_acc', ckpt_prefix='tfl_aircraft'), | |||
| callbacks.EvalAndCkpt(model=student, evaluator=flower_evaluator, metric_name='flower_acc', ckpt_prefix='tfl_flower'), | |||
| ] ) | |||
| trainer.add_callback( | |||
| engine.DefaultEvents.AFTER_STEP, | |||
| callbacks=callbacks.LRSchedulerCallback(schedulers=[sched])) | |||
| layer_groups = [] | |||
| layer_channels = [] | |||
| for stu_block, car_block, dog_block, aircraft_block, flower_block in zip( student.modules(), car_teacher.modules(), dog_teacher.modules(), aircraft_teacher.modules(), flower_teacher.modules() ): | |||
| if isinstance( stu_block, oneflow.nn.Conv2d ): | |||
| layer_groups.append( [ stu_block, car_block, dog_block, aircraft_block, flower_block ] ) | |||
| layer_channels.append( [ stu_block.out_channels, car_block.out_channels, dog_block.out_channels, aircraft_block.out_channels, flower_block.out_channels ] ) | |||
| trainer.setup( student=student, | |||
| teachers=[car_teacher, dog_teacher, aircraft_teacher, flower_teacher], | |||
| layer_groups=layer_groups, | |||
| layer_channels=layer_channels, | |||
| dataloader=train_loader, | |||
| optimizer=optim, | |||
| device=device, | |||
| weights=[1., 1., 1.] ) | |||
| trainer.run(start_iter=0, max_iter=TOTAL_ITERS) | |||
| if __name__=='__main__': | |||
| main() | |||
| @@ -0,0 +1,116 @@ | |||
| from kamal import vision, engine, utils, amalgamation, metrics, callbacks | |||
| from kamal.vision import sync_transforms as sT | |||
| import pdb | |||
| import oneflow | |||
| import argparse | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument('--car_ckpt', default="./ckpt/car_res50_model") | |||
| parser.add_argument('--dog_ckpt', default="./ckpt/dog_res50_model") | |||
| parser.add_argument('--aircraft_ckpt', default="./ckpt/aircraft_res50_model") | |||
| parser.add_argument('--lr', type=float, default=1e-3) | |||
| parser.add_argument('--batch_size', type=int, default=32) | |||
| args = parser.parse_args() | |||
| def main(): | |||
| # 数据集 | |||
| car_train_dst = vision.datasets.StanfordCars('./DataSets/StanfordCars/', split='train') | |||
| car_val_dst = vision.datasets.StanfordCars('./DataSets/StanfordCars/', split='test') | |||
| dog_train_dst = vision.datasets.StanfordDogs('./DataSets/StanfordDogs/', split='train') | |||
| dog_val_dst = vision.datasets.StanfordDogs('./DataSets/StanfordDogs/', split='test') | |||
| aircraft_train_dst = vision.datasets.FGVCAircraft('./DataSets/FGVCAircraft/', split='trainval') | |||
| aircraft_val_dst = vision.datasets.FGVCAircraft('./DataSets/FGVCAircraft/', split='test') | |||
| # 教师/学生 | |||
| car_teacher = vision.models.classification.resnet50(num_classes=196, pretrained=False) | |||
| dog_teacher = vision.models.classification.resnet50(num_classes=120, pretrained=False) | |||
| aircraft_teacher = vision.models.classification.resnet50(num_classes=102, pretrained=False) | |||
| student = vision.models.classification.resnet50(num_classes=196+120+102, pretrained=False) | |||
| # 权重参数 | |||
| cars_parameters = oneflow.load(args.car_ckpt) | |||
| dogs_parameters = oneflow.load(args.dog_ckpt) | |||
| aircraft_parameters = oneflow.load(args.aircraft_ckpt) | |||
| car_teacher.load_state_dict(cars_parameters) | |||
| dog_teacher.load_state_dict(dogs_parameters) | |||
| aircraft_teacher.load_state_dict(aircraft_parameters) | |||
| train_transform = sT.Compose( [ | |||
| sT.RandomResizedCrop(224), | |||
| sT.RandomHorizontalFlip(), | |||
| sT.ToTensor(), | |||
| sT.Normalize( mean=[0.485, 0.456, 0.406], | |||
| std=[0.229, 0.224, 0.225] ) | |||
| ] ) | |||
| val_transform = sT.Compose( [ | |||
| sT.Resize(256), | |||
| sT.CenterCrop( 224 ), | |||
| sT.ToTensor(), | |||
| sT.Normalize( mean=[0.485, 0.456, 0.406], | |||
| std=[0.229, 0.224, 0.225] ) | |||
| ] ) | |||
| car_train_dst.transform = dog_train_dst.transform = aircraft_train_dst.transform = train_transform | |||
| car_val_dst.transform = dog_val_dst.transform = aircraft_val_dst.transform = val_transform | |||
| car_metric = metrics.MetricCompose(metric_dict={'car_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, :196], t))}) | |||
| dog_metric = metrics.MetricCompose(metric_dict={'dog_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, 196:196+120], t))}) | |||
| aircraft_metric = metrics.MetricCompose(metric_dict={'aircraft_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, 196+120:196+120+102], t))}) | |||
| train_dst = oneflow.utils.data.ConcatDataset( [car_train_dst, dog_train_dst, aircraft_train_dst] ) | |||
| # pdb.set_trace() | |||
| train_loader = oneflow.utils.data.DataLoader( train_dst, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True ) | |||
| car_loader = oneflow.utils.data.DataLoader( car_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=True ) | |||
| dog_loader = oneflow.utils.data.DataLoader( dog_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=True) | |||
| aircraft_loader = oneflow.utils.data.DataLoader( aircraft_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=True ) | |||
| car_evaluator = engine.evaluator.BasicEvaluator( car_loader, car_metric ) | |||
| dog_evaluator = engine.evaluator.BasicEvaluator( dog_loader, dog_metric ) | |||
| aircraft_evaluator = engine.evaluator.BasicEvaluator( aircraft_loader, aircraft_metric ) | |||
| # pdb.set_trace() | |||
| TOTAL_ITERS=len(train_loader) * 100 | |||
| device = oneflow.device( 'cuda' if oneflow.cuda.is_available() else 'cpu' ) | |||
| optim = oneflow.optim.Adam( student.parameters(), lr=args.lr, weight_decay=1e-4) | |||
| sched = oneflow.optim.lr_scheduler.CosineAnnealingLR( optim, T_max=TOTAL_ITERS ) | |||
| trainer = amalgamation.LayerWiseAmalgamator( | |||
| logger=utils.logger.get_logger('layerwise-ka'), | |||
| # tb_writer=SummaryWriter( log_dir='run/layerwise_ka-%s'%( time.asctime().replace( ' ', '_' ) ) ) | |||
| ) | |||
| trainer.add_callback( | |||
| engine.DefaultEvents.AFTER_STEP(every=10), | |||
| callbacks=callbacks.MetricsLogging(keys=('total_loss', 'loss_kd', 'loss_amal', 'loss_recons', 'lr'))) | |||
| trainer.add_callback( | |||
| engine.DefaultEvents.AFTER_EPOCH, | |||
| callbacks=[ | |||
| callbacks.EvalAndCkpt(model=student, evaluator=car_evaluator, metric_name='car_acc', ckpt_prefix='thl_car'), | |||
| callbacks.EvalAndCkpt(model=student, evaluator=dog_evaluator, metric_name='dog_acc', ckpt_prefix='thl_dog'), | |||
| callbacks.EvalAndCkpt(model=student, evaluator=aircraft_evaluator, metric_name='aircraft_acc', ckpt_prefix='thl_aircraft'), | |||
| ] ) | |||
| trainer.add_callback( | |||
| engine.DefaultEvents.AFTER_STEP, | |||
| callbacks=callbacks.LRSchedulerCallback(schedulers=[sched])) | |||
| layer_groups = [] | |||
| layer_channels = [] | |||
| for stu_block, car_block, dog_block, aircraft_block in zip( student.modules(), car_teacher.modules(), dog_teacher.modules(), aircraft_teacher.modules() ): | |||
| if isinstance( stu_block, oneflow.nn.Conv2d ): | |||
| layer_groups.append( [ stu_block, car_block, dog_block, aircraft_block ] ) | |||
| layer_channels.append( [ stu_block.out_channels, car_block.out_channels, dog_block.out_channels, aircraft_block.out_channels ] ) | |||
| trainer.setup( student=student, | |||
| teachers=[car_teacher, dog_teacher, aircraft_teacher], | |||
| layer_groups=layer_groups, | |||
| layer_channels=layer_channels, | |||
| dataloader=train_loader, | |||
| optimizer=optim, | |||
| device=device, | |||
| weights=[1., 1., 1.] ) | |||
| trainer.run(start_iter=0, max_iter=TOTAL_ITERS) | |||
| if __name__=='__main__': | |||
| main() | |||
| @@ -0,0 +1,105 @@ | |||
| from kamal import vision, engine, utils, amalgamation, metrics, callbacks | |||
| from kamal.vision import sync_transforms as sT | |||
| import pdb | |||
| import oneflow | |||
| import argparse | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument( '--car_ckpt', default="./ckpt/car_res50_model") | |||
| parser.add_argument( '--dog_ckpt', default="./ckpt/dog_res50_model") | |||
| parser.add_argument( '--lr', type=float, default=1e-3) | |||
| args = parser.parse_args() | |||
| def main(): | |||
| # 数据集 | |||
| car_train_dst = vision.datasets.StanfordCars('./DataSets/StanfordCars/', split='train', s=0.1) | |||
| car_val_dst = vision.datasets.StanfordCars('./DataSets/StanfordCars/', split='test') | |||
| aircraft_train_dst = vision.datasets.StanfordDogs('./DataSets/StanfordDogs/', split='train', s=0.1) | |||
| aircraft_val_dst = vision.datasets.StanfordDogs('./DataSets/StanfordDogs/', split='test') | |||
| # 教师/学生 | |||
| car_teacher = vision.models.classification.resnet50(num_classes=196, pretrained=False) | |||
| dog_teacher = vision.models.classification.resnet50(num_classes=120, pretrained=False) | |||
| student = vision.models.classification.resnet50(num_classes=196+120, pretrained=False) | |||
| # 权重参数 | |||
| cars_parameters = oneflow.load(args.car_ckpt) | |||
| dogs_parameters = oneflow.load(args.dog_ckpt) | |||
| car_teacher.load_state_dict(cars_parameters) | |||
| dog_teacher.load_state_dict(dogs_parameters) | |||
| train_transform = sT.Compose( [ | |||
| sT.RandomResizedCrop(224), | |||
| sT.RandomHorizontalFlip(), | |||
| sT.ToTensor(), | |||
| sT.Normalize( mean=[0.485, 0.456, 0.406], | |||
| std=[0.229, 0.224, 0.225] ) | |||
| ] ) | |||
| val_transform = sT.Compose( [ | |||
| sT.Resize(256), | |||
| sT.CenterCrop( 224 ), | |||
| sT.ToTensor(), | |||
| sT.Normalize( mean=[0.485, 0.456, 0.406], | |||
| std=[0.229, 0.224, 0.225] ) | |||
| ] ) | |||
| car_train_dst.transform = aircraft_train_dst.transform = train_transform | |||
| car_val_dst.transform = aircraft_val_dst.transform = val_transform | |||
| car_metric = metrics.MetricCompose(metric_dict={ 'car_acc': metrics.Accuracy(attach_to=lambda o, t: ( o[:, :196], t ) ) } ) | |||
| aircraft_metric = metrics.MetricCompose(metric_dict={ 'aircraft_acc': metrics.Accuracy(attach_to=lambda o, t: ( o[:, 196:], t ) ) } ) | |||
| train_dst = oneflow.utils.data.ConcatDataset( [car_train_dst, aircraft_train_dst] ) | |||
| # pdb.set_trace() | |||
| train_loader = oneflow.utils.data.DataLoader( train_dst, batch_size=32, shuffle=True, num_workers=0 ) | |||
| car_loader = oneflow.utils.data.DataLoader( car_val_dst, batch_size=32, shuffle=False, num_workers=0 ) | |||
| aircraft_loader = oneflow.utils.data.DataLoader( aircraft_val_dst, batch_size=32, shuffle=False, num_workers=0 ) | |||
| car_evaluator = engine.evaluator.BasicEvaluator( car_loader, car_metric ) | |||
| aircraft_evaluator = engine.evaluator.BasicEvaluator( aircraft_loader, aircraft_metric ) | |||
| TOTAL_ITERS=len(train_loader) * 100 | |||
| device = oneflow.device( 'cuda' if oneflow.cuda.is_available() else 'cpu' ) | |||
| optim = oneflow.optim.Adam( student.parameters(), lr=args.lr, weight_decay=1e-4) | |||
| sched = oneflow.optim.lr_scheduler.CosineAnnealingLR( optim, T_max=TOTAL_ITERS ) | |||
| trainer = amalgamation.LayerWiseAmalgamator( | |||
| logger=utils.logger.get_logger('layerwise-ka'), | |||
| # tb_writer=SummaryWriter( log_dir='run/layerwise_ka-%s'%( time.asctime().replace( ' ', '_' ) ) ) | |||
| ) | |||
| trainer.add_callback( | |||
| engine.DefaultEvents.AFTER_STEP(every=10), | |||
| callbacks=callbacks.MetricsLogging(keys=('total_loss', 'loss_kd', 'loss_amal', 'loss_recons', 'lr'))) | |||
| trainer.add_callback( | |||
| engine.DefaultEvents.AFTER_EPOCH, | |||
| callbacks=[ | |||
| callbacks.EvalAndCkpt(model=student, evaluator=car_evaluator, metric_name='car_acc', ckpt_prefix='ttl_car'), | |||
| callbacks.EvalAndCkpt(model=student, evaluator=aircraft_evaluator, metric_name='aircraft_acc', ckpt_prefix='ttl_aircraft'), | |||
| ] ) | |||
| trainer.add_callback( | |||
| engine.DefaultEvents.AFTER_STEP, | |||
| callbacks=callbacks.LRSchedulerCallback(schedulers=[sched])) | |||
| layer_groups = [] | |||
| layer_channels = [] | |||
| for stu_block, car_block, aircraft_block in zip( student.modules(), car_teacher.modules(), dog_teacher.modules() ): | |||
| if isinstance( stu_block, oneflow.nn.Conv2d ): | |||
| layer_groups.append( [ stu_block, car_block, aircraft_block ] ) | |||
| layer_channels.append( [ stu_block.out_channels, car_block.out_channels, aircraft_block.out_channels ] ) | |||
| trainer.setup( student=student, | |||
| teachers=[car_teacher, dog_teacher], | |||
| layer_groups=layer_groups, | |||
| layer_channels=layer_channels, | |||
| dataloader=train_loader, | |||
| optimizer=optim, | |||
| device=device, | |||
| weights=[1., 1., 1.] ) | |||
| trainer.run(start_iter=0, max_iter=TOTAL_ITERS) | |||
| if __name__=='__main__': | |||
| main() | |||
| @@ -0,0 +1,41 @@ | |||
| """ | |||
| pytorch -- oneflow 格式权重转换 | |||
| """ | |||
| from kamal import vision | |||
| import torch | |||
| import argparse | |||
| import oneflow as flow | |||
| import pdb | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument( '--car_ckpt', required=True ) | |||
| parser.add_argument( '--dog_ckpt' ) | |||
| args = parser.parse_args() | |||
| cars_parameters = torch.load(args.car_ckpt).state_dict() | |||
| # dogs_parameters = torch.load(args.dog_ckpt).state_dict() | |||
| cars_para, dogs_para = {}, {} | |||
| # pdb.set_trace() | |||
| for key, value in cars_parameters.items(): | |||
| val = value.detach().cpu().numpy() | |||
| if not str(key).endswith('num_batches_tracked'): | |||
| cars_para[key] = val | |||
| # for key, value in dogs_parameters.items(): | |||
| # val = value.detach().cpu().numpy() | |||
| # if not str(key).endswith('num_batches_tracked'): | |||
| # dogs_para[key] = val | |||
| car_teacher = vision.models.classification.resnet50(num_classes=102, pretrained=False) | |||
| # dog_teacher = vision.models.classification.resnet50(num_classes=397, pretrained=False) | |||
| car_teacher.load_state_dict(cars_para) | |||
| # dog_teacher.load_state_dict(dogs_para) | |||
| # torch.save(car_teacher, 'ckpt/aircraft_res50.pth') | |||
| # torch.save(dog_teacher, 'checkpoint/sun_res50.pth') | |||
| flow.save(car_teacher.state_dict(), "./ckpt/aircraft_res50_model") | |||
| # flow.save(dog_teacher.state_dict(), "./checkpoint/sun_res50_model") | |||
| @@ -0,0 +1,3 @@ | |||
| from .core import metrics, engine, callbacks | |||
| from . import amalgamation, vision | |||
| @@ -0,0 +1 @@ | |||
| from .layerwise_amalgamation import LayerWiseAmalgamator | |||
| @@ -0,0 +1,113 @@ | |||
| import oneflow as oneflow | |||
| import oneflow.nn as nn | |||
| from kamal.core.engine.engine import Engine | |||
| from kamal.core.engine.hooks import FeatureHook | |||
| from kamal.core import tasks | |||
| import typing | |||
| import time | |||
| from kamal.utils import move_to_device, set_mode | |||
| class AmalBlock(nn.Module): | |||
| def __init__(self, cs, cts): | |||
| super( AmalBlock, self ).__init__() | |||
| self.cs, self.cts = cs, cts | |||
| self.enc = nn.Conv2d( in_channels=sum(self.cts), out_channels=self.cs, kernel_size=1, stride=1, padding=0, bias=True ) | |||
| self.fam = nn.Conv2d( in_channels=self.cs, out_channels=self.cs, kernel_size=1, stride=1, padding=0, bias=True ) | |||
| self.dec = nn.Conv2d( in_channels=self.cs, out_channels=sum(self.cts), kernel_size=1, stride=1, padding=0, bias=True ) | |||
| def forward(self, fs, fts): | |||
| rep = self.enc( oneflow.cat( fts, dim=1 ) ) | |||
| _fts = self.dec( rep ) | |||
| _fts = oneflow.split( _fts, self.cts, dim=1 ) | |||
| _fs = self.fam( fs ) | |||
| return rep, _fs, _fts | |||
| class LayerWiseAmalgamator(Engine): | |||
| def setup( | |||
| self, | |||
| student, | |||
| teachers, | |||
| layer_groups: typing.Sequence[typing.Sequence], | |||
| layer_channels: typing.Sequence[typing.Sequence], | |||
| dataloader: oneflow.utils.data.DataLoader, | |||
| optimizer: oneflow.optim.Optimizer, | |||
| weights = [1., 1., 1.], | |||
| device=None, | |||
| ): | |||
| if device is None: | |||
| device = oneflow.device('cuda:0') | |||
| # device = oneflow.device( 'cuda' if oneflow.cuda.is_available() else 'cpu' ) | |||
| self._device = device | |||
| self._dataloader = dataloader | |||
| self.model = self.student = student.to(self.device) | |||
| self.teachers = nn.ModuleList(teachers).to(self.device) | |||
| self.optimizer = optimizer | |||
| self._weights = weights | |||
| amal_blocks = [] | |||
| for group, C in zip(layer_groups, layer_channels): | |||
| hooks = [ FeatureHook(layer) for layer in group ] | |||
| amal_block = AmalBlock(cs=C[0], cts=C[1:]).to(self.device).train() | |||
| amal_blocks.append( (amal_block, hooks, C) ) | |||
| self._amal_blocks = amal_blocks | |||
| @property | |||
| def device(self): | |||
| return self._device | |||
| def run(self, max_iter, start_iter=0, epoch_length=None ): | |||
| block_params = [] | |||
| for block, _, _ in self._amal_blocks: | |||
| block_params.extend( list(block.parameters()) ) | |||
| if isinstance( self.optimizer, oneflow.optim.SGD ): | |||
| self._amal_optimimizer = oneflow.optim.SGD( block_params, lr=self.optimizer.param_groups[0]['lr'], momentum=0.9, weight_decay=1e-4 ) | |||
| else: | |||
| self._amal_optimimizer = oneflow.optim.Adam( block_params, lr=self.optimizer.param_groups[0]['lr'], weight_decay=1e-4 ) | |||
| self._amal_scheduler = oneflow.optim.lr_scheduler.CosineAnnealingLR( self._amal_optimimizer, T_max=max_iter ) | |||
| with set_mode(self.student, training=True), \ | |||
| set_mode(self.teachers, training=False): | |||
| super( LayerWiseAmalgamator, self ).run(self.step_fn, self._dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length) | |||
| @property | |||
| def device(self): | |||
| return self._device | |||
| def step_fn(self, engine, batch): | |||
| start_time = time.perf_counter() | |||
| batch = move_to_device(batch, self._device) | |||
| data = batch[0] | |||
| s_out = self.student( data ) | |||
| with oneflow.no_grad(): | |||
| t_out = [ teacher( data ) for teacher in self.teachers ] | |||
| loss_amal = 0 | |||
| loss_recons = 0 | |||
| for amal_block, hooks, C in self._amal_blocks: | |||
| features = [ h.feat_out for h in hooks ] | |||
| fs, fts = features[0], features[1:] | |||
| rep, _fs, _fts = amal_block( fs, fts ) | |||
| mse_loss = nn.MSELoss() | |||
| loss_amal += mse_loss( _fs, rep.detach() ) | |||
| loss_recons += sum( [ mse_loss( _ft, ft ) for (_ft, ft) in zip( _fts, fts ) ] ) | |||
| loss_kd = tasks.loss.kldiv( s_out, oneflow.cat( t_out, dim=1 ) ) | |||
| #loss_kd = F.mse_loss( s_out, oneflow.cat( t_out, dim=1 ) ) | |||
| loss_dict = { "loss_kd": self._weights[0] * loss_kd, | |||
| "loss_amal": self._weights[1] * loss_amal, | |||
| "loss_recons": self._weights[2] * loss_recons } | |||
| loss = sum(loss_dict.values()) | |||
| self.optimizer.zero_grad() | |||
| self._amal_optimimizer.zero_grad() | |||
| loss.backward() | |||
| self.optimizer.step() | |||
| self._amal_optimimizer.step() | |||
| self._amal_scheduler.step() | |||
| step_time = time.perf_counter() - start_time | |||
| metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } | |||
| metrics.update({ | |||
| 'total_loss': loss.item(), | |||
| 'step_time': step_time, | |||
| 'lr': float( self.optimizer.param_groups[0]['lr'] ) | |||
| }) | |||
| return metrics | |||
| @@ -0,0 +1,2 @@ | |||
| from . import engine, tasks, metrics, callbacks, exceptions | |||
| from .attach import AttachTo | |||
| @@ -0,0 +1,28 @@ | |||
| from typing import Sequence, Callable | |||
| from numbers import Number | |||
| from kamal.core import exceptions | |||
| import pdb | |||
| class AttachTo(object): | |||
| """ Attach task, metrics or visualizer to specified model outputs | |||
| """ | |||
| def __init__(self, attach_to=None): | |||
| if attach_to is not None and not isinstance(attach_to, (Sequence, Number, str, Callable) ): | |||
| raise exceptions.InvalidMapping | |||
| self._attach_to = attach_to | |||
| def __call__(self, *tensors): | |||
| if self._attach_to is not None: | |||
| if isinstance(self._attach_to, Callable): | |||
| return self._attach_to( *tensors ) | |||
| if isinstance(self._attach_to, Sequence): | |||
| _attach_to = self._attach_to | |||
| else: | |||
| _attach_to = [ self._attach_to for _ in range(len(tensors)) ] | |||
| _attach_to = _attach_to[:len(tensors)] | |||
| tensors = [ tensor[index] for (tensor, index) in zip( tensors, _attach_to ) ] | |||
| if len(tensors)==1: | |||
| tensors = tensors[0] | |||
| return tensors | |||
| def __repr__(self): | |||
| rep = "AttachTo: %s"%(self._attach_to) | |||
| @@ -0,0 +1,4 @@ | |||
| from .logging import MetricsLogging, ProgressCallback | |||
| from .base import Callback | |||
| from .eval_and_ckpt import EvalAndCkpt | |||
| from .scheduler import LRSchedulerCallback | |||
| @@ -0,0 +1,26 @@ | |||
| # Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================= | |||
| import abc | |||
| class Callback(abc.ABC): | |||
| r""" Base Class for Callbacks | |||
| """ | |||
| def __init__(self): | |||
| pass | |||
| @abc.abstractmethod | |||
| def __call__(self, engine): | |||
| pass | |||
| @@ -0,0 +1,133 @@ | |||
| from .base import Callback | |||
| import weakref | |||
| from kamal import utils | |||
| from typing import Sequence, Optional | |||
| import numbers | |||
| import os, shutil | |||
| import oneflow | |||
| import pdb | |||
| class EvalAndCkpt(Callback): | |||
| def __init__(self, | |||
| model, | |||
| evaluator, | |||
| metric_name:str, | |||
| metric_mode:str ='max', | |||
| save_type:Optional[Sequence]=('best', 'latest'), | |||
| ckpt_dir:str ='checkpoints', | |||
| ckpt_prefix:str =None, | |||
| log_tag:str ='model', | |||
| weights_only:bool =True, | |||
| verbose:bool =False,): | |||
| super(EvalAndCkpt, self).__init__() | |||
| self.metric_name = metric_name | |||
| assert metric_mode in ('max', 'min'), "metric_mode should be 'max' or 'min'" | |||
| self._metric_mode = metric_mode | |||
| self._model = weakref.ref( model ) | |||
| self._evaluator = evaluator | |||
| self._ckpt_dir = ckpt_dir | |||
| self._ckpt_prefix = "" if ckpt_prefix is None else (ckpt_prefix+'_') | |||
| if isinstance(save_type, str): | |||
| save_type = ( save_type, ) | |||
| self._save_type = save_type | |||
| if self._save_type is not None: | |||
| for save_type in self._save_type: | |||
| assert save_type in ('best', 'latest', 'all'), \ | |||
| 'save_type should be None or a subset of (\"best\", \"latest\", \"all\")' | |||
| self._log_tag = log_tag | |||
| self._weights_only = weights_only | |||
| self._verbose = verbose | |||
| self._best_score = -999999 if self._metric_mode=='max' else 99999. | |||
| self._best_ckpt = None | |||
| self._latest_ckpt = None | |||
| @property | |||
| def best_ckpt(self): | |||
| return self._best_ckpt | |||
| @property | |||
| def latest_ckpt(self): | |||
| return self._latest_ckpt | |||
| @property | |||
| def best_score(self): | |||
| return self._best_score | |||
| def __call__(self, trainer): | |||
| model = self._model() | |||
| results = self._evaluator.eval( model, device=trainer.device ) | |||
| results = utils.flatten_dict(results) | |||
| # pdb.set_trace() | |||
| scalar_results = { k: v.item() for k, v in results.items()} | |||
| current_score = scalar_results[self.metric_name] | |||
| if trainer.logger is not None: | |||
| trainer.logger.info( "[Eval %s] Iter %d/%d: %s"%(self._log_tag, trainer.state.iter, trainer.state.max_iter, scalar_results) ) | |||
| trainer.state.metrics.update( scalar_results ) | |||
| # Visualize results if trainer.tb_writer is not None | |||
| if trainer.tb_writer is not None: | |||
| for k, v in scalar_results.items(): | |||
| log_tag = "%s:%s"%(self._log_tag, k) | |||
| trainer.tb_writer.add_scalar(log_tag, v, global_step=trainer.state.iter) | |||
| if self._save_type is not None: | |||
| pth_path_list = [] | |||
| # interval model | |||
| if 'interval' in self._save_type: | |||
| pth_path = os.path.join(self._ckpt_dir, "%s%08d_%s_%.3f" | |||
| % (self._ckpt_prefix, trainer.state.iter, self.metric_name, current_score)) | |||
| pth_path_list.append(pth_path) | |||
| # the latest model | |||
| if 'latest' in self._save_type: | |||
| pth_path = os.path.join(self._ckpt_dir, "%slatest_%08d_%s_%.3f" | |||
| % (self._ckpt_prefix, trainer.state.iter, self.metric_name, current_score)) | |||
| # remove the old ckpt | |||
| if self._latest_ckpt is not None and os.path.exists(self._latest_ckpt): | |||
| # os.remove(self._latest_ckpt) | |||
| shutil.rmtree(self._latest_ckpt) | |||
| pth_path_list.append(pth_path) | |||
| self._latest_ckpt = pth_path | |||
| # the best model | |||
| if 'best' in self._save_type: | |||
| if (current_score >= self._best_score and self._metric_mode=='max') or \ | |||
| (current_score <= self._best_score and self._metric_mode=='min'): | |||
| pth_path = os.path.join(self._ckpt_dir, "%sbest_%08d_%s_%.4f" % | |||
| (self._ckpt_prefix, trainer.state.iter, self.metric_name, current_score)) | |||
| # remove the old ckpt | |||
| if self._best_ckpt is not None and os.path.exists(self._best_ckpt): | |||
| # os.remove(self._best_ckpt) | |||
| shutil.rmtree(self._best_ckpt) | |||
| pth_path_list.append(pth_path) | |||
| self._best_score = current_score | |||
| self._best_ckpt = pth_path | |||
| # save model | |||
| if self._verbose and trainer.logger: | |||
| trainer.logger.info("Model saved as:") | |||
| obj = model.state_dict() if self._weights_only else model | |||
| os.makedirs( self._ckpt_dir, exist_ok=True ) | |||
| for pth_path in pth_path_list: | |||
| oneflow.save(obj, pth_path) | |||
| if self._verbose and trainer.logger: | |||
| trainer.logger.info("\t%s" % (pth_path)) | |||
| def final_ckpt(self, ckpt_prefix=None, ckpt_dir=None, add_md5=False): | |||
| if ckpt_dir is None: | |||
| ckpt_dir = self._ckpt_dir | |||
| if ckpt_prefix is None: | |||
| ckpt_prefix = self._ckpt_prefix | |||
| if self._save_type is not None: | |||
| #if 'latest' in self._save_type and self._latest_ckpt is not None: | |||
| # os.makedirs(ckpt_dir, exist_ok=True) | |||
| # save_name = "%slatest%s.pth"%(ckpt_prefix, "" if not add_md5 else "-%s"%utils.md5(self._latest_ckpt)) | |||
| # shutil.copy2(self._latest_ckpt, os.path.join(ckpt_dir, save_name)) | |||
| if 'best' in self._save_type and self._best_ckpt is not None: | |||
| os.makedirs(ckpt_dir, exist_ok=True) | |||
| save_name = "%sbest%s.pth"%(ckpt_prefix, "" if not add_md5 else "-%s"%utils.md5(self._best_ckpt)) | |||
| shutil.copy2(self._best_ckpt, os.path.join(ckpt_dir, save_name)) | |||
| @@ -0,0 +1,59 @@ | |||
| # Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================= | |||
| from .base import Callback | |||
| import numbers | |||
| from tqdm import tqdm | |||
| class MetricsLogging(Callback): | |||
| def __init__(self, keys): | |||
| super(MetricsLogging, self).__init__() | |||
| self._keys = keys | |||
| def __call__(self, engine): | |||
| if engine.logger==None: | |||
| return | |||
| state = engine.state | |||
| content = "Iter %d/%d (Epoch %d/%d, Batch %d/%d)"%( | |||
| state.iter, state.max_iter, | |||
| state.current_epoch, state.max_epoch, | |||
| state.current_batch_index, state.max_batch_index | |||
| ) | |||
| for key in self._keys: | |||
| value = state.metrics.get(key, None) | |||
| if value is not None: | |||
| if isinstance(value, numbers.Number): | |||
| content += " %s=%.4f"%(key, value) | |||
| if engine.tb_writer is not None: | |||
| engine.tb_writer.add_scalar(key, value, global_step=state.iter) | |||
| elif isinstance(value, (list, tuple)): | |||
| content += " %s=%s"%(key, value) | |||
| engine.logger.info(content) | |||
| class ProgressCallback(Callback): | |||
| def __init__(self, max_iter=100, tag=None): | |||
| self._max_iter = max_iter | |||
| self._tag = tag | |||
| #self._pbar = tqdm(total=self._max_iter, desc=self._tag) | |||
| def __call__(self, engine): | |||
| self._pbar.update(1) | |||
| if self._pbar.n==self._max_iter: | |||
| self._pbar.close() | |||
| def reset(self): | |||
| self._pbar = tqdm(total=self._max_iter, desc=self._tag) | |||
| @@ -0,0 +1,32 @@ | |||
| # Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================= | |||
| from .base import Callback | |||
| from typing import Sequence | |||
| class LRSchedulerCallback(Callback): | |||
| r""" LR scheduler callback | |||
| """ | |||
| def __init__(self, schedulers=None): | |||
| super(LRSchedulerCallback, self).__init__() | |||
| if not isinstance(schedulers, Sequence): | |||
| schedulers = ( schedulers, ) | |||
| self._schedulers = schedulers | |||
| def __call__(self, trainer): | |||
| if self._schedulers is None: | |||
| return | |||
| for sched in self._schedulers: | |||
| sched.step() | |||
| @@ -0,0 +1,7 @@ | |||
| from . import evaluator | |||
| from . import trainer | |||
| from . import hooks | |||
| from . import events | |||
| from .engine import DefaultEvents, Event | |||
| @@ -0,0 +1,188 @@ | |||
| # Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================= | |||
| import oneflow.nn as nn | |||
| import abc, weakref | |||
| from typing import Any, Callable, Sequence | |||
| from kamal.core.engine.events import DefaultEvents, Event | |||
| from kamal.utils import get_logger | |||
| from collections import defaultdict | |||
| import numbers | |||
| import contextlib | |||
| import pdb | |||
| class State(object): | |||
| def __init__(self): | |||
| self.iter = 0 | |||
| self.max_iter = None | |||
| self.epoch_length = None | |||
| self.dataloader = None | |||
| self.seed = None | |||
| self.metrics=dict() | |||
| self.batch=None | |||
| @property | |||
| def current_epoch(self): | |||
| if self.epoch_length is not None: | |||
| return self.iter // self.epoch_length | |||
| return None | |||
| @property | |||
| def max_epoch(self): | |||
| if self.epoch_length is not None: | |||
| return self.max_iter // self.epoch_length | |||
| return None | |||
| @property | |||
| def current_batch_index(self): | |||
| if self.epoch_length is not None: | |||
| return self.iter % self.epoch_length | |||
| return None | |||
| @property | |||
| def max_batch_index(self): | |||
| return self.epoch_length | |||
| def __repr__(self): | |||
| rep = "State:\n" | |||
| for attr, value in self.__dict__.items(): | |||
| if not isinstance(value, (numbers.Number, str, dict)): | |||
| value = type(value) | |||
| rep += "\t{}: {}\n".format(attr, value) | |||
| return rep | |||
| class Engine(abc.ABC): | |||
| def __init__(self, logger=None, tb_writer=None): | |||
| self._logger = logger if logger else get_logger(name='kamal', color=True) | |||
| self._tb_writer = tb_writer | |||
| self._callbacks = defaultdict(list) | |||
| self._allowed_events = [ *DefaultEvents ] | |||
| self._state = State() | |||
| def reset(self): | |||
| self._state = State() | |||
| def run(self, step_fn: Callable, dataloader, max_iter, start_iter=0, epoch_length=None): | |||
| self.state.iter = self._state.start_iter = start_iter | |||
| self.state.max_iter = max_iter | |||
| self.state.epoch_length = epoch_length if epoch_length else len(dataloader) | |||
| self.state.dataloader = dataloader | |||
| self.state.dataloader_iter = iter(dataloader) | |||
| self.state.step_fn = step_fn | |||
| self.trigger_events(DefaultEvents.BEFORE_RUN) | |||
| for self.state.iter in range( start_iter, max_iter ): | |||
| if self.state.epoch_length!=None and \ | |||
| self.state.iter%self.state.epoch_length==0: # Epoch Start | |||
| self.trigger_events(DefaultEvents.BEFORE_EPOCH) | |||
| self.trigger_events(DefaultEvents.BEFORE_STEP) | |||
| self.state.batch = self._get_batch() | |||
| step_output = step_fn(self, self.state.batch) | |||
| if isinstance(step_output, dict): | |||
| self.state.metrics.update(step_output) | |||
| self.trigger_events(DefaultEvents.AFTER_STEP) | |||
| if self.state.epoch_length!=None and \ | |||
| (self.state.iter+1)%self.state.epoch_length==0: # Epoch End | |||
| self.trigger_events(DefaultEvents.AFTER_EPOCH) | |||
| self.trigger_events(DefaultEvents.AFTER_RUN) | |||
| def _get_batch(self): | |||
| try: | |||
| batch = next( self.state.dataloader_iter ) | |||
| except StopIteration: | |||
| self.state.dataloader_iter = iter(self.state.dataloader) # reset iterator | |||
| batch = next( self.state.dataloader_iter ) | |||
| if not isinstance(batch, (list, tuple)): | |||
| batch = [ batch, ] # no targets | |||
| return batch | |||
| @property | |||
| def state(self): | |||
| return self._state | |||
| @property | |||
| def logger(self): | |||
| return self._logger | |||
| @property | |||
| def tb_writer(self): | |||
| return self._tb_writer | |||
| def add_callback(self, event: Event, callbacks ): | |||
| if not isinstance(callbacks, Sequence): | |||
| callbacks = [callbacks] | |||
| if event in self._allowed_events: | |||
| for callback in callbacks: | |||
| if callback not in self._callbacks[event]: | |||
| if event.trigger!=event.default_trigger: | |||
| callback = self._trigger_wrapper(self, event.trigger, callback ) | |||
| self._callbacks[event].append( callback ) | |||
| callbacks = [ RemovableCallback(self, event, c) for c in callbacks ] | |||
| return ( callbacks[0] if len(callbacks)==1 else callbacks ) | |||
| def remove_callback(self, event, callback): | |||
| for c in self._callbacks[event]: | |||
| if c==callback: | |||
| self._callbacks.remove( callback ) | |||
| return True | |||
| return False | |||
| @staticmethod | |||
| def _trigger_wrapper(engine, trigger, callback): | |||
| def wrapper(*args, **kwargs) -> Any: | |||
| if trigger(engine): | |||
| return callback(engine) | |||
| return wrapper | |||
| def trigger_events(self, *events): | |||
| for e in events: | |||
| if e in self._allowed_events: | |||
| for callback in self._callbacks[e]: | |||
| callback(self) | |||
| def register_events(self, *events): | |||
| for e in events: | |||
| if e not in self._allowed_events: | |||
| self._allowed_events.apped( e ) | |||
| @contextlib.contextmanager | |||
| def save_current_callbacks(self): | |||
| temp = self._callbacks | |||
| self._callbacks = defaultdict(list) | |||
| yield | |||
| self._callbacks = temp | |||
| class RemovableCallback: | |||
| def __init__(self, engine, event, callback): | |||
| self._engine = weakref.ref(engine) | |||
| self._callback = weakref.ref(callback) | |||
| self._event = weakref.ref(event) | |||
| @property | |||
| def callback(self): | |||
| return self._callback() | |||
| def remove(self): | |||
| engine = self._engine() | |||
| callback = self._callback() | |||
| event = self._event() | |||
| return engine.remove_callback(event, callback) | |||
| @@ -0,0 +1,62 @@ | |||
| import oneflow | |||
| from kamal.core import metrics | |||
| from kamal.utils import set_mode | |||
| from typing import Callable | |||
| from .engine import Engine | |||
| from .events import DefaultEvents | |||
| from kamal.core import callbacks | |||
| import pdb | |||
| import weakref | |||
| from kamal.utils import move_to_device, split_batch | |||
| class BasicEvaluator(Engine): | |||
| def __init__(self, | |||
| dataloader: oneflow.utils.data.DataLoader, | |||
| metric: metrics.MetricCompose, | |||
| eval_fn: Callable=None, | |||
| tag: str='Eval', | |||
| progress: bool=False ): | |||
| super( BasicEvaluator, self ).__init__() | |||
| self.dataloader = dataloader | |||
| self.metric = metric | |||
| self.progress = progress | |||
| if progress: | |||
| self.porgress_callback = self.add_callback( | |||
| DefaultEvents.AFTER_STEP, callbacks=callbacks.ProgressCallback(max_iter=len(self.dataloader), tag=tag)) | |||
| self._model = None | |||
| self._tag = tag | |||
| if eval_fn is None: | |||
| eval_fn = BasicEvaluator.default_eval_fn | |||
| self.eval_fn = eval_fn | |||
| def eval(self, model, device=None): | |||
| device = device if device is not None else \ | |||
| oneflow.device( 'cuda' if oneflow.cuda.is_available() else 'cpu' ) | |||
| self._model = weakref.ref(model) # use weakref here | |||
| self.device = device | |||
| self.metric.reset() | |||
| model.to(device) | |||
| if self.progress: | |||
| self.porgress_callback.callback.reset() | |||
| with oneflow.no_grad(), set_mode(model, training=False): | |||
| super(BasicEvaluator, self).run( self.step_fn, self.dataloader, max_iter=len(self.dataloader) ) | |||
| return self.metric.get_results() | |||
| @property | |||
| def model(self): | |||
| if self._model is not None: | |||
| return self._model() | |||
| return None | |||
| def step_fn(self, engine, batch): | |||
| batch = move_to_device(batch, self.device) | |||
| self.eval_fn( engine, batch ) | |||
| @staticmethod | |||
| def default_eval_fn(evaluator, batch): | |||
| model = evaluator.model | |||
| inputs, targets = split_batch(batch) | |||
| outputs = model( inputs ) | |||
| evaluator.metric.update( outputs, targets ) | |||
| @@ -0,0 +1,90 @@ | |||
| # Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================= | |||
| from typing import Callable, Optional | |||
| from enum import Enum | |||
| class Event(object): | |||
| def __init__(self, value: str, event_trigger: Optional[Callable]=None ): | |||
| if event_trigger is None: | |||
| event_trigger = Event.default_trigger | |||
| self._trigger = event_trigger | |||
| self._name_ = self._value_ = value | |||
| @property | |||
| def trigger(self): | |||
| return self._trigger | |||
| @property | |||
| def name(self): | |||
| """The name of the Enum member.""" | |||
| return self._name_ | |||
| @property | |||
| def value(self): | |||
| """The value of the Enum member.""" | |||
| return self._value_ | |||
| @staticmethod | |||
| def default_trigger(engine): | |||
| return True | |||
| @staticmethod | |||
| def once_trigger(): | |||
| is_triggered = False | |||
| def wrapper(engine): | |||
| if is_triggered: | |||
| return False | |||
| is_triggered=True | |||
| return True | |||
| return wrapper | |||
| @staticmethod | |||
| def every_trigger(every: int): | |||
| def wrapper(engine): | |||
| return every>0 and (engine.state.iter % every)==0 | |||
| return wrapper | |||
| def __call__(self, every: Optional[int]=None, once: Optional[bool]=None ): | |||
| if every is not None: | |||
| assert once is None | |||
| return Event(self.value, event_trigger=Event.every_trigger(every) ) | |||
| if once is not None: | |||
| return Event(self.value, event_trigger=Event.once_trigger() ) | |||
| return Event(self.value) | |||
| def __hash__(self): | |||
| return hash(self._name_) | |||
| def __eq__(self, other): | |||
| if hasattr(other, 'value'): | |||
| return self.value==other.value | |||
| else: | |||
| return | |||
| class DefaultEvents(Event, Enum): | |||
| BEFORE_RUN = "before_train" | |||
| AFTER_RUN = "after_train" | |||
| BEFORE_EPOCH = "before_epoch" | |||
| AFTER_EPOCH = "after_epoch" | |||
| BEFORE_STEP = "before_step" | |||
| AFTER_STEP = "after_step" | |||
| BEFORE_GET_BATCH = "before_get_batch" | |||
| AFTER_GET_BATCH = "after_get_batch" | |||
| @@ -0,0 +1,32 @@ | |||
| # Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================= | |||
| class FeatureHook(): | |||
| def __init__(self, module): | |||
| self.module = module | |||
| self.feat_in = None | |||
| self.feat_out = None | |||
| self.register() | |||
| def register(self): | |||
| self._hook = self.module.register_forward_hook(self.hook_fn_forward) | |||
| def remove(self): | |||
| self._hook.remove() | |||
| def hook_fn_forward(self, module, fea_in, fea_out): | |||
| self.feat_in = fea_in[0] | |||
| self.feat_out = fea_out | |||
| @@ -0,0 +1,115 @@ | |||
| import oneflow as torch | |||
| import oneflow.nn as nn | |||
| from kamal.core.engine.engine import Engine | |||
| from kamal.core import tasks | |||
| from kamal.utils import set_mode, move_to_device, split_batch | |||
| from typing import Sequence | |||
| import time | |||
| import pdb | |||
| class BasicTrainer(Engine): | |||
| def __init__( self, | |||
| logger=None, | |||
| tb_writer=None): | |||
| super(BasicTrainer, self).__init__(logger=logger, tb_writer=tb_writer) | |||
| def setup(self, | |||
| model: torch.nn.Module, | |||
| task: tasks.Task, | |||
| dataloader: torch.utils.data.DataLoader, | |||
| optimizer: torch.optim.Optimizer, | |||
| device: torch.device=None): | |||
| if device is None: | |||
| device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) | |||
| self.device = device | |||
| if isinstance(task, Sequence): | |||
| task = tasks.TaskCompose(task) | |||
| self.task = task | |||
| # 多GPU | |||
| # torch.cuda.set_device(4) | |||
| self.model = model.to("cuda") | |||
| self.model = nn.parallel.DistributedDataParallel(self.model) | |||
| self.dataloader = dataloader | |||
| self.optimizer = optimizer | |||
| return self | |||
| def run( self, max_iter, start_iter=0, epoch_length=None): | |||
| self.model.to(self.device) | |||
| with set_mode(self.model, training=True): | |||
| super( BasicTrainer, self ).run( self.step_fn, self.dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length) | |||
| def step_fn(self, engine, batch): | |||
| model = self.model | |||
| start_time = time.perf_counter() | |||
| batch = move_to_device(batch, self.device) | |||
| inputs, targets = split_batch(batch) | |||
| outputs = model(inputs) | |||
| loss_dict = self.task.get_loss(outputs, targets) # get loss | |||
| loss = sum( loss_dict.values() ) | |||
| self.optimizer.zero_grad() | |||
| loss.backward() | |||
| self.optimizer.step() | |||
| step_time = time.perf_counter() - start_time | |||
| metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } | |||
| metrics.update({ | |||
| 'total_loss': loss.item(), | |||
| 'step_time': step_time, | |||
| 'lr': float( self.optimizer.param_groups[0]['lr'] ) | |||
| }) | |||
| return metrics | |||
| class KDTrainer(BasicTrainer): | |||
| def setup(self, | |||
| student: torch.nn.Module, | |||
| teacher: torch.nn.Module, | |||
| task: tasks.Task, | |||
| dataloader: torch.utils.data.DataLoader, | |||
| optimizer: torch.optim.Optimizer, | |||
| device: torch.device=None): | |||
| super(KDTrainer, self).setup( | |||
| model=student, task=task, dataloader=dataloader, optimizer=optimizer, device=device) | |||
| if isinstance(teacher, (list, tuple)): | |||
| if len(teacher)==1: | |||
| teacher=teacher[0] | |||
| else: | |||
| teacher = nn.ModuleList(teacher) | |||
| self.student = self.model | |||
| self.teacher = teacher | |||
| return self | |||
| def run( self, max_iter, start_iter=0, epoch_length=None): | |||
| self.student.to(self.device) | |||
| self.teacher.to(self.device) | |||
| with set_mode(self.student, training=True), \ | |||
| set_mode(self.teacher, training=False): | |||
| super( BasicTrainer, self ).run( | |||
| self.step_fn, self.dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length) | |||
| def step_fn(self, engine, batch): | |||
| model = self.model | |||
| start_time = time.perf_counter() | |||
| batch = move_to_device(batch, self.device) | |||
| inputs, targets = split_batch(batch) | |||
| outputs = model(inputs) | |||
| if isinstance(self.teacher, nn.ModuleList): | |||
| soft_targets = [ t(inputs) for t in self.teacher ] | |||
| else: | |||
| soft_targets = self.teacher(inputs) | |||
| loss_dict = self.task.get_loss(outputs, soft_targets) # get loss | |||
| loss = sum( loss_dict.values() ) | |||
| self.optimizer.zero_grad() | |||
| loss.backward() | |||
| self.optimizer.step() | |||
| step_time = time.perf_counter() - start_time | |||
| metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } | |||
| metrics.update({ | |||
| 'total_loss': loss.item(), | |||
| 'step_time': step_time, | |||
| 'lr': float( self.optimizer.param_groups[0]['lr'] ) | |||
| }) | |||
| return metrics | |||
| @@ -0,0 +1,5 @@ | |||
| class DataTypeError(object): | |||
| pass | |||
| class InvalidMapping(object): | |||
| pass | |||
| @@ -0,0 +1,3 @@ | |||
| from .stream_metrics import Metric, MetricCompose | |||
| from .accuracy import Accuracy | |||
| @@ -0,0 +1,23 @@ | |||
| import oneflow | |||
| from kamal.core.metrics.stream_metrics import Metric | |||
| __all__=['Accuracy'] | |||
| class Accuracy(Metric): | |||
| def __init__(self, attach_to=None): | |||
| super(Accuracy, self).__init__(attach_to=attach_to) | |||
| self.reset() | |||
| @oneflow.no_grad() | |||
| def update(self, outputs, targets): | |||
| outputs, targets = self._attach(outputs, targets) | |||
| outputs = outputs.max(1)[1] | |||
| self._correct += ( outputs.view(-1)==targets.view(-1) ).sum() | |||
| self._cnt += oneflow.numel( targets ) | |||
| def get_results(self): | |||
| return (self._correct / self._cnt).detach().cpu() | |||
| def reset(self): | |||
| self._correct = self._cnt = 0.0 | |||
| @@ -0,0 +1,63 @@ | |||
| from __future__ import division | |||
| import oneflow | |||
| from abc import ABC, abstractmethod | |||
| from typing import Mapping | |||
| from kamal.core.attach import AttachTo | |||
| class Metric(ABC): | |||
| def __init__(self, attach_to=None): | |||
| self._attach = AttachTo(attach_to) | |||
| @abstractmethod | |||
| def update(self, pred, target): | |||
| """ Overridden by subclasses """ | |||
| raise NotImplementedError() | |||
| @abstractmethod | |||
| def get_results(self): | |||
| """ Overridden by subclasses """ | |||
| raise NotImplementedError() | |||
| @abstractmethod | |||
| def reset(self): | |||
| """ Overridden by subclasses """ | |||
| raise NotImplementedError() | |||
| class MetricCompose(dict): | |||
| def __init__(self, metric_dict: Mapping): | |||
| self._metric_dict = metric_dict | |||
| def add_metrics( self, metric_dict: Mapping): | |||
| if isinstance(metric_dict, MetricCompose): | |||
| metric_dict = metric_dict.metrics | |||
| self._metric_dict.update(metric_dict) | |||
| return self | |||
| @property | |||
| def metrics(self): | |||
| return self._metric_dict | |||
| @oneflow.no_grad() | |||
| def update(self, outputs, targets): | |||
| for key, metric in self._metric_dict.items(): | |||
| if isinstance(metric, Metric): | |||
| metric.update(outputs, targets) | |||
| def get_results(self): | |||
| results = {} | |||
| for key, metric in self._metric_dict.items(): | |||
| if isinstance(metric, Metric): | |||
| results[key] = metric.get_results() | |||
| return results | |||
| def reset(self): | |||
| for key, metric in self._metric_dict.items(): | |||
| if isinstance(metric, Metric): | |||
| metric.reset() | |||
| def __getitem__(self, name): | |||
| return self._metric_dict[name] | |||
| @@ -0,0 +1,3 @@ | |||
| from .task import StandardTask, StandardMetrics, GeneralTask, Task, TaskCompose | |||
| from . import loss | |||
| @@ -0,0 +1,2 @@ | |||
| from . import functional, loss | |||
| from .loss import * | |||
| @@ -0,0 +1,17 @@ | |||
| import oneflow | |||
| import oneflow.nn.functional as F | |||
| def kldiv(logits, targets, T=1.0): | |||
| """ Cross Entropy for soft targets | |||
| Parameters: | |||
| - logits (Tensor): logits score (e.g. outputs of fc layer) | |||
| - targets (Tensor): logits of soft targets | |||
| - T (float): temperature of distill | |||
| - reduction: reduction to the output | |||
| """ | |||
| p_targets = F.softmax(targets/T, dim=1) | |||
| logp_logits = F.log_softmax(logits/T, dim=1) | |||
| kl_div = oneflow.nn.KLDivLoss(reduction="none") | |||
| kld = kl_div(logp_logits, p_targets) * (T**2) | |||
| return kld.sum(1).mean() | |||
| @@ -0,0 +1,8 @@ | |||
| from .functional import * | |||
| class KLDiv(object): | |||
| def __init__(self, T=1.0): | |||
| self.T = T | |||
| def __call__(self, logits, targets): | |||
| return kldiv( logits, targets, T=self.T ) | |||
| @@ -0,0 +1,180 @@ | |||
| # Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================= | |||
| import abc | |||
| import oneflow.nn as nn | |||
| import oneflow.nn.functional as F | |||
| from typing import Callable, Dict, Any | |||
| from . import loss | |||
| from kamal.core import metrics | |||
| from kamal.core.attach import AttachTo | |||
| class Task(object): | |||
| def __init__(self, name): | |||
| self.name = name | |||
| @abc.abstractmethod | |||
| def get_loss( self, outputs, targets ) -> Dict: | |||
| pass | |||
| @abc.abstractmethod | |||
| def predict(self, outputs) -> Any: | |||
| pass | |||
| class GeneralTask(Task): | |||
| def __init__(self, | |||
| name: str, | |||
| loss_fn: Callable, | |||
| scaling:float=1.0, | |||
| pred_fn: Callable=lambda x: x, | |||
| attach_to=None): | |||
| super(GeneralTask, self).__init__(name) | |||
| self._attach = AttachTo(attach_to) | |||
| self.loss_fn = loss_fn | |||
| self.pred_fn = pred_fn | |||
| self.scaling = scaling | |||
| def get_loss(self, outputs, targets): | |||
| outputs, targets = self._attach(outputs, targets) | |||
| return { self.name: self.loss_fn( outputs, targets ) * self.scaling } | |||
| def predict(self, outputs): | |||
| outputs = self._attach(outputs) | |||
| return self.pred_fn(outputs) | |||
| def __repr__(self): | |||
| rep = "Task: [%s loss_fn=%s scaling=%.4f attach=%s]"%(self.name, str(self.loss_fn), self.scaling, self._attach) | |||
| return rep | |||
| class TaskCompose(list): | |||
| def __init__(self, tasks: list): | |||
| for task in tasks: | |||
| if isinstance(task, Task): | |||
| self.append(task) | |||
| def get_loss(self, outputs, targets): | |||
| loss_dict = {} | |||
| for task in self: | |||
| loss_dict.update( task.get_loss( outputs, targets ) ) | |||
| return loss_dict | |||
| def predict(self, outputs): | |||
| results = [] | |||
| for task in self: | |||
| results.append( task.predict( outputs ) ) | |||
| return results | |||
| def __repr__(self): | |||
| rep="TaskCompose: \n" | |||
| for task in self: | |||
| rep+="\t%s\n"%task | |||
| class StandardTask: | |||
| @staticmethod | |||
| def classification(name='ce', scaling=1.0, attach_to=None): | |||
| return GeneralTask( name=name, | |||
| loss_fn=nn.CrossEntropyLoss(), | |||
| scaling=scaling, | |||
| pred_fn=lambda x: x.max(1)[1], | |||
| attach_to=attach_to ) | |||
| @staticmethod | |||
| def binary_classification(name='bce', scaling=1.0, attach_to=None): | |||
| return GeneralTask(name=name, | |||
| loss_fn=F.binary_cross_entropy_with_logits, | |||
| scaling=scaling, | |||
| pred_fn=lambda x: (x>0.5), | |||
| attach_to=attach_to ) | |||
| @staticmethod | |||
| def regression(name='mse', scaling=1.0, attach_to=None): | |||
| return GeneralTask(name=name, | |||
| loss_fn=nn.MSELoss(), | |||
| scaling=scaling, | |||
| pred_fn=lambda x: x, | |||
| attach_to=attach_to ) | |||
| @staticmethod | |||
| def segmentation(name='ce', scaling=1.0, attach_to=None): | |||
| return GeneralTask(name=name, | |||
| loss_fn=nn.CrossEntropyLoss(ignore_index=255), | |||
| scaling=scaling, | |||
| pred_fn=lambda x: x.max(1)[1], | |||
| attach_to=attach_to ) | |||
| @staticmethod | |||
| def monocular_depth(name='l1', scaling=1.0, attach_to=None): | |||
| return GeneralTask(name=name, | |||
| loss_fn=nn.L1Loss(), | |||
| scaling=scaling, | |||
| pred_fn=lambda x: x, | |||
| attach_to=attach_to) | |||
| @staticmethod | |||
| def detection(): | |||
| raise NotImplementedError | |||
| @staticmethod | |||
| def distillation(name='kld', T=1.0, scaling=1.0, attach_to=None): | |||
| return GeneralTask(name=name, | |||
| loss_fn=loss.KLDiv(T=T), | |||
| scaling=scaling, | |||
| pred_fn=lambda x: x.max(1)[1], | |||
| attach_to=attach_to) | |||
| class StandardMetrics(object): | |||
| @staticmethod | |||
| def classification(attach_to=None): | |||
| return metrics.MetricCompose( | |||
| metric_dict={'acc': metrics.Accuracy(attach_to=attach_to)} | |||
| ) | |||
| @staticmethod | |||
| def regression(attach_to=None): | |||
| return metrics.MetricCompose( | |||
| metric_dict={'mse': metrics.MeanSquaredError(attach_to=attach_to)} | |||
| ) | |||
| @staticmethod | |||
| def segmentation(num_classes, ignore_idx=255, attach_to=None): | |||
| confusion_matrix = metrics.ConfusionMatrix(num_classes=num_classes, ignore_idx=ignore_idx, attach_to=attach_to) | |||
| return metrics.MetricCompose( | |||
| metric_dict={'acc': metrics.Accuracy(attach_to=attach_to), | |||
| 'confusion_matrix': confusion_matrix , | |||
| 'miou': metrics.mIoU(confusion_matrix)} | |||
| ) | |||
| @staticmethod | |||
| def monocular_depth(attach_to=None): | |||
| return metrics.MetricCompose( | |||
| metric_dict={ | |||
| 'rmse': metrics.RootMeanSquaredError(attach_to=attach_to), | |||
| 'rmse_log': metrics.RootMeanSquaredError( log_scale=True,attach_to=attach_to ), | |||
| 'rmse_scale_inv': metrics.ScaleInveriantMeanSquaredError(attach_to=attach_to), | |||
| 'abs rel': metrics.AbsoluteRelativeDifference(attach_to=attach_to), | |||
| 'sq rel': metrics.SquaredRelativeDifference(attach_to=attach_to), | |||
| 'percents within thresholds': metrics.Threshold( thresholds=[1.25, 1.25**2, 1.25**3], attach_to=attach_to ) | |||
| } | |||
| ) | |||
| @staticmethod | |||
| def loss_metric(loss_fn): | |||
| return metrics.MetricCompose( | |||
| metric_dict={ | |||
| 'loss': metrics.AverageMetric( loss_fn ) | |||
| } | |||
| ) | |||
| @@ -0,0 +1,2 @@ | |||
| from ._utils import * | |||
| from .logger import get_logger | |||
| @@ -0,0 +1,45 @@ | |||
| import oneflow | |||
| import contextlib | |||
| import oneflow.nn as nn | |||
| def split_batch(batch): | |||
| if isinstance(batch, (list, tuple)): | |||
| inputs, *targets = batch | |||
| if len(targets)==1: | |||
| targets = targets[0] | |||
| return inputs, targets | |||
| else: | |||
| return [batch, None] | |||
| @contextlib.contextmanager | |||
| def set_mode(model, training=True): | |||
| ori_mode = model.training | |||
| model.train(training) | |||
| yield | |||
| model.train(ori_mode) | |||
| def move_to_device(obj, device): | |||
| if isinstance(obj, oneflow.Tensor): | |||
| return obj.to(device=device) | |||
| elif isinstance( obj, (list, tuple) ): | |||
| return [ o.to(device=device) for o in obj ] | |||
| elif isinstance(obj, nn.Module): | |||
| return obj.to(device=device) | |||
| def flatten_dict(dic): | |||
| flattned = dict() | |||
| def _flatten(prefix, d): | |||
| for k, v in d.items(): | |||
| if isinstance(v, dict): | |||
| if prefix is None: | |||
| _flatten( k, v ) | |||
| else: | |||
| _flatten( prefix+'%s/'%k, v ) | |||
| else: | |||
| flattned[ (prefix+'%s/'%k).strip('/') ] = v | |||
| _flatten('', dic) | |||
| return flattned | |||
| @@ -0,0 +1,56 @@ | |||
| import logging | |||
| import os, sys | |||
| from termcolor import colored | |||
| class _ColorfulFormatter(logging.Formatter): | |||
| def __init__(self, *args, **kwargs): | |||
| super(_ColorfulFormatter, self).__init__(*args, **kwargs) | |||
| def formatMessage(self, record): | |||
| log = super(_ColorfulFormatter, self).formatMessage(record) | |||
| if record.levelno == logging.WARNING: | |||
| prefix = colored("WARNING", "yellow", attrs=["blink"]) | |||
| elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: | |||
| prefix = colored("ERROR", "red", attrs=["blink", "underline"]) | |||
| else: | |||
| return log | |||
| return prefix + " " + log | |||
| def get_logger(name='Kamal', output=None, color=True): | |||
| logger = logging.getLogger(name) | |||
| logger.setLevel(logging.DEBUG) | |||
| logger.propagate = False | |||
| # STDOUT | |||
| stdout_handler = logging.StreamHandler( stream=sys.stdout ) | |||
| stdout_handler.setLevel( logging.DEBUG ) | |||
| plain_formatter = logging.Formatter( | |||
| "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" ) | |||
| if color: | |||
| formatter = _ColorfulFormatter( | |||
| colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", | |||
| datefmt="%m/%d %H:%M:%S") | |||
| else: | |||
| formatter = plain_formatter | |||
| stdout_handler.setFormatter(formatter) | |||
| logger.addHandler(stdout_handler) | |||
| # FILE | |||
| if output is not None: | |||
| if output.endswith('.txt') or output.endswith('.log'): | |||
| os.makedirs(os.path.dirname(output), exist_ok=True) | |||
| filename = output | |||
| else: | |||
| os.makedirs(output, exist_ok=True) | |||
| filename = os.path.join(output, "log.txt") | |||
| file_handler = logging.FileHandler(filename) | |||
| file_handler.setFormatter(plain_formatter) | |||
| file_handler.setLevel(logging.DEBUG) | |||
| logger.addHandler(file_handler) | |||
| return logger | |||
| @@ -0,0 +1,3 @@ | |||
| from . import models | |||
| from . import datasets | |||
| from . import sync_transforms | |||
| @@ -0,0 +1,5 @@ | |||
| from .stanford_cars import StanfordCars | |||
| from .stanford_dogs import StanfordDogs | |||
| from .flowers102 import Flowers102 | |||
| from .sun397 import SUN397 | |||
| from .fgvc_aircraft import FGVCAircraft | |||
| @@ -0,0 +1,152 @@ | |||
| import oneflow.utils.data as data | |||
| from flowvision.datasets.folder import default_loader | |||
| import os | |||
| import numpy as np | |||
| import random | |||
| def make_dataset(dir, image_ids, targets): | |||
| assert(len(image_ids) == len(targets)) | |||
| images = [] | |||
| dir = os.path.expanduser(dir) | |||
| for i in range(len(image_ids)): | |||
| item = (os.path.join(dir, 'fgvc-aircraft-2013b', 'data', 'images', | |||
| '%s.jpg' % image_ids[i]), targets[i]) | |||
| images.append(item) | |||
| return images | |||
| def find_classes(classes_file): | |||
| # read classes file, separating out image IDs and class names | |||
| image_ids = [] | |||
| targets = [] | |||
| f = open(classes_file, 'r') | |||
| for line in f: | |||
| split_line = line.split(' ') | |||
| image_ids.append(split_line[0]) | |||
| targets.append(' '.join(split_line[1:])) | |||
| f.close() | |||
| # index class names | |||
| classes = np.unique(targets) | |||
| class_to_idx = {classes[i]: i for i in range(len(classes))} | |||
| targets = [class_to_idx[c] for c in targets] | |||
| return (image_ids, targets, classes, class_to_idx) | |||
| class FGVCAircraft(data.Dataset): | |||
| """`FGVC-Aircraft <http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft>`_ Dataset. | |||
| Args: | |||
| root (string): Root directory path to dataset. | |||
| class_type (string, optional): The level of FGVC-Aircraft fine-grain classification | |||
| to label data with (i.e., ``variant``, ``family``, or ``manufacturer``). | |||
| transforms (callable, optional): A function/transforms that takes in a PIL image | |||
| and returns a transformed version. E.g. ``transforms.RandomCrop`` | |||
| target_transform (callable, optional): A function/transforms that takes in the | |||
| target and transforms it. | |||
| loader (callable, optional): A function to load an image given its path. | |||
| download (bool, optional): If true, downloads the dataset from the internet and | |||
| puts it in the root directory. If dataset is already downloaded, it is not | |||
| downloaded again. | |||
| """ | |||
| url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz' | |||
| class_types = ('variant', 'family', 'manufacturer') | |||
| splits = ('train', 'val', 'trainval', 'test') | |||
| def __init__(self, root, class_type='variant', split='train', s=0.5, transform=None, | |||
| target_transform=None, loader=default_loader, download=False): | |||
| if split not in self.splits: | |||
| raise ValueError('Split "{}" not found. Valid splits are: {}'.format( | |||
| split, ', '.join(self.splits), | |||
| )) | |||
| if class_type not in self.class_types: | |||
| raise ValueError('Class type "{}" not found. Valid class types are: {}'.format( | |||
| class_type, ', '.join(self.class_types), | |||
| )) | |||
| self.root = root | |||
| self.class_type = class_type | |||
| self.split = split | |||
| self.classes_file = os.path.join(self.root, 'fgvc-aircraft-2013b', 'data', | |||
| 'images_%s_%s.txt' % (self.class_type, self.split)) | |||
| (image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file) | |||
| """if split == 'trainval': | |||
| self.image_ids = image_ids | |||
| self.targets = targets | |||
| self.image_ids, self.targets = self.sample_by_class(s=s) | |||
| image_ids = self.image_ids | |||
| targets = self.targets""" | |||
| samples = make_dataset(self.root, image_ids, targets) | |||
| self.transform = transform | |||
| self.target_transform = target_transform | |||
| self.loader = loader | |||
| self.samples = samples | |||
| self.classes = classes | |||
| self.class_to_idx = class_to_idx | |||
| with open(os.path.join(self.root, 'fgvc-aircraft-2013b/data', 'variants.txt')) as f: | |||
| self.object_categories = [ | |||
| line.strip('\n') for line in f.readlines()] | |||
| print('FGVC-Aircraft, Split: %s, Size: %d' % (self.split, self.__len__())) | |||
| def __getitem__(self, index): | |||
| """ | |||
| Args: | |||
| index (int): Index | |||
| Returns: | |||
| tuple: (sample, target) where target is class_index of the target class. | |||
| """ | |||
| path, target = self.samples[index] | |||
| sample = self.loader(path) | |||
| if self.transform is not None: | |||
| sample = self.transform(sample) | |||
| if self.target_transform is not None: | |||
| target = self.target_transform(target) | |||
| return sample, target | |||
| def __len__(self): | |||
| return len(self.samples) | |||
| def __repr__(self): | |||
| fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' | |||
| fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) | |||
| fmt_str += ' Root Location: {}\n'.format(self.root) | |||
| tmp = ' Transforms (if any): ' | |||
| fmt_str += '{0}{1}\n'.format( | |||
| tmp, self.transforms.__repr__().replace('\n', '\n' + ' ' * len(tmp))) | |||
| tmp = ' Target Transforms (if any): ' | |||
| fmt_str += '{0}{1}'.format( | |||
| tmp, self.target_transforms.__repr__().replace('\n', '\n' + ' ' * len(tmp))) | |||
| return fmt_str | |||
| def _check_exists(self): | |||
| return os.path.exists(os.path.join(self.root, 'data', 'images')) and \ | |||
| os.path.exists(self.classes_file) | |||
| def sample_by_class(self, s): | |||
| class_dit = {} | |||
| image_dit = {} | |||
| for class_name, image in zip(self.targets,self.image_ids): | |||
| if class_name not in class_dit.keys(): | |||
| class_dit[class_name] = [] | |||
| image_dit[class_name] = [] | |||
| class_dit[class_name].append(class_name) | |||
| image_dit[class_name].append(image) | |||
| labels, images = [], [] | |||
| for key in class_dit.keys(): | |||
| n1 = len(class_dit[key]) | |||
| n2 = len(image_dit[key]) | |||
| assert n1 == n2, "{} not equal {}".format(n1, n2) | |||
| random.shuffle(image_dit[key]) | |||
| labels += class_dit[key][:int(n1*s)] | |||
| images += image_dit[key][:int(n1*s)] | |||
| return images, labels | |||
| @@ -0,0 +1,11 @@ | |||
| import flowvision | |||
| import os | |||
| class Flowers102(flowvision.datasets.ImageFolder): | |||
| def __init__(self, root, split='train'): | |||
| self.split = split | |||
| self.num_classes = 102 | |||
| path = os.path.join(root, 'train') if self.split=='train' else os.path.join(root, 'valid') | |||
| super().__init__(path) | |||
| print('Flowers-102, Split: %s, Size: %d' % (self.split, len(self.imgs))) | |||
| @@ -0,0 +1,81 @@ | |||
| import os | |||
| import glob | |||
| from PIL import Image | |||
| import numpy as np | |||
| from scipy.io import loadmat | |||
| from oneflow.utils import data | |||
| import random | |||
| class StanfordCars(data.Dataset): | |||
| """Dataset for Stanford Cars | |||
| """ | |||
| urls = {'cars_train.tgz': 'http://imagenet.stanford.edu/internal/car196/cars_train.tgz', | |||
| 'cars_test.tgz': 'http://imagenet.stanford.edu/internal/car196/cars_test.tgz', | |||
| 'car_devkit.tgz': 'https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz', | |||
| 'cars_test_annos_withlabels.mat': 'http://imagenet.stanford.edu/internal/car196/cars_test_annos_withlabels.mat'} | |||
| def __init__(self, root, split='train', s=0.5, download=False, transform=None, target_transform=None): | |||
| self.root = os.path.abspath( os.path.expanduser(root) ) | |||
| self.split = split | |||
| self.transform = transform | |||
| self.target_transform = target_transform | |||
| if download: | |||
| self.download() | |||
| if self.split == 'train': | |||
| annos = os.path.join(self.root, 'devkit', 'cars_train_annos.mat') | |||
| else: | |||
| annos = os.path.join(self.root, 'devkit', | |||
| 'cars_test_annos_withlabels.mat') | |||
| annos = loadmat(annos) | |||
| size = len(annos['annotations'][0]) | |||
| self.files = glob.glob(os.path.join( | |||
| self.root, 'cars_'+self.split, '*.jpg')) | |||
| self.files.sort() | |||
| """if split == 'train': | |||
| self.files = self.sample_by_class(s=s)""" | |||
| self.labels = np.array([int(l[4])-1 for l in annos['annotations'][0]]) | |||
| lbl_annos = loadmat(os.path.join(self.root, 'devkit', 'cars_meta.mat')) | |||
| self.object_categories = [str(c[0]) | |||
| for c in lbl_annos['class_names'][0]] | |||
| print('Stanford Cars, Split: %s, Size: %d' % | |||
| (self.split, self.__len__())) | |||
| def __len__(self): | |||
| return len(self.files) | |||
| def __getitem__(self, idx): | |||
| img = Image.open(os.path.join(self.root, 'Images', | |||
| self.files[idx])).convert("RGB") | |||
| lbl = self.labels[idx] | |||
| if self.transform is not None: | |||
| img = self.transform(img) | |||
| if self.target_transform is not None: | |||
| lbl = self.target_transform(lbl) | |||
| return img, lbl | |||
| def sample_by_class(self, s): | |||
| class_dit = {} | |||
| for file in self.files: | |||
| class_name = file.split('/')[0] | |||
| if class_name not in class_dit.keys(): | |||
| class_dit[class_name] = [] | |||
| class_dit[class_name].append(file) | |||
| files = [] | |||
| for key in class_dit.keys(): | |||
| n = len(class_dit[key]) | |||
| random.shuffle(class_dit[key]) | |||
| files += class_dit[key][:int(n*s)] | |||
| return files | |||
| @@ -0,0 +1,67 @@ | |||
| import os | |||
| import numpy as np | |||
| from PIL import Image | |||
| from scipy.io import loadmat | |||
| from oneflow.utils import data | |||
| import random | |||
| class StanfordDogs(data.Dataset): | |||
| """Dataset for Stanford Dogs | |||
| """ | |||
| urls = {"images.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/images.tar", | |||
| "annotation.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/annotation.tar", | |||
| "lists.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/lists.tar"} | |||
| def __init__(self, root, split='train', s=0.5, download=False, transform=None, target_transform=None): | |||
| self.root = os.path.abspath( os.path.expanduser(root) ) | |||
| self.split = split | |||
| self.transform = transform | |||
| self.target_transform = target_transform | |||
| if download: | |||
| self.download() | |||
| list_file = os.path.join(self.root, 'lists', self.split+'_list.mat') | |||
| mat_file = loadmat(list_file) | |||
| size = len(mat_file['file_list']) | |||
| self.files = [str(mat_file['file_list'][i][0][0]) for i in range(size)] | |||
| """if split == 'train': | |||
| self.files = self.sample_by_class(s=s)""" | |||
| self.labels = np.array( | |||
| [mat_file['labels'][i][0]-1 for i in range(size)]) | |||
| categories = os.listdir(os.path.join(self.root, 'Images')) | |||
| categories.sort() | |||
| self.object_categories = [c[10:] for c in categories] | |||
| print('Stanford Dogs, Split: %s, Size: %d' % | |||
| (self.split, self.__len__())) | |||
| def __len__(self): | |||
| return len(self.files) | |||
| def __getitem__(self, idx): | |||
| img = Image.open(os.path.join(self.root, 'Images', | |||
| self.files[idx])).convert("RGB") | |||
| lbl = self.labels[idx] | |||
| if self.transform is not None: | |||
| img = self.transform(img) | |||
| if self.target_transform is not None: | |||
| lbl = self.target_transform( lbl ) | |||
| return img, lbl | |||
| def sample_by_class(self, s): | |||
| class_dit = {} | |||
| for file in self.files: | |||
| class_name = file.split('/')[0] | |||
| if class_name not in class_dit.keys(): | |||
| class_dit[class_name] = [] | |||
| class_dit[class_name].append(file) | |||
| files = [] | |||
| for key in class_dit.keys(): | |||
| n = len(class_dit[key]) | |||
| random.shuffle(class_dit[key]) | |||
| files += class_dit[key][:int(n*s)] | |||
| return files | |||
| @@ -0,0 +1,11 @@ | |||
| import flowvision | |||
| import os | |||
| class SUN397(flowvision.datasets.ImageFolder): | |||
| def __init__(self, root, split='train'): | |||
| self.split = split | |||
| self.num_classes = 397 | |||
| path = os.path.join(root, 'train') if self.split=='train' else os.path.join(root, 'test') | |||
| super().__init__(path) | |||
| print('SUN397, Split: %s, Size: %d' % (self.split, len(self.imgs))) | |||
| @@ -0,0 +1 @@ | |||
| from . import classification | |||
| @@ -0,0 +1 @@ | |||
| from .resnet import * | |||