@@ -0,0 +1,40 @@ | |||
# 模型炼知(Model Knowledge Amalgamation, KA) | |||
## Introduction | |||
模型炼知是用于知识融合和模型迁移性度量,而建立的轻量级算法包,详情请见 [KAE](https://github.com/zju-vipa/KamalEngine/tree/master/examples)。 | |||
## Table of contents | |||
* [Introduction](#Introduction) | |||
* [Algorithms](#Algorithms) | |||
* [Authors](#Authors) | |||
## Algorithms | |||
#### 0. 数据集 | |||
* data/StanfordCars: https://ai.stanford.edu/~jkrause/cars/car_dataset.html | |||
* data/StanfordDogs: http://vision.stanford.edu/aditya86/ImageNetDogs/ | |||
* data/FGVCAircraft: http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/ | |||
* data/Flowers102: http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/ | |||
#### 1. 依赖包 | |||
* [requirements.txt](requirements.txt) | |||
#### 2. 算法 | |||
* TW_Layerwise | |||
```bash | |||
python3 TTL.py --car_ckpt ./ckpt/car_res50_model stanford_dogs --dog_ckpt ./ckpt/dog_res50_model | |||
``` | |||
* TH_Layerwise | |||
```bash | |||
python3 THL.py --car_ckpt ./ckpt/car_res50_model stanford_dogs --dog_ckpt ./ckpt/dog_res50_model --aircraft_ckpt ./ckpt/aircraft_res50_model | |||
``` | |||
* TF_Layerwise | |||
```bash | |||
python3 TFL.py --car_ckpt ./ckpt/car_res50_model stanford_dogs --dog_ckpt ./ckpt/dog_res50_model --aircraft_ckpt ./ckpt/aircraft_res50_model --flower_ckpt ./ckpt/flower_res50_model | |||
``` | |||
* flow | |||
```bash | |||
pytorch weights convert oneflow | |||
``` | |||
#### 3. 权重 | |||
文件夹MODEL_ckpt保存的是三个模型重组/炼知算法的权重,oneflow格式,供复现时参考。ckpt文件夹为pytorch权重转为oneflow权重,用于教师模型加载权重,内存占用较大,改为放在下方链接。 | |||
* 教师模型权重 [链接]:https://pan.baidu.com/s/19pRLZKQvjgEQ7CuD_RiPPw 提取码:FLOW | |||
## Authors | |||
This project is simplified by KAE, which is developed by [VIPA Lab](http://vipazoo.cn) from Zhejiang University and [Zhejiang Lab](http://www.zhejianglab.com/) |
@@ -0,0 +1,124 @@ | |||
from kamal import vision, engine, utils, amalgamation, metrics, callbacks | |||
from kamal.vision import sync_transforms as sT | |||
import pdb | |||
import oneflow | |||
import argparse | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument('--car_ckpt', default="./ckpt/car_res50_model") | |||
parser.add_argument('--dog_ckpt', default="./ckpt/dog_res50_model") | |||
parser.add_argument('--aircraft_ckpt', default="./ckpt/aircraft_res50_model") | |||
parser.add_argument('--flower_ckpt', default="./ckpt/flower_res50_model") | |||
parser.add_argument('--lr', type=float, default=1e-3) | |||
parser.add_argument('--batch_size', type=int, default=16) | |||
args = parser.parse_args() | |||
def main(): | |||
# 数据集 | |||
car_train_dst = vision.datasets.StanfordCars('./DataSets/StanfordCars/', split='train') | |||
car_val_dst = vision.datasets.StanfordCars('./DataSets/StanfordCars/', split='test') | |||
dog_train_dst = vision.datasets.StanfordDogs('./DataSets/StanfordDogs/', split='train') | |||
dog_val_dst = vision.datasets.StanfordDogs('./DataSets/StanfordDogs/', split='test') | |||
aircraft_train_dst = vision.datasets.FGVCAircraft('./DataSets/FGVCAircraft/', split='trainval') | |||
aircraft_val_dst = vision.datasets.FGVCAircraft('./DataSets/FGVCAircraft/', split='test') | |||
flower_train_dst = vision.datasets.Flowers102('./DataSets/Flower102/', split='train') | |||
flower_val_dst = vision.datasets.Flowers102('./DataSets/Flower102/', split='valid') | |||
# 教师/学生 | |||
car_teacher = vision.models.classification.resnet50(num_classes=196, pretrained=False) | |||
dog_teacher = vision.models.classification.resnet50(num_classes=120, pretrained=False) | |||
aircraft_teacher = vision.models.classification.resnet50(num_classes=102, pretrained=False) | |||
flower_teacher = vision.models.classification.resnet50(num_classes=102, pretrained=False) | |||
student = vision.models.classification.resnet50(num_classes=196+120+102+102, pretrained=False) | |||
# 权重参数 | |||
cars_parameters = oneflow.load(args.car_ckpt) | |||
dogs_parameters = oneflow.load(args.dog_ckpt) | |||
aircraft_parameters = oneflow.load(args.aircraft_ckpt) | |||
flowers_parameters = oneflow.load(args.flower_ckpt) | |||
car_teacher.load_state_dict(cars_parameters) | |||
dog_teacher.load_state_dict(dogs_parameters) | |||
aircraft_teacher.load_state_dict(aircraft_parameters) | |||
flower_teacher.load_state_dict(flowers_parameters) | |||
train_transform = sT.Compose( [ | |||
sT.RandomResizedCrop(224), | |||
sT.RandomHorizontalFlip(), | |||
sT.ToTensor(), | |||
sT.Normalize( mean=[0.485, 0.456, 0.406], | |||
std=[0.229, 0.224, 0.225] ) | |||
] ) | |||
val_transform = sT.Compose( [ | |||
sT.Resize(256), | |||
sT.CenterCrop( 224 ), | |||
sT.ToTensor(), | |||
sT.Normalize( mean=[0.485, 0.456, 0.406], | |||
std=[0.229, 0.224, 0.225] ) | |||
] ) | |||
car_train_dst.transform = dog_train_dst.transform = aircraft_train_dst.transform = flower_train_dst.transform = train_transform | |||
car_val_dst.transform = dog_val_dst.transform = aircraft_val_dst.transform = flower_val_dst.transform = val_transform | |||
car_metric = metrics.MetricCompose(metric_dict={'car_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, :196], t))}) | |||
dog_metric = metrics.MetricCompose(metric_dict={'dog_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, 196:196+120], t))}) | |||
aircraft_metric = metrics.MetricCompose(metric_dict={'aircraft_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, 196+120:196+120+102], t))}) | |||
flower_metric = metrics.MetricCompose(metric_dict={'flower_acc': metrics.Accuracy(attach_to=lambda o, t: ( o[:, 196+120+102:196+120+102+102], t ) ) } ) | |||
train_dst = oneflow.utils.data.ConcatDataset([car_train_dst, dog_train_dst, aircraft_train_dst, flower_train_dst]) | |||
train_loader = oneflow.utils.data.DataLoader(train_dst, batch_size=args.batch_size, shuffle=True, num_workers=0) | |||
car_loader = oneflow.utils.data.DataLoader(car_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0) | |||
dog_loader = oneflow.utils.data.DataLoader(dog_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0) | |||
aircraft_loader = oneflow.utils.data.DataLoader(aircraft_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=True) | |||
flower_loader = oneflow.utils.data.DataLoader(flower_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0) | |||
car_evaluator = engine.evaluator.BasicEvaluator(car_loader, car_metric) | |||
dog_evaluator = engine.evaluator.BasicEvaluator(dog_loader, dog_metric) | |||
aircraft_evaluator = engine.evaluator.BasicEvaluator(aircraft_loader, aircraft_metric) | |||
flower_evaluator = engine.evaluator.BasicEvaluator(flower_loader, flower_metric) | |||
TOTAL_ITERS=len(train_loader) * 100 | |||
device = oneflow.device( 'cuda' if oneflow.cuda.is_available() else 'cpu' ) | |||
optim = oneflow.optim.Adam( student.parameters(), lr=args.lr, weight_decay=1e-4) | |||
sched = oneflow.optim.lr_scheduler.CosineAnnealingLR( optim, T_max=TOTAL_ITERS ) | |||
trainer = amalgamation.LayerWiseAmalgamator( | |||
logger=utils.logger.get_logger('layerwise-ka'), | |||
# tb_writer=SummaryWriter( log_dir='run/layerwise_ka-%s'%( time.asctime().replace( ' ', '_' ) ) ) | |||
) | |||
trainer.add_callback( | |||
engine.DefaultEvents.AFTER_STEP(every=10), | |||
callbacks=callbacks.MetricsLogging(keys=('total_loss', 'loss_kd', 'loss_amal', 'loss_recons', 'lr'))) | |||
trainer.add_callback( | |||
engine.DefaultEvents.AFTER_EPOCH, | |||
callbacks=[ | |||
callbacks.EvalAndCkpt(model=student, evaluator=car_evaluator, metric_name='car_acc', ckpt_prefix='tfl_car'), | |||
callbacks.EvalAndCkpt(model=student, evaluator=dog_evaluator, metric_name='dog_acc', ckpt_prefix='tfl_dog'), | |||
callbacks.EvalAndCkpt(model=student, evaluator=aircraft_evaluator, metric_name='aircraft_acc', ckpt_prefix='tfl_aircraft'), | |||
callbacks.EvalAndCkpt(model=student, evaluator=flower_evaluator, metric_name='flower_acc', ckpt_prefix='tfl_flower'), | |||
] ) | |||
trainer.add_callback( | |||
engine.DefaultEvents.AFTER_STEP, | |||
callbacks=callbacks.LRSchedulerCallback(schedulers=[sched])) | |||
layer_groups = [] | |||
layer_channels = [] | |||
for stu_block, car_block, dog_block, aircraft_block, flower_block in zip( student.modules(), car_teacher.modules(), dog_teacher.modules(), aircraft_teacher.modules(), flower_teacher.modules() ): | |||
if isinstance( stu_block, oneflow.nn.Conv2d ): | |||
layer_groups.append( [ stu_block, car_block, dog_block, aircraft_block, flower_block ] ) | |||
layer_channels.append( [ stu_block.out_channels, car_block.out_channels, dog_block.out_channels, aircraft_block.out_channels, flower_block.out_channels ] ) | |||
trainer.setup( student=student, | |||
teachers=[car_teacher, dog_teacher, aircraft_teacher, flower_teacher], | |||
layer_groups=layer_groups, | |||
layer_channels=layer_channels, | |||
dataloader=train_loader, | |||
optimizer=optim, | |||
device=device, | |||
weights=[1., 1., 1.] ) | |||
trainer.run(start_iter=0, max_iter=TOTAL_ITERS) | |||
if __name__=='__main__': | |||
main() |
@@ -0,0 +1,116 @@ | |||
from kamal import vision, engine, utils, amalgamation, metrics, callbacks | |||
from kamal.vision import sync_transforms as sT | |||
import pdb | |||
import oneflow | |||
import argparse | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument('--car_ckpt', default="./ckpt/car_res50_model") | |||
parser.add_argument('--dog_ckpt', default="./ckpt/dog_res50_model") | |||
parser.add_argument('--aircraft_ckpt', default="./ckpt/aircraft_res50_model") | |||
parser.add_argument('--lr', type=float, default=1e-3) | |||
parser.add_argument('--batch_size', type=int, default=32) | |||
args = parser.parse_args() | |||
def main(): | |||
# 数据集 | |||
car_train_dst = vision.datasets.StanfordCars('./DataSets/StanfordCars/', split='train') | |||
car_val_dst = vision.datasets.StanfordCars('./DataSets/StanfordCars/', split='test') | |||
dog_train_dst = vision.datasets.StanfordDogs('./DataSets/StanfordDogs/', split='train') | |||
dog_val_dst = vision.datasets.StanfordDogs('./DataSets/StanfordDogs/', split='test') | |||
aircraft_train_dst = vision.datasets.FGVCAircraft('./DataSets/FGVCAircraft/', split='trainval') | |||
aircraft_val_dst = vision.datasets.FGVCAircraft('./DataSets/FGVCAircraft/', split='test') | |||
# 教师/学生 | |||
car_teacher = vision.models.classification.resnet50(num_classes=196, pretrained=False) | |||
dog_teacher = vision.models.classification.resnet50(num_classes=120, pretrained=False) | |||
aircraft_teacher = vision.models.classification.resnet50(num_classes=102, pretrained=False) | |||
student = vision.models.classification.resnet50(num_classes=196+120+102, pretrained=False) | |||
# 权重参数 | |||
cars_parameters = oneflow.load(args.car_ckpt) | |||
dogs_parameters = oneflow.load(args.dog_ckpt) | |||
aircraft_parameters = oneflow.load(args.aircraft_ckpt) | |||
car_teacher.load_state_dict(cars_parameters) | |||
dog_teacher.load_state_dict(dogs_parameters) | |||
aircraft_teacher.load_state_dict(aircraft_parameters) | |||
train_transform = sT.Compose( [ | |||
sT.RandomResizedCrop(224), | |||
sT.RandomHorizontalFlip(), | |||
sT.ToTensor(), | |||
sT.Normalize( mean=[0.485, 0.456, 0.406], | |||
std=[0.229, 0.224, 0.225] ) | |||
] ) | |||
val_transform = sT.Compose( [ | |||
sT.Resize(256), | |||
sT.CenterCrop( 224 ), | |||
sT.ToTensor(), | |||
sT.Normalize( mean=[0.485, 0.456, 0.406], | |||
std=[0.229, 0.224, 0.225] ) | |||
] ) | |||
car_train_dst.transform = dog_train_dst.transform = aircraft_train_dst.transform = train_transform | |||
car_val_dst.transform = dog_val_dst.transform = aircraft_val_dst.transform = val_transform | |||
car_metric = metrics.MetricCompose(metric_dict={'car_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, :196], t))}) | |||
dog_metric = metrics.MetricCompose(metric_dict={'dog_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, 196:196+120], t))}) | |||
aircraft_metric = metrics.MetricCompose(metric_dict={'aircraft_acc': metrics.Accuracy(attach_to=lambda o, t: (o[:, 196+120:196+120+102], t))}) | |||
train_dst = oneflow.utils.data.ConcatDataset( [car_train_dst, dog_train_dst, aircraft_train_dst] ) | |||
# pdb.set_trace() | |||
train_loader = oneflow.utils.data.DataLoader( train_dst, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True ) | |||
car_loader = oneflow.utils.data.DataLoader( car_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=True ) | |||
dog_loader = oneflow.utils.data.DataLoader( dog_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=True) | |||
aircraft_loader = oneflow.utils.data.DataLoader( aircraft_val_dst, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=True ) | |||
car_evaluator = engine.evaluator.BasicEvaluator( car_loader, car_metric ) | |||
dog_evaluator = engine.evaluator.BasicEvaluator( dog_loader, dog_metric ) | |||
aircraft_evaluator = engine.evaluator.BasicEvaluator( aircraft_loader, aircraft_metric ) | |||
# pdb.set_trace() | |||
TOTAL_ITERS=len(train_loader) * 100 | |||
device = oneflow.device( 'cuda' if oneflow.cuda.is_available() else 'cpu' ) | |||
optim = oneflow.optim.Adam( student.parameters(), lr=args.lr, weight_decay=1e-4) | |||
sched = oneflow.optim.lr_scheduler.CosineAnnealingLR( optim, T_max=TOTAL_ITERS ) | |||
trainer = amalgamation.LayerWiseAmalgamator( | |||
logger=utils.logger.get_logger('layerwise-ka'), | |||
# tb_writer=SummaryWriter( log_dir='run/layerwise_ka-%s'%( time.asctime().replace( ' ', '_' ) ) ) | |||
) | |||
trainer.add_callback( | |||
engine.DefaultEvents.AFTER_STEP(every=10), | |||
callbacks=callbacks.MetricsLogging(keys=('total_loss', 'loss_kd', 'loss_amal', 'loss_recons', 'lr'))) | |||
trainer.add_callback( | |||
engine.DefaultEvents.AFTER_EPOCH, | |||
callbacks=[ | |||
callbacks.EvalAndCkpt(model=student, evaluator=car_evaluator, metric_name='car_acc', ckpt_prefix='thl_car'), | |||
callbacks.EvalAndCkpt(model=student, evaluator=dog_evaluator, metric_name='dog_acc', ckpt_prefix='thl_dog'), | |||
callbacks.EvalAndCkpt(model=student, evaluator=aircraft_evaluator, metric_name='aircraft_acc', ckpt_prefix='thl_aircraft'), | |||
] ) | |||
trainer.add_callback( | |||
engine.DefaultEvents.AFTER_STEP, | |||
callbacks=callbacks.LRSchedulerCallback(schedulers=[sched])) | |||
layer_groups = [] | |||
layer_channels = [] | |||
for stu_block, car_block, dog_block, aircraft_block in zip( student.modules(), car_teacher.modules(), dog_teacher.modules(), aircraft_teacher.modules() ): | |||
if isinstance( stu_block, oneflow.nn.Conv2d ): | |||
layer_groups.append( [ stu_block, car_block, dog_block, aircraft_block ] ) | |||
layer_channels.append( [ stu_block.out_channels, car_block.out_channels, dog_block.out_channels, aircraft_block.out_channels ] ) | |||
trainer.setup( student=student, | |||
teachers=[car_teacher, dog_teacher, aircraft_teacher], | |||
layer_groups=layer_groups, | |||
layer_channels=layer_channels, | |||
dataloader=train_loader, | |||
optimizer=optim, | |||
device=device, | |||
weights=[1., 1., 1.] ) | |||
trainer.run(start_iter=0, max_iter=TOTAL_ITERS) | |||
if __name__=='__main__': | |||
main() |
@@ -0,0 +1,105 @@ | |||
from kamal import vision, engine, utils, amalgamation, metrics, callbacks | |||
from kamal.vision import sync_transforms as sT | |||
import pdb | |||
import oneflow | |||
import argparse | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument( '--car_ckpt', default="./ckpt/car_res50_model") | |||
parser.add_argument( '--dog_ckpt', default="./ckpt/dog_res50_model") | |||
parser.add_argument( '--lr', type=float, default=1e-3) | |||
args = parser.parse_args() | |||
def main(): | |||
# 数据集 | |||
car_train_dst = vision.datasets.StanfordCars('./DataSets/StanfordCars/', split='train', s=0.1) | |||
car_val_dst = vision.datasets.StanfordCars('./DataSets/StanfordCars/', split='test') | |||
aircraft_train_dst = vision.datasets.StanfordDogs('./DataSets/StanfordDogs/', split='train', s=0.1) | |||
aircraft_val_dst = vision.datasets.StanfordDogs('./DataSets/StanfordDogs/', split='test') | |||
# 教师/学生 | |||
car_teacher = vision.models.classification.resnet50(num_classes=196, pretrained=False) | |||
dog_teacher = vision.models.classification.resnet50(num_classes=120, pretrained=False) | |||
student = vision.models.classification.resnet50(num_classes=196+120, pretrained=False) | |||
# 权重参数 | |||
cars_parameters = oneflow.load(args.car_ckpt) | |||
dogs_parameters = oneflow.load(args.dog_ckpt) | |||
car_teacher.load_state_dict(cars_parameters) | |||
dog_teacher.load_state_dict(dogs_parameters) | |||
train_transform = sT.Compose( [ | |||
sT.RandomResizedCrop(224), | |||
sT.RandomHorizontalFlip(), | |||
sT.ToTensor(), | |||
sT.Normalize( mean=[0.485, 0.456, 0.406], | |||
std=[0.229, 0.224, 0.225] ) | |||
] ) | |||
val_transform = sT.Compose( [ | |||
sT.Resize(256), | |||
sT.CenterCrop( 224 ), | |||
sT.ToTensor(), | |||
sT.Normalize( mean=[0.485, 0.456, 0.406], | |||
std=[0.229, 0.224, 0.225] ) | |||
] ) | |||
car_train_dst.transform = aircraft_train_dst.transform = train_transform | |||
car_val_dst.transform = aircraft_val_dst.transform = val_transform | |||
car_metric = metrics.MetricCompose(metric_dict={ 'car_acc': metrics.Accuracy(attach_to=lambda o, t: ( o[:, :196], t ) ) } ) | |||
aircraft_metric = metrics.MetricCompose(metric_dict={ 'aircraft_acc': metrics.Accuracy(attach_to=lambda o, t: ( o[:, 196:], t ) ) } ) | |||
train_dst = oneflow.utils.data.ConcatDataset( [car_train_dst, aircraft_train_dst] ) | |||
# pdb.set_trace() | |||
train_loader = oneflow.utils.data.DataLoader( train_dst, batch_size=32, shuffle=True, num_workers=0 ) | |||
car_loader = oneflow.utils.data.DataLoader( car_val_dst, batch_size=32, shuffle=False, num_workers=0 ) | |||
aircraft_loader = oneflow.utils.data.DataLoader( aircraft_val_dst, batch_size=32, shuffle=False, num_workers=0 ) | |||
car_evaluator = engine.evaluator.BasicEvaluator( car_loader, car_metric ) | |||
aircraft_evaluator = engine.evaluator.BasicEvaluator( aircraft_loader, aircraft_metric ) | |||
TOTAL_ITERS=len(train_loader) * 100 | |||
device = oneflow.device( 'cuda' if oneflow.cuda.is_available() else 'cpu' ) | |||
optim = oneflow.optim.Adam( student.parameters(), lr=args.lr, weight_decay=1e-4) | |||
sched = oneflow.optim.lr_scheduler.CosineAnnealingLR( optim, T_max=TOTAL_ITERS ) | |||
trainer = amalgamation.LayerWiseAmalgamator( | |||
logger=utils.logger.get_logger('layerwise-ka'), | |||
# tb_writer=SummaryWriter( log_dir='run/layerwise_ka-%s'%( time.asctime().replace( ' ', '_' ) ) ) | |||
) | |||
trainer.add_callback( | |||
engine.DefaultEvents.AFTER_STEP(every=10), | |||
callbacks=callbacks.MetricsLogging(keys=('total_loss', 'loss_kd', 'loss_amal', 'loss_recons', 'lr'))) | |||
trainer.add_callback( | |||
engine.DefaultEvents.AFTER_EPOCH, | |||
callbacks=[ | |||
callbacks.EvalAndCkpt(model=student, evaluator=car_evaluator, metric_name='car_acc', ckpt_prefix='ttl_car'), | |||
callbacks.EvalAndCkpt(model=student, evaluator=aircraft_evaluator, metric_name='aircraft_acc', ckpt_prefix='ttl_aircraft'), | |||
] ) | |||
trainer.add_callback( | |||
engine.DefaultEvents.AFTER_STEP, | |||
callbacks=callbacks.LRSchedulerCallback(schedulers=[sched])) | |||
layer_groups = [] | |||
layer_channels = [] | |||
for stu_block, car_block, aircraft_block in zip( student.modules(), car_teacher.modules(), dog_teacher.modules() ): | |||
if isinstance( stu_block, oneflow.nn.Conv2d ): | |||
layer_groups.append( [ stu_block, car_block, aircraft_block ] ) | |||
layer_channels.append( [ stu_block.out_channels, car_block.out_channels, aircraft_block.out_channels ] ) | |||
trainer.setup( student=student, | |||
teachers=[car_teacher, dog_teacher], | |||
layer_groups=layer_groups, | |||
layer_channels=layer_channels, | |||
dataloader=train_loader, | |||
optimizer=optim, | |||
device=device, | |||
weights=[1., 1., 1.] ) | |||
trainer.run(start_iter=0, max_iter=TOTAL_ITERS) | |||
if __name__=='__main__': | |||
main() |
@@ -0,0 +1,41 @@ | |||
""" | |||
pytorch -- oneflow 格式权重转换 | |||
""" | |||
from kamal import vision | |||
import torch | |||
import argparse | |||
import oneflow as flow | |||
import pdb | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument( '--car_ckpt', required=True ) | |||
parser.add_argument( '--dog_ckpt' ) | |||
args = parser.parse_args() | |||
cars_parameters = torch.load(args.car_ckpt).state_dict() | |||
# dogs_parameters = torch.load(args.dog_ckpt).state_dict() | |||
cars_para, dogs_para = {}, {} | |||
# pdb.set_trace() | |||
for key, value in cars_parameters.items(): | |||
val = value.detach().cpu().numpy() | |||
if not str(key).endswith('num_batches_tracked'): | |||
cars_para[key] = val | |||
# for key, value in dogs_parameters.items(): | |||
# val = value.detach().cpu().numpy() | |||
# if not str(key).endswith('num_batches_tracked'): | |||
# dogs_para[key] = val | |||
car_teacher = vision.models.classification.resnet50(num_classes=102, pretrained=False) | |||
# dog_teacher = vision.models.classification.resnet50(num_classes=397, pretrained=False) | |||
car_teacher.load_state_dict(cars_para) | |||
# dog_teacher.load_state_dict(dogs_para) | |||
# torch.save(car_teacher, 'ckpt/aircraft_res50.pth') | |||
# torch.save(dog_teacher, 'checkpoint/sun_res50.pth') | |||
flow.save(car_teacher.state_dict(), "./ckpt/aircraft_res50_model") | |||
# flow.save(dog_teacher.state_dict(), "./checkpoint/sun_res50_model") |
@@ -0,0 +1,3 @@ | |||
from .core import metrics, engine, callbacks | |||
from . import amalgamation, vision |
@@ -0,0 +1 @@ | |||
from .layerwise_amalgamation import LayerWiseAmalgamator |
@@ -0,0 +1,113 @@ | |||
import oneflow as oneflow | |||
import oneflow.nn as nn | |||
from kamal.core.engine.engine import Engine | |||
from kamal.core.engine.hooks import FeatureHook | |||
from kamal.core import tasks | |||
import typing | |||
import time | |||
from kamal.utils import move_to_device, set_mode | |||
class AmalBlock(nn.Module): | |||
def __init__(self, cs, cts): | |||
super( AmalBlock, self ).__init__() | |||
self.cs, self.cts = cs, cts | |||
self.enc = nn.Conv2d( in_channels=sum(self.cts), out_channels=self.cs, kernel_size=1, stride=1, padding=0, bias=True ) | |||
self.fam = nn.Conv2d( in_channels=self.cs, out_channels=self.cs, kernel_size=1, stride=1, padding=0, bias=True ) | |||
self.dec = nn.Conv2d( in_channels=self.cs, out_channels=sum(self.cts), kernel_size=1, stride=1, padding=0, bias=True ) | |||
def forward(self, fs, fts): | |||
rep = self.enc( oneflow.cat( fts, dim=1 ) ) | |||
_fts = self.dec( rep ) | |||
_fts = oneflow.split( _fts, self.cts, dim=1 ) | |||
_fs = self.fam( fs ) | |||
return rep, _fs, _fts | |||
class LayerWiseAmalgamator(Engine): | |||
def setup( | |||
self, | |||
student, | |||
teachers, | |||
layer_groups: typing.Sequence[typing.Sequence], | |||
layer_channels: typing.Sequence[typing.Sequence], | |||
dataloader: oneflow.utils.data.DataLoader, | |||
optimizer: oneflow.optim.Optimizer, | |||
weights = [1., 1., 1.], | |||
device=None, | |||
): | |||
if device is None: | |||
device = oneflow.device('cuda:0') | |||
# device = oneflow.device( 'cuda' if oneflow.cuda.is_available() else 'cpu' ) | |||
self._device = device | |||
self._dataloader = dataloader | |||
self.model = self.student = student.to(self.device) | |||
self.teachers = nn.ModuleList(teachers).to(self.device) | |||
self.optimizer = optimizer | |||
self._weights = weights | |||
amal_blocks = [] | |||
for group, C in zip(layer_groups, layer_channels): | |||
hooks = [ FeatureHook(layer) for layer in group ] | |||
amal_block = AmalBlock(cs=C[0], cts=C[1:]).to(self.device).train() | |||
amal_blocks.append( (amal_block, hooks, C) ) | |||
self._amal_blocks = amal_blocks | |||
@property | |||
def device(self): | |||
return self._device | |||
def run(self, max_iter, start_iter=0, epoch_length=None ): | |||
block_params = [] | |||
for block, _, _ in self._amal_blocks: | |||
block_params.extend( list(block.parameters()) ) | |||
if isinstance( self.optimizer, oneflow.optim.SGD ): | |||
self._amal_optimimizer = oneflow.optim.SGD( block_params, lr=self.optimizer.param_groups[0]['lr'], momentum=0.9, weight_decay=1e-4 ) | |||
else: | |||
self._amal_optimimizer = oneflow.optim.Adam( block_params, lr=self.optimizer.param_groups[0]['lr'], weight_decay=1e-4 ) | |||
self._amal_scheduler = oneflow.optim.lr_scheduler.CosineAnnealingLR( self._amal_optimimizer, T_max=max_iter ) | |||
with set_mode(self.student, training=True), \ | |||
set_mode(self.teachers, training=False): | |||
super( LayerWiseAmalgamator, self ).run(self.step_fn, self._dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length) | |||
@property | |||
def device(self): | |||
return self._device | |||
def step_fn(self, engine, batch): | |||
start_time = time.perf_counter() | |||
batch = move_to_device(batch, self._device) | |||
data = batch[0] | |||
s_out = self.student( data ) | |||
with oneflow.no_grad(): | |||
t_out = [ teacher( data ) for teacher in self.teachers ] | |||
loss_amal = 0 | |||
loss_recons = 0 | |||
for amal_block, hooks, C in self._amal_blocks: | |||
features = [ h.feat_out for h in hooks ] | |||
fs, fts = features[0], features[1:] | |||
rep, _fs, _fts = amal_block( fs, fts ) | |||
mse_loss = nn.MSELoss() | |||
loss_amal += mse_loss( _fs, rep.detach() ) | |||
loss_recons += sum( [ mse_loss( _ft, ft ) for (_ft, ft) in zip( _fts, fts ) ] ) | |||
loss_kd = tasks.loss.kldiv( s_out, oneflow.cat( t_out, dim=1 ) ) | |||
#loss_kd = F.mse_loss( s_out, oneflow.cat( t_out, dim=1 ) ) | |||
loss_dict = { "loss_kd": self._weights[0] * loss_kd, | |||
"loss_amal": self._weights[1] * loss_amal, | |||
"loss_recons": self._weights[2] * loss_recons } | |||
loss = sum(loss_dict.values()) | |||
self.optimizer.zero_grad() | |||
self._amal_optimimizer.zero_grad() | |||
loss.backward() | |||
self.optimizer.step() | |||
self._amal_optimimizer.step() | |||
self._amal_scheduler.step() | |||
step_time = time.perf_counter() - start_time | |||
metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } | |||
metrics.update({ | |||
'total_loss': loss.item(), | |||
'step_time': step_time, | |||
'lr': float( self.optimizer.param_groups[0]['lr'] ) | |||
}) | |||
return metrics |
@@ -0,0 +1,2 @@ | |||
from . import engine, tasks, metrics, callbacks, exceptions | |||
from .attach import AttachTo |
@@ -0,0 +1,28 @@ | |||
from typing import Sequence, Callable | |||
from numbers import Number | |||
from kamal.core import exceptions | |||
import pdb | |||
class AttachTo(object): | |||
""" Attach task, metrics or visualizer to specified model outputs | |||
""" | |||
def __init__(self, attach_to=None): | |||
if attach_to is not None and not isinstance(attach_to, (Sequence, Number, str, Callable) ): | |||
raise exceptions.InvalidMapping | |||
self._attach_to = attach_to | |||
def __call__(self, *tensors): | |||
if self._attach_to is not None: | |||
if isinstance(self._attach_to, Callable): | |||
return self._attach_to( *tensors ) | |||
if isinstance(self._attach_to, Sequence): | |||
_attach_to = self._attach_to | |||
else: | |||
_attach_to = [ self._attach_to for _ in range(len(tensors)) ] | |||
_attach_to = _attach_to[:len(tensors)] | |||
tensors = [ tensor[index] for (tensor, index) in zip( tensors, _attach_to ) ] | |||
if len(tensors)==1: | |||
tensors = tensors[0] | |||
return tensors | |||
def __repr__(self): | |||
rep = "AttachTo: %s"%(self._attach_to) |
@@ -0,0 +1,4 @@ | |||
from .logging import MetricsLogging, ProgressCallback | |||
from .base import Callback | |||
from .eval_and_ckpt import EvalAndCkpt | |||
from .scheduler import LRSchedulerCallback |
@@ -0,0 +1,26 @@ | |||
# Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================= | |||
import abc | |||
class Callback(abc.ABC): | |||
r""" Base Class for Callbacks | |||
""" | |||
def __init__(self): | |||
pass | |||
@abc.abstractmethod | |||
def __call__(self, engine): | |||
pass |
@@ -0,0 +1,133 @@ | |||
from .base import Callback | |||
import weakref | |||
from kamal import utils | |||
from typing import Sequence, Optional | |||
import numbers | |||
import os, shutil | |||
import oneflow | |||
import pdb | |||
class EvalAndCkpt(Callback): | |||
def __init__(self, | |||
model, | |||
evaluator, | |||
metric_name:str, | |||
metric_mode:str ='max', | |||
save_type:Optional[Sequence]=('best', 'latest'), | |||
ckpt_dir:str ='checkpoints', | |||
ckpt_prefix:str =None, | |||
log_tag:str ='model', | |||
weights_only:bool =True, | |||
verbose:bool =False,): | |||
super(EvalAndCkpt, self).__init__() | |||
self.metric_name = metric_name | |||
assert metric_mode in ('max', 'min'), "metric_mode should be 'max' or 'min'" | |||
self._metric_mode = metric_mode | |||
self._model = weakref.ref( model ) | |||
self._evaluator = evaluator | |||
self._ckpt_dir = ckpt_dir | |||
self._ckpt_prefix = "" if ckpt_prefix is None else (ckpt_prefix+'_') | |||
if isinstance(save_type, str): | |||
save_type = ( save_type, ) | |||
self._save_type = save_type | |||
if self._save_type is not None: | |||
for save_type in self._save_type: | |||
assert save_type in ('best', 'latest', 'all'), \ | |||
'save_type should be None or a subset of (\"best\", \"latest\", \"all\")' | |||
self._log_tag = log_tag | |||
self._weights_only = weights_only | |||
self._verbose = verbose | |||
self._best_score = -999999 if self._metric_mode=='max' else 99999. | |||
self._best_ckpt = None | |||
self._latest_ckpt = None | |||
@property | |||
def best_ckpt(self): | |||
return self._best_ckpt | |||
@property | |||
def latest_ckpt(self): | |||
return self._latest_ckpt | |||
@property | |||
def best_score(self): | |||
return self._best_score | |||
def __call__(self, trainer): | |||
model = self._model() | |||
results = self._evaluator.eval( model, device=trainer.device ) | |||
results = utils.flatten_dict(results) | |||
# pdb.set_trace() | |||
scalar_results = { k: v.item() for k, v in results.items()} | |||
current_score = scalar_results[self.metric_name] | |||
if trainer.logger is not None: | |||
trainer.logger.info( "[Eval %s] Iter %d/%d: %s"%(self._log_tag, trainer.state.iter, trainer.state.max_iter, scalar_results) ) | |||
trainer.state.metrics.update( scalar_results ) | |||
# Visualize results if trainer.tb_writer is not None | |||
if trainer.tb_writer is not None: | |||
for k, v in scalar_results.items(): | |||
log_tag = "%s:%s"%(self._log_tag, k) | |||
trainer.tb_writer.add_scalar(log_tag, v, global_step=trainer.state.iter) | |||
if self._save_type is not None: | |||
pth_path_list = [] | |||
# interval model | |||
if 'interval' in self._save_type: | |||
pth_path = os.path.join(self._ckpt_dir, "%s%08d_%s_%.3f" | |||
% (self._ckpt_prefix, trainer.state.iter, self.metric_name, current_score)) | |||
pth_path_list.append(pth_path) | |||
# the latest model | |||
if 'latest' in self._save_type: | |||
pth_path = os.path.join(self._ckpt_dir, "%slatest_%08d_%s_%.3f" | |||
% (self._ckpt_prefix, trainer.state.iter, self.metric_name, current_score)) | |||
# remove the old ckpt | |||
if self._latest_ckpt is not None and os.path.exists(self._latest_ckpt): | |||
# os.remove(self._latest_ckpt) | |||
shutil.rmtree(self._latest_ckpt) | |||
pth_path_list.append(pth_path) | |||
self._latest_ckpt = pth_path | |||
# the best model | |||
if 'best' in self._save_type: | |||
if (current_score >= self._best_score and self._metric_mode=='max') or \ | |||
(current_score <= self._best_score and self._metric_mode=='min'): | |||
pth_path = os.path.join(self._ckpt_dir, "%sbest_%08d_%s_%.4f" % | |||
(self._ckpt_prefix, trainer.state.iter, self.metric_name, current_score)) | |||
# remove the old ckpt | |||
if self._best_ckpt is not None and os.path.exists(self._best_ckpt): | |||
# os.remove(self._best_ckpt) | |||
shutil.rmtree(self._best_ckpt) | |||
pth_path_list.append(pth_path) | |||
self._best_score = current_score | |||
self._best_ckpt = pth_path | |||
# save model | |||
if self._verbose and trainer.logger: | |||
trainer.logger.info("Model saved as:") | |||
obj = model.state_dict() if self._weights_only else model | |||
os.makedirs( self._ckpt_dir, exist_ok=True ) | |||
for pth_path in pth_path_list: | |||
oneflow.save(obj, pth_path) | |||
if self._verbose and trainer.logger: | |||
trainer.logger.info("\t%s" % (pth_path)) | |||
def final_ckpt(self, ckpt_prefix=None, ckpt_dir=None, add_md5=False): | |||
if ckpt_dir is None: | |||
ckpt_dir = self._ckpt_dir | |||
if ckpt_prefix is None: | |||
ckpt_prefix = self._ckpt_prefix | |||
if self._save_type is not None: | |||
#if 'latest' in self._save_type and self._latest_ckpt is not None: | |||
# os.makedirs(ckpt_dir, exist_ok=True) | |||
# save_name = "%slatest%s.pth"%(ckpt_prefix, "" if not add_md5 else "-%s"%utils.md5(self._latest_ckpt)) | |||
# shutil.copy2(self._latest_ckpt, os.path.join(ckpt_dir, save_name)) | |||
if 'best' in self._save_type and self._best_ckpt is not None: | |||
os.makedirs(ckpt_dir, exist_ok=True) | |||
save_name = "%sbest%s.pth"%(ckpt_prefix, "" if not add_md5 else "-%s"%utils.md5(self._best_ckpt)) | |||
shutil.copy2(self._best_ckpt, os.path.join(ckpt_dir, save_name)) |
@@ -0,0 +1,59 @@ | |||
# Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================= | |||
from .base import Callback | |||
import numbers | |||
from tqdm import tqdm | |||
class MetricsLogging(Callback): | |||
def __init__(self, keys): | |||
super(MetricsLogging, self).__init__() | |||
self._keys = keys | |||
def __call__(self, engine): | |||
if engine.logger==None: | |||
return | |||
state = engine.state | |||
content = "Iter %d/%d (Epoch %d/%d, Batch %d/%d)"%( | |||
state.iter, state.max_iter, | |||
state.current_epoch, state.max_epoch, | |||
state.current_batch_index, state.max_batch_index | |||
) | |||
for key in self._keys: | |||
value = state.metrics.get(key, None) | |||
if value is not None: | |||
if isinstance(value, numbers.Number): | |||
content += " %s=%.4f"%(key, value) | |||
if engine.tb_writer is not None: | |||
engine.tb_writer.add_scalar(key, value, global_step=state.iter) | |||
elif isinstance(value, (list, tuple)): | |||
content += " %s=%s"%(key, value) | |||
engine.logger.info(content) | |||
class ProgressCallback(Callback): | |||
def __init__(self, max_iter=100, tag=None): | |||
self._max_iter = max_iter | |||
self._tag = tag | |||
#self._pbar = tqdm(total=self._max_iter, desc=self._tag) | |||
def __call__(self, engine): | |||
self._pbar.update(1) | |||
if self._pbar.n==self._max_iter: | |||
self._pbar.close() | |||
def reset(self): | |||
self._pbar = tqdm(total=self._max_iter, desc=self._tag) | |||
@@ -0,0 +1,32 @@ | |||
# Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================= | |||
from .base import Callback | |||
from typing import Sequence | |||
class LRSchedulerCallback(Callback): | |||
r""" LR scheduler callback | |||
""" | |||
def __init__(self, schedulers=None): | |||
super(LRSchedulerCallback, self).__init__() | |||
if not isinstance(schedulers, Sequence): | |||
schedulers = ( schedulers, ) | |||
self._schedulers = schedulers | |||
def __call__(self, trainer): | |||
if self._schedulers is None: | |||
return | |||
for sched in self._schedulers: | |||
sched.step() |
@@ -0,0 +1,7 @@ | |||
from . import evaluator | |||
from . import trainer | |||
from . import hooks | |||
from . import events | |||
from .engine import DefaultEvents, Event |
@@ -0,0 +1,188 @@ | |||
# Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================= | |||
import oneflow.nn as nn | |||
import abc, weakref | |||
from typing import Any, Callable, Sequence | |||
from kamal.core.engine.events import DefaultEvents, Event | |||
from kamal.utils import get_logger | |||
from collections import defaultdict | |||
import numbers | |||
import contextlib | |||
import pdb | |||
class State(object): | |||
def __init__(self): | |||
self.iter = 0 | |||
self.max_iter = None | |||
self.epoch_length = None | |||
self.dataloader = None | |||
self.seed = None | |||
self.metrics=dict() | |||
self.batch=None | |||
@property | |||
def current_epoch(self): | |||
if self.epoch_length is not None: | |||
return self.iter // self.epoch_length | |||
return None | |||
@property | |||
def max_epoch(self): | |||
if self.epoch_length is not None: | |||
return self.max_iter // self.epoch_length | |||
return None | |||
@property | |||
def current_batch_index(self): | |||
if self.epoch_length is not None: | |||
return self.iter % self.epoch_length | |||
return None | |||
@property | |||
def max_batch_index(self): | |||
return self.epoch_length | |||
def __repr__(self): | |||
rep = "State:\n" | |||
for attr, value in self.__dict__.items(): | |||
if not isinstance(value, (numbers.Number, str, dict)): | |||
value = type(value) | |||
rep += "\t{}: {}\n".format(attr, value) | |||
return rep | |||
class Engine(abc.ABC): | |||
def __init__(self, logger=None, tb_writer=None): | |||
self._logger = logger if logger else get_logger(name='kamal', color=True) | |||
self._tb_writer = tb_writer | |||
self._callbacks = defaultdict(list) | |||
self._allowed_events = [ *DefaultEvents ] | |||
self._state = State() | |||
def reset(self): | |||
self._state = State() | |||
def run(self, step_fn: Callable, dataloader, max_iter, start_iter=0, epoch_length=None): | |||
self.state.iter = self._state.start_iter = start_iter | |||
self.state.max_iter = max_iter | |||
self.state.epoch_length = epoch_length if epoch_length else len(dataloader) | |||
self.state.dataloader = dataloader | |||
self.state.dataloader_iter = iter(dataloader) | |||
self.state.step_fn = step_fn | |||
self.trigger_events(DefaultEvents.BEFORE_RUN) | |||
for self.state.iter in range( start_iter, max_iter ): | |||
if self.state.epoch_length!=None and \ | |||
self.state.iter%self.state.epoch_length==0: # Epoch Start | |||
self.trigger_events(DefaultEvents.BEFORE_EPOCH) | |||
self.trigger_events(DefaultEvents.BEFORE_STEP) | |||
self.state.batch = self._get_batch() | |||
step_output = step_fn(self, self.state.batch) | |||
if isinstance(step_output, dict): | |||
self.state.metrics.update(step_output) | |||
self.trigger_events(DefaultEvents.AFTER_STEP) | |||
if self.state.epoch_length!=None and \ | |||
(self.state.iter+1)%self.state.epoch_length==0: # Epoch End | |||
self.trigger_events(DefaultEvents.AFTER_EPOCH) | |||
self.trigger_events(DefaultEvents.AFTER_RUN) | |||
def _get_batch(self): | |||
try: | |||
batch = next( self.state.dataloader_iter ) | |||
except StopIteration: | |||
self.state.dataloader_iter = iter(self.state.dataloader) # reset iterator | |||
batch = next( self.state.dataloader_iter ) | |||
if not isinstance(batch, (list, tuple)): | |||
batch = [ batch, ] # no targets | |||
return batch | |||
@property | |||
def state(self): | |||
return self._state | |||
@property | |||
def logger(self): | |||
return self._logger | |||
@property | |||
def tb_writer(self): | |||
return self._tb_writer | |||
def add_callback(self, event: Event, callbacks ): | |||
if not isinstance(callbacks, Sequence): | |||
callbacks = [callbacks] | |||
if event in self._allowed_events: | |||
for callback in callbacks: | |||
if callback not in self._callbacks[event]: | |||
if event.trigger!=event.default_trigger: | |||
callback = self._trigger_wrapper(self, event.trigger, callback ) | |||
self._callbacks[event].append( callback ) | |||
callbacks = [ RemovableCallback(self, event, c) for c in callbacks ] | |||
return ( callbacks[0] if len(callbacks)==1 else callbacks ) | |||
def remove_callback(self, event, callback): | |||
for c in self._callbacks[event]: | |||
if c==callback: | |||
self._callbacks.remove( callback ) | |||
return True | |||
return False | |||
@staticmethod | |||
def _trigger_wrapper(engine, trigger, callback): | |||
def wrapper(*args, **kwargs) -> Any: | |||
if trigger(engine): | |||
return callback(engine) | |||
return wrapper | |||
def trigger_events(self, *events): | |||
for e in events: | |||
if e in self._allowed_events: | |||
for callback in self._callbacks[e]: | |||
callback(self) | |||
def register_events(self, *events): | |||
for e in events: | |||
if e not in self._allowed_events: | |||
self._allowed_events.apped( e ) | |||
@contextlib.contextmanager | |||
def save_current_callbacks(self): | |||
temp = self._callbacks | |||
self._callbacks = defaultdict(list) | |||
yield | |||
self._callbacks = temp | |||
class RemovableCallback: | |||
def __init__(self, engine, event, callback): | |||
self._engine = weakref.ref(engine) | |||
self._callback = weakref.ref(callback) | |||
self._event = weakref.ref(event) | |||
@property | |||
def callback(self): | |||
return self._callback() | |||
def remove(self): | |||
engine = self._engine() | |||
callback = self._callback() | |||
event = self._event() | |||
return engine.remove_callback(event, callback) | |||
@@ -0,0 +1,62 @@ | |||
import oneflow | |||
from kamal.core import metrics | |||
from kamal.utils import set_mode | |||
from typing import Callable | |||
from .engine import Engine | |||
from .events import DefaultEvents | |||
from kamal.core import callbacks | |||
import pdb | |||
import weakref | |||
from kamal.utils import move_to_device, split_batch | |||
class BasicEvaluator(Engine): | |||
def __init__(self, | |||
dataloader: oneflow.utils.data.DataLoader, | |||
metric: metrics.MetricCompose, | |||
eval_fn: Callable=None, | |||
tag: str='Eval', | |||
progress: bool=False ): | |||
super( BasicEvaluator, self ).__init__() | |||
self.dataloader = dataloader | |||
self.metric = metric | |||
self.progress = progress | |||
if progress: | |||
self.porgress_callback = self.add_callback( | |||
DefaultEvents.AFTER_STEP, callbacks=callbacks.ProgressCallback(max_iter=len(self.dataloader), tag=tag)) | |||
self._model = None | |||
self._tag = tag | |||
if eval_fn is None: | |||
eval_fn = BasicEvaluator.default_eval_fn | |||
self.eval_fn = eval_fn | |||
def eval(self, model, device=None): | |||
device = device if device is not None else \ | |||
oneflow.device( 'cuda' if oneflow.cuda.is_available() else 'cpu' ) | |||
self._model = weakref.ref(model) # use weakref here | |||
self.device = device | |||
self.metric.reset() | |||
model.to(device) | |||
if self.progress: | |||
self.porgress_callback.callback.reset() | |||
with oneflow.no_grad(), set_mode(model, training=False): | |||
super(BasicEvaluator, self).run( self.step_fn, self.dataloader, max_iter=len(self.dataloader) ) | |||
return self.metric.get_results() | |||
@property | |||
def model(self): | |||
if self._model is not None: | |||
return self._model() | |||
return None | |||
def step_fn(self, engine, batch): | |||
batch = move_to_device(batch, self.device) | |||
self.eval_fn( engine, batch ) | |||
@staticmethod | |||
def default_eval_fn(evaluator, batch): | |||
model = evaluator.model | |||
inputs, targets = split_batch(batch) | |||
outputs = model( inputs ) | |||
evaluator.metric.update( outputs, targets ) |
@@ -0,0 +1,90 @@ | |||
# Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================= | |||
from typing import Callable, Optional | |||
from enum import Enum | |||
class Event(object): | |||
def __init__(self, value: str, event_trigger: Optional[Callable]=None ): | |||
if event_trigger is None: | |||
event_trigger = Event.default_trigger | |||
self._trigger = event_trigger | |||
self._name_ = self._value_ = value | |||
@property | |||
def trigger(self): | |||
return self._trigger | |||
@property | |||
def name(self): | |||
"""The name of the Enum member.""" | |||
return self._name_ | |||
@property | |||
def value(self): | |||
"""The value of the Enum member.""" | |||
return self._value_ | |||
@staticmethod | |||
def default_trigger(engine): | |||
return True | |||
@staticmethod | |||
def once_trigger(): | |||
is_triggered = False | |||
def wrapper(engine): | |||
if is_triggered: | |||
return False | |||
is_triggered=True | |||
return True | |||
return wrapper | |||
@staticmethod | |||
def every_trigger(every: int): | |||
def wrapper(engine): | |||
return every>0 and (engine.state.iter % every)==0 | |||
return wrapper | |||
def __call__(self, every: Optional[int]=None, once: Optional[bool]=None ): | |||
if every is not None: | |||
assert once is None | |||
return Event(self.value, event_trigger=Event.every_trigger(every) ) | |||
if once is not None: | |||
return Event(self.value, event_trigger=Event.once_trigger() ) | |||
return Event(self.value) | |||
def __hash__(self): | |||
return hash(self._name_) | |||
def __eq__(self, other): | |||
if hasattr(other, 'value'): | |||
return self.value==other.value | |||
else: | |||
return | |||
class DefaultEvents(Event, Enum): | |||
BEFORE_RUN = "before_train" | |||
AFTER_RUN = "after_train" | |||
BEFORE_EPOCH = "before_epoch" | |||
AFTER_EPOCH = "after_epoch" | |||
BEFORE_STEP = "before_step" | |||
AFTER_STEP = "after_step" | |||
BEFORE_GET_BATCH = "before_get_batch" | |||
AFTER_GET_BATCH = "after_get_batch" | |||
@@ -0,0 +1,32 @@ | |||
# Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================= | |||
class FeatureHook(): | |||
def __init__(self, module): | |||
self.module = module | |||
self.feat_in = None | |||
self.feat_out = None | |||
self.register() | |||
def register(self): | |||
self._hook = self.module.register_forward_hook(self.hook_fn_forward) | |||
def remove(self): | |||
self._hook.remove() | |||
def hook_fn_forward(self, module, fea_in, fea_out): | |||
self.feat_in = fea_in[0] | |||
self.feat_out = fea_out | |||
@@ -0,0 +1,115 @@ | |||
import oneflow as torch | |||
import oneflow.nn as nn | |||
from kamal.core.engine.engine import Engine | |||
from kamal.core import tasks | |||
from kamal.utils import set_mode, move_to_device, split_batch | |||
from typing import Sequence | |||
import time | |||
import pdb | |||
class BasicTrainer(Engine): | |||
def __init__( self, | |||
logger=None, | |||
tb_writer=None): | |||
super(BasicTrainer, self).__init__(logger=logger, tb_writer=tb_writer) | |||
def setup(self, | |||
model: torch.nn.Module, | |||
task: tasks.Task, | |||
dataloader: torch.utils.data.DataLoader, | |||
optimizer: torch.optim.Optimizer, | |||
device: torch.device=None): | |||
if device is None: | |||
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) | |||
self.device = device | |||
if isinstance(task, Sequence): | |||
task = tasks.TaskCompose(task) | |||
self.task = task | |||
# 多GPU | |||
# torch.cuda.set_device(4) | |||
self.model = model.to("cuda") | |||
self.model = nn.parallel.DistributedDataParallel(self.model) | |||
self.dataloader = dataloader | |||
self.optimizer = optimizer | |||
return self | |||
def run( self, max_iter, start_iter=0, epoch_length=None): | |||
self.model.to(self.device) | |||
with set_mode(self.model, training=True): | |||
super( BasicTrainer, self ).run( self.step_fn, self.dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length) | |||
def step_fn(self, engine, batch): | |||
model = self.model | |||
start_time = time.perf_counter() | |||
batch = move_to_device(batch, self.device) | |||
inputs, targets = split_batch(batch) | |||
outputs = model(inputs) | |||
loss_dict = self.task.get_loss(outputs, targets) # get loss | |||
loss = sum( loss_dict.values() ) | |||
self.optimizer.zero_grad() | |||
loss.backward() | |||
self.optimizer.step() | |||
step_time = time.perf_counter() - start_time | |||
metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } | |||
metrics.update({ | |||
'total_loss': loss.item(), | |||
'step_time': step_time, | |||
'lr': float( self.optimizer.param_groups[0]['lr'] ) | |||
}) | |||
return metrics | |||
class KDTrainer(BasicTrainer): | |||
def setup(self, | |||
student: torch.nn.Module, | |||
teacher: torch.nn.Module, | |||
task: tasks.Task, | |||
dataloader: torch.utils.data.DataLoader, | |||
optimizer: torch.optim.Optimizer, | |||
device: torch.device=None): | |||
super(KDTrainer, self).setup( | |||
model=student, task=task, dataloader=dataloader, optimizer=optimizer, device=device) | |||
if isinstance(teacher, (list, tuple)): | |||
if len(teacher)==1: | |||
teacher=teacher[0] | |||
else: | |||
teacher = nn.ModuleList(teacher) | |||
self.student = self.model | |||
self.teacher = teacher | |||
return self | |||
def run( self, max_iter, start_iter=0, epoch_length=None): | |||
self.student.to(self.device) | |||
self.teacher.to(self.device) | |||
with set_mode(self.student, training=True), \ | |||
set_mode(self.teacher, training=False): | |||
super( BasicTrainer, self ).run( | |||
self.step_fn, self.dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length) | |||
def step_fn(self, engine, batch): | |||
model = self.model | |||
start_time = time.perf_counter() | |||
batch = move_to_device(batch, self.device) | |||
inputs, targets = split_batch(batch) | |||
outputs = model(inputs) | |||
if isinstance(self.teacher, nn.ModuleList): | |||
soft_targets = [ t(inputs) for t in self.teacher ] | |||
else: | |||
soft_targets = self.teacher(inputs) | |||
loss_dict = self.task.get_loss(outputs, soft_targets) # get loss | |||
loss = sum( loss_dict.values() ) | |||
self.optimizer.zero_grad() | |||
loss.backward() | |||
self.optimizer.step() | |||
step_time = time.perf_counter() - start_time | |||
metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } | |||
metrics.update({ | |||
'total_loss': loss.item(), | |||
'step_time': step_time, | |||
'lr': float( self.optimizer.param_groups[0]['lr'] ) | |||
}) | |||
return metrics |
@@ -0,0 +1,5 @@ | |||
class DataTypeError(object): | |||
pass | |||
class InvalidMapping(object): | |||
pass |
@@ -0,0 +1,3 @@ | |||
from .stream_metrics import Metric, MetricCompose | |||
from .accuracy import Accuracy | |||
@@ -0,0 +1,23 @@ | |||
import oneflow | |||
from kamal.core.metrics.stream_metrics import Metric | |||
__all__=['Accuracy'] | |||
class Accuracy(Metric): | |||
def __init__(self, attach_to=None): | |||
super(Accuracy, self).__init__(attach_to=attach_to) | |||
self.reset() | |||
@oneflow.no_grad() | |||
def update(self, outputs, targets): | |||
outputs, targets = self._attach(outputs, targets) | |||
outputs = outputs.max(1)[1] | |||
self._correct += ( outputs.view(-1)==targets.view(-1) ).sum() | |||
self._cnt += oneflow.numel( targets ) | |||
def get_results(self): | |||
return (self._correct / self._cnt).detach().cpu() | |||
def reset(self): | |||
self._correct = self._cnt = 0.0 |
@@ -0,0 +1,63 @@ | |||
from __future__ import division | |||
import oneflow | |||
from abc import ABC, abstractmethod | |||
from typing import Mapping | |||
from kamal.core.attach import AttachTo | |||
class Metric(ABC): | |||
def __init__(self, attach_to=None): | |||
self._attach = AttachTo(attach_to) | |||
@abstractmethod | |||
def update(self, pred, target): | |||
""" Overridden by subclasses """ | |||
raise NotImplementedError() | |||
@abstractmethod | |||
def get_results(self): | |||
""" Overridden by subclasses """ | |||
raise NotImplementedError() | |||
@abstractmethod | |||
def reset(self): | |||
""" Overridden by subclasses """ | |||
raise NotImplementedError() | |||
class MetricCompose(dict): | |||
def __init__(self, metric_dict: Mapping): | |||
self._metric_dict = metric_dict | |||
def add_metrics( self, metric_dict: Mapping): | |||
if isinstance(metric_dict, MetricCompose): | |||
metric_dict = metric_dict.metrics | |||
self._metric_dict.update(metric_dict) | |||
return self | |||
@property | |||
def metrics(self): | |||
return self._metric_dict | |||
@oneflow.no_grad() | |||
def update(self, outputs, targets): | |||
for key, metric in self._metric_dict.items(): | |||
if isinstance(metric, Metric): | |||
metric.update(outputs, targets) | |||
def get_results(self): | |||
results = {} | |||
for key, metric in self._metric_dict.items(): | |||
if isinstance(metric, Metric): | |||
results[key] = metric.get_results() | |||
return results | |||
def reset(self): | |||
for key, metric in self._metric_dict.items(): | |||
if isinstance(metric, Metric): | |||
metric.reset() | |||
def __getitem__(self, name): | |||
return self._metric_dict[name] | |||
@@ -0,0 +1,3 @@ | |||
from .task import StandardTask, StandardMetrics, GeneralTask, Task, TaskCompose | |||
from . import loss | |||
@@ -0,0 +1,2 @@ | |||
from . import functional, loss | |||
from .loss import * |
@@ -0,0 +1,17 @@ | |||
import oneflow | |||
import oneflow.nn.functional as F | |||
def kldiv(logits, targets, T=1.0): | |||
""" Cross Entropy for soft targets | |||
Parameters: | |||
- logits (Tensor): logits score (e.g. outputs of fc layer) | |||
- targets (Tensor): logits of soft targets | |||
- T (float): temperature of distill | |||
- reduction: reduction to the output | |||
""" | |||
p_targets = F.softmax(targets/T, dim=1) | |||
logp_logits = F.log_softmax(logits/T, dim=1) | |||
kl_div = oneflow.nn.KLDivLoss(reduction="none") | |||
kld = kl_div(logp_logits, p_targets) * (T**2) | |||
return kld.sum(1).mean() |
@@ -0,0 +1,8 @@ | |||
from .functional import * | |||
class KLDiv(object): | |||
def __init__(self, T=1.0): | |||
self.T = T | |||
def __call__(self, logits, targets): | |||
return kldiv( logits, targets, T=self.T ) |
@@ -0,0 +1,180 @@ | |||
# Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================= | |||
import abc | |||
import oneflow.nn as nn | |||
import oneflow.nn.functional as F | |||
from typing import Callable, Dict, Any | |||
from . import loss | |||
from kamal.core import metrics | |||
from kamal.core.attach import AttachTo | |||
class Task(object): | |||
def __init__(self, name): | |||
self.name = name | |||
@abc.abstractmethod | |||
def get_loss( self, outputs, targets ) -> Dict: | |||
pass | |||
@abc.abstractmethod | |||
def predict(self, outputs) -> Any: | |||
pass | |||
class GeneralTask(Task): | |||
def __init__(self, | |||
name: str, | |||
loss_fn: Callable, | |||
scaling:float=1.0, | |||
pred_fn: Callable=lambda x: x, | |||
attach_to=None): | |||
super(GeneralTask, self).__init__(name) | |||
self._attach = AttachTo(attach_to) | |||
self.loss_fn = loss_fn | |||
self.pred_fn = pred_fn | |||
self.scaling = scaling | |||
def get_loss(self, outputs, targets): | |||
outputs, targets = self._attach(outputs, targets) | |||
return { self.name: self.loss_fn( outputs, targets ) * self.scaling } | |||
def predict(self, outputs): | |||
outputs = self._attach(outputs) | |||
return self.pred_fn(outputs) | |||
def __repr__(self): | |||
rep = "Task: [%s loss_fn=%s scaling=%.4f attach=%s]"%(self.name, str(self.loss_fn), self.scaling, self._attach) | |||
return rep | |||
class TaskCompose(list): | |||
def __init__(self, tasks: list): | |||
for task in tasks: | |||
if isinstance(task, Task): | |||
self.append(task) | |||
def get_loss(self, outputs, targets): | |||
loss_dict = {} | |||
for task in self: | |||
loss_dict.update( task.get_loss( outputs, targets ) ) | |||
return loss_dict | |||
def predict(self, outputs): | |||
results = [] | |||
for task in self: | |||
results.append( task.predict( outputs ) ) | |||
return results | |||
def __repr__(self): | |||
rep="TaskCompose: \n" | |||
for task in self: | |||
rep+="\t%s\n"%task | |||
class StandardTask: | |||
@staticmethod | |||
def classification(name='ce', scaling=1.0, attach_to=None): | |||
return GeneralTask( name=name, | |||
loss_fn=nn.CrossEntropyLoss(), | |||
scaling=scaling, | |||
pred_fn=lambda x: x.max(1)[1], | |||
attach_to=attach_to ) | |||
@staticmethod | |||
def binary_classification(name='bce', scaling=1.0, attach_to=None): | |||
return GeneralTask(name=name, | |||
loss_fn=F.binary_cross_entropy_with_logits, | |||
scaling=scaling, | |||
pred_fn=lambda x: (x>0.5), | |||
attach_to=attach_to ) | |||
@staticmethod | |||
def regression(name='mse', scaling=1.0, attach_to=None): | |||
return GeneralTask(name=name, | |||
loss_fn=nn.MSELoss(), | |||
scaling=scaling, | |||
pred_fn=lambda x: x, | |||
attach_to=attach_to ) | |||
@staticmethod | |||
def segmentation(name='ce', scaling=1.0, attach_to=None): | |||
return GeneralTask(name=name, | |||
loss_fn=nn.CrossEntropyLoss(ignore_index=255), | |||
scaling=scaling, | |||
pred_fn=lambda x: x.max(1)[1], | |||
attach_to=attach_to ) | |||
@staticmethod | |||
def monocular_depth(name='l1', scaling=1.0, attach_to=None): | |||
return GeneralTask(name=name, | |||
loss_fn=nn.L1Loss(), | |||
scaling=scaling, | |||
pred_fn=lambda x: x, | |||
attach_to=attach_to) | |||
@staticmethod | |||
def detection(): | |||
raise NotImplementedError | |||
@staticmethod | |||
def distillation(name='kld', T=1.0, scaling=1.0, attach_to=None): | |||
return GeneralTask(name=name, | |||
loss_fn=loss.KLDiv(T=T), | |||
scaling=scaling, | |||
pred_fn=lambda x: x.max(1)[1], | |||
attach_to=attach_to) | |||
class StandardMetrics(object): | |||
@staticmethod | |||
def classification(attach_to=None): | |||
return metrics.MetricCompose( | |||
metric_dict={'acc': metrics.Accuracy(attach_to=attach_to)} | |||
) | |||
@staticmethod | |||
def regression(attach_to=None): | |||
return metrics.MetricCompose( | |||
metric_dict={'mse': metrics.MeanSquaredError(attach_to=attach_to)} | |||
) | |||
@staticmethod | |||
def segmentation(num_classes, ignore_idx=255, attach_to=None): | |||
confusion_matrix = metrics.ConfusionMatrix(num_classes=num_classes, ignore_idx=ignore_idx, attach_to=attach_to) | |||
return metrics.MetricCompose( | |||
metric_dict={'acc': metrics.Accuracy(attach_to=attach_to), | |||
'confusion_matrix': confusion_matrix , | |||
'miou': metrics.mIoU(confusion_matrix)} | |||
) | |||
@staticmethod | |||
def monocular_depth(attach_to=None): | |||
return metrics.MetricCompose( | |||
metric_dict={ | |||
'rmse': metrics.RootMeanSquaredError(attach_to=attach_to), | |||
'rmse_log': metrics.RootMeanSquaredError( log_scale=True,attach_to=attach_to ), | |||
'rmse_scale_inv': metrics.ScaleInveriantMeanSquaredError(attach_to=attach_to), | |||
'abs rel': metrics.AbsoluteRelativeDifference(attach_to=attach_to), | |||
'sq rel': metrics.SquaredRelativeDifference(attach_to=attach_to), | |||
'percents within thresholds': metrics.Threshold( thresholds=[1.25, 1.25**2, 1.25**3], attach_to=attach_to ) | |||
} | |||
) | |||
@staticmethod | |||
def loss_metric(loss_fn): | |||
return metrics.MetricCompose( | |||
metric_dict={ | |||
'loss': metrics.AverageMetric( loss_fn ) | |||
} | |||
) |
@@ -0,0 +1,2 @@ | |||
from ._utils import * | |||
from .logger import get_logger |
@@ -0,0 +1,45 @@ | |||
import oneflow | |||
import contextlib | |||
import oneflow.nn as nn | |||
def split_batch(batch): | |||
if isinstance(batch, (list, tuple)): | |||
inputs, *targets = batch | |||
if len(targets)==1: | |||
targets = targets[0] | |||
return inputs, targets | |||
else: | |||
return [batch, None] | |||
@contextlib.contextmanager | |||
def set_mode(model, training=True): | |||
ori_mode = model.training | |||
model.train(training) | |||
yield | |||
model.train(ori_mode) | |||
def move_to_device(obj, device): | |||
if isinstance(obj, oneflow.Tensor): | |||
return obj.to(device=device) | |||
elif isinstance( obj, (list, tuple) ): | |||
return [ o.to(device=device) for o in obj ] | |||
elif isinstance(obj, nn.Module): | |||
return obj.to(device=device) | |||
def flatten_dict(dic): | |||
flattned = dict() | |||
def _flatten(prefix, d): | |||
for k, v in d.items(): | |||
if isinstance(v, dict): | |||
if prefix is None: | |||
_flatten( k, v ) | |||
else: | |||
_flatten( prefix+'%s/'%k, v ) | |||
else: | |||
flattned[ (prefix+'%s/'%k).strip('/') ] = v | |||
_flatten('', dic) | |||
return flattned |
@@ -0,0 +1,56 @@ | |||
import logging | |||
import os, sys | |||
from termcolor import colored | |||
class _ColorfulFormatter(logging.Formatter): | |||
def __init__(self, *args, **kwargs): | |||
super(_ColorfulFormatter, self).__init__(*args, **kwargs) | |||
def formatMessage(self, record): | |||
log = super(_ColorfulFormatter, self).formatMessage(record) | |||
if record.levelno == logging.WARNING: | |||
prefix = colored("WARNING", "yellow", attrs=["blink"]) | |||
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: | |||
prefix = colored("ERROR", "red", attrs=["blink", "underline"]) | |||
else: | |||
return log | |||
return prefix + " " + log | |||
def get_logger(name='Kamal', output=None, color=True): | |||
logger = logging.getLogger(name) | |||
logger.setLevel(logging.DEBUG) | |||
logger.propagate = False | |||
# STDOUT | |||
stdout_handler = logging.StreamHandler( stream=sys.stdout ) | |||
stdout_handler.setLevel( logging.DEBUG ) | |||
plain_formatter = logging.Formatter( | |||
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" ) | |||
if color: | |||
formatter = _ColorfulFormatter( | |||
colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", | |||
datefmt="%m/%d %H:%M:%S") | |||
else: | |||
formatter = plain_formatter | |||
stdout_handler.setFormatter(formatter) | |||
logger.addHandler(stdout_handler) | |||
# FILE | |||
if output is not None: | |||
if output.endswith('.txt') or output.endswith('.log'): | |||
os.makedirs(os.path.dirname(output), exist_ok=True) | |||
filename = output | |||
else: | |||
os.makedirs(output, exist_ok=True) | |||
filename = os.path.join(output, "log.txt") | |||
file_handler = logging.FileHandler(filename) | |||
file_handler.setFormatter(plain_formatter) | |||
file_handler.setLevel(logging.DEBUG) | |||
logger.addHandler(file_handler) | |||
return logger | |||
@@ -0,0 +1,3 @@ | |||
from . import models | |||
from . import datasets | |||
from . import sync_transforms |
@@ -0,0 +1,5 @@ | |||
from .stanford_cars import StanfordCars | |||
from .stanford_dogs import StanfordDogs | |||
from .flowers102 import Flowers102 | |||
from .sun397 import SUN397 | |||
from .fgvc_aircraft import FGVCAircraft |
@@ -0,0 +1,152 @@ | |||
import oneflow.utils.data as data | |||
from flowvision.datasets.folder import default_loader | |||
import os | |||
import numpy as np | |||
import random | |||
def make_dataset(dir, image_ids, targets): | |||
assert(len(image_ids) == len(targets)) | |||
images = [] | |||
dir = os.path.expanduser(dir) | |||
for i in range(len(image_ids)): | |||
item = (os.path.join(dir, 'fgvc-aircraft-2013b', 'data', 'images', | |||
'%s.jpg' % image_ids[i]), targets[i]) | |||
images.append(item) | |||
return images | |||
def find_classes(classes_file): | |||
# read classes file, separating out image IDs and class names | |||
image_ids = [] | |||
targets = [] | |||
f = open(classes_file, 'r') | |||
for line in f: | |||
split_line = line.split(' ') | |||
image_ids.append(split_line[0]) | |||
targets.append(' '.join(split_line[1:])) | |||
f.close() | |||
# index class names | |||
classes = np.unique(targets) | |||
class_to_idx = {classes[i]: i for i in range(len(classes))} | |||
targets = [class_to_idx[c] for c in targets] | |||
return (image_ids, targets, classes, class_to_idx) | |||
class FGVCAircraft(data.Dataset): | |||
"""`FGVC-Aircraft <http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft>`_ Dataset. | |||
Args: | |||
root (string): Root directory path to dataset. | |||
class_type (string, optional): The level of FGVC-Aircraft fine-grain classification | |||
to label data with (i.e., ``variant``, ``family``, or ``manufacturer``). | |||
transforms (callable, optional): A function/transforms that takes in a PIL image | |||
and returns a transformed version. E.g. ``transforms.RandomCrop`` | |||
target_transform (callable, optional): A function/transforms that takes in the | |||
target and transforms it. | |||
loader (callable, optional): A function to load an image given its path. | |||
download (bool, optional): If true, downloads the dataset from the internet and | |||
puts it in the root directory. If dataset is already downloaded, it is not | |||
downloaded again. | |||
""" | |||
url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz' | |||
class_types = ('variant', 'family', 'manufacturer') | |||
splits = ('train', 'val', 'trainval', 'test') | |||
def __init__(self, root, class_type='variant', split='train', s=0.5, transform=None, | |||
target_transform=None, loader=default_loader, download=False): | |||
if split not in self.splits: | |||
raise ValueError('Split "{}" not found. Valid splits are: {}'.format( | |||
split, ', '.join(self.splits), | |||
)) | |||
if class_type not in self.class_types: | |||
raise ValueError('Class type "{}" not found. Valid class types are: {}'.format( | |||
class_type, ', '.join(self.class_types), | |||
)) | |||
self.root = root | |||
self.class_type = class_type | |||
self.split = split | |||
self.classes_file = os.path.join(self.root, 'fgvc-aircraft-2013b', 'data', | |||
'images_%s_%s.txt' % (self.class_type, self.split)) | |||
(image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file) | |||
"""if split == 'trainval': | |||
self.image_ids = image_ids | |||
self.targets = targets | |||
self.image_ids, self.targets = self.sample_by_class(s=s) | |||
image_ids = self.image_ids | |||
targets = self.targets""" | |||
samples = make_dataset(self.root, image_ids, targets) | |||
self.transform = transform | |||
self.target_transform = target_transform | |||
self.loader = loader | |||
self.samples = samples | |||
self.classes = classes | |||
self.class_to_idx = class_to_idx | |||
with open(os.path.join(self.root, 'fgvc-aircraft-2013b/data', 'variants.txt')) as f: | |||
self.object_categories = [ | |||
line.strip('\n') for line in f.readlines()] | |||
print('FGVC-Aircraft, Split: %s, Size: %d' % (self.split, self.__len__())) | |||
def __getitem__(self, index): | |||
""" | |||
Args: | |||
index (int): Index | |||
Returns: | |||
tuple: (sample, target) where target is class_index of the target class. | |||
""" | |||
path, target = self.samples[index] | |||
sample = self.loader(path) | |||
if self.transform is not None: | |||
sample = self.transform(sample) | |||
if self.target_transform is not None: | |||
target = self.target_transform(target) | |||
return sample, target | |||
def __len__(self): | |||
return len(self.samples) | |||
def __repr__(self): | |||
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' | |||
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) | |||
fmt_str += ' Root Location: {}\n'.format(self.root) | |||
tmp = ' Transforms (if any): ' | |||
fmt_str += '{0}{1}\n'.format( | |||
tmp, self.transforms.__repr__().replace('\n', '\n' + ' ' * len(tmp))) | |||
tmp = ' Target Transforms (if any): ' | |||
fmt_str += '{0}{1}'.format( | |||
tmp, self.target_transforms.__repr__().replace('\n', '\n' + ' ' * len(tmp))) | |||
return fmt_str | |||
def _check_exists(self): | |||
return os.path.exists(os.path.join(self.root, 'data', 'images')) and \ | |||
os.path.exists(self.classes_file) | |||
def sample_by_class(self, s): | |||
class_dit = {} | |||
image_dit = {} | |||
for class_name, image in zip(self.targets,self.image_ids): | |||
if class_name not in class_dit.keys(): | |||
class_dit[class_name] = [] | |||
image_dit[class_name] = [] | |||
class_dit[class_name].append(class_name) | |||
image_dit[class_name].append(image) | |||
labels, images = [], [] | |||
for key in class_dit.keys(): | |||
n1 = len(class_dit[key]) | |||
n2 = len(image_dit[key]) | |||
assert n1 == n2, "{} not equal {}".format(n1, n2) | |||
random.shuffle(image_dit[key]) | |||
labels += class_dit[key][:int(n1*s)] | |||
images += image_dit[key][:int(n1*s)] | |||
return images, labels |
@@ -0,0 +1,11 @@ | |||
import flowvision | |||
import os | |||
class Flowers102(flowvision.datasets.ImageFolder): | |||
def __init__(self, root, split='train'): | |||
self.split = split | |||
self.num_classes = 102 | |||
path = os.path.join(root, 'train') if self.split=='train' else os.path.join(root, 'valid') | |||
super().__init__(path) | |||
print('Flowers-102, Split: %s, Size: %d' % (self.split, len(self.imgs))) |
@@ -0,0 +1,81 @@ | |||
import os | |||
import glob | |||
from PIL import Image | |||
import numpy as np | |||
from scipy.io import loadmat | |||
from oneflow.utils import data | |||
import random | |||
class StanfordCars(data.Dataset): | |||
"""Dataset for Stanford Cars | |||
""" | |||
urls = {'cars_train.tgz': 'http://imagenet.stanford.edu/internal/car196/cars_train.tgz', | |||
'cars_test.tgz': 'http://imagenet.stanford.edu/internal/car196/cars_test.tgz', | |||
'car_devkit.tgz': 'https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz', | |||
'cars_test_annos_withlabels.mat': 'http://imagenet.stanford.edu/internal/car196/cars_test_annos_withlabels.mat'} | |||
def __init__(self, root, split='train', s=0.5, download=False, transform=None, target_transform=None): | |||
self.root = os.path.abspath( os.path.expanduser(root) ) | |||
self.split = split | |||
self.transform = transform | |||
self.target_transform = target_transform | |||
if download: | |||
self.download() | |||
if self.split == 'train': | |||
annos = os.path.join(self.root, 'devkit', 'cars_train_annos.mat') | |||
else: | |||
annos = os.path.join(self.root, 'devkit', | |||
'cars_test_annos_withlabels.mat') | |||
annos = loadmat(annos) | |||
size = len(annos['annotations'][0]) | |||
self.files = glob.glob(os.path.join( | |||
self.root, 'cars_'+self.split, '*.jpg')) | |||
self.files.sort() | |||
"""if split == 'train': | |||
self.files = self.sample_by_class(s=s)""" | |||
self.labels = np.array([int(l[4])-1 for l in annos['annotations'][0]]) | |||
lbl_annos = loadmat(os.path.join(self.root, 'devkit', 'cars_meta.mat')) | |||
self.object_categories = [str(c[0]) | |||
for c in lbl_annos['class_names'][0]] | |||
print('Stanford Cars, Split: %s, Size: %d' % | |||
(self.split, self.__len__())) | |||
def __len__(self): | |||
return len(self.files) | |||
def __getitem__(self, idx): | |||
img = Image.open(os.path.join(self.root, 'Images', | |||
self.files[idx])).convert("RGB") | |||
lbl = self.labels[idx] | |||
if self.transform is not None: | |||
img = self.transform(img) | |||
if self.target_transform is not None: | |||
lbl = self.target_transform(lbl) | |||
return img, lbl | |||
def sample_by_class(self, s): | |||
class_dit = {} | |||
for file in self.files: | |||
class_name = file.split('/')[0] | |||
if class_name not in class_dit.keys(): | |||
class_dit[class_name] = [] | |||
class_dit[class_name].append(file) | |||
files = [] | |||
for key in class_dit.keys(): | |||
n = len(class_dit[key]) | |||
random.shuffle(class_dit[key]) | |||
files += class_dit[key][:int(n*s)] | |||
return files |
@@ -0,0 +1,67 @@ | |||
import os | |||
import numpy as np | |||
from PIL import Image | |||
from scipy.io import loadmat | |||
from oneflow.utils import data | |||
import random | |||
class StanfordDogs(data.Dataset): | |||
"""Dataset for Stanford Dogs | |||
""" | |||
urls = {"images.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/images.tar", | |||
"annotation.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/annotation.tar", | |||
"lists.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/lists.tar"} | |||
def __init__(self, root, split='train', s=0.5, download=False, transform=None, target_transform=None): | |||
self.root = os.path.abspath( os.path.expanduser(root) ) | |||
self.split = split | |||
self.transform = transform | |||
self.target_transform = target_transform | |||
if download: | |||
self.download() | |||
list_file = os.path.join(self.root, 'lists', self.split+'_list.mat') | |||
mat_file = loadmat(list_file) | |||
size = len(mat_file['file_list']) | |||
self.files = [str(mat_file['file_list'][i][0][0]) for i in range(size)] | |||
"""if split == 'train': | |||
self.files = self.sample_by_class(s=s)""" | |||
self.labels = np.array( | |||
[mat_file['labels'][i][0]-1 for i in range(size)]) | |||
categories = os.listdir(os.path.join(self.root, 'Images')) | |||
categories.sort() | |||
self.object_categories = [c[10:] for c in categories] | |||
print('Stanford Dogs, Split: %s, Size: %d' % | |||
(self.split, self.__len__())) | |||
def __len__(self): | |||
return len(self.files) | |||
def __getitem__(self, idx): | |||
img = Image.open(os.path.join(self.root, 'Images', | |||
self.files[idx])).convert("RGB") | |||
lbl = self.labels[idx] | |||
if self.transform is not None: | |||
img = self.transform(img) | |||
if self.target_transform is not None: | |||
lbl = self.target_transform( lbl ) | |||
return img, lbl | |||
def sample_by_class(self, s): | |||
class_dit = {} | |||
for file in self.files: | |||
class_name = file.split('/')[0] | |||
if class_name not in class_dit.keys(): | |||
class_dit[class_name] = [] | |||
class_dit[class_name].append(file) | |||
files = [] | |||
for key in class_dit.keys(): | |||
n = len(class_dit[key]) | |||
random.shuffle(class_dit[key]) | |||
files += class_dit[key][:int(n*s)] | |||
return files |
@@ -0,0 +1,11 @@ | |||
import flowvision | |||
import os | |||
class SUN397(flowvision.datasets.ImageFolder): | |||
def __init__(self, root, split='train'): | |||
self.split = split | |||
self.num_classes = 397 | |||
path = os.path.join(root, 'train') if self.split=='train' else os.path.join(root, 'test') | |||
super().__init__(path) | |||
print('SUN397, Split: %s, Size: %d' % (self.split, len(self.imgs))) |
@@ -0,0 +1 @@ | |||
from . import classification |
@@ -0,0 +1 @@ | |||
from .resnet import * |