@@ -0,0 +1,40 @@ | |||||
# 模型炼知(Model Knowledge Amalgamation, KA) | |||||
## Introduction | |||||
模型炼知是用于知识融合和模型迁移性度量,而建立的轻量级算法包,详情请见 [KAE](https://github.com/zju-vipa/KamalEngine/tree/master/examples)。 | |||||
## Table of contents | |||||
* [Introduction](#Introduction) | |||||
* [Algorithms](#Algorithms) | |||||
* [Authors](#Authors) | |||||
## Algorithms | |||||
#### 0. 数据集 | |||||
* data/StanfordCars: https://ai.stanford.edu/~jkrause/cars/car_dataset.html | |||||
* data/StanfordDogs: http://vision.stanford.edu/aditya86/ImageNetDogs/ | |||||
* data/FGVCAircraft: http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/ | |||||
* data/Flowers102: http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/ | |||||
#### 1. 依赖包 | |||||
* [requirements.txt](requirements.txt) | |||||
#### 2. 算法 | |||||
* TW_Layerwise | |||||
```bash | |||||
python3 TTL.py --car_ckpt ./ckpt/car_res50_model stanford_dogs --dog_ckpt ./ckpt/dog_res50_model | |||||
``` | |||||
* TH_Layerwise | |||||
```bash | |||||
python3 THL.py --car_ckpt ./ckpt/car_res50_model stanford_dogs --dog_ckpt ./ckpt/dog_res50_model --aircraft_ckpt ./ckpt/aircraft_res50_model | |||||
``` | |||||
* TF_Layerwise | |||||
```bash | |||||
python3 TFL.py --car_ckpt ./ckpt/car_res50_model stanford_dogs --dog_ckpt ./ckpt/dog_res50_model --aircraft_ckpt ./ckpt/aircraft_res50_model --flower_ckpt ./ckpt/flower_res50_model | |||||
``` | |||||
* flow | |||||
```bash | |||||
pytorch weights convert oneflow | |||||
``` | |||||
#### 3. 权重 | |||||
文件夹MODEL_ckpt保存的是三个模型重组/炼知算法的权重,oneflow格式,供复现时参考。ckpt文件夹为pytorch权重转为oneflow权重,用于教师模型加载权重,内存占用较大,改为放在下方链接。 | |||||
* 教师模型权重 [链接]:https://pan.baidu.com/s/19pRLZKQvjgEQ7CuD_RiPPw 提取码:FLOW | |||||
## Authors | |||||
This project is simplified by KAE, which is developed by [VIPA Lab](http://vipazoo.cn) from Zhejiang University and [Zhejiang Lab](http://www.zhejianglab.com/) |
@@ -0,0 +1,124 @@ | |||||
from kamal import vision, engine, utils, amalgamation, metrics, callbacks | |||||
from kamal.vision import sync_transforms as sT | |||||
import pdb | |||||
import oneflow | |||||
import argparse | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument('--car_ckpt', default="./ckpt/car_res50_model") | |||||
parser.add_argument('--dog_ckpt', default="./ckpt/dog_res50_model") | |||||
parser.add_argument('--aircraft_ckpt', default="./ckpt/aircraft_res50_model") | |||||
parser.add_argument('--flower_ckpt', default="./ckpt/flower_res50_model") | |||||
parser.add_argument('--lr', type=float, default=1e-3) | |||||
parser.add_argument('--batch_size', type=int, default=16) | |||||
args = parser.parse_args() | |||||
def main(): | |||||
# 数据集 | |||||
car_train_dst = vision.datasets.StanfordCars('./DataSets/StanfordCars/', split='train') | |||||
car_val_dst = vision.datasets.StanfordCars('./DataSets/StanfordCars/', split='test') | |||||
dog_train_dst = vision.datasets.StanfordDogs('./DataSets/StanfordDogs/', split='train') | |||||
dog_val_dst = vision.datasets.StanfordDogs('./DataSets/StanfordDogs/', split='test') | |||||
aircraft_train_dst = vision.datasets.FGVCAircraft('./DataSets/FGVCAircraft/', split='trainval') | |||||
aircraft_val_dst = vision.datasets.FGVCAircraft('./DataSets/FGVCAircraft/', split='test') | |||||
flower_train_dst = vision.datasets.Flowers102('./DataSets/Flower102/', split='train') | |||||
flower_val_dst = vision.datasets.Flowers102('./DataSets/Flower102/', split='valid') | |||||
# 教师/学生 | |||||
car_teacher = vision.models.classification.resnet50(num_classes=196, pretrained=False) | |||||
dog_teacher = vision.models.classification.resnet50(num_classes=120, pretrained=False) | |||||
aircraft_teacher = vision.models.classification.resnet50(num_classes=102, pretrained=False) | |||||
flower_teacher = vision.models.classification.resnet50(num_classes=102, pretrained=False) | |||||
student = vision.models.classification.resnet50(num_classes=196+120+102+102, pretrained=False) | |||||
# 权重参数 | |||||
cars_parameters = oneflow.load(args.car_ckpt) | |||||
dogs_parameters = oneflow.load(args.dog_ckpt) | |||||
aircraft_parameters = oneflow.load(args.aircraft_ckpt) | |||||
flowers_parameters = oneflow.load(args.flower_ckpt) | |||||
car_teacher.load_state_dict(cars_parameters) | |||||
dog_teacher.load_state_dict(dogs_parameters) | |||||
aircraft_teacher.load_state_dict(aircraft_parameters) | |||||
flower_teacher.load_state_dict(flowers_parameters) | |||||
train_transform = sT.Compose( [ | |||||
sT.RandomResizedCrop(224), | |||||
sT.RandomHorizontalFlip(), | |||||
sT.ToTensor(), | |||||
sT.Normalize( mean=[0.485, 0.456, 0.406], | |||||
std=[0.229, 0.224, 0.225] ) | |||||
] ) | |||||
val_transform = sT.Compose( [ | |||||
sT.Resize(256), | |||||
sT.CenterCrop( 224 ), | |||||
sT.ToTensor(), | |||||
sT.Normalize( mean=[0.485, 0.456, 0.406], | |||||
std=[0.229, 0.224, 0.225] ) | |||||
] ) | |||||
car_train_dst.transform = dog_train_dst.transform = aircraft_train_dst.transform = flower_train_dst.transform = train_transform | |||||
car_val_dst.transform = dog_val_dst.transform = aircraft_val_dst.transform = flower_val_dst.transform = val_transform | |||||
car_metric = metrics.MetricCompose(metric_dict={'car_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, :196], t))}) | |||||
dog_metric = metrics.MetricCompose(metric_dict={'dog_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, 196:196+120], t))}) | |||||
aircraft_metric = metrics.MetricCompose(metric_dict={'aircraft_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, 196+120:196+120+102], t))}) | |||||
flower_metric = metrics.MetricCompose(metric_dict={'flower_acc': metrics.Accuracy(attach_to=lambda o, t: ( o[:, 196+120+102:196+120+102+102], t ) ) } ) | |||||
train_dst = oneflow.utils.data.ConcatDataset([car_train_dst, dog_train_dst, aircraft_train_dst, flower_train_dst]) | |||||
train_loader = oneflow.utils.data.DataLoader(train_dst, batch_size=args.batch_size, shuffle=True, num_workers=0) | |||||
car_loader = oneflow.utils.data.DataLoader(car_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0) | |||||
dog_loader = oneflow.utils.data.DataLoader(dog_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0) | |||||
aircraft_loader = oneflow.utils.data.DataLoader(aircraft_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=True) | |||||
flower_loader = oneflow.utils.data.DataLoader(flower_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0) | |||||
car_evaluator = engine.evaluator.BasicEvaluator(car_loader, car_metric) | |||||
dog_evaluator = engine.evaluator.BasicEvaluator(dog_loader, dog_metric) | |||||
aircraft_evaluator = engine.evaluator.BasicEvaluator(aircraft_loader, aircraft_metric) | |||||
flower_evaluator = engine.evaluator.BasicEvaluator(flower_loader, flower_metric) | |||||
TOTAL_ITERS=len(train_loader) * 100 | |||||
device = oneflow.device( 'cuda' if oneflow.cuda.is_available() else 'cpu' ) | |||||
optim = oneflow.optim.Adam( student.parameters(), lr=args.lr, weight_decay=1e-4) | |||||
sched = oneflow.optim.lr_scheduler.CosineAnnealingLR( optim, T_max=TOTAL_ITERS ) | |||||
trainer = amalgamation.LayerWiseAmalgamator( | |||||
logger=utils.logger.get_logger('layerwise-ka'), | |||||
# tb_writer=SummaryWriter( log_dir='run/layerwise_ka-%s'%( time.asctime().replace( ' ', '_' ) ) ) | |||||
) | |||||
trainer.add_callback( | |||||
engine.DefaultEvents.AFTER_STEP(every=10), | |||||
callbacks=callbacks.MetricsLogging(keys=('total_loss', 'loss_kd', 'loss_amal', 'loss_recons', 'lr'))) | |||||
trainer.add_callback( | |||||
engine.DefaultEvents.AFTER_EPOCH, | |||||
callbacks=[ | |||||
callbacks.EvalAndCkpt(model=student, evaluator=car_evaluator, metric_name='car_acc', ckpt_prefix='tfl_car'), | |||||
callbacks.EvalAndCkpt(model=student, evaluator=dog_evaluator, metric_name='dog_acc', ckpt_prefix='tfl_dog'), | |||||
callbacks.EvalAndCkpt(model=student, evaluator=aircraft_evaluator, metric_name='aircraft_acc', ckpt_prefix='tfl_aircraft'), | |||||
callbacks.EvalAndCkpt(model=student, evaluator=flower_evaluator, metric_name='flower_acc', ckpt_prefix='tfl_flower'), | |||||
] ) | |||||
trainer.add_callback( | |||||
engine.DefaultEvents.AFTER_STEP, | |||||
callbacks=callbacks.LRSchedulerCallback(schedulers=[sched])) | |||||
layer_groups = [] | |||||
layer_channels = [] | |||||
for stu_block, car_block, dog_block, aircraft_block, flower_block in zip( student.modules(), car_teacher.modules(), dog_teacher.modules(), aircraft_teacher.modules(), flower_teacher.modules() ): | |||||
if isinstance( stu_block, oneflow.nn.Conv2d ): | |||||
layer_groups.append( [ stu_block, car_block, dog_block, aircraft_block, flower_block ] ) | |||||
layer_channels.append( [ stu_block.out_channels, car_block.out_channels, dog_block.out_channels, aircraft_block.out_channels, flower_block.out_channels ] ) | |||||
trainer.setup( student=student, | |||||
teachers=[car_teacher, dog_teacher, aircraft_teacher, flower_teacher], | |||||
layer_groups=layer_groups, | |||||
layer_channels=layer_channels, | |||||
dataloader=train_loader, | |||||
optimizer=optim, | |||||
device=device, | |||||
weights=[1., 1., 1.] ) | |||||
trainer.run(start_iter=0, max_iter=TOTAL_ITERS) | |||||
if __name__=='__main__': | |||||
main() |
@@ -0,0 +1,116 @@ | |||||
from kamal import vision, engine, utils, amalgamation, metrics, callbacks | |||||
from kamal.vision import sync_transforms as sT | |||||
import pdb | |||||
import oneflow | |||||
import argparse | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument('--car_ckpt', default="./ckpt/car_res50_model") | |||||
parser.add_argument('--dog_ckpt', default="./ckpt/dog_res50_model") | |||||
parser.add_argument('--aircraft_ckpt', default="./ckpt/aircraft_res50_model") | |||||
parser.add_argument('--lr', type=float, default=1e-3) | |||||
parser.add_argument('--batch_size', type=int, default=32) | |||||
args = parser.parse_args() | |||||
def main(): | |||||
# 数据集 | |||||
car_train_dst = vision.datasets.StanfordCars('./DataSets/StanfordCars/', split='train') | |||||
car_val_dst = vision.datasets.StanfordCars('./DataSets/StanfordCars/', split='test') | |||||
dog_train_dst = vision.datasets.StanfordDogs('./DataSets/StanfordDogs/', split='train') | |||||
dog_val_dst = vision.datasets.StanfordDogs('./DataSets/StanfordDogs/', split='test') | |||||
aircraft_train_dst = vision.datasets.FGVCAircraft('./DataSets/FGVCAircraft/', split='trainval') | |||||
aircraft_val_dst = vision.datasets.FGVCAircraft('./DataSets/FGVCAircraft/', split='test') | |||||
# 教师/学生 | |||||
car_teacher = vision.models.classification.resnet50(num_classes=196, pretrained=False) | |||||
dog_teacher = vision.models.classification.resnet50(num_classes=120, pretrained=False) | |||||
aircraft_teacher = vision.models.classification.resnet50(num_classes=102, pretrained=False) | |||||
student = vision.models.classification.resnet50(num_classes=196+120+102, pretrained=False) | |||||
# 权重参数 | |||||
cars_parameters = oneflow.load(args.car_ckpt) | |||||
dogs_parameters = oneflow.load(args.dog_ckpt) | |||||
aircraft_parameters = oneflow.load(args.aircraft_ckpt) | |||||
car_teacher.load_state_dict(cars_parameters) | |||||
dog_teacher.load_state_dict(dogs_parameters) | |||||
aircraft_teacher.load_state_dict(aircraft_parameters) | |||||
train_transform = sT.Compose( [ | |||||
sT.RandomResizedCrop(224), | |||||
sT.RandomHorizontalFlip(), | |||||
sT.ToTensor(), | |||||
sT.Normalize( mean=[0.485, 0.456, 0.406], | |||||
std=[0.229, 0.224, 0.225] ) | |||||
] ) | |||||
val_transform = sT.Compose( [ | |||||
sT.Resize(256), | |||||
sT.CenterCrop( 224 ), | |||||
sT.ToTensor(), | |||||
sT.Normalize( mean=[0.485, 0.456, 0.406], | |||||
std=[0.229, 0.224, 0.225] ) | |||||
] ) | |||||
car_train_dst.transform = dog_train_dst.transform = aircraft_train_dst.transform = train_transform | |||||
car_val_dst.transform = dog_val_dst.transform = aircraft_val_dst.transform = val_transform | |||||
car_metric = metrics.MetricCompose(metric_dict={'car_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, :196], t))}) | |||||
dog_metric = metrics.MetricCompose(metric_dict={'dog_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, 196:196+120], t))}) | |||||
aircraft_metric = metrics.MetricCompose(metric_dict={'aircraft_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, 196+120:196+120+102], t))}) | |||||
train_dst = oneflow.utils.data.ConcatDataset( [car_train_dst, dog_train_dst, aircraft_train_dst] ) | |||||
# pdb.set_trace() | |||||
train_loader = oneflow.utils.data.DataLoader( train_dst, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True ) | |||||
car_loader = oneflow.utils.data.DataLoader( car_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=True ) | |||||
dog_loader = oneflow.utils.data.DataLoader( dog_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=True) | |||||
aircraft_loader = oneflow.utils.data.DataLoader( aircraft_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=True ) | |||||
car_evaluator = engine.evaluator.BasicEvaluator( car_loader, car_metric ) | |||||
dog_evaluator = engine.evaluator.BasicEvaluator( dog_loader, dog_metric ) | |||||
aircraft_evaluator = engine.evaluator.BasicEvaluator( aircraft_loader, aircraft_metric ) | |||||
# pdb.set_trace() | |||||
TOTAL_ITERS=len(train_loader) * 100 | |||||
device = oneflow.device( 'cuda' if oneflow.cuda.is_available() else 'cpu' ) | |||||
optim = oneflow.optim.Adam( student.parameters(), lr=args.lr, weight_decay=1e-4) | |||||
sched = oneflow.optim.lr_scheduler.CosineAnnealingLR( optim, T_max=TOTAL_ITERS ) | |||||
trainer = amalgamation.LayerWiseAmalgamator( | |||||
logger=utils.logger.get_logger('layerwise-ka'), | |||||
# tb_writer=SummaryWriter( log_dir='run/layerwise_ka-%s'%( time.asctime().replace( ' ', '_' ) ) ) | |||||
) | |||||
trainer.add_callback( | |||||
engine.DefaultEvents.AFTER_STEP(every=10), | |||||
callbacks=callbacks.MetricsLogging(keys=('total_loss', 'loss_kd', 'loss_amal', 'loss_recons', 'lr'))) | |||||
trainer.add_callback( | |||||
engine.DefaultEvents.AFTER_EPOCH, | |||||
callbacks=[ | |||||
callbacks.EvalAndCkpt(model=student, evaluator=car_evaluator, metric_name='car_acc', ckpt_prefix='thl_car'), | |||||
callbacks.EvalAndCkpt(model=student, evaluator=dog_evaluator, metric_name='dog_acc', ckpt_prefix='thl_dog'), | |||||
callbacks.EvalAndCkpt(model=student, evaluator=aircraft_evaluator, metric_name='aircraft_acc', ckpt_prefix='thl_aircraft'), | |||||
] ) | |||||
trainer.add_callback( | |||||
engine.DefaultEvents.AFTER_STEP, | |||||
callbacks=callbacks.LRSchedulerCallback(schedulers=[sched])) | |||||
layer_groups = [] | |||||
layer_channels = [] | |||||
for stu_block, car_block, dog_block, aircraft_block in zip( student.modules(), car_teacher.modules(), dog_teacher.modules(), aircraft_teacher.modules() ): | |||||
if isinstance( stu_block, oneflow.nn.Conv2d ): | |||||
layer_groups.append( [ stu_block, car_block, dog_block, aircraft_block ] ) | |||||
layer_channels.append( [ stu_block.out_channels, car_block.out_channels, dog_block.out_channels, aircraft_block.out_channels ] ) | |||||
trainer.setup( student=student, | |||||
teachers=[car_teacher, dog_teacher, aircraft_teacher], | |||||
layer_groups=layer_groups, | |||||
layer_channels=layer_channels, | |||||
dataloader=train_loader, | |||||
optimizer=optim, | |||||
device=device, | |||||
weights=[1., 1., 1.] ) | |||||
trainer.run(start_iter=0, max_iter=TOTAL_ITERS) | |||||
if __name__=='__main__': | |||||
main() |
@@ -0,0 +1,105 @@ | |||||
from kamal import vision, engine, utils, amalgamation, metrics, callbacks | |||||
from kamal.vision import sync_transforms as sT | |||||
import pdb | |||||
import oneflow | |||||
import argparse | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument( '--car_ckpt', default="./ckpt/car_res50_model") | |||||
parser.add_argument( '--dog_ckpt', default="./ckpt/dog_res50_model") | |||||
parser.add_argument( '--lr', type=float, default=1e-3) | |||||
args = parser.parse_args() | |||||
def main(): | |||||
# 数据集 | |||||
car_train_dst = vision.datasets.StanfordCars('./DataSets/StanfordCars/', split='train', s=0.1) | |||||
car_val_dst = vision.datasets.StanfordCars('./DataSets/StanfordCars/', split='test') | |||||
aircraft_train_dst = vision.datasets.StanfordDogs('./DataSets/StanfordDogs/', split='train', s=0.1) | |||||
aircraft_val_dst = vision.datasets.StanfordDogs('./DataSets/StanfordDogs/', split='test') | |||||
# 教师/学生 | |||||
car_teacher = vision.models.classification.resnet50(num_classes=196, pretrained=False) | |||||
dog_teacher = vision.models.classification.resnet50(num_classes=120, pretrained=False) | |||||
student = vision.models.classification.resnet50(num_classes=196+120, pretrained=False) | |||||
# 权重参数 | |||||
cars_parameters = oneflow.load(args.car_ckpt) | |||||
dogs_parameters = oneflow.load(args.dog_ckpt) | |||||
car_teacher.load_state_dict(cars_parameters) | |||||
dog_teacher.load_state_dict(dogs_parameters) | |||||
train_transform = sT.Compose( [ | |||||
sT.RandomResizedCrop(224), | |||||
sT.RandomHorizontalFlip(), | |||||
sT.ToTensor(), | |||||
sT.Normalize( mean=[0.485, 0.456, 0.406], | |||||
std=[0.229, 0.224, 0.225] ) | |||||
] ) | |||||
val_transform = sT.Compose( [ | |||||
sT.Resize(256), | |||||
sT.CenterCrop( 224 ), | |||||
sT.ToTensor(), | |||||
sT.Normalize( mean=[0.485, 0.456, 0.406], | |||||
std=[0.229, 0.224, 0.225] ) | |||||
] ) | |||||
car_train_dst.transform = aircraft_train_dst.transform = train_transform | |||||
car_val_dst.transform = aircraft_val_dst.transform = val_transform | |||||
car_metric = metrics.MetricCompose(metric_dict={ 'car_acc': metrics.Accuracy(attach_to=lambda o, t: ( o[:, :196], t ) ) } ) | |||||
aircraft_metric = metrics.MetricCompose(metric_dict={ 'aircraft_acc': metrics.Accuracy(attach_to=lambda o, t: ( o[:, 196:], t ) ) } ) | |||||
train_dst = oneflow.utils.data.ConcatDataset( [car_train_dst, aircraft_train_dst] ) | |||||
# pdb.set_trace() | |||||
train_loader = oneflow.utils.data.DataLoader( train_dst, batch_size=32, shuffle=True, num_workers=0 ) | |||||
car_loader = oneflow.utils.data.DataLoader( car_val_dst, batch_size=32, shuffle=False, num_workers=0 ) | |||||
aircraft_loader = oneflow.utils.data.DataLoader( aircraft_val_dst, batch_size=32, shuffle=False, num_workers=0 ) | |||||
car_evaluator = engine.evaluator.BasicEvaluator( car_loader, car_metric ) | |||||
aircraft_evaluator = engine.evaluator.BasicEvaluator( aircraft_loader, aircraft_metric ) | |||||
TOTAL_ITERS=len(train_loader) * 100 | |||||
device = oneflow.device( 'cuda' if oneflow.cuda.is_available() else 'cpu' ) | |||||
optim = oneflow.optim.Adam( student.parameters(), lr=args.lr, weight_decay=1e-4) | |||||
sched = oneflow.optim.lr_scheduler.CosineAnnealingLR( optim, T_max=TOTAL_ITERS ) | |||||
trainer = amalgamation.LayerWiseAmalgamator( | |||||
logger=utils.logger.get_logger('layerwise-ka'), | |||||
# tb_writer=SummaryWriter( log_dir='run/layerwise_ka-%s'%( time.asctime().replace( ' ', '_' ) ) ) | |||||
) | |||||
trainer.add_callback( | |||||
engine.DefaultEvents.AFTER_STEP(every=10), | |||||
callbacks=callbacks.MetricsLogging(keys=('total_loss', 'loss_kd', 'loss_amal', 'loss_recons', 'lr'))) | |||||
trainer.add_callback( | |||||
engine.DefaultEvents.AFTER_EPOCH, | |||||
callbacks=[ | |||||
callbacks.EvalAndCkpt(model=student, evaluator=car_evaluator, metric_name='car_acc', ckpt_prefix='ttl_car'), | |||||
callbacks.EvalAndCkpt(model=student, evaluator=aircraft_evaluator, metric_name='aircraft_acc', ckpt_prefix='ttl_aircraft'), | |||||
] ) | |||||
trainer.add_callback( | |||||
engine.DefaultEvents.AFTER_STEP, | |||||
callbacks=callbacks.LRSchedulerCallback(schedulers=[sched])) | |||||
layer_groups = [] | |||||
layer_channels = [] | |||||
for stu_block, car_block, aircraft_block in zip( student.modules(), car_teacher.modules(), dog_teacher.modules() ): | |||||
if isinstance( stu_block, oneflow.nn.Conv2d ): | |||||
layer_groups.append( [ stu_block, car_block, aircraft_block ] ) | |||||
layer_channels.append( [ stu_block.out_channels, car_block.out_channels, aircraft_block.out_channels ] ) | |||||
trainer.setup( student=student, | |||||
teachers=[car_teacher, dog_teacher], | |||||
layer_groups=layer_groups, | |||||
layer_channels=layer_channels, | |||||
dataloader=train_loader, | |||||
optimizer=optim, | |||||
device=device, | |||||
weights=[1., 1., 1.] ) | |||||
trainer.run(start_iter=0, max_iter=TOTAL_ITERS) | |||||
if __name__=='__main__': | |||||
main() |
@@ -0,0 +1,41 @@ | |||||
""" | |||||
pytorch -- oneflow 格式权重转换 | |||||
""" | |||||
from kamal import vision | |||||
import torch | |||||
import argparse | |||||
import oneflow as flow | |||||
import pdb | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument( '--car_ckpt', required=True ) | |||||
parser.add_argument( '--dog_ckpt' ) | |||||
args = parser.parse_args() | |||||
cars_parameters = torch.load(args.car_ckpt).state_dict() | |||||
# dogs_parameters = torch.load(args.dog_ckpt).state_dict() | |||||
cars_para, dogs_para = {}, {} | |||||
# pdb.set_trace() | |||||
for key, value in cars_parameters.items(): | |||||
val = value.detach().cpu().numpy() | |||||
if not str(key).endswith('num_batches_tracked'): | |||||
cars_para[key] = val | |||||
# for key, value in dogs_parameters.items(): | |||||
# val = value.detach().cpu().numpy() | |||||
# if not str(key).endswith('num_batches_tracked'): | |||||
# dogs_para[key] = val | |||||
car_teacher = vision.models.classification.resnet50(num_classes=102, pretrained=False) | |||||
# dog_teacher = vision.models.classification.resnet50(num_classes=397, pretrained=False) | |||||
car_teacher.load_state_dict(cars_para) | |||||
# dog_teacher.load_state_dict(dogs_para) | |||||
# torch.save(car_teacher, 'ckpt/aircraft_res50.pth') | |||||
# torch.save(dog_teacher, 'checkpoint/sun_res50.pth') | |||||
flow.save(car_teacher.state_dict(), "./ckpt/aircraft_res50_model") | |||||
# flow.save(dog_teacher.state_dict(), "./checkpoint/sun_res50_model") |
@@ -0,0 +1,3 @@ | |||||
from .core import metrics, engine, callbacks | |||||
from . import amalgamation, vision |
@@ -0,0 +1 @@ | |||||
from .layerwise_amalgamation import LayerWiseAmalgamator |
@@ -0,0 +1,113 @@ | |||||
import oneflow as oneflow | |||||
import oneflow.nn as nn | |||||
from kamal.core.engine.engine import Engine | |||||
from kamal.core.engine.hooks import FeatureHook | |||||
from kamal.core import tasks | |||||
import typing | |||||
import time | |||||
from kamal.utils import move_to_device, set_mode | |||||
class AmalBlock(nn.Module): | |||||
def __init__(self, cs, cts): | |||||
super( AmalBlock, self ).__init__() | |||||
self.cs, self.cts = cs, cts | |||||
self.enc = nn.Conv2d( in_channels=sum(self.cts), out_channels=self.cs, kernel_size=1, stride=1, padding=0, bias=True ) | |||||
self.fam = nn.Conv2d( in_channels=self.cs, out_channels=self.cs, kernel_size=1, stride=1, padding=0, bias=True ) | |||||
self.dec = nn.Conv2d( in_channels=self.cs, out_channels=sum(self.cts), kernel_size=1, stride=1, padding=0, bias=True ) | |||||
def forward(self, fs, fts): | |||||
rep = self.enc( oneflow.cat( fts, dim=1 ) ) | |||||
_fts = self.dec( rep ) | |||||
_fts = oneflow.split( _fts, self.cts, dim=1 ) | |||||
_fs = self.fam( fs ) | |||||
return rep, _fs, _fts | |||||
class LayerWiseAmalgamator(Engine): | |||||
def setup( | |||||
self, | |||||
student, | |||||
teachers, | |||||
layer_groups: typing.Sequence[typing.Sequence], | |||||
layer_channels: typing.Sequence[typing.Sequence], | |||||
dataloader: oneflow.utils.data.DataLoader, | |||||
optimizer: oneflow.optim.Optimizer, | |||||
weights = [1., 1., 1.], | |||||
device=None, | |||||
): | |||||
if device is None: | |||||
device = oneflow.device('cuda:0') | |||||
# device = oneflow.device( 'cuda' if oneflow.cuda.is_available() else 'cpu' ) | |||||
self._device = device | |||||
self._dataloader = dataloader | |||||
self.model = self.student = student.to(self.device) | |||||
self.teachers = nn.ModuleList(teachers).to(self.device) | |||||
self.optimizer = optimizer | |||||
self._weights = weights | |||||
amal_blocks = [] | |||||
for group, C in zip(layer_groups, layer_channels): | |||||
hooks = [ FeatureHook(layer) for layer in group ] | |||||
amal_block = AmalBlock(cs=C[0], cts=C[1:]).to(self.device).train() | |||||
amal_blocks.append( (amal_block, hooks, C) ) | |||||
self._amal_blocks = amal_blocks | |||||
@property | |||||
def device(self): | |||||
return self._device | |||||
def run(self, max_iter, start_iter=0, epoch_length=None ): | |||||
block_params = [] | |||||
for block, _, _ in self._amal_blocks: | |||||
block_params.extend( list(block.parameters()) ) | |||||
if isinstance( self.optimizer, oneflow.optim.SGD ): | |||||
self._amal_optimimizer = oneflow.optim.SGD( block_params, lr=self.optimizer.param_groups[0]['lr'], momentum=0.9, weight_decay=1e-4 ) | |||||
else: | |||||
self._amal_optimimizer = oneflow.optim.Adam( block_params, lr=self.optimizer.param_groups[0]['lr'], weight_decay=1e-4 ) | |||||
self._amal_scheduler = oneflow.optim.lr_scheduler.CosineAnnealingLR( self._amal_optimimizer, T_max=max_iter ) | |||||
with set_mode(self.student, training=True), \ | |||||
set_mode(self.teachers, training=False): | |||||
super( LayerWiseAmalgamator, self ).run(self.step_fn, self._dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length) | |||||
@property | |||||
def device(self): | |||||
return self._device | |||||
def step_fn(self, engine, batch): | |||||
start_time = time.perf_counter() | |||||
batch = move_to_device(batch, self._device) | |||||
data = batch[0] | |||||
s_out = self.student( data ) | |||||
with oneflow.no_grad(): | |||||
t_out = [ teacher( data ) for teacher in self.teachers ] | |||||
loss_amal = 0 | |||||
loss_recons = 0 | |||||
for amal_block, hooks, C in self._amal_blocks: | |||||
features = [ h.feat_out for h in hooks ] | |||||
fs, fts = features[0], features[1:] | |||||
rep, _fs, _fts = amal_block( fs, fts ) | |||||
mse_loss = nn.MSELoss() | |||||
loss_amal += mse_loss( _fs, rep.detach() ) | |||||
loss_recons += sum( [ mse_loss( _ft, ft ) for (_ft, ft) in zip( _fts, fts ) ] ) | |||||
loss_kd = tasks.loss.kldiv( s_out, oneflow.cat( t_out, dim=1 ) ) | |||||
#loss_kd = F.mse_loss( s_out, oneflow.cat( t_out, dim=1 ) ) | |||||
loss_dict = { "loss_kd": self._weights[0] * loss_kd, | |||||
"loss_amal": self._weights[1] * loss_amal, | |||||
"loss_recons": self._weights[2] * loss_recons } | |||||
loss = sum(loss_dict.values()) | |||||
self.optimizer.zero_grad() | |||||
self._amal_optimimizer.zero_grad() | |||||
loss.backward() | |||||
self.optimizer.step() | |||||
self._amal_optimimizer.step() | |||||
self._amal_scheduler.step() | |||||
step_time = time.perf_counter() - start_time | |||||
metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } | |||||
metrics.update({ | |||||
'total_loss': loss.item(), | |||||
'step_time': step_time, | |||||
'lr': float( self.optimizer.param_groups[0]['lr'] ) | |||||
}) | |||||
return metrics |
@@ -0,0 +1,2 @@ | |||||
from . import engine, tasks, metrics, callbacks, exceptions | |||||
from .attach import AttachTo |
@@ -0,0 +1,28 @@ | |||||
from typing import Sequence, Callable | |||||
from numbers import Number | |||||
from kamal.core import exceptions | |||||
import pdb | |||||
class AttachTo(object): | |||||
""" Attach task, metrics or visualizer to specified model outputs | |||||
""" | |||||
def __init__(self, attach_to=None): | |||||
if attach_to is not None and not isinstance(attach_to, (Sequence, Number, str, Callable) ): | |||||
raise exceptions.InvalidMapping | |||||
self._attach_to = attach_to | |||||
def __call__(self, *tensors): | |||||
if self._attach_to is not None: | |||||
if isinstance(self._attach_to, Callable): | |||||
return self._attach_to( *tensors ) | |||||
if isinstance(self._attach_to, Sequence): | |||||
_attach_to = self._attach_to | |||||
else: | |||||
_attach_to = [ self._attach_to for _ in range(len(tensors)) ] | |||||
_attach_to = _attach_to[:len(tensors)] | |||||
tensors = [ tensor[index] for (tensor, index) in zip( tensors, _attach_to ) ] | |||||
if len(tensors)==1: | |||||
tensors = tensors[0] | |||||
return tensors | |||||
def __repr__(self): | |||||
rep = "AttachTo: %s"%(self._attach_to) |
@@ -0,0 +1,4 @@ | |||||
from .logging import MetricsLogging, ProgressCallback | |||||
from .base import Callback | |||||
from .eval_and_ckpt import EvalAndCkpt | |||||
from .scheduler import LRSchedulerCallback |
@@ -0,0 +1,26 @@ | |||||
# Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================= | |||||
import abc | |||||
class Callback(abc.ABC): | |||||
r""" Base Class for Callbacks | |||||
""" | |||||
def __init__(self): | |||||
pass | |||||
@abc.abstractmethod | |||||
def __call__(self, engine): | |||||
pass |
@@ -0,0 +1,133 @@ | |||||
from .base import Callback | |||||
import weakref | |||||
from kamal import utils | |||||
from typing import Sequence, Optional | |||||
import numbers | |||||
import os, shutil | |||||
import oneflow | |||||
import pdb | |||||
class EvalAndCkpt(Callback): | |||||
def __init__(self, | |||||
model, | |||||
evaluator, | |||||
metric_name:str, | |||||
metric_mode:str ='max', | |||||
save_type:Optional[Sequence]=('best', 'latest'), | |||||
ckpt_dir:str ='checkpoints', | |||||
ckpt_prefix:str =None, | |||||
log_tag:str ='model', | |||||
weights_only:bool =True, | |||||
verbose:bool =False,): | |||||
super(EvalAndCkpt, self).__init__() | |||||
self.metric_name = metric_name | |||||
assert metric_mode in ('max', 'min'), "metric_mode should be 'max' or 'min'" | |||||
self._metric_mode = metric_mode | |||||
self._model = weakref.ref( model ) | |||||
self._evaluator = evaluator | |||||
self._ckpt_dir = ckpt_dir | |||||
self._ckpt_prefix = "" if ckpt_prefix is None else (ckpt_prefix+'_') | |||||
if isinstance(save_type, str): | |||||
save_type = ( save_type, ) | |||||
self._save_type = save_type | |||||
if self._save_type is not None: | |||||
for save_type in self._save_type: | |||||
assert save_type in ('best', 'latest', 'all'), \ | |||||
'save_type should be None or a subset of (\"best\", \"latest\", \"all\")' | |||||
self._log_tag = log_tag | |||||
self._weights_only = weights_only | |||||
self._verbose = verbose | |||||
self._best_score = -999999 if self._metric_mode=='max' else 99999. | |||||
self._best_ckpt = None | |||||
self._latest_ckpt = None | |||||
@property | |||||
def best_ckpt(self): | |||||
return self._best_ckpt | |||||
@property | |||||
def latest_ckpt(self): | |||||
return self._latest_ckpt | |||||
@property | |||||
def best_score(self): | |||||
return self._best_score | |||||
def __call__(self, trainer): | |||||
model = self._model() | |||||
results = self._evaluator.eval( model, device=trainer.device ) | |||||
results = utils.flatten_dict(results) | |||||
# pdb.set_trace() | |||||
scalar_results = { k: v.item() for k, v in results.items()} | |||||
current_score = scalar_results[self.metric_name] | |||||
if trainer.logger is not None: | |||||
trainer.logger.info( "[Eval %s] Iter %d/%d: %s"%(self._log_tag, trainer.state.iter, trainer.state.max_iter, scalar_results) ) | |||||
trainer.state.metrics.update( scalar_results ) | |||||
# Visualize results if trainer.tb_writer is not None | |||||
if trainer.tb_writer is not None: | |||||
for k, v in scalar_results.items(): | |||||
log_tag = "%s:%s"%(self._log_tag, k) | |||||
trainer.tb_writer.add_scalar(log_tag, v, global_step=trainer.state.iter) | |||||
if self._save_type is not None: | |||||
pth_path_list = [] | |||||
# interval model | |||||
if 'interval' in self._save_type: | |||||
pth_path = os.path.join(self._ckpt_dir, "%s%08d_%s_%.3f" | |||||
% (self._ckpt_prefix, trainer.state.iter, self.metric_name, current_score)) | |||||
pth_path_list.append(pth_path) | |||||
# the latest model | |||||
if 'latest' in self._save_type: | |||||
pth_path = os.path.join(self._ckpt_dir, "%slatest_%08d_%s_%.3f" | |||||
% (self._ckpt_prefix, trainer.state.iter, self.metric_name, current_score)) | |||||
# remove the old ckpt | |||||
if self._latest_ckpt is not None and os.path.exists(self._latest_ckpt): | |||||
# os.remove(self._latest_ckpt) | |||||
shutil.rmtree(self._latest_ckpt) | |||||
pth_path_list.append(pth_path) | |||||
self._latest_ckpt = pth_path | |||||
# the best model | |||||
if 'best' in self._save_type: | |||||
if (current_score >= self._best_score and self._metric_mode=='max') or \ | |||||
(current_score <= self._best_score and self._metric_mode=='min'): | |||||
pth_path = os.path.join(self._ckpt_dir, "%sbest_%08d_%s_%.4f" % | |||||
(self._ckpt_prefix, trainer.state.iter, self.metric_name, current_score)) | |||||
# remove the old ckpt | |||||
if self._best_ckpt is not None and os.path.exists(self._best_ckpt): | |||||
# os.remove(self._best_ckpt) | |||||
shutil.rmtree(self._best_ckpt) | |||||
pth_path_list.append(pth_path) | |||||
self._best_score = current_score | |||||
self._best_ckpt = pth_path | |||||
# save model | |||||
if self._verbose and trainer.logger: | |||||
trainer.logger.info("Model saved as:") | |||||
obj = model.state_dict() if self._weights_only else model | |||||
os.makedirs( self._ckpt_dir, exist_ok=True ) | |||||
for pth_path in pth_path_list: | |||||
oneflow.save(obj, pth_path) | |||||
if self._verbose and trainer.logger: | |||||
trainer.logger.info("\t%s" % (pth_path)) | |||||
def final_ckpt(self, ckpt_prefix=None, ckpt_dir=None, add_md5=False): | |||||
if ckpt_dir is None: | |||||
ckpt_dir = self._ckpt_dir | |||||
if ckpt_prefix is None: | |||||
ckpt_prefix = self._ckpt_prefix | |||||
if self._save_type is not None: | |||||
#if 'latest' in self._save_type and self._latest_ckpt is not None: | |||||
# os.makedirs(ckpt_dir, exist_ok=True) | |||||
# save_name = "%slatest%s.pth"%(ckpt_prefix, "" if not add_md5 else "-%s"%utils.md5(self._latest_ckpt)) | |||||
# shutil.copy2(self._latest_ckpt, os.path.join(ckpt_dir, save_name)) | |||||
if 'best' in self._save_type and self._best_ckpt is not None: | |||||
os.makedirs(ckpt_dir, exist_ok=True) | |||||
save_name = "%sbest%s.pth"%(ckpt_prefix, "" if not add_md5 else "-%s"%utils.md5(self._best_ckpt)) | |||||
shutil.copy2(self._best_ckpt, os.path.join(ckpt_dir, save_name)) |
@@ -0,0 +1,59 @@ | |||||
# Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================= | |||||
from .base import Callback | |||||
import numbers | |||||
from tqdm import tqdm | |||||
class MetricsLogging(Callback): | |||||
def __init__(self, keys): | |||||
super(MetricsLogging, self).__init__() | |||||
self._keys = keys | |||||
def __call__(self, engine): | |||||
if engine.logger==None: | |||||
return | |||||
state = engine.state | |||||
content = "Iter %d/%d (Epoch %d/%d, Batch %d/%d)"%( | |||||
state.iter, state.max_iter, | |||||
state.current_epoch, state.max_epoch, | |||||
state.current_batch_index, state.max_batch_index | |||||
) | |||||
for key in self._keys: | |||||
value = state.metrics.get(key, None) | |||||
if value is not None: | |||||
if isinstance(value, numbers.Number): | |||||
content += " %s=%.4f"%(key, value) | |||||
if engine.tb_writer is not None: | |||||
engine.tb_writer.add_scalar(key, value, global_step=state.iter) | |||||
elif isinstance(value, (list, tuple)): | |||||
content += " %s=%s"%(key, value) | |||||
engine.logger.info(content) | |||||
class ProgressCallback(Callback): | |||||
def __init__(self, max_iter=100, tag=None): | |||||
self._max_iter = max_iter | |||||
self._tag = tag | |||||
#self._pbar = tqdm(total=self._max_iter, desc=self._tag) | |||||
def __call__(self, engine): | |||||
self._pbar.update(1) | |||||
if self._pbar.n==self._max_iter: | |||||
self._pbar.close() | |||||
def reset(self): | |||||
self._pbar = tqdm(total=self._max_iter, desc=self._tag) | |||||
@@ -0,0 +1,32 @@ | |||||
# Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================= | |||||
from .base import Callback | |||||
from typing import Sequence | |||||
class LRSchedulerCallback(Callback): | |||||
r""" LR scheduler callback | |||||
""" | |||||
def __init__(self, schedulers=None): | |||||
super(LRSchedulerCallback, self).__init__() | |||||
if not isinstance(schedulers, Sequence): | |||||
schedulers = ( schedulers, ) | |||||
self._schedulers = schedulers | |||||
def __call__(self, trainer): | |||||
if self._schedulers is None: | |||||
return | |||||
for sched in self._schedulers: | |||||
sched.step() |
@@ -0,0 +1,7 @@ | |||||
from . import evaluator | |||||
from . import trainer | |||||
from . import hooks | |||||
from . import events | |||||
from .engine import DefaultEvents, Event |
@@ -0,0 +1,188 @@ | |||||
# Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================= | |||||
import oneflow.nn as nn | |||||
import abc, weakref | |||||
from typing import Any, Callable, Sequence | |||||
from kamal.core.engine.events import DefaultEvents, Event | |||||
from kamal.utils import get_logger | |||||
from collections import defaultdict | |||||
import numbers | |||||
import contextlib | |||||
import pdb | |||||
class State(object): | |||||
def __init__(self): | |||||
self.iter = 0 | |||||
self.max_iter = None | |||||
self.epoch_length = None | |||||
self.dataloader = None | |||||
self.seed = None | |||||
self.metrics=dict() | |||||
self.batch=None | |||||
@property | |||||
def current_epoch(self): | |||||
if self.epoch_length is not None: | |||||
return self.iter // self.epoch_length | |||||
return None | |||||
@property | |||||
def max_epoch(self): | |||||
if self.epoch_length is not None: | |||||
return self.max_iter // self.epoch_length | |||||
return None | |||||
@property | |||||
def current_batch_index(self): | |||||
if self.epoch_length is not None: | |||||
return self.iter % self.epoch_length | |||||
return None | |||||
@property | |||||
def max_batch_index(self): | |||||
return self.epoch_length | |||||
def __repr__(self): | |||||
rep = "State:\n" | |||||
for attr, value in self.__dict__.items(): | |||||
if not isinstance(value, (numbers.Number, str, dict)): | |||||
value = type(value) | |||||
rep += "\t{}: {}\n".format(attr, value) | |||||
return rep | |||||
class Engine(abc.ABC): | |||||
def __init__(self, logger=None, tb_writer=None): | |||||
self._logger = logger if logger else get_logger(name='kamal', color=True) | |||||
self._tb_writer = tb_writer | |||||
self._callbacks = defaultdict(list) | |||||
self._allowed_events = [ *DefaultEvents ] | |||||
self._state = State() | |||||
def reset(self): | |||||
self._state = State() | |||||
def run(self, step_fn: Callable, dataloader, max_iter, start_iter=0, epoch_length=None): | |||||
self.state.iter = self._state.start_iter = start_iter | |||||
self.state.max_iter = max_iter | |||||
self.state.epoch_length = epoch_length if epoch_length else len(dataloader) | |||||
self.state.dataloader = dataloader | |||||
self.state.dataloader_iter = iter(dataloader) | |||||
self.state.step_fn = step_fn | |||||
self.trigger_events(DefaultEvents.BEFORE_RUN) | |||||
for self.state.iter in range( start_iter, max_iter ): | |||||
if self.state.epoch_length!=None and \ | |||||
self.state.iter%self.state.epoch_length==0: # Epoch Start | |||||
self.trigger_events(DefaultEvents.BEFORE_EPOCH) | |||||
self.trigger_events(DefaultEvents.BEFORE_STEP) | |||||
self.state.batch = self._get_batch() | |||||
step_output = step_fn(self, self.state.batch) | |||||
if isinstance(step_output, dict): | |||||
self.state.metrics.update(step_output) | |||||
self.trigger_events(DefaultEvents.AFTER_STEP) | |||||
if self.state.epoch_length!=None and \ | |||||
(self.state.iter+1)%self.state.epoch_length==0: # Epoch End | |||||
self.trigger_events(DefaultEvents.AFTER_EPOCH) | |||||
self.trigger_events(DefaultEvents.AFTER_RUN) | |||||
def _get_batch(self): | |||||
try: | |||||
batch = next( self.state.dataloader_iter ) | |||||
except StopIteration: | |||||
self.state.dataloader_iter = iter(self.state.dataloader) # reset iterator | |||||
batch = next( self.state.dataloader_iter ) | |||||
if not isinstance(batch, (list, tuple)): | |||||
batch = [ batch, ] # no targets | |||||
return batch | |||||
@property | |||||
def state(self): | |||||
return self._state | |||||
@property | |||||
def logger(self): | |||||
return self._logger | |||||
@property | |||||
def tb_writer(self): | |||||
return self._tb_writer | |||||
def add_callback(self, event: Event, callbacks ): | |||||
if not isinstance(callbacks, Sequence): | |||||
callbacks = [callbacks] | |||||
if event in self._allowed_events: | |||||
for callback in callbacks: | |||||
if callback not in self._callbacks[event]: | |||||
if event.trigger!=event.default_trigger: | |||||
callback = self._trigger_wrapper(self, event.trigger, callback ) | |||||
self._callbacks[event].append( callback ) | |||||
callbacks = [ RemovableCallback(self, event, c) for c in callbacks ] | |||||
return ( callbacks[0] if len(callbacks)==1 else callbacks ) | |||||
def remove_callback(self, event, callback): | |||||
for c in self._callbacks[event]: | |||||
if c==callback: | |||||
self._callbacks.remove( callback ) | |||||
return True | |||||
return False | |||||
@staticmethod | |||||
def _trigger_wrapper(engine, trigger, callback): | |||||
def wrapper(*args, **kwargs) -> Any: | |||||
if trigger(engine): | |||||
return callback(engine) | |||||
return wrapper | |||||
def trigger_events(self, *events): | |||||
for e in events: | |||||
if e in self._allowed_events: | |||||
for callback in self._callbacks[e]: | |||||
callback(self) | |||||
def register_events(self, *events): | |||||
for e in events: | |||||
if e not in self._allowed_events: | |||||
self._allowed_events.apped( e ) | |||||
@contextlib.contextmanager | |||||
def save_current_callbacks(self): | |||||
temp = self._callbacks | |||||
self._callbacks = defaultdict(list) | |||||
yield | |||||
self._callbacks = temp | |||||
class RemovableCallback: | |||||
def __init__(self, engine, event, callback): | |||||
self._engine = weakref.ref(engine) | |||||
self._callback = weakref.ref(callback) | |||||
self._event = weakref.ref(event) | |||||
@property | |||||
def callback(self): | |||||
return self._callback() | |||||
def remove(self): | |||||
engine = self._engine() | |||||
callback = self._callback() | |||||
event = self._event() | |||||
return engine.remove_callback(event, callback) | |||||
@@ -0,0 +1,62 @@ | |||||
import oneflow | |||||
from kamal.core import metrics | |||||
from kamal.utils import set_mode | |||||
from typing import Callable | |||||
from .engine import Engine | |||||
from .events import DefaultEvents | |||||
from kamal.core import callbacks | |||||
import pdb | |||||
import weakref | |||||
from kamal.utils import move_to_device, split_batch | |||||
class BasicEvaluator(Engine): | |||||
def __init__(self, | |||||
dataloader: oneflow.utils.data.DataLoader, | |||||
metric: metrics.MetricCompose, | |||||
eval_fn: Callable=None, | |||||
tag: str='Eval', | |||||
progress: bool=False ): | |||||
super( BasicEvaluator, self ).__init__() | |||||
self.dataloader = dataloader | |||||
self.metric = metric | |||||
self.progress = progress | |||||
if progress: | |||||
self.porgress_callback = self.add_callback( | |||||
DefaultEvents.AFTER_STEP, callbacks=callbacks.ProgressCallback(max_iter=len(self.dataloader), tag=tag)) | |||||
self._model = None | |||||
self._tag = tag | |||||
if eval_fn is None: | |||||
eval_fn = BasicEvaluator.default_eval_fn | |||||
self.eval_fn = eval_fn | |||||
def eval(self, model, device=None): | |||||
device = device if device is not None else \ | |||||
oneflow.device( 'cuda' if oneflow.cuda.is_available() else 'cpu' ) | |||||
self._model = weakref.ref(model) # use weakref here | |||||
self.device = device | |||||
self.metric.reset() | |||||
model.to(device) | |||||
if self.progress: | |||||
self.porgress_callback.callback.reset() | |||||
with oneflow.no_grad(), set_mode(model, training=False): | |||||
super(BasicEvaluator, self).run( self.step_fn, self.dataloader, max_iter=len(self.dataloader) ) | |||||
return self.metric.get_results() | |||||
@property | |||||
def model(self): | |||||
if self._model is not None: | |||||
return self._model() | |||||
return None | |||||
def step_fn(self, engine, batch): | |||||
batch = move_to_device(batch, self.device) | |||||
self.eval_fn( engine, batch ) | |||||
@staticmethod | |||||
def default_eval_fn(evaluator, batch): | |||||
model = evaluator.model | |||||
inputs, targets = split_batch(batch) | |||||
outputs = model( inputs ) | |||||
evaluator.metric.update( outputs, targets ) |
@@ -0,0 +1,90 @@ | |||||
# Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================= | |||||
from typing import Callable, Optional | |||||
from enum import Enum | |||||
class Event(object): | |||||
def __init__(self, value: str, event_trigger: Optional[Callable]=None ): | |||||
if event_trigger is None: | |||||
event_trigger = Event.default_trigger | |||||
self._trigger = event_trigger | |||||
self._name_ = self._value_ = value | |||||
@property | |||||
def trigger(self): | |||||
return self._trigger | |||||
@property | |||||
def name(self): | |||||
"""The name of the Enum member.""" | |||||
return self._name_ | |||||
@property | |||||
def value(self): | |||||
"""The value of the Enum member.""" | |||||
return self._value_ | |||||
@staticmethod | |||||
def default_trigger(engine): | |||||
return True | |||||
@staticmethod | |||||
def once_trigger(): | |||||
is_triggered = False | |||||
def wrapper(engine): | |||||
if is_triggered: | |||||
return False | |||||
is_triggered=True | |||||
return True | |||||
return wrapper | |||||
@staticmethod | |||||
def every_trigger(every: int): | |||||
def wrapper(engine): | |||||
return every>0 and (engine.state.iter % every)==0 | |||||
return wrapper | |||||
def __call__(self, every: Optional[int]=None, once: Optional[bool]=None ): | |||||
if every is not None: | |||||
assert once is None | |||||
return Event(self.value, event_trigger=Event.every_trigger(every) ) | |||||
if once is not None: | |||||
return Event(self.value, event_trigger=Event.once_trigger() ) | |||||
return Event(self.value) | |||||
def __hash__(self): | |||||
return hash(self._name_) | |||||
def __eq__(self, other): | |||||
if hasattr(other, 'value'): | |||||
return self.value==other.value | |||||
else: | |||||
return | |||||
class DefaultEvents(Event, Enum): | |||||
BEFORE_RUN = "before_train" | |||||
AFTER_RUN = "after_train" | |||||
BEFORE_EPOCH = "before_epoch" | |||||
AFTER_EPOCH = "after_epoch" | |||||
BEFORE_STEP = "before_step" | |||||
AFTER_STEP = "after_step" | |||||
BEFORE_GET_BATCH = "before_get_batch" | |||||
AFTER_GET_BATCH = "after_get_batch" | |||||
@@ -0,0 +1,32 @@ | |||||
# Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================= | |||||
class FeatureHook(): | |||||
def __init__(self, module): | |||||
self.module = module | |||||
self.feat_in = None | |||||
self.feat_out = None | |||||
self.register() | |||||
def register(self): | |||||
self._hook = self.module.register_forward_hook(self.hook_fn_forward) | |||||
def remove(self): | |||||
self._hook.remove() | |||||
def hook_fn_forward(self, module, fea_in, fea_out): | |||||
self.feat_in = fea_in[0] | |||||
self.feat_out = fea_out | |||||
@@ -0,0 +1,115 @@ | |||||
import oneflow as torch | |||||
import oneflow.nn as nn | |||||
from kamal.core.engine.engine import Engine | |||||
from kamal.core import tasks | |||||
from kamal.utils import set_mode, move_to_device, split_batch | |||||
from typing import Sequence | |||||
import time | |||||
import pdb | |||||
class BasicTrainer(Engine): | |||||
def __init__( self, | |||||
logger=None, | |||||
tb_writer=None): | |||||
super(BasicTrainer, self).__init__(logger=logger, tb_writer=tb_writer) | |||||
def setup(self, | |||||
model: torch.nn.Module, | |||||
task: tasks.Task, | |||||
dataloader: torch.utils.data.DataLoader, | |||||
optimizer: torch.optim.Optimizer, | |||||
device: torch.device=None): | |||||
if device is None: | |||||
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) | |||||
self.device = device | |||||
if isinstance(task, Sequence): | |||||
task = tasks.TaskCompose(task) | |||||
self.task = task | |||||
# 多GPU | |||||
# torch.cuda.set_device(4) | |||||
self.model = model.to("cuda") | |||||
self.model = nn.parallel.DistributedDataParallel(self.model) | |||||
self.dataloader = dataloader | |||||
self.optimizer = optimizer | |||||
return self | |||||
def run( self, max_iter, start_iter=0, epoch_length=None): | |||||
self.model.to(self.device) | |||||
with set_mode(self.model, training=True): | |||||
super( BasicTrainer, self ).run( self.step_fn, self.dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length) | |||||
def step_fn(self, engine, batch): | |||||
model = self.model | |||||
start_time = time.perf_counter() | |||||
batch = move_to_device(batch, self.device) | |||||
inputs, targets = split_batch(batch) | |||||
outputs = model(inputs) | |||||
loss_dict = self.task.get_loss(outputs, targets) # get loss | |||||
loss = sum( loss_dict.values() ) | |||||
self.optimizer.zero_grad() | |||||
loss.backward() | |||||
self.optimizer.step() | |||||
step_time = time.perf_counter() - start_time | |||||
metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } | |||||
metrics.update({ | |||||
'total_loss': loss.item(), | |||||
'step_time': step_time, | |||||
'lr': float( self.optimizer.param_groups[0]['lr'] ) | |||||
}) | |||||
return metrics | |||||
class KDTrainer(BasicTrainer): | |||||
def setup(self, | |||||
student: torch.nn.Module, | |||||
teacher: torch.nn.Module, | |||||
task: tasks.Task, | |||||
dataloader: torch.utils.data.DataLoader, | |||||
optimizer: torch.optim.Optimizer, | |||||
device: torch.device=None): | |||||
super(KDTrainer, self).setup( | |||||
model=student, task=task, dataloader=dataloader, optimizer=optimizer, device=device) | |||||
if isinstance(teacher, (list, tuple)): | |||||
if len(teacher)==1: | |||||
teacher=teacher[0] | |||||
else: | |||||
teacher = nn.ModuleList(teacher) | |||||
self.student = self.model | |||||
self.teacher = teacher | |||||
return self | |||||
def run( self, max_iter, start_iter=0, epoch_length=None): | |||||
self.student.to(self.device) | |||||
self.teacher.to(self.device) | |||||
with set_mode(self.student, training=True), \ | |||||
set_mode(self.teacher, training=False): | |||||
super( BasicTrainer, self ).run( | |||||
self.step_fn, self.dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length) | |||||
def step_fn(self, engine, batch): | |||||
model = self.model | |||||
start_time = time.perf_counter() | |||||
batch = move_to_device(batch, self.device) | |||||
inputs, targets = split_batch(batch) | |||||
outputs = model(inputs) | |||||
if isinstance(self.teacher, nn.ModuleList): | |||||
soft_targets = [ t(inputs) for t in self.teacher ] | |||||
else: | |||||
soft_targets = self.teacher(inputs) | |||||
loss_dict = self.task.get_loss(outputs, soft_targets) # get loss | |||||
loss = sum( loss_dict.values() ) | |||||
self.optimizer.zero_grad() | |||||
loss.backward() | |||||
self.optimizer.step() | |||||
step_time = time.perf_counter() - start_time | |||||
metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } | |||||
metrics.update({ | |||||
'total_loss': loss.item(), | |||||
'step_time': step_time, | |||||
'lr': float( self.optimizer.param_groups[0]['lr'] ) | |||||
}) | |||||
return metrics |
@@ -0,0 +1,5 @@ | |||||
class DataTypeError(object): | |||||
pass | |||||
class InvalidMapping(object): | |||||
pass |
@@ -0,0 +1,3 @@ | |||||
from .stream_metrics import Metric, MetricCompose | |||||
from .accuracy import Accuracy | |||||
@@ -0,0 +1,23 @@ | |||||
import oneflow | |||||
from kamal.core.metrics.stream_metrics import Metric | |||||
__all__=['Accuracy'] | |||||
class Accuracy(Metric): | |||||
def __init__(self, attach_to=None): | |||||
super(Accuracy, self).__init__(attach_to=attach_to) | |||||
self.reset() | |||||
@oneflow.no_grad() | |||||
def update(self, outputs, targets): | |||||
outputs, targets = self._attach(outputs, targets) | |||||
outputs = outputs.max(1)[1] | |||||
self._correct += ( outputs.view(-1)==targets.view(-1) ).sum() | |||||
self._cnt += oneflow.numel( targets ) | |||||
def get_results(self): | |||||
return (self._correct / self._cnt).detach().cpu() | |||||
def reset(self): | |||||
self._correct = self._cnt = 0.0 |
@@ -0,0 +1,63 @@ | |||||
from __future__ import division | |||||
import oneflow | |||||
from abc import ABC, abstractmethod | |||||
from typing import Mapping | |||||
from kamal.core.attach import AttachTo | |||||
class Metric(ABC): | |||||
def __init__(self, attach_to=None): | |||||
self._attach = AttachTo(attach_to) | |||||
@abstractmethod | |||||
def update(self, pred, target): | |||||
""" Overridden by subclasses """ | |||||
raise NotImplementedError() | |||||
@abstractmethod | |||||
def get_results(self): | |||||
""" Overridden by subclasses """ | |||||
raise NotImplementedError() | |||||
@abstractmethod | |||||
def reset(self): | |||||
""" Overridden by subclasses """ | |||||
raise NotImplementedError() | |||||
class MetricCompose(dict): | |||||
def __init__(self, metric_dict: Mapping): | |||||
self._metric_dict = metric_dict | |||||
def add_metrics( self, metric_dict: Mapping): | |||||
if isinstance(metric_dict, MetricCompose): | |||||
metric_dict = metric_dict.metrics | |||||
self._metric_dict.update(metric_dict) | |||||
return self | |||||
@property | |||||
def metrics(self): | |||||
return self._metric_dict | |||||
@oneflow.no_grad() | |||||
def update(self, outputs, targets): | |||||
for key, metric in self._metric_dict.items(): | |||||
if isinstance(metric, Metric): | |||||
metric.update(outputs, targets) | |||||
def get_results(self): | |||||
results = {} | |||||
for key, metric in self._metric_dict.items(): | |||||
if isinstance(metric, Metric): | |||||
results[key] = metric.get_results() | |||||
return results | |||||
def reset(self): | |||||
for key, metric in self._metric_dict.items(): | |||||
if isinstance(metric, Metric): | |||||
metric.reset() | |||||
def __getitem__(self, name): | |||||
return self._metric_dict[name] | |||||
@@ -0,0 +1,3 @@ | |||||
from .task import StandardTask, StandardMetrics, GeneralTask, Task, TaskCompose | |||||
from . import loss | |||||
@@ -0,0 +1,2 @@ | |||||
from . import functional, loss | |||||
from .loss import * |
@@ -0,0 +1,17 @@ | |||||
import oneflow | |||||
import oneflow.nn.functional as F | |||||
def kldiv(logits, targets, T=1.0): | |||||
""" Cross Entropy for soft targets | |||||
Parameters: | |||||
- logits (Tensor): logits score (e.g. outputs of fc layer) | |||||
- targets (Tensor): logits of soft targets | |||||
- T (float): temperature of distill | |||||
- reduction: reduction to the output | |||||
""" | |||||
p_targets = F.softmax(targets/T, dim=1) | |||||
logp_logits = F.log_softmax(logits/T, dim=1) | |||||
kl_div = oneflow.nn.KLDivLoss(reduction="none") | |||||
kld = kl_div(logp_logits, p_targets) * (T**2) | |||||
return kld.sum(1).mean() |
@@ -0,0 +1,8 @@ | |||||
from .functional import * | |||||
class KLDiv(object): | |||||
def __init__(self, T=1.0): | |||||
self.T = T | |||||
def __call__(self, logits, targets): | |||||
return kldiv( logits, targets, T=self.T ) |
@@ -0,0 +1,180 @@ | |||||
# Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================= | |||||
import abc | |||||
import oneflow.nn as nn | |||||
import oneflow.nn.functional as F | |||||
from typing import Callable, Dict, Any | |||||
from . import loss | |||||
from kamal.core import metrics | |||||
from kamal.core.attach import AttachTo | |||||
class Task(object): | |||||
def __init__(self, name): | |||||
self.name = name | |||||
@abc.abstractmethod | |||||
def get_loss( self, outputs, targets ) -> Dict: | |||||
pass | |||||
@abc.abstractmethod | |||||
def predict(self, outputs) -> Any: | |||||
pass | |||||
class GeneralTask(Task): | |||||
def __init__(self, | |||||
name: str, | |||||
loss_fn: Callable, | |||||
scaling:float=1.0, | |||||
pred_fn: Callable=lambda x: x, | |||||
attach_to=None): | |||||
super(GeneralTask, self).__init__(name) | |||||
self._attach = AttachTo(attach_to) | |||||
self.loss_fn = loss_fn | |||||
self.pred_fn = pred_fn | |||||
self.scaling = scaling | |||||
def get_loss(self, outputs, targets): | |||||
outputs, targets = self._attach(outputs, targets) | |||||
return { self.name: self.loss_fn( outputs, targets ) * self.scaling } | |||||
def predict(self, outputs): | |||||
outputs = self._attach(outputs) | |||||
return self.pred_fn(outputs) | |||||
def __repr__(self): | |||||
rep = "Task: [%s loss_fn=%s scaling=%.4f attach=%s]"%(self.name, str(self.loss_fn), self.scaling, self._attach) | |||||
return rep | |||||
class TaskCompose(list): | |||||
def __init__(self, tasks: list): | |||||
for task in tasks: | |||||
if isinstance(task, Task): | |||||
self.append(task) | |||||
def get_loss(self, outputs, targets): | |||||
loss_dict = {} | |||||
for task in self: | |||||
loss_dict.update( task.get_loss( outputs, targets ) ) | |||||
return loss_dict | |||||
def predict(self, outputs): | |||||
results = [] | |||||
for task in self: | |||||
results.append( task.predict( outputs ) ) | |||||
return results | |||||
def __repr__(self): | |||||
rep="TaskCompose: \n" | |||||
for task in self: | |||||
rep+="\t%s\n"%task | |||||
class StandardTask: | |||||
@staticmethod | |||||
def classification(name='ce', scaling=1.0, attach_to=None): | |||||
return GeneralTask( name=name, | |||||
loss_fn=nn.CrossEntropyLoss(), | |||||
scaling=scaling, | |||||
pred_fn=lambda x: x.max(1)[1], | |||||
attach_to=attach_to ) | |||||
@staticmethod | |||||
def binary_classification(name='bce', scaling=1.0, attach_to=None): | |||||
return GeneralTask(name=name, | |||||
loss_fn=F.binary_cross_entropy_with_logits, | |||||
scaling=scaling, | |||||
pred_fn=lambda x: (x>0.5), | |||||
attach_to=attach_to ) | |||||
@staticmethod | |||||
def regression(name='mse', scaling=1.0, attach_to=None): | |||||
return GeneralTask(name=name, | |||||
loss_fn=nn.MSELoss(), | |||||
scaling=scaling, | |||||
pred_fn=lambda x: x, | |||||
attach_to=attach_to ) | |||||
@staticmethod | |||||
def segmentation(name='ce', scaling=1.0, attach_to=None): | |||||
return GeneralTask(name=name, | |||||
loss_fn=nn.CrossEntropyLoss(ignore_index=255), | |||||
scaling=scaling, | |||||
pred_fn=lambda x: x.max(1)[1], | |||||
attach_to=attach_to ) | |||||
@staticmethod | |||||
def monocular_depth(name='l1', scaling=1.0, attach_to=None): | |||||
return GeneralTask(name=name, | |||||
loss_fn=nn.L1Loss(), | |||||
scaling=scaling, | |||||
pred_fn=lambda x: x, | |||||
attach_to=attach_to) | |||||
@staticmethod | |||||
def detection(): | |||||
raise NotImplementedError | |||||
@staticmethod | |||||
def distillation(name='kld', T=1.0, scaling=1.0, attach_to=None): | |||||
return GeneralTask(name=name, | |||||
loss_fn=loss.KLDiv(T=T), | |||||
scaling=scaling, | |||||
pred_fn=lambda x: x.max(1)[1], | |||||
attach_to=attach_to) | |||||
class StandardMetrics(object): | |||||
@staticmethod | |||||
def classification(attach_to=None): | |||||
return metrics.MetricCompose( | |||||
metric_dict={'acc': metrics.Accuracy(attach_to=attach_to)} | |||||
) | |||||
@staticmethod | |||||
def regression(attach_to=None): | |||||
return metrics.MetricCompose( | |||||
metric_dict={'mse': metrics.MeanSquaredError(attach_to=attach_to)} | |||||
) | |||||
@staticmethod | |||||
def segmentation(num_classes, ignore_idx=255, attach_to=None): | |||||
confusion_matrix = metrics.ConfusionMatrix(num_classes=num_classes, ignore_idx=ignore_idx, attach_to=attach_to) | |||||
return metrics.MetricCompose( | |||||
metric_dict={'acc': metrics.Accuracy(attach_to=attach_to), | |||||
'confusion_matrix': confusion_matrix , | |||||
'miou': metrics.mIoU(confusion_matrix)} | |||||
) | |||||
@staticmethod | |||||
def monocular_depth(attach_to=None): | |||||
return metrics.MetricCompose( | |||||
metric_dict={ | |||||
'rmse': metrics.RootMeanSquaredError(attach_to=attach_to), | |||||
'rmse_log': metrics.RootMeanSquaredError( log_scale=True,attach_to=attach_to ), | |||||
'rmse_scale_inv': metrics.ScaleInveriantMeanSquaredError(attach_to=attach_to), | |||||
'abs rel': metrics.AbsoluteRelativeDifference(attach_to=attach_to), | |||||
'sq rel': metrics.SquaredRelativeDifference(attach_to=attach_to), | |||||
'percents within thresholds': metrics.Threshold( thresholds=[1.25, 1.25**2, 1.25**3], attach_to=attach_to ) | |||||
} | |||||
) | |||||
@staticmethod | |||||
def loss_metric(loss_fn): | |||||
return metrics.MetricCompose( | |||||
metric_dict={ | |||||
'loss': metrics.AverageMetric( loss_fn ) | |||||
} | |||||
) |
@@ -0,0 +1,2 @@ | |||||
from ._utils import * | |||||
from .logger import get_logger |
@@ -0,0 +1,45 @@ | |||||
import oneflow | |||||
import contextlib | |||||
import oneflow.nn as nn | |||||
def split_batch(batch): | |||||
if isinstance(batch, (list, tuple)): | |||||
inputs, *targets = batch | |||||
if len(targets)==1: | |||||
targets = targets[0] | |||||
return inputs, targets | |||||
else: | |||||
return [batch, None] | |||||
@contextlib.contextmanager | |||||
def set_mode(model, training=True): | |||||
ori_mode = model.training | |||||
model.train(training) | |||||
yield | |||||
model.train(ori_mode) | |||||
def move_to_device(obj, device): | |||||
if isinstance(obj, oneflow.Tensor): | |||||
return obj.to(device=device) | |||||
elif isinstance( obj, (list, tuple) ): | |||||
return [ o.to(device=device) for o in obj ] | |||||
elif isinstance(obj, nn.Module): | |||||
return obj.to(device=device) | |||||
def flatten_dict(dic): | |||||
flattned = dict() | |||||
def _flatten(prefix, d): | |||||
for k, v in d.items(): | |||||
if isinstance(v, dict): | |||||
if prefix is None: | |||||
_flatten( k, v ) | |||||
else: | |||||
_flatten( prefix+'%s/'%k, v ) | |||||
else: | |||||
flattned[ (prefix+'%s/'%k).strip('/') ] = v | |||||
_flatten('', dic) | |||||
return flattned |
@@ -0,0 +1,56 @@ | |||||
import logging | |||||
import os, sys | |||||
from termcolor import colored | |||||
class _ColorfulFormatter(logging.Formatter): | |||||
def __init__(self, *args, **kwargs): | |||||
super(_ColorfulFormatter, self).__init__(*args, **kwargs) | |||||
def formatMessage(self, record): | |||||
log = super(_ColorfulFormatter, self).formatMessage(record) | |||||
if record.levelno == logging.WARNING: | |||||
prefix = colored("WARNING", "yellow", attrs=["blink"]) | |||||
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: | |||||
prefix = colored("ERROR", "red", attrs=["blink", "underline"]) | |||||
else: | |||||
return log | |||||
return prefix + " " + log | |||||
def get_logger(name='Kamal', output=None, color=True): | |||||
logger = logging.getLogger(name) | |||||
logger.setLevel(logging.DEBUG) | |||||
logger.propagate = False | |||||
# STDOUT | |||||
stdout_handler = logging.StreamHandler( stream=sys.stdout ) | |||||
stdout_handler.setLevel( logging.DEBUG ) | |||||
plain_formatter = logging.Formatter( | |||||
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" ) | |||||
if color: | |||||
formatter = _ColorfulFormatter( | |||||
colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", | |||||
datefmt="%m/%d %H:%M:%S") | |||||
else: | |||||
formatter = plain_formatter | |||||
stdout_handler.setFormatter(formatter) | |||||
logger.addHandler(stdout_handler) | |||||
# FILE | |||||
if output is not None: | |||||
if output.endswith('.txt') or output.endswith('.log'): | |||||
os.makedirs(os.path.dirname(output), exist_ok=True) | |||||
filename = output | |||||
else: | |||||
os.makedirs(output, exist_ok=True) | |||||
filename = os.path.join(output, "log.txt") | |||||
file_handler = logging.FileHandler(filename) | |||||
file_handler.setFormatter(plain_formatter) | |||||
file_handler.setLevel(logging.DEBUG) | |||||
logger.addHandler(file_handler) | |||||
return logger | |||||
@@ -0,0 +1,3 @@ | |||||
from . import models | |||||
from . import datasets | |||||
from . import sync_transforms |
@@ -0,0 +1,5 @@ | |||||
from .stanford_cars import StanfordCars | |||||
from .stanford_dogs import StanfordDogs | |||||
from .flowers102 import Flowers102 | |||||
from .sun397 import SUN397 | |||||
from .fgvc_aircraft import FGVCAircraft |
@@ -0,0 +1,152 @@ | |||||
import oneflow.utils.data as data | |||||
from flowvision.datasets.folder import default_loader | |||||
import os | |||||
import numpy as np | |||||
import random | |||||
def make_dataset(dir, image_ids, targets): | |||||
assert(len(image_ids) == len(targets)) | |||||
images = [] | |||||
dir = os.path.expanduser(dir) | |||||
for i in range(len(image_ids)): | |||||
item = (os.path.join(dir, 'fgvc-aircraft-2013b', 'data', 'images', | |||||
'%s.jpg' % image_ids[i]), targets[i]) | |||||
images.append(item) | |||||
return images | |||||
def find_classes(classes_file): | |||||
# read classes file, separating out image IDs and class names | |||||
image_ids = [] | |||||
targets = [] | |||||
f = open(classes_file, 'r') | |||||
for line in f: | |||||
split_line = line.split(' ') | |||||
image_ids.append(split_line[0]) | |||||
targets.append(' '.join(split_line[1:])) | |||||
f.close() | |||||
# index class names | |||||
classes = np.unique(targets) | |||||
class_to_idx = {classes[i]: i for i in range(len(classes))} | |||||
targets = [class_to_idx[c] for c in targets] | |||||
return (image_ids, targets, classes, class_to_idx) | |||||
class FGVCAircraft(data.Dataset): | |||||
"""`FGVC-Aircraft <http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft>`_ Dataset. | |||||
Args: | |||||
root (string): Root directory path to dataset. | |||||
class_type (string, optional): The level of FGVC-Aircraft fine-grain classification | |||||
to label data with (i.e., ``variant``, ``family``, or ``manufacturer``). | |||||
transforms (callable, optional): A function/transforms that takes in a PIL image | |||||
and returns a transformed version. E.g. ``transforms.RandomCrop`` | |||||
target_transform (callable, optional): A function/transforms that takes in the | |||||
target and transforms it. | |||||
loader (callable, optional): A function to load an image given its path. | |||||
download (bool, optional): If true, downloads the dataset from the internet and | |||||
puts it in the root directory. If dataset is already downloaded, it is not | |||||
downloaded again. | |||||
""" | |||||
url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz' | |||||
class_types = ('variant', 'family', 'manufacturer') | |||||
splits = ('train', 'val', 'trainval', 'test') | |||||
def __init__(self, root, class_type='variant', split='train', s=0.5, transform=None, | |||||
target_transform=None, loader=default_loader, download=False): | |||||
if split not in self.splits: | |||||
raise ValueError('Split "{}" not found. Valid splits are: {}'.format( | |||||
split, ', '.join(self.splits), | |||||
)) | |||||
if class_type not in self.class_types: | |||||
raise ValueError('Class type "{}" not found. Valid class types are: {}'.format( | |||||
class_type, ', '.join(self.class_types), | |||||
)) | |||||
self.root = root | |||||
self.class_type = class_type | |||||
self.split = split | |||||
self.classes_file = os.path.join(self.root, 'fgvc-aircraft-2013b', 'data', | |||||
'images_%s_%s.txt' % (self.class_type, self.split)) | |||||
(image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file) | |||||
"""if split == 'trainval': | |||||
self.image_ids = image_ids | |||||
self.targets = targets | |||||
self.image_ids, self.targets = self.sample_by_class(s=s) | |||||
image_ids = self.image_ids | |||||
targets = self.targets""" | |||||
samples = make_dataset(self.root, image_ids, targets) | |||||
self.transform = transform | |||||
self.target_transform = target_transform | |||||
self.loader = loader | |||||
self.samples = samples | |||||
self.classes = classes | |||||
self.class_to_idx = class_to_idx | |||||
with open(os.path.join(self.root, 'fgvc-aircraft-2013b/data', 'variants.txt')) as f: | |||||
self.object_categories = [ | |||||
line.strip('\n') for line in f.readlines()] | |||||
print('FGVC-Aircraft, Split: %s, Size: %d' % (self.split, self.__len__())) | |||||
def __getitem__(self, index): | |||||
""" | |||||
Args: | |||||
index (int): Index | |||||
Returns: | |||||
tuple: (sample, target) where target is class_index of the target class. | |||||
""" | |||||
path, target = self.samples[index] | |||||
sample = self.loader(path) | |||||
if self.transform is not None: | |||||
sample = self.transform(sample) | |||||
if self.target_transform is not None: | |||||
target = self.target_transform(target) | |||||
return sample, target | |||||
def __len__(self): | |||||
return len(self.samples) | |||||
def __repr__(self): | |||||
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' | |||||
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) | |||||
fmt_str += ' Root Location: {}\n'.format(self.root) | |||||
tmp = ' Transforms (if any): ' | |||||
fmt_str += '{0}{1}\n'.format( | |||||
tmp, self.transforms.__repr__().replace('\n', '\n' + ' ' * len(tmp))) | |||||
tmp = ' Target Transforms (if any): ' | |||||
fmt_str += '{0}{1}'.format( | |||||
tmp, self.target_transforms.__repr__().replace('\n', '\n' + ' ' * len(tmp))) | |||||
return fmt_str | |||||
def _check_exists(self): | |||||
return os.path.exists(os.path.join(self.root, 'data', 'images')) and \ | |||||
os.path.exists(self.classes_file) | |||||
def sample_by_class(self, s): | |||||
class_dit = {} | |||||
image_dit = {} | |||||
for class_name, image in zip(self.targets,self.image_ids): | |||||
if class_name not in class_dit.keys(): | |||||
class_dit[class_name] = [] | |||||
image_dit[class_name] = [] | |||||
class_dit[class_name].append(class_name) | |||||
image_dit[class_name].append(image) | |||||
labels, images = [], [] | |||||
for key in class_dit.keys(): | |||||
n1 = len(class_dit[key]) | |||||
n2 = len(image_dit[key]) | |||||
assert n1 == n2, "{} not equal {}".format(n1, n2) | |||||
random.shuffle(image_dit[key]) | |||||
labels += class_dit[key][:int(n1*s)] | |||||
images += image_dit[key][:int(n1*s)] | |||||
return images, labels |
@@ -0,0 +1,11 @@ | |||||
import flowvision | |||||
import os | |||||
class Flowers102(flowvision.datasets.ImageFolder): | |||||
def __init__(self, root, split='train'): | |||||
self.split = split | |||||
self.num_classes = 102 | |||||
path = os.path.join(root, 'train') if self.split=='train' else os.path.join(root, 'valid') | |||||
super().__init__(path) | |||||
print('Flowers-102, Split: %s, Size: %d' % (self.split, len(self.imgs))) |
@@ -0,0 +1,81 @@ | |||||
import os | |||||
import glob | |||||
from PIL import Image | |||||
import numpy as np | |||||
from scipy.io import loadmat | |||||
from oneflow.utils import data | |||||
import random | |||||
class StanfordCars(data.Dataset): | |||||
"""Dataset for Stanford Cars | |||||
""" | |||||
urls = {'cars_train.tgz': 'http://imagenet.stanford.edu/internal/car196/cars_train.tgz', | |||||
'cars_test.tgz': 'http://imagenet.stanford.edu/internal/car196/cars_test.tgz', | |||||
'car_devkit.tgz': 'https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz', | |||||
'cars_test_annos_withlabels.mat': 'http://imagenet.stanford.edu/internal/car196/cars_test_annos_withlabels.mat'} | |||||
def __init__(self, root, split='train', s=0.5, download=False, transform=None, target_transform=None): | |||||
self.root = os.path.abspath( os.path.expanduser(root) ) | |||||
self.split = split | |||||
self.transform = transform | |||||
self.target_transform = target_transform | |||||
if download: | |||||
self.download() | |||||
if self.split == 'train': | |||||
annos = os.path.join(self.root, 'devkit', 'cars_train_annos.mat') | |||||
else: | |||||
annos = os.path.join(self.root, 'devkit', | |||||
'cars_test_annos_withlabels.mat') | |||||
annos = loadmat(annos) | |||||
size = len(annos['annotations'][0]) | |||||
self.files = glob.glob(os.path.join( | |||||
self.root, 'cars_'+self.split, '*.jpg')) | |||||
self.files.sort() | |||||
"""if split == 'train': | |||||
self.files = self.sample_by_class(s=s)""" | |||||
self.labels = np.array([int(l[4])-1 for l in annos['annotations'][0]]) | |||||
lbl_annos = loadmat(os.path.join(self.root, 'devkit', 'cars_meta.mat')) | |||||
self.object_categories = [str(c[0]) | |||||
for c in lbl_annos['class_names'][0]] | |||||
print('Stanford Cars, Split: %s, Size: %d' % | |||||
(self.split, self.__len__())) | |||||
def __len__(self): | |||||
return len(self.files) | |||||
def __getitem__(self, idx): | |||||
img = Image.open(os.path.join(self.root, 'Images', | |||||
self.files[idx])).convert("RGB") | |||||
lbl = self.labels[idx] | |||||
if self.transform is not None: | |||||
img = self.transform(img) | |||||
if self.target_transform is not None: | |||||
lbl = self.target_transform(lbl) | |||||
return img, lbl | |||||
def sample_by_class(self, s): | |||||
class_dit = {} | |||||
for file in self.files: | |||||
class_name = file.split('/')[0] | |||||
if class_name not in class_dit.keys(): | |||||
class_dit[class_name] = [] | |||||
class_dit[class_name].append(file) | |||||
files = [] | |||||
for key in class_dit.keys(): | |||||
n = len(class_dit[key]) | |||||
random.shuffle(class_dit[key]) | |||||
files += class_dit[key][:int(n*s)] | |||||
return files |
@@ -0,0 +1,67 @@ | |||||
import os | |||||
import numpy as np | |||||
from PIL import Image | |||||
from scipy.io import loadmat | |||||
from oneflow.utils import data | |||||
import random | |||||
class StanfordDogs(data.Dataset): | |||||
"""Dataset for Stanford Dogs | |||||
""" | |||||
urls = {"images.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/images.tar", | |||||
"annotation.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/annotation.tar", | |||||
"lists.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/lists.tar"} | |||||
def __init__(self, root, split='train', s=0.5, download=False, transform=None, target_transform=None): | |||||
self.root = os.path.abspath( os.path.expanduser(root) ) | |||||
self.split = split | |||||
self.transform = transform | |||||
self.target_transform = target_transform | |||||
if download: | |||||
self.download() | |||||
list_file = os.path.join(self.root, 'lists', self.split+'_list.mat') | |||||
mat_file = loadmat(list_file) | |||||
size = len(mat_file['file_list']) | |||||
self.files = [str(mat_file['file_list'][i][0][0]) for i in range(size)] | |||||
"""if split == 'train': | |||||
self.files = self.sample_by_class(s=s)""" | |||||
self.labels = np.array( | |||||
[mat_file['labels'][i][0]-1 for i in range(size)]) | |||||
categories = os.listdir(os.path.join(self.root, 'Images')) | |||||
categories.sort() | |||||
self.object_categories = [c[10:] for c in categories] | |||||
print('Stanford Dogs, Split: %s, Size: %d' % | |||||
(self.split, self.__len__())) | |||||
def __len__(self): | |||||
return len(self.files) | |||||
def __getitem__(self, idx): | |||||
img = Image.open(os.path.join(self.root, 'Images', | |||||
self.files[idx])).convert("RGB") | |||||
lbl = self.labels[idx] | |||||
if self.transform is not None: | |||||
img = self.transform(img) | |||||
if self.target_transform is not None: | |||||
lbl = self.target_transform( lbl ) | |||||
return img, lbl | |||||
def sample_by_class(self, s): | |||||
class_dit = {} | |||||
for file in self.files: | |||||
class_name = file.split('/')[0] | |||||
if class_name not in class_dit.keys(): | |||||
class_dit[class_name] = [] | |||||
class_dit[class_name].append(file) | |||||
files = [] | |||||
for key in class_dit.keys(): | |||||
n = len(class_dit[key]) | |||||
random.shuffle(class_dit[key]) | |||||
files += class_dit[key][:int(n*s)] | |||||
return files |
@@ -0,0 +1,11 @@ | |||||
import flowvision | |||||
import os | |||||
class SUN397(flowvision.datasets.ImageFolder): | |||||
def __init__(self, root, split='train'): | |||||
self.split = split | |||||
self.num_classes = 397 | |||||
path = os.path.join(root, 'train') if self.split=='train' else os.path.join(root, 'test') | |||||
super().__init__(path) | |||||
print('SUN397, Split: %s, Size: %d' % (self.split, len(self.imgs))) |
@@ -0,0 +1 @@ | |||||
from . import classification |
@@ -0,0 +1 @@ | |||||
from .resnet import * |