@@ -0,0 +1,2 @@ | |||
/.idea/ | |||
*.iml |
@@ -0,0 +1,7 @@ | |||
FROM tensorflow/tensorflow:2.4.1 | |||
WORKDIR /app | |||
RUN pip install web.py tf2onnx | |||
COPY . /app | |||
ENTRYPOINT ["python3", "main.py"] |
@@ -0,0 +1,86 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
import json | |||
import os | |||
import subprocess | |||
import logging | |||
import web | |||
from subprocess import PIPE | |||
urls = ( | |||
'/hello', 'Hello', | |||
'/model_convert', 'ModelConvert' | |||
) | |||
logging.basicConfig(filename='onnx.log', level=logging.DEBUG) | |||
class Hello(object): | |||
def GET(self): | |||
return 'service alive' | |||
class ModelConvert(object): | |||
def POST(self): | |||
data = web.data() | |||
web.header('Content-Type', 'application/json') | |||
try: | |||
json_data = json.loads(data) | |||
model_path = json_data['model_path'] | |||
output_path = json_data['output_path'] | |||
if not os.path.isdir(model_path): | |||
msg = 'model_path is not a dir: %s' % model_path | |||
logging.error(msg) | |||
return json.dumps({'code': 501, 'msg': msg, 'data': ''}) | |||
if not output_path.endswith('/'): | |||
msg = 'output_path is not a dir: %s' % output_path | |||
logging.error(msg) | |||
return json.dumps({'code': 502, 'msg': msg, 'data': ''}) | |||
exist_flag = exist(model_path) | |||
if not exist_flag: | |||
msg = 'SavedModel file does not exist at: %s' % model_path | |||
logging.error(msg) | |||
return json.dumps({'code': 503, 'msg': msg, 'data': ''}) | |||
convert_flag, msg = convert(model_path, output_path) | |||
if not convert_flag: | |||
return json.dumps({'code': 504, 'msg': msg, 'data': ''}) | |||
except Exception as e: | |||
logging.error(str(e)) | |||
return json.dumps({'code': 505, 'msg': str(e), 'data': ''}) | |||
return json.dumps({'code': 200, 'msg': 'ok', 'data': msg}) | |||
def exist(model_path): | |||
for file in os.listdir(model_path): | |||
if file=='saved_model.pbtxt' or file=='saved_model.pb': | |||
return True | |||
return False | |||
def convert(model_path, output_path): | |||
output_path = output_path+'model.onnx' | |||
try: | |||
logging.info('model_path=%s, output_path=%s' % (model_path, output_path)) | |||
result = subprocess.run(["python", "-m", "tf2onnx.convert", "--saved-model", model_path, "--output", output_path], stdout=PIPE, stderr=PIPE) | |||
logging.info(repr(result)) | |||
if result.returncode != 0: | |||
return False, str(result.stderr) | |||
except Exception as e: | |||
logging.error(str(e)) | |||
return False, str(e) | |||
return True, output_path | |||
if __name__ == '__main__': | |||
app = web.application(urls, globals()) | |||
app.run() |
@@ -0,0 +1,130 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
dependencies = ['torch', 'kamal'] # 依赖项 | |||
import json | |||
import traceback | |||
import os | |||
import torch | |||
import web | |||
import kamal | |||
from PIL import Image | |||
from kamal.transferability import TransferabilityGraph | |||
from kamal.transferability.trans_metric import AttrMapMetric | |||
from torchvision.models import * | |||
urls = ( | |||
'/model_measure/measure', 'Measure', | |||
'/model_measure/package', 'Package', | |||
) | |||
app = web.application(urls, globals()) | |||
class Package: | |||
def POST(self): | |||
req = web.data() | |||
save_path_list = [] | |||
for json_data in json.loads(req): | |||
try: | |||
metadata = json_data['metadata'] | |||
except Exception as e: | |||
traceback.print_exc() | |||
return json.dumps(Response(506, 'Failed to package model, Error: %s'% (traceback.format_exc(limit=1)), save_path_list).__dict__) | |||
entry_name = json_data['entry_name'] | |||
for name, fn in kamal.hub.list_entry(__file__): | |||
if entry_name in name: | |||
try: | |||
dataset = metadata['dataset'] | |||
model = fn(pretrained=False, num_classes=metadata['entry_args']['num_classes']) | |||
num_params = sum( [ torch.numel(p) for p in model.parameters() ] ) | |||
save_path_for_measure = '%sfinegraind_%s/' % (json_data['ckpt'], entry_name) | |||
save_path = save_path_for_measure+'%s' % (metadata['name']) | |||
ckpt = self.file_name(json_data['ckpt']) | |||
if ckpt == '': | |||
return json.dumps( | |||
Response(506, 'Failed to package model [%s]: No .pth file was found in directory ckpt' % (entry_name), save_path_list).__dict__) | |||
model.load_state_dict(torch.load(ckpt), False) | |||
kamal.hub.save( # 该调用将用户的pytorch模型打包成上述格式,并存储至指定位置 | |||
model, # 需要保存的模型 nn.Module | |||
save_path=save_path, | |||
# 导出文件夹名称 | |||
entry_name=entry_name, # 入口函数名,需要与上边的入口函数一致 | |||
spec_name=None, # 具体的参数版本名,为空则自动用md5替代 | |||
code_path=__file__, # 模型依赖的代码,可以是文件夹(必须包含hubconf.py文件), | |||
# 或者是当前hubconf.py, 例子中直接使用了依赖中的模型实现,故只需指定为本文件即可 | |||
metadata=metadata, | |||
tags=dict( | |||
num_params=num_params, | |||
metadata=metadata, | |||
name=metadata['name'], | |||
url=metadata['url'], | |||
dataset=dataset, | |||
img_size=metadata['input']['size'], | |||
readme=json_data['readme']) | |||
) | |||
save_path_list.append(save_path_for_measure) | |||
return json.dumps(Response(200, 'Success', save_path_list).__dict__) | |||
except Exception: | |||
traceback.print_exc() | |||
return json.dumps(Response(506,'Failed to package model [%s], Error: %s' % (entry_name, traceback.format_exc(limit=1)), save_path_list).__dict__) | |||
return json.dumps(Response(506, 'Failed to package model [%s], Error: %s' % (entry_name, traceback.format_exc(limit=1)), save_path_list).__dict__) | |||
def file_name(self, file_dir): | |||
for root, dirs, files in os.walk(file_dir): | |||
for file in files: | |||
if file.endswith('pth'): | |||
return root + file | |||
return '' | |||
class Measure: | |||
def POST(self): | |||
req = web.data() | |||
json_data = json.loads(req) | |||
print(json_data) | |||
try: | |||
measure_name = 'measure' | |||
zoo_set = json_data['zoo_set'] | |||
probe_set_root = json_data['probe_set_root'] | |||
export_path = json_data['export_path'] | |||
output_filename_list = [] | |||
TG = TransferabilityGraph(zoo_set) | |||
print("Add %s" % (probe_set_root)) | |||
imgs_set = list(os.listdir(os.path.join(probe_set_root))) | |||
images = [Image.open(os.path.join(probe_set_root, img)) for img in imgs_set] | |||
metric = AttrMapMetric(images, device=torch.device('cuda')) | |||
TG.add_metric(probe_set_root, metric) | |||
isExists = os.path.exists(export_path) | |||
if not isExists: | |||
# 如果不存在则创建目录 | |||
os.makedirs(export_path) | |||
output_filename = export_path+'%s.json' % (measure_name) | |||
TG.export_to_json(probe_set_root, output_filename, topk=3, normalize=True) | |||
output_filename_list.append(output_filename) | |||
except Exception: | |||
traceback.print_exc() | |||
return json.dumps(Response(506, 'Failed to generate measurement file of [%s], Error: %s' % (probe_set_root, traceback.format_exc(limit=1)), output_filename_list).__dict__) | |||
return json.dumps(Response(200, 'Success', output_filename_list).__dict__) | |||
class Response: | |||
def __init__(self, code, msg, data): | |||
self.code = code | |||
self.msg = msg | |||
self.data = data | |||
if __name__ == "__main__": | |||
app.run() |
@@ -0,0 +1,5 @@ | |||
from .core import tasks, metrics, engine, callbacks, hub | |||
from . import amalgamation, slim, vision, transferability | |||
from .core import load, save |
@@ -0,0 +1,4 @@ | |||
from .layerwise_amalgamation import LayerWiseAmalgamator | |||
from .common_feature import CommonFeatureAmalgamator | |||
from .task_branching import TaskBranchingAmalgamator, JointSegNet | |||
from .recombination import RecombinationAmalgamator, CombinedModel |
@@ -0,0 +1,291 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
from kamal.core.engine.engine import Engine | |||
from kamal.core.engine.hooks import FeatureHook | |||
from kamal.core import tasks | |||
from kamal.utils import set_mode, move_to_device | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
import typing, time | |||
import numpy as np | |||
def conv3x3(in_planes, out_planes, stride=1): | |||
"""3x3 convolution with padding""" | |||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, | |||
padding=1, bias=False) | |||
class ResBlock(nn.Module): | |||
""" Residual Blocks | |||
""" | |||
def __init__(self, inplanes, planes, stride=1, momentum=0.1): | |||
super(ResBlock, self).__init__() | |||
self.conv1 = conv3x3(inplanes, planes, stride) | |||
self.bn1 = nn.BatchNorm2d(planes, momentum=momentum) | |||
self.relu = nn.ReLU(inplace=True) | |||
self.conv2 = conv3x3(planes, planes) | |||
self.bn2 = nn.BatchNorm2d(planes, momentum=momentum) | |||
if stride > 1 or inplanes != planes: | |||
self.downsample = nn.Sequential( | |||
nn.Conv2d(inplanes, planes, kernel_size=1, | |||
stride=stride, bias=False), | |||
nn.BatchNorm2d(planes, momentum=momentum) | |||
) | |||
else: | |||
self.downsample = None | |||
self.stride = stride | |||
def forward(self, x): | |||
residual = x | |||
out = self.conv1(x) | |||
out = self.bn1(out) | |||
out = self.relu(out) | |||
out = self.conv2(out) | |||
out = self.bn2(out) | |||
if self.downsample is not None: | |||
residual = self.downsample(x) | |||
out += residual | |||
out = self.relu(out) | |||
return out | |||
class CFL_FCBlock(nn.Module): | |||
"""Common Feature Blocks for Fully-Connected layer | |||
This module is used to capture the common features of multiple teachers and calculate mmd with features of student. | |||
**Parameters:** | |||
- cs (int): channel number of student features | |||
- channel_ts (list or tuple): channel number list of teacher features | |||
- ch (int): channel number of hidden features | |||
""" | |||
def __init__(self, cs, cts, ch, k_size=5): | |||
super(CFL_FCBlock, self).__init__() | |||
self.align_t = nn.ModuleList() | |||
for ct in cts: | |||
self.align_t.append( | |||
nn.Sequential( | |||
nn.Linear(ct, ch), | |||
nn.ReLU(inplace=True) | |||
) | |||
) | |||
self.align_s = nn.Sequential( | |||
nn.Linear(cs, ch), | |||
nn.ReLU(inplace=True), | |||
) | |||
self.extractor = nn.Sequential( | |||
nn.Linear(ch, ch), | |||
nn.ReLU(), | |||
nn.Linear(ch, ch), | |||
) | |||
self.dec_t = nn.ModuleList() | |||
for ct in cts: | |||
self.dec_t.append( | |||
nn.Sequential( | |||
nn.Linear(ch, ct), | |||
nn.ReLU(inplace=True), | |||
nn.Linear(ct, ct) | |||
) | |||
) | |||
def init_weights(self): | |||
for m in self.modules(): | |||
if isinstance(m, nn.Linear): | |||
torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu') | |||
elif isinstance(m, nn.BatchNorm2d): | |||
m.weight.data.fill_(1) | |||
m.bias.data.zero_() | |||
def forward(self, fs, fts): | |||
aligned_t = [self.align_t[i](fts[i]) for i in range(len(fts))] | |||
aligned_s = self.align_s(fs) | |||
hts = [self.extractor(f) for f in aligned_t] | |||
hs = self.extractor(aligned_s) | |||
_fts = [self.dec_t[i](hts[i]) for i in range(len(hts))] | |||
return (hs, hts), (_fts, fts) | |||
class CFL_ConvBlock(nn.Module): | |||
"""Common Feature Blocks for Convolutional layer | |||
This module is used to capture the common features of multiple teachers and calculate mmd with features of student. | |||
**Parameters:** | |||
- cs (int): channel number of student features | |||
- channel_ts (list or tuple): channel number list of teacher features | |||
- ch (int): channel number of hidden features | |||
""" | |||
def __init__(self, cs, cts, ch, k_size=5): | |||
super(CFL_ConvBlock, self).__init__() | |||
self.align_t = nn.ModuleList() | |||
for ct in cts: | |||
self.align_t.append( | |||
nn.Sequential( | |||
nn.Conv2d(in_channels=ct, out_channels=ch, | |||
kernel_size=1), | |||
nn.BatchNorm2d(ch), | |||
nn.ReLU(inplace=True) | |||
) | |||
) | |||
self.align_s = nn.Sequential( | |||
nn.Conv2d(in_channels=cs, out_channels=ch, | |||
kernel_size=1), | |||
nn.BatchNorm2d(ch), | |||
nn.ReLU(inplace=True), | |||
) | |||
self.extractor = nn.Sequential( | |||
ResBlock(inplanes=ch, planes=ch, stride=1), | |||
ResBlock(inplanes=ch, planes=ch, stride=1), | |||
) | |||
self.dec_t = nn.ModuleList() | |||
for ct in cts: | |||
self.dec_t.append( | |||
nn.Sequential( | |||
nn.Conv2d(ch, ch, kernel_size=1, stride=1), | |||
nn.BatchNorm2d(ch), | |||
nn.ReLU(inplace=True), | |||
nn.Conv2d(ch, ct, kernel_size=1, stride=1) | |||
) | |||
) | |||
def init_weights(self): | |||
for m in self.modules(): | |||
if isinstance(m, nn.Conv2d): | |||
torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu') | |||
elif isinstance(m, nn.BatchNorm2d): | |||
m.weight.data.fill_(1) | |||
m.bias.data.zero_() | |||
def forward(self, fs, fts): | |||
aligned_t = [self.align_t[i](fts[i]) for i in range(len(fts))] | |||
aligned_s = self.align_s(fs) | |||
hts = [self.extractor(f) for f in aligned_t] | |||
hs = self.extractor(aligned_s) | |||
_fts = [self.dec_t[i](hts[i]) for i in range(len(hts))] | |||
return (hs, hts), (_fts, fts) | |||
class CommonFeatureAmalgamator(Engine): | |||
def setup( | |||
self, | |||
student, | |||
teachers, | |||
layer_groups: typing.Sequence[typing.Sequence], | |||
layer_channels: typing.Sequence[typing.Sequence], | |||
dataloader: torch.utils.data.DataLoader, | |||
optimizer: torch.optim.Optimizer, | |||
weights = [1.0, 1.0, 1.0], | |||
on_layer_input=False, | |||
device = None, | |||
): | |||
self._dataloader = dataloader | |||
if device is None: | |||
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) | |||
self._device = device | |||
self._model = self._student = student.to(self._device) | |||
self._teachers = nn.ModuleList(teachers).to(self._device) | |||
self._optimizer = optimizer | |||
self._weights = weights | |||
self._on_layer_input = on_layer_input | |||
amal_blocks = [] | |||
for group, C in zip( layer_groups, layer_channels ): | |||
hooks = [ FeatureHook(layer) for layer in group ] | |||
if isinstance(group[0], nn.Linear): | |||
amal_block = CFL_FCBlock( cs=C[0], cts=C[1:], ch=C[0]//4 ).to(self._device).train() | |||
print("Building FC Blocks") | |||
else: | |||
amal_block = CFL_ConvBlock(cs=C[0], cts=C[1:], ch=C[0]//4).to(self._device).train() | |||
print("Building Conv Blocks") | |||
amal_blocks.append( (amal_block, hooks, C) ) | |||
self._amal_blocks = amal_blocks | |||
self._cfl_criterion = tasks.loss.CFLLoss( sigmas=[0.001, 0.01, 0.05, 0.1, 0.2, 1, 2] ) | |||
@property | |||
def device(self): | |||
return self._device | |||
def run(self, max_iter, start_iter=0, epoch_length=None): | |||
block_params = [] | |||
for block, _, _ in self._amal_blocks: | |||
block_params.extend( list(block.parameters()) ) | |||
if isinstance( self._optimizer, torch.optim.SGD ): | |||
self._amal_optimimizer = torch.optim.SGD( block_params, lr=self._optimizer.param_groups[0]['lr'], momentum=0.9, weight_decay=1e-4 ) | |||
else: | |||
self._amal_optimimizer = torch.optim.Adam( block_params, lr=self._optimizer.param_groups[0]['lr'], weight_decay=1e-4 ) | |||
self._amal_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self._amal_optimimizer, T_max=max_iter ) | |||
with set_mode(self._student, training=True), \ | |||
set_mode(self._teachers, training=False): | |||
super( CommonFeatureAmalgamator, self ).run(self.step_fn, self._dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length) | |||
def step_fn(self, engine, batch): | |||
start_time = time.perf_counter() | |||
batch = move_to_device(batch, self._device) | |||
data = batch[0] | |||
s_out = self._student( data ) | |||
with torch.no_grad(): | |||
t_out = [ teacher( data ) for teacher in self._teachers ] | |||
loss_amal = 0 | |||
loss_recons = 0 | |||
for amal_block, hooks, C in self._amal_blocks: | |||
features = [ h.feat_in if self._on_layer_input else h.feat_out for h in hooks ] | |||
fs, fts = features[0], features[1:] | |||
(hs, hts), (_fts, fts) = amal_block( fs, fts ) | |||
_loss_amal, _loss_recons = self._cfl_criterion( hs, hts, _fts, fts ) | |||
loss_amal += _loss_amal | |||
loss_recons += _loss_recons | |||
loss_kd = tasks.loss.kldiv( s_out, torch.cat( t_out, dim=1 ) ) | |||
loss_dict = { | |||
'loss_kd': self._weights[0]*loss_kd, | |||
'loss_amal': self._weights[1]*loss_amal, | |||
'loss_recons': self._weights[2]*loss_recons | |||
} | |||
loss = sum(loss_dict.values()) | |||
self._optimizer.zero_grad() | |||
self._amal_optimimizer.zero_grad() | |||
loss.backward() | |||
self._optimizer.step() | |||
self._amal_optimimizer.step() | |||
self._amal_scheduler.step() | |||
step_time = time.perf_counter() - start_time | |||
metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } | |||
metrics.update({ | |||
'total_loss': loss.item(), | |||
'step_time': step_time, | |||
'lr': float( self._optimizer.param_groups[0]['lr'] ) | |||
}) | |||
return metrics |
@@ -0,0 +1,131 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from kamal.core.engine.engine import Engine | |||
from kamal.core.engine.hooks import FeatureHook | |||
from kamal.core import tasks | |||
from kamal.utils import set_mode | |||
import typing | |||
import time | |||
from kamal.utils import move_to_device, set_mode | |||
class AmalBlock(nn.Module): | |||
def __init__(self, cs, cts): | |||
super( AmalBlock, self ).__init__() | |||
self.cs, self.cts = cs, cts | |||
self.enc = nn.Conv2d( in_channels=sum(self.cts), out_channels=self.cs, kernel_size=1, stride=1, padding=0, bias=True ) | |||
self.fam = nn.Conv2d( in_channels=self.cs, out_channels=self.cs, kernel_size=1, stride=1, padding=0, bias=True ) | |||
self.dec = nn.Conv2d( in_channels=self.cs, out_channels=sum(self.cts), kernel_size=1, stride=1, padding=0, bias=True ) | |||
def forward(self, fs, fts): | |||
rep = self.enc( torch.cat( fts, dim=1 ) ) | |||
_fts = self.dec( rep ) | |||
_fts = torch.split( _fts, self.cts, dim=1 ) | |||
_fs = self.fam( fs ) | |||
return rep, _fs, _fts | |||
class LayerWiseAmalgamator(Engine): | |||
def setup( | |||
self, | |||
student, | |||
teachers, | |||
layer_groups: typing.Sequence[typing.Sequence], | |||
layer_channels: typing.Sequence[typing.Sequence], | |||
dataloader: torch.utils.data.DataLoader, | |||
optimizer: torch.optim.Optimizer, | |||
weights = [1., 1., 1.], | |||
device=None, | |||
): | |||
if device is None: | |||
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) | |||
self._device = device | |||
self._dataloader = dataloader | |||
self.model = self.student = student.to(self.device) | |||
self.teachers = nn.ModuleList(teachers).to(self.device) | |||
self.optimizer = optimizer | |||
self._weights = weights | |||
amal_blocks = [] | |||
for group, C in zip(layer_groups, layer_channels): | |||
hooks = [ FeatureHook(layer) for layer in group ] | |||
amal_block = AmalBlock(cs=C[0], cts=C[1:]).to(self.device).train() | |||
amal_blocks.append( (amal_block, hooks, C) ) | |||
self._amal_blocks = amal_blocks | |||
@property | |||
def device(self): | |||
return self._device | |||
def run(self, max_iter, start_iter=0, epoch_length=None ): | |||
block_params = [] | |||
for block, _, _ in self._amal_blocks: | |||
block_params.extend( list(block.parameters()) ) | |||
if isinstance( self.optimizer, torch.optim.SGD ): | |||
self._amal_optimimizer = torch.optim.SGD( block_params, lr=self.optimizer.param_groups[0]['lr'], momentum=0.9, weight_decay=1e-4 ) | |||
else: | |||
self._amal_optimimizer = torch.optim.Adam( block_params, lr=self.optimizer.param_groups[0]['lr'], weight_decay=1e-4 ) | |||
self._amal_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self._amal_optimimizer, T_max=max_iter ) | |||
with set_mode(self.student, training=True), \ | |||
set_mode(self.teachers, training=False): | |||
super( LayerWiseAmalgamator, self ).run(self.step_fn, self._dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length) | |||
@property | |||
def device(self): | |||
return self._device | |||
def step_fn(self, engine, batch): | |||
start_time = time.perf_counter() | |||
batch = move_to_device(batch, self._device) | |||
data = batch[0] | |||
s_out = self.student( data ) | |||
with torch.no_grad(): | |||
t_out = [ teacher( data ) for teacher in self.teachers ] | |||
loss_amal = 0 | |||
loss_recons = 0 | |||
for amal_block, hooks, C in self._amal_blocks: | |||
features = [ h.feat_out for h in hooks ] | |||
fs, fts = features[0], features[1:] | |||
rep, _fs, _fts = amal_block( fs, fts ) | |||
loss_amal += F.mse_loss( _fs, rep.detach() ) | |||
loss_recons += sum( [ F.mse_loss( _ft, ft ) for (_ft, ft) in zip( _fts, fts ) ] ) | |||
loss_kd = tasks.loss.kldiv( s_out, torch.cat( t_out, dim=1 ) ) | |||
#loss_kd = F.mse_loss( s_out, torch.cat( t_out, dim=1 ) ) | |||
loss_dict = { "loss_kd": self._weights[0] * loss_kd, | |||
"loss_amal": self._weights[1] * loss_amal, | |||
"loss_recons": self._weights[2] * loss_recons } | |||
loss = sum(loss_dict.values()) | |||
self.optimizer.zero_grad() | |||
self._amal_optimimizer.zero_grad() | |||
loss.backward() | |||
self.optimizer.step() | |||
self._amal_optimimizer.step() | |||
self._amal_scheduler.step() | |||
step_time = time.perf_counter() - start_time | |||
metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } | |||
metrics.update({ | |||
'total_loss': loss.item(), | |||
'step_time': step_time, | |||
'lr': float( self.optimizer.param_groups[0]['lr'] ) | |||
}) | |||
return metrics | |||
@@ -0,0 +1,209 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from copy import deepcopy | |||
from typing import Callable | |||
from kamal.core.engine.engine import Engine | |||
from kamal.core.engine.trainer import KDTrainer | |||
from kamal.core.engine.hooks import FeatureHook | |||
from kamal.core import tasks | |||
import math | |||
from kamal.slim.prunning import Pruner, strategy | |||
def _assert_same_type(layers, layer_type=None): | |||
if layer_type is None: | |||
layer_type = type(layers[0]) | |||
assert all(isinstance(l, layer_type) for l in layers), 'Model archictures must be the same' | |||
def _get_layers(model_list): | |||
submodel = [ model.modules() for model in model_list ] | |||
for layers in zip(*submodel): | |||
_assert_same_type(layers) | |||
yield layers | |||
def bn_combine_fn(layers): | |||
"""Combine 2D Batch Normalization Layers | |||
**Parameters:** | |||
- **layers** (BatchNorm2D): Batch Normalization Layers. | |||
""" | |||
_assert_same_type(layers, nn.BatchNorm2d) | |||
num_features = sum(l.num_features for l in layers) | |||
combined_bn = nn.BatchNorm2d(num_features=num_features, | |||
eps=layers[0].eps, | |||
momentum=layers[0].momentum, | |||
affine=layers[0].affine, | |||
track_running_stats=layers[0].track_running_stats) | |||
combined_bn.running_mean = torch.cat( | |||
[l.running_mean for l in layers], dim=0).clone() | |||
combined_bn.running_var = torch.cat( | |||
[l.running_var for l in layers], dim=0).clone() | |||
if combined_bn.affine: | |||
combined_bn.weight = torch.nn.Parameter( | |||
torch.cat([l.weight.data.clone() for l in layers], dim=0).clone()) | |||
combined_bn.bias = torch.nn.Parameter( | |||
torch.cat([l.bias.data.clone() for l in layers], dim=0).clone()) | |||
return combined_bn | |||
def conv2d_combine_fn(layers): | |||
"""Combine 2D Conv Layers | |||
**Parameters:** | |||
- **layers** (Conv2d): Conv Layers. | |||
""" | |||
_assert_same_type(layers, nn.Conv2d) | |||
CO, CI = 0, 0 | |||
for l in layers: | |||
O, I, H, W = l.weight.shape | |||
CO += O | |||
CI += I | |||
dtype = layers[0].weight.dtype | |||
device = layers[0].weight.device | |||
combined_weight = torch.nn.Parameter( | |||
torch.zeros(CO, CI, H, W, dtype=dtype, device=device)) | |||
if layers[0].bias is not None: | |||
combined_bias = torch.nn.Parameter( | |||
torch.zeros(CO, dtype=dtype, device=device)) | |||
else: | |||
combined_bias = None | |||
co_offset = 0 | |||
ci_offset = 0 | |||
for idx, l in enumerate(layers): | |||
co_len, ci_len = l.weight.shape[0], l.weight.shape[1] | |||
combined_weight[co_offset: co_offset+co_len, | |||
ci_offset: ci_offset+ci_len, :, :] = l.weight.clone() | |||
if combined_bias is not None: | |||
combined_bias[co_offset: co_offset+co_len] = l.bias.clone() | |||
co_offset += co_len | |||
ci_offset += ci_offset | |||
combined_conv2d = nn.Conv2d(in_channels=CI, | |||
out_channels=CO, | |||
kernel_size=layers[0].weight.shape[-2:], | |||
stride=layers[0].stride, | |||
padding=layers[0].padding, | |||
bias=layers[0].bias is not None) | |||
combined_conv2d.weight.data = combined_weight | |||
if combined_bias is not None: | |||
combined_conv2d.bias.data = combined_bias | |||
for p in combined_conv2d.parameters(): | |||
p.requires_grad = True | |||
return combined_conv2d | |||
def combine_models(models): | |||
"""Combine modules with parser | |||
**Parameters:** | |||
- **models** (nn.Module): modules to be combined. | |||
- **combine_parser** (function): layer selector | |||
""" | |||
def _recursively_combine(module): | |||
module_output = module | |||
if isinstance( module, nn.Conv2d ): | |||
combined_module = conv2d_combine_fn( layer_mapping[module] ) | |||
elif isinstance( module, nn.BatchNorm2d ): | |||
combined_module = bn_combine_fn( layer_mapping[module] ) | |||
else: | |||
combined_module = module | |||
if combined_module is not None: | |||
module_output = combined_module | |||
for name, child in module.named_children(): | |||
module_output.add_module(name, _recursively_combine(child)) | |||
return module_output | |||
models = deepcopy(models) | |||
combined_model = deepcopy(models[0]) # copy the model archicture and modify it with _recursively_combine | |||
layer_mapping = {} | |||
for combined_layer, layers in zip(combined_model.modules(), _get_layers(models)): | |||
layer_mapping[combined_layer] = layers # link to teachers | |||
combined_model = _recursively_combine(combined_model) | |||
return combined_model | |||
class CombinedModel(nn.Module): | |||
def __init__(self, models): | |||
super( Combination, self ).__init__() | |||
self.combined_model = combine_models( models ) | |||
self.expand = len(models) | |||
def forward(self, x): | |||
x.repeat( -1, x.shape[1]*self.expand, -1, -1 ) | |||
return self.combined_model(x) | |||
class PruningKDTrainer(KDTrainer): | |||
def setup( | |||
self, | |||
student, | |||
teachers, | |||
task, | |||
dataloader: torch.utils.data.DataLoader, | |||
get_optimizer_and_scheduler:Callable=None, | |||
pruning_rounds=5, | |||
device=None, | |||
): | |||
if device is None: | |||
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) | |||
self._device = device | |||
self._dataloader = dataloader | |||
self.model = self.student = student.to(self.device) | |||
self.teachers = nn.ModuleList(teachers).to(self.device) | |||
self.get_optimizer_and_scheduler = get_optimizer_and_scheduler | |||
@property | |||
def device(self): | |||
return self._device | |||
def run(self, max_iter, start_iter=0, epoch_length=None, pruning_rounds=3, target_model_size=0.6 ): | |||
pruning_size_per_round = 1 - math.pow( target_model_size, 1/pruning_rounds ) | |||
prunner = Pruner( strategy.LNStrategy(n=1) ) | |||
for pruning_round in range(pruning_rounds): | |||
prunner.prune( self.student, rate=pruning_size_per_round, example_inputs=torch.randn(1,3,240,240) ) | |||
self.student.to(self.device) | |||
if self.get_optimizer_and_scheduler: | |||
self.optimizer, self.scheduler = self.get_optimizer_and_scheduler( self.student ) | |||
else: | |||
self.optimizer = torch.optim.Adam( self.student.parameters(), lr=1e-4, weight_decay=1e-5 ) | |||
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self.optimizer, T_max= (max_iter-start_iter)//pruning_rounds ) | |||
step_iter = (max_iter - start_iter)//pruning_rounds | |||
with set_mode(self.student, training=True), \ | |||
set_mode(self.teachers, training=False): | |||
super( RecombinationAmalgamation, self ).run(self.step_fn, self._dataloader, | |||
start_iter=start_iter+step_iter*pruning_round , max_iter=start_iter+step_iter*(pruning_round+1), epoch_length=epoch_length) | |||
def step_fn(self, engine, batch): | |||
metrics = super(RecombinationAmalgamation, self).step_fn( engine, batch ) | |||
self.scheduler.step() | |||
return metrics | |||
class RecombinationAmalgamator(PruningKDTrainer): | |||
pass |
@@ -0,0 +1,295 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from kamal.core.engine.engine import Engine | |||
from kamal.core.engine.hooks import FeatureHook | |||
from kamal.core import tasks | |||
from kamal.utils import move_to_device, set_mode | |||
from kamal.core.hub import meta | |||
from kamal import vision | |||
import kamal | |||
from kamal.utils import set_mode | |||
import typing | |||
import time | |||
from copy import deepcopy | |||
import random | |||
import numpy as np | |||
from collections import defaultdict | |||
import numbers | |||
class BranchySegNet(nn.Module): | |||
def __init__(self, out_channels, segnet_fn=vision.models.segmentation.segnet_vgg16_bn): | |||
super(BranchySegNet, self).__init__() | |||
channels=[512, 512, 256, 128, 64] | |||
self.register_buffer( 'branch_indices', torch.zeros((len(out_channels),)) ) | |||
self.student_b_decoders_list = nn.ModuleList() | |||
self.student_adaptors_list = nn.ModuleList() | |||
ses = [] | |||
for i in range(5): | |||
se = int(channels[i]/4) | |||
ses.append(16 if se < 16 else se) | |||
for oc in out_channels: | |||
segnet = self.get_segnet( oc, segnet_fn ) | |||
decoders = nn.ModuleList(deepcopy(list(segnet.children())[5:])) | |||
adaptors = nn.ModuleList() | |||
for i in range(5): | |||
adaptor = nn.Sequential( | |||
nn.Conv2d(channels[i], ses[i], kernel_size=1, bias=False), | |||
nn.ReLU(), | |||
nn.Conv2d(ses[i], channels[i], kernel_size=1, bias=False), | |||
nn.Sigmoid() | |||
) | |||
adaptors.append(adaptor) | |||
self.student_b_decoders_list.append(decoders) | |||
self.student_adaptors_list.append(adaptors) | |||
self.student_encoders = nn.ModuleList(deepcopy(list(segnet.children())[0:5])) | |||
self.student_decoders = nn.ModuleList(deepcopy(list(segnet.children())[5:])) | |||
def set_branch(self, branch_indices): | |||
assert len(branch_indices)==len(self.student_b_decoders_list) | |||
self.branch_indices = torch.from_numpy( np.array( branch_indices ) ).to(self.branch_indices.device) | |||
def get_segnet(self, oc, segnet_fn): | |||
return segnet_fn( num_classes=oc, pretrained_backbone=True ) | |||
def forward(self, inputs): | |||
output_list = [] | |||
down1, indices_1, unpool_shape1 = self.student_encoders[0](inputs) | |||
down2, indices_2, unpool_shape2 = self.student_encoders[1](down1) | |||
down3, indices_3, unpool_shape3 = self.student_encoders[2](down2) | |||
down4, indices_4, unpool_shape4 = self.student_encoders[3](down3) | |||
down5, indices_5, unpool_shape5 = self.student_encoders[4](down4) | |||
up5 = self.student_decoders[0](down5, indices_5, unpool_shape5) | |||
up4 = self.student_decoders[1](up5, indices_4, unpool_shape4) | |||
up3 = self.student_decoders[2](up4, indices_3, unpool_shape3) | |||
up2 = self.student_decoders[3](up3, indices_2, unpool_shape2) | |||
up1 = self.student_decoders[4](up2, indices_1, unpool_shape1) | |||
decoder_features = [down5, up5, up4, up3, up2] | |||
decoder_indices = [indices_5, indices_4, indices_3, indices_2, indices_1] | |||
decoder_shapes = [unpool_shape5, unpool_shape4, unpool_shape3, unpool_shape2, unpool_shape1] | |||
# Mimic teachers. | |||
for i in range(len(self.branch_indices)): | |||
out_idx = self.branch_indices[i] | |||
output = decoder_features[out_idx] | |||
output = output * self.student_adaptors_list[i][out_idx](F.avg_pool2d(output, output.shape[2:3])) | |||
for j in range(out_idx, 5): | |||
output = self.student_b_decoders_list[i][j]( | |||
output, | |||
decoder_indices[j], | |||
decoder_shapes[j] | |||
) | |||
output_list.append( output ) | |||
return output_list | |||
class JointSegNet(nn.Module): | |||
"""The online student model to learn from any number of single teacher with 'SegNet' structure. | |||
**Parameters:** | |||
- **teachers** (list of 'Module' object): Teachers with 'SegNet' structure to learn from. | |||
- **indices** (list of int): Where to branch out for each task. | |||
- **phase** (string): Should be 'block' or 'finetune'. Useful only in training mode. | |||
- **channels** (list of int, optional): Parameter to build adaptor modules, corresponding to that of 'SegNet'. | |||
""" | |||
def __init__(self, teachers, student=None, channels=[512, 512, 256, 128, 64]): | |||
super(JointSegNet, self).__init__() | |||
self.register_buffer( 'branch_indices', torch.zeros((2,)) ) | |||
if student is None: | |||
student = teachers[0] | |||
self.student_encoders = nn.ModuleList(deepcopy(list(teachers[0].children())[0:5])) | |||
self.student_decoders = nn.ModuleList(deepcopy(list(teachers[0].children())[5:])) | |||
self.student_b_decoders_list = nn.ModuleList() | |||
self.student_adaptors_list = nn.ModuleList() | |||
ses = [] | |||
for i in range(5): | |||
se = int(channels[i]/4) | |||
ses.append(16 if se < 16 else se) | |||
for teacher in teachers: | |||
decoders = nn.ModuleList(deepcopy(list(teacher.children())[5:])) | |||
adaptors = nn.ModuleList() | |||
for i in range(5): | |||
adaptor = nn.Sequential( | |||
nn.Conv2d(channels[i], ses[i], kernel_size=1, bias=False), | |||
nn.ReLU(), | |||
nn.Conv2d(ses[i], channels[i], kernel_size=1, bias=False), | |||
nn.Sigmoid() | |||
) | |||
adaptors.append(adaptor) | |||
self.student_b_decoders_list.append(decoders) | |||
self.student_adaptors_list.append(adaptors) | |||
def set_branch(self, branch_indices): | |||
assert len(branch_indices)==len(self.student_b_decoders_list) | |||
self.branch_indices = torch.from_numpy( np.array( branch_indices ) ).to(self.branch_indices.device) | |||
def forward(self, inputs): | |||
output_list = [] | |||
down1, indices_1, unpool_shape1 = self.student_encoders[0](inputs) | |||
down2, indices_2, unpool_shape2 = self.student_encoders[1](down1) | |||
down3, indices_3, unpool_shape3 = self.student_encoders[2](down2) | |||
down4, indices_4, unpool_shape4 = self.student_encoders[3](down3) | |||
down5, indices_5, unpool_shape5 = self.student_encoders[4](down4) | |||
up5 = self.student_decoders[0](down5, indices_5, unpool_shape5) | |||
up4 = self.student_decoders[1](up5, indices_4, unpool_shape4) | |||
up3 = self.student_decoders[2](up4, indices_3, unpool_shape3) | |||
up2 = self.student_decoders[3](up3, indices_2, unpool_shape2) | |||
up1 = self.student_decoders[4](up2, indices_1, unpool_shape1) | |||
decoder_features = [down5, up5, up4, up3, up2] | |||
decoder_indices = [indices_5, indices_4, indices_3, indices_2, indices_1] | |||
decoder_shapes = [unpool_shape5, unpool_shape4, unpool_shape3, unpool_shape2, unpool_shape1] | |||
# Mimic teachers. | |||
for i in range(len(self.branch_indices)): | |||
out_idx = self.branch_indices[i] | |||
output = decoder_features[out_idx] | |||
output = output * self.student_adaptors_list[i][out_idx](F.avg_pool2d(output, output.shape[2:3])) | |||
for j in range(out_idx, 5): | |||
output = self.student_b_decoders_list[i][j]( | |||
output, | |||
decoder_indices[j], | |||
decoder_shapes[j] | |||
) | |||
output_list.append( output ) | |||
return output_list | |||
class TaskBranchingAmalgamator(Engine): | |||
def setup( | |||
self, | |||
joint_student: JointSegNet, | |||
teachers, | |||
tasks, | |||
dataloader: torch.utils.data.DataLoader, | |||
optimizer: torch.optim.Optimizer, | |||
device=None, | |||
): | |||
if device is None: | |||
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) | |||
self._device = device | |||
self._dataloader = dataloader | |||
self.student = self.model = joint_student.to(self._device) | |||
self.teachers = nn.ModuleList(teachers).to(self._device) | |||
self.tasks = tasks | |||
self.optimizer = optimizer | |||
self.is_finetuning=False | |||
@property | |||
def device(self): | |||
return self._device | |||
def run(self, max_iter, start_iter=0, epoch_length=None, stage_callback=None ): | |||
# Branching | |||
with set_mode(self.student, training=True), \ | |||
set_mode(self.teachers, training=False): | |||
super( TaskBranchingAmalgamator, self ).run(self.step_fn, self._dataloader, start_iter=start_iter, max_iter=max_iter//2, epoch_length=epoch_length) | |||
branch = self.find_the_best_branch( self._dataloader ) | |||
self.logger.info("[Task Branching] the best branch indices: %s"%(branch)) | |||
if stage_callback is not None: | |||
stage_callback() | |||
# Finetuning | |||
self.is_finetuning = True | |||
with set_mode(self.student, training=True), \ | |||
set_mode(self.teachers, training=False): | |||
super( TaskBranchingAmalgamator, self ).run(self.step_fn, self._dataloader, start_iter=max_iter-max_iter//2, max_iter=max_iter, epoch_length=epoch_length) | |||
def find_the_best_branch(self, dataloader): | |||
with set_mode(self.student, training=False), \ | |||
set_mode(self.teachers, training=False), \ | |||
torch.no_grad(): | |||
n_blocks = len(self.student.student_decoders) | |||
branch_loss = { task: [0 for _ in range(n_blocks)] for task in self.tasks } | |||
for batch in dataloader: | |||
batch = move_to_device(batch, self.device) | |||
data = batch if isinstance(batch, torch.Tensor) else batch[0] | |||
for candidate_branch in range( n_blocks ): | |||
self.student.set_branch( [candidate_branch for _ in range(len(self.teachers))] ) | |||
s_out_list = self.student( data ) | |||
t_out_list = [ teacher( data ) for teacher in self.teachers ] | |||
for task, s_out, t_out in zip( self.tasks, s_out_list, t_out_list ): | |||
task_loss = task.get_loss( s_out, t_out ) | |||
branch_loss[task][candidate_branch] += sum(task_loss.values()) | |||
best_brach = [] | |||
for task in self.tasks: | |||
best_brach.append( int(np.argmin( branch_loss[task] )) ) | |||
self.student.set_branch(best_brach) | |||
return best_brach | |||
@property | |||
def device(self): | |||
return self._device | |||
def step_fn(self, engine, batch): | |||
start_time = time.perf_counter() | |||
batch = move_to_device(batch, self._device) | |||
data = batch[0] | |||
#data = batch if isinstance(batch, torch.Tensor) else batch[0] | |||
data, None | |||
n_blocks = len(self.student.student_decoders) | |||
if not self.is_finetuning: | |||
rand_branch_indices = [ random.randint(0, n_blocks-1) for _ in range(len(self.teachers)) ] | |||
self.student.set_branch(rand_branch_indices) | |||
s_out_list = self.student( data ) | |||
with torch.no_grad(): | |||
t_out_list = [ teacher( data ) for teacher in self.teachers ] | |||
loss_dict = {} | |||
for task, s_out, t_out in zip( self.tasks, s_out_list, t_out_list ): | |||
task_loss = task.get_loss( s_out, t_out ) | |||
loss_dict.update( task_loss ) | |||
loss = sum(loss_dict.values()) | |||
self.optimizer.zero_grad() | |||
loss.backward() | |||
self.optimizer.step() | |||
step_time = time.perf_counter() - start_time | |||
metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } | |||
metrics.update({ | |||
'total_loss': loss.item(), | |||
'step_time': step_time, | |||
'lr': float( self.optimizer.param_groups[0]['lr'] ), | |||
'branch': self.student.branch_indices.cpu().numpy().tolist() | |||
}) | |||
return metrics | |||
@@ -0,0 +1,4 @@ | |||
from . import engine, tasks, metrics, callbacks, exceptions, hub | |||
from .attach import AttachTo | |||
from .hub import load, save |
@@ -0,0 +1,45 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
from typing import Sequence, Callable | |||
from numbers import Number | |||
from kamal.core import exceptions | |||
class AttachTo(object): | |||
""" Attach task, metrics or visualizer to specified model outputs | |||
""" | |||
def __init__(self, attach_to=None): | |||
if attach_to is not None and not isinstance(attach_to, (Sequence, Number, str, Callable) ): | |||
raise exceptions.InvalidMapping | |||
self._attach_to = attach_to | |||
def __call__(self, *tensors): | |||
if self._attach_to is not None: | |||
if isinstance(self._attach_to, Callable): | |||
return self._attach_to( *tensors ) | |||
if isinstance(self._attach_to, Sequence): | |||
_attach_to = self._attach_to | |||
else: | |||
_attach_to = [ self._attach_to for _ in range(len(tensors)) ] | |||
_attach_to = _attach_to[:len(tensors)] | |||
tensors = [ tensor[index] for (tensor, index) in zip( tensors, _attach_to ) ] | |||
if len(tensors)==1: | |||
tensors = tensors[0] | |||
return tensors | |||
def __repr__(self): | |||
rep = "AttachTo: %s"%(self._attach_to) |
@@ -0,0 +1,5 @@ | |||
from .logging import MetricsLogging, ProgressCallback | |||
from .base import Callback | |||
from .eval_and_ckpt import EvalAndCkpt | |||
from .scheduler import LRSchedulerCallback | |||
from .visualize import VisualizeOutputs, VisualizeSegmentation, VisualizeDepth |
@@ -0,0 +1,28 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
import abc | |||
class Callback(abc.ABC): | |||
r""" Base Class for Callbacks | |||
""" | |||
def __init__(self): | |||
pass | |||
@abc.abstractmethod | |||
def __call__(self, engine): | |||
pass |
@@ -0,0 +1,145 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
from .base import Callback | |||
import weakref | |||
from kamal import utils | |||
from typing import Sequence, Optional | |||
import numbers | |||
import os, shutil | |||
import torch | |||
class EvalAndCkpt(Callback): | |||
def __init__(self, | |||
model, | |||
evaluator, | |||
metric_name:str, | |||
metric_mode:str ='max', | |||
save_type:Optional[Sequence]=('best', 'latest'), | |||
ckpt_dir:str ='checkpoints', | |||
ckpt_prefix:str =None, | |||
log_tag:str ='model', | |||
weights_only:bool =True, | |||
verbose:bool =False,): | |||
super(EvalAndCkpt, self).__init__() | |||
self.metric_name = metric_name | |||
assert metric_mode in ('max', 'min'), "metric_mode should be 'max' or 'min'" | |||
self._metric_mode = metric_mode | |||
self._model = weakref.ref( model ) | |||
self._evaluator = evaluator | |||
self._ckpt_dir = ckpt_dir | |||
self._ckpt_prefix = "" if ckpt_prefix is None else (ckpt_prefix+'_') | |||
if isinstance(save_type, str): | |||
save_type = ( save_type, ) | |||
self._save_type = save_type | |||
if self._save_type is not None: | |||
for save_type in self._save_type: | |||
assert save_type in ('best', 'latest', 'all'), \ | |||
'save_type should be None or a subset of (\"best\", \"latest\", \"all\")' | |||
self._log_tag = log_tag | |||
self._weights_only = weights_only | |||
self._verbose = verbose | |||
self._best_score = -999999 if self._metric_mode=='max' else 99999. | |||
self._best_ckpt = None | |||
self._latest_ckpt = None | |||
@property | |||
def best_ckpt(self): | |||
return self._best_ckpt | |||
@property | |||
def latest_ckpt(self): | |||
return self._latest_ckpt | |||
@property | |||
def best_score(self): | |||
return self._best_score | |||
def __call__(self, trainer): | |||
model = self._model() | |||
results = self._evaluator.eval( model, device=trainer.device ) | |||
results = utils.flatten_dict(results) | |||
current_score = results[self.metric_name] | |||
scalar_results = { k: float(v) for (k, v) in results.items() if isinstance(v, numbers.Number) or len(v.shape)==0 } | |||
if trainer.logger is not None: | |||
trainer.logger.info( "[Eval %s] Iter %d/%d: %s"%(self._log_tag, trainer.state.iter, trainer.state.max_iter, scalar_results) ) | |||
trainer.state.metrics.update( scalar_results ) | |||
# Visualize results if trainer.tb_writer is not None | |||
if trainer.tb_writer is not None: | |||
for k, v in scalar_results.items(): | |||
log_tag = "%s:%s"%(self._log_tag, k) | |||
trainer.tb_writer.add_scalar(log_tag, v, global_step=trainer.state.iter) | |||
if self._save_type is not None: | |||
pth_path_list = [] | |||
# interval model | |||
if 'interval' in self._save_type: | |||
pth_path = os.path.join(self._ckpt_dir, "%s%08d_%s_%.3f.pth" | |||
% (self._ckpt_prefix, trainer.state.iter, self.metric_name, current_score)) | |||
pth_path_list.append(pth_path) | |||
# the latest model | |||
if 'latest' in self._save_type: | |||
pth_path = os.path.join(self._ckpt_dir, "%slatest_%08d_%s_%.3f.pth" | |||
% (self._ckpt_prefix, trainer.state.iter, self.metric_name, current_score)) | |||
# remove the old ckpt | |||
if self._latest_ckpt is not None and os.path.exists(self._latest_ckpt): | |||
os.remove(self._latest_ckpt) | |||
pth_path_list.append(pth_path) | |||
self._latest_ckpt = pth_path | |||
# the best model | |||
if 'best' in self._save_type: | |||
if (current_score >= self._best_score and self._metric_mode=='max') or \ | |||
(current_score <= self._best_score and self._metric_mode=='min'): | |||
pth_path = os.path.join(self._ckpt_dir, "%sbest_%08d_%s_%.4f.pth" % | |||
(self._ckpt_prefix, trainer.state.iter, self.metric_name, current_score)) | |||
# remove the old ckpt | |||
if self._best_ckpt is not None and os.path.exists(self._best_ckpt): | |||
os.remove(self._best_ckpt) | |||
pth_path_list.append(pth_path) | |||
self._best_score = current_score | |||
self._best_ckpt = pth_path | |||
# save model | |||
if self._verbose and trainer.logger: | |||
trainer.logger.info("Model saved as:") | |||
obj = model.state_dict() if self._weights_only else model | |||
os.makedirs( self._ckpt_dir, exist_ok=True ) | |||
for pth_path in pth_path_list: | |||
torch.save(obj, pth_path) | |||
if self._verbose and trainer.logger: | |||
trainer.logger.info("\t%s" % (pth_path)) | |||
def final_ckpt(self, ckpt_prefix=None, ckpt_dir=None, add_md5=False): | |||
if ckpt_dir is None: | |||
ckpt_dir = self._ckpt_dir | |||
if ckpt_prefix is None: | |||
ckpt_prefix = self._ckpt_prefix | |||
if self._save_type is not None: | |||
#if 'latest' in self._save_type and self._latest_ckpt is not None: | |||
# os.makedirs(ckpt_dir, exist_ok=True) | |||
# save_name = "%slatest%s.pth"%(ckpt_prefix, "" if not add_md5 else "-%s"%utils.md5(self._latest_ckpt)) | |||
# shutil.copy2(self._latest_ckpt, os.path.join(ckpt_dir, save_name)) | |||
if 'best' in self._save_type and self._best_ckpt is not None: | |||
os.makedirs(ckpt_dir, exist_ok=True) | |||
save_name = "%sbest%s.pth"%(ckpt_prefix, "" if not add_md5 else "-%s"%utils.md5(self._best_ckpt)) | |||
shutil.copy2(self._best_ckpt, os.path.join(ckpt_dir, save_name)) |
@@ -0,0 +1,61 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
from .base import Callback | |||
import numbers | |||
from tqdm import tqdm | |||
class MetricsLogging(Callback): | |||
def __init__(self, keys): | |||
super(MetricsLogging, self).__init__() | |||
self._keys = keys | |||
def __call__(self, engine): | |||
if engine.logger==None: | |||
return | |||
state = engine.state | |||
content = "Iter %d/%d (Epoch %d/%d, Batch %d/%d)"%( | |||
state.iter, state.max_iter, | |||
state.current_epoch, state.max_epoch, | |||
state.current_batch_index, state.max_batch_index | |||
) | |||
for key in self._keys: | |||
value = state.metrics.get(key, None) | |||
if value is not None: | |||
if isinstance(value, numbers.Number): | |||
content += " %s=%.4f"%(key, value) | |||
if engine.tb_writer is not None: | |||
engine.tb_writer.add_scalar(key, value, global_step=state.iter) | |||
elif isinstance(value, (list, tuple)): | |||
content += " %s=%s"%(key, value) | |||
engine.logger.info(content) | |||
class ProgressCallback(Callback): | |||
def __init__(self, max_iter=100, tag=None): | |||
self._max_iter = max_iter | |||
self._tag = tag | |||
#self._pbar = tqdm(total=self._max_iter, desc=self._tag) | |||
def __call__(self, engine): | |||
self._pbar.update(1) | |||
if self._pbar.n==self._max_iter: | |||
self._pbar.close() | |||
def reset(self): | |||
self._pbar = tqdm(total=self._max_iter, desc=self._tag) | |||
@@ -0,0 +1,34 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
from .base import Callback | |||
from typing import Sequence | |||
class LRSchedulerCallback(Callback): | |||
r""" LR scheduler callback | |||
""" | |||
def __init__(self, schedulers=None): | |||
super(LRSchedulerCallback, self).__init__() | |||
if not isinstance(schedulers, Sequence): | |||
schedulers = ( schedulers, ) | |||
self._schedulers = schedulers | |||
def __call__(self, trainer): | |||
if self._schedulers is None: | |||
return | |||
for sched in self._schedulers: | |||
sched.step() |
@@ -0,0 +1,161 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
from .base import Callback | |||
from typing import Callable, Union, Sequence | |||
import weakref | |||
import random | |||
from kamal.utils import move_to_device, set_mode, split_batch, colormap | |||
from kamal.core.attach import AttachTo | |||
import torch | |||
import numpy as np | |||
import matplotlib.pyplot as plt | |||
import matplotlib | |||
matplotlib.use('agg') | |||
import math | |||
import numbers | |||
class VisualizeOutputs(Callback): | |||
def __init__(self, | |||
model, | |||
dataset: torch.utils.data.Dataset, | |||
idx_list_or_num_vis: Union[int, Sequence]=5, | |||
normalizer: Callable=None, | |||
prepare_fn: Callable=None, | |||
decode_fn: Callable=None, # decode targets and preds | |||
tag: str='viz'): | |||
super(VisualizeOutputs, self).__init__() | |||
self._dataset = dataset | |||
self._model = weakref.ref(model) | |||
if isinstance(idx_list_or_num_vis, int): | |||
self.idx_list = self._get_vis_idx_list(self._dataset, idx_list_or_num_vis) | |||
elif isinstance(idx_list_or_num_vis, Sequence): | |||
self.idx_list = idx_list_or_num_vis | |||
self._normalizer = normalizer | |||
self._decode_fn = decode_fn | |||
if prepare_fn is None: | |||
prepare_fn = VisualizeOutputs.get_prepare_fn() | |||
self._prepare_fn = prepare_fn | |||
self._tag = tag | |||
def _get_vis_idx_list(self, dataset, num_vis): | |||
return random.sample(list(range(len(dataset))), num_vis) | |||
@torch.no_grad() | |||
def __call__(self, trainer): | |||
if trainer.tb_writer is None: | |||
trainer.logger.warning("summary writer was not found in trainer") | |||
return | |||
device = trainer.device | |||
model = self._model() | |||
with torch.no_grad(), set_mode(model, training=False): | |||
for img_id, idx in enumerate(self.idx_list): | |||
batch = move_to_device(self._dataset[idx], device) | |||
batch = [ d.unsqueeze(0) for d in batch ] | |||
inputs, targets, preds = self._prepare_fn(model, batch) | |||
if self._normalizer is not None: | |||
inputs = self._normalizer(inputs) | |||
inputs = inputs.detach().cpu().numpy() | |||
preds = preds.detach().cpu().numpy() | |||
targets = targets.detach().cpu().numpy() | |||
if self._decode_fn: # to RGB 0~1 NCHW | |||
preds = self._decode_fn(preds) | |||
targets = self._decode_fn(targets) | |||
inputs = inputs[0] | |||
preds = preds[0] | |||
targets = targets[0] | |||
trainer.tb_writer.add_images("%s-%d"%(self._tag, img_id), np.stack( [inputs, targets, preds], axis=0), global_step=trainer.state.iter) | |||
@staticmethod | |||
def get_prepare_fn(attach_to=None, pred_fn=lambda x: x): | |||
attach_to = AttachTo(attach_to) | |||
def wrapper(model, batch): | |||
inputs, targets = split_batch(batch) | |||
outputs = model(inputs) | |||
outputs, targets = attach_to(outputs, targets) | |||
return inputs, targets, pred_fn(outputs) | |||
return wrapper | |||
@staticmethod | |||
def get_seg_decode_fn(cmap=colormap(), index_transform=lambda x: x+1): # 255->0, 0->1, | |||
def wrapper(preds): | |||
if len(preds.shape)>3: | |||
preds = preds.squeeze(1) | |||
out = cmap[ index_transform(preds.astype('uint8')) ] | |||
out = out.transpose(0, 3, 1, 2) / 255 | |||
return out | |||
return wrapper | |||
@staticmethod | |||
def get_depth_decode_fn(max_depth, log_scale=True, cmap=plt.get_cmap('jet')): | |||
def wrapper(preds): | |||
if log_scale: | |||
_max_depth = np.log( max_depth ) | |||
preds = np.log( preds ) | |||
else: | |||
_max_depth = max_depth | |||
if len(preds.shape)>3: | |||
preds = preds.squeeze(1) | |||
out = (cmap(preds.clip(0, _max_depth)/_max_depth)).transpose(0, 3, 1, 2)[:, :3] | |||
return out | |||
return wrapper | |||
class VisualizeSegmentation(VisualizeOutputs): | |||
def __init__( | |||
self, model, dataset: torch.utils.data.Dataset, idx_list_or_num_vis: Union[int, Sequence]=5, | |||
cmap = colormap(), | |||
attach_to=None, | |||
normalizer: Callable=None, | |||
prepare_fn: Callable=None, | |||
decode_fn: Callable=None, | |||
tag: str='seg' | |||
): | |||
if prepare_fn is None: | |||
prepare_fn = VisualizeOutputs.get_prepare_fn(attach_to=attach_to, pred_fn=lambda x: x.max(1)[1]) | |||
if decode_fn is None: | |||
decode_fn = VisualizeOutputs.get_seg_decode_fn(cmap=cmap, index_transform=lambda x: x+1) | |||
super(VisualizeSegmentation, self).__init__( | |||
model=model, dataset=dataset, idx_list_or_num_vis=idx_list_or_num_vis, | |||
normalizer=normalizer, prepare_fn=prepare_fn, decode_fn=decode_fn, | |||
tag=tag | |||
) | |||
class VisualizeDepth(VisualizeOutputs): | |||
def __init__( | |||
self, model, dataset: torch.utils.data.Dataset, | |||
idx_list_or_num_vis: Union[int, Sequence]=5, | |||
max_depth = 10, | |||
log_scale = True, | |||
attach_to = None, | |||
normalizer: Callable=None, | |||
prepare_fn: Callable=None, | |||
decode_fn: Callable=None, | |||
tag: str='depth' | |||
): | |||
if prepare_fn is None: | |||
prepare_fn = VisualizeOutputs.get_prepare_fn(attach_to=attach_to, pred_fn=lambda x: x) | |||
if decode_fn is None: | |||
decode_fn = VisualizeOutputs.get_depth_decode_fn(max_depth=max_depth, log_scale=log_scale) | |||
super(VisualizeDepth, self).__init__( | |||
model=model, dataset=dataset, idx_list_or_num_vis=idx_list_or_num_vis, | |||
normalizer=normalizer, prepare_fn=prepare_fn, decode_fn=decode_fn, | |||
tag=tag | |||
) |
@@ -0,0 +1,7 @@ | |||
from . import evaluator | |||
from . import trainer | |||
from . import lr_finder | |||
from . import hooks | |||
from . import events | |||
from .engine import DefaultEvents, Event |
@@ -0,0 +1,190 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
import torch | |||
import torch.nn as nn | |||
import abc, math, weakref, typing, time | |||
from typing import Any, Callable, Optional, Sequence | |||
import numpy as np | |||
from kamal.core.engine.events import DefaultEvents, Event | |||
from kamal.core import tasks | |||
from kamal.utils import set_mode, move_to_device, get_logger | |||
from collections import defaultdict | |||
import numbers | |||
import contextlib | |||
class State(object): | |||
def __init__(self): | |||
self.iter = 0 | |||
self.max_iter = None | |||
self.epoch_length = None | |||
self.dataloader = None | |||
self.seed = None | |||
self.metrics=dict() | |||
self.batch=None | |||
@property | |||
def current_epoch(self): | |||
if self.epoch_length is not None: | |||
return self.iter // self.epoch_length | |||
return None | |||
@property | |||
def max_epoch(self): | |||
if self.epoch_length is not None: | |||
return self.max_iter // self.epoch_length | |||
return None | |||
@property | |||
def current_batch_index(self): | |||
if self.epoch_length is not None: | |||
return self.iter % self.epoch_length | |||
return None | |||
@property | |||
def max_batch_index(self): | |||
return self.epoch_length | |||
def __repr__(self): | |||
rep = "State:\n" | |||
for attr, value in self.__dict__.items(): | |||
if not isinstance(value, (numbers.Number, str, dict)): | |||
value = type(value) | |||
rep += "\t{}: {}\n".format(attr, value) | |||
return rep | |||
class Engine(abc.ABC): | |||
def __init__(self, logger=None, tb_writer=None): | |||
self._logger = logger if logger else get_logger(name='kamal', color=True) | |||
self._tb_writer = tb_writer | |||
self._callbacks = defaultdict(list) | |||
self._allowed_events = [ *DefaultEvents ] | |||
self._state = State() | |||
def reset(self): | |||
self._state = State() | |||
def run(self, step_fn: Callable, dataloader, max_iter, start_iter=0, epoch_length=None): | |||
self.state.iter = self._state.start_iter = start_iter | |||
self.state.max_iter = max_iter | |||
self.state.epoch_length = epoch_length if epoch_length else len(dataloader) | |||
self.state.dataloader = dataloader | |||
self.state.dataloader_iter = iter(dataloader) | |||
self.state.step_fn = step_fn | |||
self.trigger_events(DefaultEvents.BEFORE_RUN) | |||
for self.state.iter in range( start_iter, max_iter ): | |||
if self.state.epoch_length!=None and \ | |||
self.state.iter%self.state.epoch_length==0: # Epoch Start | |||
self.trigger_events(DefaultEvents.BEFORE_EPOCH) | |||
self.trigger_events(DefaultEvents.BEFORE_STEP) | |||
self.state.batch = self._get_batch() | |||
step_output = step_fn(self, self.state.batch) | |||
if isinstance(step_output, dict): | |||
self.state.metrics.update(step_output) | |||
self.trigger_events(DefaultEvents.AFTER_STEP) | |||
if self.state.epoch_length!=None and \ | |||
(self.state.iter+1)%self.state.epoch_length==0: # Epoch End | |||
self.trigger_events(DefaultEvents.AFTER_EPOCH) | |||
self.trigger_events(DefaultEvents.AFTER_RUN) | |||
def _get_batch(self): | |||
try: | |||
batch = next( self.state.dataloader_iter ) | |||
except StopIteration: | |||
self.state.dataloader_iter = iter(self.state.dataloader) # reset iterator | |||
batch = next( self.state.dataloader_iter ) | |||
if not isinstance(batch, (list, tuple)): | |||
batch = [ batch, ] # no targets | |||
return batch | |||
@property | |||
def state(self): | |||
return self._state | |||
@property | |||
def logger(self): | |||
return self._logger | |||
@property | |||
def tb_writer(self): | |||
return self._tb_writer | |||
def add_callback(self, event: Event, callbacks ): | |||
if not isinstance(callbacks, Sequence): | |||
callbacks = [callbacks] | |||
if event in self._allowed_events: | |||
for callback in callbacks: | |||
if callback not in self._callbacks[event]: | |||
if event.trigger!=event.default_trigger: | |||
callback = self._trigger_wrapper(self, event.trigger, callback ) | |||
self._callbacks[event].append( callback ) | |||
callbacks = [ RemovableCallback(self, event, c) for c in callbacks ] | |||
return ( callbacks[0] if len(callbacks)==1 else callbacks ) | |||
def remove_callback(self, event, callback): | |||
for c in self._callbacks[event]: | |||
if c==callback: | |||
self._callbacks.remove( callback ) | |||
return True | |||
return False | |||
@staticmethod | |||
def _trigger_wrapper(engine, trigger, callback): | |||
def wrapper(*args, **kwargs) -> Any: | |||
if trigger(engine): | |||
return callback(engine) | |||
return wrapper | |||
def trigger_events(self, *events): | |||
for e in events: | |||
if e in self._allowed_events: | |||
for callback in self._callbacks[e]: | |||
callback(self) | |||
def register_events(self, *events): | |||
for e in events: | |||
if e not in self._allowed_events: | |||
self._allowed_events.apped( e ) | |||
@contextlib.contextmanager | |||
def save_current_callbacks(self): | |||
temp = self._callbacks | |||
self._callbacks = defaultdict(list) | |||
yield | |||
self._callbacks = temp | |||
class RemovableCallback: | |||
def __init__(self, engine, event, callback): | |||
self._engine = weakref.ref(engine) | |||
self._callback = weakref.ref(callback) | |||
self._event = weakref.ref(event) | |||
@property | |||
def callback(self): | |||
return self._callback() | |||
def remove(self): | |||
engine = self._engine() | |||
callback = self._callback() | |||
event = self._event() | |||
return engine.remove_callback(event, callback) | |||
@@ -0,0 +1,130 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
import abc, sys | |||
import torch | |||
from tqdm import tqdm | |||
from kamal.core import metrics | |||
from kamal.utils import set_mode | |||
from typing import Any, Callable | |||
from .engine import Engine | |||
from .events import DefaultEvents | |||
from kamal.core import callbacks | |||
import weakref | |||
from kamal.utils import move_to_device, split_batch | |||
class BasicEvaluator(Engine): | |||
def __init__(self, | |||
dataloader: torch.utils.data.DataLoader, | |||
metric: metrics.MetricCompose, | |||
eval_fn: Callable=None, | |||
tag: str='Eval', | |||
progress: bool=False ): | |||
super( BasicEvaluator, self ).__init__() | |||
self.dataloader = dataloader | |||
self.metric = metric | |||
self.progress = progress | |||
if progress: | |||
self.porgress_callback = self.add_callback( | |||
DefaultEvents.AFTER_STEP, callbacks=callbacks.ProgressCallback(max_iter=len(self.dataloader), tag=tag)) | |||
self._model = None | |||
self._tag = tag | |||
if eval_fn is None: | |||
eval_fn = BasicEvaluator.default_eval_fn | |||
self.eval_fn = eval_fn | |||
def eval(self, model, device=None): | |||
device = device if device is not None else \ | |||
torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) | |||
self._model = weakref.ref(model) # use weakref here | |||
self.device = device | |||
self.metric.reset() | |||
model.to(device) | |||
if self.progress: | |||
self.porgress_callback.callback.reset() | |||
with torch.no_grad(), set_mode(model, training=False): | |||
super(BasicEvaluator, self).run( self.step_fn, self.dataloader, max_iter=len(self.dataloader) ) | |||
return self.metric.get_results() | |||
@property | |||
def model(self): | |||
if self._model is not None: | |||
return self._model() | |||
return None | |||
def step_fn(self, engine, batch): | |||
batch = move_to_device(batch, self.device) | |||
self.eval_fn( engine, batch ) | |||
@staticmethod | |||
def default_eval_fn(evaluator, batch): | |||
model = evaluator.model | |||
inputs, targets = split_batch(batch) | |||
outputs = model( inputs ) | |||
evaluator.metric.update( outputs, targets ) | |||
class TeacherEvaluator(BasicEvaluator): | |||
def __init__(self, | |||
dataloader: torch.utils.data.DataLoader, | |||
teacher: torch.nn.Module, | |||
task, | |||
metric: metrics.MetricCompose, | |||
eval_fn: Callable=None, | |||
tag: str='Eval', | |||
progress: bool=False ): | |||
if eval_fn is None: | |||
eval_fn = TeacherEvaluator.default_eval_fn | |||
super( TeacherEvaluator, self ).__init__(dataloader=dataloader, metric=metric, eval_fn=eval_fn, tag=tag, progress=progress) | |||
self._teacher = teacher | |||
self.task = task | |||
def eval(self, model, device=None): | |||
self.teacher.to(device) | |||
with set_mode(self.teacher, training=False): | |||
return super(TeacherEvaluator, self).eval( model, device=device ) | |||
@property | |||
def model(self): | |||
if self._model is not None: | |||
return self._model() | |||
return None | |||
@property | |||
def teacher(self): | |||
return self._teacher | |||
def step_fn(self, engine, batch): | |||
batch = move_to_device(batch, self.device) | |||
self.eval_fn( engine, batch ) | |||
@staticmethod | |||
def default_eval_fn(evaluator, batch): | |||
model = evaluator.model | |||
teacher = evaluator.teacher | |||
inputs, targets = split_batch(batch) | |||
outputs = model( inputs ) | |||
# get teacher outputs | |||
if isinstance(teacher, torch.nn.ModuleList): | |||
targets = [ task.predict(tea(inputs)) for (tea, task) in zip(teacher, evaluator.task) ] | |||
else: | |||
t_outputs = teacher(inputs) | |||
targets = evaluator.task.predict( t_outputs ) | |||
evaluator.metric.update( outputs, targets ) |
@@ -0,0 +1,92 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
from typing import Callable, Optional | |||
from enum import Enum | |||
class Event(object): | |||
def __init__(self, value: str, event_trigger: Optional[Callable]=None ): | |||
if event_trigger is None: | |||
event_trigger = Event.default_trigger | |||
self._trigger = event_trigger | |||
self._name_ = self._value_ = value | |||
@property | |||
def trigger(self): | |||
return self._trigger | |||
@property | |||
def name(self): | |||
"""The name of the Enum member.""" | |||
return self._name_ | |||
@property | |||
def value(self): | |||
"""The value of the Enum member.""" | |||
return self._value_ | |||
@staticmethod | |||
def default_trigger(engine): | |||
return True | |||
@staticmethod | |||
def once_trigger(): | |||
is_triggered = False | |||
def wrapper(engine): | |||
if is_triggered: | |||
return False | |||
is_triggered=True | |||
return True | |||
return wrapper | |||
@staticmethod | |||
def every_trigger(every: int): | |||
def wrapper(engine): | |||
return every>0 and (engine.state.iter % every)==0 | |||
return wrapper | |||
def __call__(self, every: Optional[int]=None, once: Optional[bool]=None ): | |||
if every is not None: | |||
assert once is None | |||
return Event(self.value, event_trigger=Event.every_trigger(every) ) | |||
if once is not None: | |||
return Event(self.value, event_trigger=Event.once_trigger() ) | |||
return Event(self.value) | |||
def __hash__(self): | |||
return hash(self._name_) | |||
def __eq__(self, other): | |||
if hasattr(other, 'value'): | |||
return self.value==other.value | |||
else: | |||
return | |||
class DefaultEvents(Event, Enum): | |||
BEFORE_RUN = "before_train" | |||
AFTER_RUN = "after_train" | |||
BEFORE_EPOCH = "before_epoch" | |||
AFTER_EPOCH = "after_epoch" | |||
BEFORE_STEP = "before_step" | |||
AFTER_STEP = "after_step" | |||
BEFORE_GET_BATCH = "before_get_batch" | |||
AFTER_GET_BATCH = "after_get_batch" | |||
@@ -0,0 +1,34 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
class FeatureHook(): | |||
def __init__(self, module): | |||
self.module = module | |||
self.feat_in = None | |||
self.feat_out = None | |||
self.register() | |||
def register(self): | |||
self._hook = self.module.register_forward_hook(self.hook_fn_forward) | |||
def remove(self): | |||
self._hook.remove() | |||
def hook_fn_forward(self, module, fea_in, fea_out): | |||
self.feat_in = fea_in[0] | |||
self.feat_out = fea_out | |||
@@ -0,0 +1,214 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
import torch | |||
import os | |||
import tempfile, uuid | |||
import contextlib | |||
import numpy as np | |||
from tqdm import tqdm | |||
from kamal.core import callbacks | |||
from kamal.core.engine import evaluator | |||
from kamal.core.engine.events import DefaultEvents | |||
class _ProgressCallback(callbacks.Callback): | |||
def __init__(self, max_iter, tag): | |||
self._tag = tag | |||
self._max_iter = max_iter | |||
self._pbar = tqdm(total=self._max_iter, desc=self._tag) | |||
def __call__(self, enigine): | |||
self._pbar.update(1) | |||
def reset(self): | |||
self._pbar = tqdm(total=self._max_iter, desc=self._tag) | |||
class _LinearLRScheduler(torch.optim.lr_scheduler._LRScheduler): | |||
""" Linear Scheduler | |||
""" | |||
def __init__(self, | |||
optimizer, | |||
lr_range, | |||
max_iter, | |||
last_epoch=-1): | |||
self.lr_range = lr_range | |||
self.max_iter = max_iter | |||
super(_LinearLRScheduler, self).__init__(optimizer, last_epoch) | |||
def get_lr(self): | |||
r = self.last_epoch / self.max_iter | |||
return [self.lr_range[0] + (self.lr_range[1]-self.lr_range[0]) * r for base_lr in self.base_lrs] | |||
class _ExponentialLRScheduler(torch.optim.lr_scheduler._LRScheduler): | |||
def __init__(self, optimizer, lr_range, max_iter, last_epoch=-1): | |||
self.lr_range = lr_range | |||
self.max_iter = max_iter | |||
super(_ExponentialLRScheduler, self).__init__(optimizer, last_epoch) | |||
def get_lr(self): | |||
r = self.last_epoch / (self.max_iter - 1) | |||
return [self.lr_range[0] * (self.lr_range[1] / self.lr_range[0]) ** r for base_lr in self.base_lrs] | |||
class _LRFinderCallback(object): | |||
def __init__(self, | |||
model, | |||
optimizer, | |||
metric_name, | |||
evaluator, | |||
smooth_momentum): | |||
self._model = model | |||
self._optimizer = optimizer | |||
self._metric_name = metric_name | |||
self._evaluator = evaluator | |||
self._smooth_momentum = smooth_momentum | |||
self._records = [] | |||
@property | |||
def records(self): | |||
return self._records | |||
def reset(self): | |||
self._records = [] | |||
def __call__(self, trainer): | |||
model = self._model | |||
optimizer = self._optimizer | |||
if self._evaluator is not None: | |||
results = self._evaluator.eval( model ) | |||
score = float(results[ self._metric_name ]) | |||
else: | |||
score = float(trainer.state.metrics[self._metric_name]) | |||
if self._smooth_momentum>0 and len(self._records)>0: | |||
score = self._records[-1][1] * self._smooth_momentum + score * (1-self._smooth_momentum) | |||
lr = optimizer.param_groups[0]['lr'] | |||
self._records.append( ( lr, score ) ) | |||
class LRFinder(object): | |||
def _reset(self): | |||
init_state = torch.load( self._temp_file ) | |||
self.trainer.model.load_state_dict( init_state['model'] ) | |||
self.trainer.optimizer.load_state_dict( init_state['optimizer'] ) | |||
try: | |||
self.trainer.reset() | |||
except: pass | |||
def adjust_learning_rate(self, optimizer, lr): | |||
for group in optimizer.param_groups: | |||
group['lr'] = lr | |||
def _get_default_lr_range(self, optimizer): | |||
if isinstance( optimizer, torch.optim.Adam ): | |||
return ( 1e-5, 1e-2 ) | |||
elif isinstance( optimizer, torch.optim.SGD ): | |||
return ( 1e-3, 0.2 ) | |||
else: | |||
return ( 1e-5, 0.5) | |||
def plot(self, polyfit: int=3, log_x=True): | |||
import matplotlib.pyplot as plt | |||
lrs = [ rec[0] for rec in self.records ] | |||
scores = [ rec[1] for rec in self.records ] | |||
if polyfit is not None: | |||
z = np.polyfit( lrs, scores, deg=polyfit ) | |||
fitted_score = np.polyval( z, lrs ) | |||
fig, ax = plt.subplots() | |||
ax.plot(lrs, scores, label='score') | |||
ax.plot(lrs, fitted_score, label='polyfit') | |||
if log_x: | |||
plt.xscale('log') | |||
ax.set_xlabel("Learning rate") | |||
ax.set_ylabel("Score") | |||
return fig | |||
def suggest( self, mode='min', skip_begin=10, skip_end=5, polyfit=None ): | |||
scores = np.array( [ self.records[i][1] for i in range( len(self.records) ) ] ) | |||
lrs = np.array( [ self.records[i][0] for i in range( len(self.records) ) ] ) | |||
if polyfit is not None: | |||
z = np.polyfit( lrs, scores, deg=polyfit ) | |||
scores = np.polyval( z, lrs ) | |||
grad = np.gradient( scores )[skip_begin:-skip_end] | |||
index = grad.argmin() if mode=='min' else grad.argmax() | |||
index = skip_begin + index | |||
return index, self.records[index][0] | |||
def find(self, | |||
optimizer, | |||
model, | |||
trainer, | |||
metric_name='total_loss', | |||
metric_mode='min', | |||
evaluator=None, | |||
lr_range=[1e-4, 0.1], | |||
max_iter=100, | |||
num_eval=None, | |||
smooth_momentum=0.9, | |||
scheduler='exp', # exp | |||
polyfit=None, # None | |||
skip=[10, 5], | |||
progress=True): | |||
self.optimizer = optimizer | |||
self.model = model | |||
self.trainer = trainer | |||
# save init state | |||
_filename = str(uuid.uuid4())+'.pth' | |||
_tempdir = tempfile.gettempdir() | |||
self._temp_file = os.path.join(_tempdir, _filename) | |||
init_state = { | |||
'optimizer': optimizer.state_dict(), | |||
'model': model.state_dict() | |||
} | |||
torch.save(init_state, self._temp_file) | |||
if num_eval is None or num_eval > max_iter: | |||
num_eval = max_iter | |||
if lr_range is None: | |||
lr_range = self._get_default_lr_range(optimizer) | |||
interval = max_iter // num_eval | |||
if scheduler=='exp': | |||
lr_sched = _ExponentialLRScheduler(optimizer, lr_range, max_iter=max_iter) | |||
else: | |||
lr_sched = _LinearLRScheduler(optimizer, lr_range, max_iter=max_iter) | |||
self._lr_callback = callbacks.LRSchedulerCallback(schedulers=[lr_sched]) | |||
self._finder_callback = _LRFinderCallback(model, optimizer, metric_name, evaluator, smooth_momentum) | |||
self.adjust_learning_rate( self.optimizer, lr_range[0] ) | |||
with self.trainer.save_current_callbacks(): | |||
trainer.add_callback( | |||
DefaultEvents.AFTER_STEP, callbacks=[ | |||
self._lr_callback, | |||
_ProgressCallback(max_iter, '[LR Finder]') ]) | |||
trainer.add_callback( | |||
DefaultEvents.AFTER_STEP(interval), callbacks=self._finder_callback) | |||
self.trainer.run( start_iter=0, max_iter=max_iter ) | |||
self.records = self._finder_callback.records # get records | |||
index, best_lr = self.suggest(mode=metric_mode, skip_begin=skip[0], skip_end=skip[1], polyfit=polyfit) | |||
self._reset() | |||
del self.model, self.optimizer, self.trainer | |||
return best_lr |
@@ -0,0 +1,129 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
import torch | |||
import torch.nn as nn | |||
from kamal.core.engine.engine import Engine, Event, DefaultEvents, State | |||
from kamal.core import tasks | |||
from kamal.utils import set_mode, move_to_device, get_logger, split_batch | |||
from typing import Callable, Mapping, Any, Sequence | |||
import time | |||
import weakref | |||
class BasicTrainer(Engine): | |||
def __init__( self, | |||
logger=None, | |||
tb_writer=None): | |||
super(BasicTrainer, self).__init__(logger=logger, tb_writer=tb_writer) | |||
def setup(self, | |||
model: torch.nn.Module, | |||
task: tasks.Task, | |||
dataloader: torch.utils.data.DataLoader, | |||
optimizer: torch.optim.Optimizer, | |||
device: torch.device=None): | |||
if device is None: | |||
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) | |||
self.device = device | |||
if isinstance(task, Sequence): | |||
task = tasks.TaskCompose(task) | |||
self.task = task | |||
self.model = model | |||
self.dataloader = dataloader | |||
self.optimizer = optimizer | |||
return self | |||
def run( self, max_iter, start_iter=0, epoch_length=None): | |||
self.model.to(self.device) | |||
with set_mode(self.model, training=True): | |||
super( BasicTrainer, self ).run( self.step_fn, self.dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length) | |||
def step_fn(self, engine, batch): | |||
model = self.model | |||
start_time = time.perf_counter() | |||
batch = move_to_device(batch, self.device) | |||
inputs, targets = split_batch(batch) | |||
outputs = model(inputs) | |||
loss_dict = self.task.get_loss(outputs, targets) # get loss | |||
loss = sum( loss_dict.values() ) | |||
self.optimizer.zero_grad() | |||
loss.backward() | |||
self.optimizer.step() | |||
step_time = time.perf_counter() - start_time | |||
metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } | |||
metrics.update({ | |||
'total_loss': loss.item(), | |||
'step_time': step_time, | |||
'lr': float( self.optimizer.param_groups[0]['lr'] ) | |||
}) | |||
return metrics | |||
class KDTrainer(BasicTrainer): | |||
def setup(self, | |||
student: torch.nn.Module, | |||
teacher: torch.nn.Module, | |||
task: tasks.Task, | |||
dataloader: torch.utils.data.DataLoader, | |||
optimizer: torch.optim.Optimizer, | |||
device: torch.device=None): | |||
super(KDTrainer, self).setup( | |||
model=student, task=task, dataloader=dataloader, optimizer=optimizer, device=device) | |||
if isinstance(teacher, (list, tuple)): | |||
if len(teacher)==1: | |||
teacher=teacher[0] | |||
else: | |||
teacher = nn.ModuleList(teacher) | |||
self.student = self.model | |||
self.teacher = teacher | |||
return self | |||
def run( self, max_iter, start_iter=0, epoch_length=None): | |||
self.student.to(self.device) | |||
self.teacher.to(self.device) | |||
with set_mode(self.student, training=True), \ | |||
set_mode(self.teacher, training=False): | |||
super( BasicTrainer, self ).run( | |||
self.step_fn, self.dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length) | |||
def step_fn(self, engine, batch): | |||
model = self.model | |||
start_time = time.perf_counter() | |||
batch = move_to_device(batch, self.device) | |||
inputs, targets = split_batch(batch) | |||
outputs = model(inputs) | |||
if isinstance(self.teacher, nn.ModuleList): | |||
soft_targets = [ t(inputs) for t in self.teacher ] | |||
else: | |||
soft_targets = self.teacher(inputs) | |||
loss_dict = self.task.get_loss(outputs, soft_targets) # get loss | |||
loss = sum( loss_dict.values() ) | |||
self.optimizer.zero_grad() | |||
loss.backward() | |||
self.optimizer.step() | |||
step_time = time.perf_counter() - start_time | |||
metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } | |||
metrics.update({ | |||
'total_loss': loss.item(), | |||
'step_time': step_time, | |||
'lr': float( self.optimizer.param_groups[0]['lr'] ) | |||
}) | |||
return metrics |
@@ -0,0 +1,22 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
class DataTypeError(object): | |||
pass | |||
class InvalidMapping(object): | |||
pass |
@@ -0,0 +1,2 @@ | |||
from ._hub import * | |||
from . import meta |
@@ -0,0 +1,288 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
from pip._internal import main as pipmain | |||
from typing import Sequence, Optional | |||
import torch | |||
import os, sys, re | |||
import shutil, inspect | |||
import importlib.util | |||
from glob import glob | |||
from ruamel.yaml import YAML | |||
from ruamel.yaml import comments | |||
from ._module_mapping import PACKAGE_NAME_TO_IMPORT_NAME | |||
import hashlib | |||
import inspect | |||
_DEFAULT_PROTOCOL = 2 | |||
_DEPENDENCY_FILE = "requirements.txt" | |||
_CODE_DIR = "code" | |||
_WEIGHTS_DIR = "weight" | |||
_TAGS_DIR = "tag" | |||
yaml = YAML() | |||
def _replace_invalid_char(name): | |||
return name.replace('-', '_').replace(' ', '_') | |||
def save( | |||
model: torch.nn.Module, | |||
save_path: str, | |||
entry_name: str, | |||
spec_name: str, | |||
code_path: str, # path to SOURCE_CODE_DIR or hubconf.py | |||
metadata: dict, | |||
tags: dict = None, | |||
ignore_files: Sequence = None, | |||
save_arch: bool = False, | |||
overwrite: bool = False, | |||
): | |||
entry_name = _replace_invalid_char(entry_name) | |||
if spec_name is not None: | |||
spec_name = _replace_invalid_char(spec_name) | |||
if not os.path.isdir(save_path): | |||
overwrite = True | |||
save_path = os.path.abspath(save_path) | |||
export_code_path = os.path.join(save_path, _CODE_DIR) | |||
export_weights_path = os.path.join(save_path, _WEIGHTS_DIR) | |||
os.makedirs(save_path, exist_ok=True) | |||
os.makedirs(export_code_path, exist_ok=True) | |||
os.makedirs(export_weights_path, exist_ok=True) | |||
code_path = os.path.abspath(code_path) | |||
if os.path.isdir( code_path ): | |||
if overwrite: | |||
shutil.rmtree(export_code_path) | |||
_copy_file_or_tree(src=code_path, dst=export_code_path) # overwrite old files | |||
elif code_path.endswith('.py'): | |||
if overwrite: | |||
shutil.copy2(src=code_path, dst=os.path.join( export_code_path, 'hubconf.py' )) # overwrite old files | |||
if hasattr(model, 'SETUP_INFO'): | |||
del model.SETUP_INFO | |||
if hasattr(model, 'METADATA'): | |||
del model.METADATA | |||
model_and_metadata = { | |||
'model': model if save_arch else model.state_dict(), | |||
'metadata': metadata, | |||
} | |||
temp_pth = os.path.join(export_weights_path, 'temp.pth') | |||
torch.save(model_and_metadata, temp_pth) | |||
if spec_name is None: | |||
_md5 = md5( temp_pth ) | |||
spec_name = _md5 | |||
shutil.move( temp_pth, os.path.join(export_weights_path, '%s-%s.pth'%(entry_name, spec_name)) ) | |||
if tags is not None: | |||
save_tags( tags, save_path, entry_name, spec_name ) | |||
def list_entry(path): | |||
path = os.path.abspath( os.path.expanduser( path ) ) | |||
if path.endswith('.py'): | |||
code_dir = os.path.dirname( path ) | |||
entry_file = path | |||
elif os.path.isdir(path): | |||
code_dir = os.path.join( path, _CODE_DIR ) | |||
entry_file = os.path.join(path, _CODE_DIR, 'app.py') | |||
else: | |||
raise NotImplementedError | |||
sys.path.insert(0, code_dir) | |||
spec = importlib.util.spec_from_file_location('app', entry_file) | |||
module = importlib.util.module_from_spec(spec) | |||
spec.loader.exec_module(module) | |||
entry_list = [ ( f, getattr(module, f) ) for f in dir(module) if callable(getattr(module, f)) and not f.startswith('_') ] | |||
sys.path.remove(code_dir) | |||
return entry_list | |||
def list_spec(path, entry_name=None): | |||
path = os.path.abspath( os.path.expanduser( path ) ) | |||
weight_dir = os.path.join( path, _WEIGHTS_DIR ) | |||
spec_list = [ f.split('-') for f in os.listdir( weight_dir ) if f.endswith('.pth')] | |||
spec_list = [ (f[0], f[1][:-4]) for f in spec_list ] | |||
if entry_name is not None: | |||
spec_list = [ s for s in spec_list if s[0]==entry_name ] | |||
return spec_list | |||
def load(path: str, entry_name:str=None, spec_name: str=None, pretrained=True, **kwargs): | |||
""" | |||
check dependencies and load pytorch models. | |||
Args: | |||
path: path to the model package | |||
pretrained: load the pretrained model | |||
""" | |||
path = os.path.abspath(path) | |||
if not os.path.exists(path): | |||
raise FileNotFoundError | |||
code_dir = os.path.join(path, _CODE_DIR) | |||
if os.path.isdir(path): | |||
cwd = os.getcwd() | |||
os.chdir(code_dir) | |||
sys.path.insert(0, code_dir) | |||
spec = importlib.util.spec_from_file_location( | |||
'hubconf', os.path.join(code_dir, 'hubconf.py')) | |||
module = importlib.util.module_from_spec(spec) | |||
spec.loader.exec_module(module) | |||
if hasattr( module, 'dependencies' ): | |||
deps = getattr( module, 'dependencies') | |||
for dep in deps: | |||
_import_with_auto_install(dep) | |||
if entry_name is None: # auto select | |||
pth_file = [pth for pth in os.listdir( os.path.join(path, _WEIGHTS_DIR) )] | |||
assert len(pth_file)<=1, "Loading models with more than one weight files (.pth) is ambiguous" | |||
pth_file = pth_file[0] | |||
entry_name = pth_file.split('-')[0] | |||
entry_fn = getattr( module, entry_name ) | |||
else: | |||
entry_fn = getattr( module, entry_name ) | |||
if spec_name is None: | |||
pth_file = [pth for pth in os.listdir( os.path.join(path, _WEIGHTS_DIR) ) if pth.startswith(entry_name) ] | |||
assert len(pth_file)<=1, "Loading models with more than one weight files (.pth) is ambiguous" | |||
pth_file = pth_file[0] | |||
else: | |||
pth_file = '%s-%s.pth'%(entry_name, spec_name) | |||
try: | |||
model_and_metadata = torch.load(os.path.join(path, _WEIGHTS_DIR, pth_file), map_location='cpu' ) | |||
except: raise FileNotFoundError | |||
if isinstance( model_and_metadata['model'], torch.nn.Module ): | |||
model = model_and_metadata['model'] | |||
else: | |||
entry_args = model_and_metadata['metadata']['entry_args'] | |||
if entry_args is None: | |||
entry_args = dict() | |||
model = entry_fn( **entry_args ) | |||
if pretrained: | |||
model.load_state_dict( model_and_metadata['model'], False) | |||
# setup metadata and atlas info | |||
model.METADATA = model_and_metadata['metadata'] | |||
model.SETUP_INFO = {"path": path, "entry_name": entry_name} | |||
sys.path.pop(0) | |||
os.chdir(cwd) | |||
return model | |||
raise NotImplementedError | |||
def load_metadata(path, entry_name=None, spec_name=None): | |||
path = os.path.abspath(path) | |||
if os.path.isdir(path): | |||
if entry_name is None: # auto select | |||
pth_file = [pth for pth in os.listdir( os.path.join(path, _WEIGHTS_DIR) )] | |||
assert len(pth_file)<=1, "Loading models with more than one weight files (.pth) is ambiguous" | |||
pth_file = pth_file[0] | |||
else: | |||
if spec_name is None: | |||
pth_file = [pth for pth in os.listdir( os.path.join(path, _WEIGHTS_DIR) ) if pth.startswith(entry_name) ] | |||
assert len(pth_file)<=1, "Loading models with more than one weight files (.pth) is ambiguous" | |||
pth_file = pth_file[0] | |||
else: | |||
pth_file = '%s-%s.pth'%(entry_name, spec_name) | |||
try: | |||
metadata = torch.load( os.path.join( path, _WEIGHTS_DIR, pth_file ), map_location='cpu' )['metadata'] | |||
except: | |||
FileNotFoundError | |||
return metadata | |||
def load_tags(path, entry_name, spec_name): | |||
path = os.path.abspath(path) | |||
tags_path = os.path.join(path, _TAGS_DIR, '%s-%s.yml'%( entry_name, spec_name )) | |||
if os.path.isfile(tags_path): | |||
return _to_python_type(_yaml_load(tags_path)) | |||
return dict() | |||
def save_tags(tags, path, entry_name, spec_name): | |||
path = os.path.abspath(path) | |||
if tags is None: | |||
tags = {} | |||
tags_path = os.path.join(path, _TAGS_DIR, '%s-%s.yml'%( entry_name, spec_name )) | |||
os.makedirs( os.path.join( path, _TAGS_DIR ), exist_ok=True ) | |||
_yaml_dump(tags_path, tags) | |||
def _yaml_dump(f, obj): | |||
with open(f, 'w') as f: | |||
yaml.dump(obj, f) | |||
def _yaml_load(f): | |||
with open(f, 'r') as f: | |||
return yaml.load(f) | |||
def _to_python_type(data): | |||
if isinstance(data, dict): | |||
for k, v in data.items(): | |||
data[k] = _to_python_type(v) | |||
return dict(data) | |||
elif isinstance(data, comments.CommentedSeq ): | |||
for idx, v in enumerate(data): | |||
data[idx] = _to_python_type(v) | |||
return list(data) | |||
return data | |||
def md5(fname): | |||
hash_md5 = hashlib.md5() | |||
with open(fname, "rb") as f: | |||
for chunk in iter(lambda: f.read(4096), b""): | |||
hash_md5.update(chunk) | |||
return hash_md5.hexdigest() | |||
def _get_package_name_and_version(package): | |||
_version_sym_list = ('==', '>=', '<=') | |||
for sym in _version_sym_list: | |||
if sym in package: | |||
return package.split(sym) | |||
return package, None | |||
def _import_with_auto_install(package): | |||
package_name, version = _get_package_name_and_version(package) | |||
package_name = package_name.strip() | |||
import_name = PACKAGE_NAME_TO_IMPORT_NAME.get( | |||
package_name, package_name).replace('-', '_') | |||
try: | |||
return __import__(import_name) | |||
except ImportError: | |||
try: | |||
pipmain.main(['install', package]) | |||
except: | |||
pipmain(['install', package]) | |||
return __import__(import_name) | |||
from distutils.dir_util import copy_tree | |||
def _copy_file_or_tree(src, dst): | |||
if os.path.isfile(src): | |||
shutil.copy2(src=src, dst=dst) | |||
else: | |||
copy_tree(src=src, dst=dst) | |||
def _glob_list(path_list, recursive=False): | |||
results = [] | |||
for path in path_list: | |||
if '*' in path: | |||
path = list(glob(path, recursive=recursive)) | |||
results.extend(path) | |||
else: | |||
results.append(path) | |||
return results |
@@ -0,0 +1,6 @@ | |||
PACKAGE_NAME_TO_IMPORT_NAME = { | |||
'opencv-python': 'cv2', | |||
'pillow': 'PIL', | |||
'scikit-learn': 'sklearn', | |||
'scikit-image': 'scikit-image', | |||
} |
@@ -0,0 +1,22 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
CLASSIFICATION = 'classification' | |||
SEGMENTATION = 'segmentation' | |||
DETECTION = 'detection' | |||
DEPTH = 'depth' | |||
AUTOENCODER = 'autoencoder' |
@@ -0,0 +1,3 @@ | |||
from .meta import Metadata | |||
from .input import ImageInput | |||
from . import TASK |
@@ -0,0 +1,31 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
from typing import Union, Sequence | |||
# INPUT Metadata | |||
def ImageInput( size: Union[Sequence[int], int], | |||
range: Union[Sequence[int], int], | |||
space: str, | |||
normalize: Union[Sequence]=None): | |||
assert space in ['rgb', 'gray', 'bgr', 'rgbd'] | |||
return dict( | |||
size=size, | |||
range=range, | |||
space=space, | |||
normalize=normalize | |||
) |
@@ -0,0 +1,44 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
from copy import deepcopy | |||
import abc, os | |||
from . import TASK | |||
import torch | |||
__all__ = ['yaml', 'ImageInput', 'MetaData', 'AtlasEntryBase'] | |||
# Model Metadata | |||
def Metadata( name: str, | |||
dataset: str, | |||
task: int, | |||
url: str, | |||
input: dict, | |||
entry_args: dict, | |||
other_metadata: dict): | |||
if task in [TASK.SEGMENTATION, TASK.CLASSIFICATION]: | |||
assert 'num_classes' in other_metadata | |||
return dict( | |||
name=name, | |||
dataset=dataset, | |||
task=task, | |||
url=url, | |||
input=dict(input), | |||
entry_args=entry_args, | |||
other_metadata=other_metadata | |||
) |
@@ -0,0 +1,8 @@ | |||
from .stream_metrics import Metric, MetricCompose | |||
from .accuracy import Accuracy, TopkAccuracy | |||
from .confusion_matrix import ConfusionMatrix, IoU, mIoU | |||
from .regression import * | |||
from .average import AverageMetric | |||
@@ -0,0 +1,118 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
import numpy as np | |||
import torch | |||
from kamal.core.metrics.stream_metrics import Metric | |||
from typing import Callable | |||
__all__=['Accuracy', 'TopkAccuracy'] | |||
class Accuracy(Metric): | |||
def __init__(self, attach_to=None): | |||
super(Accuracy, self).__init__(attach_to=attach_to) | |||
self.reset() | |||
@torch.no_grad() | |||
def update(self, outputs, targets): | |||
outputs, targets = self._attach(outputs, targets) | |||
outputs = outputs.max(1)[1] | |||
self._correct += ( outputs.view(-1)==targets.view(-1) ).sum() | |||
self._cnt += torch.numel( targets ) | |||
def get_results(self): | |||
return (self._correct / self._cnt).detach().cpu() | |||
def reset(self): | |||
self._correct = self._cnt = 0.0 | |||
class TopkAccuracy(Metric): | |||
def __init__(self, topk=5, attach_to=None): | |||
super(TopkAccuracy, self).__init__(attach_to=attach_to) | |||
self._topk = topk | |||
self.reset() | |||
@torch.no_grad() | |||
def update(self, outputs, targets): | |||
outputs, targets = self._attach(outputs, targets) | |||
_, outputs = outputs.topk(self._topk, dim=1, largest=True, sorted=True) | |||
correct = outputs.eq( targets.view(-1, 1).expand_as(outputs) ) | |||
self._correct += correct[:, :self._topk].view(-1).float().sum(0).item() | |||
self._cnt += len(targets) | |||
def get_results(self): | |||
return self._correct / self._cnt | |||
def reset(self): | |||
self._correct = 0.0 | |||
self._cnt = 0.0 | |||
class StreamCEMAPMetrics(): | |||
@property | |||
def PRIMARY_METRIC(self): | |||
return "eap" | |||
def __init__(self): | |||
self.reset() | |||
def update(self, logits, targets): | |||
preds = logits.max(1)[1] | |||
# targets: -1 negative, 0 difficult, 1 positive | |||
if isinstance(preds, torch.Tensor): | |||
preds = preds.cpu().numpy() | |||
targets = targets.cpu().numpy() | |||
self.preds = preds if self.preds is None else np.append(self.preds, preds, axis=0) | |||
self.targets = targets if self.targets is None else np.append(self.targets, targets, axis=0) | |||
def get_results(self): | |||
nTest = self.targets.shape[0] | |||
nLabel = self.targets.shape[1] | |||
eap = np.zeros(nTest) | |||
for i in range(0,nTest): | |||
R = np.sum(self.targets[i,:]==1) | |||
for j in range(0,nLabel): | |||
if self.targets[i,j]==1: | |||
r = np.sum(self.preds[i,np.nonzero(self.targets[i,:]!=0)]>=self.preds[i,j]) | |||
rb = np.sum(self.preds[i,np.nonzero(self.targets[i,:]==1)] >= self.preds[i,j]) | |||
eap[i] = eap[i] + rb/(r*1.0) | |||
eap[i] = eap[i]/R | |||
# emap = np.nanmean(ap) | |||
cap = np.zeros(nLabel) | |||
for i in range(0,nLabel): | |||
R = np.sum(self.targets[:,i]==1) | |||
for j in range(0,nTest): | |||
if self.targets[j,i]==1: | |||
r = np.sum(self.preds[np.nonzero(self.targets[:,i]!=0),i] >= self.preds[j,i]) | |||
rb = np.sum(self.preds[np.nonzero(self.targets[:,i]==1),i] >= self.preds[j,i]) | |||
cap[i] = cap[i] + rb/(r*1.0) | |||
cap[i] = cap[i]/R | |||
# cmap = np.nanmean(ap) | |||
return { | |||
'eap': eap, | |||
'emap': np.nanmean(eap), | |||
'cap': cap, | |||
'cmap': np.nanmean(cap), | |||
} | |||
def reset(self): | |||
self.preds = None | |||
self.targets = None |
@@ -0,0 +1,49 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
import numpy as np | |||
import torch | |||
from kamal.core.metrics.stream_metrics import Metric | |||
from typing import Callable | |||
__all__=['AverageMetric'] | |||
class AverageMetric(Metric): | |||
def __init__(self, fn:Callable, attach_to=None): | |||
super(AverageMetric, self).__init__(attach_to=attach_to) | |||
self._fn = fn | |||
self.reset() | |||
@torch.no_grad() | |||
def update(self, outputs, targets): | |||
outputs, targets = self._attach(outputs, targets) | |||
m = self._fn( outputs, targets ) | |||
if m.ndim > 1: | |||
self._cnt += m.shape[0] | |||
self._accum += m.sum(0) | |||
else: | |||
self._cnt += 1 | |||
self._accum += m | |||
def get_results(self): | |||
return (self._accum / self._cnt).detach().cpu() | |||
def reset(self): | |||
self._accum = 0. | |||
self._cnt = 0. |
@@ -0,0 +1,68 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
from .stream_metrics import Metric | |||
import torch | |||
from typing import Callable | |||
class ConfusionMatrix(Metric): | |||
def __init__(self, num_classes, ignore_idx=None, attach_to=None): | |||
super(ConfusionMatrix, self).__init__(attach_to=attach_to) | |||
self._num_classes = num_classes | |||
self._ignore_idx = ignore_idx | |||
self.reset() | |||
@torch.no_grad() | |||
def update(self, outputs, targets): | |||
outputs, targets = self._attach(outputs, targets) | |||
if self.confusion_matrix.device != outputs.device: | |||
self.confusion_matrix = self.confusion_matrix.to(device=outputs.device) | |||
preds = outputs.max(1)[1].flatten() | |||
targets = targets.flatten() | |||
mask = (preds<self._num_classes) & (preds>=0) | |||
if self._ignore_idx: | |||
mask = mask & (targets!=self._ignore_idx) | |||
preds, targets = preds[mask], targets[mask] | |||
hist = torch.bincount( self._num_classes * targets + preds, | |||
minlength=self._num_classes ** 2 ).view(self._num_classes, self._num_classes) | |||
self.confusion_matrix += hist | |||
def get_results(self): | |||
return self.confusion_matrix.detach().cpu() | |||
def reset(self): | |||
self._cnt = 0 | |||
self.confusion_matrix = torch.zeros(self._num_classes, self._num_classes, dtype=torch.int64, requires_grad=False) | |||
class IoU(Metric): | |||
def __init__(self, confusion_matrix: ConfusionMatrix, attach_to=None): | |||
self._confusion_matrix = confusion_matrix | |||
def update(self, outputs, targets): | |||
pass # update will be done in confusion matrix | |||
def reset(self): | |||
pass | |||
def get_results(self): | |||
cm = self._confusion_matrix.get_results() | |||
iou = cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) - cm.diag() + 1e-9) | |||
return iou | |||
class mIoU(IoU): | |||
def get_results(self): | |||
return super(mIoU, self).get_results().mean() |
@@ -0,0 +1,86 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
import numpy as np | |||
import torch | |||
from kamal.core.metrics.stream_metrics import StreamMetricsBase | |||
class NormalPredictionMetrics(StreamMetricsBase): | |||
@property | |||
def PRIMARY_METRIC(self): | |||
return 'mean angle' | |||
def __init__(self, thresholds): | |||
self.thresholds = thresholds | |||
self.preds = None | |||
self.targets = None | |||
self.masks = None | |||
@torch.no_grad() | |||
def update(self, preds, targets, masks): | |||
""" | |||
**Type**: numpy.ndarray or torch.Tensor | |||
**Shape:** | |||
- **preds**: $(N, 3, H, W)$. | |||
- **targets**: $(N, 3, H, W)$. | |||
- **masks**: $(N, 1, H, W)$. | |||
""" | |||
if isinstance(preds, torch.Tensor): | |||
preds = preds.cpu().numpy() | |||
targets = targets.cpu().numpy() | |||
masks = masks.cpu().numpy() | |||
self.preds = preds if self.preds is None else np.append(self.preds, preds, axis=0) | |||
self.targets = targets if self.targets is None else np.append(self.targets, targets, axis=0) | |||
self.masks = masks if self.masks is None else np.append(self.masks, masks, axis=0) | |||
def get_results(self): | |||
""" | |||
**Returns:** | |||
- **mean angle** | |||
- **median angle** | |||
- **precents for angle within thresholds** | |||
""" | |||
products = np.sum(self.preds * self.targets, axis=1) | |||
angles = np.arccos(np.clip(products, -1.0, 1.0)) / np.pi * 180 | |||
self.masks = self.masks.squeeze(1) | |||
angles = angles[self.masks == 1] | |||
mean_angle = np.mean(angles) | |||
median_angle = np.median(angles) | |||
count = self.masks.sum() | |||
threshold_percents = {} | |||
for threshold in self.thresholds: | |||
# threshold_percents[threshold] = np.sum((angles < threshold)) / count | |||
threshold_percents[threshold] = np.mean(angles < threshold) | |||
if return_key_metric: | |||
return ('absolute relative', ard) | |||
return { | |||
'mean angle': mean_angle, | |||
'median angle': median_angle, | |||
'percents within thresholds': threshold_percents | |||
} | |||
def reset(self): | |||
self.preds = None | |||
self.targets = None | |||
self.masks = None |
@@ -0,0 +1,199 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
import numpy as np | |||
import torch | |||
from kamal.core.metrics.stream_metrics import Metric | |||
from typing import Callable | |||
__all__=['MeanSquaredError', 'RootMeanSquaredError', 'MeanAbsoluteError', | |||
'ScaleInveriantMeanSquaredError', 'RelativeDifference', | |||
'AbsoluteRelativeDifference', 'SquaredRelativeDifference', 'Threshold' ] | |||
class MeanSquaredError(Metric): | |||
def __init__(self, log_scale=False, attach_to=None): | |||
super(MeanSquaredError, self).__init__( attach_to=attach_to ) | |||
self.reset() | |||
self.log_scale=log_scale | |||
@torch.no_grad() | |||
def update(self, outputs, targets): | |||
outputs, targets = self._attach(outputs, targets) | |||
if self.log_scale: | |||
diff = torch.sum((torch.log(outputs+1e-8) - torch.log(targets+1e-8))**2) | |||
else: | |||
diff = torch.sum((outputs - targets)**2) | |||
self._accum_sq_diff += diff | |||
self._cnt += torch.numel(outputs) | |||
def get_results(self): | |||
return (self._accum_sq_diff / self._cnt).detach().cpu() | |||
def reset(self): | |||
self._accum_sq_diff = 0. | |||
self._cnt = 0. | |||
class RootMeanSquaredError(MeanSquaredError): | |||
def get_results(self): | |||
return torch.sqrt( (self._accum_sq_diff / self._cnt) ).detach().cpu() | |||
class MeanAbsoluteError(Metric): | |||
def __init__(self, attach_to=None ): | |||
super(MeanAbsoluteError, self).__init__( attach_to=attach_to ) | |||
self.reset() | |||
@torch.no_grad() | |||
def update(self, outputs, targets): | |||
outputs, targets = self._attach(outputs, targets) | |||
diff = torch.sum((outputs - targets).abs()) | |||
self._accum_abs_diff += diff | |||
self._cnt += torch.numel(outputs) | |||
def get_results(self): | |||
return (self._accum_abs_diff / self._cnt).detach().cpu() | |||
def reset(self): | |||
self._accum_abs_diff = 0. | |||
self._cnt = 0. | |||
class ScaleInveriantMeanSquaredError(Metric): | |||
def __init__(self, attach_to=None ): | |||
super(ScaleInveriantMeanSquaredError, self).__init__( attach_to=attach_to ) | |||
self.reset() | |||
@torch.no_grad() | |||
def update(self, outputs, targets): | |||
outputs, targets = self._attach(outputs, targets) | |||
diff_log = torch.log( outputs+1e-8 ) - torch.log( targets+1e-8 ) | |||
self._accum_log_diff = diff_log.sum() | |||
self._accum_sq_log_diff = (diff_log**2).sum() | |||
self._cnt += torch.numel(outputs) | |||
def get_results(self): | |||
return ( self._accum_sq_log_diff / self._cnt - 0.5 * (self._accum_log_diff**2 / self._cnt**2) ).detach().cpu() | |||
def reset(self): | |||
self._accum_log_diff = 0. | |||
self._accum_sq_log_diff = 0. | |||
self._cnt = 0. | |||
class RelativeDifference(Metric): | |||
def __init__(self, attach_to=None ): | |||
super(RelativeDifference, self).__init__( attach_to=attach_to ) | |||
self.reset() | |||
@torch.no_grad() | |||
def update(self, outputs, targets): | |||
outputs, targets = self._attach(outputs, targets) | |||
diff = (outputs - targets).abs() | |||
self._accum_abs_rel += (diff/targets).sum() | |||
self._cnt += torch.numel(outputs) | |||
def get_results(self): | |||
return (self._accum_abs_rel / self._cnt).detach().cpu() | |||
def reset(self): | |||
self._accum_abs_rel = 0. | |||
self._cnt = 0. | |||
class AbsoluteRelativeDifference(Metric): | |||
def __init__(self, attach_to=None ): | |||
super(AbsoluteRelativeDifference, self).__init__( attach_to=attach_to ) | |||
self.reset() | |||
@torch.no_grad() | |||
def update(self, outputs, targets): | |||
outputs, targets = self._attach(outputs, targets) | |||
diff = (outputs - targets).abs() | |||
self._accum_abs_rel += (diff/targets).sum() | |||
self._cnt += torch.numel(outputs) | |||
def get_results(self): | |||
return (self._accum_abs_rel / self._cnt).detach().cpu() | |||
def reset(self): | |||
self._accum_abs_rel = 0. | |||
self._cnt = 0. | |||
class AbsoluteRelativeDifference(Metric): | |||
def __init__(self, attach_to=None ): | |||
super(AbsoluteRelativeDifference, self).__init__( attach_to=attach_to ) | |||
self.reset() | |||
@torch.no_grad() | |||
def update(self, outputs, targets): | |||
outputs, targets = self._attach(outputs, targets) | |||
diff = (outputs - targets).abs() | |||
self._accum_abs_rel += (diff/targets).sum() | |||
self._cnt += torch.numel(outputs) | |||
def get_results(self): | |||
return (self._accum_abs_rel / self._cnt).detach().cpu() | |||
def reset(self): | |||
self._accum_abs_rel = 0. | |||
self._cnt = 0. | |||
class SquaredRelativeDifference(Metric): | |||
def __init__(self, attach_to=None ): | |||
super(SquaredRelativeDifference, self).__init__( attach_to=attach_to ) | |||
self.reset() | |||
@torch.no_grad() | |||
def update(self, outputs, targets): | |||
outputs, targets = self._attach(outputs, targets) | |||
diff = (outputs - targets)**2 | |||
self._accum_sq_rel += (diff/targets).sum() | |||
self._cnt += torch.numel(outputs) | |||
def get_results(self): | |||
return (self._accum_sq_rel / self._cnt).detach().cpu() | |||
def reset(self): | |||
self._accum_sq_rel = 0. | |||
self._cnt = 0. | |||
class Threshold(Metric): | |||
def __init__(self, thresholds=[1.25, 1.25**2, 1.25**3], attach_to=None ): | |||
super(Threshold, self).__init__( attach_to=attach_to ) | |||
self.thresholds = thresholds | |||
self.reset() | |||
@torch.no_grad() | |||
def update(self, outputs, targets): | |||
outputs, targets = self._attach(outputs, targets) | |||
sigma = torch.max(outputs / targets, targets / outputs) | |||
for thres in self.thresholds: | |||
self._accum_thres[thres]+=torch.sum( sigma<thres ) | |||
self._cnt += torch.numel(outputs) | |||
def get_results(self): | |||
return { thres: (self._accum_thres[thres] / self._cnt).detach().cpu() for thres in self.thresholds } | |||
def reset(self): | |||
self._cnt = 0. | |||
self._accum_thres = {thres: 0. for thres in self.thresholds} |
@@ -0,0 +1,82 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
from __future__ import division | |||
import torch | |||
import numpy as np | |||
from abc import ABC, abstractmethod | |||
from typing import Callable, Union, Any, Mapping, Sequence | |||
import numbers | |||
from kamal.core.attach import AttachTo | |||
class Metric(ABC): | |||
def __init__(self, attach_to=None): | |||
self._attach = AttachTo(attach_to) | |||
@abstractmethod | |||
def update(self, pred, target): | |||
""" Overridden by subclasses """ | |||
raise NotImplementedError() | |||
@abstractmethod | |||
def get_results(self): | |||
""" Overridden by subclasses """ | |||
raise NotImplementedError() | |||
@abstractmethod | |||
def reset(self): | |||
""" Overridden by subclasses """ | |||
raise NotImplementedError() | |||
class MetricCompose(dict): | |||
def __init__(self, metric_dict: Mapping): | |||
self._metric_dict = metric_dict | |||
def add_metrics( self, metric_dict: Mapping): | |||
if isinstance(metric_dict, MetricCompose): | |||
metric_dict = metric_dict.metrics | |||
self._metric_dict.update(metric_dict) | |||
return self | |||
@property | |||
def metrics(self): | |||
return self._metric_dict | |||
@torch.no_grad() | |||
def update(self, outputs, targets): | |||
for key, metric in self._metric_dict.items(): | |||
if isinstance(metric, Metric): | |||
metric.update(outputs, targets) | |||
def get_results(self): | |||
results = {} | |||
for key, metric in self._metric_dict.items(): | |||
if isinstance(metric, Metric): | |||
results[key] = metric.get_results() | |||
return results | |||
def reset(self): | |||
for key, metric in self._metric_dict.items(): | |||
if isinstance(metric, Metric): | |||
metric.reset() | |||
def __getitem__(self, name): | |||
return self._metric_dict[name] | |||
@@ -0,0 +1,3 @@ | |||
from .task import StandardTask, StandardMetrics, GeneralTask, Task, TaskCompose | |||
from . import loss | |||
@@ -0,0 +1,2 @@ | |||
from . import functional, loss | |||
from .loss import * |
@@ -0,0 +1,107 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
import torch | |||
import torch.nn.functional as F | |||
import torch.nn as nn | |||
def kldiv(logits, targets, T=1.0): | |||
""" Cross Entropy for soft targets | |||
Parameters: | |||
- logits (Tensor): logits score (e.g. outputs of fc layer) | |||
- targets (Tensor): logits of soft targets | |||
- T (float): temperature of distill | |||
- reduction: reduction to the output | |||
""" | |||
p_targets = F.softmax(targets/T, dim=1) | |||
logp_logits = F.log_softmax(logits/T, dim=1) | |||
kld = F.kl_div(logp_logits, p_targets, reduction='none') * (T**2) | |||
return kld.sum(1).mean() | |||
def jsdiv(logits, targets, T=1.0, reduction='mean'): | |||
p = F.softmax(logits, dim=1) | |||
q = F.softmax(targets, dim=1) | |||
log_m = torch.log( (p+q) / 2 ) | |||
return 0.5* ( F.kl_div( log_m, p, reduction=reduction) + F.kl_div( log_m, q, reduction=reduction) ) | |||
def mmd_loss(f1, f2, sigmas, normalized=False): | |||
if len(f1.shape) != 2: | |||
N, C, H, W = f1.shape | |||
f1 = f1.view(N, -1) | |||
N, C, H, W = f2.shape | |||
f2 = f2.view(N, -1) | |||
if normalized == True: | |||
f1 = F.normalize(f1, p=2, dim=1) | |||
f2 = F.normalize(f2, p=2, dim=1) | |||
return _mmd_rbf2(f1, f2, sigmas=sigmas) | |||
def psnr(img1, img2, size_average=True, data_range=255): | |||
N = img1.shape[0] | |||
mse = torch.mean(((img1-img2)**2).view(N, -1), dim=1) | |||
psnr = torch.clamp(torch.log10(data_range**2 / mse) * 10, 0.0, 99.99) | |||
if size_average == True: | |||
psnr = psnr.mean() | |||
return psnr | |||
def soft_cross_entropy(logits, targets, T=1.0, size_average=True): | |||
""" Cross Entropy for soft targets | |||
**Parameters:** | |||
- **logits** (Tensor): logits score (e.g. outputs of fc layer) | |||
- **targets** (Tensor): logits of soft targets | |||
- **T** (float): temperature of distill | |||
- **size_average**: average the outputs | |||
""" | |||
p_targets = F.softmax(targets/T, dim=1) | |||
logp_pred = F.log_softmax(logits/T, dim=1) | |||
ce = torch.sum(-p_targets * logp_pred, dim=1) | |||
if size_average: | |||
return ce.mean() * T * T | |||
else: | |||
return ce * T * T | |||
def _mmd_rbf2(x, y, sigmas=None): | |||
N, _ = x.shape | |||
xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t()) | |||
rx = (xx.diag().unsqueeze(0).expand_as(xx)) | |||
ry = (yy.diag().unsqueeze(0).expand_as(yy)) | |||
K = L = P = 0.0 | |||
XX2 = rx.t() + rx - 2*xx | |||
YY2 = ry.t() + ry - 2*yy | |||
XY2 = rx.t() + ry - 2*zz | |||
if sigmas is None: | |||
sigma2 = torch.mean((XX2.detach()+YY2.detach()+2*XY2.detach()) / 4) | |||
sigmas2 = [sigma2/4, sigma2/2, sigma2, sigma2*2, sigma2*4] | |||
alphas = [1.0 / (2 * sigma2) for sigma2 in sigmas2] | |||
else: | |||
alphas = [1.0 / (2 * sigma**2) for sigma in sigmas] | |||
for alpha in alphas: | |||
K += torch.exp(- alpha * (XX2.clamp(min=1e-12))) | |||
L += torch.exp(- alpha * (YY2.clamp(min=1e-12))) | |||
P += torch.exp(- alpha * (XY2.clamp(min=1e-12))) | |||
beta = (1./(N*(N))) | |||
gamma = (2./(N*N)) | |||
return F.relu(beta * (torch.sum(K)+torch.sum(L)) - gamma * torch.sum(P)) |
@@ -0,0 +1,386 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
import torch | |||
import torch.nn.functional as F | |||
import torch.nn as nn | |||
import numpy as np | |||
#from pytorch_msssim import ssim, ms_ssim, MS_SSIM, SSIM | |||
from .functional import * | |||
class KLDiv(object): | |||
def __init__(self, T=1.0): | |||
self.T = T | |||
def __call__(self, logits, targets): | |||
return kldiv( logits, targets, T=self.T ) | |||
class KDLoss(nn.Module): | |||
""" KD Loss Function | |||
""" | |||
def __init__(self, T=1.0, alpha=1.0, use_kldiv=False): | |||
super(KDLoss, self).__init__() | |||
self.T = T | |||
self.alpha = alpha | |||
self.kdloss = kldiv if use_kldiv else soft_cross_entropy | |||
def forward(self, logits, targets, hard_targets=None): | |||
loss = self.kdloss(logits, targets, T=self.T) | |||
if hard_targets is not None and self.alpha != 0.0: | |||
loss += self.alpha*F.cross_entropy(logits, hard_targets) | |||
return loss | |||
class CFLLoss(nn.Module): | |||
""" Common Feature Learning Loss | |||
CFL Loss = MMD + MSE | |||
""" | |||
def __init__(self, sigmas, normalized=True): | |||
super(CFLLoss, self).__init__() | |||
self.sigmas = sigmas | |||
self.normalized = normalized | |||
def forward(self, hs, hts, fts_, fts): | |||
mmd = mse = 0.0 | |||
for ht_i in hts: | |||
mmd += mmd_loss(hs, ht_i, sigmas=self.sigmas, normalized=self.normalized) | |||
for i in range(len(fts_)): | |||
mse += F.mse_loss(fts_[i], fts[i]) | |||
return mmd, mse | |||
class PSNR_Loss(nn.Module): | |||
def __init__(self, data_range=1.0, size_average=True): | |||
super(PSNR_Loss, self).__init__() | |||
self.data_range = data_range | |||
self.size_average = size_average | |||
def forward(self, img1, img2): | |||
return 100 - psnr(img1, img2, size_average=self.size_average, data_range=self.data_range) | |||
#class MS_SSIM_Loss(MS_SSIM): | |||
# def forward(self, img1, img2): | |||
# return 100*(1 - super(MS_SSIM_Loss, self).forward(img1, img2)) | |||
class ScaleInvariantLoss(nn.Module): | |||
"""This criterion is used in depth prediction task. | |||
**Parameters:** | |||
- **la** (int, optional): Default value is 0.5. No need to change. | |||
- **ignore_index** (int, optional): Value to ignore. | |||
**Shape:** | |||
- **inputs**: $(N, H, W)$. | |||
- **targets**: $(N, H, W)$. | |||
- **output**: scalar. | |||
""" | |||
def __init__(self, la=0.5, ignore_index=0): | |||
super(ScaleInvariantLoss, self).__init__() | |||
self.la = la | |||
self.ignore_index = ignore_index | |||
def forward(self, inputs, targets): | |||
size = inputs.size() | |||
if len(size) == 3: | |||
inputs = inputs.view(size[0], -1) | |||
targets = targets.view(size[0], -1) | |||
inv_mask = targets.eq(self.ignore_index) | |||
nums = (1-inv_mask.float()).sum(1) | |||
log_d = torch.log(inputs) - torch.log(targets) | |||
log_d[inv_mask] = 0 | |||
loss = torch.div(torch.pow(log_d, 2).sum(1), nums) - \ | |||
self.la * torch.pow(torch.div(log_d.sum(1), nums), 2) | |||
return loss.mean() | |||
class AngleLoss(nn.Module): | |||
"""This criterion is used in surface normal prediction task. | |||
**Shape:** | |||
- **inputs**: $(N, 3, H, W)$. Predicted space vector for each pixel. Must be formalized before. | |||
- **targets**: $(N, 3, H, W)$. Ground truth. Must be formalized before. | |||
- **masks**: $(N, 1, H, W)$. One for valid pixels, else zero. | |||
- **output**: scalar. | |||
""" | |||
def forward(self, inputs, targets, masks): | |||
nums = masks.sum(dim=[1,2,3]) | |||
product = (inputs * targets).sum(1, keepdim=True) | |||
loss = -torch.div((product * masks.float()).sum([1,2,3]), nums) | |||
return loss.mean() | |||
class FocalLoss(nn.Module): | |||
def __init__(self, alpha=1, gamma=0, size_average=True, ignore_index=255): | |||
super(FocalLoss, self).__init__() | |||
self.alpha = alpha | |||
self.gamma = gamma | |||
self.ignore_index = ignore_index | |||
self.size_average = size_average | |||
def forward(self, inputs, targets): | |||
ce_loss = F.cross_entropy(inputs, targets, reduction='none', ignore_index=self.ignore_index) | |||
pt = torch.exp(-ce_loss) | |||
focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss | |||
if self.size_average: | |||
return focal_loss.mean() | |||
else: | |||
return focal_loss.sum() | |||
class AttentionLoss(nn.Module): | |||
""" Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer""" | |||
def __init__(self, p=2): | |||
super(AttentionLoss, self).__init__() | |||
self.p = p | |||
def forward(self, g_s, g_t): | |||
return sum([self.at_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]) | |||
def at_loss(self, f_s, f_t): | |||
s_H, t_H = f_s.shape[2], f_t.shape[2] | |||
if s_H > t_H: | |||
f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H)) | |||
elif s_H < t_H: | |||
f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H)) | |||
else: | |||
pass | |||
return (self.at(f_s) - self.at(f_t)).pow(2).mean() | |||
def at(self, f): | |||
return F.normalize(f.pow(self.p).mean(1).view(f.size(0), -1)) | |||
class NSTLoss(nn.Module): | |||
"""like what you like: knowledge distill via neuron selectivity transfer""" | |||
def __init__(self): | |||
super(NSTLoss, self).__init__() | |||
pass | |||
def forward(self, g_s, g_t): | |||
return sum([self.nst_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]) | |||
def nst_loss(self, f_s, f_t): | |||
s_H, t_H = f_s.shape[2], f_t.shape[2] | |||
if s_H > t_H: | |||
f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H)) | |||
elif s_H < t_H: | |||
f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H)) | |||
else: | |||
pass | |||
f_s = f_s.view(f_s.shape[0], f_s.shape[1], -1) | |||
f_s = F.normalize(f_s, dim=2) | |||
f_t = f_t.view(f_t.shape[0], f_t.shape[1], -1) | |||
f_t = F.normalize(f_t, dim=2) | |||
# set full_loss as False to avoid unnecessary computation | |||
full_loss = True | |||
if full_loss: | |||
return (self.poly_kernel(f_t, f_t).mean().detach() + self.poly_kernel(f_s, f_s).mean() | |||
- 2 * self.poly_kernel(f_s, f_t).mean()) | |||
else: | |||
return self.poly_kernel(f_s, f_s).mean() - 2 * self.poly_kernel(f_s, f_t).mean() | |||
def poly_kernel(self, a, b): | |||
a = a.unsqueeze(1) | |||
b = b.unsqueeze(2) | |||
res = (a * b).sum(-1).pow(2) | |||
return res | |||
class SPLoss(nn.Module): | |||
"""Similarity-Preserving Knowledge Distillation, ICCV2019, verified by original author""" | |||
def __init__(self): | |||
super(SPLoss, self).__init__() | |||
def forward(self, g_s, g_t): | |||
return sum([self.similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]) | |||
def similarity_loss(self, f_s, f_t): | |||
bsz = f_s.shape[0] | |||
f_s = f_s.view(bsz, -1) | |||
f_t = f_t.view(bsz, -1) | |||
G_s = torch.mm(f_s, torch.t(f_s)) | |||
# G_s = G_s / G_s.norm(2) | |||
G_s = torch.nn.functional.normalize(G_s) | |||
G_t = torch.mm(f_t, torch.t(f_t)) | |||
# G_t = G_t / G_t.norm(2) | |||
G_t = torch.nn.functional.normalize(G_t) | |||
G_diff = G_t - G_s | |||
loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz) | |||
return loss | |||
class RKDLoss(nn.Module): | |||
"""Relational Knowledge Disitllation, CVPR2019""" | |||
def __init__(self, w_d=25, w_a=50): | |||
super(RKDLoss, self).__init__() | |||
self.w_d = w_d | |||
self.w_a = w_a | |||
def forward(self, f_s, f_t): | |||
student = f_s.view(f_s.shape[0], -1) | |||
teacher = f_t.view(f_t.shape[0], -1) | |||
# RKD distance loss | |||
with torch.no_grad(): | |||
t_d = self.pdist(teacher, squared=False) | |||
mean_td = t_d[t_d > 0].mean() | |||
t_d = t_d / mean_td | |||
d = self.pdist(student, squared=False) | |||
mean_d = d[d > 0].mean() | |||
d = d / mean_d | |||
loss_d = F.smooth_l1_loss(d, t_d) | |||
# RKD Angle loss | |||
with torch.no_grad(): | |||
td = (teacher.unsqueeze(0) - teacher.unsqueeze(1)) | |||
norm_td = F.normalize(td, p=2, dim=2) | |||
t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1) | |||
sd = (student.unsqueeze(0) - student.unsqueeze(1)) | |||
norm_sd = F.normalize(sd, p=2, dim=2) | |||
s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1) | |||
loss_a = F.smooth_l1_loss(s_angle, t_angle) | |||
loss = self.w_d * loss_d + self.w_a * loss_a | |||
return loss | |||
@staticmethod | |||
def pdist(e, squared=False, eps=1e-12): | |||
e_square = e.pow(2).sum(dim=1) | |||
prod = e @ e.t() | |||
res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps) | |||
if not squared: | |||
res = res.sqrt() | |||
res = res.clone() | |||
res[range(len(e)), range(len(e))] = 0 | |||
return res | |||
class PKTLoss(nn.Module): | |||
"""Probabilistic Knowledge Transfer for deep representation learning""" | |||
def __init__(self): | |||
super(PKTLoss, self).__init__() | |||
def forward(self, f_s, f_t): | |||
return self.cosine_similarity_loss(f_s, f_t) | |||
@staticmethod | |||
def cosine_similarity_loss(output_net, target_net, eps=0.0000001): | |||
# Normalize each vector by its norm | |||
output_net_norm = torch.sqrt(torch.sum(output_net ** 2, dim=1, keepdim=True)) | |||
output_net = output_net / (output_net_norm + eps) | |||
output_net[output_net != output_net] = 0 | |||
target_net_norm = torch.sqrt(torch.sum(target_net ** 2, dim=1, keepdim=True)) | |||
target_net = target_net / (target_net_norm + eps) | |||
target_net[target_net != target_net] = 0 | |||
# Calculate the cosine similarity | |||
model_similarity = torch.mm(output_net, output_net.transpose(0, 1)) | |||
target_similarity = torch.mm(target_net, target_net.transpose(0, 1)) | |||
# Scale cosine similarity to 0..1 | |||
model_similarity = (model_similarity + 1.0) / 2.0 | |||
target_similarity = (target_similarity + 1.0) / 2.0 | |||
# Transform them into probabilities | |||
model_similarity = model_similarity / torch.sum(model_similarity, dim=1, keepdim=True) | |||
target_similarity = target_similarity / torch.sum(target_similarity, dim=1, keepdim=True) | |||
# Calculate the KL-divergence | |||
loss = torch.mean(target_similarity * torch.log((target_similarity + eps) / (model_similarity + eps))) | |||
return loss | |||
class SVDLoss(nn.Module): | |||
""" | |||
Self-supervised Knowledge Distillation using Singular Value Decomposition | |||
""" | |||
def __init__(self, k=1): | |||
super(SVDLoss, self).__init__() | |||
self.k = k | |||
def forward(self, g_s, g_t): | |||
v_sb = None | |||
v_tb = None | |||
losses = [] | |||
for i, f_s, f_t in zip(range(len(g_s)), g_s, g_t): | |||
u_t, s_t, v_t = self.svd(f_t, self.k) | |||
u_s, s_s, v_s = self.svd(f_s, self.k + 3) | |||
v_s, v_t = self.align_rsv(v_s, v_t) | |||
s_t = s_t.unsqueeze(1) | |||
v_t = v_t * s_t | |||
v_s = v_s * s_t | |||
if i > 0: | |||
s_rbf = torch.exp(-(v_s.unsqueeze(2) - v_sb.unsqueeze(1)).pow(2) / 8) | |||
t_rbf = torch.exp(-(v_t.unsqueeze(2) - v_tb.unsqueeze(1)).pow(2) / 8) | |||
l2loss = (s_rbf - t_rbf.detach()).pow(2) | |||
l2loss = torch.where(torch.isfinite(l2loss), l2loss, torch.zeros_like(l2loss)) | |||
losses.append(l2loss.sum()) | |||
v_tb = v_t | |||
v_sb = v_s | |||
bsz = g_s[0].shape[0] | |||
losses = [l / bsz for l in losses] | |||
return sum(losses) | |||
def svd(self, feat, n=1): | |||
size = feat.shape | |||
assert len(size) == 4 | |||
x = feat.view(-1, size[1], size[2] * size[2]).transpose(-2, -1) | |||
u, s, v = torch.svd(x) | |||
u = self.removenan(u) | |||
s = self.removenan(s) | |||
v = self.removenan(v) | |||
if n > 0: | |||
u = F.normalize(u[:, :, :n], dim=1) | |||
s = F.normalize(s[:, :n], dim=1) | |||
v = F.normalize(v[:, :, :n], dim=1) | |||
return u, s, v | |||
@staticmethod | |||
def removenan(x): | |||
x = torch.where(torch.isfinite(x), x, torch.zeros_like(x)) | |||
return x | |||
@staticmethod | |||
def align_rsv(a, b): | |||
cosine = torch.matmul(a.transpose(-2, -1), b) | |||
max_abs_cosine, _ = torch.max(torch.abs(cosine), 1, keepdim=True) | |||
mask = torch.where(torch.eq(max_abs_cosine, torch.abs(cosine)), | |||
torch.sign(cosine), torch.zeros_like(cosine)) | |||
a = torch.matmul(a, mask) | |||
return a, b | |||
@@ -0,0 +1,186 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
import abc | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
import sys | |||
import typing | |||
from typing import Callable, Dict, List, Any | |||
from collections import Mapping, Sequence | |||
from . import loss | |||
from kamal.core import metrics, exceptions | |||
from kamal.core.attach import AttachTo | |||
class Task(object): | |||
def __init__(self, name): | |||
self.name = name | |||
@abc.abstractmethod | |||
def get_loss( self, outputs, targets ) -> Dict: | |||
pass | |||
@abc.abstractmethod | |||
def predict(self, outputs) -> Any: | |||
pass | |||
class GeneralTask(Task): | |||
def __init__(self, | |||
name: str, | |||
loss_fn: Callable, | |||
scaling:float=1.0, | |||
pred_fn: Callable=lambda x: x, | |||
attach_to=None): | |||
super(GeneralTask, self).__init__(name) | |||
self._attach = AttachTo(attach_to) | |||
self.loss_fn = loss_fn | |||
self.pred_fn = pred_fn | |||
self.scaling = scaling | |||
def get_loss(self, outputs, targets): | |||
outputs, targets = self._attach(outputs, targets) | |||
return { self.name: self.loss_fn( outputs, targets ) * self.scaling } | |||
def predict(self, outputs): | |||
outputs = self._attach(outputs) | |||
return self.pred_fn(outputs) | |||
def __repr__(self): | |||
rep = "Task: [%s loss_fn=%s scaling=%.4f attach=%s]"%(self.name, str(self.loss_fn), self.scaling, self._attach) | |||
return rep | |||
class TaskCompose(list): | |||
def __init__(self, tasks: list): | |||
for task in tasks: | |||
if isinstance(task, Task): | |||
self.append(task) | |||
def get_loss(self, outputs, targets): | |||
loss_dict = {} | |||
for task in self: | |||
loss_dict.update( task.get_loss( outputs, targets ) ) | |||
return loss_dict | |||
def predict(self, outputs): | |||
results = [] | |||
for task in self: | |||
results.append( task.predict( outputs ) ) | |||
return results | |||
def __repr__(self): | |||
rep="TaskCompose: \n" | |||
for task in self: | |||
rep+="\t%s\n"%task | |||
class StandardTask: | |||
@staticmethod | |||
def classification(name='ce', scaling=1.0, attach_to=None): | |||
return GeneralTask( name=name, | |||
loss_fn=nn.CrossEntropyLoss(), | |||
scaling=scaling, | |||
pred_fn=lambda x: x.max(1)[1], | |||
attach_to=attach_to ) | |||
@staticmethod | |||
def binary_classification(name='bce', scaling=1.0, attach_to=None): | |||
return GeneralTask(name=name, | |||
loss_fn=F.binary_cross_entropy_with_logits, | |||
scaling=scaling, | |||
pred_fn=lambda x: (x>0.5), | |||
attach_to=attach_to ) | |||
@staticmethod | |||
def regression(name='mse', scaling=1.0, attach_to=None): | |||
return GeneralTask(name=name, | |||
loss_fn=nn.MSELoss(), | |||
scaling=scaling, | |||
pred_fn=lambda x: x, | |||
attach_to=attach_to ) | |||
@staticmethod | |||
def segmentation(name='ce', scaling=1.0, attach_to=None): | |||
return GeneralTask(name=name, | |||
loss_fn=nn.CrossEntropyLoss(ignore_index=255), | |||
scaling=scaling, | |||
pred_fn=lambda x: x.max(1)[1], | |||
attach_to=attach_to ) | |||
@staticmethod | |||
def monocular_depth(name='l1', scaling=1.0, attach_to=None): | |||
return GeneralTask(name=name, | |||
loss_fn=nn.L1Loss(), | |||
scaling=scaling, | |||
pred_fn=lambda x: x, | |||
attach_to=attach_to) | |||
@staticmethod | |||
def detection(): | |||
raise NotImplementedError | |||
@staticmethod | |||
def distillation(name='kld', T=1.0, scaling=1.0, attach_to=None): | |||
return GeneralTask(name=name, | |||
loss_fn=loss.KLDiv(T=T), | |||
scaling=scaling, | |||
pred_fn=lambda x: x.max(1)[1], | |||
attach_to=attach_to) | |||
class StandardMetrics(object): | |||
@staticmethod | |||
def classification(attach_to=None): | |||
return metrics.MetricCompose( | |||
metric_dict={'acc': metrics.Accuracy(attach_to=attach_to)} | |||
) | |||
@staticmethod | |||
def regression(attach_to=None): | |||
return metrics.MetricCompose( | |||
metric_dict={'mse': metrics.MeanSquaredError(attach_to=attach_to)} | |||
) | |||
@staticmethod | |||
def segmentation(num_classes, ignore_idx=255, attach_to=None): | |||
confusion_matrix = metrics.ConfusionMatrix(num_classes=num_classes, ignore_idx=ignore_idx, attach_to=attach_to) | |||
return metrics.MetricCompose( | |||
metric_dict={'acc': metrics.Accuracy(attach_to=attach_to), | |||
'confusion_matrix': confusion_matrix , | |||
'miou': metrics.mIoU(confusion_matrix)} | |||
) | |||
@staticmethod | |||
def monocular_depth(attach_to=None): | |||
return metrics.MetricCompose( | |||
metric_dict={ | |||
'rmse': metrics.RootMeanSquaredError(attach_to=attach_to), | |||
'rmse_log': metrics.RootMeanSquaredError( log_scale=True,attach_to=attach_to ), | |||
'rmse_scale_inv': metrics.ScaleInveriantMeanSquaredError(attach_to=attach_to), | |||
'abs rel': metrics.AbsoluteRelativeDifference(attach_to=attach_to), | |||
'sq rel': metrics.SquaredRelativeDifference(attach_to=attach_to), | |||
'percents within thresholds': metrics.Threshold( thresholds=[1.25, 1.25**2, 1.25**3], attach_to=attach_to ) | |||
} | |||
) | |||
@staticmethod | |||
def loss_metric(loss_fn): | |||
return metrics.MetricCompose( | |||
metric_dict={ | |||
'loss': metrics.AverageMetric( loss_fn ) | |||
} | |||
) |
@@ -0,0 +1,2 @@ | |||
from .prunning import Pruner, strategy | |||
from .distillation import * |
@@ -0,0 +1,12 @@ | |||
from .kd import KDDistiller | |||
from .hint import * | |||
from .attention import AttentionDistiller | |||
from .nst import NSTDistiller | |||
from .sp import SPDistiller | |||
from .rkd import RKDDistiller | |||
from .pkt import PKTDistiller | |||
from .svd import SVDDistiller | |||
from .cc import * | |||
from .vid import * | |||
from . import data_free |
@@ -0,0 +1,47 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
from .kd import KDDistiller | |||
from kamal.core.tasks.loss import AttentionLoss | |||
from kamal.core.tasks.loss import KDLoss | |||
from kamal.utils import set_mode, move_to_device | |||
import torch | |||
import torch.nn as nn | |||
import time | |||
class AttentionDistiller(KDDistiller): | |||
def __init__(self, logger=None, tb_writer=None ): | |||
super(AttentionDistiller, self).__init__( logger, tb_writer ) | |||
def setup(self, | |||
student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, | |||
stu_hooks=[], tea_hooks=[], out_flags=[], device=None): | |||
super(AttentionDistiller, self).setup( | |||
student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device) | |||
self.stu_hooks = stu_hooks | |||
self.tea_hooks = tea_hooks | |||
self.out_flags = out_flags | |||
self._at_loss = AttentionLoss() | |||
def additional_kd_loss(self, engine, batch): | |||
feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] | |||
feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] | |||
g_s = feat_s[1:-1] | |||
g_t = feat_t[1:-1] | |||
return self._at_loss(g_s, g_t) |
@@ -0,0 +1,55 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
from .kd import KDDistiller | |||
from kamal.core.tasks.loss import KDLoss | |||
import torch.nn as nn | |||
import torch._ops | |||
import time | |||
class CCDistiller(KDDistiller): | |||
def __init__(self, logger=None, tb_writer=None ): | |||
super(CCDistiller, self).__init__( logger, tb_writer ) | |||
def setup(self, student, teacher, dataloader, optimizer, embed_s, embed_t, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None ): | |||
super(CCDistiller, self).setup( | |||
student, teacher, dataloader, optimizer, T=T, gamma=gamma, alpha=alpha, device=device) | |||
self.embed_s = embed_s.to(self.device).train() | |||
self.embed_t = embed_t.to(self.device).train() | |||
self.stu_hooks = stu_hooks | |||
self.tea_hooks = tea_hooks | |||
self.out_flags = out_flags | |||
def additional_kd_loss(self, engine, batch): | |||
feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] | |||
feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] | |||
f_s = self.embed_s(feat_s[-1]) | |||
f_t = self.embed_t(feat_t[-1]) | |||
return torch.mean((torch.abs(f_s-f_t)[:-1] * torch.abs(f_s-f_t)[1:]).sum(1)) | |||
class LinearEmbed(nn.Module): | |||
def __init__(self, dim_in=1024, dim_out=128): | |||
super(LinearEmbed, self).__init__() | |||
self.linear = nn.Linear(dim_in, dim_out) | |||
def forward(self, x): | |||
x = x.view(x.shape[0], -1) | |||
x = self.linear(x) | |||
return x |
@@ -0,0 +1 @@ | |||
from .zskt import ZSKTDistiller |
@@ -0,0 +1,99 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
import torch | |||
import time | |||
import torch.nn.functional as F | |||
from kamal.slim.distillation.kd import KDDistiller | |||
from kamal.utils import set_mode | |||
from kamal.core.tasks.loss import kldiv | |||
class ZSKTDistiller(KDDistiller): | |||
def __init__( self, | |||
student, | |||
teacher, | |||
generator, | |||
z_dim, | |||
logger=None, | |||
viz=None): | |||
super(ZSKTDistiller, self).__init__(logger, viz) | |||
self.teacher = teacher | |||
self.model = self.student = student | |||
self.generator = generator | |||
self.z_dim = z_dim | |||
def train(self, start_iter, max_iter, optim_s, optim_g, device=None): | |||
if device is None: | |||
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) | |||
self.device = device | |||
self.optim_s, self.optim_g = optim_s, optim_g | |||
self.model.to(self.device) | |||
self.teacher.to(self.device) | |||
self.generator.to(self.device) | |||
self.train_loader = [0, ] | |||
with set_mode(self.student, training=True), \ | |||
set_mode(self.teacher, training=False), \ | |||
set_mode(self.generator, training=True): | |||
super( ZSKTDistiller, self ).train( start_iter, max_iter ) | |||
def search_optimizer(self, evaluator, train_loader, hpo_space=None, mode='min', max_evals=20, max_iters=400): | |||
optimizer = hpo.search_optimizer(self, train_loader, evaluator=evaluator, hpo_space=hpo_space, mode=mode, max_evals=max_evals, max_iters=max_iters) | |||
return optimizer | |||
def step(self): | |||
start_time = time.perf_counter() | |||
# Adv | |||
z = torch.randn( self.z_dim ).to(self.device) | |||
fake = self.generator( z ) | |||
self.optim_g.zero_grad() | |||
t_out = self.teacher( fake ) | |||
s_out = self.student( fake ) | |||
loss_g = -kldiv( s_out, t_out ) | |||
loss_g.backward() | |||
self.optim_g.step() | |||
with torch.no_grad(): | |||
fake = self.generator( z ) | |||
t_out = self.teacher( fake.detach() ) | |||
for _ in range(10): | |||
self.optim_s.zero_grad() | |||
s_out = self.student( fake.detach() ) | |||
loss_s = kldiv( s_out, t_out ) | |||
loss_s.backward() | |||
self.optim_s.step() | |||
loss_dict = { | |||
'loss_g': loss_g, | |||
'loss_s': loss_s, | |||
} | |||
step_time = time.perf_counter() - start_time | |||
# record training info | |||
info = loss_dict | |||
info['step_time'] = step_time | |||
info['lr_s'] = float( self.optim_s.param_groups[0]['lr'] ) | |||
info['lr_g'] = float( self.optim_g.param_groups[0]['lr'] ) | |||
self.history.put_scalars( **info ) | |||
def reset(self): | |||
self.history = None | |||
self._train_loader_iter = iter(train_loader) | |||
self.iter = self.start_iter |
@@ -0,0 +1,86 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
from .kd import KDDistiller | |||
from kamal.core.tasks.loss import KDLoss | |||
import torch.nn as nn | |||
import torch._ops | |||
import time | |||
class HintDistiller(KDDistiller): | |||
def __init__(self, logger=None, tb_writer=None ): | |||
super(HintDistiller, self).__init__( logger, tb_writer ) | |||
def setup(self, | |||
student, teacher, regressor, dataloader, optimizer, | |||
hint_layer=2, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, | |||
stu_hooks=[], tea_hooks=[], out_flags=[], device=None): | |||
super( HintDistiller, self ).setup( | |||
student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device ) | |||
self.regressor = regressor | |||
self._hint_layer = hint_layer | |||
self._beta = beta | |||
self.stu_hooks = stu_hooks | |||
self.tea_hooks = tea_hooks | |||
self.out_flags = out_flags | |||
self.regressor.to(device) | |||
def additional_kd_loss(self, engine, batch): | |||
feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] | |||
feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] | |||
f_s = self.regressor(feat_s[self._hint_layer]) | |||
f_t = feat_t[self._hint_layer] | |||
return nn.functional.mse_loss(f_s, f_t) | |||
class Regressor(nn.Module): | |||
""" | |||
Convolutional regression for FitNet | |||
@inproceedings{tian2019crd, | |||
title={Contrastive Representation Distillation}, | |||
author={Yonglong Tian and Dilip Krishnan and Phillip Isola}, | |||
booktitle={International Conference on Learning Representations}, | |||
year={2020} | |||
} | |||
""" | |||
def __init__(self, s_shape, t_shape, is_relu=True): | |||
super(Regressor, self).__init__() | |||
self.is_relu = is_relu | |||
_, s_C, s_H, s_W = s_shape | |||
_, t_C, t_H, t_W = t_shape | |||
if s_H == 2 * t_H: | |||
self.conv = nn.Conv2d(s_C, t_C, kernel_size=3, stride=2, padding=1) | |||
elif s_H * 2 == t_H: | |||
self.conv = nn.ConvTranspose2d( | |||
s_C, t_C, kernel_size=4, stride=2, padding=1) | |||
elif s_H >= t_H: | |||
self.conv = nn.Conv2d(s_C, t_C, kernel_size=(1+s_H-t_H, 1+s_W-t_W)) | |||
else: | |||
raise NotImplemented( | |||
'student size {}, teacher size {}'.format(s_H, t_H)) | |||
self.bn = nn.BatchNorm2d(t_C) | |||
self.relu = nn.ReLU(inplace=True) | |||
def forward(self, x): | |||
x = self.conv(x) | |||
if self.is_relu: | |||
return self.relu(self.bn(x)) | |||
else: | |||
return self.bn(x) |
@@ -0,0 +1,90 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
from kamal.core.engine.trainer import Engine | |||
from kamal.core.tasks.loss import kldiv | |||
import torch.nn.functional as F | |||
from kamal.utils.logger import get_logger | |||
from kamal.utils import set_mode, move_to_device | |||
import weakref | |||
import torch | |||
import torch.nn as nn | |||
import time | |||
import numpy as np | |||
class KDDistiller(Engine): | |||
def __init__( self, | |||
logger=None, | |||
tb_writer=None): | |||
super(KDDistiller, self).__init__(logger=logger, tb_writer=tb_writer) | |||
def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, device=None): | |||
self.model = self.student = student | |||
self.teacher = teacher | |||
self.dataloader = dataloader | |||
self.optimizer = optimizer | |||
if device is None: | |||
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) | |||
self.device = device | |||
self.T = T | |||
self.gamma = gamma | |||
self.alpha = alpha | |||
self.beta = beta | |||
self.student.to(self.device) | |||
self.teacher.to(self.device) | |||
def run( self, max_iter, start_iter=0, epoch_length=None): | |||
with set_mode(self.student, training=True), \ | |||
set_mode(self.teacher, training=False): | |||
super( KDDistiller, self ).run( self.step_fn, self.dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length) | |||
def additional_kd_loss(self, engine, batch): | |||
return batch[0].new_zeros(1) | |||
def step_fn(self, engine, batch): | |||
student = self.student | |||
teacher = self.teacher | |||
start_time = time.perf_counter() | |||
batch = move_to_device(batch, self.device) | |||
inputs, targets = batch | |||
outputs = student(inputs) | |||
with torch.no_grad(): | |||
soft_targets = teacher(inputs) | |||
loss_dict = { "loss_kld": self.alpha * kldiv(outputs, soft_targets, T=self.T), | |||
"loss_ce": self.beta * F.cross_entropy( outputs, targets ), | |||
"loss_additional": self.gamma * self.additional_kd_loss(engine, batch) } | |||
loss = sum( loss_dict.values() ) | |||
self.optimizer.zero_grad() | |||
loss.backward() | |||
self.optimizer.step() | |||
step_time = time.perf_counter() - start_time | |||
metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } | |||
metrics.update({ | |||
'total_loss': loss.item(), | |||
'step_time': step_time, | |||
'lr': float( self.optimizer.param_groups[0]['lr'] ) | |||
}) | |||
return metrics | |||
@@ -0,0 +1,44 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
from .kd import KDDistiller | |||
from kamal.core.tasks.loss import KDLoss | |||
from kamal.core.tasks.loss import NSTLoss | |||
import torch | |||
import torch.nn as nn | |||
import time | |||
class NSTDistiller(KDDistiller): | |||
def __init__(self, logger=None, tb_writer=None ): | |||
super(NSTDistiller, self).__init__( logger, tb_writer ) | |||
def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None): | |||
super( NSTDistiller, self ).setup( | |||
student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device ) | |||
self.stu_hooks = stu_hooks | |||
self.tea_hooks = tea_hooks | |||
self.out_flags = out_flags | |||
self._nst_loss = NSTLoss() | |||
def additional_kd_loss(self, engine, batch): | |||
feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] | |||
feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] | |||
g_s = feat_s[1:-1] | |||
g_t = feat_t[1:-1] | |||
return self._nst_loss(g_s, g_t) |
@@ -0,0 +1,44 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
from .kd import KDDistiller | |||
from kamal.core.tasks.loss import PKTLoss | |||
from kamal.core.tasks.loss import KDLoss | |||
import torch | |||
import torch.nn as nn | |||
import time | |||
class PKTDistiller(KDDistiller): | |||
def __init__(self, logger=None, tb_writer=None ): | |||
super(PKTDistiller, self).__init__( logger, tb_writer ) | |||
def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None): | |||
super( PKTDistiller, self ).setup( | |||
student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device ) | |||
self.stu_hooks = stu_hooks | |||
self.tea_hooks = tea_hooks | |||
self.out_flags = out_flags | |||
self._pkt_loss = PKTLoss() | |||
def additional_kd_loss(self, engine, batch): | |||
feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] | |||
feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] | |||
f_s = feat_s[-1] | |||
f_t = feat_t[-1] | |||
return self._pkt_loss(f_s, f_t) |
@@ -0,0 +1,45 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
from .kd import KDDistiller | |||
from kamal.core.tasks.loss import RKDLoss | |||
from kamal.core.tasks.loss import KDLoss | |||
import torch | |||
import torch.nn as nn | |||
import time | |||
class RKDDistiller(KDDistiller): | |||
def __init__(self, logger=None, tb_writer=None ): | |||
super(RKDDistiller, self).__init__( logger, tb_writer ) | |||
def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None): | |||
super( RKDDistiller, self ).setup( | |||
student, teacher, dataloader, optimizer, T=T, gamma=gamma, alpha=alpha, device=device ) | |||
self.stu_hooks = stu_hooks | |||
self.tea_hooks = tea_hooks | |||
self.out_flags = out_flags | |||
self._rkd_loss = RKDLoss() | |||
def additional_kd_loss(self, engine, batch): | |||
feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] | |||
feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] | |||
f_s = feat_s[-1] | |||
f_t = feat_t[-1] | |||
return self._rkd_loss(f_s, f_t) | |||
@@ -0,0 +1,44 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
from .kd import KDDistiller | |||
from kamal.core.tasks.loss import SPLoss | |||
from kamal.core.tasks.loss import KDLoss | |||
import torch | |||
import torch.nn as nn | |||
import time | |||
class SPDistiller(KDDistiller): | |||
def __init__(self, logger=None, tb_writer=None ): | |||
super(SPDistiller, self).__init__( logger, tb_writer ) | |||
def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None): | |||
super( SPDistiller, self ).setup( | |||
student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device ) | |||
self.stu_hooks = stu_hooks | |||
self.tea_hooks = tea_hooks | |||
self.out_flags = out_flags | |||
self._sp_loss = SPLoss() | |||
def additional_kd_loss(self, engine, batch): | |||
feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] | |||
feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] | |||
g_s = [feat_s[-2]] | |||
g_t = [feat_t[-2]] | |||
return self._sp_loss(g_s, g_t) |
@@ -0,0 +1,45 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
from .kd import KDDistiller | |||
from kamal.core.tasks.loss import SVDLoss | |||
from kamal.core.tasks.loss import KDLoss | |||
import torch | |||
import torch.nn as nn | |||
import time | |||
class SVDDistiller(KDDistiller): | |||
def __init__(self, logger=None, tb_writer=None ): | |||
super(SVDDistiller, self).__init__( logger, tb_writer ) | |||
def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None): | |||
super( SVDDistiller, self ).setup( | |||
student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device ) | |||
self.stu_hooks = stu_hooks | |||
self.tea_hooks = tea_hooks | |||
self.out_flags = out_flags | |||
self._svd_loss = SVDLoss() | |||
def additional_kd_loss(self, engine, batch): | |||
feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] | |||
feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] | |||
g_s = feat_s[1:-1] | |||
g_t = feat_t[1:-1] | |||
return self._svd_loss( g_s, g_t ) |
@@ -0,0 +1,90 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
import numpy as np | |||
import time | |||
import torch.nn as nn | |||
import torch._ops | |||
import torch.nn.functional as F | |||
from .kd import KDDistiller | |||
from kamal.utils import set_mode | |||
from kamal.core.tasks.loss import KDLoss | |||
class VIDDistiller(KDDistiller): | |||
def __init__(self, logger=None, tb_writer=None ): | |||
super(VIDDistiller, self).__init__( logger, tb_writer ) | |||
def setup(self, student, teacher, dataloader, optimizer, regressor_l, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None): | |||
super( VIDDistiller, self ).setup( | |||
student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device ) | |||
self.regressor_l = regressor_l | |||
self.stu_hooks = stu_hooks | |||
self.tea_hooks = tea_hooks | |||
self.out_flags = out_flags | |||
self.regressor_l = [regressor.to(self.device).train() for regressor in self.regressor_l] | |||
def additional_kd_loss(self, engine, batch): | |||
feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] | |||
feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] | |||
g_s = feat_s[1:-1] | |||
g_t = feat_t[1:-1] | |||
return sum([c(f_s, f_t) for f_s, f_t, c in zip(g_s, g_t, self.regressor_l)]) | |||
class VIDRegressor(nn.Module): | |||
def __init__(self, | |||
num_input_channels, | |||
num_mid_channel, | |||
num_target_channels, | |||
init_pred_var=5.0, | |||
eps=1e-5): | |||
super(VIDRegressor, self).__init__() | |||
def conv1x1(in_channels, out_channels, stride=1): | |||
return nn.Conv2d( | |||
in_channels, out_channels, | |||
kernel_size=1, padding=0, | |||
bias=False, stride=stride) | |||
self.regressor = nn.Sequential( | |||
conv1x1(num_input_channels, num_mid_channel), | |||
nn.ReLU(), | |||
conv1x1(num_mid_channel, num_mid_channel), | |||
nn.ReLU(), | |||
conv1x1(num_mid_channel, num_target_channels), | |||
) | |||
self.log_scale = torch.nn.Parameter( | |||
np.log(np.exp(init_pred_var-eps)-1.0) * torch.ones(num_target_channels) | |||
) | |||
self.eps = eps | |||
def forward(self, input, target): | |||
# pool for dimentsion match | |||
s_H, t_H = input.shape[2], target.shape[2] | |||
if s_H > t_H: | |||
input = F.adaptive_avg_pool2d(input, (t_H, t_H)) | |||
elif s_H < t_H: | |||
target = F.adaptive_avg_pool2d(target, (s_H, s_H)) | |||
else: | |||
pass | |||
pred_mean = self.regressor(input) | |||
pred_var = torch.log(1.0+torch.exp(self.log_scale))+self.eps | |||
pred_var = pred_var.view(1, -1, 1, 1) | |||
neg_log_prob = 0.5*( | |||
(pred_mean-target)**2/pred_var+torch.log(pred_var) | |||
) | |||
loss = torch.mean(neg_log_prob) | |||
return loss |
@@ -0,0 +1,2 @@ | |||
from .pruner import Pruner | |||
from .strategy import LNStrategy, RandomStrategy |
@@ -0,0 +1,37 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
from copy import deepcopy | |||
import os | |||
import torch | |||
class Pruner(object): | |||
def __init__(self, strategy): | |||
self.strategy = strategy | |||
def prune(self, model, rate=0.1, example_inputs=None): | |||
ori_num_params = sum( [ torch.numel(p) for p in model.parameters() ] ) | |||
model = deepcopy(model).cpu() | |||
model = self._prune( model, rate=rate, example_inputs=example_inputs ) | |||
new_num_params = sum( [ torch.numel(p) for p in model.parameters() ] ) | |||
print( "%d=>%d, %.2f%% params were pruned"%( ori_num_params, new_num_params, 100*(ori_num_params-new_num_params)/ori_num_params ) ) | |||
return model | |||
def _prune(self, model, **kargs): | |||
return self.strategy( model, **kargs) | |||
@@ -0,0 +1,85 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
import torch_pruning as tp | |||
import abc | |||
import torch | |||
import torch.nn as nn | |||
import random | |||
import numpy as np | |||
_PRUNABLE_MODULES= tp.DependencyGraph.PRUNABLE_MODULES | |||
class BaseStrategy(abc.ABC): | |||
@abc.abstractmethod | |||
def select(self, layer_to_prune): | |||
pass | |||
def __call__(self, model, rate=0.1, example_inputs=None): | |||
if example_inputs is None: | |||
example_inputs = torch.randn( 1,3,256,256 ) | |||
DG = tp.DependencyGraph() | |||
DG.build_dependency(model, example_inputs=example_inputs) | |||
prunable_layers = [] | |||
total_params = 0 | |||
num_accumulative_conv_params = [ 0, ] | |||
for m in model.modules(): | |||
if isinstance(m, _PRUNABLE_MODULES ) : | |||
nparam = tp.utils.count_prunable_params( m ) | |||
total_params += nparam | |||
if isinstance(m, (nn.modules.conv._ConvNd, nn.Linear)): | |||
prunable_layers.append( m ) | |||
num_accumulative_conv_params.append( num_accumulative_conv_params[-1]+nparam ) | |||
prunable_layers.pop(-1) # remove the last layer | |||
num_accumulative_conv_params.pop(-1) # remove the last layer | |||
num_conv_params = num_accumulative_conv_params[-1] | |||
num_accumulative_conv_params = [ ( num_accumulative_conv_params[i], num_accumulative_conv_params[i+1] ) for i in range(len(num_accumulative_conv_params)-1) ] | |||
def map_param_idx_to_conv_layer(i): | |||
for l, accu in zip( prunable_layers, num_accumulative_conv_params ): | |||
if accu[0]<=i and i<accu[1]: | |||
return l | |||
num_pruned = 0 | |||
while num_pruned<total_params*rate: | |||
layer_to_prune = map_param_idx_to_conv_layer( random.randint( 0, num_conv_params-1 ) ) | |||
if layer_to_prune.weight.shape[0]<1: | |||
continue | |||
idx = self.select( layer_to_prune ) | |||
fn = tp.prune_conv if isinstance(layer_to_prune, nn.modules.conv._ConvNd) else tp.prune_linear | |||
plan = DG.get_pruning_plan( layer_to_prune, fn, idxs=idx ) | |||
num_pruned += plan.exec() | |||
return model | |||
class RandomStrategy(BaseStrategy): | |||
def select(self, layer_to_prune): | |||
return [ random.randint( 0, layer_to_prune.weight.shape[0]-1 ) ] | |||
class LNStrategy(BaseStrategy): | |||
def __init__(self, n=2): | |||
self.n = n | |||
def select(self, layer_to_prune): | |||
w = torch.flatten( layer_to_prune.weight, 1 ) | |||
norm = torch.norm(w, p=self.n, dim=1) | |||
idx = [ int(norm.min(dim=0)[1].item()) ] | |||
return idx |
@@ -0,0 +1,18 @@ | |||
# Deep Model Transferbility from Attribution Maps | |||
- [*"Paper: Deep Model Transferbility from Attribution Maps"*](https:), NeurIPS 2019.(released soon) | |||
J. Song, Y. Chen, X. Wang, C. Shen, M. Song | |||
[Homepage of VIPA Group](https://www.vipazoo.cn/index_en.html), Zhejiang University, China | |||
This repo is rewrited by zhfeing in `Pytorch`. | |||
[Homepage of VIPA Group](https://www.vipazoo.cn/index_en.html), Zhejiang University, China | |||
<div align="left"> | |||
<img src="vipa-logo.png" width = "40%" height = "40%" alt="icon"/> | |||
</div> | |||
@@ -0,0 +1,20 @@ | |||
from kamal.transferability.trans_graph import TransferabilityGraph | |||
from kamal.transferability.trans_metric import AttrMapMetric | |||
import kamal | |||
from kamal.vision import sync_transforms as sT | |||
import os | |||
import torch | |||
from PIL import Image | |||
if __name__=='__main__': | |||
zoo = '/tmp/pycharm_project_225/kamal/transferability/model2' | |||
TG = TransferabilityGraph(zoo) | |||
probe_set_root = '/tmp/pycharm_project_225/kamal/transferability/probe_data' | |||
for probe_set in os.listdir( probe_set_root ): | |||
print("Add %s"%(probe_set)) | |||
imgs_set = list( os.listdir( os.path.join( probe_set_root, probe_set ) ) ) | |||
images = [ Image.open( os.path.join(probe_set_root, probe_set, img) ) for img in imgs_set ] | |||
metric = AttrMapMetric(images, device=torch.device('cuda')) | |||
TG.add_metric( probe_set, metric) | |||
TG.export_to_json(probe_set, 'exported_metrics/%s.json'%(probe_set), topk=3, normalize=True) |
@@ -0,0 +1,3 @@ | |||
from .attribution_graph import get_attribution_graph, graph_similarity | |||
from .attribution_map import attribution_map, attr_map_distance |
@@ -0,0 +1,184 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
from typing import Type, Dict, Any | |||
import itertools | |||
import numpy as np | |||
from numpy import Inf | |||
import networkx | |||
import torch | |||
from torch.nn import Module | |||
from torch.utils.data import Dataset | |||
from .attribution_map import attribution_map, attr_map_similarity | |||
def graph_to_array(graph: networkx.Graph): | |||
weight_matrix = np.zeros((len(graph.nodes), len(graph.nodes))) | |||
for i, n1 in enumerate(graph.nodes): | |||
for j, n2 in enumerate(graph.nodes): | |||
try: | |||
dist = graph[n1][n2]["weight"] | |||
except KeyError: | |||
dist = 1 | |||
weight_matrix[i, j] = dist | |||
return weight_matrix | |||
class FeatureMapExtractor(): | |||
def __init__(self, module: Module): | |||
self.module = module | |||
self.feature_pool: Dict[str, Dict[str, Any]] = dict() | |||
self.register_hooks() | |||
def register_hooks(self): | |||
for name, m in self.module.named_modules(): | |||
if "pool" in name: | |||
m.name = name | |||
self.feature_pool[name] = dict() | |||
def hook(m: Module, input, output): | |||
self.feature_pool[m.name]["feature"] = input | |||
self.feature_pool[name]["handle"] = m.register_forward_hook(hook) | |||
def _forward(self, x): | |||
self.module(x) | |||
def remove_hooks(self): | |||
for name, cfg in self.feature_pool.items(): | |||
cfg["handle"].remove() | |||
cfg.clear() | |||
self.feature_pool.clear() | |||
def extract_final_map(self, x): | |||
self._forward(x) | |||
feature_map = None | |||
max_channel = 0 | |||
min_size = Inf | |||
for name, cfg in self.feature_pool.items(): | |||
f = cfg["feature"] | |||
if len(f) == 1 and isinstance(f[0], torch.Tensor): | |||
f = f[0] | |||
if f.dim() == 4: # BxCxHxW | |||
b, c, h, w = f.shape | |||
if c >= max_channel and 1 < h * w <= min_size: | |||
feature_map = f | |||
max_channel = c | |||
min_size = h * w | |||
return feature_map | |||
def get_attribution_graph( | |||
model: Module, | |||
attribution_type: Type, | |||
with_noise: bool, | |||
probe_data: Dataset, | |||
device: torch.device, | |||
norm_square: bool = False, | |||
): | |||
attribution_graph = networkx.Graph() | |||
model = model.to(device) | |||
extractor = FeatureMapExtractor(model) | |||
for i, x in enumerate(probe_data): | |||
x = x.to(device) | |||
x.requires_grad_() | |||
attribution = attribution_map( | |||
func=lambda x: extractor.extract_final_map(x), | |||
attribution_type=attribution_type, | |||
with_noise=with_noise, | |||
probe_data=x.unsqueeze(0), | |||
norm_square=norm_square | |||
) | |||
attribution_graph.add_node(i, attribution_map=attribution) | |||
nodes = attribution_graph.nodes | |||
for i, j in itertools.product(nodes, nodes): | |||
if i < j: | |||
weight = attr_map_similarity( | |||
attribution_graph.nodes(data=True)[i]["attribution_map"], | |||
attribution_graph.nodes(data=True)[j]["attribution_map"] | |||
) | |||
attribution_graph.add_edge(i, j, weight=weight) | |||
return attribution_graph | |||
def edge_to_embedding(graph: networkx.Graph): | |||
adj = graph_to_array(graph) | |||
up_tri_mask = np.tri(*adj.shape[-2:], k=0, dtype=bool) | |||
return adj[up_tri_mask] | |||
def embedding_to_rank(embedding: np.ndarray): | |||
order = embedding.argsort() | |||
ranks = order.argsort() | |||
return ranks | |||
def graph_similarity(g1: networkx.Graph, g2: networkx.Graph, Lambda: float = 1.0): | |||
nodes_1 = g1.nodes(data=True) | |||
nodes_2 = g2.nodes(data=True) | |||
assert len(nodes_1) == len(nodes_2) | |||
# calculate vertex similarity | |||
v_s = 0 | |||
n = len(g1.nodes) | |||
for i in range(n): | |||
v_s += attr_map_similarity( | |||
map_1=g1.nodes(data=True)[i]["attribution_map"], | |||
map_2=g2.nodes(data=True)[i]["attribution_map"] | |||
) | |||
vertex_similarity = v_s / n | |||
# calculate edges similarity | |||
emb_1 = edge_to_embedding(g1) | |||
emb_2 = edge_to_embedding(g2) | |||
rank_1 = embedding_to_rank(emb_1) | |||
rank_2 = embedding_to_rank(emb_2) | |||
k = emb_1.shape[0] | |||
edge_similarity = 1 - 6 * np.sum(np.square(rank_1 - rank_2)) / (k ** 3 - k) | |||
return vertex_similarity + Lambda * edge_similarity | |||
if __name__ == "__main__": | |||
from captum.attr import InputXGradient | |||
from torchvision.models import resnet34 | |||
model_1 = resnet34(num_classes=10) | |||
graph_1 = get_attribution_graph( | |||
model_1, | |||
attribution_type=InputXGradient, | |||
with_noise=False, | |||
probe_data=torch.rand(10, 3, 244, 244), | |||
device=torch.device("cpu") | |||
) | |||
model_2 = resnet34(num_classes=10) | |||
graph_2 = get_attribution_graph( | |||
model_2, | |||
attribution_type=InputXGradient, | |||
with_noise=False, | |||
probe_data=torch.rand(10, 3, 244, 244), | |||
device=torch.device("cpu") | |||
) | |||
print(graph_similarity(graph_1, graph_2)) |
@@ -0,0 +1,87 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
import torch | |||
from typing import Type, Callable | |||
from captum.attr import Attribution | |||
from captum.attr import NoiseTunnel | |||
def with_norm(func: Callable[[torch.Tensor], torch.Tensor], x: torch.Tensor, square: bool = False): | |||
x = func(x) | |||
x = torch.norm(x.flatten(1), dim=1, p=2) | |||
if square: | |||
x = torch.pow(x, 2) | |||
return x | |||
def attribution_map( | |||
func: Callable[[torch.Tensor], torch.Tensor], | |||
attribution_type: Type, | |||
with_noise: bool, | |||
probe_data: torch.Tensor, | |||
norm_square: bool = False, | |||
**attribution_kwargs | |||
) -> torch.Tensor: | |||
""" | |||
Calculate attribution map with given attribution type(algorithm). | |||
Args: | |||
model: pytorch module | |||
attribution_type: attribution algorithm, e.g. IntegratedGradients, InputXGradient, ... | |||
with_noise: whether to add noise tunnel | |||
probe_data: input data to model | |||
device: torch.device("cuda: 0") | |||
attribution_kwargs: other kwargs for attribution method | |||
Return: attribution map | |||
""" | |||
attribution: Attribution = attribution_type(lambda x: with_norm(func, x, norm_square)) | |||
if with_noise: | |||
attribution = NoiseTunnel(attribution) | |||
attr_map = attribution.attribute( | |||
inputs=probe_data, | |||
target=None, | |||
**attribution_kwargs | |||
) | |||
return attr_map.detach() | |||
def attr_map_distance(map_1: torch.Tensor, map_2: torch.Tensor): | |||
if map_1.shape != map_2.shape: | |||
map_1 = torch.nn.functional.interpolate( map_1, size=map_2.shape[-2:] ) | |||
#dist = torch.dist(map_1.flatten(1), map_2.flatten(1), p=2).mean() | |||
dist = 1 - torch.cosine_similarity(map_1.flatten(1), map_2.flatten(1)).mean() | |||
return dist.item() | |||
def attr_map_similarity(map_1: torch.Tensor, map_2: torch.Tensor): | |||
assert(map_1.shape == map_2.shape) | |||
dist = torch.cosine_similarity(map_1.flatten(1), map_2.flatten(1)).mean() | |||
return dist.item() | |||
if __name__ == "__main__": | |||
import captum | |||
def ff(x): | |||
return x ** 2 | |||
m = attribution_map( | |||
ff, | |||
captum.attr.InputXGradient, | |||
with_noise=False, | |||
probe_data=torch.tensor([[1, 2, 3, 4]], dtype=torch.float, requires_grad=True), | |||
norm_square=True | |||
) | |||
print(m) |
@@ -0,0 +1,135 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
import torch | |||
import networkx as nx | |||
from . import depara | |||
import os, abc | |||
from typing import Callable | |||
from kamal import hub | |||
import json, numbers | |||
from tqdm import tqdm | |||
class Node(object): | |||
def __init__(self, hub_root, entry_name, spec_name): | |||
self.hub_root = hub_root | |||
self.entry_name = entry_name | |||
self.spec_name = spec_name | |||
@property | |||
def model(self): | |||
return hub.load( self.hub_root, entry_name=self.entry_name, spec_name=self.spec_name ).eval() | |||
@property | |||
def tag(self): | |||
return hub.load_tags(self.hub_root, entry_name=self.entry_name, spec_name=self.spec_name) | |||
@property | |||
def metadata(self): | |||
return hub.load_metadata(self.hub_root, entry_name=self.entry_name, spec_name=self.spec_name) | |||
class TransferabilityGraph(object): | |||
def __init__(self, model_zoo_set): | |||
self.model_zoo_set = model_zoo_set | |||
# self.model_zoo = os.path.abspath( os.path.expanduser( model_zoo ) ) | |||
self._graphs = dict() | |||
self._models = dict() | |||
self._register_models() | |||
def _register_models(self): | |||
cnt = 0 | |||
for model_zoo in self.model_zoo_set: | |||
model_zoo = os.path.abspath(os.path.expanduser(model_zoo)) | |||
for hub_root in self._list_modelzoo(model_zoo): | |||
for entry_name, spec_name in hub.list_spec(hub_root): | |||
node = Node( hub_root, entry_name, spec_name ) | |||
name = node.metadata['name'] | |||
self._models[name] = node | |||
cnt += 1 | |||
print("%d models has been registered!"%cnt) | |||
def _list_modelzoo(self, zoo_dir): | |||
zoo_list = [] | |||
def _traverse(path): | |||
for item in os.listdir(path): | |||
item_path = os.path.join(path, item) | |||
if os.path.isdir(item_path): | |||
if os.path.exists(os.path.join( item_path, 'code/hubconf.py' )): | |||
zoo_list.append(item_path) | |||
else: | |||
_traverse( item_path ) | |||
_traverse(zoo_dir) | |||
return zoo_list | |||
def add_metric(self, metric_name, metric): | |||
self._graphs[metric_name] = g = nx.DiGraph() | |||
g.add_nodes_from( self._models.values() ) | |||
for n1 in self._models.values(): | |||
for n2 in tqdm(self._models.values()): | |||
if n1!=n2 and not g.has_edge(n1, n2): | |||
try: | |||
g.add_edge(n1, n2, dist=metric( n1, n2 )) | |||
except: | |||
ori_device = metric.device | |||
metric.device = torch.device('cpu') | |||
g.add_edge(n1, n2, dist=metric( n1, n2 )) | |||
metric.device = ori_device | |||
def export_to_json(self, metric_name, output_filename, topk=None, normalize=False): | |||
graph = self._graphs.get( metric_name, None ) | |||
assert graph is not None | |||
graph_data={ | |||
'nodes': [], | |||
'edges': [], | |||
} | |||
node_to_idx = {} | |||
for i, node in enumerate(self._models.values()): | |||
tags = node.tag | |||
metadata = node.metadata | |||
node_data = { k:v for (k, v) in tags.items() if isinstance(v, (numbers.Number, str) ) } | |||
node_data['name'] = metadata['name'] | |||
node_data['task'] = metadata['task'] | |||
node_data['dataset'] = metadata['dataset'] | |||
node_data['url'] = metadata['url'] | |||
node_data['id'] = i | |||
graph_data['nodes'].append({'tags': node_data}) | |||
node_to_idx[node] = i | |||
# record Edges | |||
edge_list = graph_data['edges'] | |||
topk_dist = { idx: [] for idx in range(len( self._models )) } | |||
for i, edge in enumerate(graph.edges.data('dist')): | |||
s, t, d = int( node_to_idx[edge[0]] ), int( node_to_idx[edge[1]] ), float(edge[2]) | |||
topk_dist[s].append(d) | |||
edge_list.append([ | |||
s, t, d # source, target, distance | |||
]) | |||
if isinstance(topk, int): | |||
for i, dist in topk_dist.items(): | |||
dist.sort() | |||
topk_dist[i] = dist[topk] | |||
graph_data['edges'] = [ edge for edge in edge_list if edge[2] < topk_dist[edge[0]] ] | |||
if normalize: | |||
edge_dist = [e[2] for e in graph_data['edges']] | |||
min_dist, max_dist = min(edge_dist), max(edge_dist) | |||
for e in graph_data['edges']: | |||
e[2] = (e[2] - min_dist+1e-8) / (max_dist - min_dist+1e-8) | |||
with open(output_filename, 'w') as fp: | |||
json.dump(graph_data, fp) |
@@ -0,0 +1,109 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
import abc | |||
from . import depara | |||
from captum.attr import InputXGradient | |||
import torch | |||
from kamal.core import hub | |||
from kamal.vision import sync_transforms as sT | |||
class TransMetric(abc.ABC): | |||
def __init__(self): | |||
pass | |||
def __call__(self, a, b) -> float: | |||
return 0 | |||
class DeparaMetric(TransMetric): | |||
def __init__(self, data, device): | |||
self.data = data | |||
self.device = device | |||
self._cache = {} | |||
def _get_transform(self, metadata): | |||
input_metadata = metadata['input'] | |||
size = input_metadata['size'] | |||
space = input_metadata['space'] | |||
drange = input_metadata['range'] | |||
normalize = input_metadata['normalize'] | |||
if size==None: | |||
size=224 | |||
if isinstance(size, (list, tuple)): | |||
size = size[-1] | |||
transform = [ | |||
sT.Resize(size), | |||
sT.CenterCrop(size), | |||
] | |||
if space=='bgr': | |||
transform.append(sT.FlipChannels()) | |||
if list(drange)==[0, 1]: | |||
transform.append( sT.ToTensor() ) | |||
elif list(drange)==[0, 255]: | |||
transform.append( sT.ToTensor(normalize=False, dtype=torch.float) ) | |||
else: | |||
raise NotImplementedError | |||
if normalize is not None: | |||
transform.append(sT.Normalize( mean=normalize['mean'], std=normalize['std'] )) | |||
return sT.Compose(transform) | |||
def _get_attr_graph(self, n): | |||
transform = self._get_transform(n.metadata) | |||
data = torch.stack( [ transform( d ) for d in self.data ], dim=0 ) | |||
return depara.get_attribution_graph( | |||
n.model, | |||
attribution_type=InputXGradient, | |||
with_noise=False, | |||
probe_data=data, | |||
device=self.device | |||
) | |||
def __call__(self, n1, n2): | |||
attrgraph_1 = self._cache.get(n1, None) | |||
attrgraph_2 = self._cache.get(n2, None) | |||
if attrgraph_1 is None: | |||
self._cache[n1] = attrgraph_1 = self._get_attr_graph(n1).cpu() | |||
if attrgraph_2 is None: | |||
self._cache[n2] = attrgraph_2 = self._get_attr_graph(n2).cpu() | |||
result = depara.graph_similarity(attrgraph_1, attrgraph_2) | |||
self._cache[n1] = self._cache[n1] | |||
self._cache[n2] = self._cache[n2] | |||
class AttrMapMetric(DeparaMetric): | |||
def _get_attr_map(self, n): | |||
transform = self._get_transform(n.metadata) | |||
data = torch.stack( [ transform( d ).to(self.device) for d in self.data ], dim=0 ) | |||
return depara.attribution_map( | |||
n.model.to(self.device), | |||
attribution_type=InputXGradient, | |||
with_noise=False, | |||
probe_data=data, | |||
) | |||
def __call__(self, n1, n2): | |||
attrgraph_1 = self._cache.get(n1, None) | |||
attrgraph_2 = self._cache.get(n2, None) | |||
if attrgraph_1 is None: | |||
self._cache[n1] = attrgraph_1 = self._get_attr_map(n1).cpu() | |||
if attrgraph_2 is None: | |||
self._cache[n2] = attrgraph_2 = self._get_attr_map(n2).cpu() | |||
result = depara.attr_map_distance(attrgraph_1, attrgraph_2) | |||
self._cache[n1] = self._cache[n1] | |||
self._cache[n2] = self._cache[n2] | |||
return result | |||
@@ -0,0 +1,2 @@ | |||
from ._utils import * | |||
from .logger import get_logger |
@@ -0,0 +1,153 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
import numpy as np | |||
import math | |||
import torch | |||
import random | |||
from copy import deepcopy | |||
import contextlib, hashlib | |||
def split_batch(batch): | |||
if isinstance(batch, (list, tuple)): | |||
inputs, *targets = batch | |||
if len(targets)==1: | |||
targets = targets[0] | |||
return inputs, targets | |||
else: | |||
return [batch, None] | |||
@contextlib.contextmanager | |||
def set_mode(model, training=True): | |||
ori_mode = model.training | |||
model.train(training) | |||
yield | |||
model.train(ori_mode) | |||
def move_to_device(obj, device): | |||
if isinstance(obj, torch.Tensor): | |||
return obj.to(device=device) | |||
elif isinstance( obj, (list, tuple) ): | |||
return [ o.to(device=device) for o in obj ] | |||
elif isinstance(obj, nn.Module): | |||
return obj.to(device=device) | |||
def pack_images(images, col=None, channel_last=False): | |||
# N, C, H, W | |||
if isinstance(images, (list, tuple) ): | |||
images = np.stack(images, 0) | |||
if channel_last: | |||
images = images.transpose(0,3,1,2) # make it channel first | |||
assert len(images.shape)==4 | |||
assert isinstance(images, np.ndarray) | |||
N,C,H,W = images.shape | |||
if col is None: | |||
col = int(math.ceil(math.sqrt(N))) | |||
row = int(math.ceil(N / col)) | |||
pack = np.zeros( (C, H*row, W*col), dtype=images.dtype ) | |||
for idx, img in enumerate(images): | |||
h = (idx//col) * H | |||
w = (idx% col) * W | |||
pack[:, h:h+H, w:w+W] = img | |||
return pack | |||
def normalize(tensor, mean, std, reverse=False): | |||
if reverse: | |||
_mean = [ -m / s for m, s in zip(mean, std) ] | |||
_std = [ 1/s for s in std ] | |||
else: | |||
_mean = mean | |||
_std = std | |||
_mean = torch.as_tensor(_mean, dtype=tensor.dtype, device=tensor.device) | |||
_std = torch.as_tensor(_std, dtype=tensor.dtype, device=tensor.device) | |||
tensor = (tensor - _mean[None, :, None, None]) / (_std[None, :, None, None]) | |||
return tensor | |||
class Normalizer(object): | |||
def __init__(self, mean, std, reverse=False): | |||
self.mean = mean | |||
self.std = std | |||
self.reverse = reverse | |||
def __call__(self, x): | |||
if self.reverse: | |||
return self.denormalize(x) | |||
else: | |||
return self.normalize(x) | |||
def normalize(self, x): | |||
return normalize( x, self.mean, self.std ) | |||
def denormalize(self, x): | |||
return normalize( x, self.mean, self.std, reverse=True ) | |||
def colormap(N=256, normalized=False): | |||
def bitget(byteval, idx): | |||
return ((byteval & (1 << idx)) != 0) | |||
dtype = 'float32' if normalized else 'uint8' | |||
cmap = np.zeros((N, 3), dtype=dtype) | |||
for i in range(N): | |||
r = g = b = 0 | |||
c = i | |||
for j in range(8): | |||
r = r | (bitget(c, 0) << 7-j) | |||
g = g | (bitget(c, 1) << 7-j) | |||
b = b | (bitget(c, 2) << 7-j) | |||
c = c >> 3 | |||
cmap[i] = np.array([r, g, b]) | |||
cmap = cmap/255 if normalized else cmap | |||
return cmap | |||
DEFAULT_COLORMAP = colormap() | |||
def flatten_dict(dic): | |||
flattned = dict() | |||
def _flatten(prefix, d): | |||
for k, v in d.items(): | |||
if isinstance(v, dict): | |||
if prefix is None: | |||
_flatten( k, v ) | |||
else: | |||
_flatten( prefix+'%s/'%k, v ) | |||
else: | |||
flattned[ (prefix+'%s/'%k).strip('/') ] = v | |||
_flatten('', dic) | |||
return flattned | |||
def setup_seed(seed): | |||
torch.manual_seed(seed) | |||
torch.cuda.manual_seed(seed) | |||
np.random.seed(seed) | |||
random.seed(seed) | |||
def count_parameters(model): | |||
return sum( [ p.numel() for p in model.parameters() ] ) | |||
def md5(fname): | |||
hash_md5 = hashlib.md5() | |||
with open(fname, "rb") as f: | |||
for chunk in iter(lambda: f.read(4096), b""): | |||
hash_md5.update(chunk) | |||
return hash_md5.hexdigest() |
@@ -0,0 +1,56 @@ | |||
import logging | |||
import os, sys | |||
from termcolor import colored | |||
class _ColorfulFormatter(logging.Formatter): | |||
def __init__(self, *args, **kwargs): | |||
super(_ColorfulFormatter, self).__init__(*args, **kwargs) | |||
def formatMessage(self, record): | |||
log = super(_ColorfulFormatter, self).formatMessage(record) | |||
if record.levelno == logging.WARNING: | |||
prefix = colored("WARNING", "yellow", attrs=["blink"]) | |||
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: | |||
prefix = colored("ERROR", "red", attrs=["blink", "underline"]) | |||
else: | |||
return log | |||
return prefix + " " + log | |||
def get_logger(name='Kamal', output=None, color=True): | |||
logger = logging.getLogger(name) | |||
logger.setLevel(logging.DEBUG) | |||
logger.propagate = False | |||
# STDOUT | |||
stdout_handler = logging.StreamHandler( stream=sys.stdout ) | |||
stdout_handler.setLevel( logging.DEBUG ) | |||
plain_formatter = logging.Formatter( | |||
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" ) | |||
if color: | |||
formatter = _ColorfulFormatter( | |||
colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", | |||
datefmt="%m/%d %H:%M:%S") | |||
else: | |||
formatter = plain_formatter | |||
stdout_handler.setFormatter(formatter) | |||
logger.addHandler(stdout_handler) | |||
# FILE | |||
if output is not None: | |||
if output.endswith('.txt') or output.endswith('.log'): | |||
os.makedirs(os.path.dirname(output), exist_ok=True) | |||
filename = output | |||
else: | |||
os.makedirs(output, exist_ok=True) | |||
filename = os.path.join(output, "log.txt") | |||
file_handler = logging.FileHandler(filename) | |||
file_handler.setFormatter(plain_formatter) | |||
file_handler.setLevel(logging.DEBUG) | |||
logger.addHandler(file_handler) | |||
return logger | |||
@@ -0,0 +1,3 @@ | |||
from . import models | |||
from . import datasets | |||
from . import sync_transforms |
@@ -0,0 +1,16 @@ | |||
from .ade20k import ADE20K | |||
from .caltech import Caltech101, Caltech256 | |||
from .camvid import CamVid | |||
from .cityscapes import Cityscapes | |||
from .cub200 import CUB200 | |||
from .fgvc_aircraft import FGVCAircraft | |||
from .nyu import NYUv2 | |||
from .stanford_cars import StanfordCars | |||
from .stanford_dogs import StanfordDogs | |||
from .sunrgbd import SunRGBD | |||
from .voc import VOCClassification, VOCSegmentation | |||
from .dataset import LabelConcatDataset | |||
from torchvision import datasets as torchvision_datasets | |||
from .unlabeled import UnlabeledDataset |
@@ -0,0 +1,70 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
import collections | |||
import torch | |||
import torchvision | |||
import numpy as np | |||
from PIL import Image | |||
import os | |||
from torchvision.datasets import VisionDataset | |||
from .utils import colormap | |||
class ADE20K(VisionDataset): | |||
cmap = colormap() | |||
def __init__( | |||
self, | |||
root, | |||
split="training", | |||
transform=None, | |||
target_transform=None, | |||
transforms=None, | |||
): | |||
super( ADE20K, self ).__init__( root=root, transforms=transforms, transform=transform, target_transform=target_transform ) | |||
assert split in ['training', 'validation'], "split should be \'training\' or \'validation\'" | |||
self.root = os.path.expanduser(root) | |||
self.split = split | |||
self.num_classes = 150 | |||
img_list = [] | |||
lbl_list = [] | |||
img_dir = os.path.join( self.root, 'images', self.split ) | |||
lbl_dir = os.path.join( self.root, 'annotations', self.split ) | |||
for img_name in os.listdir( img_dir ): | |||
img_list.append( os.path.join( img_dir, img_name ) ) | |||
lbl_list.append( os.path.join( lbl_dir, img_name[:-3]+'png') ) | |||
self.img_list = img_list | |||
self.lbl_list = lbl_list | |||
def __len__(self): | |||
return len(self.img_list) | |||
def __getitem__(self, index): | |||
img = Image.open( self.img_list[index] ) | |||
lbl = Image.open( self.lbl_list[index] ) | |||
if self.transforms: | |||
img, lbl = self.transforms(img, lbl) | |||
lbl = np.array(lbl, dtype='uint8')-1 # 1-150 => 0-149 + 255 | |||
return img, lbl | |||
@classmethod | |||
def decode_seg_to_color(cls, mask): | |||
"""decode semantic mask to RGB image""" | |||
return cls.cmap[mask+1] |
@@ -0,0 +1,226 @@ | |||
# Modified from https://github.com/pytorch/vision/blob/master/torchvision/datasets/caltech.py | |||
from __future__ import print_function | |||
from PIL import Image | |||
import os | |||
import os.path | |||
from torchvision.datasets.vision import VisionDataset | |||
from torchvision.datasets.utils import download_url | |||
class Caltech101(VisionDataset): | |||
"""`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset. | |||
Args: | |||
root (string): Root directory of dataset where directory | |||
``caltech101`` exists or will be saved to if download is set to True. | |||
target_type (string or list, optional): Type of target to use, ``category`` or | |||
``annotation``. Can also be a list to output a tuple with all specified target types. | |||
``category`` represents the target class, and ``annotation`` is a list of points | |||
from a hand-generated outline. Defaults to ``category``. | |||
transform (callable, optional): A function/transform that takes in an PIL image | |||
and returns a transformed version. E.g, ``transforms.RandomCrop`` | |||
target_transform (callable, optional): A function/transform that takes in the | |||
target and transforms it. | |||
download (bool, optional): If true, downloads the dataset from the internet and | |||
puts it in root directory. If dataset is already downloaded, it is not | |||
downloaded again. | |||
""" | |||
def __init__(self, root, target_type="category", train=True, | |||
transform=None, target_transform=None, | |||
download=False): | |||
super(Caltech101, self).__init__(os.path.join(root, 'caltech101')) | |||
self.train = train | |||
self.dir_name = '101_ObjectCategories_split/train' if self.train else '101_ObjectCategories_split/test' | |||
os.makdirs(self.root, exist_ok=True) | |||
if isinstance(target_type, list): | |||
self.target_type = target_type | |||
else: | |||
self.target_type = [target_type] | |||
self.transform = transform | |||
self.target_transform = target_transform | |||
if download: | |||
self.download() | |||
if not self._check_integrity(): | |||
raise RuntimeError('Dataset not found or corrupted.' + | |||
' You can use download=True to download it') | |||
self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories"))) | |||
self.categories.remove("BACKGROUND_Google") # this is not a real class | |||
# For some reason, the category names in "101_ObjectCategories" and | |||
# "Annotations" do not always match. This is a manual map between the | |||
# two. Defaults to using same name, since most names are fine. | |||
name_map = {"Faces": "Faces_2", | |||
"Faces_easy": "Faces_3", | |||
"Motorbikes": "Motorbikes_16", | |||
"airplanes": "Airplanes_Side_2"} | |||
self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories)) | |||
self.index = [] | |||
self.y = [] | |||
for (i, c) in enumerate(self.categories): | |||
file_names = os.listdir(os.path.join(self.root, self.dir_name, c)) | |||
n = len(file_names) | |||
self.index.extend( file_names ) | |||
self.y.extend(n * [i]) | |||
print(self.train, len(self.index)) | |||
def __getitem__(self, index): | |||
""" | |||
Args: | |||
index (int): Index | |||
Returns: | |||
tuple: (image, target) where the type of target specified by target_type. | |||
""" | |||
import scipy.io | |||
img = Image.open(os.path.join(self.root, | |||
self.dir_name, | |||
self.categories[self.y[index]], | |||
self.index[index])).convert("RGB") | |||
target = [] | |||
for t in self.target_type: | |||
if t == "category": | |||
target.append(self.y[index]) | |||
elif t == "annotation": | |||
data = scipy.io.loadmat(os.path.join(self.root, | |||
"Annotations", | |||
self.annotation_categories[self.y[index]], | |||
"annotation_{:04d}.mat".format(self.index[index]))) | |||
target.append(data["obj_contour"]) | |||
else: | |||
raise ValueError("Target type \"{}\" is not recognized.".format(t)) | |||
target = tuple(target) if len(target) > 1 else target[0] | |||
if self.transform is not None: | |||
img = self.transform(img) | |||
if self.target_transform is not None: | |||
target = self.target_transform(target) | |||
return img, target | |||
def _check_integrity(self): | |||
# can be more robust and check hash of files | |||
return os.path.exists(os.path.join(self.root, "101_ObjectCategories")) | |||
def __len__(self): | |||
return len(self.index) | |||
def download(self): | |||
import tarfile | |||
if self._check_integrity(): | |||
print('Files already downloaded and verified') | |||
return | |||
download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz", | |||
self.root, | |||
"101_ObjectCategories.tar.gz", | |||
"b224c7392d521a49829488ab0f1120d9") | |||
download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar", | |||
self.root, | |||
"101_Annotations.tar", | |||
"6f83eeb1f24d99cab4eb377263132c91") | |||
# extract file | |||
with tarfile.open(os.path.join(self.root, "101_ObjectCategories.tar.gz"), "r:gz") as tar: | |||
tar.extractall(path=self.root) | |||
with tarfile.open(os.path.join(self.root, "101_Annotations.tar"), "r:") as tar: | |||
tar.extractall(path=self.root) | |||
def extra_repr(self): | |||
return "Target type: {target_type}".format(**self.__dict__) | |||
class Caltech256(VisionDataset): | |||
"""`Caltech 256 <http://www.vision.caltech.edu/Image_Datasets/Caltech256/>`_ Dataset. | |||
Args: | |||
root (string): Root directory of dataset where directory | |||
``caltech256`` exists or will be saved to if download is set to True. | |||
transform (callable, optional): A function/transform that takes in an PIL image | |||
and returns a transformed version. E.g, ``transforms.RandomCrop`` | |||
target_transform (callable, optional): A function/transform that takes in the | |||
target and transforms it. | |||
download (bool, optional): If true, downloads the dataset from the internet and | |||
puts it in root directory. If dataset is already downloaded, it is not | |||
downloaded again. | |||
""" | |||
def __init__(self, root, | |||
transform=None, target_transform=None, | |||
download=False): | |||
super(Caltech256, self).__init__(os.path.join(root, 'caltech256')) | |||
os.makedirs(self.root, exist_ok=True) | |||
self.transform = transform | |||
self.target_transform = target_transform | |||
if download: | |||
self.download() | |||
if not self._check_integrity(): | |||
raise RuntimeError('Dataset not found or corrupted.' + | |||
' You can use download=True to download it') | |||
self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories"))) | |||
self.index = [] | |||
self.y = [] | |||
for (i, c) in enumerate(self.categories): | |||
n = len(os.listdir(os.path.join(self.root, "256_ObjectCategories", c))) | |||
self.index.extend(range(1, n + 1)) | |||
self.y.extend(n * [i]) | |||
def __getitem__(self, index): | |||
""" | |||
Args: | |||
index (int): Index | |||
Returns: | |||
tuple: (image, target) where target is index of the target class. | |||
""" | |||
img = Image.open(os.path.join(self.root, | |||
"256_ObjectCategories", | |||
self.categories[self.y[index]], | |||
"{:03d}_{:04d}.jpg".format(self.y[index] + 1, self.index[index]))) | |||
target = self.y[index] | |||
if self.transform is not None: | |||
img = self.transform(img) | |||
if self.target_transform is not None: | |||
target = self.target_transform(target) | |||
return img, target | |||
def _check_integrity(self): | |||
# can be more robust and check hash of files | |||
return os.path.exists(os.path.join(self.root, "256_ObjectCategories")) | |||
def __len__(self): | |||
return len(self.index) | |||
def download(self): | |||
import tarfile | |||
if self._check_integrity(): | |||
print('Files already downloaded and verified') | |||
return | |||
download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar", | |||
self.root, | |||
"256_ObjectCategories.tar", | |||
"67b4f42ca05d46448c6bb8ecd2220f6d") | |||
# extract file | |||
with tarfile.open(os.path.join(self.root, "256_ObjectCategories.tar"), "r:") as tar: | |||
tar.extractall(path=self.root) |
@@ -0,0 +1,78 @@ | |||
# Modified from https://github.com/davidtvs/PyTorch-ENet/blob/master/data/camvid.py | |||
import os | |||
import torch.utils.data as data | |||
from glob import glob | |||
from PIL import Image | |||
import numpy as np | |||
from torchvision.datasets import VisionDataset | |||
class CamVid(VisionDataset): | |||
"""CamVid dataset loader where the dataset is arranged as in https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid. | |||
Args: | |||
root (string): | |||
split (string): The type of dataset: 'train', 'val', 'trainval', or 'test' | |||
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. Default: None. | |||
target_transform (callable, optional): A function/transform that takes in the target and transform it. Default: None. | |||
transforms (callable, optional): A function/transform that takes in both the image and target and transform them. Default: None. | |||
""" | |||
cmap = np.array([ | |||
(128, 128, 128), | |||
(128, 0, 0), | |||
(192, 192, 128), | |||
(128, 64, 128), | |||
(60, 40, 222), | |||
(128, 128, 0), | |||
(192, 128, 128), | |||
(64, 64, 128), | |||
(64, 0, 128), | |||
(64, 64, 0), | |||
(0, 128, 192), | |||
(0, 0, 0), | |||
]) | |||
def __init__(self, | |||
root, | |||
split='train', | |||
transform=None, | |||
target_transform=None, | |||
transforms=None): | |||
assert split in ('train', 'val', 'test', 'trainval') | |||
super( CamVid, self ).__init__(root=root, transforms=transforms, transform=transform, target_transform=target_transform) | |||
self.root = os.path.expanduser(root) | |||
self.split = split | |||
if split == 'trainval': | |||
self.images = glob(os.path.join(self.root, 'train', '*.png')) + glob(os.path.join(self.root, 'val', '*.png')) | |||
self.labels = glob(os.path.join(self.root, 'trainannot', '*.png')) + glob(os.path.join(self.root, 'valannot', '*.png')) | |||
else: | |||
self.images = glob(os.path.join(self.root, self.split, '*.png')) | |||
self.labels = glob(os.path.join(self.root, self.split+'annot', '*.png')) | |||
self.images.sort() | |||
self.labels.sort() | |||
def __getitem__(self, idx): | |||
""" | |||
Args: | |||
- index (``int``): index of the item in the dataset | |||
Returns: | |||
A tuple of ``PIL.Image`` (image, label) where label is the ground-truth | |||
of the image. | |||
""" | |||
img, label = Image.open(self.images[idx]), Image.open(self.labels[idx]) | |||
if self.transforms is not None: | |||
img, label = self.transforms(img, label) | |||
label[label == 11] = 255 # ignore void | |||
return img, label.squeeze(0) | |||
def __len__(self): | |||
return len(self.images) | |||
@classmethod | |||
def decode_fn(cls, mask): | |||
"""decode semantic mask to RGB image""" | |||
mask[mask == 255] = 11 | |||
return cls.cmap[mask] |
@@ -0,0 +1,146 @@ | |||
# Modified from https://github.com/pytorch/vision/blob/master/torchvision/datasets/cityscapes.py | |||
import json | |||
import os | |||
from collections import namedtuple | |||
import torch | |||
import torch.utils.data as data | |||
from PIL import Image | |||
import numpy as np | |||
from torchvision.datasets import VisionDataset | |||
class Cityscapes(VisionDataset): | |||
"""Cityscapes <http://www.cityscapes-dataset.com/> Dataset. | |||
Args: | |||
root (string): Root directory of dataset where directory 'leftImg8bit' and 'gtFine' or 'gtCoarse' are located. | |||
split (string, optional): The image split to use, 'train', 'test' or 'val' if mode="gtFine" otherwise 'train', 'train_extra' or 'val' | |||
mode (string, optional): The quality mode to use, 'gtFine' or 'gtCoarse' or 'color'. Can also be a list to output a tuple with all specified target types. | |||
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` | |||
target_transform (callable, optional): A function/transform that takes in the target and transforms it. | |||
""" | |||
# Based on https://github.com/mcordts/cityscapesScripts | |||
CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id', | |||
'has_instances', 'ignore_in_eval', 'color']) | |||
classes = [ | |||
CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)), | |||
CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)), | |||
CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)), | |||
CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)), | |||
CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)), | |||
CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)), | |||
CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)), | |||
CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)), | |||
CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)), | |||
CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)), | |||
CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)), | |||
CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)), | |||
CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)), | |||
CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)), | |||
CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)), | |||
CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)), | |||
CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)), | |||
CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)), | |||
CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)), | |||
CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)), | |||
CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)), | |||
CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)), | |||
CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)), | |||
CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)), | |||
CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)), | |||
CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)), | |||
CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)), | |||
CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)), | |||
CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)), | |||
CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)), | |||
CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)), | |||
CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)), | |||
CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)), | |||
CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)), | |||
CityscapesClass('license plate', -1, 255, 'vehicle', 7, False, True, (0, 0, 142)), | |||
] | |||
_TRAIN_ID_TO_COLOR = [c.color for c in classes if (c.train_id != -1 and c.train_id != 255)] | |||
_TRAIN_ID_TO_COLOR.append([0, 0, 0]) | |||
_TRAIN_ID_TO_COLOR = np.array(_TRAIN_ID_TO_COLOR) | |||
_ID_TO_TRAIN_ID = np.array([c.train_id for c in classes]) | |||
def __init__(self, root, split='train', mode='gtFine', target_type='semantic', transform=None, target_transform=None, transforms=None): | |||
super(Cityscapes, self).__init__( root, transform=transform, target_transform=target_transform, transforms=transforms ) | |||
self.root = os.path.expanduser(root) | |||
self.mode = mode | |||
self.target_type = target_type | |||
self.images_dir = os.path.join(self.root, 'leftImg8bit', split) | |||
self.targets_dir = os.path.join(self.root, self.mode, split) | |||
self.split = split | |||
self.images = [] | |||
self.targets = [] | |||
if split not in ['train', 'test', 'val']: | |||
raise ValueError('Invalid split for mode! Please use split="train", split="test"' | |||
' or split="val"') | |||
if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir): | |||
raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the' | |||
' specified "split" and "mode" are inside the "root" directory') | |||
for city in os.listdir(self.images_dir): | |||
img_dir = os.path.join(self.images_dir, city) | |||
target_dir = os.path.join(self.targets_dir, city) | |||
for file_name in os.listdir(img_dir): | |||
self.images.append(os.path.join(img_dir, file_name)) | |||
target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0], | |||
self._get_target_suffix(self.mode, self.target_type)) | |||
self.targets.append(os.path.join(target_dir, target_name)) | |||
@classmethod | |||
def encode_target(cls, target): | |||
if isinstance( target, torch.Tensor ): | |||
return torch.from_numpy( cls._ID_TO_TRAIN_ID[np.array(target)] ) | |||
else: | |||
return cls._ID_TO_TRAIN_ID[target] | |||
@classmethod | |||
def decode_fn(cls, target): | |||
target[target == 255] = 19 | |||
#target = target.astype('uint8') + 1 | |||
return cls._TRAIN_ID_TO_COLOR[target] | |||
def __getitem__(self, index): | |||
""" | |||
Args: | |||
index (int): Index | |||
Returns: | |||
tuple: (image, target) where target is a tuple of all target types if target_type is a list with more | |||
than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation. | |||
""" | |||
image = Image.open(self.images[index]).convert('RGB') | |||
target = Image.open(self.targets[index]) | |||
if self.transforms: | |||
image, target = self.transforms(image, target) | |||
target = self.encode_target(target) | |||
return image, target | |||
def __len__(self): | |||
return len(self.images) | |||
def _load_json(self, path): | |||
with open(path, 'r') as file: | |||
data = json.load(file) | |||
return data | |||
def _get_target_suffix(self, mode, target_type): | |||
if target_type == 'instance': | |||
return '{}_instanceIds.png'.format(mode) | |||
elif target_type == 'semantic': | |||
return '{}_labelIds.png'.format(mode) | |||
elif target_type == 'color': | |||
return '{}_color.png'.format(mode) | |||
elif target_type == 'polygon': | |||
return '{}_polygons.json'.format(mode) | |||
elif target_type == 'depth': | |||
return '{}_disparity.png'.format(mode) |
@@ -0,0 +1,70 @@ | |||
# Modified from https://github.com/TDeVries/cub2011_dataset/blob/master/cub2011.py | |||
import os | |||
import pandas as pd | |||
from torchvision.datasets.folder import default_loader | |||
from .utils import download_url | |||
from torch.utils.data import Dataset | |||
import shutil | |||
class CUB200(Dataset): | |||
base_folder = 'images' | |||
url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz' | |||
filename = 'CUB_200_2011.tgz' | |||
tgz_md5 = '97eceeb196236b17998738112f37df78' | |||
def __init__(self, root, split='train', transform=None, target_transform=None, loader=default_loader, download=False): | |||
self.root = os.path.abspath( os.path.expanduser( root ) ) | |||
self.transform = transform | |||
self.target_transform = target_transform | |||
self.loader = default_loader | |||
self.split = split | |||
if download: | |||
self.download() | |||
self._load_metadata() | |||
categories = os.listdir(os.path.join( | |||
self.root, 'CUB_200_2011', 'images')) | |||
categories.sort() | |||
self.object_categories = [c[4:] for c in categories] | |||
print('CUB200, Split: %s, Size: %d' % (self.split, self.__len__())) | |||
def _load_metadata(self): | |||
images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ', | |||
names=['img_id', 'filepath']) | |||
image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'), | |||
sep=' ', names=['img_id', 'target'], encoding='latin-1', engine='python') | |||
train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'), | |||
sep=' ', names=['img_id', 'is_training_img'], encoding='latin-1', engine='python') | |||
data = images.merge(image_class_labels, on='img_id') | |||
self.data = data.merge(train_test_split, on='img_id') | |||
if self.split == 'train': | |||
self.data = self.data[self.data.is_training_img == 1] | |||
else: | |||
self.data = self.data[self.data.is_training_img == 0] | |||
def download(self): | |||
import tarfile | |||
os.makedirs(self.root, exist_ok=True) | |||
if not os.path.isfile(os.path.join(self.root, self.filename)): | |||
download_url(self.url, self.root, self.filename) | |||
print("Extracting %s..." % self.filename) | |||
with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: | |||
tar.extractall(path=self.root) | |||
def __len__(self): | |||
return len(self.data) | |||
def __getitem__(self, idx): | |||
sample = self.data.iloc[idx] | |||
path = os.path.join(self.root, 'CUB_200_2011', | |||
self.base_folder, sample.filepath) | |||
lbl = sample.target - 1 | |||
img = self.loader(path) | |||
if self.transform is not None: | |||
img = self.transform(img) | |||
if self.target_transform is not None: | |||
lbl = self.target_transform(lbl) | |||
return img, lbl |
@@ -0,0 +1,57 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
from torchvision.datasets import VisionDataset | |||
from PIL import Image | |||
import torch | |||
class LabelConcatDataset(VisionDataset): | |||
"""Dataset as a concatenation of dataset's lables. | |||
This class is useful to assemble the same dataset's labels. | |||
Arguments: | |||
datasets (sequence): List of datasets to be concatenated | |||
tasks (list) : List of teacher tasks | |||
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. | |||
target_transform (callable, optional): A function/transform that takes in the target and transforms it. | |||
transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version. | |||
""" | |||
def __init__(self, datasets, transforms=None, transform=None, target_transform=None): | |||
super(LabelConcatDataset, self).__init__( | |||
root=None, transforms=transforms, transform=transform, target_transform=target_transform) | |||
assert len(datasets) > 0, 'datasets should not be an empty iterable' | |||
self.datasets = list(datasets) | |||
for d in self.datasets: | |||
for t in ('transform', 'transforms', 'target_transform'): | |||
if hasattr( d, t ): | |||
setattr( d, t, None ) | |||
def __getitem__(self, idx): | |||
targets_list = [] | |||
for dst in self.datasets: | |||
image, target = dst[idx] | |||
targets_list.append(target) | |||
if self.transforms is not None: | |||
image, *targets_list = self.transforms( image, *targets_list ) | |||
data = [ image, *targets_list ] | |||
return data | |||
def __len__(self): | |||
return len(self.datasets[0].images) |
@@ -0,0 +1,143 @@ | |||
#Modified from https://github.com/pytorch/vision/pull/467/files | |||
from __future__ import print_function | |||
import torch.utils.data as data | |||
from torchvision.datasets.folder import pil_loader, accimage_loader, default_loader | |||
from PIL import Image | |||
import os | |||
import numpy as np | |||
from .utils import download_url, mkdir | |||
def make_dataset(dir, image_ids, targets): | |||
assert(len(image_ids) == len(targets)) | |||
images = [] | |||
dir = os.path.expanduser(dir) | |||
for i in range(len(image_ids)): | |||
item = (os.path.join(dir, 'fgvc-aircraft-2013b', 'data', 'images', | |||
'%s.jpg' % image_ids[i]), targets[i]) | |||
images.append(item) | |||
return images | |||
def find_classes(classes_file): | |||
# read classes file, separating out image IDs and class names | |||
image_ids = [] | |||
targets = [] | |||
f = open(classes_file, 'r') | |||
for line in f: | |||
split_line = line.split(' ') | |||
image_ids.append(split_line[0]) | |||
targets.append(' '.join(split_line[1:])) | |||
f.close() | |||
# index class names | |||
classes = np.unique(targets) | |||
class_to_idx = {classes[i]: i for i in range(len(classes))} | |||
targets = [class_to_idx[c] for c in targets] | |||
return (image_ids, targets, classes, class_to_idx) | |||
class FGVCAircraft(data.Dataset): | |||
"""`FGVC-Aircraft <http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft>`_ Dataset. | |||
Args: | |||
root (string): Root directory path to dataset. | |||
class_type (string, optional): The level of FGVC-Aircraft fine-grain classification | |||
to label data with (i.e., ``variant``, ``family``, or ``manufacturer``). | |||
transforms (callable, optional): A function/transforms that takes in a PIL image | |||
and returns a transformed version. E.g. ``transforms.RandomCrop`` | |||
target_transform (callable, optional): A function/transforms that takes in the | |||
target and transforms it. | |||
loader (callable, optional): A function to load an image given its path. | |||
download (bool, optional): If true, downloads the dataset from the internet and | |||
puts it in the root directory. If dataset is already downloaded, it is not | |||
downloaded again. | |||
""" | |||
url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz' | |||
class_types = ('variant', 'family', 'manufacturer') | |||
splits = ('train', 'val', 'trainval', 'test') | |||
def __init__(self, root, class_type='variant', split='train', transform=None, | |||
target_transform=None, loader=default_loader, download=False): | |||
if split not in self.splits: | |||
raise ValueError('Split "{}" not found. Valid splits are: {}'.format( | |||
split, ', '.join(self.splits), | |||
)) | |||
if class_type not in self.class_types: | |||
raise ValueError('Class type "{}" not found. Valid class types are: {}'.format( | |||
class_type, ', '.join(self.class_types), | |||
)) | |||
self.root = root | |||
self.class_type = class_type | |||
self.split = split | |||
self.classes_file = os.path.join(self.root, 'fgvc-aircraft-2013b', 'data', | |||
'images_%s_%s.txt' % (self.class_type, self.split)) | |||
if download: | |||
self.download() | |||
(image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file) | |||
samples = make_dataset(self.root, image_ids, targets) | |||
self.transform = transform | |||
self.target_transform = target_transform | |||
self.loader = loader | |||
self.samples = samples | |||
self.classes = classes | |||
self.class_to_idx = class_to_idx | |||
with open(os.path.join(self.root, 'fgvc-aircraft-2013b/data', 'variants.txt')) as f: | |||
self.object_categories = [ | |||
line.strip('\n') for line in f.readlines()] | |||
print('FGVC-Aircraft, Split: %s, Size: %d' % (self.split, self.__len__())) | |||
def __getitem__(self, index): | |||
""" | |||
Args: | |||
index (int): Index | |||
Returns: | |||
tuple: (sample, target) where target is class_index of the target class. | |||
""" | |||
path, target = self.samples[index] | |||
sample = self.loader(path) | |||
if self.transform is not None: | |||
sample = self.transform(sample) | |||
if self.target_transform is not None: | |||
target = self.target_transform(target) | |||
return sample, target | |||
def __len__(self): | |||
return len(self.samples) | |||
def __repr__(self): | |||
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' | |||
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) | |||
fmt_str += ' Root Location: {}\n'.format(self.root) | |||
tmp = ' Transforms (if any): ' | |||
fmt_str += '{0}{1}\n'.format( | |||
tmp, self.transforms.__repr__().replace('\n', '\n' + ' ' * len(tmp))) | |||
tmp = ' Target Transforms (if any): ' | |||
fmt_str += '{0}{1}'.format( | |||
tmp, self.target_transforms.__repr__().replace('\n', '\n' + ' ' * len(tmp))) | |||
return fmt_str | |||
def _check_exists(self): | |||
return os.path.exists(os.path.join(self.root, 'data', 'images')) and \ | |||
os.path.exists(self.classes_file) | |||
def download(self): | |||
"""Download the FGVC-Aircraft data if it doesn't exist already.""" | |||
from six.moves import urllib | |||
import tarfile | |||
mkdir(self.root) | |||
fpath = os.path.join(self.root, 'fgvc-aircraft-2013b.tar.gz') | |||
if not os.path.isfile(fpath): | |||
download_url(self.url, self.root, 'fgvc-aircraft-2013b.tar.gz') | |||
print("Extracting fgvc-aircraft-2013b.tar.gz...") | |||
with tarfile.open(fpath, "r:gz") as tar: | |||
tar.extractall(path=self.root) |
@@ -0,0 +1,84 @@ | |||
# Modified from https://github.com/VainF/nyuv2-python-toolkit | |||
import os | |||
import torch | |||
import torch.utils.data as data | |||
from PIL import Image | |||
from scipy.io import loadmat | |||
import numpy as np | |||
import glob | |||
from torchvision import transforms | |||
from torchvision.datasets import VisionDataset | |||
import random | |||
from .utils import colormap | |||
class NYUv2(VisionDataset): | |||
"""NYUv2 dataset | |||
See https://github.com/VainF/nyuv2-python-toolkit for more details. | |||
Args: | |||
root (string): Root directory path. | |||
split (string, optional): 'train' for training set, and 'test' for test set. Default: 'train'. | |||
target_type (string, optional): Type of target to use, ``semantic``, ``depth`` or ``normal``. | |||
num_classes (int, optional): The number of classes, must be 40 or 13. Default:13. | |||
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. | |||
target_transform (callable, optional): A function/transform that takes in the target and transforms it. | |||
transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version. | |||
""" | |||
cmap = colormap() | |||
def __init__(self, | |||
root, | |||
split='train', | |||
target_type='semantic', | |||
num_classes=13, | |||
transforms=None, | |||
transform=None, | |||
target_transform=None): | |||
super( NYUv2, self ).__init__(root, transforms=transforms, transform=transform, target_transform=target_transform) | |||
assert(split in ('train', 'test')) | |||
self.root = root | |||
self.split = split | |||
self.target_type = target_type | |||
self.num_classes = num_classes | |||
split_mat = loadmat(os.path.join(self.root, 'splits.mat')) | |||
idxs = split_mat[self.split+'Ndxs'].reshape(-1) - 1 | |||
img_names = os.listdir( os.path.join(self.root, 'image', self.split) ) | |||
img_names.sort() | |||
images_dir = os.path.join(self.root, 'image', self.split) | |||
self.images = [os.path.join(images_dir, name) for name in img_names] | |||
self._is_depth = False | |||
if self.target_type=='semantic': | |||
semantic_dir = os.path.join(self.root, 'seg%d'%self.num_classes, self.split) | |||
self.labels = [os.path.join(semantic_dir, name) for name in img_names] | |||
self.targets = self.labels | |||
if self.target_type=='depth': | |||
depth_dir = os.path.join(self.root, 'depth', self.split) | |||
self.depths = [os.path.join(depth_dir, name) for name in img_names] | |||
self.targets = self.depths | |||
self._is_depth = True | |||
if self.target_type=='normal': | |||
normal_dir = os.path.join(self.root, 'normal', self.split) | |||
self.normals = [os.path.join(normal_dir, name) for name in img_names] | |||
self.targets = self.normals | |||
def __getitem__(self, idx): | |||
image = Image.open(self.images[idx]) | |||
target = Image.open(self.targets[idx]) | |||
if self.transforms is not None: | |||
image, target = self.transforms( image, target ) | |||
return image, target | |||
def __len__(self): | |||
return len(self.images) | |||
@classmethod | |||
def decode_fn(cls, mask: np.ndarray): | |||
"""decode semantic mask to RGB image""" | |||
mask = mask.astype('uint8') + 1 # 255 => 0 | |||
return cls.cmap[mask] |
@@ -0,0 +1,61 @@ | |||
import os, sys | |||
from glob import glob | |||
import random | |||
import argparse | |||
from PIL import Image | |||
if __name__=='__main__': | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument('--data_root', type=str, default='../101_ObjectCategories') | |||
parser.add_argument('--test_split', type=float, default=0.3) | |||
args = parser.parse_args() | |||
SAVE_DIR = os.path.join( os.path.dirname(args.data_root), 'caltech101_data' ) | |||
if not os.path.exists(SAVE_DIR): | |||
os.mkdir(SAVE_DIR) | |||
# Train | |||
TRAIN_DIR = os.path.join( SAVE_DIR, 'train' ) | |||
if not os.path.exists(TRAIN_DIR): | |||
os.mkdir(TRAIN_DIR) | |||
# Test | |||
TEST_DIR = os.path.join( SAVE_DIR, 'test' ) | |||
if not os.path.exists(TEST_DIR): | |||
os.mkdir(TEST_DIR) | |||
img_folders = os.listdir(args.data_root) | |||
img_folders.sort() | |||
for folder in img_folders: | |||
if folder=='Faces': | |||
continue | |||
print('Processing %s'%(folder)) | |||
img_paths = glob(os.path.join( args.data_root, folder, '*.jpg') ) | |||
img_name = [os.path.split(p)[-1] for p in img_paths] | |||
random.shuffle(img_name) | |||
img_n = len(img_name) | |||
test_n = int(args.test_split * img_n) | |||
test_set = img_name[:test_n] | |||
train_set = img_name[test_n:] | |||
# test | |||
dst_path = os.path.join(TEST_DIR, folder) | |||
if not os.path.exists(dst_path): | |||
os.mkdir(dst_path) | |||
for test_name in test_set: | |||
img = Image.open(os.path.join( args.data_root, folder, test_name )) | |||
img.save( os.path.join(dst_path, test_name ) ) | |||
# train | |||
dst_path = os.path.join(TRAIN_DIR, folder) | |||
if not os.path.exists(dst_path): | |||
os.mkdir(dst_path) | |||
for train_name in train_set: | |||
img = Image.open(os.path.join( args.data_root, folder, train_name )) | |||
img.save( os.path.join(dst_path, train_name ) ) |
@@ -0,0 +1,196 @@ | |||
from __future__ import print_function | |||
import sys | |||
import os, sys, tarfile, errno | |||
import numpy as np | |||
import matplotlib.pyplot as plt | |||
import argparse | |||
if sys.version_info >= (3, 0, 0): | |||
import urllib.request as urllib # ugly but works | |||
else: | |||
import urllib | |||
try: | |||
from imageio import imsave | |||
except: | |||
from scipy.misc import imsave | |||
print(sys.version_info) | |||
# image shape | |||
HEIGHT = 96 | |||
WIDTH = 96 | |||
DEPTH = 3 | |||
# size of a single image in bytes | |||
SIZE = HEIGHT * WIDTH * DEPTH | |||
# path to the directory with the data | |||
#DATA_DIR = './data' | |||
# url of the binary data | |||
#DATA_URL = 'http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz' | |||
# path to the binary train file with image data | |||
#DATA_PATH = './data/stl10_binary/train_X.bin' | |||
# path to the binary train file with labels | |||
#LABEL_PATH = './data/stl10_binary/train_y.bin' | |||
def read_labels(path_to_labels): | |||
""" | |||
:param path_to_labels: path to the binary file containing labels from the STL-10 dataset | |||
:return: an array containing the labels | |||
""" | |||
with open(path_to_labels, 'rb') as f: | |||
labels = np.fromfile(f, dtype=np.uint8) | |||
return labels | |||
def read_all_images(path_to_data): | |||
""" | |||
:param path_to_data: the file containing the binary images from the STL-10 dataset | |||
:return: an array containing all the images | |||
""" | |||
with open(path_to_data, 'rb') as f: | |||
# read whole file in uint8 chunks | |||
everything = np.fromfile(f, dtype=np.uint8) | |||
# We force the data into 3x96x96 chunks, since the | |||
# images are stored in "column-major order", meaning | |||
# that "the first 96*96 values are the red channel, | |||
# the next 96*96 are green, and the last are blue." | |||
# The -1 is since the size of the pictures depends | |||
# on the input file, and this way numpy determines | |||
# the size on its own. | |||
images = np.reshape(everything, (-1, 3, 96, 96)) | |||
# Now transpose the images into a standard image format | |||
# readable by, for example, matplotlib.imshow | |||
# You might want to comment this line or reverse the shuffle | |||
# if you will use a learning algorithm like CNN, since they like | |||
# their channels separated. | |||
images = np.transpose(images, (0, 3, 2, 1)) | |||
return images | |||
def read_single_image(image_file): | |||
""" | |||
CAREFUL! - this method uses a file as input instead of the path - so the | |||
position of the reader will be remembered outside of context of this method. | |||
:param image_file: the open file containing the images | |||
:return: a single image | |||
""" | |||
# read a single image, count determines the number of uint8's to read | |||
image = np.fromfile(image_file, dtype=np.uint8, count=SIZE) | |||
# force into image matrix | |||
image = np.reshape(image, (3, 96, 96)) | |||
# transpose to standard format | |||
# You might want to comment this line or reverse the shuffle | |||
# if you will use a learning algorithm like CNN, since they like | |||
# their channels separated. | |||
image = np.transpose(image, (2, 1, 0)) | |||
return image | |||
def plot_image(image): | |||
""" | |||
:param image: the image to be plotted in a 3-D matrix format | |||
:return: None | |||
""" | |||
plt.imshow(image) | |||
plt.show() | |||
def save_image(image, name): | |||
imsave("%s.png" % name, image, format="png") | |||
def download_and_extract(DATA_DIR): | |||
""" | |||
Download and extract the STL-10 dataset | |||
:return: None | |||
""" | |||
dest_directory = DATA_DIR | |||
if not os.path.exists(dest_directory): | |||
os.makedirs(dest_directory) | |||
filename = DATA_URL.split('/')[-1] | |||
filepath = os.path.join(dest_directory, filename) | |||
if not os.path.exists(filepath): | |||
def _progress(count, block_size, total_size): | |||
sys.stdout.write('\rDownloading %s %.2f%%' % (filename, | |||
float(count * block_size) / float(total_size) * 100.0)) | |||
sys.stdout.flush() | |||
filepath, _ = urllib.urlretrieve(DATA_URL, filepath, reporthook=_progress) | |||
print('Downloaded', filename) | |||
tarfile.open(filepath, 'r:gz').extractall(dest_directory) | |||
def save_images(images, labels, save_dir): | |||
print("Saving images to disk") | |||
i = 0 | |||
if not os.path.exists(save_dir): | |||
os.mkdir(save_dir) | |||
for image in images: | |||
if labels is not None: | |||
label = labels[i] | |||
directory = os.path.join( save_dir, str(label) ) | |||
if not os.path.exists(directory): | |||
os.makedirs(directory) | |||
else: | |||
directory = save_dir | |||
filename = os.path.join( directory, str(i) ) | |||
print(filename) | |||
save_image(image, filename) | |||
i = i+1 | |||
if __name__ == "__main__": | |||
# download data if needed | |||
# download_and_extract() | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument('--DATA_DIR', type=str, default='../stl10_binary') | |||
args = parser.parse_args() | |||
BASE_PATH = os.path.join( args.DATA_DIR, 'stl10_data' ) | |||
if not os.path.exists(BASE_PATH): | |||
os.mkdir(BASE_PATH) | |||
# Train | |||
DATA_PATH = os.path.join( args.DATA_DIR, 'train_X.bin') | |||
LABEL_PATH = os.path.join( args.DATA_DIR, 'train_y.bin') | |||
print('Preparing Train') | |||
images = read_all_images(DATA_PATH) | |||
print(images.shape) | |||
labels = read_labels(LABEL_PATH) | |||
print(labels.shape) | |||
save_images(images, labels, os.path.join(BASE_PATH, 'train')) | |||
# Test | |||
DATA_PATH = os.path.join( args.DATA_DIR, 'test_X.bin') | |||
LABEL_PATH = os.path.join( args.DATA_DIR, 'test_y.bin') | |||
#with open(DATA_PATH) as f: | |||
# image = read_single_image(f) | |||
# plot_image(image) | |||
print('Preparing Test') | |||
images = read_all_images(DATA_PATH) | |||
print(images.shape) | |||
labels = read_labels(LABEL_PATH) | |||
print(labels.shape) | |||
save_images(images, labels, os.path.join(BASE_PATH, 'test')) | |||
# Unlabeled | |||
print('Preparing Unlabeled') | |||
DATA_PATH = os.path.join( args.DATA_DIR, 'unlabeled_X.bin') | |||
images = read_all_images(DATA_PATH) | |||
save_images(images, None, os.path.join(BASE_PATH, 'unlabeled')) | |||
@@ -0,0 +1,53 @@ | |||
import argparse | |||
import os | |||
from PIL import Image | |||
import numpy as np | |||
SIZE=(480, 320) # W, H | |||
def is_image(path: str): | |||
return path.endswith( 'png' ) or path.endswith('jpg') or path.endswith( 'jpeg' ) | |||
def copy_and_resize( src_dir, dst_dir, resize_fn ): | |||
for file_or_dir in os.listdir( src_dir ): | |||
src = os.path.join( src_dir, file_or_dir ) | |||
dst = os.path.join( dst_dir, file_or_dir ) | |||
if os.path.isdir( src ): | |||
os.mkdir( dst ) | |||
copy_and_resize( src, dst, resize_fn ) | |||
elif is_image( src ): | |||
print(src, ' -> ', dst) | |||
image = Image.open( src ) | |||
resized_image = resize_fn(image) | |||
resized_image.save( dst ) | |||
def resize_input( image: Image.Image ): | |||
return image.resize( SIZE, resample=Image.BICUBIC ) | |||
def resize_seg( image: Image.Image ): | |||
return image.resize( SIZE, resample=Image.NEAREST ) | |||
def main(): | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument('--root', type=str, required=True) | |||
ROOT = parser.parse_args().root | |||
NEW_ROOT = os.path.join( ROOT, '%d_%d'%(*SIZE) ) | |||
os.mkdir(NEW_ROOT, ) | |||
for split in ['train', 'val', 'test']: | |||
IMG_DIR = os.path.join( ROOT, split ) | |||
GT_DIR = os.path.join( ROOT, split+'annot' ) | |||
NEW_IMG_DIR = os.path.join( NEW_ROOT, split ) | |||
NEW_GT_DIR = os.path.join( NEW_ROOT, split+'annot' ) | |||
os.mkdir( NEW_IMG_DIR ) | |||
os.mkdir( NEW_GT_DIR ) | |||
copy_and_resize( IMG_DIR, NEW_IMG_DIR, resize_input ) | |||
copy_and_resize( GT_DIR, NEW_GT_DIR, resize_seg ) | |||
if __name__=='__main__': | |||
main() |
@@ -0,0 +1,65 @@ | |||
import argparse | |||
import os | |||
from PIL import Image | |||
import numpy as np | |||
SIZE=(640, 320) | |||
def is_image(path: str): | |||
return path.endswith( 'png' ) or path.endswith('jpg') or path.endswith( 'jpeg' ) | |||
def copy_and_resize( src_dir, dst_dir, resize_fn ): | |||
for file_or_dir in os.listdir( src_dir ): | |||
src = os.path.join( src_dir, file_or_dir ) | |||
dst = os.path.join( dst_dir, file_or_dir ) | |||
if os.path.isdir( src ): | |||
os.mkdir( dst ) | |||
copy_and_resize( src, dst, resize_fn ) | |||
elif is_image( src ): | |||
print(src, ' -> ', dst) | |||
image = Image.open( src ) | |||
resized_image = resize_fn(image) | |||
resized_image.save( dst ) | |||
def resize_input( image: Image.Image ): | |||
return image.resize( SIZE, resample=Image.BICUBIC ) | |||
def resize_seg( image: Image.Image ): | |||
return image.resize( SIZE, resample=Image.NEAREST ) | |||
def resize_depth( image: Image.Image ): | |||
return image.resize( SIZE, resample=Image.NEAREST ) | |||
def main(): | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument('--root', type=str, required=True) | |||
ROOT = parser.parse_args().root | |||
IMG_DIR = os.path.join( ROOT, 'leftImg8bit' ) | |||
GT_DIR = os.path.join( ROOT, 'gtFine' ) | |||
DEPTH_DIR = os.path.join( ROOT, 'disparity' ) | |||
NEW_ROOT = os.path.join( ROOT, '%d_%d'%(*SIZE) ) | |||
NEW_IMG_DIR = os.path.join( NEW_ROOT, 'leftImg8bit' ) | |||
NEW_GT_DIR = os.path.join( NEW_ROOT, 'gtFine' ) | |||
NEW_DEPTH_DIR = os.path.join( NEW_ROOT, 'disparity' ) | |||
if os.path.exists(NEW_ROOT): | |||
print("Directory %s existed, please remove it before running this script"%NEW_ROOT) | |||
return | |||
os.mkdir(NEW_ROOT) | |||
os.mkdir( NEW_IMG_DIR ) | |||
os.mkdir( NEW_GT_DIR ) | |||
os.mkdir( NEW_DEPTH_DIR ) | |||
copy_and_resize( IMG_DIR, NEW_IMG_DIR, resize_input ) | |||
copy_and_resize( GT_DIR, NEW_GT_DIR, resize_seg ) | |||
copy_and_resize( DEPTH_DIR, NEW_DEPTH_DIR, resize_depth ) | |||
if __name__=='__main__': | |||
main() |
@@ -0,0 +1,59 @@ | |||
import argparse | |||
import os | |||
from PIL import Image | |||
import numpy as np | |||
from torchvision import transforms | |||
import shutil | |||
SIZE=240 | |||
def is_image(path: str): | |||
return path.endswith( 'png' ) or path.endswith('jpg') or path.endswith( 'jpeg' ) | |||
def copy_and_resize( src_dir, dst_dir, resize_fn ): | |||
for file_or_dir in os.listdir( src_dir ): | |||
src = os.path.join( src_dir, file_or_dir ) | |||
dst = os.path.join( dst_dir, file_or_dir ) | |||
if os.path.isdir( src ): | |||
os.mkdir( dst ) | |||
copy_and_resize( src, dst, resize_fn ) | |||
elif is_image( src ): | |||
print(src, ' -> ', dst) | |||
image = Image.open( src ) | |||
resized_image = resize_fn(image) | |||
resized_image.save( dst ) | |||
def resize_input( image: Image.Image ): | |||
return transforms.functional.resize( image, SIZE, interpolation=Image.BILINEAR ) | |||
def resize_seg( image: Image.Image ): | |||
return transforms.functional.resize( image, SIZE, interpolation=Image.NEAREST) | |||
def main(): | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument('--root', type=str, required=True) | |||
ROOT = parser.parse_args().root | |||
IMG_DIR = os.path.join( ROOT, 'JPEGImages' ) | |||
GT_DIR = os.path.join( ROOT, 'SegmentationClass' ) | |||
NEW_ROOT = os.path.join( ROOT, str(SIZE) ) | |||
NEW_IMG_DIR = os.path.join( NEW_ROOT, 'JPEGImages' ) | |||
NEW_GT_DIR = os.path.join( NEW_ROOT, 'SegmentationClass' ) | |||
if os.path.exists(NEW_ROOT): | |||
print("Directory %s existed, please remove it before running this script"%NEW_ROOT) | |||
return | |||
os.mkdir(NEW_ROOT) | |||
os.mkdir( NEW_IMG_DIR ) | |||
os.mkdir( NEW_GT_DIR ) | |||
copy_and_resize( IMG_DIR, NEW_IMG_DIR, resize_input ) | |||
copy_and_resize( GT_DIR, NEW_GT_DIR, resize_seg ) | |||
shutil.copytree( os.path.join( ROOT, 'ImageSets'), os.path.join( NEW_ROOT, 'ImageSets' ) ) | |||
if __name__=='__main__': | |||
main() |
@@ -0,0 +1,57 @@ | |||
import argparse | |||
import os | |||
from PIL import Image | |||
import numpy as np | |||
from torchvision import transforms | |||
import shutil | |||
def is_image(path: str): | |||
return path.endswith( 'png' ) or path.endswith('jpg') or path.endswith( 'jpeg' ) | |||
def copy_and_resize( src_dir, dst_dir, resize_fn ): | |||
for file_or_dir in os.listdir( src_dir ): | |||
src = os.path.join( src_dir, file_or_dir ) | |||
dst = os.path.join( dst_dir, file_or_dir ) | |||
if os.path.isdir( src ): | |||
os.mkdir( dst ) | |||
copy_and_resize( src, dst, resize_fn ) | |||
elif is_image( src ): | |||
print(src, ' -> ', dst) | |||
image = Image.open( src ) | |||
resized_image = resize_fn(image) | |||
resized_image.save( dst ) | |||
def resize_input( image: Image.Image ): | |||
return transforms.functional.resize( image, 240, interpolation=Image.BILINEAR ) | |||
def resize_seg( image: Image.Image ): | |||
return transforms.functional.resize( image, 240, interpolation=Image.NEAREST) | |||
def main(): | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument('--root', type=str, required=True) | |||
ROOT = parser.parse_args().root | |||
IMG_DIR = os.path.join( ROOT, 'JPEGImages' ) | |||
GT_DIR = os.path.join( ROOT, 'SegmentationClass' ) | |||
NEW_ROOT = os.path.join( ROOT, '240' ) | |||
NEW_IMG_DIR = os.path.join( NEW_ROOT, 'JPEGImages' ) | |||
NEW_GT_DIR = os.path.join( NEW_ROOT, 'SegmentationClass' ) | |||
if os.path.exists(NEW_ROOT): | |||
print("Directory %s existed, please remove it before running this script"%NEW_ROOT) | |||
return | |||
os.mkdir(NEW_ROOT) | |||
os.mkdir( NEW_IMG_DIR ) | |||
os.mkdir( NEW_GT_DIR ) | |||
copy_and_resize( IMG_DIR, NEW_IMG_DIR, resize_input ) | |||
copy_and_resize( GT_DIR, NEW_GT_DIR, resize_seg ) | |||
shutil.copytree( os.path.join( ROOT, 'ImageSets'), os.path.join( NEW_ROOT, 'ImageSets' ) ) | |||
if __name__=='__main__': | |||
main() |
@@ -0,0 +1,80 @@ | |||
import os | |||
import glob | |||
from PIL import Image | |||
import numpy as np | |||
from scipy.io import loadmat | |||
from torch.utils import data | |||
from .utils import download_url, mkdir | |||
from shutil import copyfile | |||
class StanfordCars(data.Dataset): | |||
"""Dataset for Stanford Cars | |||
""" | |||
urls = {'cars_train.tgz': 'http://imagenet.stanford.edu/internal/car196/cars_train.tgz', | |||
'cars_test.tgz': 'http://imagenet.stanford.edu/internal/car196/cars_test.tgz', | |||
'car_devkit.tgz': 'https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz', | |||
'cars_test_annos_withlabels.mat': 'http://imagenet.stanford.edu/internal/car196/cars_test_annos_withlabels.mat'} | |||
def __init__(self, root, split='train', download=False, transform=None, target_transform=None): | |||
self.root = os.path.abspath( os.path.expanduser(root) ) | |||
self.split = split | |||
self.transform = transform | |||
self.target_transform = target_transform | |||
if download: | |||
self.download() | |||
if self.split == 'train': | |||
annos = os.path.join(self.root, 'devkit', 'cars_train_annos.mat') | |||
else: | |||
annos = os.path.join(self.root, 'devkit', | |||
'cars_test_annos_withlabels.mat') | |||
annos = loadmat(annos) | |||
size = len(annos['annotations'][0]) | |||
self.files = glob.glob(os.path.join( | |||
self.root, 'cars_'+self.split, '*.jpg')) | |||
self.files.sort() | |||
self.labels = np.array([int(l[4])-1 for l in annos['annotations'][0]]) | |||
lbl_annos = loadmat(os.path.join(self.root, 'devkit', 'cars_meta.mat')) | |||
self.object_categories = [str(c[0]) | |||
for c in lbl_annos['class_names'][0]] | |||
print('Stanford Cars, Split: %s, Size: %d' % | |||
(self.split, self.__len__())) | |||
def __len__(self): | |||
return len(self.files) | |||
def __getitem__(self, idx): | |||
img = Image.open(os.path.join(self.root, 'Images', | |||
self.files[idx])).convert("RGB") | |||
lbl = self.labels[idx] | |||
if self.transform is not None: | |||
img = self.transform(img) | |||
if self.target_transform is not None: | |||
lbl = self.target_transform(lbl) | |||
return img, lbl | |||
def download(self): | |||
import tarfile | |||
mkdir(self.root) | |||
for fname, url in self.urls.items(): | |||
if not os.path.isfile(os.path.join(self.root, fname)): | |||
download_url(url, self.root, fname) | |||
if fname.endswith('tgz'): | |||
print("Extracting %s..." % fname) | |||
with tarfile.open(os.path.join(self.root, fname), "r:gz") as tar: | |||
tar.extractall(path=self.root) | |||
copyfile(os.path.join(self.root, 'cars_test_annos_withlabels.mat'), | |||
os.path.join(self.root, 'devkit', 'cars_test_annos_withlabels.mat')) |
@@ -0,0 +1,58 @@ | |||
import os | |||
import numpy as np | |||
from PIL import Image | |||
from scipy.io import loadmat | |||
from torch.utils import data | |||
from .utils import download_url | |||
from shutil import move | |||
class StanfordDogs(data.Dataset): | |||
"""Dataset for Stanford Dogs | |||
""" | |||
urls = {"images.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/images.tar", | |||
"annotation.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/annotation.tar", | |||
"lists.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/lists.tar"} | |||
def __init__(self, root, split='train', download=False, transform=None, target_transform=None): | |||
self.root = os.path.abspath( os.path.expanduser(root) ) | |||
self.split = split | |||
self.transform = transform | |||
self.target_transform = target_transform | |||
if download: | |||
self.download() | |||
list_file = os.path.join(self.root, self.split+'_list.mat') | |||
mat_file = loadmat(list_file) | |||
size = len(mat_file['file_list']) | |||
self.files = [str(mat_file['file_list'][i][0][0]) for i in range(size)] | |||
self.labels = np.array( | |||
[mat_file['labels'][i][0]-1 for i in range(size)]) | |||
categories = os.listdir(os.path.join(self.root, 'Images')) | |||
categories.sort() | |||
self.object_categories = [c[10:] for c in categories] | |||
print('Stanford Dogs, Split: %s, Size: %d' % | |||
(self.split, self.__len__())) | |||
def __len__(self): | |||
return len(self.files) | |||
def __getitem__(self, idx): | |||
img = Image.open(os.path.join(self.root, 'Images', | |||
self.files[idx])).convert("RGB") | |||
lbl = self.labels[idx] | |||
if self.transform is not None: | |||
img = self.transform(img) | |||
if self.target_transform is not None: | |||
lbl = self.target_transform( lbl ) | |||
return img, lbl | |||
def download(self): | |||
import tarfile | |||
os.makedirs(self.root, exist_ok=True) | |||
for fname, url in self.urls.items(): | |||
if not os.path.isfile(os.path.join(self.root, fname)): | |||
download_url(url, self.root, fname) | |||
# extract file | |||
print("Extracting %s..." % fname) | |||
with tarfile.open(os.path.join(self.root, fname), "r") as tar: | |||
tar.extractall(path=self.root) |
@@ -0,0 +1,65 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
import os | |||
from glob import glob | |||
from PIL import Image | |||
from .utils import colormap | |||
from torchvision.datasets import VisionDataset | |||
class SunRGBD(VisionDataset): | |||
cmap = colormap() | |||
def __init__(self, | |||
root, | |||
split='train', | |||
transform=None, | |||
target_transform=None, | |||
transforms=None): | |||
super( SunRGBD, self ).__init__( root, transform=transform, target_transform=target_transform, transforms=transforms ) | |||
self.root = root | |||
self.split = split | |||
self.images = glob(os.path.join(self.root, 'SUNRGBD-%s_images'%self.split, '*.jpg')) | |||
self.labels = glob(os.path.join(self.root, '%s13labels'%self.split, '*.png')) | |||
self.images.sort() | |||
self.labels.sort() | |||
def __getitem__(self, idx): | |||
""" | |||
Args: | |||
- index (``int``): index of the item in the dataset | |||
Returns: | |||
A tuple of ``PIL.Image`` (image, label) where label is the ground-truth | |||
of the image. | |||
""" | |||
img, label = Image.open(self.images[idx]), Image.open(self.labels[idx]) | |||
if self.transform is not None: | |||
img, label = self.transform(img, label) | |||
label = label-1 # void 0=>255 | |||
return img, label | |||
def __len__(self): | |||
return len(self.images) | |||
@classmethod | |||
def decode_fn(cls, mask): | |||
"""decode semantic mask to RGB image""" | |||
return cls.cmap[mask.astype('uint8')+1] |
@@ -0,0 +1,67 @@ | |||
""" | |||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
============================================================= | |||
""" | |||
import torch | |||
from torch.utils.data import Dataset | |||
import os | |||
from PIL import Image | |||
import random | |||
from copy import deepcopy | |||
def _collect_all_images(root, postfix=['png', 'jpg', 'jpeg', 'JPEG']): | |||
images = [] | |||
if isinstance( postfix, str): | |||
postfix = [ postfix ] | |||
for dirpath, dirnames, files in os.walk(root): | |||
for pos in postfix: | |||
for f in files: | |||
if f.endswith( pos ): | |||
images.append( os.path.join( dirpath, f ) ) | |||
return images | |||
def get_train_val_set(root, val_size=0.3): | |||
if not isinstance(root, (list, tuple)): | |||
root = [root] | |||
train_set = [] | |||
val_set = [] | |||
for _root in root: | |||
_part_train_set = _collect_all_images( _root) | |||
if os.path.isdir( os.path.join(_root, 'test') ): | |||
_part_val_set = _collect_all_images( os.path.join(_root, 'test') ) | |||
else: | |||
_val_size = int( len(_part_train_set) * val_size ) | |||
_part_val_set = random.sample( _part_train_set, k=_val_size ) | |||
_part_train_set = [ d for d in _part_train_set if d not in _part_val_set ] | |||
train_set.extend(_part_train_set) | |||
val_set.extend(_part_val_set) | |||
return train_set, val_set | |||
class UnlabeledDataset(Dataset): | |||
def __init__(self, data, transform=None, postfix=['png', 'jpg', 'jpeg', 'JPEG']): | |||
self.transform = transform | |||
self.data = data | |||
def __getitem__(self, idx): | |||
data = Image.open( self.data[idx] ) | |||
if self.transform is not None: | |||
data = self.transform(data) | |||
return data | |||
def __len__(self): | |||
return len(self.data) | |||
@@ -0,0 +1,161 @@ | |||
# Modified from https://github.com/pytorch/vision | |||
import os | |||
import os.path | |||
import hashlib | |||
import errno | |||
from tqdm import tqdm | |||
import numpy as np | |||
import torch | |||
import random | |||
def mkdir(dir): | |||
if not os.path.isdir(dir): | |||
os.mkdir(dir) | |||
def colormap(N=256, normalized=False): | |||
def bitget(byteval, idx): | |||
return ((byteval & (1 << idx)) != 0) | |||
dtype = 'float32' if normalized else 'uint8' | |||
cmap = np.zeros((N, 3), dtype=dtype) | |||
for i in range(N): | |||
r = g = b = 0 | |||
c = i | |||
for j in range(8): | |||
r = r | (bitget(c, 0) << 7-j) | |||
g = g | (bitget(c, 1) << 7-j) | |||
b = b | (bitget(c, 2) << 7-j) | |||
c = c >> 3 | |||
cmap[i] = np.array([r, g, b]) | |||
cmap = cmap/255 if normalized else cmap | |||
return cmap | |||
DEFAULT_COLORMAP = colormap() | |||
def gen_bar_updater(pbar): | |||
def bar_update(count, block_size, total_size): | |||
if pbar.total is None and total_size: | |||
pbar.total = total_size | |||
progress_bytes = count * block_size | |||
pbar.update(progress_bytes - pbar.n) | |||
return bar_update | |||
def check_integrity(fpath, md5=None): | |||
if md5 is None: | |||
return True | |||
if not os.path.isfile(fpath): | |||
return False | |||
md5o = hashlib.md5() | |||
with open(fpath, 'rb') as f: | |||
# read in 1MB chunks | |||
for chunk in iter(lambda: f.read(1024 * 1024), b''): | |||
md5o.update(chunk) | |||
md5c = md5o.hexdigest() | |||
if md5c != md5: | |||
return False | |||
return True | |||
def makedir_exist_ok(dirpath): | |||
""" | |||
Python2 support for os.makedirs(.., exist_ok=True) | |||
""" | |||
try: | |||
os.makedirs(dirpath) | |||
except OSError as e: | |||
if e.errno == errno.EEXIST: | |||
pass | |||
else: | |||
raise | |||
def download_url(url, root, filename=None, md5=None): | |||
"""Download a file from a url and place it in root. | |||
Args: | |||
url (str): URL to download file from | |||
root (str): Directory to place downloaded file in | |||
filename (str): Name to save the file under. If None, use the basename of the URL | |||
md5 (str): MD5 checksum of the download. If None, do not check | |||
""" | |||
from six.moves import urllib | |||
root = os.path.expanduser(root) | |||
if not filename: | |||
filename = os.path.basename(url) | |||
fpath = os.path.join(root, filename) | |||
makedir_exist_ok(root) | |||
# downloads file | |||
if os.path.isfile(fpath) and check_integrity(fpath, md5): | |||
print('Using downloaded and verified file: ' + fpath) | |||
else: | |||
try: | |||
print('Downloading ' + url + ' to ' + fpath) | |||
urllib.request.urlretrieve( | |||
url, fpath, | |||
reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True)) | |||
) | |||
except OSError: | |||
if url[:5] == 'https': | |||
url = url.replace('https:', 'http:') | |||
print('Failed download. Trying https -> http instead.' | |||
' Downloading ' + url + ' to ' + fpath) | |||
urllib.request.urlretrieve( | |||
url, fpath, | |||
reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True)) | |||
) | |||
def list_dir(root, prefix=False): | |||
"""List all directories at a given root | |||
Args: | |||
root (str): Path to directory whose folders need to be listed | |||
prefix (bool, optional): If true, prepends the path to each result, otherwise | |||
only returns the name of the directories found | |||
""" | |||
root = os.path.expanduser(root) | |||
directories = list( | |||
filter( | |||
lambda p: os.path.isdir(os.path.join(root, p)), | |||
os.listdir(root) | |||
) | |||
) | |||
if prefix is True: | |||
directories = [os.path.join(root, d) for d in directories] | |||
return directories | |||
def list_files(root, suffix, prefix=False): | |||
"""List all files ending with a suffix at a given root | |||
Args: | |||
root (str): Path to directory whose folders need to be listed | |||
suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). | |||
It uses the Python "str.endswith" method and is passed directly | |||
prefix (bool, optional): If true, prepends the path to each result, otherwise | |||
only returns the name of the files found | |||
""" | |||
root = os.path.expanduser(root) | |||
files = list( | |||
filter( | |||
lambda p: os.path.isfile(os.path.join( | |||
root, p)) and p.endswith(suffix), | |||
os.listdir(root) | |||
) | |||
) | |||
if prefix is True: | |||
files = [os.path.join(root, d) for d in files] | |||
return files | |||
def set_seed(random_seed): | |||
torch.manual_seed(random_seed) | |||
torch.cuda.manual_seed(random_seed) | |||
np.random.seed(random_seed) | |||
random.seed(random_seed) |
@@ -0,0 +1,209 @@ | |||
# Modified from https://github.com/pytorch/vision | |||
import os | |||
import sys | |||
import tarfile | |||
import collections | |||
import torch.utils.data as data | |||
import shutil | |||
import numpy as np | |||
from .utils import colormap | |||
from torchvision.datasets import VisionDataset | |||
import torch | |||
from PIL import Image | |||
from torchvision.datasets.utils import download_url, check_integrity | |||
DATASET_YEAR_DICT = { | |||
'2012aug': { | |||
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', | |||
'filename': 'VOCtrainval_11-May-2012.tar', | |||
'md5': '6cd6e144f989b92b3379bac3b3de84fd', | |||
'base_dir': 'VOCdevkit/VOC2012' | |||
}, | |||
'2012': { | |||
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', | |||
'filename': 'VOCtrainval_11-May-2012.tar', | |||
'md5': '6cd6e144f989b92b3379bac3b3de84fd', | |||
'base_dir': 'VOCdevkit/VOC2012' | |||
}, | |||
'2011': { | |||
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar', | |||
'filename': 'VOCtrainval_25-May-2011.tar', | |||
'md5': '6c3384ef61512963050cb5d687e5bf1e', | |||
'base_dir': 'TrainVal/VOCdevkit/VOC2011' | |||
}, | |||
'2010': { | |||
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar', | |||
'filename': 'VOCtrainval_03-May-2010.tar', | |||
'md5': 'da459979d0c395079b5c75ee67908abb', | |||
'base_dir': 'VOCdevkit/VOC2010' | |||
}, | |||
'2009': { | |||
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar', | |||
'filename': 'VOCtrainval_11-May-2009.tar', | |||
'md5': '59065e4b188729180974ef6572f6a212', | |||
'base_dir': 'VOCdevkit/VOC2009' | |||
}, | |||
'2008': { | |||
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar', | |||
'filename': 'VOCtrainval_11-May-2012.tar', | |||
'md5': '2629fa636546599198acfcfbfcf1904a', | |||
'base_dir': 'VOCdevkit/VOC2008' | |||
}, | |||
'2007': { | |||
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar', | |||
'filename': 'VOCtrainval_06-Nov-2007.tar', | |||
'md5': 'c52e279531787c972589f7e41ab4ae64', | |||
'base_dir': 'VOCdevkit/VOC2007' | |||
} | |||
} | |||
class VOCSegmentation(VisionDataset): | |||
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset. | |||
Args: | |||
root (string): Root directory of the VOC Dataset. | |||
year (string, optional): The dataset year, supports years 2007 to 2012. | |||
image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val`` | |||
download (bool, optional): If true, downloads the dataset from the internet and | |||
puts it in root directory. If dataset is already downloaded, it is not | |||
downloaded again. | |||
transform (callable, optional): A function/transform that takes in an PIL image | |||
and returns a transformed version. E.g, ``transforms.RandomCrop`` | |||
""" | |||
cmap = colormap() | |||
def __init__(self, | |||
root, | |||
year='2012', | |||
image_set='train', | |||
download=False, | |||
transform=None, | |||
target_transform=None, | |||
transforms=None, | |||
): | |||
super( VOCSegmentation, self ).__init__( root, transform=transform, target_transform=target_transform, transforms=transforms ) | |||
is_aug=False | |||
if year=='2012aug': | |||
is_aug = True | |||
year = '2012' | |||
self.root = os.path.expanduser(root) | |||
self.year = year | |||
self.url = DATASET_YEAR_DICT[year]['url'] | |||
self.filename = DATASET_YEAR_DICT[year]['filename'] | |||
self.md5 = DATASET_YEAR_DICT[year]['md5'] | |||
self.image_set = image_set | |||
base_dir = DATASET_YEAR_DICT[year]['base_dir'] | |||
voc_root = os.path.join(self.root, base_dir) | |||
image_dir = os.path.join(voc_root, 'JPEGImages') | |||
if download: | |||
download_extract(self.url, self.root, self.filename, self.md5) | |||
if not os.path.isdir(voc_root): | |||
raise RuntimeError('Dataset not found or corrupted.' + | |||
' You can use download=True to download it') | |||
if is_aug and image_set=='train': | |||
mask_dir = os.path.join(voc_root, 'SegmentationClassAug') | |||
assert os.path.exists(mask_dir), "SegmentationClassAug not found, please refer to README.md and prepare it manually" | |||
split_f = os.path.join( self.root, 'train_aug.txt') | |||
else: | |||
mask_dir = os.path.join(voc_root, 'SegmentationClass') | |||
splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation') | |||
split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt') | |||
if not os.path.exists(split_f): | |||
raise ValueError( | |||
'Wrong image_set entered! Please use image_set="train" ' | |||
'or image_set="trainval" or image_set="val"') | |||
with open(os.path.join(split_f), "r") as f: | |||
file_names = [x.strip() for x in f.readlines()] | |||
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] | |||
self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names] | |||
assert (len(self.images) == len(self.masks)) | |||
def __getitem__(self, index): | |||
""" | |||
Args: | |||
index (int): Index | |||
Returns: | |||
tuple: (image, target) where target is the image segmentation. | |||
""" | |||
img = Image.open(self.images[index]).convert('RGB') | |||
target = Image.open(self.masks[index]) | |||
if self.transforms is not None: | |||
img, target = self.transforms(img, target) | |||
return img, target.squeeze(0) | |||
def __len__(self): | |||
return len(self.images) | |||
@classmethod | |||
def decode_fn(cls, mask): | |||
"""decode semantic mask to RGB image""" | |||
return cls.cmap[mask] | |||
def download_extract(url, root, filename, md5): | |||
download_url(url, root, filename, md5) | |||
with tarfile.open(os.path.join(root, filename), "r") as tar: | |||
tar.extractall(path=root) | |||
CLASSES = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', | |||
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] | |||
class VOCClassification(data.Dataset): | |||
def __init__(self, | |||
root, | |||
year='2010', | |||
split='train', | |||
download=False, | |||
transforms=None, | |||
target_transforms=None): | |||
voc_root = os.path.join(root, 'VOC{}'.format(year)) | |||
if not os.path.isdir(voc_root): | |||
raise RuntimeError('Dataset not found or corrupted.' + | |||
' You can use download=True to download it') | |||
self.transforms = transforms | |||
self.target_transforms = target_transforms | |||
image_dir = os.path.join(voc_root, 'JPEGImages') | |||
label_dir = os.path.join(voc_root, 'ImageSets/Main') | |||
self.labels_list = [] | |||
fname = os.path.join(label_dir, '{}.txt'.format(split)) | |||
with open(fname) as f: | |||
self.images = [os.path.join(image_dir, line.split()[0]+'.jpg') for line in f] | |||
for clas in CLASSES: | |||
labels = [] | |||
with open(os.path.join(label_dir, '{}_{}.txt'.format(clas, split))) as f: | |||
labels = [int(line.split()[1]) for line in f] | |||
self.labels_list.append(labels) | |||
assert (len(self.images) == len(self.labels_list[0])) | |||
def __getitem__(self, index): | |||
""" | |||
Args: | |||
index (int): Index | |||
Returns: | |||
tuple: (image, target) where target is the image segmentation. | |||
""" | |||
img = Image.open(self.images[index]).convert('RGB') | |||
labels = [labels[index] for labels in self.labels_list] | |||
if self.transforms is not None: | |||
img = self.transforms(img) | |||
if self.target_transforms is not None: | |||
labels = self.target_transforms(labels) | |||
return img, labels | |||
def __len__(self): | |||
return len(self.images) |
@@ -0,0 +1,3 @@ | |||
from . import classification, segmentation | |||
from torchvision import models as torchvision_models |
@@ -0,0 +1,7 @@ | |||
from .darknet import * | |||
from .mobilenetv2 import * | |||
from .resnet import * | |||
from .vgg import * | |||
from . import cifar | |||
from .alexnet import alexnet |
@@ -0,0 +1,63 @@ | |||
# Modified from https://github.com/pytorch/vision | |||
import torch.nn as nn | |||
import torch.utils.model_zoo as model_zoo | |||
__all__ = ['AlexNet', 'alexnet'] | |||
model_urls = { | |||
'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', | |||
} | |||
class AlexNet(nn.Module): | |||
def __init__(self, num_classes=1000): | |||
super(AlexNet, self).__init__() | |||
self.features = nn.Sequential( | |||
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), | |||
nn.ReLU(inplace=True), | |||
nn.MaxPool2d(kernel_size=3, stride=2), | |||
nn.Conv2d(64, 192, kernel_size=5, padding=2), | |||
nn.ReLU(inplace=True), | |||
nn.MaxPool2d(kernel_size=3, stride=2), | |||
nn.Conv2d(192, 384, kernel_size=3, padding=1), | |||
nn.ReLU(inplace=True), | |||
nn.Conv2d(384, 256, kernel_size=3, padding=1), | |||
nn.ReLU(inplace=True), | |||
nn.Conv2d(256, 256, kernel_size=3, padding=1), | |||
nn.ReLU(inplace=True), | |||
nn.MaxPool2d(kernel_size=3, stride=2), | |||
) | |||
self.classifier = nn.Sequential( | |||
nn.Dropout(), | |||
nn.Linear(256 * 6 * 6, 4096), | |||
nn.ReLU(inplace=True), | |||
nn.Dropout(), | |||
nn.Linear(4096, 4096), | |||
nn.ReLU(inplace=True), | |||
nn.Linear(4096, num_classes), | |||
) | |||
def forward(self, x): | |||
x = self.features(x) | |||
x = x.view(x.size(0), 256 * 6 * 6) | |||
x = self.classifier(x) | |||
return x | |||
def alexnet(pretrained=False, **kwargs): | |||
r"""AlexNet model architecture from the | |||
`"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper. | |||
Args: | |||
pretrained (bool): If True, returns a model pre-trained on ImageNet | |||
""" | |||
num_classes = kwargs.pop('num_classes', None) | |||
model = AlexNet(**kwargs) | |||
if pretrained: | |||
model.load_state_dict(model_zoo.load_url(model_urls['alexnet'])) | |||
if num_classes is not None and num_classes!=1000: | |||
model.classifier[-1] = nn.Linear(4096, num_classes) | |||
return model |
@@ -0,0 +1 @@ | |||
from . import wrn |
@@ -0,0 +1,108 @@ | |||
#Adapted from https://github.com/polo5/ZeroShotKnowledgeTransfer/blob/master/models/wresnet.py | |||
import math | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
__all__ = ['wrn'] | |||
class BasicBlock(nn.Module): | |||
def __init__(self, in_planes, out_planes, stride, dropout_rate=0.0): | |||
super(BasicBlock, self).__init__() | |||
self.bn1 = nn.BatchNorm2d(in_planes) | |||
self.relu1 = nn.ReLU(inplace=True) | |||
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, | |||
padding=1, bias=False) | |||
self.bn2 = nn.BatchNorm2d(out_planes) | |||
self.relu2 = nn.ReLU(inplace=True) | |||
self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, | |||
padding=1, bias=False) | |||
self.dropout = nn.Dropout( dropout_rate ) | |||
self.equalInOut = (in_planes == out_planes) | |||
self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, | |||
padding=0, bias=False) or None | |||
def forward(self, x): | |||
if not self.equalInOut: | |||
x = self.relu1(self.bn1(x)) | |||
else: | |||
out = self.relu1(self.bn1(x)) | |||
out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) | |||
out = self.dropout(out) | |||
out = self.conv2(out) | |||
return torch.add(x if self.equalInOut else self.convShortcut(x), out) | |||
class NetworkBlock(nn.Module): | |||
def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropout_rate=0.0): | |||
super(NetworkBlock, self).__init__() | |||
self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropout_rate) | |||
def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropout_rate): | |||
layers = [] | |||
for i in range(nb_layers): | |||
layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropout_rate)) | |||
return nn.Sequential(*layers) | |||
def forward(self, x): | |||
return self.layer(x) | |||
class WideResNet(nn.Module): | |||
def __init__(self, depth, num_classes, widen_factor=1, dropout_rate=0.0): | |||
super(WideResNet, self).__init__() | |||
nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] | |||
assert (depth - 4) % 6 == 0, 'depth should be 6n+4' | |||
n = (depth - 4) // 6 | |||
block = BasicBlock | |||
# 1st conv before any network block | |||
self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, | |||
padding=1, bias=False) | |||
# 1st block | |||
self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropout_rate) | |||
# 2nd block | |||
self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropout_rate) | |||
# 3rd block | |||
self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropout_rate) | |||
# global average pooling and classifier | |||
self.bn1 = nn.BatchNorm2d(nChannels[3]) | |||
self.relu = nn.ReLU(inplace=True) | |||
self.fc = nn.Linear(nChannels[3], num_classes) | |||
self.nChannels = nChannels[3] | |||
for m in self.modules(): | |||
if isinstance(m, nn.Conv2d): | |||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |||
m.weight.data.normal_(0, math.sqrt(2. / n)) | |||
elif isinstance(m, nn.BatchNorm2d): | |||
m.weight.data.fill_(1) | |||
m.bias.data.zero_() | |||
elif isinstance(m, nn.Linear): | |||
m.bias.data.zero_() | |||
def forward(self, x, return_features=False): | |||
out = self.conv1(x) | |||
out = self.block1(out) | |||
out = self.block2(out) | |||
out = self.block3(out) | |||
out = self.relu(self.bn1(out)) | |||
out = F.avg_pool2d(out, 8) | |||
features = out.view(-1, self.nChannels) | |||
out = self.fc(features) | |||
if return_features: | |||
return out, features | |||
else: | |||
return out | |||
def wrn_16_1(num_classes, dropout_rate=0): | |||
return WideResNet(depth=16, num_classes=num_classes, widen_factor=1, dropout_rate=dropout_rate) | |||
def wrn_16_2(num_classes, dropout_rate=0): | |||
return WideResNet(depth=16, num_classes=num_classes, widen_factor=2, dropout_rate=dropout_rate) | |||
def wrn_40_1(num_classes, dropout_rate=0): | |||
return WideResNet(depth=40, num_classes=num_classes, widen_factor=1, dropout_rate=dropout_rate) | |||
def wrn_40_2(num_classes, dropout_rate=0): | |||
return WideResNet(depth=40, num_classes=num_classes, widen_factor=2, dropout_rate=dropout_rate) |