Browse Source

update model converter and model measuring

tags/v0.4.0
之江天枢 3 years ago
parent
commit
5a0d9ad9f9
100 changed files with 7952 additions and 0 deletions
  1. +2
    -0
      model-converter/.gitignore
  2. +7
    -0
      model-converter/Dockerfile
  3. +86
    -0
      model-converter/main.py
  4. +0
    -0
      model_measuring/README.md
  5. +130
    -0
      model_measuring/app.py
  6. +5
    -0
      model_measuring/kamal/__init__.py
  7. +4
    -0
      model_measuring/kamal/amalgamation/__init__.py
  8. +291
    -0
      model_measuring/kamal/amalgamation/common_feature.py
  9. +131
    -0
      model_measuring/kamal/amalgamation/layerwise_amalgamation.py
  10. +209
    -0
      model_measuring/kamal/amalgamation/recombination.py
  11. +295
    -0
      model_measuring/kamal/amalgamation/task_branching.py
  12. +4
    -0
      model_measuring/kamal/core/__init__.py
  13. +45
    -0
      model_measuring/kamal/core/attach.py
  14. +5
    -0
      model_measuring/kamal/core/callbacks/__init__.py
  15. +28
    -0
      model_measuring/kamal/core/callbacks/base.py
  16. +145
    -0
      model_measuring/kamal/core/callbacks/eval_and_ckpt.py
  17. +61
    -0
      model_measuring/kamal/core/callbacks/logging.py
  18. +34
    -0
      model_measuring/kamal/core/callbacks/scheduler.py
  19. +161
    -0
      model_measuring/kamal/core/callbacks/visualize.py
  20. +7
    -0
      model_measuring/kamal/core/engine/__init__.py
  21. +190
    -0
      model_measuring/kamal/core/engine/engine.py
  22. +130
    -0
      model_measuring/kamal/core/engine/evaluator.py
  23. +92
    -0
      model_measuring/kamal/core/engine/events.py
  24. +34
    -0
      model_measuring/kamal/core/engine/hooks.py
  25. +214
    -0
      model_measuring/kamal/core/engine/lr_finder.py
  26. +129
    -0
      model_measuring/kamal/core/engine/trainer.py
  27. +22
    -0
      model_measuring/kamal/core/exceptions.py
  28. +2
    -0
      model_measuring/kamal/core/hub/__init__.py
  29. +288
    -0
      model_measuring/kamal/core/hub/_hub.py
  30. +6
    -0
      model_measuring/kamal/core/hub/_module_mapping.py
  31. +22
    -0
      model_measuring/kamal/core/hub/meta/TASK.py
  32. +3
    -0
      model_measuring/kamal/core/hub/meta/__init__.py
  33. +31
    -0
      model_measuring/kamal/core/hub/meta/input.py
  34. +44
    -0
      model_measuring/kamal/core/hub/meta/meta.py
  35. +8
    -0
      model_measuring/kamal/core/metrics/__init__.py
  36. +118
    -0
      model_measuring/kamal/core/metrics/accuracy.py
  37. +49
    -0
      model_measuring/kamal/core/metrics/average.py
  38. +68
    -0
      model_measuring/kamal/core/metrics/confusion_matrix.py
  39. +86
    -0
      model_measuring/kamal/core/metrics/normal.py
  40. +199
    -0
      model_measuring/kamal/core/metrics/regression.py
  41. +82
    -0
      model_measuring/kamal/core/metrics/stream_metrics.py
  42. +3
    -0
      model_measuring/kamal/core/tasks/__init__.py
  43. +2
    -0
      model_measuring/kamal/core/tasks/loss/__init__.py
  44. +107
    -0
      model_measuring/kamal/core/tasks/loss/functional.py
  45. +386
    -0
      model_measuring/kamal/core/tasks/loss/loss.py
  46. +186
    -0
      model_measuring/kamal/core/tasks/task.py
  47. +2
    -0
      model_measuring/kamal/slim/__init__.py
  48. +12
    -0
      model_measuring/kamal/slim/distillation/__init__.py
  49. +47
    -0
      model_measuring/kamal/slim/distillation/attention.py
  50. +55
    -0
      model_measuring/kamal/slim/distillation/cc.py
  51. +1
    -0
      model_measuring/kamal/slim/distillation/data_free/__init__.py
  52. +99
    -0
      model_measuring/kamal/slim/distillation/data_free/zskt.py
  53. +86
    -0
      model_measuring/kamal/slim/distillation/hint.py
  54. +90
    -0
      model_measuring/kamal/slim/distillation/kd.py
  55. +44
    -0
      model_measuring/kamal/slim/distillation/nst.py
  56. +44
    -0
      model_measuring/kamal/slim/distillation/pkt.py
  57. +45
    -0
      model_measuring/kamal/slim/distillation/rkd.py
  58. +44
    -0
      model_measuring/kamal/slim/distillation/sp.py
  59. +45
    -0
      model_measuring/kamal/slim/distillation/svd.py
  60. +90
    -0
      model_measuring/kamal/slim/distillation/vid.py
  61. +2
    -0
      model_measuring/kamal/slim/prunning/__init__.py
  62. +37
    -0
      model_measuring/kamal/slim/prunning/pruner.py
  63. +85
    -0
      model_measuring/kamal/slim/prunning/strategy.py
  64. +18
    -0
      model_measuring/kamal/transferability/README.md
  65. +20
    -0
      model_measuring/kamal/transferability/__init__.py
  66. +3
    -0
      model_measuring/kamal/transferability/depara/__init__.py
  67. +184
    -0
      model_measuring/kamal/transferability/depara/attribution_graph.py
  68. +87
    -0
      model_measuring/kamal/transferability/depara/attribution_map.py
  69. +135
    -0
      model_measuring/kamal/transferability/trans_graph.py
  70. +109
    -0
      model_measuring/kamal/transferability/trans_metric.py
  71. +2
    -0
      model_measuring/kamal/utils/__init__.py
  72. +153
    -0
      model_measuring/kamal/utils/_utils.py
  73. +56
    -0
      model_measuring/kamal/utils/logger.py
  74. +3
    -0
      model_measuring/kamal/vision/__init__.py
  75. +16
    -0
      model_measuring/kamal/vision/datasets/__init__.py
  76. +70
    -0
      model_measuring/kamal/vision/datasets/ade20k.py
  77. +226
    -0
      model_measuring/kamal/vision/datasets/caltech.py
  78. +78
    -0
      model_measuring/kamal/vision/datasets/camvid.py
  79. +146
    -0
      model_measuring/kamal/vision/datasets/cityscapes.py
  80. +70
    -0
      model_measuring/kamal/vision/datasets/cub200.py
  81. +57
    -0
      model_measuring/kamal/vision/datasets/dataset.py
  82. +143
    -0
      model_measuring/kamal/vision/datasets/fgvc_aircraft.py
  83. +84
    -0
      model_measuring/kamal/vision/datasets/nyu.py
  84. +61
    -0
      model_measuring/kamal/vision/datasets/preprocess/prepare_caltech101.py
  85. +196
    -0
      model_measuring/kamal/vision/datasets/preprocess/prepare_stl10.py
  86. +53
    -0
      model_measuring/kamal/vision/datasets/preprocess/resize_camvid.py
  87. +65
    -0
      model_measuring/kamal/vision/datasets/preprocess/resize_cityscapes.py
  88. +59
    -0
      model_measuring/kamal/vision/datasets/preprocess/resize_voc.py
  89. +57
    -0
      model_measuring/kamal/vision/datasets/preprocess/resize_voc_240.py
  90. +80
    -0
      model_measuring/kamal/vision/datasets/stanford_cars.py
  91. +58
    -0
      model_measuring/kamal/vision/datasets/stanford_dogs.py
  92. +65
    -0
      model_measuring/kamal/vision/datasets/sunrgbd.py
  93. +67
    -0
      model_measuring/kamal/vision/datasets/unlabeled.py
  94. +161
    -0
      model_measuring/kamal/vision/datasets/utils.py
  95. +209
    -0
      model_measuring/kamal/vision/datasets/voc.py
  96. +3
    -0
      model_measuring/kamal/vision/models/__init__.py
  97. +7
    -0
      model_measuring/kamal/vision/models/classification/__init__.py
  98. +63
    -0
      model_measuring/kamal/vision/models/classification/alexnet.py
  99. +1
    -0
      model_measuring/kamal/vision/models/classification/cifar/__init__.py
  100. +108
    -0
      model_measuring/kamal/vision/models/classification/cifar/wrn.py

+ 2
- 0
model-converter/.gitignore View File

@@ -0,0 +1,2 @@
/.idea/
*.iml

+ 7
- 0
model-converter/Dockerfile View File

@@ -0,0 +1,7 @@
FROM tensorflow/tensorflow:2.4.1

WORKDIR /app
RUN pip install web.py tf2onnx
COPY . /app

ENTRYPOINT ["python3", "main.py"]

+ 86
- 0
model-converter/main.py View File

@@ -0,0 +1,86 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

import json
import os
import subprocess
import logging
import web
from subprocess import PIPE

urls = (
'/hello', 'Hello',
'/model_convert', 'ModelConvert'
)
logging.basicConfig(filename='onnx.log', level=logging.DEBUG)

class Hello(object):
def GET(self):
return 'service alive'

class ModelConvert(object):
def POST(self):
data = web.data()
web.header('Content-Type', 'application/json')
try:
json_data = json.loads(data)
model_path = json_data['model_path']
output_path = json_data['output_path']
if not os.path.isdir(model_path):
msg = 'model_path is not a dir: %s' % model_path
logging.error(msg)
return json.dumps({'code': 501, 'msg': msg, 'data': ''})
if not output_path.endswith('/'):
msg = 'output_path is not a dir: %s' % output_path
logging.error(msg)
return json.dumps({'code': 502, 'msg': msg, 'data': ''})
exist_flag = exist(model_path)
if not exist_flag:
msg = 'SavedModel file does not exist at: %s' % model_path
logging.error(msg)
return json.dumps({'code': 503, 'msg': msg, 'data': ''})
convert_flag, msg = convert(model_path, output_path)
if not convert_flag:
return json.dumps({'code': 504, 'msg': msg, 'data': ''})
except Exception as e:
logging.error(str(e))
return json.dumps({'code': 505, 'msg': str(e), 'data': ''})
return json.dumps({'code': 200, 'msg': 'ok', 'data': msg})

def exist(model_path):
for file in os.listdir(model_path):
if file=='saved_model.pbtxt' or file=='saved_model.pb':
return True
return False


def convert(model_path, output_path):
output_path = output_path+'model.onnx'
try:
logging.info('model_path=%s, output_path=%s' % (model_path, output_path))
result = subprocess.run(["python", "-m", "tf2onnx.convert", "--saved-model", model_path, "--output", output_path], stdout=PIPE, stderr=PIPE)
logging.info(repr(result))
if result.returncode != 0:
return False, str(result.stderr)
except Exception as e:
logging.error(str(e))
return False, str(e)
return True, output_path

if __name__ == '__main__':
app = web.application(urls, globals())
app.run()

+ 0
- 0
model_measuring/README.md View File


+ 130
- 0
model_measuring/app.py View File

@@ -0,0 +1,130 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

dependencies = ['torch', 'kamal'] # 依赖项
import json
import traceback
import os
import torch
import web
import kamal
from PIL import Image
from kamal.transferability import TransferabilityGraph
from kamal.transferability.trans_metric import AttrMapMetric
from torchvision.models import *

urls = (
'/model_measure/measure', 'Measure',
'/model_measure/package', 'Package',
)
app = web.application(urls, globals())
class Package:
def POST(self):
req = web.data()
save_path_list = []
for json_data in json.loads(req):
try:
metadata = json_data['metadata']
except Exception as e:
traceback.print_exc()
return json.dumps(Response(506, 'Failed to package model, Error: %s'% (traceback.format_exc(limit=1)), save_path_list).__dict__)
entry_name = json_data['entry_name']
for name, fn in kamal.hub.list_entry(__file__):
if entry_name in name:
try:
dataset = metadata['dataset']
model = fn(pretrained=False, num_classes=metadata['entry_args']['num_classes'])
num_params = sum( [ torch.numel(p) for p in model.parameters() ] )
save_path_for_measure = '%sfinegraind_%s/' % (json_data['ckpt'], entry_name)
save_path = save_path_for_measure+'%s' % (metadata['name'])
ckpt = self.file_name(json_data['ckpt'])
if ckpt == '':
return json.dumps(
Response(506, 'Failed to package model [%s]: No .pth file was found in directory ckpt' % (entry_name), save_path_list).__dict__)
model.load_state_dict(torch.load(ckpt), False)
kamal.hub.save( # 该调用将用户的pytorch模型打包成上述格式,并存储至指定位置
model, # 需要保存的模型 nn.Module
save_path=save_path,
# 导出文件夹名称
entry_name=entry_name, # 入口函数名,需要与上边的入口函数一致
spec_name=None, # 具体的参数版本名,为空则自动用md5替代
code_path=__file__, # 模型依赖的代码,可以是文件夹(必须包含hubconf.py文件),
# 或者是当前hubconf.py, 例子中直接使用了依赖中的模型实现,故只需指定为本文件即可
metadata=metadata,
tags=dict(
num_params=num_params,
metadata=metadata,
name=metadata['name'],
url=metadata['url'],
dataset=dataset,
img_size=metadata['input']['size'],
readme=json_data['readme'])

)
save_path_list.append(save_path_for_measure)
return json.dumps(Response(200, 'Success', save_path_list).__dict__)
except Exception:
traceback.print_exc()
return json.dumps(Response(506,'Failed to package model [%s], Error: %s' % (entry_name, traceback.format_exc(limit=1)), save_path_list).__dict__)
return json.dumps(Response(506, 'Failed to package model [%s], Error: %s' % (entry_name, traceback.format_exc(limit=1)), save_path_list).__dict__)


def file_name(self, file_dir):
for root, dirs, files in os.walk(file_dir):
for file in files:
if file.endswith('pth'):
return root + file
return ''


class Measure:
def POST(self):
req = web.data()
json_data = json.loads(req)
print(json_data)
try:
measure_name = 'measure'
zoo_set = json_data['zoo_set']
probe_set_root = json_data['probe_set_root']
export_path = json_data['export_path']
output_filename_list = []
TG = TransferabilityGraph(zoo_set)
print("Add %s" % (probe_set_root))
imgs_set = list(os.listdir(os.path.join(probe_set_root)))
images = [Image.open(os.path.join(probe_set_root, img)) for img in imgs_set]
metric = AttrMapMetric(images, device=torch.device('cuda'))
TG.add_metric(probe_set_root, metric)
isExists = os.path.exists(export_path)
if not isExists:
# 如果不存在则创建目录
os.makedirs(export_path)
output_filename = export_path+'%s.json' % (measure_name)
TG.export_to_json(probe_set_root, output_filename, topk=3, normalize=True)
output_filename_list.append(output_filename)
except Exception:
traceback.print_exc()
return json.dumps(Response(506, 'Failed to generate measurement file of [%s], Error: %s' % (probe_set_root, traceback.format_exc(limit=1)), output_filename_list).__dict__)
return json.dumps(Response(200, 'Success', output_filename_list).__dict__)

class Response:
def __init__(self, code, msg, data):
self.code = code
self.msg = msg
self.data = data

if __name__ == "__main__":
app.run()

+ 5
- 0
model_measuring/kamal/__init__.py View File

@@ -0,0 +1,5 @@
from .core import tasks, metrics, engine, callbacks, hub

from . import amalgamation, slim, vision, transferability

from .core import load, save

+ 4
- 0
model_measuring/kamal/amalgamation/__init__.py View File

@@ -0,0 +1,4 @@
from .layerwise_amalgamation import LayerWiseAmalgamator
from .common_feature import CommonFeatureAmalgamator
from .task_branching import TaskBranchingAmalgamator, JointSegNet
from .recombination import RecombinationAmalgamator, CombinedModel

+ 291
- 0
model_measuring/kamal/amalgamation/common_feature.py View File

@@ -0,0 +1,291 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

from kamal.core.engine.engine import Engine
from kamal.core.engine.hooks import FeatureHook
from kamal.core import tasks
from kamal.utils import set_mode, move_to_device

import torch
import torch.nn as nn
import torch.nn.functional as F

import typing, time
import numpy as np


def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)

class ResBlock(nn.Module):
""" Residual Blocks
"""
def __init__(self, inplanes, planes, stride=1, momentum=0.1):
super(ResBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes, momentum=momentum)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes, momentum=momentum)
if stride > 1 or inplanes != planes:
self.downsample = nn.Sequential(
nn.Conv2d(inplanes, planes, kernel_size=1,
stride=stride, bias=False),
nn.BatchNorm2d(planes, momentum=momentum)
)
else:
self.downsample = None

self.stride = stride

def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out


class CFL_FCBlock(nn.Module):
"""Common Feature Blocks for Fully-Connected layer
This module is used to capture the common features of multiple teachers and calculate mmd with features of student.

**Parameters:**
- cs (int): channel number of student features
- channel_ts (list or tuple): channel number list of teacher features
- ch (int): channel number of hidden features
"""
def __init__(self, cs, cts, ch, k_size=5):
super(CFL_FCBlock, self).__init__()
self.align_t = nn.ModuleList()
for ct in cts:
self.align_t.append(
nn.Sequential(
nn.Linear(ct, ch),
nn.ReLU(inplace=True)
)
)

self.align_s = nn.Sequential(
nn.Linear(cs, ch),
nn.ReLU(inplace=True),
)

self.extractor = nn.Sequential(
nn.Linear(ch, ch),
nn.ReLU(),
nn.Linear(ch, ch),
)

self.dec_t = nn.ModuleList()
for ct in cts:
self.dec_t.append(
nn.Sequential(
nn.Linear(ch, ct),
nn.ReLU(inplace=True),
nn.Linear(ct, ct)
)
)

def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()

def forward(self, fs, fts):
aligned_t = [self.align_t[i](fts[i]) for i in range(len(fts))]
aligned_s = self.align_s(fs)

hts = [self.extractor(f) for f in aligned_t]
hs = self.extractor(aligned_s)

_fts = [self.dec_t[i](hts[i]) for i in range(len(hts))]
return (hs, hts), (_fts, fts)


class CFL_ConvBlock(nn.Module):
"""Common Feature Blocks for Convolutional layer
This module is used to capture the common features of multiple teachers and calculate mmd with features of student.

**Parameters:**
- cs (int): channel number of student features
- channel_ts (list or tuple): channel number list of teacher features
- ch (int): channel number of hidden features
"""
def __init__(self, cs, cts, ch, k_size=5):
super(CFL_ConvBlock, self).__init__()
self.align_t = nn.ModuleList()
for ct in cts:
self.align_t.append(
nn.Sequential(
nn.Conv2d(in_channels=ct, out_channels=ch,
kernel_size=1),
nn.BatchNorm2d(ch),
nn.ReLU(inplace=True)
)
)

self.align_s = nn.Sequential(
nn.Conv2d(in_channels=cs, out_channels=ch,
kernel_size=1),
nn.BatchNorm2d(ch),
nn.ReLU(inplace=True),
)

self.extractor = nn.Sequential(
ResBlock(inplanes=ch, planes=ch, stride=1),
ResBlock(inplanes=ch, planes=ch, stride=1),
)

self.dec_t = nn.ModuleList()
for ct in cts:
self.dec_t.append(
nn.Sequential(
nn.Conv2d(ch, ch, kernel_size=1, stride=1),
nn.BatchNorm2d(ch),
nn.ReLU(inplace=True),
nn.Conv2d(ch, ct, kernel_size=1, stride=1)
)
)

def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()

def forward(self, fs, fts):
aligned_t = [self.align_t[i](fts[i]) for i in range(len(fts))]
aligned_s = self.align_s(fs)

hts = [self.extractor(f) for f in aligned_t]
hs = self.extractor(aligned_s)

_fts = [self.dec_t[i](hts[i]) for i in range(len(hts))]
return (hs, hts), (_fts, fts)

class CommonFeatureAmalgamator(Engine):
def setup(
self,
student,
teachers,
layer_groups: typing.Sequence[typing.Sequence],
layer_channels: typing.Sequence[typing.Sequence],
dataloader: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
weights = [1.0, 1.0, 1.0],
on_layer_input=False,
device = None,
):
self._dataloader = dataloader
if device is None:
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )
self._device = device

self._model = self._student = student.to(self._device)
self._teachers = nn.ModuleList(teachers).to(self._device)
self._optimizer = optimizer
self._weights = weights
self._on_layer_input = on_layer_input

amal_blocks = []
for group, C in zip( layer_groups, layer_channels ):
hooks = [ FeatureHook(layer) for layer in group ]
if isinstance(group[0], nn.Linear):
amal_block = CFL_FCBlock( cs=C[0], cts=C[1:], ch=C[0]//4 ).to(self._device).train()
print("Building FC Blocks")
else:
amal_block = CFL_ConvBlock(cs=C[0], cts=C[1:], ch=C[0]//4).to(self._device).train()
print("Building Conv Blocks")
amal_blocks.append( (amal_block, hooks, C) )
self._amal_blocks = amal_blocks
self._cfl_criterion = tasks.loss.CFLLoss( sigmas=[0.001, 0.01, 0.05, 0.1, 0.2, 1, 2] )

@property
def device(self):
return self._device
def run(self, max_iter, start_iter=0, epoch_length=None):
block_params = []
for block, _, _ in self._amal_blocks:
block_params.extend( list(block.parameters()) )
if isinstance( self._optimizer, torch.optim.SGD ):
self._amal_optimimizer = torch.optim.SGD( block_params, lr=self._optimizer.param_groups[0]['lr'], momentum=0.9, weight_decay=1e-4 )
else:
self._amal_optimimizer = torch.optim.Adam( block_params, lr=self._optimizer.param_groups[0]['lr'], weight_decay=1e-4 )
self._amal_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self._amal_optimimizer, T_max=max_iter )
with set_mode(self._student, training=True), \
set_mode(self._teachers, training=False):
super( CommonFeatureAmalgamator, self ).run(self.step_fn, self._dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length)

def step_fn(self, engine, batch):
start_time = time.perf_counter()
batch = move_to_device(batch, self._device)
data = batch[0]
s_out = self._student( data )
with torch.no_grad():
t_out = [ teacher( data ) for teacher in self._teachers ]
loss_amal = 0
loss_recons = 0
for amal_block, hooks, C in self._amal_blocks:
features = [ h.feat_in if self._on_layer_input else h.feat_out for h in hooks ]
fs, fts = features[0], features[1:]
(hs, hts), (_fts, fts) = amal_block( fs, fts )
_loss_amal, _loss_recons = self._cfl_criterion( hs, hts, _fts, fts )
loss_amal += _loss_amal
loss_recons += _loss_recons
loss_kd = tasks.loss.kldiv( s_out, torch.cat( t_out, dim=1 ) )
loss_dict = {
'loss_kd': self._weights[0]*loss_kd,
'loss_amal': self._weights[1]*loss_amal,
'loss_recons': self._weights[2]*loss_recons
}
loss = sum(loss_dict.values())
self._optimizer.zero_grad()
self._amal_optimimizer.zero_grad()
loss.backward()
self._optimizer.step()
self._amal_optimimizer.step()
self._amal_scheduler.step()
step_time = time.perf_counter() - start_time

metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() }
metrics.update({
'total_loss': loss.item(),
'step_time': step_time,
'lr': float( self._optimizer.param_groups[0]['lr'] )
})
return metrics

+ 131
- 0
model_measuring/kamal/amalgamation/layerwise_amalgamation.py View File

@@ -0,0 +1,131 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from kamal.core.engine.engine import Engine
from kamal.core.engine.hooks import FeatureHook
from kamal.core import tasks

from kamal.utils import set_mode
import typing
import time
from kamal.utils import move_to_device, set_mode

class AmalBlock(nn.Module):
def __init__(self, cs, cts):
super( AmalBlock, self ).__init__()
self.cs, self.cts = cs, cts
self.enc = nn.Conv2d( in_channels=sum(self.cts), out_channels=self.cs, kernel_size=1, stride=1, padding=0, bias=True )
self.fam = nn.Conv2d( in_channels=self.cs, out_channels=self.cs, kernel_size=1, stride=1, padding=0, bias=True )
self.dec = nn.Conv2d( in_channels=self.cs, out_channels=sum(self.cts), kernel_size=1, stride=1, padding=0, bias=True )
def forward(self, fs, fts):
rep = self.enc( torch.cat( fts, dim=1 ) )
_fts = self.dec( rep )
_fts = torch.split( _fts, self.cts, dim=1 )
_fs = self.fam( fs )
return rep, _fs, _fts

class LayerWiseAmalgamator(Engine):
def setup(
self,
student,
teachers,
layer_groups: typing.Sequence[typing.Sequence],
layer_channels: typing.Sequence[typing.Sequence],
dataloader: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
weights = [1., 1., 1.],
device=None,
):
if device is None:
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )
self._device = device
self._dataloader = dataloader
self.model = self.student = student.to(self.device)
self.teachers = nn.ModuleList(teachers).to(self.device)
self.optimizer = optimizer
self._weights = weights
amal_blocks = []
for group, C in zip(layer_groups, layer_channels):
hooks = [ FeatureHook(layer) for layer in group ]
amal_block = AmalBlock(cs=C[0], cts=C[1:]).to(self.device).train()
amal_blocks.append( (amal_block, hooks, C) )
self._amal_blocks = amal_blocks
@property
def device(self):
return self._device
def run(self, max_iter, start_iter=0, epoch_length=None ):
block_params = []
for block, _, _ in self._amal_blocks:
block_params.extend( list(block.parameters()) )
if isinstance( self.optimizer, torch.optim.SGD ):
self._amal_optimimizer = torch.optim.SGD( block_params, lr=self.optimizer.param_groups[0]['lr'], momentum=0.9, weight_decay=1e-4 )
else:
self._amal_optimimizer = torch.optim.Adam( block_params, lr=self.optimizer.param_groups[0]['lr'], weight_decay=1e-4 )
self._amal_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self._amal_optimimizer, T_max=max_iter )

with set_mode(self.student, training=True), \
set_mode(self.teachers, training=False):
super( LayerWiseAmalgamator, self ).run(self.step_fn, self._dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length)
@property
def device(self):
return self._device
def step_fn(self, engine, batch):
start_time = time.perf_counter()
batch = move_to_device(batch, self._device)
data = batch[0]
s_out = self.student( data )
with torch.no_grad():
t_out = [ teacher( data ) for teacher in self.teachers ]
loss_amal = 0
loss_recons = 0
for amal_block, hooks, C in self._amal_blocks:
features = [ h.feat_out for h in hooks ]
fs, fts = features[0], features[1:]
rep, _fs, _fts = amal_block( fs, fts )
loss_amal += F.mse_loss( _fs, rep.detach() )
loss_recons += sum( [ F.mse_loss( _ft, ft ) for (_ft, ft) in zip( _fts, fts ) ] )
loss_kd = tasks.loss.kldiv( s_out, torch.cat( t_out, dim=1 ) )
#loss_kd = F.mse_loss( s_out, torch.cat( t_out, dim=1 ) )
loss_dict = { "loss_kd": self._weights[0] * loss_kd,
"loss_amal": self._weights[1] * loss_amal,
"loss_recons": self._weights[2] * loss_recons }
loss = sum(loss_dict.values())
self.optimizer.zero_grad()
self._amal_optimimizer.zero_grad()
loss.backward()
self.optimizer.step()
self._amal_optimimizer.step()
self._amal_scheduler.step()
step_time = time.perf_counter() - start_time
metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() }
metrics.update({
'total_loss': loss.item(),
'step_time': step_time,
'lr': float( self.optimizer.param_groups[0]['lr'] )
})
return metrics



+ 209
- 0
model_measuring/kamal/amalgamation/recombination.py View File

@@ -0,0 +1,209 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy

from typing import Callable

from kamal.core.engine.engine import Engine
from kamal.core.engine.trainer import KDTrainer
from kamal.core.engine.hooks import FeatureHook
from kamal.core import tasks
import math

from kamal.slim.prunning import Pruner, strategy

def _assert_same_type(layers, layer_type=None):
if layer_type is None:
layer_type = type(layers[0])

assert all(isinstance(l, layer_type) for l in layers), 'Model archictures must be the same'

def _get_layers(model_list):
submodel = [ model.modules() for model in model_list ]
for layers in zip(*submodel):
_assert_same_type(layers)
yield layers

def bn_combine_fn(layers):
"""Combine 2D Batch Normalization Layers
**Parameters:**
- **layers** (BatchNorm2D): Batch Normalization Layers.
"""
_assert_same_type(layers, nn.BatchNorm2d)
num_features = sum(l.num_features for l in layers)
combined_bn = nn.BatchNorm2d(num_features=num_features,
eps=layers[0].eps,
momentum=layers[0].momentum,
affine=layers[0].affine,
track_running_stats=layers[0].track_running_stats)
combined_bn.running_mean = torch.cat(
[l.running_mean for l in layers], dim=0).clone()
combined_bn.running_var = torch.cat(
[l.running_var for l in layers], dim=0).clone()

if combined_bn.affine:
combined_bn.weight = torch.nn.Parameter(
torch.cat([l.weight.data.clone() for l in layers], dim=0).clone())
combined_bn.bias = torch.nn.Parameter(
torch.cat([l.bias.data.clone() for l in layers], dim=0).clone())
return combined_bn


def conv2d_combine_fn(layers):
"""Combine 2D Conv Layers
**Parameters:**
- **layers** (Conv2d): Conv Layers.
"""
_assert_same_type(layers, nn.Conv2d)

CO, CI = 0, 0
for l in layers:
O, I, H, W = l.weight.shape
CO += O
CI += I

dtype = layers[0].weight.dtype
device = layers[0].weight.device

combined_weight = torch.nn.Parameter(
torch.zeros(CO, CI, H, W, dtype=dtype, device=device))
if layers[0].bias is not None:
combined_bias = torch.nn.Parameter(
torch.zeros(CO, dtype=dtype, device=device))
else:
combined_bias = None
co_offset = 0
ci_offset = 0
for idx, l in enumerate(layers):
co_len, ci_len = l.weight.shape[0], l.weight.shape[1]
combined_weight[co_offset: co_offset+co_len,
ci_offset: ci_offset+ci_len, :, :] = l.weight.clone()
if combined_bias is not None:
combined_bias[co_offset: co_offset+co_len] = l.bias.clone()
co_offset += co_len
ci_offset += ci_offset
combined_conv2d = nn.Conv2d(in_channels=CI,
out_channels=CO,
kernel_size=layers[0].weight.shape[-2:],
stride=layers[0].stride,
padding=layers[0].padding,
bias=layers[0].bias is not None)
combined_conv2d.weight.data = combined_weight
if combined_bias is not None:
combined_conv2d.bias.data = combined_bias
for p in combined_conv2d.parameters():
p.requires_grad = True
return combined_conv2d


def combine_models(models):
"""Combine modules with parser
**Parameters:**
- **models** (nn.Module): modules to be combined.
- **combine_parser** (function): layer selector
"""
def _recursively_combine(module):
module_output = module

if isinstance( module, nn.Conv2d ):
combined_module = conv2d_combine_fn( layer_mapping[module] )
elif isinstance( module, nn.BatchNorm2d ):
combined_module = bn_combine_fn( layer_mapping[module] )
else:
combined_module = module

if combined_module is not None:
module_output = combined_module

for name, child in module.named_children():
module_output.add_module(name, _recursively_combine(child))
return module_output

models = deepcopy(models)
combined_model = deepcopy(models[0]) # copy the model archicture and modify it with _recursively_combine

layer_mapping = {}
for combined_layer, layers in zip(combined_model.modules(), _get_layers(models)):
layer_mapping[combined_layer] = layers # link to teachers
combined_model = _recursively_combine(combined_model)
return combined_model


class CombinedModel(nn.Module):
def __init__(self, models):
super( Combination, self ).__init__()
self.combined_model = combine_models( models )
self.expand = len(models)

def forward(self, x):
x.repeat( -1, x.shape[1]*self.expand, -1, -1 )
return self.combined_model(x)


class PruningKDTrainer(KDTrainer):
def setup(
self,
student,
teachers,
task,
dataloader: torch.utils.data.DataLoader,
get_optimizer_and_scheduler:Callable=None,
pruning_rounds=5,
device=None,
):
if device is None:
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )
self._device = device
self._dataloader = dataloader
self.model = self.student = student.to(self.device)
self.teachers = nn.ModuleList(teachers).to(self.device)
self.get_optimizer_and_scheduler = get_optimizer_and_scheduler
@property
def device(self):
return self._device
def run(self, max_iter, start_iter=0, epoch_length=None, pruning_rounds=3, target_model_size=0.6 ):
pruning_size_per_round = 1 - math.pow( target_model_size, 1/pruning_rounds )
prunner = Pruner( strategy.LNStrategy(n=1) )
for pruning_round in range(pruning_rounds):
prunner.prune( self.student, rate=pruning_size_per_round, example_inputs=torch.randn(1,3,240,240) )
self.student.to(self.device)
if self.get_optimizer_and_scheduler:
self.optimizer, self.scheduler = self.get_optimizer_and_scheduler( self.student )
else:
self.optimizer = torch.optim.Adam( self.student.parameters(), lr=1e-4, weight_decay=1e-5 )
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self.optimizer, T_max= (max_iter-start_iter)//pruning_rounds )
step_iter = (max_iter - start_iter)//pruning_rounds

with set_mode(self.student, training=True), \
set_mode(self.teachers, training=False):
super( RecombinationAmalgamation, self ).run(self.step_fn, self._dataloader,
start_iter=start_iter+step_iter*pruning_round , max_iter=start_iter+step_iter*(pruning_round+1), epoch_length=epoch_length)
def step_fn(self, engine, batch):
metrics = super(RecombinationAmalgamation, self).step_fn( engine, batch )
self.scheduler.step()
return metrics

class RecombinationAmalgamator(PruningKDTrainer):
pass

+ 295
- 0
model_measuring/kamal/amalgamation/task_branching.py View File

@@ -0,0 +1,295 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from kamal.core.engine.engine import Engine
from kamal.core.engine.hooks import FeatureHook
from kamal.core import tasks
from kamal.utils import move_to_device, set_mode
from kamal.core.hub import meta
from kamal import vision
import kamal

from kamal.utils import set_mode
import typing
import time

from copy import deepcopy
import random
import numpy as np
from collections import defaultdict
import numbers

class BranchySegNet(nn.Module):
def __init__(self, out_channels, segnet_fn=vision.models.segmentation.segnet_vgg16_bn):
super(BranchySegNet, self).__init__()
channels=[512, 512, 256, 128, 64]
self.register_buffer( 'branch_indices', torch.zeros((len(out_channels),)) )

self.student_b_decoders_list = nn.ModuleList()
self.student_adaptors_list = nn.ModuleList()

ses = []
for i in range(5):
se = int(channels[i]/4)
ses.append(16 if se < 16 else se)

for oc in out_channels:
segnet = self.get_segnet( oc, segnet_fn )
decoders = nn.ModuleList(deepcopy(list(segnet.children())[5:]))
adaptors = nn.ModuleList()
for i in range(5):
adaptor = nn.Sequential(
nn.Conv2d(channels[i], ses[i], kernel_size=1, bias=False),
nn.ReLU(),
nn.Conv2d(ses[i], channels[i], kernel_size=1, bias=False),
nn.Sigmoid()
)
adaptors.append(adaptor)
self.student_b_decoders_list.append(decoders)
self.student_adaptors_list.append(adaptors)

self.student_encoders = nn.ModuleList(deepcopy(list(segnet.children())[0:5]))
self.student_decoders = nn.ModuleList(deepcopy(list(segnet.children())[5:]))

def set_branch(self, branch_indices):
assert len(branch_indices)==len(self.student_b_decoders_list)
self.branch_indices = torch.from_numpy( np.array( branch_indices ) ).to(self.branch_indices.device)

def get_segnet(self, oc, segnet_fn):
return segnet_fn( num_classes=oc, pretrained_backbone=True )

def forward(self, inputs):
output_list = []
down1, indices_1, unpool_shape1 = self.student_encoders[0](inputs)
down2, indices_2, unpool_shape2 = self.student_encoders[1](down1)
down3, indices_3, unpool_shape3 = self.student_encoders[2](down2)
down4, indices_4, unpool_shape4 = self.student_encoders[3](down3)
down5, indices_5, unpool_shape5 = self.student_encoders[4](down4)

up5 = self.student_decoders[0](down5, indices_5, unpool_shape5)
up4 = self.student_decoders[1](up5, indices_4, unpool_shape4)
up3 = self.student_decoders[2](up4, indices_3, unpool_shape3)
up2 = self.student_decoders[3](up3, indices_2, unpool_shape2)
up1 = self.student_decoders[4](up2, indices_1, unpool_shape1)

decoder_features = [down5, up5, up4, up3, up2]
decoder_indices = [indices_5, indices_4, indices_3, indices_2, indices_1]
decoder_shapes = [unpool_shape5, unpool_shape4, unpool_shape3, unpool_shape2, unpool_shape1]

# Mimic teachers.
for i in range(len(self.branch_indices)):
out_idx = self.branch_indices[i]
output = decoder_features[out_idx]
output = output * self.student_adaptors_list[i][out_idx](F.avg_pool2d(output, output.shape[2:3]))
for j in range(out_idx, 5):
output = self.student_b_decoders_list[i][j](
output,
decoder_indices[j],
decoder_shapes[j]
)
output_list.append( output )
return output_list

class JointSegNet(nn.Module):
"""The online student model to learn from any number of single teacher with 'SegNet' structure.

**Parameters:**
- **teachers** (list of 'Module' object): Teachers with 'SegNet' structure to learn from.
- **indices** (list of int): Where to branch out for each task.
- **phase** (string): Should be 'block' or 'finetune'. Useful only in training mode.
- **channels** (list of int, optional): Parameter to build adaptor modules, corresponding to that of 'SegNet'.
"""
def __init__(self, teachers, student=None, channels=[512, 512, 256, 128, 64]):
super(JointSegNet, self).__init__()
self.register_buffer( 'branch_indices', torch.zeros((2,)) )

if student is None:
student = teachers[0]
self.student_encoders = nn.ModuleList(deepcopy(list(teachers[0].children())[0:5]))
self.student_decoders = nn.ModuleList(deepcopy(list(teachers[0].children())[5:]))
self.student_b_decoders_list = nn.ModuleList()
self.student_adaptors_list = nn.ModuleList()

ses = []
for i in range(5):
se = int(channels[i]/4)
ses.append(16 if se < 16 else se)

for teacher in teachers:
decoders = nn.ModuleList(deepcopy(list(teacher.children())[5:]))
adaptors = nn.ModuleList()
for i in range(5):
adaptor = nn.Sequential(
nn.Conv2d(channels[i], ses[i], kernel_size=1, bias=False),
nn.ReLU(),
nn.Conv2d(ses[i], channels[i], kernel_size=1, bias=False),
nn.Sigmoid()
)
adaptors.append(adaptor)
self.student_b_decoders_list.append(decoders)
self.student_adaptors_list.append(adaptors)

def set_branch(self, branch_indices):
assert len(branch_indices)==len(self.student_b_decoders_list)
self.branch_indices = torch.from_numpy( np.array( branch_indices ) ).to(self.branch_indices.device)

def forward(self, inputs):
output_list = []

down1, indices_1, unpool_shape1 = self.student_encoders[0](inputs)
down2, indices_2, unpool_shape2 = self.student_encoders[1](down1)
down3, indices_3, unpool_shape3 = self.student_encoders[2](down2)
down4, indices_4, unpool_shape4 = self.student_encoders[3](down3)
down5, indices_5, unpool_shape5 = self.student_encoders[4](down4)

up5 = self.student_decoders[0](down5, indices_5, unpool_shape5)
up4 = self.student_decoders[1](up5, indices_4, unpool_shape4)
up3 = self.student_decoders[2](up4, indices_3, unpool_shape3)
up2 = self.student_decoders[3](up3, indices_2, unpool_shape2)
up1 = self.student_decoders[4](up2, indices_1, unpool_shape1)

decoder_features = [down5, up5, up4, up3, up2]
decoder_indices = [indices_5, indices_4, indices_3, indices_2, indices_1]
decoder_shapes = [unpool_shape5, unpool_shape4, unpool_shape3, unpool_shape2, unpool_shape1]

# Mimic teachers.
for i in range(len(self.branch_indices)):
out_idx = self.branch_indices[i]
output = decoder_features[out_idx]
output = output * self.student_adaptors_list[i][out_idx](F.avg_pool2d(output, output.shape[2:3]))
for j in range(out_idx, 5):
output = self.student_b_decoders_list[i][j](
output,
decoder_indices[j],
decoder_shapes[j]
)
output_list.append( output )
return output_list


class TaskBranchingAmalgamator(Engine):
def setup(
self,
joint_student: JointSegNet,
teachers,
tasks,
dataloader: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
device=None,
):
if device is None:
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )
self._device = device
self._dataloader = dataloader
self.student = self.model = joint_student.to(self._device)
self.teachers = nn.ModuleList(teachers).to(self._device)
self.tasks = tasks
self.optimizer = optimizer

self.is_finetuning=False
@property
def device(self):
return self._device

def run(self, max_iter, start_iter=0, epoch_length=None, stage_callback=None ):
# Branching
with set_mode(self.student, training=True), \
set_mode(self.teachers, training=False):
super( TaskBranchingAmalgamator, self ).run(self.step_fn, self._dataloader, start_iter=start_iter, max_iter=max_iter//2, epoch_length=epoch_length)
branch = self.find_the_best_branch( self._dataloader )
self.logger.info("[Task Branching] the best branch indices: %s"%(branch))

if stage_callback is not None:
stage_callback()

# Finetuning
self.is_finetuning = True
with set_mode(self.student, training=True), \
set_mode(self.teachers, training=False):
super( TaskBranchingAmalgamator, self ).run(self.step_fn, self._dataloader, start_iter=max_iter-max_iter//2, max_iter=max_iter, epoch_length=epoch_length)

def find_the_best_branch(self, dataloader):
with set_mode(self.student, training=False), \
set_mode(self.teachers, training=False), \
torch.no_grad():
n_blocks = len(self.student.student_decoders)
branch_loss = { task: [0 for _ in range(n_blocks)] for task in self.tasks }
for batch in dataloader:
batch = move_to_device(batch, self.device)
data = batch if isinstance(batch, torch.Tensor) else batch[0]
for candidate_branch in range( n_blocks ):
self.student.set_branch( [candidate_branch for _ in range(len(self.teachers))] )
s_out_list = self.student( data )
t_out_list = [ teacher( data ) for teacher in self.teachers ]
for task, s_out, t_out in zip( self.tasks, s_out_list, t_out_list ):
task_loss = task.get_loss( s_out, t_out )
branch_loss[task][candidate_branch] += sum(task_loss.values())
best_brach = []
for task in self.tasks:
best_brach.append( int(np.argmin( branch_loss[task] )) )

self.student.set_branch(best_brach)
return best_brach

@property
def device(self):
return self._device
def step_fn(self, engine, batch):
start_time = time.perf_counter()
batch = move_to_device(batch, self._device)
data = batch[0]
#data = batch if isinstance(batch, torch.Tensor) else batch[0]
data, None
n_blocks = len(self.student.student_decoders)
if not self.is_finetuning:
rand_branch_indices = [ random.randint(0, n_blocks-1) for _ in range(len(self.teachers)) ]
self.student.set_branch(rand_branch_indices)

s_out_list = self.student( data )
with torch.no_grad():
t_out_list = [ teacher( data ) for teacher in self.teachers ]

loss_dict = {}
for task, s_out, t_out in zip( self.tasks, s_out_list, t_out_list ):
task_loss = task.get_loss( s_out, t_out )
loss_dict.update( task_loss )
loss = sum(loss_dict.values())

self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

step_time = time.perf_counter() - start_time
metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() }
metrics.update({
'total_loss': loss.item(),
'step_time': step_time,
'lr': float( self.optimizer.param_groups[0]['lr'] ),
'branch': self.student.branch_indices.cpu().numpy().tolist()
})
return metrics



+ 4
- 0
model_measuring/kamal/core/__init__.py View File

@@ -0,0 +1,4 @@
from . import engine, tasks, metrics, callbacks, exceptions, hub
from .attach import AttachTo

from .hub import load, save

+ 45
- 0
model_measuring/kamal/core/attach.py View File

@@ -0,0 +1,45 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

from typing import Sequence, Callable
from numbers import Number
from kamal.core import exceptions

class AttachTo(object):
""" Attach task, metrics or visualizer to specified model outputs
"""
def __init__(self, attach_to=None):
if attach_to is not None and not isinstance(attach_to, (Sequence, Number, str, Callable) ):
raise exceptions.InvalidMapping
self._attach_to = attach_to

def __call__(self, *tensors):
if self._attach_to is not None:
if isinstance(self._attach_to, Callable):
return self._attach_to( *tensors )
if isinstance(self._attach_to, Sequence):
_attach_to = self._attach_to
else:
_attach_to = [ self._attach_to for _ in range(len(tensors)) ]
_attach_to = _attach_to[:len(tensors)]
tensors = [ tensor[index] for (tensor, index) in zip( tensors, _attach_to ) ]
if len(tensors)==1:
tensors = tensors[0]
return tensors

def __repr__(self):
rep = "AttachTo: %s"%(self._attach_to)

+ 5
- 0
model_measuring/kamal/core/callbacks/__init__.py View File

@@ -0,0 +1,5 @@
from .logging import MetricsLogging, ProgressCallback
from .base import Callback
from .eval_and_ckpt import EvalAndCkpt
from .scheduler import LRSchedulerCallback
from .visualize import VisualizeOutputs, VisualizeSegmentation, VisualizeDepth

+ 28
- 0
model_measuring/kamal/core/callbacks/base.py View File

@@ -0,0 +1,28 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

import abc

class Callback(abc.ABC):
r""" Base Class for Callbacks
"""
def __init__(self):
pass

@abc.abstractmethod
def __call__(self, engine):
pass

+ 145
- 0
model_measuring/kamal/core/callbacks/eval_and_ckpt.py View File

@@ -0,0 +1,145 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

from .base import Callback
import weakref
from kamal import utils
from typing import Sequence, Optional
import numbers
import os, shutil
import torch

class EvalAndCkpt(Callback):
def __init__(self,
model,
evaluator,
metric_name:str,
metric_mode:str ='max',
save_type:Optional[Sequence]=('best', 'latest'),
ckpt_dir:str ='checkpoints',
ckpt_prefix:str =None,
log_tag:str ='model',
weights_only:bool =True,
verbose:bool =False,):
super(EvalAndCkpt, self).__init__()
self.metric_name = metric_name
assert metric_mode in ('max', 'min'), "metric_mode should be 'max' or 'min'"
self._metric_mode = metric_mode

self._model = weakref.ref( model )
self._evaluator = evaluator
self._ckpt_dir = ckpt_dir
self._ckpt_prefix = "" if ckpt_prefix is None else (ckpt_prefix+'_')
if isinstance(save_type, str):
save_type = ( save_type, )
self._save_type = save_type
if self._save_type is not None:
for save_type in self._save_type:
assert save_type in ('best', 'latest', 'all'), \
'save_type should be None or a subset of (\"best\", \"latest\", \"all\")'
self._log_tag = log_tag
self._weights_only = weights_only
self._verbose = verbose

self._best_score = -999999 if self._metric_mode=='max' else 99999.
self._best_ckpt = None
self._latest_ckpt = None

@property
def best_ckpt(self):
return self._best_ckpt
@property
def latest_ckpt(self):
return self._latest_ckpt

@property
def best_score(self):
return self._best_score

def __call__(self, trainer):
model = self._model()
results = self._evaluator.eval( model, device=trainer.device )
results = utils.flatten_dict(results)
current_score = results[self.metric_name]

scalar_results = { k: float(v) for (k, v) in results.items() if isinstance(v, numbers.Number) or len(v.shape)==0 }
if trainer.logger is not None:
trainer.logger.info( "[Eval %s] Iter %d/%d: %s"%(self._log_tag, trainer.state.iter, trainer.state.max_iter, scalar_results) )
trainer.state.metrics.update( scalar_results )
# Visualize results if trainer.tb_writer is not None

if trainer.tb_writer is not None:
for k, v in scalar_results.items():
log_tag = "%s:%s"%(self._log_tag, k)
trainer.tb_writer.add_scalar(log_tag, v, global_step=trainer.state.iter)

if self._save_type is not None:
pth_path_list = []
# interval model
if 'interval' in self._save_type:
pth_path = os.path.join(self._ckpt_dir, "%s%08d_%s_%.3f.pth"
% (self._ckpt_prefix, trainer.state.iter, self.metric_name, current_score))
pth_path_list.append(pth_path)

# the latest model
if 'latest' in self._save_type:
pth_path = os.path.join(self._ckpt_dir, "%slatest_%08d_%s_%.3f.pth"
% (self._ckpt_prefix, trainer.state.iter, self.metric_name, current_score))
# remove the old ckpt
if self._latest_ckpt is not None and os.path.exists(self._latest_ckpt):
os.remove(self._latest_ckpt)
pth_path_list.append(pth_path)
self._latest_ckpt = pth_path

# the best model
if 'best' in self._save_type:
if (current_score >= self._best_score and self._metric_mode=='max') or \
(current_score <= self._best_score and self._metric_mode=='min'):
pth_path = os.path.join(self._ckpt_dir, "%sbest_%08d_%s_%.4f.pth" %
(self._ckpt_prefix, trainer.state.iter, self.metric_name, current_score))
# remove the old ckpt
if self._best_ckpt is not None and os.path.exists(self._best_ckpt):
os.remove(self._best_ckpt)
pth_path_list.append(pth_path)
self._best_score = current_score
self._best_ckpt = pth_path
# save model
if self._verbose and trainer.logger:
trainer.logger.info("Model saved as:")
obj = model.state_dict() if self._weights_only else model
os.makedirs( self._ckpt_dir, exist_ok=True )
for pth_path in pth_path_list:
torch.save(obj, pth_path)
if self._verbose and trainer.logger:
trainer.logger.info("\t%s" % (pth_path))
def final_ckpt(self, ckpt_prefix=None, ckpt_dir=None, add_md5=False):
if ckpt_dir is None:
ckpt_dir = self._ckpt_dir
if ckpt_prefix is None:
ckpt_prefix = self._ckpt_prefix
if self._save_type is not None:
#if 'latest' in self._save_type and self._latest_ckpt is not None:
# os.makedirs(ckpt_dir, exist_ok=True)
# save_name = "%slatest%s.pth"%(ckpt_prefix, "" if not add_md5 else "-%s"%utils.md5(self._latest_ckpt))
# shutil.copy2(self._latest_ckpt, os.path.join(ckpt_dir, save_name))
if 'best' in self._save_type and self._best_ckpt is not None:
os.makedirs(ckpt_dir, exist_ok=True)
save_name = "%sbest%s.pth"%(ckpt_prefix, "" if not add_md5 else "-%s"%utils.md5(self._best_ckpt))
shutil.copy2(self._best_ckpt, os.path.join(ckpt_dir, save_name))

+ 61
- 0
model_measuring/kamal/core/callbacks/logging.py View File

@@ -0,0 +1,61 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

from .base import Callback
import numbers
from tqdm import tqdm

class MetricsLogging(Callback):
def __init__(self, keys):
super(MetricsLogging, self).__init__()
self._keys = keys

def __call__(self, engine):
if engine.logger==None:
return
state = engine.state
content = "Iter %d/%d (Epoch %d/%d, Batch %d/%d)"%(
state.iter, state.max_iter,
state.current_epoch, state.max_epoch,
state.current_batch_index, state.max_batch_index
)
for key in self._keys:
value = state.metrics.get(key, None)
if value is not None:
if isinstance(value, numbers.Number):
content += " %s=%.4f"%(key, value)
if engine.tb_writer is not None:
engine.tb_writer.add_scalar(key, value, global_step=state.iter)
elif isinstance(value, (list, tuple)):
content += " %s=%s"%(key, value)
engine.logger.info(content)
class ProgressCallback(Callback):
def __init__(self, max_iter=100, tag=None):
self._max_iter = max_iter
self._tag = tag
#self._pbar = tqdm(total=self._max_iter, desc=self._tag)
def __call__(self, engine):
self._pbar.update(1)
if self._pbar.n==self._max_iter:
self._pbar.close()
def reset(self):
self._pbar = tqdm(total=self._max_iter, desc=self._tag)

+ 34
- 0
model_measuring/kamal/core/callbacks/scheduler.py View File

@@ -0,0 +1,34 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

from .base import Callback
from typing import Sequence

class LRSchedulerCallback(Callback):
r""" LR scheduler callback
"""
def __init__(self, schedulers=None):
super(LRSchedulerCallback, self).__init__()
if not isinstance(schedulers, Sequence):
schedulers = ( schedulers, )
self._schedulers = schedulers

def __call__(self, trainer):
if self._schedulers is None:
return
for sched in self._schedulers:
sched.step()

+ 161
- 0
model_measuring/kamal/core/callbacks/visualize.py View File

@@ -0,0 +1,161 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

from .base import Callback
from typing import Callable, Union, Sequence
import weakref
import random
from kamal.utils import move_to_device, set_mode, split_batch, colormap
from kamal.core.attach import AttachTo
import torch
import numpy as np

import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('agg')
import math
import numbers

class VisualizeOutputs(Callback):
def __init__(self,
model,
dataset: torch.utils.data.Dataset,
idx_list_or_num_vis: Union[int, Sequence]=5,
normalizer: Callable=None,
prepare_fn: Callable=None,
decode_fn: Callable=None, # decode targets and preds
tag: str='viz'):
super(VisualizeOutputs, self).__init__()
self._dataset = dataset
self._model = weakref.ref(model)
if isinstance(idx_list_or_num_vis, int):
self.idx_list = self._get_vis_idx_list(self._dataset, idx_list_or_num_vis)
elif isinstance(idx_list_or_num_vis, Sequence):
self.idx_list = idx_list_or_num_vis
self._normalizer = normalizer
self._decode_fn = decode_fn
if prepare_fn is None:
prepare_fn = VisualizeOutputs.get_prepare_fn()
self._prepare_fn = prepare_fn
self._tag = tag

def _get_vis_idx_list(self, dataset, num_vis):
return random.sample(list(range(len(dataset))), num_vis)
@torch.no_grad()
def __call__(self, trainer):
if trainer.tb_writer is None:
trainer.logger.warning("summary writer was not found in trainer")
return
device = trainer.device
model = self._model()
with torch.no_grad(), set_mode(model, training=False):
for img_id, idx in enumerate(self.idx_list):
batch = move_to_device(self._dataset[idx], device)
batch = [ d.unsqueeze(0) for d in batch ]
inputs, targets, preds = self._prepare_fn(model, batch)
if self._normalizer is not None:
inputs = self._normalizer(inputs)
inputs = inputs.detach().cpu().numpy()
preds = preds.detach().cpu().numpy()
targets = targets.detach().cpu().numpy()
if self._decode_fn: # to RGB 0~1 NCHW
preds = self._decode_fn(preds)
targets = self._decode_fn(targets)
inputs = inputs[0]
preds = preds[0]
targets = targets[0]
trainer.tb_writer.add_images("%s-%d"%(self._tag, img_id), np.stack( [inputs, targets, preds], axis=0), global_step=trainer.state.iter)

@staticmethod
def get_prepare_fn(attach_to=None, pred_fn=lambda x: x):
attach_to = AttachTo(attach_to)
def wrapper(model, batch):
inputs, targets = split_batch(batch)
outputs = model(inputs)
outputs, targets = attach_to(outputs, targets)
return inputs, targets, pred_fn(outputs)
return wrapper
@staticmethod
def get_seg_decode_fn(cmap=colormap(), index_transform=lambda x: x+1): # 255->0, 0->1,
def wrapper(preds):
if len(preds.shape)>3:
preds = preds.squeeze(1)
out = cmap[ index_transform(preds.astype('uint8')) ]
out = out.transpose(0, 3, 1, 2) / 255
return out
return wrapper

@staticmethod
def get_depth_decode_fn(max_depth, log_scale=True, cmap=plt.get_cmap('jet')):
def wrapper(preds):
if log_scale:
_max_depth = np.log( max_depth )
preds = np.log( preds )
else:
_max_depth = max_depth
if len(preds.shape)>3:
preds = preds.squeeze(1)
out = (cmap(preds.clip(0, _max_depth)/_max_depth)).transpose(0, 3, 1, 2)[:, :3]
return out
return wrapper

class VisualizeSegmentation(VisualizeOutputs):
def __init__(
self, model, dataset: torch.utils.data.Dataset, idx_list_or_num_vis: Union[int, Sequence]=5,
cmap = colormap(),
attach_to=None,

normalizer: Callable=None,
prepare_fn: Callable=None,
decode_fn: Callable=None,
tag: str='seg'
):
if prepare_fn is None:
prepare_fn = VisualizeOutputs.get_prepare_fn(attach_to=attach_to, pred_fn=lambda x: x.max(1)[1])
if decode_fn is None:
decode_fn = VisualizeOutputs.get_seg_decode_fn(cmap=cmap, index_transform=lambda x: x+1)

super(VisualizeSegmentation, self).__init__(
model=model, dataset=dataset, idx_list_or_num_vis=idx_list_or_num_vis,
normalizer=normalizer, prepare_fn=prepare_fn, decode_fn=decode_fn,
tag=tag
)

class VisualizeDepth(VisualizeOutputs):
def __init__(
self, model, dataset: torch.utils.data.Dataset,
idx_list_or_num_vis: Union[int, Sequence]=5,
max_depth = 10,
log_scale = True,
attach_to = None,

normalizer: Callable=None,
prepare_fn: Callable=None,
decode_fn: Callable=None,
tag: str='depth'
):
if prepare_fn is None:
prepare_fn = VisualizeOutputs.get_prepare_fn(attach_to=attach_to, pred_fn=lambda x: x)
if decode_fn is None:
decode_fn = VisualizeOutputs.get_depth_decode_fn(max_depth=max_depth, log_scale=log_scale)
super(VisualizeDepth, self).__init__(
model=model, dataset=dataset, idx_list_or_num_vis=idx_list_or_num_vis,
normalizer=normalizer, prepare_fn=prepare_fn, decode_fn=decode_fn,
tag=tag
)

+ 7
- 0
model_measuring/kamal/core/engine/__init__.py View File

@@ -0,0 +1,7 @@
from . import evaluator
from . import trainer
from . import lr_finder
from . import hooks
from . import events

from .engine import DefaultEvents, Event

+ 190
- 0
model_measuring/kamal/core/engine/engine.py View File

@@ -0,0 +1,190 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

import torch
import torch.nn as nn
import abc, math, weakref, typing, time
from typing import Any, Callable, Optional, Sequence
import numpy as np

from kamal.core.engine.events import DefaultEvents, Event
from kamal.core import tasks
from kamal.utils import set_mode, move_to_device, get_logger
from collections import defaultdict

import numbers
import contextlib

class State(object):
def __init__(self):
self.iter = 0
self.max_iter = None
self.epoch_length = None
self.dataloader = None
self.seed = None

self.metrics=dict()
self.batch=None

@property
def current_epoch(self):
if self.epoch_length is not None:
return self.iter // self.epoch_length
return None

@property
def max_epoch(self):
if self.epoch_length is not None:
return self.max_iter // self.epoch_length
return None

@property
def current_batch_index(self):
if self.epoch_length is not None:
return self.iter % self.epoch_length
return None

@property
def max_batch_index(self):
return self.epoch_length

def __repr__(self):
rep = "State:\n"
for attr, value in self.__dict__.items():
if not isinstance(value, (numbers.Number, str, dict)):
value = type(value)
rep += "\t{}: {}\n".format(attr, value)
return rep

class Engine(abc.ABC):
def __init__(self, logger=None, tb_writer=None):
self._logger = logger if logger else get_logger(name='kamal', color=True)
self._tb_writer = tb_writer
self._callbacks = defaultdict(list)
self._allowed_events = [ *DefaultEvents ]
self._state = State()
def reset(self):
self._state = State()

def run(self, step_fn: Callable, dataloader, max_iter, start_iter=0, epoch_length=None):
self.state.iter = self._state.start_iter = start_iter
self.state.max_iter = max_iter
self.state.epoch_length = epoch_length if epoch_length else len(dataloader)
self.state.dataloader = dataloader
self.state.dataloader_iter = iter(dataloader)
self.state.step_fn = step_fn

self.trigger_events(DefaultEvents.BEFORE_RUN)
for self.state.iter in range( start_iter, max_iter ):
if self.state.epoch_length!=None and \
self.state.iter%self.state.epoch_length==0: # Epoch Start
self.trigger_events(DefaultEvents.BEFORE_EPOCH)
self.trigger_events(DefaultEvents.BEFORE_STEP)
self.state.batch = self._get_batch()
step_output = step_fn(self, self.state.batch)
if isinstance(step_output, dict):
self.state.metrics.update(step_output)
self.trigger_events(DefaultEvents.AFTER_STEP)
if self.state.epoch_length!=None and \
(self.state.iter+1)%self.state.epoch_length==0: # Epoch End
self.trigger_events(DefaultEvents.AFTER_EPOCH)
self.trigger_events(DefaultEvents.AFTER_RUN)

def _get_batch(self):
try:
batch = next( self.state.dataloader_iter )
except StopIteration:
self.state.dataloader_iter = iter(self.state.dataloader) # reset iterator
batch = next( self.state.dataloader_iter )
if not isinstance(batch, (list, tuple)):
batch = [ batch, ] # no targets
return batch

@property
def state(self):
return self._state

@property
def logger(self):
return self._logger

@property
def tb_writer(self):
return self._tb_writer

def add_callback(self, event: Event, callbacks ):
if not isinstance(callbacks, Sequence):
callbacks = [callbacks]
if event in self._allowed_events:
for callback in callbacks:
if callback not in self._callbacks[event]:
if event.trigger!=event.default_trigger:
callback = self._trigger_wrapper(self, event.trigger, callback )
self._callbacks[event].append( callback )
callbacks = [ RemovableCallback(self, event, c) for c in callbacks ]
return ( callbacks[0] if len(callbacks)==1 else callbacks )

def remove_callback(self, event, callback):
for c in self._callbacks[event]:
if c==callback:
self._callbacks.remove( callback )
return True
return False

@staticmethod
def _trigger_wrapper(engine, trigger, callback):
def wrapper(*args, **kwargs) -> Any:
if trigger(engine):
return callback(engine)
return wrapper

def trigger_events(self, *events):
for e in events:
if e in self._allowed_events:
for callback in self._callbacks[e]:
callback(self)

def register_events(self, *events):
for e in events:
if e not in self._allowed_events:
self._allowed_events.apped( e )

@contextlib.contextmanager
def save_current_callbacks(self):
temp = self._callbacks
self._callbacks = defaultdict(list)
yield
self._callbacks = temp

class RemovableCallback:
def __init__(self, engine, event, callback):
self._engine = weakref.ref(engine)
self._callback = weakref.ref(callback)
self._event = weakref.ref(event)
@property
def callback(self):
return self._callback()

def remove(self):
engine = self._engine()
callback = self._callback()
event = self._event()
return engine.remove_callback(event, callback)


+ 130
- 0
model_measuring/kamal/core/engine/evaluator.py View File

@@ -0,0 +1,130 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

import abc, sys
import torch
from tqdm import tqdm
from kamal.core import metrics
from kamal.utils import set_mode
from typing import Any, Callable
from .engine import Engine
from .events import DefaultEvents
from kamal.core import callbacks

import weakref
from kamal.utils import move_to_device, split_batch

class BasicEvaluator(Engine):
def __init__(self,
dataloader: torch.utils.data.DataLoader,
metric: metrics.MetricCompose,
eval_fn: Callable=None,
tag: str='Eval',
progress: bool=False ):
super( BasicEvaluator, self ).__init__()
self.dataloader = dataloader
self.metric = metric
self.progress = progress
if progress:
self.porgress_callback = self.add_callback(
DefaultEvents.AFTER_STEP, callbacks=callbacks.ProgressCallback(max_iter=len(self.dataloader), tag=tag))
self._model = None
self._tag = tag
if eval_fn is None:
eval_fn = BasicEvaluator.default_eval_fn
self.eval_fn = eval_fn

def eval(self, model, device=None):
device = device if device is not None else \
torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )
self._model = weakref.ref(model) # use weakref here
self.device = device
self.metric.reset()
model.to(device)
if self.progress:
self.porgress_callback.callback.reset()
with torch.no_grad(), set_mode(model, training=False):
super(BasicEvaluator, self).run( self.step_fn, self.dataloader, max_iter=len(self.dataloader) )
return self.metric.get_results()
@property
def model(self):
if self._model is not None:
return self._model()
return None

def step_fn(self, engine, batch):
batch = move_to_device(batch, self.device)
self.eval_fn( engine, batch )
@staticmethod
def default_eval_fn(evaluator, batch):
model = evaluator.model
inputs, targets = split_batch(batch)
outputs = model( inputs )
evaluator.metric.update( outputs, targets )
class TeacherEvaluator(BasicEvaluator):
def __init__(self,
dataloader: torch.utils.data.DataLoader,
teacher: torch.nn.Module,
task,
metric: metrics.MetricCompose,
eval_fn: Callable=None,
tag: str='Eval',
progress: bool=False ):
if eval_fn is None:
eval_fn = TeacherEvaluator.default_eval_fn
super( TeacherEvaluator, self ).__init__(dataloader=dataloader, metric=metric, eval_fn=eval_fn, tag=tag, progress=progress)
self._teacher = teacher
self.task = task
def eval(self, model, device=None):
self.teacher.to(device)
with set_mode(self.teacher, training=False):
return super(TeacherEvaluator, self).eval( model, device=device )

@property
def model(self):
if self._model is not None:
return self._model()
return None

@property
def teacher(self):
return self._teacher

def step_fn(self, engine, batch):
batch = move_to_device(batch, self.device)
self.eval_fn( engine, batch )
@staticmethod
def default_eval_fn(evaluator, batch):
model = evaluator.model
teacher = evaluator.teacher

inputs, targets = split_batch(batch)
outputs = model( inputs )

# get teacher outputs
if isinstance(teacher, torch.nn.ModuleList):
targets = [ task.predict(tea(inputs)) for (tea, task) in zip(teacher, evaluator.task) ]
else:
t_outputs = teacher(inputs)
targets = evaluator.task.predict( t_outputs )
evaluator.metric.update( outputs, targets )

+ 92
- 0
model_measuring/kamal/core/engine/events.py View File

@@ -0,0 +1,92 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

from typing import Callable, Optional
from enum import Enum

class Event(object):
def __init__(self, value: str, event_trigger: Optional[Callable]=None ):
if event_trigger is None:
event_trigger = Event.default_trigger
self._trigger = event_trigger
self._name_ = self._value_ = value

@property
def trigger(self):
return self._trigger

@property
def name(self):
"""The name of the Enum member."""
return self._name_

@property
def value(self):
"""The value of the Enum member."""
return self._value_

@staticmethod
def default_trigger(engine):
return True

@staticmethod
def once_trigger():
is_triggered = False
def wrapper(engine):
if is_triggered:
return False
is_triggered=True
return True
return wrapper

@staticmethod
def every_trigger(every: int):
def wrapper(engine):
return every>0 and (engine.state.iter % every)==0
return wrapper

def __call__(self, every: Optional[int]=None, once: Optional[bool]=None ):
if every is not None:
assert once is None
return Event(self.value, event_trigger=Event.every_trigger(every) )
if once is not None:
return Event(self.value, event_trigger=Event.once_trigger() )
return Event(self.value)

def __hash__(self):
return hash(self._name_)
def __eq__(self, other):
if hasattr(other, 'value'):
return self.value==other.value
else:
return

class DefaultEvents(Event, Enum):
BEFORE_RUN = "before_train"
AFTER_RUN = "after_train"

BEFORE_EPOCH = "before_epoch"
AFTER_EPOCH = "after_epoch"

BEFORE_STEP = "before_step"
AFTER_STEP = "after_step"

BEFORE_GET_BATCH = "before_get_batch"
AFTER_GET_BATCH = "after_get_batch"


+ 34
- 0
model_measuring/kamal/core/engine/hooks.py View File

@@ -0,0 +1,34 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

class FeatureHook():
def __init__(self, module):
self.module = module
self.feat_in = None
self.feat_out = None
self.register()

def register(self):
self._hook = self.module.register_forward_hook(self.hook_fn_forward)

def remove(self):
self._hook.remove()

def hook_fn_forward(self, module, fea_in, fea_out):
self.feat_in = fea_in[0]
self.feat_out = fea_out


+ 214
- 0
model_measuring/kamal/core/engine/lr_finder.py View File

@@ -0,0 +1,214 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

import torch
import os
import tempfile, uuid
import contextlib
import numpy as np
from tqdm import tqdm
from kamal.core import callbacks
from kamal.core.engine import evaluator
from kamal.core.engine.events import DefaultEvents

class _ProgressCallback(callbacks.Callback):
def __init__(self, max_iter, tag):
self._tag = tag
self._max_iter = max_iter
self._pbar = tqdm(total=self._max_iter, desc=self._tag)

def __call__(self, enigine):
self._pbar.update(1)

def reset(self):
self._pbar = tqdm(total=self._max_iter, desc=self._tag)

class _LinearLRScheduler(torch.optim.lr_scheduler._LRScheduler):
""" Linear Scheduler
"""
def __init__(self,
optimizer,
lr_range,
max_iter,
last_epoch=-1):
self.lr_range = lr_range
self.max_iter = max_iter
super(_LinearLRScheduler, self).__init__(optimizer, last_epoch)
def get_lr(self):
r = self.last_epoch / self.max_iter
return [self.lr_range[0] + (self.lr_range[1]-self.lr_range[0]) * r for base_lr in self.base_lrs]

class _ExponentialLRScheduler(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, lr_range, max_iter, last_epoch=-1):
self.lr_range = lr_range
self.max_iter = max_iter
super(_ExponentialLRScheduler, self).__init__(optimizer, last_epoch)

def get_lr(self):
r = self.last_epoch / (self.max_iter - 1)
return [self.lr_range[0] * (self.lr_range[1] / self.lr_range[0]) ** r for base_lr in self.base_lrs]

class _LRFinderCallback(object):
def __init__(self,
model,
optimizer,
metric_name,
evaluator,
smooth_momentum):
self._model = model
self._optimizer = optimizer
self._metric_name = metric_name
self._evaluator = evaluator
self._smooth_momentum = smooth_momentum
self._records = []
@property
def records(self):
return self._records
def reset(self):
self._records = []

def __call__(self, trainer):
model = self._model
optimizer = self._optimizer
if self._evaluator is not None:
results = self._evaluator.eval( model )
score = float(results[ self._metric_name ])
else:
score = float(trainer.state.metrics[self._metric_name])
if self._smooth_momentum>0 and len(self._records)>0:
score = self._records[-1][1] * self._smooth_momentum + score * (1-self._smooth_momentum)
lr = optimizer.param_groups[0]['lr']
self._records.append( ( lr, score ) )


class LRFinder(object):

def _reset(self):
init_state = torch.load( self._temp_file )
self.trainer.model.load_state_dict( init_state['model'] )
self.trainer.optimizer.load_state_dict( init_state['optimizer'] )
try:
self.trainer.reset()
except: pass

def adjust_learning_rate(self, optimizer, lr):
for group in optimizer.param_groups:
group['lr'] = lr

def _get_default_lr_range(self, optimizer):
if isinstance( optimizer, torch.optim.Adam ):
return ( 1e-5, 1e-2 )
elif isinstance( optimizer, torch.optim.SGD ):
return ( 1e-3, 0.2 )
else:
return ( 1e-5, 0.5)
def plot(self, polyfit: int=3, log_x=True):
import matplotlib.pyplot as plt
lrs = [ rec[0] for rec in self.records ]
scores = [ rec[1] for rec in self.records ]
if polyfit is not None:
z = np.polyfit( lrs, scores, deg=polyfit )
fitted_score = np.polyval( z, lrs )
fig, ax = plt.subplots()
ax.plot(lrs, scores, label='score')
ax.plot(lrs, fitted_score, label='polyfit')

if log_x:
plt.xscale('log')

ax.set_xlabel("Learning rate")
ax.set_ylabel("Score")
return fig

def suggest( self, mode='min', skip_begin=10, skip_end=5, polyfit=None ):

scores = np.array( [ self.records[i][1] for i in range( len(self.records) ) ] )
lrs = np.array( [ self.records[i][0] for i in range( len(self.records) ) ] )
if polyfit is not None:
z = np.polyfit( lrs, scores, deg=polyfit )
scores = np.polyval( z, lrs )

grad = np.gradient( scores )[skip_begin:-skip_end]
index = grad.argmin() if mode=='min' else grad.argmax()
index = skip_begin + index
return index, self.records[index][0]

def find(self,
optimizer,
model,
trainer,
metric_name='total_loss',
metric_mode='min',
evaluator=None,
lr_range=[1e-4, 0.1],
max_iter=100,
num_eval=None,
smooth_momentum=0.9,
scheduler='exp', # exp
polyfit=None, # None
skip=[10, 5],
progress=True):

self.optimizer = optimizer
self.model = model
self.trainer = trainer
# save init state
_filename = str(uuid.uuid4())+'.pth'
_tempdir = tempfile.gettempdir()
self._temp_file = os.path.join(_tempdir, _filename)
init_state = {
'optimizer': optimizer.state_dict(),
'model': model.state_dict()
}
torch.save(init_state, self._temp_file)

if num_eval is None or num_eval > max_iter:
num_eval = max_iter
if lr_range is None:
lr_range = self._get_default_lr_range(optimizer)
interval = max_iter // num_eval
if scheduler=='exp':
lr_sched = _ExponentialLRScheduler(optimizer, lr_range, max_iter=max_iter)
else:
lr_sched = _LinearLRScheduler(optimizer, lr_range, max_iter=max_iter)

self._lr_callback = callbacks.LRSchedulerCallback(schedulers=[lr_sched])
self._finder_callback = _LRFinderCallback(model, optimizer, metric_name, evaluator, smooth_momentum)
self.adjust_learning_rate( self.optimizer, lr_range[0] )
with self.trainer.save_current_callbacks():
trainer.add_callback(
DefaultEvents.AFTER_STEP, callbacks=[
self._lr_callback,
_ProgressCallback(max_iter, '[LR Finder]') ])
trainer.add_callback(
DefaultEvents.AFTER_STEP(interval), callbacks=self._finder_callback)
self.trainer.run( start_iter=0, max_iter=max_iter )

self.records = self._finder_callback.records # get records
index, best_lr = self.suggest(mode=metric_mode, skip_begin=skip[0], skip_end=skip[1], polyfit=polyfit)
self._reset()
del self.model, self.optimizer, self.trainer
return best_lr

+ 129
- 0
model_measuring/kamal/core/engine/trainer.py View File

@@ -0,0 +1,129 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

import torch
import torch.nn as nn
from kamal.core.engine.engine import Engine, Event, DefaultEvents, State
from kamal.core import tasks
from kamal.utils import set_mode, move_to_device, get_logger, split_batch
from typing import Callable, Mapping, Any, Sequence
import time
import weakref

class BasicTrainer(Engine):
def __init__( self,
logger=None,
tb_writer=None):
super(BasicTrainer, self).__init__(logger=logger, tb_writer=tb_writer)

def setup(self,
model: torch.nn.Module,
task: tasks.Task,
dataloader: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
device: torch.device=None):
if device is None:
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )
self.device = device
if isinstance(task, Sequence):
task = tasks.TaskCompose(task)
self.task = task
self.model = model
self.dataloader = dataloader
self.optimizer = optimizer
return self

def run( self, max_iter, start_iter=0, epoch_length=None):
self.model.to(self.device)
with set_mode(self.model, training=True):
super( BasicTrainer, self ).run( self.step_fn, self.dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length)

def step_fn(self, engine, batch):
model = self.model
start_time = time.perf_counter()
batch = move_to_device(batch, self.device)
inputs, targets = split_batch(batch)
outputs = model(inputs)
loss_dict = self.task.get_loss(outputs, targets) # get loss
loss = sum( loss_dict.values() )
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
step_time = time.perf_counter() - start_time
metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() }
metrics.update({
'total_loss': loss.item(),
'step_time': step_time,
'lr': float( self.optimizer.param_groups[0]['lr'] )
})
return metrics


class KDTrainer(BasicTrainer):

def setup(self,
student: torch.nn.Module,
teacher: torch.nn.Module,
task: tasks.Task,
dataloader: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
device: torch.device=None):

super(KDTrainer, self).setup(
model=student, task=task, dataloader=dataloader, optimizer=optimizer, device=device)
if isinstance(teacher, (list, tuple)):
if len(teacher)==1:
teacher=teacher[0]
else:
teacher = nn.ModuleList(teacher)
self.student = self.model
self.teacher = teacher
return self

def run( self, max_iter, start_iter=0, epoch_length=None):
self.student.to(self.device)
self.teacher.to(self.device)

with set_mode(self.student, training=True), \
set_mode(self.teacher, training=False):
super( BasicTrainer, self ).run(
self.step_fn, self.dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length)

def step_fn(self, engine, batch):
model = self.model
start_time = time.perf_counter()
batch = move_to_device(batch, self.device)
inputs, targets = split_batch(batch)
outputs = model(inputs)
if isinstance(self.teacher, nn.ModuleList):
soft_targets = [ t(inputs) for t in self.teacher ]
else:
soft_targets = self.teacher(inputs)
loss_dict = self.task.get_loss(outputs, soft_targets) # get loss
loss = sum( loss_dict.values() )
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
step_time = time.perf_counter() - start_time
metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() }
metrics.update({
'total_loss': loss.item(),
'step_time': step_time,
'lr': float( self.optimizer.param_groups[0]['lr'] )
})
return metrics

+ 22
- 0
model_measuring/kamal/core/exceptions.py View File

@@ -0,0 +1,22 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

class DataTypeError(object):
pass

class InvalidMapping(object):
pass

+ 2
- 0
model_measuring/kamal/core/hub/__init__.py View File

@@ -0,0 +1,2 @@
from ._hub import *
from . import meta

+ 288
- 0
model_measuring/kamal/core/hub/_hub.py View File

@@ -0,0 +1,288 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

from pip._internal import main as pipmain
from typing import Sequence, Optional

import torch
import os, sys, re
import shutil, inspect
import importlib.util
from glob import glob

from ruamel.yaml import YAML
from ruamel.yaml import comments
from ._module_mapping import PACKAGE_NAME_TO_IMPORT_NAME
import hashlib

import inspect

_DEFAULT_PROTOCOL = 2
_DEPENDENCY_FILE = "requirements.txt"
_CODE_DIR = "code"
_WEIGHTS_DIR = "weight"
_TAGS_DIR = "tag"

yaml = YAML()

def _replace_invalid_char(name):
return name.replace('-', '_').replace(' ', '_')

def save(
model: torch.nn.Module,
save_path: str,
entry_name: str,
spec_name: str,
code_path: str, # path to SOURCE_CODE_DIR or hubconf.py

metadata: dict,
tags: dict = None,
ignore_files: Sequence = None,
save_arch: bool = False,
overwrite: bool = False,
):
entry_name = _replace_invalid_char(entry_name)
if spec_name is not None:
spec_name = _replace_invalid_char(spec_name)

if not os.path.isdir(save_path):
overwrite = True

save_path = os.path.abspath(save_path)
export_code_path = os.path.join(save_path, _CODE_DIR)
export_weights_path = os.path.join(save_path, _WEIGHTS_DIR)
os.makedirs(save_path, exist_ok=True)
os.makedirs(export_code_path, exist_ok=True)
os.makedirs(export_weights_path, exist_ok=True)

code_path = os.path.abspath(code_path)
if os.path.isdir( code_path ):
if overwrite:
shutil.rmtree(export_code_path)
_copy_file_or_tree(src=code_path, dst=export_code_path) # overwrite old files
elif code_path.endswith('.py'):
if overwrite:
shutil.copy2(src=code_path, dst=os.path.join( export_code_path, 'hubconf.py' )) # overwrite old files

if hasattr(model, 'SETUP_INFO'):
del model.SETUP_INFO
if hasattr(model, 'METADATA'):
del model.METADATA

model_and_metadata = {
'model': model if save_arch else model.state_dict(),
'metadata': metadata,
}
temp_pth = os.path.join(export_weights_path, 'temp.pth')
torch.save(model_and_metadata, temp_pth)
if spec_name is None:
_md5 = md5( temp_pth )
spec_name = _md5
shutil.move( temp_pth, os.path.join(export_weights_path, '%s-%s.pth'%(entry_name, spec_name)) )

if tags is not None:
save_tags( tags, save_path, entry_name, spec_name )

def list_entry(path):
path = os.path.abspath( os.path.expanduser( path ) )
if path.endswith('.py'):
code_dir = os.path.dirname( path )
entry_file = path
elif os.path.isdir(path):
code_dir = os.path.join( path, _CODE_DIR )
entry_file = os.path.join(path, _CODE_DIR, 'app.py')
else:
raise NotImplementedError
sys.path.insert(0, code_dir)
spec = importlib.util.spec_from_file_location('app', entry_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
entry_list = [ ( f, getattr(module, f) ) for f in dir(module) if callable(getattr(module, f)) and not f.startswith('_') ]
sys.path.remove(code_dir)
return entry_list

def list_spec(path, entry_name=None):
path = os.path.abspath( os.path.expanduser( path ) )
weight_dir = os.path.join( path, _WEIGHTS_DIR )
spec_list = [ f.split('-') for f in os.listdir( weight_dir ) if f.endswith('.pth')]
spec_list = [ (f[0], f[1][:-4]) for f in spec_list ]
if entry_name is not None:
spec_list = [ s for s in spec_list if s[0]==entry_name ]
return spec_list

def load(path: str, entry_name:str=None, spec_name: str=None, pretrained=True, **kwargs):
"""
check dependencies and load pytorch models.
Args:
path: path to the model package
pretrained: load the pretrained model
"""
path = os.path.abspath(path)
if not os.path.exists(path):
raise FileNotFoundError

code_dir = os.path.join(path, _CODE_DIR)
if os.path.isdir(path):
cwd = os.getcwd()
os.chdir(code_dir)
sys.path.insert(0, code_dir)
spec = importlib.util.spec_from_file_location(
'hubconf', os.path.join(code_dir, 'hubconf.py'))
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)

if hasattr( module, 'dependencies' ):
deps = getattr( module, 'dependencies')
for dep in deps:
_import_with_auto_install(dep)

if entry_name is None: # auto select
pth_file = [pth for pth in os.listdir( os.path.join(path, _WEIGHTS_DIR) )]
assert len(pth_file)<=1, "Loading models with more than one weight files (.pth) is ambiguous"
pth_file = pth_file[0]
entry_name = pth_file.split('-')[0]
entry_fn = getattr( module, entry_name )
else:
entry_fn = getattr( module, entry_name )
if spec_name is None:
pth_file = [pth for pth in os.listdir( os.path.join(path, _WEIGHTS_DIR) ) if pth.startswith(entry_name) ]
assert len(pth_file)<=1, "Loading models with more than one weight files (.pth) is ambiguous"
pth_file = pth_file[0]
else:
pth_file = '%s-%s.pth'%(entry_name, spec_name)
try:
model_and_metadata = torch.load(os.path.join(path, _WEIGHTS_DIR, pth_file), map_location='cpu' )
except: raise FileNotFoundError

if isinstance( model_and_metadata['model'], torch.nn.Module ):
model = model_and_metadata['model']
else:
entry_args = model_and_metadata['metadata']['entry_args']
if entry_args is None:
entry_args = dict()
model = entry_fn( **entry_args )
if pretrained:
model.load_state_dict( model_and_metadata['model'], False)
# setup metadata and atlas info
model.METADATA = model_and_metadata['metadata']
model.SETUP_INFO = {"path": path, "entry_name": entry_name}
sys.path.pop(0)
os.chdir(cwd)
return model
raise NotImplementedError

def load_metadata(path, entry_name=None, spec_name=None):
path = os.path.abspath(path)
if os.path.isdir(path):
if entry_name is None: # auto select
pth_file = [pth for pth in os.listdir( os.path.join(path, _WEIGHTS_DIR) )]
assert len(pth_file)<=1, "Loading models with more than one weight files (.pth) is ambiguous"
pth_file = pth_file[0]
else:
if spec_name is None:
pth_file = [pth for pth in os.listdir( os.path.join(path, _WEIGHTS_DIR) ) if pth.startswith(entry_name) ]
assert len(pth_file)<=1, "Loading models with more than one weight files (.pth) is ambiguous"
pth_file = pth_file[0]
else:
pth_file = '%s-%s.pth'%(entry_name, spec_name)
try:
metadata = torch.load( os.path.join( path, _WEIGHTS_DIR, pth_file ), map_location='cpu' )['metadata']
except:
FileNotFoundError
return metadata
def load_tags(path, entry_name, spec_name):
path = os.path.abspath(path)
tags_path = os.path.join(path, _TAGS_DIR, '%s-%s.yml'%( entry_name, spec_name ))
if os.path.isfile(tags_path):
return _to_python_type(_yaml_load(tags_path))
return dict()

def save_tags(tags, path, entry_name, spec_name):
path = os.path.abspath(path)
if tags is None:
tags = {}
tags_path = os.path.join(path, _TAGS_DIR, '%s-%s.yml'%( entry_name, spec_name ))
os.makedirs( os.path.join( path, _TAGS_DIR ), exist_ok=True )
_yaml_dump(tags_path, tags)

def _yaml_dump(f, obj):
with open(f, 'w') as f:
yaml.dump(obj, f)
def _yaml_load(f):
with open(f, 'r') as f:
return yaml.load(f)

def _to_python_type(data):
if isinstance(data, dict):
for k, v in data.items():
data[k] = _to_python_type(v)
return dict(data)
elif isinstance(data, comments.CommentedSeq ):
for idx, v in enumerate(data):
data[idx] = _to_python_type(v)
return list(data)
return data

def md5(fname):
hash_md5 = hashlib.md5()
with open(fname, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()

def _get_package_name_and_version(package):
_version_sym_list = ('==', '>=', '<=')
for sym in _version_sym_list:
if sym in package:
return package.split(sym)
return package, None

def _import_with_auto_install(package):
package_name, version = _get_package_name_and_version(package)
package_name = package_name.strip()
import_name = PACKAGE_NAME_TO_IMPORT_NAME.get(
package_name, package_name).replace('-', '_')
try:
return __import__(import_name)
except ImportError:
try:
pipmain.main(['install', package])
except:
pipmain(['install', package])
return __import__(import_name)

from distutils.dir_util import copy_tree
def _copy_file_or_tree(src, dst):
if os.path.isfile(src):
shutil.copy2(src=src, dst=dst)
else:
copy_tree(src=src, dst=dst)

def _glob_list(path_list, recursive=False):
results = []
for path in path_list:
if '*' in path:
path = list(glob(path, recursive=recursive))
results.extend(path)
else:
results.append(path)
return results

+ 6
- 0
model_measuring/kamal/core/hub/_module_mapping.py View File

@@ -0,0 +1,6 @@
PACKAGE_NAME_TO_IMPORT_NAME = {
'opencv-python': 'cv2',
'pillow': 'PIL',
'scikit-learn': 'sklearn',
'scikit-image': 'scikit-image',
}

+ 22
- 0
model_measuring/kamal/core/hub/meta/TASK.py View File

@@ -0,0 +1,22 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

CLASSIFICATION = 'classification'
SEGMENTATION = 'segmentation'
DETECTION = 'detection'
DEPTH = 'depth'
AUTOENCODER = 'autoencoder'

+ 3
- 0
model_measuring/kamal/core/hub/meta/__init__.py View File

@@ -0,0 +1,3 @@
from .meta import Metadata
from .input import ImageInput
from . import TASK

+ 31
- 0
model_measuring/kamal/core/hub/meta/input.py View File

@@ -0,0 +1,31 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

from typing import Union, Sequence

# INPUT Metadata
def ImageInput( size: Union[Sequence[int], int],
range: Union[Sequence[int], int],
space: str,
normalize: Union[Sequence]=None):
assert space in ['rgb', 'gray', 'bgr', 'rgbd']
return dict(
size=size,
range=range,
space=space,
normalize=normalize
)

+ 44
- 0
model_measuring/kamal/core/hub/meta/meta.py View File

@@ -0,0 +1,44 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

from copy import deepcopy
import abc, os

from . import TASK
import torch

__all__ = ['yaml', 'ImageInput', 'MetaData', 'AtlasEntryBase']

# Model Metadata
def Metadata( name: str,
dataset: str,
task: int,
url: str,
input: dict,
entry_args: dict,
other_metadata: dict):
if task in [TASK.SEGMENTATION, TASK.CLASSIFICATION]:
assert 'num_classes' in other_metadata
return dict(
name=name,
dataset=dataset,
task=task,
url=url,
input=dict(input),
entry_args=entry_args,
other_metadata=other_metadata
)

+ 8
- 0
model_measuring/kamal/core/metrics/__init__.py View File

@@ -0,0 +1,8 @@
from .stream_metrics import Metric, MetricCompose
from .accuracy import Accuracy, TopkAccuracy
from .confusion_matrix import ConfusionMatrix, IoU, mIoU
from .regression import *
from .average import AverageMetric



+ 118
- 0
model_measuring/kamal/core/metrics/accuracy.py View File

@@ -0,0 +1,118 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

import numpy as np
import torch
from kamal.core.metrics.stream_metrics import Metric
from typing import Callable

__all__=['Accuracy', 'TopkAccuracy']

class Accuracy(Metric):
def __init__(self, attach_to=None):
super(Accuracy, self).__init__(attach_to=attach_to)
self.reset()

@torch.no_grad()
def update(self, outputs, targets):
outputs, targets = self._attach(outputs, targets)
outputs = outputs.max(1)[1]
self._correct += ( outputs.view(-1)==targets.view(-1) ).sum()
self._cnt += torch.numel( targets )

def get_results(self):
return (self._correct / self._cnt).detach().cpu()
def reset(self):
self._correct = self._cnt = 0.0


class TopkAccuracy(Metric):
def __init__(self, topk=5, attach_to=None):
super(TopkAccuracy, self).__init__(attach_to=attach_to)
self._topk = topk
self.reset()
@torch.no_grad()
def update(self, outputs, targets):
outputs, targets = self._attach(outputs, targets)
_, outputs = outputs.topk(self._topk, dim=1, largest=True, sorted=True)
correct = outputs.eq( targets.view(-1, 1).expand_as(outputs) )
self._correct += correct[:, :self._topk].view(-1).float().sum(0).item()
self._cnt += len(targets)
def get_results(self):
return self._correct / self._cnt

def reset(self):
self._correct = 0.0
self._cnt = 0.0


class StreamCEMAPMetrics():
@property
def PRIMARY_METRIC(self):
return "eap"

def __init__(self):
self.reset()

def update(self, logits, targets):
preds = logits.max(1)[1]
# targets: -1 negative, 0 difficult, 1 positive
if isinstance(preds, torch.Tensor):
preds = preds.cpu().numpy()
targets = targets.cpu().numpy()

self.preds = preds if self.preds is None else np.append(self.preds, preds, axis=0)
self.targets = targets if self.targets is None else np.append(self.targets, targets, axis=0)

def get_results(self):
nTest = self.targets.shape[0]
nLabel = self.targets.shape[1]
eap = np.zeros(nTest)
for i in range(0,nTest):
R = np.sum(self.targets[i,:]==1)
for j in range(0,nLabel):
if self.targets[i,j]==1:
r = np.sum(self.preds[i,np.nonzero(self.targets[i,:]!=0)]>=self.preds[i,j])
rb = np.sum(self.preds[i,np.nonzero(self.targets[i,:]==1)] >= self.preds[i,j])

eap[i] = eap[i] + rb/(r*1.0)
eap[i] = eap[i]/R
# emap = np.nanmean(ap)

cap = np.zeros(nLabel)
for i in range(0,nLabel):
R = np.sum(self.targets[:,i]==1)
for j in range(0,nTest):
if self.targets[j,i]==1:
r = np.sum(self.preds[np.nonzero(self.targets[:,i]!=0),i] >= self.preds[j,i])
rb = np.sum(self.preds[np.nonzero(self.targets[:,i]==1),i] >= self.preds[j,i])
cap[i] = cap[i] + rb/(r*1.0)
cap[i] = cap[i]/R
# cmap = np.nanmean(ap)
return {
'eap': eap,
'emap': np.nanmean(eap),
'cap': cap,
'cmap': np.nanmean(cap),
}

def reset(self):
self.preds = None
self.targets = None

+ 49
- 0
model_measuring/kamal/core/metrics/average.py View File

@@ -0,0 +1,49 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

import numpy as np
import torch
from kamal.core.metrics.stream_metrics import Metric
from typing import Callable

__all__=['AverageMetric']

class AverageMetric(Metric):
def __init__(self, fn:Callable, attach_to=None):
super(AverageMetric, self).__init__(attach_to=attach_to)
self._fn = fn
self.reset()

@torch.no_grad()
def update(self, outputs, targets):

outputs, targets = self._attach(outputs, targets)
m = self._fn( outputs, targets )

if m.ndim > 1:
self._cnt += m.shape[0]
self._accum += m.sum(0)
else:
self._cnt += 1
self._accum += m
def get_results(self):
return (self._accum / self._cnt).detach().cpu()
def reset(self):
self._accum = 0.
self._cnt = 0.

+ 68
- 0
model_measuring/kamal/core/metrics/confusion_matrix.py View File

@@ -0,0 +1,68 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

from .stream_metrics import Metric
import torch
from typing import Callable

class ConfusionMatrix(Metric):
def __init__(self, num_classes, ignore_idx=None, attach_to=None):
super(ConfusionMatrix, self).__init__(attach_to=attach_to)
self._num_classes = num_classes
self._ignore_idx = ignore_idx
self.reset()

@torch.no_grad()
def update(self, outputs, targets):
outputs, targets = self._attach(outputs, targets)
if self.confusion_matrix.device != outputs.device:
self.confusion_matrix = self.confusion_matrix.to(device=outputs.device)
preds = outputs.max(1)[1].flatten()
targets = targets.flatten()
mask = (preds<self._num_classes) & (preds>=0)
if self._ignore_idx:
mask = mask & (targets!=self._ignore_idx)
preds, targets = preds[mask], targets[mask]
hist = torch.bincount( self._num_classes * targets + preds,
minlength=self._num_classes ** 2 ).view(self._num_classes, self._num_classes)
self.confusion_matrix += hist

def get_results(self):
return self.confusion_matrix.detach().cpu()

def reset(self):
self._cnt = 0
self.confusion_matrix = torch.zeros(self._num_classes, self._num_classes, dtype=torch.int64, requires_grad=False)

class IoU(Metric):
def __init__(self, confusion_matrix: ConfusionMatrix, attach_to=None):
self._confusion_matrix = confusion_matrix

def update(self, outputs, targets):
pass # update will be done in confusion matrix

def reset(self):
pass
def get_results(self):
cm = self._confusion_matrix.get_results()
iou = cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) - cm.diag() + 1e-9)
return iou

class mIoU(IoU):
def get_results(self):
return super(mIoU, self).get_results().mean()

+ 86
- 0
model_measuring/kamal/core/metrics/normal.py View File

@@ -0,0 +1,86 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

import numpy as np
import torch
from kamal.core.metrics.stream_metrics import StreamMetricsBase

class NormalPredictionMetrics(StreamMetricsBase):

@property
def PRIMARY_METRIC(self):
return 'mean angle'

def __init__(self, thresholds):
self.thresholds = thresholds
self.preds = None
self.targets = None
self.masks = None

@torch.no_grad()
def update(self, preds, targets, masks):
"""
**Type**: numpy.ndarray or torch.Tensor
**Shape:**
- **preds**: $(N, 3, H, W)$.
- **targets**: $(N, 3, H, W)$.
- **masks**: $(N, 1, H, W)$.
"""
if isinstance(preds, torch.Tensor):
preds = preds.cpu().numpy()
targets = targets.cpu().numpy()
masks = masks.cpu().numpy()

self.preds = preds if self.preds is None else np.append(self.preds, preds, axis=0)
self.targets = targets if self.targets is None else np.append(self.targets, targets, axis=0)
self.masks = masks if self.masks is None else np.append(self.masks, masks, axis=0)

def get_results(self):
"""
**Returns:**
- **mean angle**
- **median angle**
- **precents for angle within thresholds**
"""
products = np.sum(self.preds * self.targets, axis=1)

angles = np.arccos(np.clip(products, -1.0, 1.0)) / np.pi * 180
self.masks = self.masks.squeeze(1)
angles = angles[self.masks == 1]

mean_angle = np.mean(angles)
median_angle = np.median(angles)
count = self.masks.sum()

threshold_percents = {}
for threshold in self.thresholds:
# threshold_percents[threshold] = np.sum((angles < threshold)) / count
threshold_percents[threshold] = np.mean(angles < threshold)

if return_key_metric:
return ('absolute relative', ard)

return {
'mean angle': mean_angle,
'median angle': median_angle,
'percents within thresholds': threshold_percents
}

def reset(self):
self.preds = None
self.targets = None
self.masks = None

+ 199
- 0
model_measuring/kamal/core/metrics/regression.py View File

@@ -0,0 +1,199 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

import numpy as np
import torch

from kamal.core.metrics.stream_metrics import Metric
from typing import Callable

__all__=['MeanSquaredError', 'RootMeanSquaredError', 'MeanAbsoluteError',
'ScaleInveriantMeanSquaredError', 'RelativeDifference',
'AbsoluteRelativeDifference', 'SquaredRelativeDifference', 'Threshold' ]

class MeanSquaredError(Metric):
def __init__(self, log_scale=False, attach_to=None):
super(MeanSquaredError, self).__init__( attach_to=attach_to )
self.reset()
self.log_scale=log_scale

@torch.no_grad()
def update(self, outputs, targets):
outputs, targets = self._attach(outputs, targets)
if self.log_scale:
diff = torch.sum((torch.log(outputs+1e-8) - torch.log(targets+1e-8))**2)
else:
diff = torch.sum((outputs - targets)**2)
self._accum_sq_diff += diff
self._cnt += torch.numel(outputs)

def get_results(self):
return (self._accum_sq_diff / self._cnt).detach().cpu()
def reset(self):
self._accum_sq_diff = 0.
self._cnt = 0.


class RootMeanSquaredError(MeanSquaredError):
def get_results(self):
return torch.sqrt( (self._accum_sq_diff / self._cnt) ).detach().cpu()


class MeanAbsoluteError(Metric):
def __init__(self, attach_to=None ):
super(MeanAbsoluteError, self).__init__( attach_to=attach_to )
self.reset()

@torch.no_grad()
def update(self, outputs, targets):
outputs, targets = self._attach(outputs, targets)
diff = torch.sum((outputs - targets).abs())
self._accum_abs_diff += diff
self._cnt += torch.numel(outputs)

def get_results(self):
return (self._accum_abs_diff / self._cnt).detach().cpu()
def reset(self):
self._accum_abs_diff = 0.
self._cnt = 0.


class ScaleInveriantMeanSquaredError(Metric):
def __init__(self, attach_to=None ):
super(ScaleInveriantMeanSquaredError, self).__init__( attach_to=attach_to )
self.reset()

@torch.no_grad()
def update(self, outputs, targets):
outputs, targets = self._attach(outputs, targets)
diff_log = torch.log( outputs+1e-8 ) - torch.log( targets+1e-8 )
self._accum_log_diff = diff_log.sum()
self._accum_sq_log_diff = (diff_log**2).sum()
self._cnt += torch.numel(outputs)

def get_results(self):
return ( self._accum_sq_log_diff / self._cnt - 0.5 * (self._accum_log_diff**2 / self._cnt**2) ).detach().cpu()
def reset(self):
self._accum_log_diff = 0.
self._accum_sq_log_diff = 0.
self._cnt = 0.


class RelativeDifference(Metric):
def __init__(self, attach_to=None ):
super(RelativeDifference, self).__init__( attach_to=attach_to )
self.reset()

@torch.no_grad()
def update(self, outputs, targets):
outputs, targets = self._attach(outputs, targets)
diff = (outputs - targets).abs()
self._accum_abs_rel += (diff/targets).sum()
self._cnt += torch.numel(outputs)

def get_results(self):
return (self._accum_abs_rel / self._cnt).detach().cpu()
def reset(self):
self._accum_abs_rel = 0.
self._cnt = 0.


class AbsoluteRelativeDifference(Metric):
def __init__(self, attach_to=None ):
super(AbsoluteRelativeDifference, self).__init__( attach_to=attach_to )
self.reset()

@torch.no_grad()
def update(self, outputs, targets):
outputs, targets = self._attach(outputs, targets)
diff = (outputs - targets).abs()
self._accum_abs_rel += (diff/targets).sum()
self._cnt += torch.numel(outputs)

def get_results(self):
return (self._accum_abs_rel / self._cnt).detach().cpu()
def reset(self):
self._accum_abs_rel = 0.
self._cnt = 0.


class AbsoluteRelativeDifference(Metric):
def __init__(self, attach_to=None ):
super(AbsoluteRelativeDifference, self).__init__( attach_to=attach_to )
self.reset()

@torch.no_grad()
def update(self, outputs, targets):
outputs, targets = self._attach(outputs, targets)
diff = (outputs - targets).abs()
self._accum_abs_rel += (diff/targets).sum()
self._cnt += torch.numel(outputs)

def get_results(self):
return (self._accum_abs_rel / self._cnt).detach().cpu()
def reset(self):
self._accum_abs_rel = 0.
self._cnt = 0.


class SquaredRelativeDifference(Metric):
def __init__(self, attach_to=None ):
super(SquaredRelativeDifference, self).__init__( attach_to=attach_to )
self.reset()

@torch.no_grad()
def update(self, outputs, targets):
outputs, targets = self._attach(outputs, targets)
diff = (outputs - targets)**2
self._accum_sq_rel += (diff/targets).sum()
self._cnt += torch.numel(outputs)

def get_results(self):
return (self._accum_sq_rel / self._cnt).detach().cpu()
def reset(self):
self._accum_sq_rel = 0.
self._cnt = 0.


class Threshold(Metric):
def __init__(self, thresholds=[1.25, 1.25**2, 1.25**3], attach_to=None ):
super(Threshold, self).__init__( attach_to=attach_to )
self.thresholds = thresholds
self.reset()

@torch.no_grad()
def update(self, outputs, targets):
outputs, targets = self._attach(outputs, targets)
sigma = torch.max(outputs / targets, targets / outputs)
for thres in self.thresholds:
self._accum_thres[thres]+=torch.sum( sigma<thres )
self._cnt += torch.numel(outputs)

def get_results(self):
return { thres: (self._accum_thres[thres] / self._cnt).detach().cpu() for thres in self.thresholds }
def reset(self):
self._cnt = 0.
self._accum_thres = {thres: 0. for thres in self.thresholds}

+ 82
- 0
model_measuring/kamal/core/metrics/stream_metrics.py View File

@@ -0,0 +1,82 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

from __future__ import division
import torch

import numpy as np
from abc import ABC, abstractmethod
from typing import Callable, Union, Any, Mapping, Sequence
import numbers
from kamal.core.attach import AttachTo

class Metric(ABC):
def __init__(self, attach_to=None):
self._attach = AttachTo(attach_to)

@abstractmethod
def update(self, pred, target):
""" Overridden by subclasses """
raise NotImplementedError()
@abstractmethod
def get_results(self):
""" Overridden by subclasses """
raise NotImplementedError()

@abstractmethod
def reset(self):
""" Overridden by subclasses """
raise NotImplementedError()


class MetricCompose(dict):
def __init__(self, metric_dict: Mapping):
self._metric_dict = metric_dict

def add_metrics( self, metric_dict: Mapping):
if isinstance(metric_dict, MetricCompose):
metric_dict = metric_dict.metrics
self._metric_dict.update(metric_dict)
return self

@property
def metrics(self):
return self._metric_dict
@torch.no_grad()
def update(self, outputs, targets):
for key, metric in self._metric_dict.items():
if isinstance(metric, Metric):
metric.update(outputs, targets)
def get_results(self):
results = {}
for key, metric in self._metric_dict.items():
if isinstance(metric, Metric):
results[key] = metric.get_results()
return results

def reset(self):
for key, metric in self._metric_dict.items():
if isinstance(metric, Metric):
metric.reset()

def __getitem__(self, name):
return self._metric_dict[name]



+ 3
- 0
model_measuring/kamal/core/tasks/__init__.py View File

@@ -0,0 +1,3 @@
from .task import StandardTask, StandardMetrics, GeneralTask, Task, TaskCompose
from . import loss


+ 2
- 0
model_measuring/kamal/core/tasks/loss/__init__.py View File

@@ -0,0 +1,2 @@
from . import functional, loss
from .loss import *

+ 107
- 0
model_measuring/kamal/core/tasks/loss/functional.py View File

@@ -0,0 +1,107 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

import torch
import torch.nn.functional as F
import torch.nn as nn

def kldiv(logits, targets, T=1.0):
""" Cross Entropy for soft targets
Parameters:
- logits (Tensor): logits score (e.g. outputs of fc layer)
- targets (Tensor): logits of soft targets
- T (float): temperature of distill
- reduction: reduction to the output
"""
p_targets = F.softmax(targets/T, dim=1)
logp_logits = F.log_softmax(logits/T, dim=1)
kld = F.kl_div(logp_logits, p_targets, reduction='none') * (T**2)
return kld.sum(1).mean()

def jsdiv(logits, targets, T=1.0, reduction='mean'):
p = F.softmax(logits, dim=1)
q = F.softmax(targets, dim=1)
log_m = torch.log( (p+q) / 2 )
return 0.5* ( F.kl_div( log_m, p, reduction=reduction) + F.kl_div( log_m, q, reduction=reduction) )

def mmd_loss(f1, f2, sigmas, normalized=False):
if len(f1.shape) != 2:
N, C, H, W = f1.shape
f1 = f1.view(N, -1)
N, C, H, W = f2.shape
f2 = f2.view(N, -1)

if normalized == True:
f1 = F.normalize(f1, p=2, dim=1)
f2 = F.normalize(f2, p=2, dim=1)

return _mmd_rbf2(f1, f2, sigmas=sigmas)

def psnr(img1, img2, size_average=True, data_range=255):
N = img1.shape[0]
mse = torch.mean(((img1-img2)**2).view(N, -1), dim=1)
psnr = torch.clamp(torch.log10(data_range**2 / mse) * 10, 0.0, 99.99)
if size_average == True:
psnr = psnr.mean()
return psnr

def soft_cross_entropy(logits, targets, T=1.0, size_average=True):
""" Cross Entropy for soft targets
**Parameters:**
- **logits** (Tensor): logits score (e.g. outputs of fc layer)
- **targets** (Tensor): logits of soft targets
- **T** (float): temperature of distill
- **size_average**: average the outputs
"""
p_targets = F.softmax(targets/T, dim=1)
logp_pred = F.log_softmax(logits/T, dim=1)
ce = torch.sum(-p_targets * logp_pred, dim=1)
if size_average:
return ce.mean() * T * T
else:
return ce * T * T

def _mmd_rbf2(x, y, sigmas=None):
N, _ = x.shape
xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t())

rx = (xx.diag().unsqueeze(0).expand_as(xx))
ry = (yy.diag().unsqueeze(0).expand_as(yy))

K = L = P = 0.0
XX2 = rx.t() + rx - 2*xx
YY2 = ry.t() + ry - 2*yy
XY2 = rx.t() + ry - 2*zz

if sigmas is None:
sigma2 = torch.mean((XX2.detach()+YY2.detach()+2*XY2.detach()) / 4)
sigmas2 = [sigma2/4, sigma2/2, sigma2, sigma2*2, sigma2*4]
alphas = [1.0 / (2 * sigma2) for sigma2 in sigmas2]
else:
alphas = [1.0 / (2 * sigma**2) for sigma in sigmas]

for alpha in alphas:
K += torch.exp(- alpha * (XX2.clamp(min=1e-12)))
L += torch.exp(- alpha * (YY2.clamp(min=1e-12)))
P += torch.exp(- alpha * (XY2.clamp(min=1e-12)))

beta = (1./(N*(N)))
gamma = (2./(N*N))

return F.relu(beta * (torch.sum(K)+torch.sum(L)) - gamma * torch.sum(P))

+ 386
- 0
model_measuring/kamal/core/tasks/loss/loss.py View File

@@ -0,0 +1,386 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
#from pytorch_msssim import ssim, ms_ssim, MS_SSIM, SSIM

from .functional import *

class KLDiv(object):
def __init__(self, T=1.0):
self.T = T
def __call__(self, logits, targets):
return kldiv( logits, targets, T=self.T )

class KDLoss(nn.Module):
""" KD Loss Function
"""
def __init__(self, T=1.0, alpha=1.0, use_kldiv=False):
super(KDLoss, self).__init__()
self.T = T
self.alpha = alpha
self.kdloss = kldiv if use_kldiv else soft_cross_entropy

def forward(self, logits, targets, hard_targets=None):
loss = self.kdloss(logits, targets, T=self.T)
if hard_targets is not None and self.alpha != 0.0:
loss += self.alpha*F.cross_entropy(logits, hard_targets)
return loss

class CFLLoss(nn.Module):
""" Common Feature Learning Loss
CFL Loss = MMD + MSE
"""
def __init__(self, sigmas, normalized=True):
super(CFLLoss, self).__init__()
self.sigmas = sigmas
self.normalized = normalized

def forward(self, hs, hts, fts_, fts):
mmd = mse = 0.0
for ht_i in hts:
mmd += mmd_loss(hs, ht_i, sigmas=self.sigmas, normalized=self.normalized)
for i in range(len(fts_)):
mse += F.mse_loss(fts_[i], fts[i])
return mmd, mse

class PSNR_Loss(nn.Module):
def __init__(self, data_range=1.0, size_average=True):
super(PSNR_Loss, self).__init__()
self.data_range = data_range
self.size_average = size_average

def forward(self, img1, img2):
return 100 - psnr(img1, img2, size_average=self.size_average, data_range=self.data_range)


#class MS_SSIM_Loss(MS_SSIM):
# def forward(self, img1, img2):
# return 100*(1 - super(MS_SSIM_Loss, self).forward(img1, img2))

class ScaleInvariantLoss(nn.Module):
"""This criterion is used in depth prediction task.

**Parameters:**
- **la** (int, optional): Default value is 0.5. No need to change.
- **ignore_index** (int, optional): Value to ignore.

**Shape:**
- **inputs**: $(N, H, W)$.
- **targets**: $(N, H, W)$.
- **output**: scalar.
"""
def __init__(self, la=0.5, ignore_index=0):
super(ScaleInvariantLoss, self).__init__()
self.la = la
self.ignore_index = ignore_index

def forward(self, inputs, targets):
size = inputs.size()
if len(size) == 3:
inputs = inputs.view(size[0], -1)
targets = targets.view(size[0], -1)

inv_mask = targets.eq(self.ignore_index)
nums = (1-inv_mask.float()).sum(1)
log_d = torch.log(inputs) - torch.log(targets)
log_d[inv_mask] = 0

loss = torch.div(torch.pow(log_d, 2).sum(1), nums) - \
self.la * torch.pow(torch.div(log_d.sum(1), nums), 2)

return loss.mean()

class AngleLoss(nn.Module):
"""This criterion is used in surface normal prediction task.

**Shape:**
- **inputs**: $(N, 3, H, W)$. Predicted space vector for each pixel. Must be formalized before.
- **targets**: $(N, 3, H, W)$. Ground truth. Must be formalized before.
- **masks**: $(N, 1, H, W)$. One for valid pixels, else zero.
- **output**: scalar.
"""
def forward(self, inputs, targets, masks):
nums = masks.sum(dim=[1,2,3])

product = (inputs * targets).sum(1, keepdim=True)
loss = -torch.div((product * masks.float()).sum([1,2,3]), nums)
return loss.mean()

class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=0, size_average=True, ignore_index=255):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.ignore_index = ignore_index
self.size_average = size_average

def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none', ignore_index=self.ignore_index)
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
if self.size_average:
return focal_loss.mean()
else:
return focal_loss.sum()

class AttentionLoss(nn.Module):
""" Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer"""
def __init__(self, p=2):
super(AttentionLoss, self).__init__()
self.p = p

def forward(self, g_s, g_t):
return sum([self.at_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)])

def at_loss(self, f_s, f_t):
s_H, t_H = f_s.shape[2], f_t.shape[2]
if s_H > t_H:
f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
elif s_H < t_H:
f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
else:
pass
return (self.at(f_s) - self.at(f_t)).pow(2).mean()

def at(self, f):
return F.normalize(f.pow(self.p).mean(1).view(f.size(0), -1))

class NSTLoss(nn.Module):
"""like what you like: knowledge distill via neuron selectivity transfer"""
def __init__(self):
super(NSTLoss, self).__init__()
pass

def forward(self, g_s, g_t):
return sum([self.nst_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)])

def nst_loss(self, f_s, f_t):
s_H, t_H = f_s.shape[2], f_t.shape[2]
if s_H > t_H:
f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
elif s_H < t_H:
f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
else:
pass

f_s = f_s.view(f_s.shape[0], f_s.shape[1], -1)
f_s = F.normalize(f_s, dim=2)
f_t = f_t.view(f_t.shape[0], f_t.shape[1], -1)
f_t = F.normalize(f_t, dim=2)

# set full_loss as False to avoid unnecessary computation
full_loss = True
if full_loss:
return (self.poly_kernel(f_t, f_t).mean().detach() + self.poly_kernel(f_s, f_s).mean()
- 2 * self.poly_kernel(f_s, f_t).mean())
else:
return self.poly_kernel(f_s, f_s).mean() - 2 * self.poly_kernel(f_s, f_t).mean()

def poly_kernel(self, a, b):
a = a.unsqueeze(1)
b = b.unsqueeze(2)
res = (a * b).sum(-1).pow(2)
return res

class SPLoss(nn.Module):
"""Similarity-Preserving Knowledge Distillation, ICCV2019, verified by original author"""
def __init__(self):
super(SPLoss, self).__init__()

def forward(self, g_s, g_t):
return sum([self.similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)])

def similarity_loss(self, f_s, f_t):
bsz = f_s.shape[0]
f_s = f_s.view(bsz, -1)
f_t = f_t.view(bsz, -1)

G_s = torch.mm(f_s, torch.t(f_s))
# G_s = G_s / G_s.norm(2)
G_s = torch.nn.functional.normalize(G_s)
G_t = torch.mm(f_t, torch.t(f_t))
# G_t = G_t / G_t.norm(2)
G_t = torch.nn.functional.normalize(G_t)

G_diff = G_t - G_s
loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz)
return loss

class RKDLoss(nn.Module):
"""Relational Knowledge Disitllation, CVPR2019"""
def __init__(self, w_d=25, w_a=50):
super(RKDLoss, self).__init__()
self.w_d = w_d
self.w_a = w_a

def forward(self, f_s, f_t):
student = f_s.view(f_s.shape[0], -1)
teacher = f_t.view(f_t.shape[0], -1)

# RKD distance loss
with torch.no_grad():
t_d = self.pdist(teacher, squared=False)
mean_td = t_d[t_d > 0].mean()
t_d = t_d / mean_td

d = self.pdist(student, squared=False)
mean_d = d[d > 0].mean()
d = d / mean_d

loss_d = F.smooth_l1_loss(d, t_d)

# RKD Angle loss
with torch.no_grad():
td = (teacher.unsqueeze(0) - teacher.unsqueeze(1))
norm_td = F.normalize(td, p=2, dim=2)
t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1)

sd = (student.unsqueeze(0) - student.unsqueeze(1))
norm_sd = F.normalize(sd, p=2, dim=2)
s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1)

loss_a = F.smooth_l1_loss(s_angle, t_angle)

loss = self.w_d * loss_d + self.w_a * loss_a

return loss

@staticmethod
def pdist(e, squared=False, eps=1e-12):
e_square = e.pow(2).sum(dim=1)
prod = e @ e.t()
res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps)

if not squared:
res = res.sqrt()

res = res.clone()
res[range(len(e)), range(len(e))] = 0
return res

class PKTLoss(nn.Module):
"""Probabilistic Knowledge Transfer for deep representation learning"""
def __init__(self):
super(PKTLoss, self).__init__()

def forward(self, f_s, f_t):
return self.cosine_similarity_loss(f_s, f_t)

@staticmethod
def cosine_similarity_loss(output_net, target_net, eps=0.0000001):
# Normalize each vector by its norm
output_net_norm = torch.sqrt(torch.sum(output_net ** 2, dim=1, keepdim=True))
output_net = output_net / (output_net_norm + eps)
output_net[output_net != output_net] = 0

target_net_norm = torch.sqrt(torch.sum(target_net ** 2, dim=1, keepdim=True))
target_net = target_net / (target_net_norm + eps)
target_net[target_net != target_net] = 0

# Calculate the cosine similarity
model_similarity = torch.mm(output_net, output_net.transpose(0, 1))
target_similarity = torch.mm(target_net, target_net.transpose(0, 1))

# Scale cosine similarity to 0..1
model_similarity = (model_similarity + 1.0) / 2.0
target_similarity = (target_similarity + 1.0) / 2.0

# Transform them into probabilities
model_similarity = model_similarity / torch.sum(model_similarity, dim=1, keepdim=True)
target_similarity = target_similarity / torch.sum(target_similarity, dim=1, keepdim=True)

# Calculate the KL-divergence
loss = torch.mean(target_similarity * torch.log((target_similarity + eps) / (model_similarity + eps)))

return loss

class SVDLoss(nn.Module):
"""
Self-supervised Knowledge Distillation using Singular Value Decomposition
"""
def __init__(self, k=1):
super(SVDLoss, self).__init__()
self.k = k

def forward(self, g_s, g_t):
v_sb = None
v_tb = None
losses = []
for i, f_s, f_t in zip(range(len(g_s)), g_s, g_t):

u_t, s_t, v_t = self.svd(f_t, self.k)
u_s, s_s, v_s = self.svd(f_s, self.k + 3)
v_s, v_t = self.align_rsv(v_s, v_t)
s_t = s_t.unsqueeze(1)
v_t = v_t * s_t
v_s = v_s * s_t

if i > 0:
s_rbf = torch.exp(-(v_s.unsqueeze(2) - v_sb.unsqueeze(1)).pow(2) / 8)
t_rbf = torch.exp(-(v_t.unsqueeze(2) - v_tb.unsqueeze(1)).pow(2) / 8)

l2loss = (s_rbf - t_rbf.detach()).pow(2)
l2loss = torch.where(torch.isfinite(l2loss), l2loss, torch.zeros_like(l2loss))
losses.append(l2loss.sum())

v_tb = v_t
v_sb = v_s

bsz = g_s[0].shape[0]
losses = [l / bsz for l in losses]
return sum(losses)

def svd(self, feat, n=1):
size = feat.shape
assert len(size) == 4

x = feat.view(-1, size[1], size[2] * size[2]).transpose(-2, -1)
u, s, v = torch.svd(x)

u = self.removenan(u)
s = self.removenan(s)
v = self.removenan(v)

if n > 0:
u = F.normalize(u[:, :, :n], dim=1)
s = F.normalize(s[:, :n], dim=1)
v = F.normalize(v[:, :, :n], dim=1)

return u, s, v

@staticmethod
def removenan(x):
x = torch.where(torch.isfinite(x), x, torch.zeros_like(x))
return x

@staticmethod
def align_rsv(a, b):
cosine = torch.matmul(a.transpose(-2, -1), b)
max_abs_cosine, _ = torch.max(torch.abs(cosine), 1, keepdim=True)
mask = torch.where(torch.eq(max_abs_cosine, torch.abs(cosine)),
torch.sign(cosine), torch.zeros_like(cosine))
a = torch.matmul(a, mask)
return a, b



+ 186
- 0
model_measuring/kamal/core/tasks/task.py View File

@@ -0,0 +1,186 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

import abc
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import typing
from typing import Callable, Dict, List, Any
from collections import Mapping, Sequence
from . import loss
from kamal.core import metrics, exceptions
from kamal.core.attach import AttachTo

class Task(object):
def __init__(self, name):
self.name = name
@abc.abstractmethod
def get_loss( self, outputs, targets ) -> Dict:
pass

@abc.abstractmethod
def predict(self, outputs) -> Any:
pass

class GeneralTask(Task):
def __init__(self,
name: str,
loss_fn: Callable,
scaling:float=1.0,
pred_fn: Callable=lambda x: x,
attach_to=None):
super(GeneralTask, self).__init__(name)
self._attach = AttachTo(attach_to)
self.loss_fn = loss_fn
self.pred_fn = pred_fn
self.scaling = scaling

def get_loss(self, outputs, targets):
outputs, targets = self._attach(outputs, targets)
return { self.name: self.loss_fn( outputs, targets ) * self.scaling }

def predict(self, outputs):
outputs = self._attach(outputs)
return self.pred_fn(outputs)
def __repr__(self):
rep = "Task: [%s loss_fn=%s scaling=%.4f attach=%s]"%(self.name, str(self.loss_fn), self.scaling, self._attach)
return rep

class TaskCompose(list):
def __init__(self, tasks: list):
for task in tasks:
if isinstance(task, Task):
self.append(task)
def get_loss(self, outputs, targets):
loss_dict = {}
for task in self:
loss_dict.update( task.get_loss( outputs, targets ) )
return loss_dict
def predict(self, outputs):
results = []
for task in self:
results.append( task.predict( outputs ) )
return results
def __repr__(self):
rep="TaskCompose: \n"
for task in self:
rep+="\t%s\n"%task

class StandardTask:
@staticmethod
def classification(name='ce', scaling=1.0, attach_to=None):
return GeneralTask( name=name,
loss_fn=nn.CrossEntropyLoss(),
scaling=scaling,
pred_fn=lambda x: x.max(1)[1],
attach_to=attach_to )

@staticmethod
def binary_classification(name='bce', scaling=1.0, attach_to=None):
return GeneralTask(name=name,
loss_fn=F.binary_cross_entropy_with_logits,
scaling=scaling,
pred_fn=lambda x: (x>0.5),
attach_to=attach_to )

@staticmethod
def regression(name='mse', scaling=1.0, attach_to=None):
return GeneralTask(name=name,
loss_fn=nn.MSELoss(),
scaling=scaling,
pred_fn=lambda x: x,
attach_to=attach_to )

@staticmethod
def segmentation(name='ce', scaling=1.0, attach_to=None):
return GeneralTask(name=name,
loss_fn=nn.CrossEntropyLoss(ignore_index=255),
scaling=scaling,
pred_fn=lambda x: x.max(1)[1],
attach_to=attach_to )
@staticmethod
def monocular_depth(name='l1', scaling=1.0, attach_to=None):
return GeneralTask(name=name,
loss_fn=nn.L1Loss(),
scaling=scaling,
pred_fn=lambda x: x,
attach_to=attach_to)

@staticmethod
def detection():
raise NotImplementedError

@staticmethod
def distillation(name='kld', T=1.0, scaling=1.0, attach_to=None):
return GeneralTask(name=name,
loss_fn=loss.KLDiv(T=T),
scaling=scaling,
pred_fn=lambda x: x.max(1)[1],
attach_to=attach_to)


class StandardMetrics(object):

@staticmethod
def classification(attach_to=None):
return metrics.MetricCompose(
metric_dict={'acc': metrics.Accuracy(attach_to=attach_to)}
)

@staticmethod
def regression(attach_to=None):
return metrics.MetricCompose(
metric_dict={'mse': metrics.MeanSquaredError(attach_to=attach_to)}
)

@staticmethod
def segmentation(num_classes, ignore_idx=255, attach_to=None):
confusion_matrix = metrics.ConfusionMatrix(num_classes=num_classes, ignore_idx=ignore_idx, attach_to=attach_to)
return metrics.MetricCompose(
metric_dict={'acc': metrics.Accuracy(attach_to=attach_to),
'confusion_matrix': confusion_matrix ,
'miou': metrics.mIoU(confusion_matrix)}
)

@staticmethod
def monocular_depth(attach_to=None):
return metrics.MetricCompose(
metric_dict={
'rmse': metrics.RootMeanSquaredError(attach_to=attach_to),
'rmse_log': metrics.RootMeanSquaredError( log_scale=True,attach_to=attach_to ),
'rmse_scale_inv': metrics.ScaleInveriantMeanSquaredError(attach_to=attach_to),
'abs rel': metrics.AbsoluteRelativeDifference(attach_to=attach_to),
'sq rel': metrics.SquaredRelativeDifference(attach_to=attach_to),
'percents within thresholds': metrics.Threshold( thresholds=[1.25, 1.25**2, 1.25**3], attach_to=attach_to )
}
)
@staticmethod
def loss_metric(loss_fn):
return metrics.MetricCompose(
metric_dict={
'loss': metrics.AverageMetric( loss_fn )
}
)

+ 2
- 0
model_measuring/kamal/slim/__init__.py View File

@@ -0,0 +1,2 @@
from .prunning import Pruner, strategy
from .distillation import *

+ 12
- 0
model_measuring/kamal/slim/distillation/__init__.py View File

@@ -0,0 +1,12 @@
from .kd import KDDistiller
from .hint import *
from .attention import AttentionDistiller
from .nst import NSTDistiller
from .sp import SPDistiller
from .rkd import RKDDistiller
from .pkt import PKTDistiller
from .svd import SVDDistiller
from .cc import *
from .vid import *

from . import data_free

+ 47
- 0
model_measuring/kamal/slim/distillation/attention.py View File

@@ -0,0 +1,47 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

from .kd import KDDistiller
from kamal.core.tasks.loss import AttentionLoss
from kamal.core.tasks.loss import KDLoss
from kamal.utils import set_mode, move_to_device

import torch
import torch.nn as nn

import time

class AttentionDistiller(KDDistiller):
def __init__(self, logger=None, tb_writer=None ):
super(AttentionDistiller, self).__init__( logger, tb_writer )
def setup(self,
student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0,
stu_hooks=[], tea_hooks=[], out_flags=[], device=None):
super(AttentionDistiller, self).setup(
student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device)
self.stu_hooks = stu_hooks
self.tea_hooks = tea_hooks
self.out_flags = out_flags
self._at_loss = AttentionLoss()
def additional_kd_loss(self, engine, batch):
feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)]
feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)]
g_s = feat_s[1:-1]
g_t = feat_t[1:-1]
return self._at_loss(g_s, g_t)

+ 55
- 0
model_measuring/kamal/slim/distillation/cc.py View File

@@ -0,0 +1,55 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

from .kd import KDDistiller
from kamal.core.tasks.loss import KDLoss

import torch.nn as nn
import torch._ops

import time

class CCDistiller(KDDistiller):

def __init__(self, logger=None, tb_writer=None ):
super(CCDistiller, self).__init__( logger, tb_writer )

def setup(self, student, teacher, dataloader, optimizer, embed_s, embed_t, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None ):
super(CCDistiller, self).setup(
student, teacher, dataloader, optimizer, T=T, gamma=gamma, alpha=alpha, device=device)
self.embed_s = embed_s.to(self.device).train()
self.embed_t = embed_t.to(self.device).train()
self.stu_hooks = stu_hooks
self.tea_hooks = tea_hooks
self.out_flags = out_flags

def additional_kd_loss(self, engine, batch):
feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)]
feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)]
f_s = self.embed_s(feat_s[-1])
f_t = self.embed_t(feat_t[-1])
return torch.mean((torch.abs(f_s-f_t)[:-1] * torch.abs(f_s-f_t)[1:]).sum(1))

class LinearEmbed(nn.Module):
def __init__(self, dim_in=1024, dim_out=128):
super(LinearEmbed, self).__init__()
self.linear = nn.Linear(dim_in, dim_out)

def forward(self, x):
x = x.view(x.shape[0], -1)
x = self.linear(x)
return x

+ 1
- 0
model_measuring/kamal/slim/distillation/data_free/__init__.py View File

@@ -0,0 +1 @@
from .zskt import ZSKTDistiller

+ 99
- 0
model_measuring/kamal/slim/distillation/data_free/zskt.py View File

@@ -0,0 +1,99 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

import torch
import time
import torch.nn.functional as F
from kamal.slim.distillation.kd import KDDistiller
from kamal.utils import set_mode
from kamal.core.tasks.loss import kldiv

class ZSKTDistiller(KDDistiller):
def __init__( self,
student,
teacher,
generator,
z_dim,
logger=None,
viz=None):
super(ZSKTDistiller, self).__init__(logger, viz)
self.teacher = teacher
self.model = self.student = student
self.generator = generator
self.z_dim = z_dim

def train(self, start_iter, max_iter, optim_s, optim_g, device=None):
if device is None:
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )
self.device = device
self.optim_s, self.optim_g = optim_s, optim_g

self.model.to(self.device)
self.teacher.to(self.device)
self.generator.to(self.device)
self.train_loader = [0, ]
with set_mode(self.student, training=True), \
set_mode(self.teacher, training=False), \
set_mode(self.generator, training=True):
super( ZSKTDistiller, self ).train( start_iter, max_iter )
def search_optimizer(self, evaluator, train_loader, hpo_space=None, mode='min', max_evals=20, max_iters=400):
optimizer = hpo.search_optimizer(self, train_loader, evaluator=evaluator, hpo_space=hpo_space, mode=mode, max_evals=max_evals, max_iters=max_iters)
return optimizer
def step(self):
start_time = time.perf_counter()
# Adv
z = torch.randn( self.z_dim ).to(self.device)
fake = self.generator( z )
self.optim_g.zero_grad()
t_out = self.teacher( fake )
s_out = self.student( fake )
loss_g = -kldiv( s_out, t_out )
loss_g.backward()
self.optim_g.step()

with torch.no_grad():
fake = self.generator( z )
t_out = self.teacher( fake.detach() )
for _ in range(10):
self.optim_s.zero_grad()
s_out = self.student( fake.detach() )
loss_s = kldiv( s_out, t_out )
loss_s.backward()
self.optim_s.step()
loss_dict = {
'loss_g': loss_g,
'loss_s': loss_s,
}

step_time = time.perf_counter() - start_time

# record training info
info = loss_dict
info['step_time'] = step_time
info['lr_s'] = float( self.optim_s.param_groups[0]['lr'] )
info['lr_g'] = float( self.optim_g.param_groups[0]['lr'] )
self.history.put_scalars( **info )
def reset(self):
self.history = None
self._train_loader_iter = iter(train_loader)
self.iter = self.start_iter

+ 86
- 0
model_measuring/kamal/slim/distillation/hint.py View File

@@ -0,0 +1,86 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

from .kd import KDDistiller
from kamal.core.tasks.loss import KDLoss

import torch.nn as nn
import torch._ops

import time


class HintDistiller(KDDistiller):
def __init__(self, logger=None, tb_writer=None ):
super(HintDistiller, self).__init__( logger, tb_writer )

def setup(self,
student, teacher, regressor, dataloader, optimizer,
hint_layer=2, T=1.0, alpha=1.0, beta=1.0, gamma=1.0,
stu_hooks=[], tea_hooks=[], out_flags=[], device=None):
super( HintDistiller, self ).setup(
student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device )
self.regressor = regressor
self._hint_layer = hint_layer
self._beta = beta
self.stu_hooks = stu_hooks
self.tea_hooks = tea_hooks
self.out_flags = out_flags
self.regressor.to(device)
def additional_kd_loss(self, engine, batch):
feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)]
feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)]
f_s = self.regressor(feat_s[self._hint_layer])
f_t = feat_t[self._hint_layer]
return nn.functional.mse_loss(f_s, f_t)

class Regressor(nn.Module):
"""
Convolutional regression for FitNet
@inproceedings{tian2019crd,
title={Contrastive Representation Distillation},
author={Yonglong Tian and Dilip Krishnan and Phillip Isola},
booktitle={International Conference on Learning Representations},
year={2020}
}
"""

def __init__(self, s_shape, t_shape, is_relu=True):
super(Regressor, self).__init__()
self.is_relu = is_relu
_, s_C, s_H, s_W = s_shape
_, t_C, t_H, t_W = t_shape
if s_H == 2 * t_H:
self.conv = nn.Conv2d(s_C, t_C, kernel_size=3, stride=2, padding=1)
elif s_H * 2 == t_H:
self.conv = nn.ConvTranspose2d(
s_C, t_C, kernel_size=4, stride=2, padding=1)
elif s_H >= t_H:
self.conv = nn.Conv2d(s_C, t_C, kernel_size=(1+s_H-t_H, 1+s_W-t_W))
else:
raise NotImplemented(
'student size {}, teacher size {}'.format(s_H, t_H))
self.bn = nn.BatchNorm2d(t_C)
self.relu = nn.ReLU(inplace=True)

def forward(self, x):
x = self.conv(x)
if self.is_relu:
return self.relu(self.bn(x))
else:
return self.bn(x)

+ 90
- 0
model_measuring/kamal/slim/distillation/kd.py View File

@@ -0,0 +1,90 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

from kamal.core.engine.trainer import Engine
from kamal.core.tasks.loss import kldiv
import torch.nn.functional as F
from kamal.utils.logger import get_logger
from kamal.utils import set_mode, move_to_device
import weakref

import torch
import torch.nn as nn

import time
import numpy as np

class KDDistiller(Engine):
def __init__( self,
logger=None,
tb_writer=None):
super(KDDistiller, self).__init__(logger=logger, tb_writer=tb_writer)

def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, device=None):
self.model = self.student = student
self.teacher = teacher
self.dataloader = dataloader
self.optimizer = optimizer
if device is None:
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )
self.device = device

self.T = T
self.gamma = gamma
self.alpha = alpha
self.beta = beta

self.student.to(self.device)
self.teacher.to(self.device)

def run( self, max_iter, start_iter=0, epoch_length=None):
with set_mode(self.student, training=True), \
set_mode(self.teacher, training=False):
super( KDDistiller, self ).run( self.step_fn, self.dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length)

def additional_kd_loss(self, engine, batch):
return batch[0].new_zeros(1)

def step_fn(self, engine, batch):
student = self.student
teacher = self.teacher
start_time = time.perf_counter()
batch = move_to_device(batch, self.device)
inputs, targets = batch
outputs = student(inputs)
with torch.no_grad():
soft_targets = teacher(inputs)
loss_dict = { "loss_kld": self.alpha * kldiv(outputs, soft_targets, T=self.T),
"loss_ce": self.beta * F.cross_entropy( outputs, targets ),
"loss_additional": self.gamma * self.additional_kd_loss(engine, batch) }
loss = sum( loss_dict.values() )
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
step_time = time.perf_counter() - start_time
metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() }
metrics.update({
'total_loss': loss.item(),
'step_time': step_time,
'lr': float( self.optimizer.param_groups[0]['lr'] )
})
return metrics



+ 44
- 0
model_measuring/kamal/slim/distillation/nst.py View File

@@ -0,0 +1,44 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

from .kd import KDDistiller
from kamal.core.tasks.loss import KDLoss
from kamal.core.tasks.loss import NSTLoss

import torch
import torch.nn as nn

import time

class NSTDistiller(KDDistiller):
def __init__(self, logger=None, tb_writer=None ):
super(NSTDistiller, self).__init__( logger, tb_writer )

def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None):
super( NSTDistiller, self ).setup(
student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device )
self.stu_hooks = stu_hooks
self.tea_hooks = tea_hooks
self.out_flags = out_flags
self._nst_loss = NSTLoss()

def additional_kd_loss(self, engine, batch):
feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)]
feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)]
g_s = feat_s[1:-1]
g_t = feat_t[1:-1]
return self._nst_loss(g_s, g_t)

+ 44
- 0
model_measuring/kamal/slim/distillation/pkt.py View File

@@ -0,0 +1,44 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

from .kd import KDDistiller
from kamal.core.tasks.loss import PKTLoss
from kamal.core.tasks.loss import KDLoss

import torch
import torch.nn as nn

import time

class PKTDistiller(KDDistiller):
def __init__(self, logger=None, tb_writer=None ):
super(PKTDistiller, self).__init__( logger, tb_writer )

def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None):
super( PKTDistiller, self ).setup(
student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device )
self.stu_hooks = stu_hooks
self.tea_hooks = tea_hooks
self.out_flags = out_flags
self._pkt_loss = PKTLoss()
def additional_kd_loss(self, engine, batch):
feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)]
feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)]
f_s = feat_s[-1]
f_t = feat_t[-1]
return self._pkt_loss(f_s, f_t)

+ 45
- 0
model_measuring/kamal/slim/distillation/rkd.py View File

@@ -0,0 +1,45 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

from .kd import KDDistiller
from kamal.core.tasks.loss import RKDLoss
from kamal.core.tasks.loss import KDLoss

import torch
import torch.nn as nn

import time

class RKDDistiller(KDDistiller):
def __init__(self, logger=None, tb_writer=None ):
super(RKDDistiller, self).__init__( logger, tb_writer )

def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None):
super( RKDDistiller, self ).setup(
student, teacher, dataloader, optimizer, T=T, gamma=gamma, alpha=alpha, device=device )
self.stu_hooks = stu_hooks
self.tea_hooks = tea_hooks
self.out_flags = out_flags
self._rkd_loss = RKDLoss()
def additional_kd_loss(self, engine, batch):
feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)]
feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)]
f_s = feat_s[-1]
f_t = feat_t[-1]
return self._rkd_loss(f_s, f_t)

+ 44
- 0
model_measuring/kamal/slim/distillation/sp.py View File

@@ -0,0 +1,44 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

from .kd import KDDistiller
from kamal.core.tasks.loss import SPLoss
from kamal.core.tasks.loss import KDLoss

import torch
import torch.nn as nn

import time

class SPDistiller(KDDistiller):
def __init__(self, logger=None, tb_writer=None ):
super(SPDistiller, self).__init__( logger, tb_writer )

def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None):
super( SPDistiller, self ).setup(
student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device )
self.stu_hooks = stu_hooks
self.tea_hooks = tea_hooks
self.out_flags = out_flags
self._sp_loss = SPLoss()

def additional_kd_loss(self, engine, batch):
feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)]
feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)]
g_s = [feat_s[-2]]
g_t = [feat_t[-2]]
return self._sp_loss(g_s, g_t)

+ 45
- 0
model_measuring/kamal/slim/distillation/svd.py View File

@@ -0,0 +1,45 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

from .kd import KDDistiller
from kamal.core.tasks.loss import SVDLoss
from kamal.core.tasks.loss import KDLoss

import torch
import torch.nn as nn

import time


class SVDDistiller(KDDistiller):
def __init__(self, logger=None, tb_writer=None ):
super(SVDDistiller, self).__init__( logger, tb_writer )

def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None):
super( SVDDistiller, self ).setup(
student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device )
self.stu_hooks = stu_hooks
self.tea_hooks = tea_hooks
self.out_flags = out_flags
self._svd_loss = SVDLoss()

def additional_kd_loss(self, engine, batch):
feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)]
feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)]
g_s = feat_s[1:-1]
g_t = feat_t[1:-1]
return self._svd_loss( g_s, g_t )

+ 90
- 0
model_measuring/kamal/slim/distillation/vid.py View File

@@ -0,0 +1,90 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

import numpy as np
import time
import torch.nn as nn
import torch._ops
import torch.nn.functional as F
from .kd import KDDistiller
from kamal.utils import set_mode
from kamal.core.tasks.loss import KDLoss

class VIDDistiller(KDDistiller):
def __init__(self, logger=None, tb_writer=None ):
super(VIDDistiller, self).__init__( logger, tb_writer )

def setup(self, student, teacher, dataloader, optimizer, regressor_l, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None):
super( VIDDistiller, self ).setup(
student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device )
self.regressor_l = regressor_l
self.stu_hooks = stu_hooks
self.tea_hooks = tea_hooks
self.out_flags = out_flags
self.regressor_l = [regressor.to(self.device).train() for regressor in self.regressor_l]

def additional_kd_loss(self, engine, batch):
feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)]
feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)]
g_s = feat_s[1:-1]
g_t = feat_t[1:-1]
return sum([c(f_s, f_t) for f_s, f_t, c in zip(g_s, g_t, self.regressor_l)])

class VIDRegressor(nn.Module):
def __init__(self,
num_input_channels,
num_mid_channel,
num_target_channels,
init_pred_var=5.0,
eps=1e-5):
super(VIDRegressor, self).__init__()

def conv1x1(in_channels, out_channels, stride=1):
return nn.Conv2d(
in_channels, out_channels,
kernel_size=1, padding=0,
bias=False, stride=stride)

self.regressor = nn.Sequential(
conv1x1(num_input_channels, num_mid_channel),
nn.ReLU(),
conv1x1(num_mid_channel, num_mid_channel),
nn.ReLU(),
conv1x1(num_mid_channel, num_target_channels),
)
self.log_scale = torch.nn.Parameter(
np.log(np.exp(init_pred_var-eps)-1.0) * torch.ones(num_target_channels)
)
self.eps = eps

def forward(self, input, target):
# pool for dimentsion match
s_H, t_H = input.shape[2], target.shape[2]
if s_H > t_H:
input = F.adaptive_avg_pool2d(input, (t_H, t_H))
elif s_H < t_H:
target = F.adaptive_avg_pool2d(target, (s_H, s_H))
else:
pass
pred_mean = self.regressor(input)
pred_var = torch.log(1.0+torch.exp(self.log_scale))+self.eps
pred_var = pred_var.view(1, -1, 1, 1)
neg_log_prob = 0.5*(
(pred_mean-target)**2/pred_var+torch.log(pred_var)
)
loss = torch.mean(neg_log_prob)
return loss

+ 2
- 0
model_measuring/kamal/slim/prunning/__init__.py View File

@@ -0,0 +1,2 @@
from .pruner import Pruner
from .strategy import LNStrategy, RandomStrategy

+ 37
- 0
model_measuring/kamal/slim/prunning/pruner.py View File

@@ -0,0 +1,37 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

from copy import deepcopy
import os
import torch

class Pruner(object):
def __init__(self, strategy):
self.strategy = strategy
def prune(self, model, rate=0.1, example_inputs=None):
ori_num_params = sum( [ torch.numel(p) for p in model.parameters() ] )
model = deepcopy(model).cpu()
model = self._prune( model, rate=rate, example_inputs=example_inputs )
new_num_params = sum( [ torch.numel(p) for p in model.parameters() ] )
print( "%d=>%d, %.2f%% params were pruned"%( ori_num_params, new_num_params, 100*(ori_num_params-new_num_params)/ori_num_params ) )
return model
def _prune(self, model, **kargs):
return self.strategy( model, **kargs)


+ 85
- 0
model_measuring/kamal/slim/prunning/strategy.py View File

@@ -0,0 +1,85 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

import torch_pruning as tp
import abc
import torch
import torch.nn as nn
import random
import numpy as np

_PRUNABLE_MODULES= tp.DependencyGraph.PRUNABLE_MODULES

class BaseStrategy(abc.ABC):

@abc.abstractmethod
def select(self, layer_to_prune):
pass

def __call__(self, model, rate=0.1, example_inputs=None):
if example_inputs is None:
example_inputs = torch.randn( 1,3,256,256 )

DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=example_inputs)

prunable_layers = []
total_params = 0
num_accumulative_conv_params = [ 0, ]

for m in model.modules():
if isinstance(m, _PRUNABLE_MODULES ) :
nparam = tp.utils.count_prunable_params( m )
total_params += nparam
if isinstance(m, (nn.modules.conv._ConvNd, nn.Linear)):
prunable_layers.append( m )
num_accumulative_conv_params.append( num_accumulative_conv_params[-1]+nparam )
prunable_layers.pop(-1) # remove the last layer
num_accumulative_conv_params.pop(-1) # remove the last layer
num_conv_params = num_accumulative_conv_params[-1]
num_accumulative_conv_params = [ ( num_accumulative_conv_params[i], num_accumulative_conv_params[i+1] ) for i in range(len(num_accumulative_conv_params)-1) ]
def map_param_idx_to_conv_layer(i):
for l, accu in zip( prunable_layers, num_accumulative_conv_params ):
if accu[0]<=i and i<accu[1]:
return l

num_pruned = 0
while num_pruned<total_params*rate:
layer_to_prune = map_param_idx_to_conv_layer( random.randint( 0, num_conv_params-1 ) )
if layer_to_prune.weight.shape[0]<1:
continue
idx = self.select( layer_to_prune )
fn = tp.prune_conv if isinstance(layer_to_prune, nn.modules.conv._ConvNd) else tp.prune_linear
plan = DG.get_pruning_plan( layer_to_prune, fn, idxs=idx )
num_pruned += plan.exec()
return model

class RandomStrategy(BaseStrategy):
def select(self, layer_to_prune):
return [ random.randint( 0, layer_to_prune.weight.shape[0]-1 ) ]

class LNStrategy(BaseStrategy):
def __init__(self, n=2):
self.n = n

def select(self, layer_to_prune):
w = torch.flatten( layer_to_prune.weight, 1 )
norm = torch.norm(w, p=self.n, dim=1)
idx = [ int(norm.min(dim=0)[1].item()) ]
return idx

+ 18
- 0
model_measuring/kamal/transferability/README.md View File

@@ -0,0 +1,18 @@
# Deep Model Transferbility from Attribution Maps

- [*"Paper: Deep Model Transferbility from Attribution Maps"*](https:), NeurIPS 2019.(released soon)

J. Song, Y. Chen, X. Wang, C. Shen, M. Song

[Homepage of VIPA Group](https://www.vipazoo.cn/index_en.html), Zhejiang University, China

This repo is rewrited by zhfeing in `Pytorch`.


[Homepage of VIPA Group](https://www.vipazoo.cn/index_en.html), Zhejiang University, China

<div align="left">
<img src="vipa-logo.png" width = "40%" height = "40%" alt="icon"/>
</div>



+ 20
- 0
model_measuring/kamal/transferability/__init__.py View File

@@ -0,0 +1,20 @@
from kamal.transferability.trans_graph import TransferabilityGraph
from kamal.transferability.trans_metric import AttrMapMetric

import kamal
from kamal.vision import sync_transforms as sT
import os
import torch
from PIL import Image

if __name__=='__main__':
zoo = '/tmp/pycharm_project_225/kamal/transferability/model2'
TG = TransferabilityGraph(zoo)
probe_set_root = '/tmp/pycharm_project_225/kamal/transferability/probe_data'
for probe_set in os.listdir( probe_set_root ):
print("Add %s"%(probe_set))
imgs_set = list( os.listdir( os.path.join( probe_set_root, probe_set ) ) )
images = [ Image.open( os.path.join(probe_set_root, probe_set, img) ) for img in imgs_set ]
metric = AttrMapMetric(images, device=torch.device('cuda'))
TG.add_metric( probe_set, metric)
TG.export_to_json(probe_set, 'exported_metrics/%s.json'%(probe_set), topk=3, normalize=True)

+ 3
- 0
model_measuring/kamal/transferability/depara/__init__.py View File

@@ -0,0 +1,3 @@
from .attribution_graph import get_attribution_graph, graph_similarity

from .attribution_map import attribution_map, attr_map_distance

+ 184
- 0
model_measuring/kamal/transferability/depara/attribution_graph.py View File

@@ -0,0 +1,184 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

from typing import Type, Dict, Any
import itertools

import numpy as np
from numpy import Inf
import networkx

import torch
from torch.nn import Module
from torch.utils.data import Dataset

from .attribution_map import attribution_map, attr_map_similarity


def graph_to_array(graph: networkx.Graph):
weight_matrix = np.zeros((len(graph.nodes), len(graph.nodes)))
for i, n1 in enumerate(graph.nodes):
for j, n2 in enumerate(graph.nodes):
try:
dist = graph[n1][n2]["weight"]
except KeyError:
dist = 1
weight_matrix[i, j] = dist
return weight_matrix


class FeatureMapExtractor():
def __init__(self, module: Module):
self.module = module
self.feature_pool: Dict[str, Dict[str, Any]] = dict()
self.register_hooks()

def register_hooks(self):
for name, m in self.module.named_modules():
if "pool" in name:
m.name = name
self.feature_pool[name] = dict()

def hook(m: Module, input, output):
self.feature_pool[m.name]["feature"] = input
self.feature_pool[name]["handle"] = m.register_forward_hook(hook)

def _forward(self, x):
self.module(x)

def remove_hooks(self):
for name, cfg in self.feature_pool.items():
cfg["handle"].remove()
cfg.clear()
self.feature_pool.clear()

def extract_final_map(self, x):
self._forward(x)
feature_map = None
max_channel = 0
min_size = Inf
for name, cfg in self.feature_pool.items():
f = cfg["feature"]
if len(f) == 1 and isinstance(f[0], torch.Tensor):
f = f[0]
if f.dim() == 4: # BxCxHxW
b, c, h, w = f.shape
if c >= max_channel and 1 < h * w <= min_size:
feature_map = f
max_channel = c
min_size = h * w
return feature_map


def get_attribution_graph(
model: Module,
attribution_type: Type,
with_noise: bool,
probe_data: Dataset,
device: torch.device,
norm_square: bool = False,
):
attribution_graph = networkx.Graph()
model = model.to(device)
extractor = FeatureMapExtractor(model)
for i, x in enumerate(probe_data):
x = x.to(device)
x.requires_grad_()

attribution = attribution_map(
func=lambda x: extractor.extract_final_map(x),
attribution_type=attribution_type,
with_noise=with_noise,
probe_data=x.unsqueeze(0),
norm_square=norm_square
)

attribution_graph.add_node(i, attribution_map=attribution)

nodes = attribution_graph.nodes
for i, j in itertools.product(nodes, nodes):
if i < j:
weight = attr_map_similarity(
attribution_graph.nodes(data=True)[i]["attribution_map"],
attribution_graph.nodes(data=True)[j]["attribution_map"]
)
attribution_graph.add_edge(i, j, weight=weight)

return attribution_graph


def edge_to_embedding(graph: networkx.Graph):
adj = graph_to_array(graph)
up_tri_mask = np.tri(*adj.shape[-2:], k=0, dtype=bool)
return adj[up_tri_mask]


def embedding_to_rank(embedding: np.ndarray):
order = embedding.argsort()
ranks = order.argsort()
return ranks


def graph_similarity(g1: networkx.Graph, g2: networkx.Graph, Lambda: float = 1.0):
nodes_1 = g1.nodes(data=True)
nodes_2 = g2.nodes(data=True)
assert len(nodes_1) == len(nodes_2)

# calculate vertex similarity
v_s = 0
n = len(g1.nodes)
for i in range(n):
v_s += attr_map_similarity(
map_1=g1.nodes(data=True)[i]["attribution_map"],
map_2=g2.nodes(data=True)[i]["attribution_map"]
)
vertex_similarity = v_s / n

# calculate edges similarity
emb_1 = edge_to_embedding(g1)
emb_2 = edge_to_embedding(g2)
rank_1 = embedding_to_rank(emb_1)
rank_2 = embedding_to_rank(emb_2)
k = emb_1.shape[0]
edge_similarity = 1 - 6 * np.sum(np.square(rank_1 - rank_2)) / (k ** 3 - k)

return vertex_similarity + Lambda * edge_similarity


if __name__ == "__main__":
from captum.attr import InputXGradient
from torchvision.models import resnet34

model_1 = resnet34(num_classes=10)
graph_1 = get_attribution_graph(
model_1,
attribution_type=InputXGradient,
with_noise=False,
probe_data=torch.rand(10, 3, 244, 244),
device=torch.device("cpu")
)

model_2 = resnet34(num_classes=10)
graph_2 = get_attribution_graph(
model_2,
attribution_type=InputXGradient,
with_noise=False,
probe_data=torch.rand(10, 3, 244, 244),
device=torch.device("cpu")
)

print(graph_similarity(graph_1, graph_2))

+ 87
- 0
model_measuring/kamal/transferability/depara/attribution_map.py View File

@@ -0,0 +1,87 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

import torch
from typing import Type, Callable
from captum.attr import Attribution
from captum.attr import NoiseTunnel


def with_norm(func: Callable[[torch.Tensor], torch.Tensor], x: torch.Tensor, square: bool = False):
x = func(x)
x = torch.norm(x.flatten(1), dim=1, p=2)
if square:
x = torch.pow(x, 2)
return x


def attribution_map(
func: Callable[[torch.Tensor], torch.Tensor],
attribution_type: Type,
with_noise: bool,
probe_data: torch.Tensor,
norm_square: bool = False,
**attribution_kwargs
) -> torch.Tensor:
"""
Calculate attribution map with given attribution type(algorithm).
Args:
model: pytorch module
attribution_type: attribution algorithm, e.g. IntegratedGradients, InputXGradient, ...
with_noise: whether to add noise tunnel
probe_data: input data to model
device: torch.device("cuda: 0")
attribution_kwargs: other kwargs for attribution method
Return: attribution map
"""
attribution: Attribution = attribution_type(lambda x: with_norm(func, x, norm_square))
if with_noise:
attribution = NoiseTunnel(attribution)
attr_map = attribution.attribute(
inputs=probe_data,
target=None,
**attribution_kwargs
)
return attr_map.detach()

def attr_map_distance(map_1: torch.Tensor, map_2: torch.Tensor):
if map_1.shape != map_2.shape:
map_1 = torch.nn.functional.interpolate( map_1, size=map_2.shape[-2:] )
#dist = torch.dist(map_1.flatten(1), map_2.flatten(1), p=2).mean()
dist = 1 - torch.cosine_similarity(map_1.flatten(1), map_2.flatten(1)).mean()
return dist.item()

def attr_map_similarity(map_1: torch.Tensor, map_2: torch.Tensor):
assert(map_1.shape == map_2.shape)
dist = torch.cosine_similarity(map_1.flatten(1), map_2.flatten(1)).mean()
return dist.item()


if __name__ == "__main__":
import captum

def ff(x):
return x ** 2

m = attribution_map(
ff,
captum.attr.InputXGradient,
with_noise=False,
probe_data=torch.tensor([[1, 2, 3, 4]], dtype=torch.float, requires_grad=True),
norm_square=True
)
print(m)

+ 135
- 0
model_measuring/kamal/transferability/trans_graph.py View File

@@ -0,0 +1,135 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""
import torch
import networkx as nx
from . import depara
import os, abc
from typing import Callable
from kamal import hub
import json, numbers

from tqdm import tqdm

class Node(object):
def __init__(self, hub_root, entry_name, spec_name):
self.hub_root = hub_root
self.entry_name = entry_name
self.spec_name = spec_name

@property
def model(self):
return hub.load( self.hub_root, entry_name=self.entry_name, spec_name=self.spec_name ).eval()

@property
def tag(self):
return hub.load_tags(self.hub_root, entry_name=self.entry_name, spec_name=self.spec_name)

@property
def metadata(self):
return hub.load_metadata(self.hub_root, entry_name=self.entry_name, spec_name=self.spec_name)

class TransferabilityGraph(object):
def __init__(self, model_zoo_set):
self.model_zoo_set = model_zoo_set
# self.model_zoo = os.path.abspath( os.path.expanduser( model_zoo ) )
self._graphs = dict()
self._models = dict()
self._register_models()

def _register_models(self):
cnt = 0
for model_zoo in self.model_zoo_set:
model_zoo = os.path.abspath(os.path.expanduser(model_zoo))
for hub_root in self._list_modelzoo(model_zoo):
for entry_name, spec_name in hub.list_spec(hub_root):
node = Node( hub_root, entry_name, spec_name )
name = node.metadata['name']
self._models[name] = node
cnt += 1
print("%d models has been registered!"%cnt)
def _list_modelzoo(self, zoo_dir):
zoo_list = []
def _traverse(path):
for item in os.listdir(path):
item_path = os.path.join(path, item)
if os.path.isdir(item_path):
if os.path.exists(os.path.join( item_path, 'code/hubconf.py' )):
zoo_list.append(item_path)
else:
_traverse( item_path )
_traverse(zoo_dir)
return zoo_list

def add_metric(self, metric_name, metric):
self._graphs[metric_name] = g = nx.DiGraph()
g.add_nodes_from( self._models.values() )
for n1 in self._models.values():
for n2 in tqdm(self._models.values()):
if n1!=n2 and not g.has_edge(n1, n2):
try:
g.add_edge(n1, n2, dist=metric( n1, n2 ))
except:
ori_device = metric.device
metric.device = torch.device('cpu')
g.add_edge(n1, n2, dist=metric( n1, n2 ))
metric.device = ori_device
def export_to_json(self, metric_name, output_filename, topk=None, normalize=False):
graph = self._graphs.get( metric_name, None )
assert graph is not None
graph_data={
'nodes': [],
'edges': [],
}
node_to_idx = {}
for i, node in enumerate(self._models.values()):
tags = node.tag
metadata = node.metadata
node_data = { k:v for (k, v) in tags.items() if isinstance(v, (numbers.Number, str) ) }
node_data['name'] = metadata['name']
node_data['task'] = metadata['task']
node_data['dataset'] = metadata['dataset']
node_data['url'] = metadata['url']
node_data['id'] = i
graph_data['nodes'].append({'tags': node_data})
node_to_idx[node] = i

# record Edges
edge_list = graph_data['edges']
topk_dist = { idx: [] for idx in range(len( self._models )) }
for i, edge in enumerate(graph.edges.data('dist')):
s, t, d = int( node_to_idx[edge[0]] ), int( node_to_idx[edge[1]] ), float(edge[2])
topk_dist[s].append(d)
edge_list.append([
s, t, d # source, target, distance
])

if isinstance(topk, int):
for i, dist in topk_dist.items():
dist.sort()
topk_dist[i] = dist[topk]
graph_data['edges'] = [ edge for edge in edge_list if edge[2] < topk_dist[edge[0]] ]

if normalize:
edge_dist = [e[2] for e in graph_data['edges']]
min_dist, max_dist = min(edge_dist), max(edge_dist)
for e in graph_data['edges']:
e[2] = (e[2] - min_dist+1e-8) / (max_dist - min_dist+1e-8)
with open(output_filename, 'w') as fp:
json.dump(graph_data, fp)

+ 109
- 0
model_measuring/kamal/transferability/trans_metric.py View File

@@ -0,0 +1,109 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

import abc

from . import depara
from captum.attr import InputXGradient
import torch
from kamal.core import hub
from kamal.vision import sync_transforms as sT

class TransMetric(abc.ABC):
def __init__(self):
pass

def __call__(self, a, b) -> float:
return 0

class DeparaMetric(TransMetric):
def __init__(self, data, device):
self.data = data
self.device = device
self._cache = {}
def _get_transform(self, metadata):
input_metadata = metadata['input']
size = input_metadata['size']
space = input_metadata['space']
drange = input_metadata['range']
normalize = input_metadata['normalize']
if size==None:
size=224
if isinstance(size, (list, tuple)):
size = size[-1]
transform = [
sT.Resize(size),
sT.CenterCrop(size),
]
if space=='bgr':
transform.append(sT.FlipChannels())
if list(drange)==[0, 1]:
transform.append( sT.ToTensor() )
elif list(drange)==[0, 255]:
transform.append( sT.ToTensor(normalize=False, dtype=torch.float) )
else:
raise NotImplementedError
if normalize is not None:
transform.append(sT.Normalize( mean=normalize['mean'], std=normalize['std'] ))
return sT.Compose(transform)

def _get_attr_graph(self, n):
transform = self._get_transform(n.metadata)
data = torch.stack( [ transform( d ) for d in self.data ], dim=0 )
return depara.get_attribution_graph(
n.model,
attribution_type=InputXGradient,
with_noise=False,
probe_data=data,
device=self.device
)

def __call__(self, n1, n2):
attrgraph_1 = self._cache.get(n1, None)
attrgraph_2 = self._cache.get(n2, None)
if attrgraph_1 is None:
self._cache[n1] = attrgraph_1 = self._get_attr_graph(n1).cpu()
if attrgraph_2 is None:
self._cache[n2] = attrgraph_2 = self._get_attr_graph(n2).cpu()
result = depara.graph_similarity(attrgraph_1, attrgraph_2)
self._cache[n1] = self._cache[n1]
self._cache[n2] = self._cache[n2]

class AttrMapMetric(DeparaMetric):
def _get_attr_map(self, n):
transform = self._get_transform(n.metadata)
data = torch.stack( [ transform( d ).to(self.device) for d in self.data ], dim=0 )
return depara.attribution_map(
n.model.to(self.device),
attribution_type=InputXGradient,
with_noise=False,
probe_data=data,
)

def __call__(self, n1, n2):
attrgraph_1 = self._cache.get(n1, None)
attrgraph_2 = self._cache.get(n2, None)
if attrgraph_1 is None:
self._cache[n1] = attrgraph_1 = self._get_attr_map(n1).cpu()
if attrgraph_2 is None:
self._cache[n2] = attrgraph_2 = self._get_attr_map(n2).cpu()
result = depara.attr_map_distance(attrgraph_1, attrgraph_2)
self._cache[n1] = self._cache[n1]
self._cache[n2] = self._cache[n2]
return result


+ 2
- 0
model_measuring/kamal/utils/__init__.py View File

@@ -0,0 +1,2 @@
from ._utils import *
from .logger import get_logger

+ 153
- 0
model_measuring/kamal/utils/_utils.py View File

@@ -0,0 +1,153 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

import numpy as np
import math
import torch
import random
from copy import deepcopy
import contextlib, hashlib

def split_batch(batch):
if isinstance(batch, (list, tuple)):
inputs, *targets = batch
if len(targets)==1:
targets = targets[0]
return inputs, targets
else:
return [batch, None]

@contextlib.contextmanager
def set_mode(model, training=True):
ori_mode = model.training
model.train(training)
yield
model.train(ori_mode)

def move_to_device(obj, device):
if isinstance(obj, torch.Tensor):
return obj.to(device=device)
elif isinstance( obj, (list, tuple) ):
return [ o.to(device=device) for o in obj ]
elif isinstance(obj, nn.Module):
return obj.to(device=device)


def pack_images(images, col=None, channel_last=False):
# N, C, H, W
if isinstance(images, (list, tuple) ):
images = np.stack(images, 0)
if channel_last:
images = images.transpose(0,3,1,2) # make it channel first
assert len(images.shape)==4
assert isinstance(images, np.ndarray)
N,C,H,W = images.shape
if col is None:
col = int(math.ceil(math.sqrt(N)))
row = int(math.ceil(N / col))
pack = np.zeros( (C, H*row, W*col), dtype=images.dtype )
for idx, img in enumerate(images):
h = (idx//col) * H
w = (idx% col) * W
pack[:, h:h+H, w:w+W] = img
return pack

def normalize(tensor, mean, std, reverse=False):
if reverse:
_mean = [ -m / s for m, s in zip(mean, std) ]
_std = [ 1/s for s in std ]
else:
_mean = mean
_std = std
_mean = torch.as_tensor(_mean, dtype=tensor.dtype, device=tensor.device)
_std = torch.as_tensor(_std, dtype=tensor.dtype, device=tensor.device)
tensor = (tensor - _mean[None, :, None, None]) / (_std[None, :, None, None])
return tensor

class Normalizer(object):
def __init__(self, mean, std, reverse=False):
self.mean = mean
self.std = std
self.reverse = reverse

def __call__(self, x):
if self.reverse:
return self.denormalize(x)
else:
return self.normalize(x)
def normalize(self, x):
return normalize( x, self.mean, self.std )
def denormalize(self, x):
return normalize( x, self.mean, self.std, reverse=True )


def colormap(N=256, normalized=False):
def bitget(byteval, idx):
return ((byteval & (1 << idx)) != 0)

dtype = 'float32' if normalized else 'uint8'
cmap = np.zeros((N, 3), dtype=dtype)
for i in range(N):
r = g = b = 0
c = i
for j in range(8):
r = r | (bitget(c, 0) << 7-j)
g = g | (bitget(c, 1) << 7-j)
b = b | (bitget(c, 2) << 7-j)
c = c >> 3

cmap[i] = np.array([r, g, b])

cmap = cmap/255 if normalized else cmap
return cmap

DEFAULT_COLORMAP = colormap()

def flatten_dict(dic):
flattned = dict()

def _flatten(prefix, d):
for k, v in d.items():
if isinstance(v, dict):
if prefix is None:
_flatten( k, v )
else:
_flatten( prefix+'%s/'%k, v )
else:
flattned[ (prefix+'%s/'%k).strip('/') ] = v
_flatten('', dic)
return flattned

def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

def count_parameters(model):
return sum( [ p.numel() for p in model.parameters() ] )

def md5(fname):
hash_md5 = hashlib.md5()
with open(fname, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()

+ 56
- 0
model_measuring/kamal/utils/logger.py View File

@@ -0,0 +1,56 @@
import logging
import os, sys
from termcolor import colored

class _ColorfulFormatter(logging.Formatter):
def __init__(self, *args, **kwargs):
super(_ColorfulFormatter, self).__init__(*args, **kwargs)

def formatMessage(self, record):
log = super(_ColorfulFormatter, self).formatMessage(record)

if record.levelno == logging.WARNING:
prefix = colored("WARNING", "yellow", attrs=["blink"])
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
else:
return log
return prefix + " " + log

def get_logger(name='Kamal', output=None, color=True):
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
logger.propagate = False
# STDOUT
stdout_handler = logging.StreamHandler( stream=sys.stdout )
stdout_handler.setLevel( logging.DEBUG )

plain_formatter = logging.Formatter(
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" )
if color:
formatter = _ColorfulFormatter(
colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
datefmt="%m/%d %H:%M:%S")
else:
formatter = plain_formatter
stdout_handler.setFormatter(formatter)

logger.addHandler(stdout_handler)

# FILE
if output is not None:
if output.endswith('.txt') or output.endswith('.log'):
os.makedirs(os.path.dirname(output), exist_ok=True)
filename = output
else:
os.makedirs(output, exist_ok=True)
filename = os.path.join(output, "log.txt")
file_handler = logging.FileHandler(filename)
file_handler.setFormatter(plain_formatter)
file_handler.setLevel(logging.DEBUG)
logger.addHandler(file_handler)
return logger


+ 3
- 0
model_measuring/kamal/vision/__init__.py View File

@@ -0,0 +1,3 @@
from . import models
from . import datasets
from . import sync_transforms

+ 16
- 0
model_measuring/kamal/vision/datasets/__init__.py View File

@@ -0,0 +1,16 @@
from .ade20k import ADE20K
from .caltech import Caltech101, Caltech256
from .camvid import CamVid
from .cityscapes import Cityscapes
from .cub200 import CUB200
from .fgvc_aircraft import FGVCAircraft
from .nyu import NYUv2
from .stanford_cars import StanfordCars
from .stanford_dogs import StanfordDogs
from .sunrgbd import SunRGBD
from .voc import VOCClassification, VOCSegmentation
from .dataset import LabelConcatDataset

from torchvision import datasets as torchvision_datasets

from .unlabeled import UnlabeledDataset

+ 70
- 0
model_measuring/kamal/vision/datasets/ade20k.py View File

@@ -0,0 +1,70 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

import collections
import torch
import torchvision
import numpy as np
from PIL import Image
import os
from torchvision.datasets import VisionDataset
from .utils import colormap

class ADE20K(VisionDataset):
cmap = colormap()

def __init__(
self,
root,
split="training",
transform=None,
target_transform=None,
transforms=None,
):
super( ADE20K, self ).__init__( root=root, transforms=transforms, transform=transform, target_transform=target_transform )
assert split in ['training', 'validation'], "split should be \'training\' or \'validation\'"
self.root = os.path.expanduser(root)
self.split = split
self.num_classes = 150

img_list = []
lbl_list = []
img_dir = os.path.join( self.root, 'images', self.split )
lbl_dir = os.path.join( self.root, 'annotations', self.split )

for img_name in os.listdir( img_dir ):
img_list.append( os.path.join( img_dir, img_name ) )
lbl_list.append( os.path.join( lbl_dir, img_name[:-3]+'png') )

self.img_list = img_list
self.lbl_list = lbl_list

def __len__(self):
return len(self.img_list)

def __getitem__(self, index):
img = Image.open( self.img_list[index] )
lbl = Image.open( self.lbl_list[index] )
if self.transforms:
img, lbl = self.transforms(img, lbl)
lbl = np.array(lbl, dtype='uint8')-1 # 1-150 => 0-149 + 255
return img, lbl

@classmethod
def decode_seg_to_color(cls, mask):
"""decode semantic mask to RGB image"""
return cls.cmap[mask+1]

+ 226
- 0
model_measuring/kamal/vision/datasets/caltech.py View File

@@ -0,0 +1,226 @@
# Modified from https://github.com/pytorch/vision/blob/master/torchvision/datasets/caltech.py
from __future__ import print_function
from PIL import Image
import os
import os.path

from torchvision.datasets.vision import VisionDataset
from torchvision.datasets.utils import download_url


class Caltech101(VisionDataset):
"""`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset.

Args:
root (string): Root directory of dataset where directory
``caltech101`` exists or will be saved to if download is set to True.
target_type (string or list, optional): Type of target to use, ``category`` or
``annotation``. Can also be a list to output a tuple with all specified target types.
``category`` represents the target class, and ``annotation`` is a list of points
from a hand-generated outline. Defaults to ``category``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""

def __init__(self, root, target_type="category", train=True,
transform=None, target_transform=None,
download=False):
super(Caltech101, self).__init__(os.path.join(root, 'caltech101'))
self.train = train
self.dir_name = '101_ObjectCategories_split/train' if self.train else '101_ObjectCategories_split/test'
os.makdirs(self.root, exist_ok=True)
if isinstance(target_type, list):
self.target_type = target_type
else:
self.target_type = [target_type]
self.transform = transform
self.target_transform = target_transform

if download:
self.download()

if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')

self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories")))
self.categories.remove("BACKGROUND_Google") # this is not a real class

# For some reason, the category names in "101_ObjectCategories" and
# "Annotations" do not always match. This is a manual map between the
# two. Defaults to using same name, since most names are fine.
name_map = {"Faces": "Faces_2",
"Faces_easy": "Faces_3",
"Motorbikes": "Motorbikes_16",
"airplanes": "Airplanes_Side_2"}
self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories))

self.index = []
self.y = []
for (i, c) in enumerate(self.categories):
file_names = os.listdir(os.path.join(self.root, self.dir_name, c))
n = len(file_names)
self.index.extend( file_names )
self.y.extend(n * [i])
print(self.train, len(self.index))

def __getitem__(self, index):
"""
Args:
index (int): Index

Returns:
tuple: (image, target) where the type of target specified by target_type.
"""
import scipy.io

img = Image.open(os.path.join(self.root,
self.dir_name,
self.categories[self.y[index]],
self.index[index])).convert("RGB")
target = []
for t in self.target_type:
if t == "category":
target.append(self.y[index])
elif t == "annotation":
data = scipy.io.loadmat(os.path.join(self.root,
"Annotations",
self.annotation_categories[self.y[index]],
"annotation_{:04d}.mat".format(self.index[index])))
target.append(data["obj_contour"])
else:
raise ValueError("Target type \"{}\" is not recognized.".format(t))
target = tuple(target) if len(target) > 1 else target[0]

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)

return img, target

def _check_integrity(self):
# can be more robust and check hash of files
return os.path.exists(os.path.join(self.root, "101_ObjectCategories"))

def __len__(self):
return len(self.index)

def download(self):
import tarfile

if self._check_integrity():
print('Files already downloaded and verified')
return

download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
self.root,
"101_ObjectCategories.tar.gz",
"b224c7392d521a49829488ab0f1120d9")
download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",
self.root,
"101_Annotations.tar",
"6f83eeb1f24d99cab4eb377263132c91")

# extract file
with tarfile.open(os.path.join(self.root, "101_ObjectCategories.tar.gz"), "r:gz") as tar:
tar.extractall(path=self.root)

with tarfile.open(os.path.join(self.root, "101_Annotations.tar"), "r:") as tar:
tar.extractall(path=self.root)

def extra_repr(self):
return "Target type: {target_type}".format(**self.__dict__)


class Caltech256(VisionDataset):
"""`Caltech 256 <http://www.vision.caltech.edu/Image_Datasets/Caltech256/>`_ Dataset.

Args:
root (string): Root directory of dataset where directory
``caltech256`` exists or will be saved to if download is set to True.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""

def __init__(self, root,
transform=None, target_transform=None,
download=False):
super(Caltech256, self).__init__(os.path.join(root, 'caltech256'))
os.makedirs(self.root, exist_ok=True)
self.transform = transform
self.target_transform = target_transform

if download:
self.download()

if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')

self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
self.index = []
self.y = []
for (i, c) in enumerate(self.categories):
n = len(os.listdir(os.path.join(self.root, "256_ObjectCategories", c)))
self.index.extend(range(1, n + 1))
self.y.extend(n * [i])

def __getitem__(self, index):
"""
Args:
index (int): Index

Returns:
tuple: (image, target) where target is index of the target class.
"""
img = Image.open(os.path.join(self.root,
"256_ObjectCategories",
self.categories[self.y[index]],
"{:03d}_{:04d}.jpg".format(self.y[index] + 1, self.index[index])))

target = self.y[index]

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)

return img, target

def _check_integrity(self):
# can be more robust and check hash of files
return os.path.exists(os.path.join(self.root, "256_ObjectCategories"))

def __len__(self):
return len(self.index)

def download(self):
import tarfile

if self._check_integrity():
print('Files already downloaded and verified')
return

download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar",
self.root,
"256_ObjectCategories.tar",
"67b4f42ca05d46448c6bb8ecd2220f6d")

# extract file
with tarfile.open(os.path.join(self.root, "256_ObjectCategories.tar"), "r:") as tar:
tar.extractall(path=self.root)

+ 78
- 0
model_measuring/kamal/vision/datasets/camvid.py View File

@@ -0,0 +1,78 @@
# Modified from https://github.com/davidtvs/PyTorch-ENet/blob/master/data/camvid.py
import os
import torch.utils.data as data
from glob import glob
from PIL import Image
import numpy as np
from torchvision.datasets import VisionDataset

class CamVid(VisionDataset):
"""CamVid dataset loader where the dataset is arranged as in https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid.
Args:
root (string):
split (string): The type of dataset: 'train', 'val', 'trainval', or 'test'
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. Default: None.
target_transform (callable, optional): A function/transform that takes in the target and transform it. Default: None.
transforms (callable, optional): A function/transform that takes in both the image and target and transform them. Default: None.
"""

cmap = np.array([
(128, 128, 128),
(128, 0, 0),
(192, 192, 128),
(128, 64, 128),
(60, 40, 222),
(128, 128, 0),
(192, 128, 128),
(64, 64, 128),
(64, 0, 128),
(64, 64, 0),
(0, 128, 192),
(0, 0, 0),
])

def __init__(self,
root,
split='train',
transform=None,
target_transform=None,
transforms=None):
assert split in ('train', 'val', 'test', 'trainval')
super( CamVid, self ).__init__(root=root, transforms=transforms, transform=transform, target_transform=target_transform)
self.root = os.path.expanduser(root)
self.split = split

if split == 'trainval':
self.images = glob(os.path.join(self.root, 'train', '*.png')) + glob(os.path.join(self.root, 'val', '*.png'))
self.labels = glob(os.path.join(self.root, 'trainannot', '*.png')) + glob(os.path.join(self.root, 'valannot', '*.png'))
else:
self.images = glob(os.path.join(self.root, self.split, '*.png'))
self.labels = glob(os.path.join(self.root, self.split+'annot', '*.png'))
self.images.sort()
self.labels.sort()

def __getitem__(self, idx):
"""
Args:
- index (``int``): index of the item in the dataset
Returns:
A tuple of ``PIL.Image`` (image, label) where label is the ground-truth
of the image.
"""

img, label = Image.open(self.images[idx]), Image.open(self.labels[idx])
if self.transforms is not None:
img, label = self.transforms(img, label)
label[label == 11] = 255 # ignore void
return img, label.squeeze(0)
def __len__(self):
return len(self.images)

@classmethod
def decode_fn(cls, mask):
"""decode semantic mask to RGB image"""
mask[mask == 255] = 11
return cls.cmap[mask]

+ 146
- 0
model_measuring/kamal/vision/datasets/cityscapes.py View File

@@ -0,0 +1,146 @@
# Modified from https://github.com/pytorch/vision/blob/master/torchvision/datasets/cityscapes.py
import json
import os
from collections import namedtuple

import torch
import torch.utils.data as data
from PIL import Image
import numpy as np
from torchvision.datasets import VisionDataset

class Cityscapes(VisionDataset):
"""Cityscapes <http://www.cityscapes-dataset.com/> Dataset.
Args:
root (string): Root directory of dataset where directory 'leftImg8bit' and 'gtFine' or 'gtCoarse' are located.
split (string, optional): The image split to use, 'train', 'test' or 'val' if mode="gtFine" otherwise 'train', 'train_extra' or 'val'
mode (string, optional): The quality mode to use, 'gtFine' or 'gtCoarse' or 'color'. Can also be a list to output a tuple with all specified target types.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
"""

# Based on https://github.com/mcordts/cityscapesScripts
CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id',
'has_instances', 'ignore_in_eval', 'color'])
classes = [
CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)),
CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)),
CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)),
CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)),
CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)),
CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)),
CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)),
CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)),
CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)),
CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)),
CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)),
CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)),
CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)),
CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)),
CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)),
CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)),
CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)),
CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)),
CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)),
CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)),
CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)),
CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)),
CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)),
CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)),
CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)),
CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)),
CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)),
CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)),
CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)),
CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)),
CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)),
CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)),
CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)),
CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
CityscapesClass('license plate', -1, 255, 'vehicle', 7, False, True, (0, 0, 142)),
]

_TRAIN_ID_TO_COLOR = [c.color for c in classes if (c.train_id != -1 and c.train_id != 255)]
_TRAIN_ID_TO_COLOR.append([0, 0, 0])
_TRAIN_ID_TO_COLOR = np.array(_TRAIN_ID_TO_COLOR)
_ID_TO_TRAIN_ID = np.array([c.train_id for c in classes])
def __init__(self, root, split='train', mode='gtFine', target_type='semantic', transform=None, target_transform=None, transforms=None):
super(Cityscapes, self).__init__( root, transform=transform, target_transform=target_transform, transforms=transforms )
self.root = os.path.expanduser(root)
self.mode = mode
self.target_type = target_type

self.images_dir = os.path.join(self.root, 'leftImg8bit', split)
self.targets_dir = os.path.join(self.root, self.mode, split)
self.split = split

self.images = []
self.targets = []

if split not in ['train', 'test', 'val']:
raise ValueError('Invalid split for mode! Please use split="train", split="test"'
' or split="val"')

if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
' specified "split" and "mode" are inside the "root" directory')
for city in os.listdir(self.images_dir):
img_dir = os.path.join(self.images_dir, city)
target_dir = os.path.join(self.targets_dir, city)

for file_name in os.listdir(img_dir):
self.images.append(os.path.join(img_dir, file_name))
target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],
self._get_target_suffix(self.mode, self.target_type))
self.targets.append(os.path.join(target_dir, target_name))
@classmethod
def encode_target(cls, target):
if isinstance( target, torch.Tensor ):
return torch.from_numpy( cls._ID_TO_TRAIN_ID[np.array(target)] )
else:
return cls._ID_TO_TRAIN_ID[target]

@classmethod
def decode_fn(cls, target):
target[target == 255] = 19
#target = target.astype('uint8') + 1
return cls._TRAIN_ID_TO_COLOR[target]

def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
"""
image = Image.open(self.images[index]).convert('RGB')
target = Image.open(self.targets[index])
if self.transforms:
image, target = self.transforms(image, target)
target = self.encode_target(target)
return image, target

def __len__(self):
return len(self.images)

def _load_json(self, path):
with open(path, 'r') as file:
data = json.load(file)
return data

def _get_target_suffix(self, mode, target_type):
if target_type == 'instance':
return '{}_instanceIds.png'.format(mode)
elif target_type == 'semantic':
return '{}_labelIds.png'.format(mode)
elif target_type == 'color':
return '{}_color.png'.format(mode)
elif target_type == 'polygon':
return '{}_polygons.json'.format(mode)
elif target_type == 'depth':
return '{}_disparity.png'.format(mode)

+ 70
- 0
model_measuring/kamal/vision/datasets/cub200.py View File

@@ -0,0 +1,70 @@
# Modified from https://github.com/TDeVries/cub2011_dataset/blob/master/cub2011.py
import os
import pandas as pd
from torchvision.datasets.folder import default_loader
from .utils import download_url
from torch.utils.data import Dataset
import shutil


class CUB200(Dataset):
base_folder = 'images'
url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'
filename = 'CUB_200_2011.tgz'
tgz_md5 = '97eceeb196236b17998738112f37df78'

def __init__(self, root, split='train', transform=None, target_transform=None, loader=default_loader, download=False):
self.root = os.path.abspath( os.path.expanduser( root ) )
self.transform = transform
self.target_transform = target_transform
self.loader = default_loader
self.split = split

if download:
self.download()
self._load_metadata()
categories = os.listdir(os.path.join(
self.root, 'CUB_200_2011', 'images'))
categories.sort()
self.object_categories = [c[4:] for c in categories]
print('CUB200, Split: %s, Size: %d' % (self.split, self.__len__()))

def _load_metadata(self):
images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ',
names=['img_id', 'filepath'])
image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'),
sep=' ', names=['img_id', 'target'], encoding='latin-1', engine='python')
train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'),
sep=' ', names=['img_id', 'is_training_img'], encoding='latin-1', engine='python')
data = images.merge(image_class_labels, on='img_id')
self.data = data.merge(train_test_split, on='img_id')

if self.split == 'train':
self.data = self.data[self.data.is_training_img == 1]
else:
self.data = self.data[self.data.is_training_img == 0]

def download(self):
import tarfile
os.makedirs(self.root, exist_ok=True)
if not os.path.isfile(os.path.join(self.root, self.filename)):
download_url(self.url, self.root, self.filename)
print("Extracting %s..." % self.filename)
with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
tar.extractall(path=self.root)

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
sample = self.data.iloc[idx]
path = os.path.join(self.root, 'CUB_200_2011',
self.base_folder, sample.filepath)
lbl = sample.target - 1
img = self.loader(path)

if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
lbl = self.target_transform(lbl)
return img, lbl

+ 57
- 0
model_measuring/kamal/vision/datasets/dataset.py View File

@@ -0,0 +1,57 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""


from torchvision.datasets import VisionDataset
from PIL import Image
import torch

class LabelConcatDataset(VisionDataset):
"""Dataset as a concatenation of dataset's lables.

This class is useful to assemble the same dataset's labels.

Arguments:
datasets (sequence): List of datasets to be concatenated
tasks (list) : List of teacher tasks
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version.
"""

def __init__(self, datasets, transforms=None, transform=None, target_transform=None):
super(LabelConcatDataset, self).__init__(
root=None, transforms=transforms, transform=transform, target_transform=target_transform)
assert len(datasets) > 0, 'datasets should not be an empty iterable'
self.datasets = list(datasets)
for d in self.datasets:
for t in ('transform', 'transforms', 'target_transform'):
if hasattr( d, t ):
setattr( d, t, None )

def __getitem__(self, idx):
targets_list = []
for dst in self.datasets:
image, target = dst[idx]
targets_list.append(target)
if self.transforms is not None:
image, *targets_list = self.transforms( image, *targets_list )
data = [ image, *targets_list ]
return data

def __len__(self):
return len(self.datasets[0].images)

+ 143
- 0
model_measuring/kamal/vision/datasets/fgvc_aircraft.py View File

@@ -0,0 +1,143 @@
#Modified from https://github.com/pytorch/vision/pull/467/files
from __future__ import print_function
import torch.utils.data as data
from torchvision.datasets.folder import pil_loader, accimage_loader, default_loader
from PIL import Image
import os
import numpy as np

from .utils import download_url, mkdir


def make_dataset(dir, image_ids, targets):
assert(len(image_ids) == len(targets))
images = []
dir = os.path.expanduser(dir)
for i in range(len(image_ids)):
item = (os.path.join(dir, 'fgvc-aircraft-2013b', 'data', 'images',
'%s.jpg' % image_ids[i]), targets[i])
images.append(item)
return images


def find_classes(classes_file):
# read classes file, separating out image IDs and class names
image_ids = []
targets = []
f = open(classes_file, 'r')
for line in f:
split_line = line.split(' ')
image_ids.append(split_line[0])
targets.append(' '.join(split_line[1:]))
f.close()

# index class names
classes = np.unique(targets)
class_to_idx = {classes[i]: i for i in range(len(classes))}
targets = [class_to_idx[c] for c in targets]

return (image_ids, targets, classes, class_to_idx)


class FGVCAircraft(data.Dataset):
"""`FGVC-Aircraft <http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft>`_ Dataset.
Args:
root (string): Root directory path to dataset.
class_type (string, optional): The level of FGVC-Aircraft fine-grain classification
to label data with (i.e., ``variant``, ``family``, or ``manufacturer``).
transforms (callable, optional): A function/transforms that takes in a PIL image
and returns a transformed version. E.g. ``transforms.RandomCrop``
target_transform (callable, optional): A function/transforms that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in the root directory. If dataset is already downloaded, it is not
downloaded again.
"""
url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz'
class_types = ('variant', 'family', 'manufacturer')
splits = ('train', 'val', 'trainval', 'test')

def __init__(self, root, class_type='variant', split='train', transform=None,
target_transform=None, loader=default_loader, download=False):
if split not in self.splits:
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
split, ', '.join(self.splits),
))
if class_type not in self.class_types:
raise ValueError('Class type "{}" not found. Valid class types are: {}'.format(
class_type, ', '.join(self.class_types),
))

self.root = root
self.class_type = class_type
self.split = split
self.classes_file = os.path.join(self.root, 'fgvc-aircraft-2013b', 'data',
'images_%s_%s.txt' % (self.class_type, self.split))
if download:
self.download()

(image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file)
samples = make_dataset(self.root, image_ids, targets)

self.transform = transform
self.target_transform = target_transform
self.loader = loader

self.samples = samples
self.classes = classes
self.class_to_idx = class_to_idx

with open(os.path.join(self.root, 'fgvc-aircraft-2013b/data', 'variants.txt')) as f:
self.object_categories = [
line.strip('\n') for line in f.readlines()]
print('FGVC-Aircraft, Split: %s, Size: %d' % (self.split, self.__len__()))

def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""

path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target

def __len__(self):
return len(self.samples)

def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(
tmp, self.transforms.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(
tmp, self.target_transforms.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str

def _check_exists(self):
return os.path.exists(os.path.join(self.root, 'data', 'images')) and \
os.path.exists(self.classes_file)

def download(self):
"""Download the FGVC-Aircraft data if it doesn't exist already."""
from six.moves import urllib
import tarfile

mkdir(self.root)

fpath = os.path.join(self.root, 'fgvc-aircraft-2013b.tar.gz')
if not os.path.isfile(fpath):
download_url(self.url, self.root, 'fgvc-aircraft-2013b.tar.gz')
print("Extracting fgvc-aircraft-2013b.tar.gz...")
with tarfile.open(fpath, "r:gz") as tar:
tar.extractall(path=self.root)

+ 84
- 0
model_measuring/kamal/vision/datasets/nyu.py View File

@@ -0,0 +1,84 @@
# Modified from https://github.com/VainF/nyuv2-python-toolkit
import os
import torch
import torch.utils.data as data
from PIL import Image
from scipy.io import loadmat
import numpy as np
import glob
from torchvision import transforms
from torchvision.datasets import VisionDataset
import random

from .utils import colormap

class NYUv2(VisionDataset):
"""NYUv2 dataset
See https://github.com/VainF/nyuv2-python-toolkit for more details.
Args:
root (string): Root directory path.
split (string, optional): 'train' for training set, and 'test' for test set. Default: 'train'.
target_type (string, optional): Type of target to use, ``semantic``, ``depth`` or ``normal``.
num_classes (int, optional): The number of classes, must be 40 or 13. Default:13.
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version.
"""
cmap = colormap()
def __init__(self,
root,
split='train',
target_type='semantic',
num_classes=13,
transforms=None,
transform=None,
target_transform=None):
super( NYUv2, self ).__init__(root, transforms=transforms, transform=transform, target_transform=target_transform)
assert(split in ('train', 'test'))

self.root = root
self.split = split
self.target_type = target_type
self.num_classes = num_classes
split_mat = loadmat(os.path.join(self.root, 'splits.mat'))
idxs = split_mat[self.split+'Ndxs'].reshape(-1) - 1

img_names = os.listdir( os.path.join(self.root, 'image', self.split) )
img_names.sort()
images_dir = os.path.join(self.root, 'image', self.split)
self.images = [os.path.join(images_dir, name) for name in img_names]

self._is_depth = False
if self.target_type=='semantic':
semantic_dir = os.path.join(self.root, 'seg%d'%self.num_classes, self.split)
self.labels = [os.path.join(semantic_dir, name) for name in img_names]
self.targets = self.labels
if self.target_type=='depth':
depth_dir = os.path.join(self.root, 'depth', self.split)
self.depths = [os.path.join(depth_dir, name) for name in img_names]
self.targets = self.depths
self._is_depth = True
if self.target_type=='normal':
normal_dir = os.path.join(self.root, 'normal', self.split)
self.normals = [os.path.join(normal_dir, name) for name in img_names]
self.targets = self.normals
def __getitem__(self, idx):
image = Image.open(self.images[idx])
target = Image.open(self.targets[idx])
if self.transforms is not None:
image, target = self.transforms( image, target )
return image, target

def __len__(self):
return len(self.images)

@classmethod
def decode_fn(cls, mask: np.ndarray):
"""decode semantic mask to RGB image"""
mask = mask.astype('uint8') + 1 # 255 => 0
return cls.cmap[mask]

+ 61
- 0
model_measuring/kamal/vision/datasets/preprocess/prepare_caltech101.py View File

@@ -0,0 +1,61 @@

import os, sys
from glob import glob
import random
import argparse
from PIL import Image

if __name__=='__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_root', type=str, default='../101_ObjectCategories')
parser.add_argument('--test_split', type=float, default=0.3)
args = parser.parse_args()

SAVE_DIR = os.path.join( os.path.dirname(args.data_root), 'caltech101_data' )
if not os.path.exists(SAVE_DIR):
os.mkdir(SAVE_DIR)

# Train
TRAIN_DIR = os.path.join( SAVE_DIR, 'train' )
if not os.path.exists(TRAIN_DIR):
os.mkdir(TRAIN_DIR)
# Test
TEST_DIR = os.path.join( SAVE_DIR, 'test' )
if not os.path.exists(TEST_DIR):
os.mkdir(TEST_DIR)
img_folders = os.listdir(args.data_root)
img_folders.sort()

for folder in img_folders:
if folder=='Faces':
continue
print('Processing %s'%(folder))
img_paths = glob(os.path.join( args.data_root, folder, '*.jpg') )
img_name = [os.path.split(p)[-1] for p in img_paths]

random.shuffle(img_name)

img_n = len(img_name)
test_n = int(args.test_split * img_n)
test_set = img_name[:test_n]
train_set = img_name[test_n:]

# test
dst_path = os.path.join(TEST_DIR, folder)
if not os.path.exists(dst_path):
os.mkdir(dst_path)
for test_name in test_set:
img = Image.open(os.path.join( args.data_root, folder, test_name ))
img.save( os.path.join(dst_path, test_name ) )

# train
dst_path = os.path.join(TRAIN_DIR, folder)
if not os.path.exists(dst_path):
os.mkdir(dst_path)
for train_name in train_set:
img = Image.open(os.path.join( args.data_root, folder, train_name ))
img.save( os.path.join(dst_path, train_name ) )

+ 196
- 0
model_measuring/kamal/vision/datasets/preprocess/prepare_stl10.py View File

@@ -0,0 +1,196 @@
from __future__ import print_function

import sys
import os, sys, tarfile, errno
import numpy as np
import matplotlib.pyplot as plt
import argparse
if sys.version_info >= (3, 0, 0):
import urllib.request as urllib # ugly but works
else:
import urllib

try:
from imageio import imsave
except:
from scipy.misc import imsave

print(sys.version_info)

# image shape
HEIGHT = 96
WIDTH = 96
DEPTH = 3

# size of a single image in bytes
SIZE = HEIGHT * WIDTH * DEPTH

# path to the directory with the data
#DATA_DIR = './data'

# url of the binary data
#DATA_URL = 'http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz'

# path to the binary train file with image data
#DATA_PATH = './data/stl10_binary/train_X.bin'

# path to the binary train file with labels
#LABEL_PATH = './data/stl10_binary/train_y.bin'

def read_labels(path_to_labels):
"""
:param path_to_labels: path to the binary file containing labels from the STL-10 dataset
:return: an array containing the labels
"""
with open(path_to_labels, 'rb') as f:
labels = np.fromfile(f, dtype=np.uint8)
return labels


def read_all_images(path_to_data):
"""
:param path_to_data: the file containing the binary images from the STL-10 dataset
:return: an array containing all the images
"""

with open(path_to_data, 'rb') as f:
# read whole file in uint8 chunks
everything = np.fromfile(f, dtype=np.uint8)

# We force the data into 3x96x96 chunks, since the
# images are stored in "column-major order", meaning
# that "the first 96*96 values are the red channel,
# the next 96*96 are green, and the last are blue."
# The -1 is since the size of the pictures depends
# on the input file, and this way numpy determines
# the size on its own.

images = np.reshape(everything, (-1, 3, 96, 96))

# Now transpose the images into a standard image format
# readable by, for example, matplotlib.imshow
# You might want to comment this line or reverse the shuffle
# if you will use a learning algorithm like CNN, since they like
# their channels separated.
images = np.transpose(images, (0, 3, 2, 1))
return images


def read_single_image(image_file):
"""
CAREFUL! - this method uses a file as input instead of the path - so the
position of the reader will be remembered outside of context of this method.
:param image_file: the open file containing the images
:return: a single image
"""
# read a single image, count determines the number of uint8's to read
image = np.fromfile(image_file, dtype=np.uint8, count=SIZE)
# force into image matrix
image = np.reshape(image, (3, 96, 96))
# transpose to standard format
# You might want to comment this line or reverse the shuffle
# if you will use a learning algorithm like CNN, since they like
# their channels separated.
image = np.transpose(image, (2, 1, 0))
return image


def plot_image(image):
"""
:param image: the image to be plotted in a 3-D matrix format
:return: None
"""
plt.imshow(image)
plt.show()

def save_image(image, name):
imsave("%s.png" % name, image, format="png")

def download_and_extract(DATA_DIR):
"""
Download and extract the STL-10 dataset
:return: None
"""
dest_directory = DATA_DIR
if not os.path.exists(dest_directory):
os.makedirs(dest_directory)
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(dest_directory, filename)
if not os.path.exists(filepath):
def _progress(count, block_size, total_size):
sys.stdout.write('\rDownloading %s %.2f%%' % (filename,
float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib.urlretrieve(DATA_URL, filepath, reporthook=_progress)
print('Downloaded', filename)
tarfile.open(filepath, 'r:gz').extractall(dest_directory)

def save_images(images, labels, save_dir):
print("Saving images to disk")
i = 0
if not os.path.exists(save_dir):
os.mkdir(save_dir)

for image in images:
if labels is not None:
label = labels[i]
directory = os.path.join( save_dir, str(label) )
if not os.path.exists(directory):
os.makedirs(directory)
else:
directory = save_dir

filename = os.path.join( directory, str(i) )
print(filename)
save_image(image, filename)
i = i+1
if __name__ == "__main__":
# download data if needed
# download_and_extract()
parser = argparse.ArgumentParser()
parser.add_argument('--DATA_DIR', type=str, default='../stl10_binary')
args = parser.parse_args()

BASE_PATH = os.path.join( args.DATA_DIR, 'stl10_data' )
if not os.path.exists(BASE_PATH):
os.mkdir(BASE_PATH)

# Train
DATA_PATH = os.path.join( args.DATA_DIR, 'train_X.bin')
LABEL_PATH = os.path.join( args.DATA_DIR, 'train_y.bin')
print('Preparing Train')
images = read_all_images(DATA_PATH)
print(images.shape)
labels = read_labels(LABEL_PATH)
print(labels.shape)
save_images(images, labels, os.path.join(BASE_PATH, 'train'))

# Test
DATA_PATH = os.path.join( args.DATA_DIR, 'test_X.bin')
LABEL_PATH = os.path.join( args.DATA_DIR, 'test_y.bin')

#with open(DATA_PATH) as f:
# image = read_single_image(f)
# plot_image(image)
print('Preparing Test')
images = read_all_images(DATA_PATH)
print(images.shape)
labels = read_labels(LABEL_PATH)
print(labels.shape)
save_images(images, labels, os.path.join(BASE_PATH, 'test'))

# Unlabeled
print('Preparing Unlabeled')
DATA_PATH = os.path.join( args.DATA_DIR, 'unlabeled_X.bin')
images = read_all_images(DATA_PATH)
save_images(images, None, os.path.join(BASE_PATH, 'unlabeled'))








+ 53
- 0
model_measuring/kamal/vision/datasets/preprocess/resize_camvid.py View File

@@ -0,0 +1,53 @@
import argparse
import os
from PIL import Image
import numpy as np

SIZE=(480, 320) # W, H

def is_image(path: str):
return path.endswith( 'png' ) or path.endswith('jpg') or path.endswith( 'jpeg' )

def copy_and_resize( src_dir, dst_dir, resize_fn ):

for file_or_dir in os.listdir( src_dir ):
src = os.path.join( src_dir, file_or_dir )
dst = os.path.join( dst_dir, file_or_dir )
if os.path.isdir( src ):
os.mkdir( dst )
copy_and_resize( src, dst, resize_fn )
elif is_image( src ):
print(src, ' -> ', dst)
image = Image.open( src )
resized_image = resize_fn(image)
resized_image.save( dst )

def resize_input( image: Image.Image ):
return image.resize( SIZE, resample=Image.BICUBIC )

def resize_seg( image: Image.Image ):
return image.resize( SIZE, resample=Image.NEAREST )

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True)
ROOT = parser.parse_args().root
NEW_ROOT = os.path.join( ROOT, '%d_%d'%(*SIZE) )
os.mkdir(NEW_ROOT, )
for split in ['train', 'val', 'test']:
IMG_DIR = os.path.join( ROOT, split )
GT_DIR = os.path.join( ROOT, split+'annot' )
NEW_IMG_DIR = os.path.join( NEW_ROOT, split )
NEW_GT_DIR = os.path.join( NEW_ROOT, split+'annot' )
os.mkdir( NEW_IMG_DIR )
os.mkdir( NEW_GT_DIR )
copy_and_resize( IMG_DIR, NEW_IMG_DIR, resize_input )
copy_and_resize( GT_DIR, NEW_GT_DIR, resize_seg )




if __name__=='__main__':
main()

+ 65
- 0
model_measuring/kamal/vision/datasets/preprocess/resize_cityscapes.py View File

@@ -0,0 +1,65 @@
import argparse
import os
from PIL import Image
import numpy as np

SIZE=(640, 320)

def is_image(path: str):
return path.endswith( 'png' ) or path.endswith('jpg') or path.endswith( 'jpeg' )

def copy_and_resize( src_dir, dst_dir, resize_fn ):

for file_or_dir in os.listdir( src_dir ):
src = os.path.join( src_dir, file_or_dir )
dst = os.path.join( dst_dir, file_or_dir )
if os.path.isdir( src ):
os.mkdir( dst )
copy_and_resize( src, dst, resize_fn )
elif is_image( src ):
print(src, ' -> ', dst)
image = Image.open( src )
resized_image = resize_fn(image)
resized_image.save( dst )

def resize_input( image: Image.Image ):
return image.resize( SIZE, resample=Image.BICUBIC )

def resize_seg( image: Image.Image ):
return image.resize( SIZE, resample=Image.NEAREST )

def resize_depth( image: Image.Image ):
return image.resize( SIZE, resample=Image.NEAREST )

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True)

ROOT = parser.parse_args().root
IMG_DIR = os.path.join( ROOT, 'leftImg8bit' )
GT_DIR = os.path.join( ROOT, 'gtFine' )
DEPTH_DIR = os.path.join( ROOT, 'disparity' )

NEW_ROOT = os.path.join( ROOT, '%d_%d'%(*SIZE) )
NEW_IMG_DIR = os.path.join( NEW_ROOT, 'leftImg8bit' )
NEW_GT_DIR = os.path.join( NEW_ROOT, 'gtFine' )
NEW_DEPTH_DIR = os.path.join( NEW_ROOT, 'disparity' )

if os.path.exists(NEW_ROOT):
print("Directory %s existed, please remove it before running this script"%NEW_ROOT)
return
os.mkdir(NEW_ROOT)
os.mkdir( NEW_IMG_DIR )
os.mkdir( NEW_GT_DIR )
os.mkdir( NEW_DEPTH_DIR )
copy_and_resize( IMG_DIR, NEW_IMG_DIR, resize_input )
copy_and_resize( GT_DIR, NEW_GT_DIR, resize_seg )
copy_and_resize( DEPTH_DIR, NEW_DEPTH_DIR, resize_depth )




if __name__=='__main__':
main()

+ 59
- 0
model_measuring/kamal/vision/datasets/preprocess/resize_voc.py View File

@@ -0,0 +1,59 @@
import argparse
import os
from PIL import Image
import numpy as np
from torchvision import transforms
import shutil

SIZE=240

def is_image(path: str):
return path.endswith( 'png' ) or path.endswith('jpg') or path.endswith( 'jpeg' )

def copy_and_resize( src_dir, dst_dir, resize_fn ):

for file_or_dir in os.listdir( src_dir ):
src = os.path.join( src_dir, file_or_dir )
dst = os.path.join( dst_dir, file_or_dir )
if os.path.isdir( src ):
os.mkdir( dst )
copy_and_resize( src, dst, resize_fn )
elif is_image( src ):
print(src, ' -> ', dst)
image = Image.open( src )
resized_image = resize_fn(image)
resized_image.save( dst )

def resize_input( image: Image.Image ):
return transforms.functional.resize( image, SIZE, interpolation=Image.BILINEAR )

def resize_seg( image: Image.Image ):
return transforms.functional.resize( image, SIZE, interpolation=Image.NEAREST)

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True)

ROOT = parser.parse_args().root
IMG_DIR = os.path.join( ROOT, 'JPEGImages' )
GT_DIR = os.path.join( ROOT, 'SegmentationClass' )

NEW_ROOT = os.path.join( ROOT, str(SIZE) )
NEW_IMG_DIR = os.path.join( NEW_ROOT, 'JPEGImages' )
NEW_GT_DIR = os.path.join( NEW_ROOT, 'SegmentationClass' )

if os.path.exists(NEW_ROOT):
print("Directory %s existed, please remove it before running this script"%NEW_ROOT)
return
os.mkdir(NEW_ROOT)
os.mkdir( NEW_IMG_DIR )
os.mkdir( NEW_GT_DIR )
copy_and_resize( IMG_DIR, NEW_IMG_DIR, resize_input )
copy_and_resize( GT_DIR, NEW_GT_DIR, resize_seg )
shutil.copytree( os.path.join( ROOT, 'ImageSets'), os.path.join( NEW_ROOT, 'ImageSets' ) )


if __name__=='__main__':
main()

+ 57
- 0
model_measuring/kamal/vision/datasets/preprocess/resize_voc_240.py View File

@@ -0,0 +1,57 @@
import argparse
import os
from PIL import Image
import numpy as np
from torchvision import transforms
import shutil

def is_image(path: str):
return path.endswith( 'png' ) or path.endswith('jpg') or path.endswith( 'jpeg' )

def copy_and_resize( src_dir, dst_dir, resize_fn ):

for file_or_dir in os.listdir( src_dir ):
src = os.path.join( src_dir, file_or_dir )
dst = os.path.join( dst_dir, file_or_dir )
if os.path.isdir( src ):
os.mkdir( dst )
copy_and_resize( src, dst, resize_fn )
elif is_image( src ):
print(src, ' -> ', dst)
image = Image.open( src )
resized_image = resize_fn(image)
resized_image.save( dst )

def resize_input( image: Image.Image ):
return transforms.functional.resize( image, 240, interpolation=Image.BILINEAR )

def resize_seg( image: Image.Image ):
return transforms.functional.resize( image, 240, interpolation=Image.NEAREST)

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, required=True)

ROOT = parser.parse_args().root
IMG_DIR = os.path.join( ROOT, 'JPEGImages' )
GT_DIR = os.path.join( ROOT, 'SegmentationClass' )

NEW_ROOT = os.path.join( ROOT, '240' )
NEW_IMG_DIR = os.path.join( NEW_ROOT, 'JPEGImages' )
NEW_GT_DIR = os.path.join( NEW_ROOT, 'SegmentationClass' )

if os.path.exists(NEW_ROOT):
print("Directory %s existed, please remove it before running this script"%NEW_ROOT)
return
os.mkdir(NEW_ROOT)
os.mkdir( NEW_IMG_DIR )
os.mkdir( NEW_GT_DIR )
copy_and_resize( IMG_DIR, NEW_IMG_DIR, resize_input )
copy_and_resize( GT_DIR, NEW_GT_DIR, resize_seg )
shutil.copytree( os.path.join( ROOT, 'ImageSets'), os.path.join( NEW_ROOT, 'ImageSets' ) )


if __name__=='__main__':
main()

+ 80
- 0
model_measuring/kamal/vision/datasets/stanford_cars.py View File

@@ -0,0 +1,80 @@
import os
import glob
from PIL import Image
import numpy as np
from scipy.io import loadmat

from torch.utils import data
from .utils import download_url, mkdir

from shutil import copyfile


class StanfordCars(data.Dataset):
"""Dataset for Stanford Cars
"""

urls = {'cars_train.tgz': 'http://imagenet.stanford.edu/internal/car196/cars_train.tgz',
'cars_test.tgz': 'http://imagenet.stanford.edu/internal/car196/cars_test.tgz',
'car_devkit.tgz': 'https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz',
'cars_test_annos_withlabels.mat': 'http://imagenet.stanford.edu/internal/car196/cars_test_annos_withlabels.mat'}

def __init__(self, root, split='train', download=False, transform=None, target_transform=None):
self.root = os.path.abspath( os.path.expanduser(root) )
self.split = split
self.transform = transform
self.target_transform = target_transform

if download:
self.download()

if self.split == 'train':
annos = os.path.join(self.root, 'devkit', 'cars_train_annos.mat')
else:
annos = os.path.join(self.root, 'devkit',
'cars_test_annos_withlabels.mat')

annos = loadmat(annos)
size = len(annos['annotations'][0])

self.files = glob.glob(os.path.join(
self.root, 'cars_'+self.split, '*.jpg'))
self.files.sort()

self.labels = np.array([int(l[4])-1 for l in annos['annotations'][0]])

lbl_annos = loadmat(os.path.join(self.root, 'devkit', 'cars_meta.mat'))

self.object_categories = [str(c[0])
for c in lbl_annos['class_names'][0]]

print('Stanford Cars, Split: %s, Size: %d' %
(self.split, self.__len__()))

def __len__(self):
return len(self.files)

def __getitem__(self, idx):
img = Image.open(os.path.join(self.root, 'Images',
self.files[idx])).convert("RGB")
lbl = self.labels[idx]
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
lbl = self.target_transform(lbl)
return img, lbl

def download(self):
import tarfile

mkdir(self.root)
for fname, url in self.urls.items():
if not os.path.isfile(os.path.join(self.root, fname)):
download_url(url, self.root, fname)
if fname.endswith('tgz'):
print("Extracting %s..." % fname)
with tarfile.open(os.path.join(self.root, fname), "r:gz") as tar:
tar.extractall(path=self.root)

copyfile(os.path.join(self.root, 'cars_test_annos_withlabels.mat'),
os.path.join(self.root, 'devkit', 'cars_test_annos_withlabels.mat'))

+ 58
- 0
model_measuring/kamal/vision/datasets/stanford_dogs.py View File

@@ -0,0 +1,58 @@
import os
import numpy as np
from PIL import Image
from scipy.io import loadmat

from torch.utils import data
from .utils import download_url
from shutil import move

class StanfordDogs(data.Dataset):
"""Dataset for Stanford Dogs
"""
urls = {"images.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/images.tar",
"annotation.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/annotation.tar",
"lists.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/lists.tar"}

def __init__(self, root, split='train', download=False, transform=None, target_transform=None):
self.root = os.path.abspath( os.path.expanduser(root) )
self.split = split
self.transform = transform
self.target_transform = target_transform
if download:
self.download()
list_file = os.path.join(self.root, self.split+'_list.mat')
mat_file = loadmat(list_file)
size = len(mat_file['file_list'])
self.files = [str(mat_file['file_list'][i][0][0]) for i in range(size)]
self.labels = np.array(
[mat_file['labels'][i][0]-1 for i in range(size)])
categories = os.listdir(os.path.join(self.root, 'Images'))
categories.sort()
self.object_categories = [c[10:] for c in categories]
print('Stanford Dogs, Split: %s, Size: %d' %
(self.split, self.__len__()))

def __len__(self):
return len(self.files)

def __getitem__(self, idx):
img = Image.open(os.path.join(self.root, 'Images',
self.files[idx])).convert("RGB")
lbl = self.labels[idx]
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
lbl = self.target_transform( lbl )
return img, lbl

def download(self):
import tarfile
os.makedirs(self.root, exist_ok=True)
for fname, url in self.urls.items():
if not os.path.isfile(os.path.join(self.root, fname)):
download_url(url, self.root, fname)
# extract file
print("Extracting %s..." % fname)
with tarfile.open(os.path.join(self.root, fname), "r") as tar:
tar.extractall(path=self.root)

+ 65
- 0
model_measuring/kamal/vision/datasets/sunrgbd.py View File

@@ -0,0 +1,65 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""


import os
from glob import glob
from PIL import Image
from .utils import colormap
from torchvision.datasets import VisionDataset

class SunRGBD(VisionDataset):
cmap = colormap()
def __init__(self,
root,
split='train',
transform=None,
target_transform=None,
transforms=None):
super( SunRGBD, self ).__init__( root, transform=transform, target_transform=target_transform, transforms=transforms )
self.root = root
self.split = split

self.images = glob(os.path.join(self.root, 'SUNRGBD-%s_images'%self.split, '*.jpg'))
self.labels = glob(os.path.join(self.root, '%s13labels'%self.split, '*.png'))

self.images.sort()
self.labels.sort()

def __getitem__(self, idx):
"""
Args:
- index (``int``): index of the item in the dataset
Returns:
A tuple of ``PIL.Image`` (image, label) where label is the ground-truth
of the image.
"""

img, label = Image.open(self.images[idx]), Image.open(self.labels[idx])

if self.transform is not None:
img, label = self.transform(img, label)
label = label-1 # void 0=>255
return img, label

def __len__(self):
return len(self.images)

@classmethod
def decode_fn(cls, mask):
"""decode semantic mask to RGB image"""
return cls.cmap[mask.astype('uint8')+1]

+ 67
- 0
model_measuring/kamal/vision/datasets/unlabeled.py View File

@@ -0,0 +1,67 @@
"""
Copyright 2020 Tianshu AI Platform. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=============================================================
"""

import torch
from torch.utils.data import Dataset
import os

from PIL import Image
import random
from copy import deepcopy

def _collect_all_images(root, postfix=['png', 'jpg', 'jpeg', 'JPEG']):
images = []
if isinstance( postfix, str):
postfix = [ postfix ]
for dirpath, dirnames, files in os.walk(root):
for pos in postfix:
for f in files:
if f.endswith( pos ):
images.append( os.path.join( dirpath, f ) )
return images

def get_train_val_set(root, val_size=0.3):
if not isinstance(root, (list, tuple)):
root = [root]
train_set = []
val_set = []
for _root in root:
_part_train_set = _collect_all_images( _root)
if os.path.isdir( os.path.join(_root, 'test') ):
_part_val_set = _collect_all_images( os.path.join(_root, 'test') )
else:
_val_size = int( len(_part_train_set) * val_size )
_part_val_set = random.sample( _part_train_set, k=_val_size )
_part_train_set = [ d for d in _part_train_set if d not in _part_val_set ]
train_set.extend(_part_train_set)
val_set.extend(_part_val_set)
return train_set, val_set

class UnlabeledDataset(Dataset):
def __init__(self, data, transform=None, postfix=['png', 'jpg', 'jpeg', 'JPEG']):
self.transform = transform
self.data = data

def __getitem__(self, idx):
data = Image.open( self.data[idx] )
if self.transform is not None:
data = self.transform(data)
return data

def __len__(self):
return len(self.data)


+ 161
- 0
model_measuring/kamal/vision/datasets/utils.py View File

@@ -0,0 +1,161 @@
# Modified from https://github.com/pytorch/vision
import os
import os.path
import hashlib
import errno
from tqdm import tqdm
import numpy as np
import torch
import random

def mkdir(dir):
if not os.path.isdir(dir):
os.mkdir(dir)

def colormap(N=256, normalized=False):
def bitget(byteval, idx):
return ((byteval & (1 << idx)) != 0)

dtype = 'float32' if normalized else 'uint8'
cmap = np.zeros((N, 3), dtype=dtype)
for i in range(N):
r = g = b = 0
c = i
for j in range(8):
r = r | (bitget(c, 0) << 7-j)
g = g | (bitget(c, 1) << 7-j)
b = b | (bitget(c, 2) << 7-j)
c = c >> 3

cmap[i] = np.array([r, g, b])

cmap = cmap/255 if normalized else cmap
return cmap

DEFAULT_COLORMAP = colormap()

def gen_bar_updater(pbar):
def bar_update(count, block_size, total_size):
if pbar.total is None and total_size:
pbar.total = total_size
progress_bytes = count * block_size
pbar.update(progress_bytes - pbar.n)

return bar_update


def check_integrity(fpath, md5=None):
if md5 is None:
return True
if not os.path.isfile(fpath):
return False
md5o = hashlib.md5()
with open(fpath, 'rb') as f:
# read in 1MB chunks
for chunk in iter(lambda: f.read(1024 * 1024), b''):
md5o.update(chunk)
md5c = md5o.hexdigest()
if md5c != md5:
return False
return True


def makedir_exist_ok(dirpath):
"""
Python2 support for os.makedirs(.., exist_ok=True)
"""
try:
os.makedirs(dirpath)
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise

def download_url(url, root, filename=None, md5=None):
"""Download a file from a url and place it in root.
Args:
url (str): URL to download file from
root (str): Directory to place downloaded file in
filename (str): Name to save the file under. If None, use the basename of the URL
md5 (str): MD5 checksum of the download. If None, do not check
"""
from six.moves import urllib

root = os.path.expanduser(root)
if not filename:
filename = os.path.basename(url)
fpath = os.path.join(root, filename)

makedir_exist_ok(root)

# downloads file
if os.path.isfile(fpath) and check_integrity(fpath, md5):
print('Using downloaded and verified file: ' + fpath)
else:
try:
print('Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(
url, fpath,
reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True))
)
except OSError:
if url[:5] == 'https':
url = url.replace('https:', 'http:')
print('Failed download. Trying https -> http instead.'
' Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(
url, fpath,
reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True))
)


def list_dir(root, prefix=False):
"""List all directories at a given root
Args:
root (str): Path to directory whose folders need to be listed
prefix (bool, optional): If true, prepends the path to each result, otherwise
only returns the name of the directories found
"""
root = os.path.expanduser(root)
directories = list(
filter(
lambda p: os.path.isdir(os.path.join(root, p)),
os.listdir(root)
)
)

if prefix is True:
directories = [os.path.join(root, d) for d in directories]

return directories


def list_files(root, suffix, prefix=False):
"""List all files ending with a suffix at a given root
Args:
root (str): Path to directory whose folders need to be listed
suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
It uses the Python "str.endswith" method and is passed directly
prefix (bool, optional): If true, prepends the path to each result, otherwise
only returns the name of the files found
"""
root = os.path.expanduser(root)
files = list(
filter(
lambda p: os.path.isfile(os.path.join(
root, p)) and p.endswith(suffix),
os.listdir(root)
)
)

if prefix is True:
files = [os.path.join(root, d) for d in files]

return files

def set_seed(random_seed):
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)

+ 209
- 0
model_measuring/kamal/vision/datasets/voc.py View File

@@ -0,0 +1,209 @@
# Modified from https://github.com/pytorch/vision
import os
import sys
import tarfile
import collections
import torch.utils.data as data
import shutil
import numpy as np
from .utils import colormap
from torchvision.datasets import VisionDataset
import torch
from PIL import Image
from torchvision.datasets.utils import download_url, check_integrity

DATASET_YEAR_DICT = {
'2012aug': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
'filename': 'VOCtrainval_11-May-2012.tar',
'md5': '6cd6e144f989b92b3379bac3b3de84fd',
'base_dir': 'VOCdevkit/VOC2012'
},
'2012': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
'filename': 'VOCtrainval_11-May-2012.tar',
'md5': '6cd6e144f989b92b3379bac3b3de84fd',
'base_dir': 'VOCdevkit/VOC2012'
},
'2011': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar',
'filename': 'VOCtrainval_25-May-2011.tar',
'md5': '6c3384ef61512963050cb5d687e5bf1e',
'base_dir': 'TrainVal/VOCdevkit/VOC2011'
},
'2010': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar',
'filename': 'VOCtrainval_03-May-2010.tar',
'md5': 'da459979d0c395079b5c75ee67908abb',
'base_dir': 'VOCdevkit/VOC2010'
},
'2009': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar',
'filename': 'VOCtrainval_11-May-2009.tar',
'md5': '59065e4b188729180974ef6572f6a212',
'base_dir': 'VOCdevkit/VOC2009'
},
'2008': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar',
'filename': 'VOCtrainval_11-May-2012.tar',
'md5': '2629fa636546599198acfcfbfcf1904a',
'base_dir': 'VOCdevkit/VOC2008'
},
'2007': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
'filename': 'VOCtrainval_06-Nov-2007.tar',
'md5': 'c52e279531787c972589f7e41ab4ae64',
'base_dir': 'VOCdevkit/VOC2007'
}
}

class VOCSegmentation(VisionDataset):
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
Args:
root (string): Root directory of the VOC Dataset.
year (string, optional): The dataset year, supports years 2007 to 2012.
image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
"""
cmap = colormap()
def __init__(self,
root,
year='2012',
image_set='train',
download=False,
transform=None,
target_transform=None,
transforms=None,
):
super( VOCSegmentation, self ).__init__( root, transform=transform, target_transform=target_transform, transforms=transforms )

is_aug=False
if year=='2012aug':
is_aug = True
year = '2012'
self.root = os.path.expanduser(root)
self.year = year
self.url = DATASET_YEAR_DICT[year]['url']
self.filename = DATASET_YEAR_DICT[year]['filename']
self.md5 = DATASET_YEAR_DICT[year]['md5']
self.image_set = image_set
base_dir = DATASET_YEAR_DICT[year]['base_dir']
voc_root = os.path.join(self.root, base_dir)
image_dir = os.path.join(voc_root, 'JPEGImages')

if download:
download_extract(self.url, self.root, self.filename, self.md5)

if not os.path.isdir(voc_root):
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
if is_aug and image_set=='train':
mask_dir = os.path.join(voc_root, 'SegmentationClassAug')
assert os.path.exists(mask_dir), "SegmentationClassAug not found, please refer to README.md and prepare it manually"
split_f = os.path.join( self.root, 'train_aug.txt')
else:
mask_dir = os.path.join(voc_root, 'SegmentationClass')
splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')

if not os.path.exists(split_f):
raise ValueError(
'Wrong image_set entered! Please use image_set="train" '
'or image_set="trainval" or image_set="val"')

with open(os.path.join(split_f), "r") as f:
file_names = [x.strip() for x in f.readlines()]
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
assert (len(self.images) == len(self.masks))

def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is the image segmentation.
"""
img = Image.open(self.images[index]).convert('RGB')
target = Image.open(self.masks[index])
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target.squeeze(0)

def __len__(self):
return len(self.images)

@classmethod
def decode_fn(cls, mask):
"""decode semantic mask to RGB image"""
return cls.cmap[mask]

def download_extract(url, root, filename, md5):
download_url(url, root, filename, md5)
with tarfile.open(os.path.join(root, filename), "r") as tar:
tar.extractall(path=root)

CLASSES = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']

class VOCClassification(data.Dataset):
def __init__(self,
root,
year='2010',
split='train',
download=False,
transforms=None,
target_transforms=None):

voc_root = os.path.join(root, 'VOC{}'.format(year))
if not os.path.isdir(voc_root):
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')

self.transforms = transforms
self.target_transforms = target_transforms
image_dir = os.path.join(voc_root, 'JPEGImages')
label_dir = os.path.join(voc_root, 'ImageSets/Main')
self.labels_list = []

fname = os.path.join(label_dir, '{}.txt'.format(split))
with open(fname) as f:
self.images = [os.path.join(image_dir, line.split()[0]+'.jpg') for line in f]

for clas in CLASSES:
labels = []
with open(os.path.join(label_dir, '{}_{}.txt'.format(clas, split))) as f:
labels = [int(line.split()[1]) for line in f]
self.labels_list.append(labels)

assert (len(self.images) == len(self.labels_list[0]))

def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is the image segmentation.
"""
img = Image.open(self.images[index]).convert('RGB')
labels = [labels[index] for labels in self.labels_list]

if self.transforms is not None:
img = self.transforms(img)

if self.target_transforms is not None:
labels = self.target_transforms(labels)

return img, labels

def __len__(self):
return len(self.images)

+ 3
- 0
model_measuring/kamal/vision/models/__init__.py View File

@@ -0,0 +1,3 @@
from . import classification, segmentation

from torchvision import models as torchvision_models

+ 7
- 0
model_measuring/kamal/vision/models/classification/__init__.py View File

@@ -0,0 +1,7 @@
from .darknet import *
from .mobilenetv2 import *
from .resnet import *
from .vgg import *
from . import cifar

from .alexnet import alexnet

+ 63
- 0
model_measuring/kamal/vision/models/classification/alexnet.py View File

@@ -0,0 +1,63 @@
# Modified from https://github.com/pytorch/vision
import torch.nn as nn
import torch.utils.model_zoo as model_zoo


__all__ = ['AlexNet', 'alexnet']


model_urls = {
'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
}


class AlexNet(nn.Module):

def __init__(self, num_classes=1000):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(192, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
)

def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), 256 * 6 * 6)
x = self.classifier(x)
return x


def alexnet(pretrained=False, **kwargs):
r"""AlexNet model architecture from the
`"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
num_classes = kwargs.pop('num_classes', None)
model = AlexNet(**kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['alexnet']))
if num_classes is not None and num_classes!=1000:
model.classifier[-1] = nn.Linear(4096, num_classes)
return model

+ 1
- 0
model_measuring/kamal/vision/models/classification/cifar/__init__.py View File

@@ -0,0 +1 @@
from . import wrn

+ 108
- 0
model_measuring/kamal/vision/models/classification/cifar/wrn.py View File

@@ -0,0 +1,108 @@
#Adapted from https://github.com/polo5/ZeroShotKnowledgeTransfer/blob/master/models/wresnet.py

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = ['wrn']

class BasicBlock(nn.Module):
def __init__(self, in_planes, out_planes, stride, dropout_rate=0.0):
super(BasicBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.relu1 = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_planes)
self.relu2 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
padding=1, bias=False)
self.dropout = nn.Dropout( dropout_rate )
self.equalInOut = (in_planes == out_planes)
self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
padding=0, bias=False) or None

def forward(self, x):
if not self.equalInOut:
x = self.relu1(self.bn1(x))
else:
out = self.relu1(self.bn1(x))
out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
out = self.dropout(out)
out = self.conv2(out)
return torch.add(x if self.equalInOut else self.convShortcut(x), out)


class NetworkBlock(nn.Module):
def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropout_rate=0.0):
super(NetworkBlock, self).__init__()
self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropout_rate)

def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropout_rate):
layers = []
for i in range(nb_layers):
layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropout_rate))
return nn.Sequential(*layers)

def forward(self, x):
return self.layer(x)


class WideResNet(nn.Module):
def __init__(self, depth, num_classes, widen_factor=1, dropout_rate=0.0):
super(WideResNet, self).__init__()
nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
n = (depth - 4) // 6
block = BasicBlock
# 1st conv before any network block
self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
padding=1, bias=False)
# 1st block
self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropout_rate)
# 2nd block
self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropout_rate)
# 3rd block
self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropout_rate)
# global average pooling and classifier
self.bn1 = nn.BatchNorm2d(nChannels[3])
self.relu = nn.ReLU(inplace=True)
self.fc = nn.Linear(nChannels[3], num_classes)
self.nChannels = nChannels[3]

for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.bias.data.zero_()

def forward(self, x, return_features=False):
out = self.conv1(x)
out = self.block1(out)
out = self.block2(out)
out = self.block3(out)
out = self.relu(self.bn1(out))
out = F.avg_pool2d(out, 8)
features = out.view(-1, self.nChannels)
out = self.fc(features)
if return_features:
return out, features
else:
return out

def wrn_16_1(num_classes, dropout_rate=0):
return WideResNet(depth=16, num_classes=num_classes, widen_factor=1, dropout_rate=dropout_rate)

def wrn_16_2(num_classes, dropout_rate=0):
return WideResNet(depth=16, num_classes=num_classes, widen_factor=2, dropout_rate=dropout_rate)

def wrn_40_1(num_classes, dropout_rate=0):
return WideResNet(depth=40, num_classes=num_classes, widen_factor=1, dropout_rate=dropout_rate)

def wrn_40_2(num_classes, dropout_rate=0):
return WideResNet(depth=40, num_classes=num_classes, widen_factor=2, dropout_rate=dropout_rate)

Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save