@@ -0,0 +1,2 @@ | |||||
/.idea/ | |||||
*.iml |
@@ -0,0 +1,7 @@ | |||||
FROM tensorflow/tensorflow:2.4.1 | |||||
WORKDIR /app | |||||
RUN pip install web.py tf2onnx | |||||
COPY . /app | |||||
ENTRYPOINT ["python3", "main.py"] |
@@ -0,0 +1,86 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
import json | |||||
import os | |||||
import subprocess | |||||
import logging | |||||
import web | |||||
from subprocess import PIPE | |||||
urls = ( | |||||
'/hello', 'Hello', | |||||
'/model_convert', 'ModelConvert' | |||||
) | |||||
logging.basicConfig(filename='onnx.log', level=logging.DEBUG) | |||||
class Hello(object): | |||||
def GET(self): | |||||
return 'service alive' | |||||
class ModelConvert(object): | |||||
def POST(self): | |||||
data = web.data() | |||||
web.header('Content-Type', 'application/json') | |||||
try: | |||||
json_data = json.loads(data) | |||||
model_path = json_data['model_path'] | |||||
output_path = json_data['output_path'] | |||||
if not os.path.isdir(model_path): | |||||
msg = 'model_path is not a dir: %s' % model_path | |||||
logging.error(msg) | |||||
return json.dumps({'code': 501, 'msg': msg, 'data': ''}) | |||||
if not output_path.endswith('/'): | |||||
msg = 'output_path is not a dir: %s' % output_path | |||||
logging.error(msg) | |||||
return json.dumps({'code': 502, 'msg': msg, 'data': ''}) | |||||
exist_flag = exist(model_path) | |||||
if not exist_flag: | |||||
msg = 'SavedModel file does not exist at: %s' % model_path | |||||
logging.error(msg) | |||||
return json.dumps({'code': 503, 'msg': msg, 'data': ''}) | |||||
convert_flag, msg = convert(model_path, output_path) | |||||
if not convert_flag: | |||||
return json.dumps({'code': 504, 'msg': msg, 'data': ''}) | |||||
except Exception as e: | |||||
logging.error(str(e)) | |||||
return json.dumps({'code': 505, 'msg': str(e), 'data': ''}) | |||||
return json.dumps({'code': 200, 'msg': 'ok', 'data': msg}) | |||||
def exist(model_path): | |||||
for file in os.listdir(model_path): | |||||
if file=='saved_model.pbtxt' or file=='saved_model.pb': | |||||
return True | |||||
return False | |||||
def convert(model_path, output_path): | |||||
output_path = output_path+'model.onnx' | |||||
try: | |||||
logging.info('model_path=%s, output_path=%s' % (model_path, output_path)) | |||||
result = subprocess.run(["python", "-m", "tf2onnx.convert", "--saved-model", model_path, "--output", output_path], stdout=PIPE, stderr=PIPE) | |||||
logging.info(repr(result)) | |||||
if result.returncode != 0: | |||||
return False, str(result.stderr) | |||||
except Exception as e: | |||||
logging.error(str(e)) | |||||
return False, str(e) | |||||
return True, output_path | |||||
if __name__ == '__main__': | |||||
app = web.application(urls, globals()) | |||||
app.run() |
@@ -0,0 +1,130 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
dependencies = ['torch', 'kamal'] # 依赖项 | |||||
import json | |||||
import traceback | |||||
import os | |||||
import torch | |||||
import web | |||||
import kamal | |||||
from PIL import Image | |||||
from kamal.transferability import TransferabilityGraph | |||||
from kamal.transferability.trans_metric import AttrMapMetric | |||||
from torchvision.models import * | |||||
urls = ( | |||||
'/model_measure/measure', 'Measure', | |||||
'/model_measure/package', 'Package', | |||||
) | |||||
app = web.application(urls, globals()) | |||||
class Package: | |||||
def POST(self): | |||||
req = web.data() | |||||
save_path_list = [] | |||||
for json_data in json.loads(req): | |||||
try: | |||||
metadata = json_data['metadata'] | |||||
except Exception as e: | |||||
traceback.print_exc() | |||||
return json.dumps(Response(506, 'Failed to package model, Error: %s'% (traceback.format_exc(limit=1)), save_path_list).__dict__) | |||||
entry_name = json_data['entry_name'] | |||||
for name, fn in kamal.hub.list_entry(__file__): | |||||
if entry_name in name: | |||||
try: | |||||
dataset = metadata['dataset'] | |||||
model = fn(pretrained=False, num_classes=metadata['entry_args']['num_classes']) | |||||
num_params = sum( [ torch.numel(p) for p in model.parameters() ] ) | |||||
save_path_for_measure = '%sfinegraind_%s/' % (json_data['ckpt'], entry_name) | |||||
save_path = save_path_for_measure+'%s' % (metadata['name']) | |||||
ckpt = self.file_name(json_data['ckpt']) | |||||
if ckpt == '': | |||||
return json.dumps( | |||||
Response(506, 'Failed to package model [%s]: No .pth file was found in directory ckpt' % (entry_name), save_path_list).__dict__) | |||||
model.load_state_dict(torch.load(ckpt), False) | |||||
kamal.hub.save( # 该调用将用户的pytorch模型打包成上述格式,并存储至指定位置 | |||||
model, # 需要保存的模型 nn.Module | |||||
save_path=save_path, | |||||
# 导出文件夹名称 | |||||
entry_name=entry_name, # 入口函数名,需要与上边的入口函数一致 | |||||
spec_name=None, # 具体的参数版本名,为空则自动用md5替代 | |||||
code_path=__file__, # 模型依赖的代码,可以是文件夹(必须包含hubconf.py文件), | |||||
# 或者是当前hubconf.py, 例子中直接使用了依赖中的模型实现,故只需指定为本文件即可 | |||||
metadata=metadata, | |||||
tags=dict( | |||||
num_params=num_params, | |||||
metadata=metadata, | |||||
name=metadata['name'], | |||||
url=metadata['url'], | |||||
dataset=dataset, | |||||
img_size=metadata['input']['size'], | |||||
readme=json_data['readme']) | |||||
) | |||||
save_path_list.append(save_path_for_measure) | |||||
return json.dumps(Response(200, 'Success', save_path_list).__dict__) | |||||
except Exception: | |||||
traceback.print_exc() | |||||
return json.dumps(Response(506,'Failed to package model [%s], Error: %s' % (entry_name, traceback.format_exc(limit=1)), save_path_list).__dict__) | |||||
return json.dumps(Response(506, 'Failed to package model [%s], Error: %s' % (entry_name, traceback.format_exc(limit=1)), save_path_list).__dict__) | |||||
def file_name(self, file_dir): | |||||
for root, dirs, files in os.walk(file_dir): | |||||
for file in files: | |||||
if file.endswith('pth'): | |||||
return root + file | |||||
return '' | |||||
class Measure: | |||||
def POST(self): | |||||
req = web.data() | |||||
json_data = json.loads(req) | |||||
print(json_data) | |||||
try: | |||||
measure_name = 'measure' | |||||
zoo_set = json_data['zoo_set'] | |||||
probe_set_root = json_data['probe_set_root'] | |||||
export_path = json_data['export_path'] | |||||
output_filename_list = [] | |||||
TG = TransferabilityGraph(zoo_set) | |||||
print("Add %s" % (probe_set_root)) | |||||
imgs_set = list(os.listdir(os.path.join(probe_set_root))) | |||||
images = [Image.open(os.path.join(probe_set_root, img)) for img in imgs_set] | |||||
metric = AttrMapMetric(images, device=torch.device('cuda')) | |||||
TG.add_metric(probe_set_root, metric) | |||||
isExists = os.path.exists(export_path) | |||||
if not isExists: | |||||
# 如果不存在则创建目录 | |||||
os.makedirs(export_path) | |||||
output_filename = export_path+'%s.json' % (measure_name) | |||||
TG.export_to_json(probe_set_root, output_filename, topk=3, normalize=True) | |||||
output_filename_list.append(output_filename) | |||||
except Exception: | |||||
traceback.print_exc() | |||||
return json.dumps(Response(506, 'Failed to generate measurement file of [%s], Error: %s' % (probe_set_root, traceback.format_exc(limit=1)), output_filename_list).__dict__) | |||||
return json.dumps(Response(200, 'Success', output_filename_list).__dict__) | |||||
class Response: | |||||
def __init__(self, code, msg, data): | |||||
self.code = code | |||||
self.msg = msg | |||||
self.data = data | |||||
if __name__ == "__main__": | |||||
app.run() |
@@ -0,0 +1,5 @@ | |||||
from .core import tasks, metrics, engine, callbacks, hub | |||||
from . import amalgamation, slim, vision, transferability | |||||
from .core import load, save |
@@ -0,0 +1,4 @@ | |||||
from .layerwise_amalgamation import LayerWiseAmalgamator | |||||
from .common_feature import CommonFeatureAmalgamator | |||||
from .task_branching import TaskBranchingAmalgamator, JointSegNet | |||||
from .recombination import RecombinationAmalgamator, CombinedModel |
@@ -0,0 +1,291 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
from kamal.core.engine.engine import Engine | |||||
from kamal.core.engine.hooks import FeatureHook | |||||
from kamal.core import tasks | |||||
from kamal.utils import set_mode, move_to_device | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
import typing, time | |||||
import numpy as np | |||||
def conv3x3(in_planes, out_planes, stride=1): | |||||
"""3x3 convolution with padding""" | |||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, | |||||
padding=1, bias=False) | |||||
class ResBlock(nn.Module): | |||||
""" Residual Blocks | |||||
""" | |||||
def __init__(self, inplanes, planes, stride=1, momentum=0.1): | |||||
super(ResBlock, self).__init__() | |||||
self.conv1 = conv3x3(inplanes, planes, stride) | |||||
self.bn1 = nn.BatchNorm2d(planes, momentum=momentum) | |||||
self.relu = nn.ReLU(inplace=True) | |||||
self.conv2 = conv3x3(planes, planes) | |||||
self.bn2 = nn.BatchNorm2d(planes, momentum=momentum) | |||||
if stride > 1 or inplanes != planes: | |||||
self.downsample = nn.Sequential( | |||||
nn.Conv2d(inplanes, planes, kernel_size=1, | |||||
stride=stride, bias=False), | |||||
nn.BatchNorm2d(planes, momentum=momentum) | |||||
) | |||||
else: | |||||
self.downsample = None | |||||
self.stride = stride | |||||
def forward(self, x): | |||||
residual = x | |||||
out = self.conv1(x) | |||||
out = self.bn1(out) | |||||
out = self.relu(out) | |||||
out = self.conv2(out) | |||||
out = self.bn2(out) | |||||
if self.downsample is not None: | |||||
residual = self.downsample(x) | |||||
out += residual | |||||
out = self.relu(out) | |||||
return out | |||||
class CFL_FCBlock(nn.Module): | |||||
"""Common Feature Blocks for Fully-Connected layer | |||||
This module is used to capture the common features of multiple teachers and calculate mmd with features of student. | |||||
**Parameters:** | |||||
- cs (int): channel number of student features | |||||
- channel_ts (list or tuple): channel number list of teacher features | |||||
- ch (int): channel number of hidden features | |||||
""" | |||||
def __init__(self, cs, cts, ch, k_size=5): | |||||
super(CFL_FCBlock, self).__init__() | |||||
self.align_t = nn.ModuleList() | |||||
for ct in cts: | |||||
self.align_t.append( | |||||
nn.Sequential( | |||||
nn.Linear(ct, ch), | |||||
nn.ReLU(inplace=True) | |||||
) | |||||
) | |||||
self.align_s = nn.Sequential( | |||||
nn.Linear(cs, ch), | |||||
nn.ReLU(inplace=True), | |||||
) | |||||
self.extractor = nn.Sequential( | |||||
nn.Linear(ch, ch), | |||||
nn.ReLU(), | |||||
nn.Linear(ch, ch), | |||||
) | |||||
self.dec_t = nn.ModuleList() | |||||
for ct in cts: | |||||
self.dec_t.append( | |||||
nn.Sequential( | |||||
nn.Linear(ch, ct), | |||||
nn.ReLU(inplace=True), | |||||
nn.Linear(ct, ct) | |||||
) | |||||
) | |||||
def init_weights(self): | |||||
for m in self.modules(): | |||||
if isinstance(m, nn.Linear): | |||||
torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu') | |||||
elif isinstance(m, nn.BatchNorm2d): | |||||
m.weight.data.fill_(1) | |||||
m.bias.data.zero_() | |||||
def forward(self, fs, fts): | |||||
aligned_t = [self.align_t[i](fts[i]) for i in range(len(fts))] | |||||
aligned_s = self.align_s(fs) | |||||
hts = [self.extractor(f) for f in aligned_t] | |||||
hs = self.extractor(aligned_s) | |||||
_fts = [self.dec_t[i](hts[i]) for i in range(len(hts))] | |||||
return (hs, hts), (_fts, fts) | |||||
class CFL_ConvBlock(nn.Module): | |||||
"""Common Feature Blocks for Convolutional layer | |||||
This module is used to capture the common features of multiple teachers and calculate mmd with features of student. | |||||
**Parameters:** | |||||
- cs (int): channel number of student features | |||||
- channel_ts (list or tuple): channel number list of teacher features | |||||
- ch (int): channel number of hidden features | |||||
""" | |||||
def __init__(self, cs, cts, ch, k_size=5): | |||||
super(CFL_ConvBlock, self).__init__() | |||||
self.align_t = nn.ModuleList() | |||||
for ct in cts: | |||||
self.align_t.append( | |||||
nn.Sequential( | |||||
nn.Conv2d(in_channels=ct, out_channels=ch, | |||||
kernel_size=1), | |||||
nn.BatchNorm2d(ch), | |||||
nn.ReLU(inplace=True) | |||||
) | |||||
) | |||||
self.align_s = nn.Sequential( | |||||
nn.Conv2d(in_channels=cs, out_channels=ch, | |||||
kernel_size=1), | |||||
nn.BatchNorm2d(ch), | |||||
nn.ReLU(inplace=True), | |||||
) | |||||
self.extractor = nn.Sequential( | |||||
ResBlock(inplanes=ch, planes=ch, stride=1), | |||||
ResBlock(inplanes=ch, planes=ch, stride=1), | |||||
) | |||||
self.dec_t = nn.ModuleList() | |||||
for ct in cts: | |||||
self.dec_t.append( | |||||
nn.Sequential( | |||||
nn.Conv2d(ch, ch, kernel_size=1, stride=1), | |||||
nn.BatchNorm2d(ch), | |||||
nn.ReLU(inplace=True), | |||||
nn.Conv2d(ch, ct, kernel_size=1, stride=1) | |||||
) | |||||
) | |||||
def init_weights(self): | |||||
for m in self.modules(): | |||||
if isinstance(m, nn.Conv2d): | |||||
torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu') | |||||
elif isinstance(m, nn.BatchNorm2d): | |||||
m.weight.data.fill_(1) | |||||
m.bias.data.zero_() | |||||
def forward(self, fs, fts): | |||||
aligned_t = [self.align_t[i](fts[i]) for i in range(len(fts))] | |||||
aligned_s = self.align_s(fs) | |||||
hts = [self.extractor(f) for f in aligned_t] | |||||
hs = self.extractor(aligned_s) | |||||
_fts = [self.dec_t[i](hts[i]) for i in range(len(hts))] | |||||
return (hs, hts), (_fts, fts) | |||||
class CommonFeatureAmalgamator(Engine): | |||||
def setup( | |||||
self, | |||||
student, | |||||
teachers, | |||||
layer_groups: typing.Sequence[typing.Sequence], | |||||
layer_channels: typing.Sequence[typing.Sequence], | |||||
dataloader: torch.utils.data.DataLoader, | |||||
optimizer: torch.optim.Optimizer, | |||||
weights = [1.0, 1.0, 1.0], | |||||
on_layer_input=False, | |||||
device = None, | |||||
): | |||||
self._dataloader = dataloader | |||||
if device is None: | |||||
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) | |||||
self._device = device | |||||
self._model = self._student = student.to(self._device) | |||||
self._teachers = nn.ModuleList(teachers).to(self._device) | |||||
self._optimizer = optimizer | |||||
self._weights = weights | |||||
self._on_layer_input = on_layer_input | |||||
amal_blocks = [] | |||||
for group, C in zip( layer_groups, layer_channels ): | |||||
hooks = [ FeatureHook(layer) for layer in group ] | |||||
if isinstance(group[0], nn.Linear): | |||||
amal_block = CFL_FCBlock( cs=C[0], cts=C[1:], ch=C[0]//4 ).to(self._device).train() | |||||
print("Building FC Blocks") | |||||
else: | |||||
amal_block = CFL_ConvBlock(cs=C[0], cts=C[1:], ch=C[0]//4).to(self._device).train() | |||||
print("Building Conv Blocks") | |||||
amal_blocks.append( (amal_block, hooks, C) ) | |||||
self._amal_blocks = amal_blocks | |||||
self._cfl_criterion = tasks.loss.CFLLoss( sigmas=[0.001, 0.01, 0.05, 0.1, 0.2, 1, 2] ) | |||||
@property | |||||
def device(self): | |||||
return self._device | |||||
def run(self, max_iter, start_iter=0, epoch_length=None): | |||||
block_params = [] | |||||
for block, _, _ in self._amal_blocks: | |||||
block_params.extend( list(block.parameters()) ) | |||||
if isinstance( self._optimizer, torch.optim.SGD ): | |||||
self._amal_optimimizer = torch.optim.SGD( block_params, lr=self._optimizer.param_groups[0]['lr'], momentum=0.9, weight_decay=1e-4 ) | |||||
else: | |||||
self._amal_optimimizer = torch.optim.Adam( block_params, lr=self._optimizer.param_groups[0]['lr'], weight_decay=1e-4 ) | |||||
self._amal_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self._amal_optimimizer, T_max=max_iter ) | |||||
with set_mode(self._student, training=True), \ | |||||
set_mode(self._teachers, training=False): | |||||
super( CommonFeatureAmalgamator, self ).run(self.step_fn, self._dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length) | |||||
def step_fn(self, engine, batch): | |||||
start_time = time.perf_counter() | |||||
batch = move_to_device(batch, self._device) | |||||
data = batch[0] | |||||
s_out = self._student( data ) | |||||
with torch.no_grad(): | |||||
t_out = [ teacher( data ) for teacher in self._teachers ] | |||||
loss_amal = 0 | |||||
loss_recons = 0 | |||||
for amal_block, hooks, C in self._amal_blocks: | |||||
features = [ h.feat_in if self._on_layer_input else h.feat_out for h in hooks ] | |||||
fs, fts = features[0], features[1:] | |||||
(hs, hts), (_fts, fts) = amal_block( fs, fts ) | |||||
_loss_amal, _loss_recons = self._cfl_criterion( hs, hts, _fts, fts ) | |||||
loss_amal += _loss_amal | |||||
loss_recons += _loss_recons | |||||
loss_kd = tasks.loss.kldiv( s_out, torch.cat( t_out, dim=1 ) ) | |||||
loss_dict = { | |||||
'loss_kd': self._weights[0]*loss_kd, | |||||
'loss_amal': self._weights[1]*loss_amal, | |||||
'loss_recons': self._weights[2]*loss_recons | |||||
} | |||||
loss = sum(loss_dict.values()) | |||||
self._optimizer.zero_grad() | |||||
self._amal_optimimizer.zero_grad() | |||||
loss.backward() | |||||
self._optimizer.step() | |||||
self._amal_optimimizer.step() | |||||
self._amal_scheduler.step() | |||||
step_time = time.perf_counter() - start_time | |||||
metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } | |||||
metrics.update({ | |||||
'total_loss': loss.item(), | |||||
'step_time': step_time, | |||||
'lr': float( self._optimizer.param_groups[0]['lr'] ) | |||||
}) | |||||
return metrics |
@@ -0,0 +1,131 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from kamal.core.engine.engine import Engine | |||||
from kamal.core.engine.hooks import FeatureHook | |||||
from kamal.core import tasks | |||||
from kamal.utils import set_mode | |||||
import typing | |||||
import time | |||||
from kamal.utils import move_to_device, set_mode | |||||
class AmalBlock(nn.Module): | |||||
def __init__(self, cs, cts): | |||||
super( AmalBlock, self ).__init__() | |||||
self.cs, self.cts = cs, cts | |||||
self.enc = nn.Conv2d( in_channels=sum(self.cts), out_channels=self.cs, kernel_size=1, stride=1, padding=0, bias=True ) | |||||
self.fam = nn.Conv2d( in_channels=self.cs, out_channels=self.cs, kernel_size=1, stride=1, padding=0, bias=True ) | |||||
self.dec = nn.Conv2d( in_channels=self.cs, out_channels=sum(self.cts), kernel_size=1, stride=1, padding=0, bias=True ) | |||||
def forward(self, fs, fts): | |||||
rep = self.enc( torch.cat( fts, dim=1 ) ) | |||||
_fts = self.dec( rep ) | |||||
_fts = torch.split( _fts, self.cts, dim=1 ) | |||||
_fs = self.fam( fs ) | |||||
return rep, _fs, _fts | |||||
class LayerWiseAmalgamator(Engine): | |||||
def setup( | |||||
self, | |||||
student, | |||||
teachers, | |||||
layer_groups: typing.Sequence[typing.Sequence], | |||||
layer_channels: typing.Sequence[typing.Sequence], | |||||
dataloader: torch.utils.data.DataLoader, | |||||
optimizer: torch.optim.Optimizer, | |||||
weights = [1., 1., 1.], | |||||
device=None, | |||||
): | |||||
if device is None: | |||||
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) | |||||
self._device = device | |||||
self._dataloader = dataloader | |||||
self.model = self.student = student.to(self.device) | |||||
self.teachers = nn.ModuleList(teachers).to(self.device) | |||||
self.optimizer = optimizer | |||||
self._weights = weights | |||||
amal_blocks = [] | |||||
for group, C in zip(layer_groups, layer_channels): | |||||
hooks = [ FeatureHook(layer) for layer in group ] | |||||
amal_block = AmalBlock(cs=C[0], cts=C[1:]).to(self.device).train() | |||||
amal_blocks.append( (amal_block, hooks, C) ) | |||||
self._amal_blocks = amal_blocks | |||||
@property | |||||
def device(self): | |||||
return self._device | |||||
def run(self, max_iter, start_iter=0, epoch_length=None ): | |||||
block_params = [] | |||||
for block, _, _ in self._amal_blocks: | |||||
block_params.extend( list(block.parameters()) ) | |||||
if isinstance( self.optimizer, torch.optim.SGD ): | |||||
self._amal_optimimizer = torch.optim.SGD( block_params, lr=self.optimizer.param_groups[0]['lr'], momentum=0.9, weight_decay=1e-4 ) | |||||
else: | |||||
self._amal_optimimizer = torch.optim.Adam( block_params, lr=self.optimizer.param_groups[0]['lr'], weight_decay=1e-4 ) | |||||
self._amal_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self._amal_optimimizer, T_max=max_iter ) | |||||
with set_mode(self.student, training=True), \ | |||||
set_mode(self.teachers, training=False): | |||||
super( LayerWiseAmalgamator, self ).run(self.step_fn, self._dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length) | |||||
@property | |||||
def device(self): | |||||
return self._device | |||||
def step_fn(self, engine, batch): | |||||
start_time = time.perf_counter() | |||||
batch = move_to_device(batch, self._device) | |||||
data = batch[0] | |||||
s_out = self.student( data ) | |||||
with torch.no_grad(): | |||||
t_out = [ teacher( data ) for teacher in self.teachers ] | |||||
loss_amal = 0 | |||||
loss_recons = 0 | |||||
for amal_block, hooks, C in self._amal_blocks: | |||||
features = [ h.feat_out for h in hooks ] | |||||
fs, fts = features[0], features[1:] | |||||
rep, _fs, _fts = amal_block( fs, fts ) | |||||
loss_amal += F.mse_loss( _fs, rep.detach() ) | |||||
loss_recons += sum( [ F.mse_loss( _ft, ft ) for (_ft, ft) in zip( _fts, fts ) ] ) | |||||
loss_kd = tasks.loss.kldiv( s_out, torch.cat( t_out, dim=1 ) ) | |||||
#loss_kd = F.mse_loss( s_out, torch.cat( t_out, dim=1 ) ) | |||||
loss_dict = { "loss_kd": self._weights[0] * loss_kd, | |||||
"loss_amal": self._weights[1] * loss_amal, | |||||
"loss_recons": self._weights[2] * loss_recons } | |||||
loss = sum(loss_dict.values()) | |||||
self.optimizer.zero_grad() | |||||
self._amal_optimimizer.zero_grad() | |||||
loss.backward() | |||||
self.optimizer.step() | |||||
self._amal_optimimizer.step() | |||||
self._amal_scheduler.step() | |||||
step_time = time.perf_counter() - start_time | |||||
metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } | |||||
metrics.update({ | |||||
'total_loss': loss.item(), | |||||
'step_time': step_time, | |||||
'lr': float( self.optimizer.param_groups[0]['lr'] ) | |||||
}) | |||||
return metrics | |||||
@@ -0,0 +1,209 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from copy import deepcopy | |||||
from typing import Callable | |||||
from kamal.core.engine.engine import Engine | |||||
from kamal.core.engine.trainer import KDTrainer | |||||
from kamal.core.engine.hooks import FeatureHook | |||||
from kamal.core import tasks | |||||
import math | |||||
from kamal.slim.prunning import Pruner, strategy | |||||
def _assert_same_type(layers, layer_type=None): | |||||
if layer_type is None: | |||||
layer_type = type(layers[0]) | |||||
assert all(isinstance(l, layer_type) for l in layers), 'Model archictures must be the same' | |||||
def _get_layers(model_list): | |||||
submodel = [ model.modules() for model in model_list ] | |||||
for layers in zip(*submodel): | |||||
_assert_same_type(layers) | |||||
yield layers | |||||
def bn_combine_fn(layers): | |||||
"""Combine 2D Batch Normalization Layers | |||||
**Parameters:** | |||||
- **layers** (BatchNorm2D): Batch Normalization Layers. | |||||
""" | |||||
_assert_same_type(layers, nn.BatchNorm2d) | |||||
num_features = sum(l.num_features for l in layers) | |||||
combined_bn = nn.BatchNorm2d(num_features=num_features, | |||||
eps=layers[0].eps, | |||||
momentum=layers[0].momentum, | |||||
affine=layers[0].affine, | |||||
track_running_stats=layers[0].track_running_stats) | |||||
combined_bn.running_mean = torch.cat( | |||||
[l.running_mean for l in layers], dim=0).clone() | |||||
combined_bn.running_var = torch.cat( | |||||
[l.running_var for l in layers], dim=0).clone() | |||||
if combined_bn.affine: | |||||
combined_bn.weight = torch.nn.Parameter( | |||||
torch.cat([l.weight.data.clone() for l in layers], dim=0).clone()) | |||||
combined_bn.bias = torch.nn.Parameter( | |||||
torch.cat([l.bias.data.clone() for l in layers], dim=0).clone()) | |||||
return combined_bn | |||||
def conv2d_combine_fn(layers): | |||||
"""Combine 2D Conv Layers | |||||
**Parameters:** | |||||
- **layers** (Conv2d): Conv Layers. | |||||
""" | |||||
_assert_same_type(layers, nn.Conv2d) | |||||
CO, CI = 0, 0 | |||||
for l in layers: | |||||
O, I, H, W = l.weight.shape | |||||
CO += O | |||||
CI += I | |||||
dtype = layers[0].weight.dtype | |||||
device = layers[0].weight.device | |||||
combined_weight = torch.nn.Parameter( | |||||
torch.zeros(CO, CI, H, W, dtype=dtype, device=device)) | |||||
if layers[0].bias is not None: | |||||
combined_bias = torch.nn.Parameter( | |||||
torch.zeros(CO, dtype=dtype, device=device)) | |||||
else: | |||||
combined_bias = None | |||||
co_offset = 0 | |||||
ci_offset = 0 | |||||
for idx, l in enumerate(layers): | |||||
co_len, ci_len = l.weight.shape[0], l.weight.shape[1] | |||||
combined_weight[co_offset: co_offset+co_len, | |||||
ci_offset: ci_offset+ci_len, :, :] = l.weight.clone() | |||||
if combined_bias is not None: | |||||
combined_bias[co_offset: co_offset+co_len] = l.bias.clone() | |||||
co_offset += co_len | |||||
ci_offset += ci_offset | |||||
combined_conv2d = nn.Conv2d(in_channels=CI, | |||||
out_channels=CO, | |||||
kernel_size=layers[0].weight.shape[-2:], | |||||
stride=layers[0].stride, | |||||
padding=layers[0].padding, | |||||
bias=layers[0].bias is not None) | |||||
combined_conv2d.weight.data = combined_weight | |||||
if combined_bias is not None: | |||||
combined_conv2d.bias.data = combined_bias | |||||
for p in combined_conv2d.parameters(): | |||||
p.requires_grad = True | |||||
return combined_conv2d | |||||
def combine_models(models): | |||||
"""Combine modules with parser | |||||
**Parameters:** | |||||
- **models** (nn.Module): modules to be combined. | |||||
- **combine_parser** (function): layer selector | |||||
""" | |||||
def _recursively_combine(module): | |||||
module_output = module | |||||
if isinstance( module, nn.Conv2d ): | |||||
combined_module = conv2d_combine_fn( layer_mapping[module] ) | |||||
elif isinstance( module, nn.BatchNorm2d ): | |||||
combined_module = bn_combine_fn( layer_mapping[module] ) | |||||
else: | |||||
combined_module = module | |||||
if combined_module is not None: | |||||
module_output = combined_module | |||||
for name, child in module.named_children(): | |||||
module_output.add_module(name, _recursively_combine(child)) | |||||
return module_output | |||||
models = deepcopy(models) | |||||
combined_model = deepcopy(models[0]) # copy the model archicture and modify it with _recursively_combine | |||||
layer_mapping = {} | |||||
for combined_layer, layers in zip(combined_model.modules(), _get_layers(models)): | |||||
layer_mapping[combined_layer] = layers # link to teachers | |||||
combined_model = _recursively_combine(combined_model) | |||||
return combined_model | |||||
class CombinedModel(nn.Module): | |||||
def __init__(self, models): | |||||
super( Combination, self ).__init__() | |||||
self.combined_model = combine_models( models ) | |||||
self.expand = len(models) | |||||
def forward(self, x): | |||||
x.repeat( -1, x.shape[1]*self.expand, -1, -1 ) | |||||
return self.combined_model(x) | |||||
class PruningKDTrainer(KDTrainer): | |||||
def setup( | |||||
self, | |||||
student, | |||||
teachers, | |||||
task, | |||||
dataloader: torch.utils.data.DataLoader, | |||||
get_optimizer_and_scheduler:Callable=None, | |||||
pruning_rounds=5, | |||||
device=None, | |||||
): | |||||
if device is None: | |||||
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) | |||||
self._device = device | |||||
self._dataloader = dataloader | |||||
self.model = self.student = student.to(self.device) | |||||
self.teachers = nn.ModuleList(teachers).to(self.device) | |||||
self.get_optimizer_and_scheduler = get_optimizer_and_scheduler | |||||
@property | |||||
def device(self): | |||||
return self._device | |||||
def run(self, max_iter, start_iter=0, epoch_length=None, pruning_rounds=3, target_model_size=0.6 ): | |||||
pruning_size_per_round = 1 - math.pow( target_model_size, 1/pruning_rounds ) | |||||
prunner = Pruner( strategy.LNStrategy(n=1) ) | |||||
for pruning_round in range(pruning_rounds): | |||||
prunner.prune( self.student, rate=pruning_size_per_round, example_inputs=torch.randn(1,3,240,240) ) | |||||
self.student.to(self.device) | |||||
if self.get_optimizer_and_scheduler: | |||||
self.optimizer, self.scheduler = self.get_optimizer_and_scheduler( self.student ) | |||||
else: | |||||
self.optimizer = torch.optim.Adam( self.student.parameters(), lr=1e-4, weight_decay=1e-5 ) | |||||
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self.optimizer, T_max= (max_iter-start_iter)//pruning_rounds ) | |||||
step_iter = (max_iter - start_iter)//pruning_rounds | |||||
with set_mode(self.student, training=True), \ | |||||
set_mode(self.teachers, training=False): | |||||
super( RecombinationAmalgamation, self ).run(self.step_fn, self._dataloader, | |||||
start_iter=start_iter+step_iter*pruning_round , max_iter=start_iter+step_iter*(pruning_round+1), epoch_length=epoch_length) | |||||
def step_fn(self, engine, batch): | |||||
metrics = super(RecombinationAmalgamation, self).step_fn( engine, batch ) | |||||
self.scheduler.step() | |||||
return metrics | |||||
class RecombinationAmalgamator(PruningKDTrainer): | |||||
pass |
@@ -0,0 +1,295 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from kamal.core.engine.engine import Engine | |||||
from kamal.core.engine.hooks import FeatureHook | |||||
from kamal.core import tasks | |||||
from kamal.utils import move_to_device, set_mode | |||||
from kamal.core.hub import meta | |||||
from kamal import vision | |||||
import kamal | |||||
from kamal.utils import set_mode | |||||
import typing | |||||
import time | |||||
from copy import deepcopy | |||||
import random | |||||
import numpy as np | |||||
from collections import defaultdict | |||||
import numbers | |||||
class BranchySegNet(nn.Module): | |||||
def __init__(self, out_channels, segnet_fn=vision.models.segmentation.segnet_vgg16_bn): | |||||
super(BranchySegNet, self).__init__() | |||||
channels=[512, 512, 256, 128, 64] | |||||
self.register_buffer( 'branch_indices', torch.zeros((len(out_channels),)) ) | |||||
self.student_b_decoders_list = nn.ModuleList() | |||||
self.student_adaptors_list = nn.ModuleList() | |||||
ses = [] | |||||
for i in range(5): | |||||
se = int(channels[i]/4) | |||||
ses.append(16 if se < 16 else se) | |||||
for oc in out_channels: | |||||
segnet = self.get_segnet( oc, segnet_fn ) | |||||
decoders = nn.ModuleList(deepcopy(list(segnet.children())[5:])) | |||||
adaptors = nn.ModuleList() | |||||
for i in range(5): | |||||
adaptor = nn.Sequential( | |||||
nn.Conv2d(channels[i], ses[i], kernel_size=1, bias=False), | |||||
nn.ReLU(), | |||||
nn.Conv2d(ses[i], channels[i], kernel_size=1, bias=False), | |||||
nn.Sigmoid() | |||||
) | |||||
adaptors.append(adaptor) | |||||
self.student_b_decoders_list.append(decoders) | |||||
self.student_adaptors_list.append(adaptors) | |||||
self.student_encoders = nn.ModuleList(deepcopy(list(segnet.children())[0:5])) | |||||
self.student_decoders = nn.ModuleList(deepcopy(list(segnet.children())[5:])) | |||||
def set_branch(self, branch_indices): | |||||
assert len(branch_indices)==len(self.student_b_decoders_list) | |||||
self.branch_indices = torch.from_numpy( np.array( branch_indices ) ).to(self.branch_indices.device) | |||||
def get_segnet(self, oc, segnet_fn): | |||||
return segnet_fn( num_classes=oc, pretrained_backbone=True ) | |||||
def forward(self, inputs): | |||||
output_list = [] | |||||
down1, indices_1, unpool_shape1 = self.student_encoders[0](inputs) | |||||
down2, indices_2, unpool_shape2 = self.student_encoders[1](down1) | |||||
down3, indices_3, unpool_shape3 = self.student_encoders[2](down2) | |||||
down4, indices_4, unpool_shape4 = self.student_encoders[3](down3) | |||||
down5, indices_5, unpool_shape5 = self.student_encoders[4](down4) | |||||
up5 = self.student_decoders[0](down5, indices_5, unpool_shape5) | |||||
up4 = self.student_decoders[1](up5, indices_4, unpool_shape4) | |||||
up3 = self.student_decoders[2](up4, indices_3, unpool_shape3) | |||||
up2 = self.student_decoders[3](up3, indices_2, unpool_shape2) | |||||
up1 = self.student_decoders[4](up2, indices_1, unpool_shape1) | |||||
decoder_features = [down5, up5, up4, up3, up2] | |||||
decoder_indices = [indices_5, indices_4, indices_3, indices_2, indices_1] | |||||
decoder_shapes = [unpool_shape5, unpool_shape4, unpool_shape3, unpool_shape2, unpool_shape1] | |||||
# Mimic teachers. | |||||
for i in range(len(self.branch_indices)): | |||||
out_idx = self.branch_indices[i] | |||||
output = decoder_features[out_idx] | |||||
output = output * self.student_adaptors_list[i][out_idx](F.avg_pool2d(output, output.shape[2:3])) | |||||
for j in range(out_idx, 5): | |||||
output = self.student_b_decoders_list[i][j]( | |||||
output, | |||||
decoder_indices[j], | |||||
decoder_shapes[j] | |||||
) | |||||
output_list.append( output ) | |||||
return output_list | |||||
class JointSegNet(nn.Module): | |||||
"""The online student model to learn from any number of single teacher with 'SegNet' structure. | |||||
**Parameters:** | |||||
- **teachers** (list of 'Module' object): Teachers with 'SegNet' structure to learn from. | |||||
- **indices** (list of int): Where to branch out for each task. | |||||
- **phase** (string): Should be 'block' or 'finetune'. Useful only in training mode. | |||||
- **channels** (list of int, optional): Parameter to build adaptor modules, corresponding to that of 'SegNet'. | |||||
""" | |||||
def __init__(self, teachers, student=None, channels=[512, 512, 256, 128, 64]): | |||||
super(JointSegNet, self).__init__() | |||||
self.register_buffer( 'branch_indices', torch.zeros((2,)) ) | |||||
if student is None: | |||||
student = teachers[0] | |||||
self.student_encoders = nn.ModuleList(deepcopy(list(teachers[0].children())[0:5])) | |||||
self.student_decoders = nn.ModuleList(deepcopy(list(teachers[0].children())[5:])) | |||||
self.student_b_decoders_list = nn.ModuleList() | |||||
self.student_adaptors_list = nn.ModuleList() | |||||
ses = [] | |||||
for i in range(5): | |||||
se = int(channels[i]/4) | |||||
ses.append(16 if se < 16 else se) | |||||
for teacher in teachers: | |||||
decoders = nn.ModuleList(deepcopy(list(teacher.children())[5:])) | |||||
adaptors = nn.ModuleList() | |||||
for i in range(5): | |||||
adaptor = nn.Sequential( | |||||
nn.Conv2d(channels[i], ses[i], kernel_size=1, bias=False), | |||||
nn.ReLU(), | |||||
nn.Conv2d(ses[i], channels[i], kernel_size=1, bias=False), | |||||
nn.Sigmoid() | |||||
) | |||||
adaptors.append(adaptor) | |||||
self.student_b_decoders_list.append(decoders) | |||||
self.student_adaptors_list.append(adaptors) | |||||
def set_branch(self, branch_indices): | |||||
assert len(branch_indices)==len(self.student_b_decoders_list) | |||||
self.branch_indices = torch.from_numpy( np.array( branch_indices ) ).to(self.branch_indices.device) | |||||
def forward(self, inputs): | |||||
output_list = [] | |||||
down1, indices_1, unpool_shape1 = self.student_encoders[0](inputs) | |||||
down2, indices_2, unpool_shape2 = self.student_encoders[1](down1) | |||||
down3, indices_3, unpool_shape3 = self.student_encoders[2](down2) | |||||
down4, indices_4, unpool_shape4 = self.student_encoders[3](down3) | |||||
down5, indices_5, unpool_shape5 = self.student_encoders[4](down4) | |||||
up5 = self.student_decoders[0](down5, indices_5, unpool_shape5) | |||||
up4 = self.student_decoders[1](up5, indices_4, unpool_shape4) | |||||
up3 = self.student_decoders[2](up4, indices_3, unpool_shape3) | |||||
up2 = self.student_decoders[3](up3, indices_2, unpool_shape2) | |||||
up1 = self.student_decoders[4](up2, indices_1, unpool_shape1) | |||||
decoder_features = [down5, up5, up4, up3, up2] | |||||
decoder_indices = [indices_5, indices_4, indices_3, indices_2, indices_1] | |||||
decoder_shapes = [unpool_shape5, unpool_shape4, unpool_shape3, unpool_shape2, unpool_shape1] | |||||
# Mimic teachers. | |||||
for i in range(len(self.branch_indices)): | |||||
out_idx = self.branch_indices[i] | |||||
output = decoder_features[out_idx] | |||||
output = output * self.student_adaptors_list[i][out_idx](F.avg_pool2d(output, output.shape[2:3])) | |||||
for j in range(out_idx, 5): | |||||
output = self.student_b_decoders_list[i][j]( | |||||
output, | |||||
decoder_indices[j], | |||||
decoder_shapes[j] | |||||
) | |||||
output_list.append( output ) | |||||
return output_list | |||||
class TaskBranchingAmalgamator(Engine): | |||||
def setup( | |||||
self, | |||||
joint_student: JointSegNet, | |||||
teachers, | |||||
tasks, | |||||
dataloader: torch.utils.data.DataLoader, | |||||
optimizer: torch.optim.Optimizer, | |||||
device=None, | |||||
): | |||||
if device is None: | |||||
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) | |||||
self._device = device | |||||
self._dataloader = dataloader | |||||
self.student = self.model = joint_student.to(self._device) | |||||
self.teachers = nn.ModuleList(teachers).to(self._device) | |||||
self.tasks = tasks | |||||
self.optimizer = optimizer | |||||
self.is_finetuning=False | |||||
@property | |||||
def device(self): | |||||
return self._device | |||||
def run(self, max_iter, start_iter=0, epoch_length=None, stage_callback=None ): | |||||
# Branching | |||||
with set_mode(self.student, training=True), \ | |||||
set_mode(self.teachers, training=False): | |||||
super( TaskBranchingAmalgamator, self ).run(self.step_fn, self._dataloader, start_iter=start_iter, max_iter=max_iter//2, epoch_length=epoch_length) | |||||
branch = self.find_the_best_branch( self._dataloader ) | |||||
self.logger.info("[Task Branching] the best branch indices: %s"%(branch)) | |||||
if stage_callback is not None: | |||||
stage_callback() | |||||
# Finetuning | |||||
self.is_finetuning = True | |||||
with set_mode(self.student, training=True), \ | |||||
set_mode(self.teachers, training=False): | |||||
super( TaskBranchingAmalgamator, self ).run(self.step_fn, self._dataloader, start_iter=max_iter-max_iter//2, max_iter=max_iter, epoch_length=epoch_length) | |||||
def find_the_best_branch(self, dataloader): | |||||
with set_mode(self.student, training=False), \ | |||||
set_mode(self.teachers, training=False), \ | |||||
torch.no_grad(): | |||||
n_blocks = len(self.student.student_decoders) | |||||
branch_loss = { task: [0 for _ in range(n_blocks)] for task in self.tasks } | |||||
for batch in dataloader: | |||||
batch = move_to_device(batch, self.device) | |||||
data = batch if isinstance(batch, torch.Tensor) else batch[0] | |||||
for candidate_branch in range( n_blocks ): | |||||
self.student.set_branch( [candidate_branch for _ in range(len(self.teachers))] ) | |||||
s_out_list = self.student( data ) | |||||
t_out_list = [ teacher( data ) for teacher in self.teachers ] | |||||
for task, s_out, t_out in zip( self.tasks, s_out_list, t_out_list ): | |||||
task_loss = task.get_loss( s_out, t_out ) | |||||
branch_loss[task][candidate_branch] += sum(task_loss.values()) | |||||
best_brach = [] | |||||
for task in self.tasks: | |||||
best_brach.append( int(np.argmin( branch_loss[task] )) ) | |||||
self.student.set_branch(best_brach) | |||||
return best_brach | |||||
@property | |||||
def device(self): | |||||
return self._device | |||||
def step_fn(self, engine, batch): | |||||
start_time = time.perf_counter() | |||||
batch = move_to_device(batch, self._device) | |||||
data = batch[0] | |||||
#data = batch if isinstance(batch, torch.Tensor) else batch[0] | |||||
data, None | |||||
n_blocks = len(self.student.student_decoders) | |||||
if not self.is_finetuning: | |||||
rand_branch_indices = [ random.randint(0, n_blocks-1) for _ in range(len(self.teachers)) ] | |||||
self.student.set_branch(rand_branch_indices) | |||||
s_out_list = self.student( data ) | |||||
with torch.no_grad(): | |||||
t_out_list = [ teacher( data ) for teacher in self.teachers ] | |||||
loss_dict = {} | |||||
for task, s_out, t_out in zip( self.tasks, s_out_list, t_out_list ): | |||||
task_loss = task.get_loss( s_out, t_out ) | |||||
loss_dict.update( task_loss ) | |||||
loss = sum(loss_dict.values()) | |||||
self.optimizer.zero_grad() | |||||
loss.backward() | |||||
self.optimizer.step() | |||||
step_time = time.perf_counter() - start_time | |||||
metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } | |||||
metrics.update({ | |||||
'total_loss': loss.item(), | |||||
'step_time': step_time, | |||||
'lr': float( self.optimizer.param_groups[0]['lr'] ), | |||||
'branch': self.student.branch_indices.cpu().numpy().tolist() | |||||
}) | |||||
return metrics | |||||
@@ -0,0 +1,4 @@ | |||||
from . import engine, tasks, metrics, callbacks, exceptions, hub | |||||
from .attach import AttachTo | |||||
from .hub import load, save |
@@ -0,0 +1,45 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
from typing import Sequence, Callable | |||||
from numbers import Number | |||||
from kamal.core import exceptions | |||||
class AttachTo(object): | |||||
""" Attach task, metrics or visualizer to specified model outputs | |||||
""" | |||||
def __init__(self, attach_to=None): | |||||
if attach_to is not None and not isinstance(attach_to, (Sequence, Number, str, Callable) ): | |||||
raise exceptions.InvalidMapping | |||||
self._attach_to = attach_to | |||||
def __call__(self, *tensors): | |||||
if self._attach_to is not None: | |||||
if isinstance(self._attach_to, Callable): | |||||
return self._attach_to( *tensors ) | |||||
if isinstance(self._attach_to, Sequence): | |||||
_attach_to = self._attach_to | |||||
else: | |||||
_attach_to = [ self._attach_to for _ in range(len(tensors)) ] | |||||
_attach_to = _attach_to[:len(tensors)] | |||||
tensors = [ tensor[index] for (tensor, index) in zip( tensors, _attach_to ) ] | |||||
if len(tensors)==1: | |||||
tensors = tensors[0] | |||||
return tensors | |||||
def __repr__(self): | |||||
rep = "AttachTo: %s"%(self._attach_to) |
@@ -0,0 +1,5 @@ | |||||
from .logging import MetricsLogging, ProgressCallback | |||||
from .base import Callback | |||||
from .eval_and_ckpt import EvalAndCkpt | |||||
from .scheduler import LRSchedulerCallback | |||||
from .visualize import VisualizeOutputs, VisualizeSegmentation, VisualizeDepth |
@@ -0,0 +1,28 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
import abc | |||||
class Callback(abc.ABC): | |||||
r""" Base Class for Callbacks | |||||
""" | |||||
def __init__(self): | |||||
pass | |||||
@abc.abstractmethod | |||||
def __call__(self, engine): | |||||
pass |
@@ -0,0 +1,145 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
from .base import Callback | |||||
import weakref | |||||
from kamal import utils | |||||
from typing import Sequence, Optional | |||||
import numbers | |||||
import os, shutil | |||||
import torch | |||||
class EvalAndCkpt(Callback): | |||||
def __init__(self, | |||||
model, | |||||
evaluator, | |||||
metric_name:str, | |||||
metric_mode:str ='max', | |||||
save_type:Optional[Sequence]=('best', 'latest'), | |||||
ckpt_dir:str ='checkpoints', | |||||
ckpt_prefix:str =None, | |||||
log_tag:str ='model', | |||||
weights_only:bool =True, | |||||
verbose:bool =False,): | |||||
super(EvalAndCkpt, self).__init__() | |||||
self.metric_name = metric_name | |||||
assert metric_mode in ('max', 'min'), "metric_mode should be 'max' or 'min'" | |||||
self._metric_mode = metric_mode | |||||
self._model = weakref.ref( model ) | |||||
self._evaluator = evaluator | |||||
self._ckpt_dir = ckpt_dir | |||||
self._ckpt_prefix = "" if ckpt_prefix is None else (ckpt_prefix+'_') | |||||
if isinstance(save_type, str): | |||||
save_type = ( save_type, ) | |||||
self._save_type = save_type | |||||
if self._save_type is not None: | |||||
for save_type in self._save_type: | |||||
assert save_type in ('best', 'latest', 'all'), \ | |||||
'save_type should be None or a subset of (\"best\", \"latest\", \"all\")' | |||||
self._log_tag = log_tag | |||||
self._weights_only = weights_only | |||||
self._verbose = verbose | |||||
self._best_score = -999999 if self._metric_mode=='max' else 99999. | |||||
self._best_ckpt = None | |||||
self._latest_ckpt = None | |||||
@property | |||||
def best_ckpt(self): | |||||
return self._best_ckpt | |||||
@property | |||||
def latest_ckpt(self): | |||||
return self._latest_ckpt | |||||
@property | |||||
def best_score(self): | |||||
return self._best_score | |||||
def __call__(self, trainer): | |||||
model = self._model() | |||||
results = self._evaluator.eval( model, device=trainer.device ) | |||||
results = utils.flatten_dict(results) | |||||
current_score = results[self.metric_name] | |||||
scalar_results = { k: float(v) for (k, v) in results.items() if isinstance(v, numbers.Number) or len(v.shape)==0 } | |||||
if trainer.logger is not None: | |||||
trainer.logger.info( "[Eval %s] Iter %d/%d: %s"%(self._log_tag, trainer.state.iter, trainer.state.max_iter, scalar_results) ) | |||||
trainer.state.metrics.update( scalar_results ) | |||||
# Visualize results if trainer.tb_writer is not None | |||||
if trainer.tb_writer is not None: | |||||
for k, v in scalar_results.items(): | |||||
log_tag = "%s:%s"%(self._log_tag, k) | |||||
trainer.tb_writer.add_scalar(log_tag, v, global_step=trainer.state.iter) | |||||
if self._save_type is not None: | |||||
pth_path_list = [] | |||||
# interval model | |||||
if 'interval' in self._save_type: | |||||
pth_path = os.path.join(self._ckpt_dir, "%s%08d_%s_%.3f.pth" | |||||
% (self._ckpt_prefix, trainer.state.iter, self.metric_name, current_score)) | |||||
pth_path_list.append(pth_path) | |||||
# the latest model | |||||
if 'latest' in self._save_type: | |||||
pth_path = os.path.join(self._ckpt_dir, "%slatest_%08d_%s_%.3f.pth" | |||||
% (self._ckpt_prefix, trainer.state.iter, self.metric_name, current_score)) | |||||
# remove the old ckpt | |||||
if self._latest_ckpt is not None and os.path.exists(self._latest_ckpt): | |||||
os.remove(self._latest_ckpt) | |||||
pth_path_list.append(pth_path) | |||||
self._latest_ckpt = pth_path | |||||
# the best model | |||||
if 'best' in self._save_type: | |||||
if (current_score >= self._best_score and self._metric_mode=='max') or \ | |||||
(current_score <= self._best_score and self._metric_mode=='min'): | |||||
pth_path = os.path.join(self._ckpt_dir, "%sbest_%08d_%s_%.4f.pth" % | |||||
(self._ckpt_prefix, trainer.state.iter, self.metric_name, current_score)) | |||||
# remove the old ckpt | |||||
if self._best_ckpt is not None and os.path.exists(self._best_ckpt): | |||||
os.remove(self._best_ckpt) | |||||
pth_path_list.append(pth_path) | |||||
self._best_score = current_score | |||||
self._best_ckpt = pth_path | |||||
# save model | |||||
if self._verbose and trainer.logger: | |||||
trainer.logger.info("Model saved as:") | |||||
obj = model.state_dict() if self._weights_only else model | |||||
os.makedirs( self._ckpt_dir, exist_ok=True ) | |||||
for pth_path in pth_path_list: | |||||
torch.save(obj, pth_path) | |||||
if self._verbose and trainer.logger: | |||||
trainer.logger.info("\t%s" % (pth_path)) | |||||
def final_ckpt(self, ckpt_prefix=None, ckpt_dir=None, add_md5=False): | |||||
if ckpt_dir is None: | |||||
ckpt_dir = self._ckpt_dir | |||||
if ckpt_prefix is None: | |||||
ckpt_prefix = self._ckpt_prefix | |||||
if self._save_type is not None: | |||||
#if 'latest' in self._save_type and self._latest_ckpt is not None: | |||||
# os.makedirs(ckpt_dir, exist_ok=True) | |||||
# save_name = "%slatest%s.pth"%(ckpt_prefix, "" if not add_md5 else "-%s"%utils.md5(self._latest_ckpt)) | |||||
# shutil.copy2(self._latest_ckpt, os.path.join(ckpt_dir, save_name)) | |||||
if 'best' in self._save_type and self._best_ckpt is not None: | |||||
os.makedirs(ckpt_dir, exist_ok=True) | |||||
save_name = "%sbest%s.pth"%(ckpt_prefix, "" if not add_md5 else "-%s"%utils.md5(self._best_ckpt)) | |||||
shutil.copy2(self._best_ckpt, os.path.join(ckpt_dir, save_name)) |
@@ -0,0 +1,61 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
from .base import Callback | |||||
import numbers | |||||
from tqdm import tqdm | |||||
class MetricsLogging(Callback): | |||||
def __init__(self, keys): | |||||
super(MetricsLogging, self).__init__() | |||||
self._keys = keys | |||||
def __call__(self, engine): | |||||
if engine.logger==None: | |||||
return | |||||
state = engine.state | |||||
content = "Iter %d/%d (Epoch %d/%d, Batch %d/%d)"%( | |||||
state.iter, state.max_iter, | |||||
state.current_epoch, state.max_epoch, | |||||
state.current_batch_index, state.max_batch_index | |||||
) | |||||
for key in self._keys: | |||||
value = state.metrics.get(key, None) | |||||
if value is not None: | |||||
if isinstance(value, numbers.Number): | |||||
content += " %s=%.4f"%(key, value) | |||||
if engine.tb_writer is not None: | |||||
engine.tb_writer.add_scalar(key, value, global_step=state.iter) | |||||
elif isinstance(value, (list, tuple)): | |||||
content += " %s=%s"%(key, value) | |||||
engine.logger.info(content) | |||||
class ProgressCallback(Callback): | |||||
def __init__(self, max_iter=100, tag=None): | |||||
self._max_iter = max_iter | |||||
self._tag = tag | |||||
#self._pbar = tqdm(total=self._max_iter, desc=self._tag) | |||||
def __call__(self, engine): | |||||
self._pbar.update(1) | |||||
if self._pbar.n==self._max_iter: | |||||
self._pbar.close() | |||||
def reset(self): | |||||
self._pbar = tqdm(total=self._max_iter, desc=self._tag) | |||||
@@ -0,0 +1,34 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
from .base import Callback | |||||
from typing import Sequence | |||||
class LRSchedulerCallback(Callback): | |||||
r""" LR scheduler callback | |||||
""" | |||||
def __init__(self, schedulers=None): | |||||
super(LRSchedulerCallback, self).__init__() | |||||
if not isinstance(schedulers, Sequence): | |||||
schedulers = ( schedulers, ) | |||||
self._schedulers = schedulers | |||||
def __call__(self, trainer): | |||||
if self._schedulers is None: | |||||
return | |||||
for sched in self._schedulers: | |||||
sched.step() |
@@ -0,0 +1,161 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
from .base import Callback | |||||
from typing import Callable, Union, Sequence | |||||
import weakref | |||||
import random | |||||
from kamal.utils import move_to_device, set_mode, split_batch, colormap | |||||
from kamal.core.attach import AttachTo | |||||
import torch | |||||
import numpy as np | |||||
import matplotlib.pyplot as plt | |||||
import matplotlib | |||||
matplotlib.use('agg') | |||||
import math | |||||
import numbers | |||||
class VisualizeOutputs(Callback): | |||||
def __init__(self, | |||||
model, | |||||
dataset: torch.utils.data.Dataset, | |||||
idx_list_or_num_vis: Union[int, Sequence]=5, | |||||
normalizer: Callable=None, | |||||
prepare_fn: Callable=None, | |||||
decode_fn: Callable=None, # decode targets and preds | |||||
tag: str='viz'): | |||||
super(VisualizeOutputs, self).__init__() | |||||
self._dataset = dataset | |||||
self._model = weakref.ref(model) | |||||
if isinstance(idx_list_or_num_vis, int): | |||||
self.idx_list = self._get_vis_idx_list(self._dataset, idx_list_or_num_vis) | |||||
elif isinstance(idx_list_or_num_vis, Sequence): | |||||
self.idx_list = idx_list_or_num_vis | |||||
self._normalizer = normalizer | |||||
self._decode_fn = decode_fn | |||||
if prepare_fn is None: | |||||
prepare_fn = VisualizeOutputs.get_prepare_fn() | |||||
self._prepare_fn = prepare_fn | |||||
self._tag = tag | |||||
def _get_vis_idx_list(self, dataset, num_vis): | |||||
return random.sample(list(range(len(dataset))), num_vis) | |||||
@torch.no_grad() | |||||
def __call__(self, trainer): | |||||
if trainer.tb_writer is None: | |||||
trainer.logger.warning("summary writer was not found in trainer") | |||||
return | |||||
device = trainer.device | |||||
model = self._model() | |||||
with torch.no_grad(), set_mode(model, training=False): | |||||
for img_id, idx in enumerate(self.idx_list): | |||||
batch = move_to_device(self._dataset[idx], device) | |||||
batch = [ d.unsqueeze(0) for d in batch ] | |||||
inputs, targets, preds = self._prepare_fn(model, batch) | |||||
if self._normalizer is not None: | |||||
inputs = self._normalizer(inputs) | |||||
inputs = inputs.detach().cpu().numpy() | |||||
preds = preds.detach().cpu().numpy() | |||||
targets = targets.detach().cpu().numpy() | |||||
if self._decode_fn: # to RGB 0~1 NCHW | |||||
preds = self._decode_fn(preds) | |||||
targets = self._decode_fn(targets) | |||||
inputs = inputs[0] | |||||
preds = preds[0] | |||||
targets = targets[0] | |||||
trainer.tb_writer.add_images("%s-%d"%(self._tag, img_id), np.stack( [inputs, targets, preds], axis=0), global_step=trainer.state.iter) | |||||
@staticmethod | |||||
def get_prepare_fn(attach_to=None, pred_fn=lambda x: x): | |||||
attach_to = AttachTo(attach_to) | |||||
def wrapper(model, batch): | |||||
inputs, targets = split_batch(batch) | |||||
outputs = model(inputs) | |||||
outputs, targets = attach_to(outputs, targets) | |||||
return inputs, targets, pred_fn(outputs) | |||||
return wrapper | |||||
@staticmethod | |||||
def get_seg_decode_fn(cmap=colormap(), index_transform=lambda x: x+1): # 255->0, 0->1, | |||||
def wrapper(preds): | |||||
if len(preds.shape)>3: | |||||
preds = preds.squeeze(1) | |||||
out = cmap[ index_transform(preds.astype('uint8')) ] | |||||
out = out.transpose(0, 3, 1, 2) / 255 | |||||
return out | |||||
return wrapper | |||||
@staticmethod | |||||
def get_depth_decode_fn(max_depth, log_scale=True, cmap=plt.get_cmap('jet')): | |||||
def wrapper(preds): | |||||
if log_scale: | |||||
_max_depth = np.log( max_depth ) | |||||
preds = np.log( preds ) | |||||
else: | |||||
_max_depth = max_depth | |||||
if len(preds.shape)>3: | |||||
preds = preds.squeeze(1) | |||||
out = (cmap(preds.clip(0, _max_depth)/_max_depth)).transpose(0, 3, 1, 2)[:, :3] | |||||
return out | |||||
return wrapper | |||||
class VisualizeSegmentation(VisualizeOutputs): | |||||
def __init__( | |||||
self, model, dataset: torch.utils.data.Dataset, idx_list_or_num_vis: Union[int, Sequence]=5, | |||||
cmap = colormap(), | |||||
attach_to=None, | |||||
normalizer: Callable=None, | |||||
prepare_fn: Callable=None, | |||||
decode_fn: Callable=None, | |||||
tag: str='seg' | |||||
): | |||||
if prepare_fn is None: | |||||
prepare_fn = VisualizeOutputs.get_prepare_fn(attach_to=attach_to, pred_fn=lambda x: x.max(1)[1]) | |||||
if decode_fn is None: | |||||
decode_fn = VisualizeOutputs.get_seg_decode_fn(cmap=cmap, index_transform=lambda x: x+1) | |||||
super(VisualizeSegmentation, self).__init__( | |||||
model=model, dataset=dataset, idx_list_or_num_vis=idx_list_or_num_vis, | |||||
normalizer=normalizer, prepare_fn=prepare_fn, decode_fn=decode_fn, | |||||
tag=tag | |||||
) | |||||
class VisualizeDepth(VisualizeOutputs): | |||||
def __init__( | |||||
self, model, dataset: torch.utils.data.Dataset, | |||||
idx_list_or_num_vis: Union[int, Sequence]=5, | |||||
max_depth = 10, | |||||
log_scale = True, | |||||
attach_to = None, | |||||
normalizer: Callable=None, | |||||
prepare_fn: Callable=None, | |||||
decode_fn: Callable=None, | |||||
tag: str='depth' | |||||
): | |||||
if prepare_fn is None: | |||||
prepare_fn = VisualizeOutputs.get_prepare_fn(attach_to=attach_to, pred_fn=lambda x: x) | |||||
if decode_fn is None: | |||||
decode_fn = VisualizeOutputs.get_depth_decode_fn(max_depth=max_depth, log_scale=log_scale) | |||||
super(VisualizeDepth, self).__init__( | |||||
model=model, dataset=dataset, idx_list_or_num_vis=idx_list_or_num_vis, | |||||
normalizer=normalizer, prepare_fn=prepare_fn, decode_fn=decode_fn, | |||||
tag=tag | |||||
) |
@@ -0,0 +1,7 @@ | |||||
from . import evaluator | |||||
from . import trainer | |||||
from . import lr_finder | |||||
from . import hooks | |||||
from . import events | |||||
from .engine import DefaultEvents, Event |
@@ -0,0 +1,190 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
import torch | |||||
import torch.nn as nn | |||||
import abc, math, weakref, typing, time | |||||
from typing import Any, Callable, Optional, Sequence | |||||
import numpy as np | |||||
from kamal.core.engine.events import DefaultEvents, Event | |||||
from kamal.core import tasks | |||||
from kamal.utils import set_mode, move_to_device, get_logger | |||||
from collections import defaultdict | |||||
import numbers | |||||
import contextlib | |||||
class State(object): | |||||
def __init__(self): | |||||
self.iter = 0 | |||||
self.max_iter = None | |||||
self.epoch_length = None | |||||
self.dataloader = None | |||||
self.seed = None | |||||
self.metrics=dict() | |||||
self.batch=None | |||||
@property | |||||
def current_epoch(self): | |||||
if self.epoch_length is not None: | |||||
return self.iter // self.epoch_length | |||||
return None | |||||
@property | |||||
def max_epoch(self): | |||||
if self.epoch_length is not None: | |||||
return self.max_iter // self.epoch_length | |||||
return None | |||||
@property | |||||
def current_batch_index(self): | |||||
if self.epoch_length is not None: | |||||
return self.iter % self.epoch_length | |||||
return None | |||||
@property | |||||
def max_batch_index(self): | |||||
return self.epoch_length | |||||
def __repr__(self): | |||||
rep = "State:\n" | |||||
for attr, value in self.__dict__.items(): | |||||
if not isinstance(value, (numbers.Number, str, dict)): | |||||
value = type(value) | |||||
rep += "\t{}: {}\n".format(attr, value) | |||||
return rep | |||||
class Engine(abc.ABC): | |||||
def __init__(self, logger=None, tb_writer=None): | |||||
self._logger = logger if logger else get_logger(name='kamal', color=True) | |||||
self._tb_writer = tb_writer | |||||
self._callbacks = defaultdict(list) | |||||
self._allowed_events = [ *DefaultEvents ] | |||||
self._state = State() | |||||
def reset(self): | |||||
self._state = State() | |||||
def run(self, step_fn: Callable, dataloader, max_iter, start_iter=0, epoch_length=None): | |||||
self.state.iter = self._state.start_iter = start_iter | |||||
self.state.max_iter = max_iter | |||||
self.state.epoch_length = epoch_length if epoch_length else len(dataloader) | |||||
self.state.dataloader = dataloader | |||||
self.state.dataloader_iter = iter(dataloader) | |||||
self.state.step_fn = step_fn | |||||
self.trigger_events(DefaultEvents.BEFORE_RUN) | |||||
for self.state.iter in range( start_iter, max_iter ): | |||||
if self.state.epoch_length!=None and \ | |||||
self.state.iter%self.state.epoch_length==0: # Epoch Start | |||||
self.trigger_events(DefaultEvents.BEFORE_EPOCH) | |||||
self.trigger_events(DefaultEvents.BEFORE_STEP) | |||||
self.state.batch = self._get_batch() | |||||
step_output = step_fn(self, self.state.batch) | |||||
if isinstance(step_output, dict): | |||||
self.state.metrics.update(step_output) | |||||
self.trigger_events(DefaultEvents.AFTER_STEP) | |||||
if self.state.epoch_length!=None and \ | |||||
(self.state.iter+1)%self.state.epoch_length==0: # Epoch End | |||||
self.trigger_events(DefaultEvents.AFTER_EPOCH) | |||||
self.trigger_events(DefaultEvents.AFTER_RUN) | |||||
def _get_batch(self): | |||||
try: | |||||
batch = next( self.state.dataloader_iter ) | |||||
except StopIteration: | |||||
self.state.dataloader_iter = iter(self.state.dataloader) # reset iterator | |||||
batch = next( self.state.dataloader_iter ) | |||||
if not isinstance(batch, (list, tuple)): | |||||
batch = [ batch, ] # no targets | |||||
return batch | |||||
@property | |||||
def state(self): | |||||
return self._state | |||||
@property | |||||
def logger(self): | |||||
return self._logger | |||||
@property | |||||
def tb_writer(self): | |||||
return self._tb_writer | |||||
def add_callback(self, event: Event, callbacks ): | |||||
if not isinstance(callbacks, Sequence): | |||||
callbacks = [callbacks] | |||||
if event in self._allowed_events: | |||||
for callback in callbacks: | |||||
if callback not in self._callbacks[event]: | |||||
if event.trigger!=event.default_trigger: | |||||
callback = self._trigger_wrapper(self, event.trigger, callback ) | |||||
self._callbacks[event].append( callback ) | |||||
callbacks = [ RemovableCallback(self, event, c) for c in callbacks ] | |||||
return ( callbacks[0] if len(callbacks)==1 else callbacks ) | |||||
def remove_callback(self, event, callback): | |||||
for c in self._callbacks[event]: | |||||
if c==callback: | |||||
self._callbacks.remove( callback ) | |||||
return True | |||||
return False | |||||
@staticmethod | |||||
def _trigger_wrapper(engine, trigger, callback): | |||||
def wrapper(*args, **kwargs) -> Any: | |||||
if trigger(engine): | |||||
return callback(engine) | |||||
return wrapper | |||||
def trigger_events(self, *events): | |||||
for e in events: | |||||
if e in self._allowed_events: | |||||
for callback in self._callbacks[e]: | |||||
callback(self) | |||||
def register_events(self, *events): | |||||
for e in events: | |||||
if e not in self._allowed_events: | |||||
self._allowed_events.apped( e ) | |||||
@contextlib.contextmanager | |||||
def save_current_callbacks(self): | |||||
temp = self._callbacks | |||||
self._callbacks = defaultdict(list) | |||||
yield | |||||
self._callbacks = temp | |||||
class RemovableCallback: | |||||
def __init__(self, engine, event, callback): | |||||
self._engine = weakref.ref(engine) | |||||
self._callback = weakref.ref(callback) | |||||
self._event = weakref.ref(event) | |||||
@property | |||||
def callback(self): | |||||
return self._callback() | |||||
def remove(self): | |||||
engine = self._engine() | |||||
callback = self._callback() | |||||
event = self._event() | |||||
return engine.remove_callback(event, callback) | |||||
@@ -0,0 +1,130 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
import abc, sys | |||||
import torch | |||||
from tqdm import tqdm | |||||
from kamal.core import metrics | |||||
from kamal.utils import set_mode | |||||
from typing import Any, Callable | |||||
from .engine import Engine | |||||
from .events import DefaultEvents | |||||
from kamal.core import callbacks | |||||
import weakref | |||||
from kamal.utils import move_to_device, split_batch | |||||
class BasicEvaluator(Engine): | |||||
def __init__(self, | |||||
dataloader: torch.utils.data.DataLoader, | |||||
metric: metrics.MetricCompose, | |||||
eval_fn: Callable=None, | |||||
tag: str='Eval', | |||||
progress: bool=False ): | |||||
super( BasicEvaluator, self ).__init__() | |||||
self.dataloader = dataloader | |||||
self.metric = metric | |||||
self.progress = progress | |||||
if progress: | |||||
self.porgress_callback = self.add_callback( | |||||
DefaultEvents.AFTER_STEP, callbacks=callbacks.ProgressCallback(max_iter=len(self.dataloader), tag=tag)) | |||||
self._model = None | |||||
self._tag = tag | |||||
if eval_fn is None: | |||||
eval_fn = BasicEvaluator.default_eval_fn | |||||
self.eval_fn = eval_fn | |||||
def eval(self, model, device=None): | |||||
device = device if device is not None else \ | |||||
torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) | |||||
self._model = weakref.ref(model) # use weakref here | |||||
self.device = device | |||||
self.metric.reset() | |||||
model.to(device) | |||||
if self.progress: | |||||
self.porgress_callback.callback.reset() | |||||
with torch.no_grad(), set_mode(model, training=False): | |||||
super(BasicEvaluator, self).run( self.step_fn, self.dataloader, max_iter=len(self.dataloader) ) | |||||
return self.metric.get_results() | |||||
@property | |||||
def model(self): | |||||
if self._model is not None: | |||||
return self._model() | |||||
return None | |||||
def step_fn(self, engine, batch): | |||||
batch = move_to_device(batch, self.device) | |||||
self.eval_fn( engine, batch ) | |||||
@staticmethod | |||||
def default_eval_fn(evaluator, batch): | |||||
model = evaluator.model | |||||
inputs, targets = split_batch(batch) | |||||
outputs = model( inputs ) | |||||
evaluator.metric.update( outputs, targets ) | |||||
class TeacherEvaluator(BasicEvaluator): | |||||
def __init__(self, | |||||
dataloader: torch.utils.data.DataLoader, | |||||
teacher: torch.nn.Module, | |||||
task, | |||||
metric: metrics.MetricCompose, | |||||
eval_fn: Callable=None, | |||||
tag: str='Eval', | |||||
progress: bool=False ): | |||||
if eval_fn is None: | |||||
eval_fn = TeacherEvaluator.default_eval_fn | |||||
super( TeacherEvaluator, self ).__init__(dataloader=dataloader, metric=metric, eval_fn=eval_fn, tag=tag, progress=progress) | |||||
self._teacher = teacher | |||||
self.task = task | |||||
def eval(self, model, device=None): | |||||
self.teacher.to(device) | |||||
with set_mode(self.teacher, training=False): | |||||
return super(TeacherEvaluator, self).eval( model, device=device ) | |||||
@property | |||||
def model(self): | |||||
if self._model is not None: | |||||
return self._model() | |||||
return None | |||||
@property | |||||
def teacher(self): | |||||
return self._teacher | |||||
def step_fn(self, engine, batch): | |||||
batch = move_to_device(batch, self.device) | |||||
self.eval_fn( engine, batch ) | |||||
@staticmethod | |||||
def default_eval_fn(evaluator, batch): | |||||
model = evaluator.model | |||||
teacher = evaluator.teacher | |||||
inputs, targets = split_batch(batch) | |||||
outputs = model( inputs ) | |||||
# get teacher outputs | |||||
if isinstance(teacher, torch.nn.ModuleList): | |||||
targets = [ task.predict(tea(inputs)) for (tea, task) in zip(teacher, evaluator.task) ] | |||||
else: | |||||
t_outputs = teacher(inputs) | |||||
targets = evaluator.task.predict( t_outputs ) | |||||
evaluator.metric.update( outputs, targets ) |
@@ -0,0 +1,92 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
from typing import Callable, Optional | |||||
from enum import Enum | |||||
class Event(object): | |||||
def __init__(self, value: str, event_trigger: Optional[Callable]=None ): | |||||
if event_trigger is None: | |||||
event_trigger = Event.default_trigger | |||||
self._trigger = event_trigger | |||||
self._name_ = self._value_ = value | |||||
@property | |||||
def trigger(self): | |||||
return self._trigger | |||||
@property | |||||
def name(self): | |||||
"""The name of the Enum member.""" | |||||
return self._name_ | |||||
@property | |||||
def value(self): | |||||
"""The value of the Enum member.""" | |||||
return self._value_ | |||||
@staticmethod | |||||
def default_trigger(engine): | |||||
return True | |||||
@staticmethod | |||||
def once_trigger(): | |||||
is_triggered = False | |||||
def wrapper(engine): | |||||
if is_triggered: | |||||
return False | |||||
is_triggered=True | |||||
return True | |||||
return wrapper | |||||
@staticmethod | |||||
def every_trigger(every: int): | |||||
def wrapper(engine): | |||||
return every>0 and (engine.state.iter % every)==0 | |||||
return wrapper | |||||
def __call__(self, every: Optional[int]=None, once: Optional[bool]=None ): | |||||
if every is not None: | |||||
assert once is None | |||||
return Event(self.value, event_trigger=Event.every_trigger(every) ) | |||||
if once is not None: | |||||
return Event(self.value, event_trigger=Event.once_trigger() ) | |||||
return Event(self.value) | |||||
def __hash__(self): | |||||
return hash(self._name_) | |||||
def __eq__(self, other): | |||||
if hasattr(other, 'value'): | |||||
return self.value==other.value | |||||
else: | |||||
return | |||||
class DefaultEvents(Event, Enum): | |||||
BEFORE_RUN = "before_train" | |||||
AFTER_RUN = "after_train" | |||||
BEFORE_EPOCH = "before_epoch" | |||||
AFTER_EPOCH = "after_epoch" | |||||
BEFORE_STEP = "before_step" | |||||
AFTER_STEP = "after_step" | |||||
BEFORE_GET_BATCH = "before_get_batch" | |||||
AFTER_GET_BATCH = "after_get_batch" | |||||
@@ -0,0 +1,34 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
class FeatureHook(): | |||||
def __init__(self, module): | |||||
self.module = module | |||||
self.feat_in = None | |||||
self.feat_out = None | |||||
self.register() | |||||
def register(self): | |||||
self._hook = self.module.register_forward_hook(self.hook_fn_forward) | |||||
def remove(self): | |||||
self._hook.remove() | |||||
def hook_fn_forward(self, module, fea_in, fea_out): | |||||
self.feat_in = fea_in[0] | |||||
self.feat_out = fea_out | |||||
@@ -0,0 +1,214 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
import torch | |||||
import os | |||||
import tempfile, uuid | |||||
import contextlib | |||||
import numpy as np | |||||
from tqdm import tqdm | |||||
from kamal.core import callbacks | |||||
from kamal.core.engine import evaluator | |||||
from kamal.core.engine.events import DefaultEvents | |||||
class _ProgressCallback(callbacks.Callback): | |||||
def __init__(self, max_iter, tag): | |||||
self._tag = tag | |||||
self._max_iter = max_iter | |||||
self._pbar = tqdm(total=self._max_iter, desc=self._tag) | |||||
def __call__(self, enigine): | |||||
self._pbar.update(1) | |||||
def reset(self): | |||||
self._pbar = tqdm(total=self._max_iter, desc=self._tag) | |||||
class _LinearLRScheduler(torch.optim.lr_scheduler._LRScheduler): | |||||
""" Linear Scheduler | |||||
""" | |||||
def __init__(self, | |||||
optimizer, | |||||
lr_range, | |||||
max_iter, | |||||
last_epoch=-1): | |||||
self.lr_range = lr_range | |||||
self.max_iter = max_iter | |||||
super(_LinearLRScheduler, self).__init__(optimizer, last_epoch) | |||||
def get_lr(self): | |||||
r = self.last_epoch / self.max_iter | |||||
return [self.lr_range[0] + (self.lr_range[1]-self.lr_range[0]) * r for base_lr in self.base_lrs] | |||||
class _ExponentialLRScheduler(torch.optim.lr_scheduler._LRScheduler): | |||||
def __init__(self, optimizer, lr_range, max_iter, last_epoch=-1): | |||||
self.lr_range = lr_range | |||||
self.max_iter = max_iter | |||||
super(_ExponentialLRScheduler, self).__init__(optimizer, last_epoch) | |||||
def get_lr(self): | |||||
r = self.last_epoch / (self.max_iter - 1) | |||||
return [self.lr_range[0] * (self.lr_range[1] / self.lr_range[0]) ** r for base_lr in self.base_lrs] | |||||
class _LRFinderCallback(object): | |||||
def __init__(self, | |||||
model, | |||||
optimizer, | |||||
metric_name, | |||||
evaluator, | |||||
smooth_momentum): | |||||
self._model = model | |||||
self._optimizer = optimizer | |||||
self._metric_name = metric_name | |||||
self._evaluator = evaluator | |||||
self._smooth_momentum = smooth_momentum | |||||
self._records = [] | |||||
@property | |||||
def records(self): | |||||
return self._records | |||||
def reset(self): | |||||
self._records = [] | |||||
def __call__(self, trainer): | |||||
model = self._model | |||||
optimizer = self._optimizer | |||||
if self._evaluator is not None: | |||||
results = self._evaluator.eval( model ) | |||||
score = float(results[ self._metric_name ]) | |||||
else: | |||||
score = float(trainer.state.metrics[self._metric_name]) | |||||
if self._smooth_momentum>0 and len(self._records)>0: | |||||
score = self._records[-1][1] * self._smooth_momentum + score * (1-self._smooth_momentum) | |||||
lr = optimizer.param_groups[0]['lr'] | |||||
self._records.append( ( lr, score ) ) | |||||
class LRFinder(object): | |||||
def _reset(self): | |||||
init_state = torch.load( self._temp_file ) | |||||
self.trainer.model.load_state_dict( init_state['model'] ) | |||||
self.trainer.optimizer.load_state_dict( init_state['optimizer'] ) | |||||
try: | |||||
self.trainer.reset() | |||||
except: pass | |||||
def adjust_learning_rate(self, optimizer, lr): | |||||
for group in optimizer.param_groups: | |||||
group['lr'] = lr | |||||
def _get_default_lr_range(self, optimizer): | |||||
if isinstance( optimizer, torch.optim.Adam ): | |||||
return ( 1e-5, 1e-2 ) | |||||
elif isinstance( optimizer, torch.optim.SGD ): | |||||
return ( 1e-3, 0.2 ) | |||||
else: | |||||
return ( 1e-5, 0.5) | |||||
def plot(self, polyfit: int=3, log_x=True): | |||||
import matplotlib.pyplot as plt | |||||
lrs = [ rec[0] for rec in self.records ] | |||||
scores = [ rec[1] for rec in self.records ] | |||||
if polyfit is not None: | |||||
z = np.polyfit( lrs, scores, deg=polyfit ) | |||||
fitted_score = np.polyval( z, lrs ) | |||||
fig, ax = plt.subplots() | |||||
ax.plot(lrs, scores, label='score') | |||||
ax.plot(lrs, fitted_score, label='polyfit') | |||||
if log_x: | |||||
plt.xscale('log') | |||||
ax.set_xlabel("Learning rate") | |||||
ax.set_ylabel("Score") | |||||
return fig | |||||
def suggest( self, mode='min', skip_begin=10, skip_end=5, polyfit=None ): | |||||
scores = np.array( [ self.records[i][1] for i in range( len(self.records) ) ] ) | |||||
lrs = np.array( [ self.records[i][0] for i in range( len(self.records) ) ] ) | |||||
if polyfit is not None: | |||||
z = np.polyfit( lrs, scores, deg=polyfit ) | |||||
scores = np.polyval( z, lrs ) | |||||
grad = np.gradient( scores )[skip_begin:-skip_end] | |||||
index = grad.argmin() if mode=='min' else grad.argmax() | |||||
index = skip_begin + index | |||||
return index, self.records[index][0] | |||||
def find(self, | |||||
optimizer, | |||||
model, | |||||
trainer, | |||||
metric_name='total_loss', | |||||
metric_mode='min', | |||||
evaluator=None, | |||||
lr_range=[1e-4, 0.1], | |||||
max_iter=100, | |||||
num_eval=None, | |||||
smooth_momentum=0.9, | |||||
scheduler='exp', # exp | |||||
polyfit=None, # None | |||||
skip=[10, 5], | |||||
progress=True): | |||||
self.optimizer = optimizer | |||||
self.model = model | |||||
self.trainer = trainer | |||||
# save init state | |||||
_filename = str(uuid.uuid4())+'.pth' | |||||
_tempdir = tempfile.gettempdir() | |||||
self._temp_file = os.path.join(_tempdir, _filename) | |||||
init_state = { | |||||
'optimizer': optimizer.state_dict(), | |||||
'model': model.state_dict() | |||||
} | |||||
torch.save(init_state, self._temp_file) | |||||
if num_eval is None or num_eval > max_iter: | |||||
num_eval = max_iter | |||||
if lr_range is None: | |||||
lr_range = self._get_default_lr_range(optimizer) | |||||
interval = max_iter // num_eval | |||||
if scheduler=='exp': | |||||
lr_sched = _ExponentialLRScheduler(optimizer, lr_range, max_iter=max_iter) | |||||
else: | |||||
lr_sched = _LinearLRScheduler(optimizer, lr_range, max_iter=max_iter) | |||||
self._lr_callback = callbacks.LRSchedulerCallback(schedulers=[lr_sched]) | |||||
self._finder_callback = _LRFinderCallback(model, optimizer, metric_name, evaluator, smooth_momentum) | |||||
self.adjust_learning_rate( self.optimizer, lr_range[0] ) | |||||
with self.trainer.save_current_callbacks(): | |||||
trainer.add_callback( | |||||
DefaultEvents.AFTER_STEP, callbacks=[ | |||||
self._lr_callback, | |||||
_ProgressCallback(max_iter, '[LR Finder]') ]) | |||||
trainer.add_callback( | |||||
DefaultEvents.AFTER_STEP(interval), callbacks=self._finder_callback) | |||||
self.trainer.run( start_iter=0, max_iter=max_iter ) | |||||
self.records = self._finder_callback.records # get records | |||||
index, best_lr = self.suggest(mode=metric_mode, skip_begin=skip[0], skip_end=skip[1], polyfit=polyfit) | |||||
self._reset() | |||||
del self.model, self.optimizer, self.trainer | |||||
return best_lr |
@@ -0,0 +1,129 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
import torch | |||||
import torch.nn as nn | |||||
from kamal.core.engine.engine import Engine, Event, DefaultEvents, State | |||||
from kamal.core import tasks | |||||
from kamal.utils import set_mode, move_to_device, get_logger, split_batch | |||||
from typing import Callable, Mapping, Any, Sequence | |||||
import time | |||||
import weakref | |||||
class BasicTrainer(Engine): | |||||
def __init__( self, | |||||
logger=None, | |||||
tb_writer=None): | |||||
super(BasicTrainer, self).__init__(logger=logger, tb_writer=tb_writer) | |||||
def setup(self, | |||||
model: torch.nn.Module, | |||||
task: tasks.Task, | |||||
dataloader: torch.utils.data.DataLoader, | |||||
optimizer: torch.optim.Optimizer, | |||||
device: torch.device=None): | |||||
if device is None: | |||||
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) | |||||
self.device = device | |||||
if isinstance(task, Sequence): | |||||
task = tasks.TaskCompose(task) | |||||
self.task = task | |||||
self.model = model | |||||
self.dataloader = dataloader | |||||
self.optimizer = optimizer | |||||
return self | |||||
def run( self, max_iter, start_iter=0, epoch_length=None): | |||||
self.model.to(self.device) | |||||
with set_mode(self.model, training=True): | |||||
super( BasicTrainer, self ).run( self.step_fn, self.dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length) | |||||
def step_fn(self, engine, batch): | |||||
model = self.model | |||||
start_time = time.perf_counter() | |||||
batch = move_to_device(batch, self.device) | |||||
inputs, targets = split_batch(batch) | |||||
outputs = model(inputs) | |||||
loss_dict = self.task.get_loss(outputs, targets) # get loss | |||||
loss = sum( loss_dict.values() ) | |||||
self.optimizer.zero_grad() | |||||
loss.backward() | |||||
self.optimizer.step() | |||||
step_time = time.perf_counter() - start_time | |||||
metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } | |||||
metrics.update({ | |||||
'total_loss': loss.item(), | |||||
'step_time': step_time, | |||||
'lr': float( self.optimizer.param_groups[0]['lr'] ) | |||||
}) | |||||
return metrics | |||||
class KDTrainer(BasicTrainer): | |||||
def setup(self, | |||||
student: torch.nn.Module, | |||||
teacher: torch.nn.Module, | |||||
task: tasks.Task, | |||||
dataloader: torch.utils.data.DataLoader, | |||||
optimizer: torch.optim.Optimizer, | |||||
device: torch.device=None): | |||||
super(KDTrainer, self).setup( | |||||
model=student, task=task, dataloader=dataloader, optimizer=optimizer, device=device) | |||||
if isinstance(teacher, (list, tuple)): | |||||
if len(teacher)==1: | |||||
teacher=teacher[0] | |||||
else: | |||||
teacher = nn.ModuleList(teacher) | |||||
self.student = self.model | |||||
self.teacher = teacher | |||||
return self | |||||
def run( self, max_iter, start_iter=0, epoch_length=None): | |||||
self.student.to(self.device) | |||||
self.teacher.to(self.device) | |||||
with set_mode(self.student, training=True), \ | |||||
set_mode(self.teacher, training=False): | |||||
super( BasicTrainer, self ).run( | |||||
self.step_fn, self.dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length) | |||||
def step_fn(self, engine, batch): | |||||
model = self.model | |||||
start_time = time.perf_counter() | |||||
batch = move_to_device(batch, self.device) | |||||
inputs, targets = split_batch(batch) | |||||
outputs = model(inputs) | |||||
if isinstance(self.teacher, nn.ModuleList): | |||||
soft_targets = [ t(inputs) for t in self.teacher ] | |||||
else: | |||||
soft_targets = self.teacher(inputs) | |||||
loss_dict = self.task.get_loss(outputs, soft_targets) # get loss | |||||
loss = sum( loss_dict.values() ) | |||||
self.optimizer.zero_grad() | |||||
loss.backward() | |||||
self.optimizer.step() | |||||
step_time = time.perf_counter() - start_time | |||||
metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } | |||||
metrics.update({ | |||||
'total_loss': loss.item(), | |||||
'step_time': step_time, | |||||
'lr': float( self.optimizer.param_groups[0]['lr'] ) | |||||
}) | |||||
return metrics |
@@ -0,0 +1,22 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
class DataTypeError(object): | |||||
pass | |||||
class InvalidMapping(object): | |||||
pass |
@@ -0,0 +1,2 @@ | |||||
from ._hub import * | |||||
from . import meta |
@@ -0,0 +1,288 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
from pip._internal import main as pipmain | |||||
from typing import Sequence, Optional | |||||
import torch | |||||
import os, sys, re | |||||
import shutil, inspect | |||||
import importlib.util | |||||
from glob import glob | |||||
from ruamel.yaml import YAML | |||||
from ruamel.yaml import comments | |||||
from ._module_mapping import PACKAGE_NAME_TO_IMPORT_NAME | |||||
import hashlib | |||||
import inspect | |||||
_DEFAULT_PROTOCOL = 2 | |||||
_DEPENDENCY_FILE = "requirements.txt" | |||||
_CODE_DIR = "code" | |||||
_WEIGHTS_DIR = "weight" | |||||
_TAGS_DIR = "tag" | |||||
yaml = YAML() | |||||
def _replace_invalid_char(name): | |||||
return name.replace('-', '_').replace(' ', '_') | |||||
def save( | |||||
model: torch.nn.Module, | |||||
save_path: str, | |||||
entry_name: str, | |||||
spec_name: str, | |||||
code_path: str, # path to SOURCE_CODE_DIR or hubconf.py | |||||
metadata: dict, | |||||
tags: dict = None, | |||||
ignore_files: Sequence = None, | |||||
save_arch: bool = False, | |||||
overwrite: bool = False, | |||||
): | |||||
entry_name = _replace_invalid_char(entry_name) | |||||
if spec_name is not None: | |||||
spec_name = _replace_invalid_char(spec_name) | |||||
if not os.path.isdir(save_path): | |||||
overwrite = True | |||||
save_path = os.path.abspath(save_path) | |||||
export_code_path = os.path.join(save_path, _CODE_DIR) | |||||
export_weights_path = os.path.join(save_path, _WEIGHTS_DIR) | |||||
os.makedirs(save_path, exist_ok=True) | |||||
os.makedirs(export_code_path, exist_ok=True) | |||||
os.makedirs(export_weights_path, exist_ok=True) | |||||
code_path = os.path.abspath(code_path) | |||||
if os.path.isdir( code_path ): | |||||
if overwrite: | |||||
shutil.rmtree(export_code_path) | |||||
_copy_file_or_tree(src=code_path, dst=export_code_path) # overwrite old files | |||||
elif code_path.endswith('.py'): | |||||
if overwrite: | |||||
shutil.copy2(src=code_path, dst=os.path.join( export_code_path, 'hubconf.py' )) # overwrite old files | |||||
if hasattr(model, 'SETUP_INFO'): | |||||
del model.SETUP_INFO | |||||
if hasattr(model, 'METADATA'): | |||||
del model.METADATA | |||||
model_and_metadata = { | |||||
'model': model if save_arch else model.state_dict(), | |||||
'metadata': metadata, | |||||
} | |||||
temp_pth = os.path.join(export_weights_path, 'temp.pth') | |||||
torch.save(model_and_metadata, temp_pth) | |||||
if spec_name is None: | |||||
_md5 = md5( temp_pth ) | |||||
spec_name = _md5 | |||||
shutil.move( temp_pth, os.path.join(export_weights_path, '%s-%s.pth'%(entry_name, spec_name)) ) | |||||
if tags is not None: | |||||
save_tags( tags, save_path, entry_name, spec_name ) | |||||
def list_entry(path): | |||||
path = os.path.abspath( os.path.expanduser( path ) ) | |||||
if path.endswith('.py'): | |||||
code_dir = os.path.dirname( path ) | |||||
entry_file = path | |||||
elif os.path.isdir(path): | |||||
code_dir = os.path.join( path, _CODE_DIR ) | |||||
entry_file = os.path.join(path, _CODE_DIR, 'app.py') | |||||
else: | |||||
raise NotImplementedError | |||||
sys.path.insert(0, code_dir) | |||||
spec = importlib.util.spec_from_file_location('app', entry_file) | |||||
module = importlib.util.module_from_spec(spec) | |||||
spec.loader.exec_module(module) | |||||
entry_list = [ ( f, getattr(module, f) ) for f in dir(module) if callable(getattr(module, f)) and not f.startswith('_') ] | |||||
sys.path.remove(code_dir) | |||||
return entry_list | |||||
def list_spec(path, entry_name=None): | |||||
path = os.path.abspath( os.path.expanduser( path ) ) | |||||
weight_dir = os.path.join( path, _WEIGHTS_DIR ) | |||||
spec_list = [ f.split('-') for f in os.listdir( weight_dir ) if f.endswith('.pth')] | |||||
spec_list = [ (f[0], f[1][:-4]) for f in spec_list ] | |||||
if entry_name is not None: | |||||
spec_list = [ s for s in spec_list if s[0]==entry_name ] | |||||
return spec_list | |||||
def load(path: str, entry_name:str=None, spec_name: str=None, pretrained=True, **kwargs): | |||||
""" | |||||
check dependencies and load pytorch models. | |||||
Args: | |||||
path: path to the model package | |||||
pretrained: load the pretrained model | |||||
""" | |||||
path = os.path.abspath(path) | |||||
if not os.path.exists(path): | |||||
raise FileNotFoundError | |||||
code_dir = os.path.join(path, _CODE_DIR) | |||||
if os.path.isdir(path): | |||||
cwd = os.getcwd() | |||||
os.chdir(code_dir) | |||||
sys.path.insert(0, code_dir) | |||||
spec = importlib.util.spec_from_file_location( | |||||
'hubconf', os.path.join(code_dir, 'hubconf.py')) | |||||
module = importlib.util.module_from_spec(spec) | |||||
spec.loader.exec_module(module) | |||||
if hasattr( module, 'dependencies' ): | |||||
deps = getattr( module, 'dependencies') | |||||
for dep in deps: | |||||
_import_with_auto_install(dep) | |||||
if entry_name is None: # auto select | |||||
pth_file = [pth for pth in os.listdir( os.path.join(path, _WEIGHTS_DIR) )] | |||||
assert len(pth_file)<=1, "Loading models with more than one weight files (.pth) is ambiguous" | |||||
pth_file = pth_file[0] | |||||
entry_name = pth_file.split('-')[0] | |||||
entry_fn = getattr( module, entry_name ) | |||||
else: | |||||
entry_fn = getattr( module, entry_name ) | |||||
if spec_name is None: | |||||
pth_file = [pth for pth in os.listdir( os.path.join(path, _WEIGHTS_DIR) ) if pth.startswith(entry_name) ] | |||||
assert len(pth_file)<=1, "Loading models with more than one weight files (.pth) is ambiguous" | |||||
pth_file = pth_file[0] | |||||
else: | |||||
pth_file = '%s-%s.pth'%(entry_name, spec_name) | |||||
try: | |||||
model_and_metadata = torch.load(os.path.join(path, _WEIGHTS_DIR, pth_file), map_location='cpu' ) | |||||
except: raise FileNotFoundError | |||||
if isinstance( model_and_metadata['model'], torch.nn.Module ): | |||||
model = model_and_metadata['model'] | |||||
else: | |||||
entry_args = model_and_metadata['metadata']['entry_args'] | |||||
if entry_args is None: | |||||
entry_args = dict() | |||||
model = entry_fn( **entry_args ) | |||||
if pretrained: | |||||
model.load_state_dict( model_and_metadata['model'], False) | |||||
# setup metadata and atlas info | |||||
model.METADATA = model_and_metadata['metadata'] | |||||
model.SETUP_INFO = {"path": path, "entry_name": entry_name} | |||||
sys.path.pop(0) | |||||
os.chdir(cwd) | |||||
return model | |||||
raise NotImplementedError | |||||
def load_metadata(path, entry_name=None, spec_name=None): | |||||
path = os.path.abspath(path) | |||||
if os.path.isdir(path): | |||||
if entry_name is None: # auto select | |||||
pth_file = [pth for pth in os.listdir( os.path.join(path, _WEIGHTS_DIR) )] | |||||
assert len(pth_file)<=1, "Loading models with more than one weight files (.pth) is ambiguous" | |||||
pth_file = pth_file[0] | |||||
else: | |||||
if spec_name is None: | |||||
pth_file = [pth for pth in os.listdir( os.path.join(path, _WEIGHTS_DIR) ) if pth.startswith(entry_name) ] | |||||
assert len(pth_file)<=1, "Loading models with more than one weight files (.pth) is ambiguous" | |||||
pth_file = pth_file[0] | |||||
else: | |||||
pth_file = '%s-%s.pth'%(entry_name, spec_name) | |||||
try: | |||||
metadata = torch.load( os.path.join( path, _WEIGHTS_DIR, pth_file ), map_location='cpu' )['metadata'] | |||||
except: | |||||
FileNotFoundError | |||||
return metadata | |||||
def load_tags(path, entry_name, spec_name): | |||||
path = os.path.abspath(path) | |||||
tags_path = os.path.join(path, _TAGS_DIR, '%s-%s.yml'%( entry_name, spec_name )) | |||||
if os.path.isfile(tags_path): | |||||
return _to_python_type(_yaml_load(tags_path)) | |||||
return dict() | |||||
def save_tags(tags, path, entry_name, spec_name): | |||||
path = os.path.abspath(path) | |||||
if tags is None: | |||||
tags = {} | |||||
tags_path = os.path.join(path, _TAGS_DIR, '%s-%s.yml'%( entry_name, spec_name )) | |||||
os.makedirs( os.path.join( path, _TAGS_DIR ), exist_ok=True ) | |||||
_yaml_dump(tags_path, tags) | |||||
def _yaml_dump(f, obj): | |||||
with open(f, 'w') as f: | |||||
yaml.dump(obj, f) | |||||
def _yaml_load(f): | |||||
with open(f, 'r') as f: | |||||
return yaml.load(f) | |||||
def _to_python_type(data): | |||||
if isinstance(data, dict): | |||||
for k, v in data.items(): | |||||
data[k] = _to_python_type(v) | |||||
return dict(data) | |||||
elif isinstance(data, comments.CommentedSeq ): | |||||
for idx, v in enumerate(data): | |||||
data[idx] = _to_python_type(v) | |||||
return list(data) | |||||
return data | |||||
def md5(fname): | |||||
hash_md5 = hashlib.md5() | |||||
with open(fname, "rb") as f: | |||||
for chunk in iter(lambda: f.read(4096), b""): | |||||
hash_md5.update(chunk) | |||||
return hash_md5.hexdigest() | |||||
def _get_package_name_and_version(package): | |||||
_version_sym_list = ('==', '>=', '<=') | |||||
for sym in _version_sym_list: | |||||
if sym in package: | |||||
return package.split(sym) | |||||
return package, None | |||||
def _import_with_auto_install(package): | |||||
package_name, version = _get_package_name_and_version(package) | |||||
package_name = package_name.strip() | |||||
import_name = PACKAGE_NAME_TO_IMPORT_NAME.get( | |||||
package_name, package_name).replace('-', '_') | |||||
try: | |||||
return __import__(import_name) | |||||
except ImportError: | |||||
try: | |||||
pipmain.main(['install', package]) | |||||
except: | |||||
pipmain(['install', package]) | |||||
return __import__(import_name) | |||||
from distutils.dir_util import copy_tree | |||||
def _copy_file_or_tree(src, dst): | |||||
if os.path.isfile(src): | |||||
shutil.copy2(src=src, dst=dst) | |||||
else: | |||||
copy_tree(src=src, dst=dst) | |||||
def _glob_list(path_list, recursive=False): | |||||
results = [] | |||||
for path in path_list: | |||||
if '*' in path: | |||||
path = list(glob(path, recursive=recursive)) | |||||
results.extend(path) | |||||
else: | |||||
results.append(path) | |||||
return results |
@@ -0,0 +1,6 @@ | |||||
PACKAGE_NAME_TO_IMPORT_NAME = { | |||||
'opencv-python': 'cv2', | |||||
'pillow': 'PIL', | |||||
'scikit-learn': 'sklearn', | |||||
'scikit-image': 'scikit-image', | |||||
} |
@@ -0,0 +1,22 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
CLASSIFICATION = 'classification' | |||||
SEGMENTATION = 'segmentation' | |||||
DETECTION = 'detection' | |||||
DEPTH = 'depth' | |||||
AUTOENCODER = 'autoencoder' |
@@ -0,0 +1,3 @@ | |||||
from .meta import Metadata | |||||
from .input import ImageInput | |||||
from . import TASK |
@@ -0,0 +1,31 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
from typing import Union, Sequence | |||||
# INPUT Metadata | |||||
def ImageInput( size: Union[Sequence[int], int], | |||||
range: Union[Sequence[int], int], | |||||
space: str, | |||||
normalize: Union[Sequence]=None): | |||||
assert space in ['rgb', 'gray', 'bgr', 'rgbd'] | |||||
return dict( | |||||
size=size, | |||||
range=range, | |||||
space=space, | |||||
normalize=normalize | |||||
) |
@@ -0,0 +1,44 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
from copy import deepcopy | |||||
import abc, os | |||||
from . import TASK | |||||
import torch | |||||
__all__ = ['yaml', 'ImageInput', 'MetaData', 'AtlasEntryBase'] | |||||
# Model Metadata | |||||
def Metadata( name: str, | |||||
dataset: str, | |||||
task: int, | |||||
url: str, | |||||
input: dict, | |||||
entry_args: dict, | |||||
other_metadata: dict): | |||||
if task in [TASK.SEGMENTATION, TASK.CLASSIFICATION]: | |||||
assert 'num_classes' in other_metadata | |||||
return dict( | |||||
name=name, | |||||
dataset=dataset, | |||||
task=task, | |||||
url=url, | |||||
input=dict(input), | |||||
entry_args=entry_args, | |||||
other_metadata=other_metadata | |||||
) |
@@ -0,0 +1,8 @@ | |||||
from .stream_metrics import Metric, MetricCompose | |||||
from .accuracy import Accuracy, TopkAccuracy | |||||
from .confusion_matrix import ConfusionMatrix, IoU, mIoU | |||||
from .regression import * | |||||
from .average import AverageMetric | |||||
@@ -0,0 +1,118 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
import numpy as np | |||||
import torch | |||||
from kamal.core.metrics.stream_metrics import Metric | |||||
from typing import Callable | |||||
__all__=['Accuracy', 'TopkAccuracy'] | |||||
class Accuracy(Metric): | |||||
def __init__(self, attach_to=None): | |||||
super(Accuracy, self).__init__(attach_to=attach_to) | |||||
self.reset() | |||||
@torch.no_grad() | |||||
def update(self, outputs, targets): | |||||
outputs, targets = self._attach(outputs, targets) | |||||
outputs = outputs.max(1)[1] | |||||
self._correct += ( outputs.view(-1)==targets.view(-1) ).sum() | |||||
self._cnt += torch.numel( targets ) | |||||
def get_results(self): | |||||
return (self._correct / self._cnt).detach().cpu() | |||||
def reset(self): | |||||
self._correct = self._cnt = 0.0 | |||||
class TopkAccuracy(Metric): | |||||
def __init__(self, topk=5, attach_to=None): | |||||
super(TopkAccuracy, self).__init__(attach_to=attach_to) | |||||
self._topk = topk | |||||
self.reset() | |||||
@torch.no_grad() | |||||
def update(self, outputs, targets): | |||||
outputs, targets = self._attach(outputs, targets) | |||||
_, outputs = outputs.topk(self._topk, dim=1, largest=True, sorted=True) | |||||
correct = outputs.eq( targets.view(-1, 1).expand_as(outputs) ) | |||||
self._correct += correct[:, :self._topk].view(-1).float().sum(0).item() | |||||
self._cnt += len(targets) | |||||
def get_results(self): | |||||
return self._correct / self._cnt | |||||
def reset(self): | |||||
self._correct = 0.0 | |||||
self._cnt = 0.0 | |||||
class StreamCEMAPMetrics(): | |||||
@property | |||||
def PRIMARY_METRIC(self): | |||||
return "eap" | |||||
def __init__(self): | |||||
self.reset() | |||||
def update(self, logits, targets): | |||||
preds = logits.max(1)[1] | |||||
# targets: -1 negative, 0 difficult, 1 positive | |||||
if isinstance(preds, torch.Tensor): | |||||
preds = preds.cpu().numpy() | |||||
targets = targets.cpu().numpy() | |||||
self.preds = preds if self.preds is None else np.append(self.preds, preds, axis=0) | |||||
self.targets = targets if self.targets is None else np.append(self.targets, targets, axis=0) | |||||
def get_results(self): | |||||
nTest = self.targets.shape[0] | |||||
nLabel = self.targets.shape[1] | |||||
eap = np.zeros(nTest) | |||||
for i in range(0,nTest): | |||||
R = np.sum(self.targets[i,:]==1) | |||||
for j in range(0,nLabel): | |||||
if self.targets[i,j]==1: | |||||
r = np.sum(self.preds[i,np.nonzero(self.targets[i,:]!=0)]>=self.preds[i,j]) | |||||
rb = np.sum(self.preds[i,np.nonzero(self.targets[i,:]==1)] >= self.preds[i,j]) | |||||
eap[i] = eap[i] + rb/(r*1.0) | |||||
eap[i] = eap[i]/R | |||||
# emap = np.nanmean(ap) | |||||
cap = np.zeros(nLabel) | |||||
for i in range(0,nLabel): | |||||
R = np.sum(self.targets[:,i]==1) | |||||
for j in range(0,nTest): | |||||
if self.targets[j,i]==1: | |||||
r = np.sum(self.preds[np.nonzero(self.targets[:,i]!=0),i] >= self.preds[j,i]) | |||||
rb = np.sum(self.preds[np.nonzero(self.targets[:,i]==1),i] >= self.preds[j,i]) | |||||
cap[i] = cap[i] + rb/(r*1.0) | |||||
cap[i] = cap[i]/R | |||||
# cmap = np.nanmean(ap) | |||||
return { | |||||
'eap': eap, | |||||
'emap': np.nanmean(eap), | |||||
'cap': cap, | |||||
'cmap': np.nanmean(cap), | |||||
} | |||||
def reset(self): | |||||
self.preds = None | |||||
self.targets = None |
@@ -0,0 +1,49 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
import numpy as np | |||||
import torch | |||||
from kamal.core.metrics.stream_metrics import Metric | |||||
from typing import Callable | |||||
__all__=['AverageMetric'] | |||||
class AverageMetric(Metric): | |||||
def __init__(self, fn:Callable, attach_to=None): | |||||
super(AverageMetric, self).__init__(attach_to=attach_to) | |||||
self._fn = fn | |||||
self.reset() | |||||
@torch.no_grad() | |||||
def update(self, outputs, targets): | |||||
outputs, targets = self._attach(outputs, targets) | |||||
m = self._fn( outputs, targets ) | |||||
if m.ndim > 1: | |||||
self._cnt += m.shape[0] | |||||
self._accum += m.sum(0) | |||||
else: | |||||
self._cnt += 1 | |||||
self._accum += m | |||||
def get_results(self): | |||||
return (self._accum / self._cnt).detach().cpu() | |||||
def reset(self): | |||||
self._accum = 0. | |||||
self._cnt = 0. |
@@ -0,0 +1,68 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
from .stream_metrics import Metric | |||||
import torch | |||||
from typing import Callable | |||||
class ConfusionMatrix(Metric): | |||||
def __init__(self, num_classes, ignore_idx=None, attach_to=None): | |||||
super(ConfusionMatrix, self).__init__(attach_to=attach_to) | |||||
self._num_classes = num_classes | |||||
self._ignore_idx = ignore_idx | |||||
self.reset() | |||||
@torch.no_grad() | |||||
def update(self, outputs, targets): | |||||
outputs, targets = self._attach(outputs, targets) | |||||
if self.confusion_matrix.device != outputs.device: | |||||
self.confusion_matrix = self.confusion_matrix.to(device=outputs.device) | |||||
preds = outputs.max(1)[1].flatten() | |||||
targets = targets.flatten() | |||||
mask = (preds<self._num_classes) & (preds>=0) | |||||
if self._ignore_idx: | |||||
mask = mask & (targets!=self._ignore_idx) | |||||
preds, targets = preds[mask], targets[mask] | |||||
hist = torch.bincount( self._num_classes * targets + preds, | |||||
minlength=self._num_classes ** 2 ).view(self._num_classes, self._num_classes) | |||||
self.confusion_matrix += hist | |||||
def get_results(self): | |||||
return self.confusion_matrix.detach().cpu() | |||||
def reset(self): | |||||
self._cnt = 0 | |||||
self.confusion_matrix = torch.zeros(self._num_classes, self._num_classes, dtype=torch.int64, requires_grad=False) | |||||
class IoU(Metric): | |||||
def __init__(self, confusion_matrix: ConfusionMatrix, attach_to=None): | |||||
self._confusion_matrix = confusion_matrix | |||||
def update(self, outputs, targets): | |||||
pass # update will be done in confusion matrix | |||||
def reset(self): | |||||
pass | |||||
def get_results(self): | |||||
cm = self._confusion_matrix.get_results() | |||||
iou = cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) - cm.diag() + 1e-9) | |||||
return iou | |||||
class mIoU(IoU): | |||||
def get_results(self): | |||||
return super(mIoU, self).get_results().mean() |
@@ -0,0 +1,86 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
import numpy as np | |||||
import torch | |||||
from kamal.core.metrics.stream_metrics import StreamMetricsBase | |||||
class NormalPredictionMetrics(StreamMetricsBase): | |||||
@property | |||||
def PRIMARY_METRIC(self): | |||||
return 'mean angle' | |||||
def __init__(self, thresholds): | |||||
self.thresholds = thresholds | |||||
self.preds = None | |||||
self.targets = None | |||||
self.masks = None | |||||
@torch.no_grad() | |||||
def update(self, preds, targets, masks): | |||||
""" | |||||
**Type**: numpy.ndarray or torch.Tensor | |||||
**Shape:** | |||||
- **preds**: $(N, 3, H, W)$. | |||||
- **targets**: $(N, 3, H, W)$. | |||||
- **masks**: $(N, 1, H, W)$. | |||||
""" | |||||
if isinstance(preds, torch.Tensor): | |||||
preds = preds.cpu().numpy() | |||||
targets = targets.cpu().numpy() | |||||
masks = masks.cpu().numpy() | |||||
self.preds = preds if self.preds is None else np.append(self.preds, preds, axis=0) | |||||
self.targets = targets if self.targets is None else np.append(self.targets, targets, axis=0) | |||||
self.masks = masks if self.masks is None else np.append(self.masks, masks, axis=0) | |||||
def get_results(self): | |||||
""" | |||||
**Returns:** | |||||
- **mean angle** | |||||
- **median angle** | |||||
- **precents for angle within thresholds** | |||||
""" | |||||
products = np.sum(self.preds * self.targets, axis=1) | |||||
angles = np.arccos(np.clip(products, -1.0, 1.0)) / np.pi * 180 | |||||
self.masks = self.masks.squeeze(1) | |||||
angles = angles[self.masks == 1] | |||||
mean_angle = np.mean(angles) | |||||
median_angle = np.median(angles) | |||||
count = self.masks.sum() | |||||
threshold_percents = {} | |||||
for threshold in self.thresholds: | |||||
# threshold_percents[threshold] = np.sum((angles < threshold)) / count | |||||
threshold_percents[threshold] = np.mean(angles < threshold) | |||||
if return_key_metric: | |||||
return ('absolute relative', ard) | |||||
return { | |||||
'mean angle': mean_angle, | |||||
'median angle': median_angle, | |||||
'percents within thresholds': threshold_percents | |||||
} | |||||
def reset(self): | |||||
self.preds = None | |||||
self.targets = None | |||||
self.masks = None |
@@ -0,0 +1,199 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
import numpy as np | |||||
import torch | |||||
from kamal.core.metrics.stream_metrics import Metric | |||||
from typing import Callable | |||||
__all__=['MeanSquaredError', 'RootMeanSquaredError', 'MeanAbsoluteError', | |||||
'ScaleInveriantMeanSquaredError', 'RelativeDifference', | |||||
'AbsoluteRelativeDifference', 'SquaredRelativeDifference', 'Threshold' ] | |||||
class MeanSquaredError(Metric): | |||||
def __init__(self, log_scale=False, attach_to=None): | |||||
super(MeanSquaredError, self).__init__( attach_to=attach_to ) | |||||
self.reset() | |||||
self.log_scale=log_scale | |||||
@torch.no_grad() | |||||
def update(self, outputs, targets): | |||||
outputs, targets = self._attach(outputs, targets) | |||||
if self.log_scale: | |||||
diff = torch.sum((torch.log(outputs+1e-8) - torch.log(targets+1e-8))**2) | |||||
else: | |||||
diff = torch.sum((outputs - targets)**2) | |||||
self._accum_sq_diff += diff | |||||
self._cnt += torch.numel(outputs) | |||||
def get_results(self): | |||||
return (self._accum_sq_diff / self._cnt).detach().cpu() | |||||
def reset(self): | |||||
self._accum_sq_diff = 0. | |||||
self._cnt = 0. | |||||
class RootMeanSquaredError(MeanSquaredError): | |||||
def get_results(self): | |||||
return torch.sqrt( (self._accum_sq_diff / self._cnt) ).detach().cpu() | |||||
class MeanAbsoluteError(Metric): | |||||
def __init__(self, attach_to=None ): | |||||
super(MeanAbsoluteError, self).__init__( attach_to=attach_to ) | |||||
self.reset() | |||||
@torch.no_grad() | |||||
def update(self, outputs, targets): | |||||
outputs, targets = self._attach(outputs, targets) | |||||
diff = torch.sum((outputs - targets).abs()) | |||||
self._accum_abs_diff += diff | |||||
self._cnt += torch.numel(outputs) | |||||
def get_results(self): | |||||
return (self._accum_abs_diff / self._cnt).detach().cpu() | |||||
def reset(self): | |||||
self._accum_abs_diff = 0. | |||||
self._cnt = 0. | |||||
class ScaleInveriantMeanSquaredError(Metric): | |||||
def __init__(self, attach_to=None ): | |||||
super(ScaleInveriantMeanSquaredError, self).__init__( attach_to=attach_to ) | |||||
self.reset() | |||||
@torch.no_grad() | |||||
def update(self, outputs, targets): | |||||
outputs, targets = self._attach(outputs, targets) | |||||
diff_log = torch.log( outputs+1e-8 ) - torch.log( targets+1e-8 ) | |||||
self._accum_log_diff = diff_log.sum() | |||||
self._accum_sq_log_diff = (diff_log**2).sum() | |||||
self._cnt += torch.numel(outputs) | |||||
def get_results(self): | |||||
return ( self._accum_sq_log_diff / self._cnt - 0.5 * (self._accum_log_diff**2 / self._cnt**2) ).detach().cpu() | |||||
def reset(self): | |||||
self._accum_log_diff = 0. | |||||
self._accum_sq_log_diff = 0. | |||||
self._cnt = 0. | |||||
class RelativeDifference(Metric): | |||||
def __init__(self, attach_to=None ): | |||||
super(RelativeDifference, self).__init__( attach_to=attach_to ) | |||||
self.reset() | |||||
@torch.no_grad() | |||||
def update(self, outputs, targets): | |||||
outputs, targets = self._attach(outputs, targets) | |||||
diff = (outputs - targets).abs() | |||||
self._accum_abs_rel += (diff/targets).sum() | |||||
self._cnt += torch.numel(outputs) | |||||
def get_results(self): | |||||
return (self._accum_abs_rel / self._cnt).detach().cpu() | |||||
def reset(self): | |||||
self._accum_abs_rel = 0. | |||||
self._cnt = 0. | |||||
class AbsoluteRelativeDifference(Metric): | |||||
def __init__(self, attach_to=None ): | |||||
super(AbsoluteRelativeDifference, self).__init__( attach_to=attach_to ) | |||||
self.reset() | |||||
@torch.no_grad() | |||||
def update(self, outputs, targets): | |||||
outputs, targets = self._attach(outputs, targets) | |||||
diff = (outputs - targets).abs() | |||||
self._accum_abs_rel += (diff/targets).sum() | |||||
self._cnt += torch.numel(outputs) | |||||
def get_results(self): | |||||
return (self._accum_abs_rel / self._cnt).detach().cpu() | |||||
def reset(self): | |||||
self._accum_abs_rel = 0. | |||||
self._cnt = 0. | |||||
class AbsoluteRelativeDifference(Metric): | |||||
def __init__(self, attach_to=None ): | |||||
super(AbsoluteRelativeDifference, self).__init__( attach_to=attach_to ) | |||||
self.reset() | |||||
@torch.no_grad() | |||||
def update(self, outputs, targets): | |||||
outputs, targets = self._attach(outputs, targets) | |||||
diff = (outputs - targets).abs() | |||||
self._accum_abs_rel += (diff/targets).sum() | |||||
self._cnt += torch.numel(outputs) | |||||
def get_results(self): | |||||
return (self._accum_abs_rel / self._cnt).detach().cpu() | |||||
def reset(self): | |||||
self._accum_abs_rel = 0. | |||||
self._cnt = 0. | |||||
class SquaredRelativeDifference(Metric): | |||||
def __init__(self, attach_to=None ): | |||||
super(SquaredRelativeDifference, self).__init__( attach_to=attach_to ) | |||||
self.reset() | |||||
@torch.no_grad() | |||||
def update(self, outputs, targets): | |||||
outputs, targets = self._attach(outputs, targets) | |||||
diff = (outputs - targets)**2 | |||||
self._accum_sq_rel += (diff/targets).sum() | |||||
self._cnt += torch.numel(outputs) | |||||
def get_results(self): | |||||
return (self._accum_sq_rel / self._cnt).detach().cpu() | |||||
def reset(self): | |||||
self._accum_sq_rel = 0. | |||||
self._cnt = 0. | |||||
class Threshold(Metric): | |||||
def __init__(self, thresholds=[1.25, 1.25**2, 1.25**3], attach_to=None ): | |||||
super(Threshold, self).__init__( attach_to=attach_to ) | |||||
self.thresholds = thresholds | |||||
self.reset() | |||||
@torch.no_grad() | |||||
def update(self, outputs, targets): | |||||
outputs, targets = self._attach(outputs, targets) | |||||
sigma = torch.max(outputs / targets, targets / outputs) | |||||
for thres in self.thresholds: | |||||
self._accum_thres[thres]+=torch.sum( sigma<thres ) | |||||
self._cnt += torch.numel(outputs) | |||||
def get_results(self): | |||||
return { thres: (self._accum_thres[thres] / self._cnt).detach().cpu() for thres in self.thresholds } | |||||
def reset(self): | |||||
self._cnt = 0. | |||||
self._accum_thres = {thres: 0. for thres in self.thresholds} |
@@ -0,0 +1,82 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
from __future__ import division | |||||
import torch | |||||
import numpy as np | |||||
from abc import ABC, abstractmethod | |||||
from typing import Callable, Union, Any, Mapping, Sequence | |||||
import numbers | |||||
from kamal.core.attach import AttachTo | |||||
class Metric(ABC): | |||||
def __init__(self, attach_to=None): | |||||
self._attach = AttachTo(attach_to) | |||||
@abstractmethod | |||||
def update(self, pred, target): | |||||
""" Overridden by subclasses """ | |||||
raise NotImplementedError() | |||||
@abstractmethod | |||||
def get_results(self): | |||||
""" Overridden by subclasses """ | |||||
raise NotImplementedError() | |||||
@abstractmethod | |||||
def reset(self): | |||||
""" Overridden by subclasses """ | |||||
raise NotImplementedError() | |||||
class MetricCompose(dict): | |||||
def __init__(self, metric_dict: Mapping): | |||||
self._metric_dict = metric_dict | |||||
def add_metrics( self, metric_dict: Mapping): | |||||
if isinstance(metric_dict, MetricCompose): | |||||
metric_dict = metric_dict.metrics | |||||
self._metric_dict.update(metric_dict) | |||||
return self | |||||
@property | |||||
def metrics(self): | |||||
return self._metric_dict | |||||
@torch.no_grad() | |||||
def update(self, outputs, targets): | |||||
for key, metric in self._metric_dict.items(): | |||||
if isinstance(metric, Metric): | |||||
metric.update(outputs, targets) | |||||
def get_results(self): | |||||
results = {} | |||||
for key, metric in self._metric_dict.items(): | |||||
if isinstance(metric, Metric): | |||||
results[key] = metric.get_results() | |||||
return results | |||||
def reset(self): | |||||
for key, metric in self._metric_dict.items(): | |||||
if isinstance(metric, Metric): | |||||
metric.reset() | |||||
def __getitem__(self, name): | |||||
return self._metric_dict[name] | |||||
@@ -0,0 +1,3 @@ | |||||
from .task import StandardTask, StandardMetrics, GeneralTask, Task, TaskCompose | |||||
from . import loss | |||||
@@ -0,0 +1,2 @@ | |||||
from . import functional, loss | |||||
from .loss import * |
@@ -0,0 +1,107 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
import torch | |||||
import torch.nn.functional as F | |||||
import torch.nn as nn | |||||
def kldiv(logits, targets, T=1.0): | |||||
""" Cross Entropy for soft targets | |||||
Parameters: | |||||
- logits (Tensor): logits score (e.g. outputs of fc layer) | |||||
- targets (Tensor): logits of soft targets | |||||
- T (float): temperature of distill | |||||
- reduction: reduction to the output | |||||
""" | |||||
p_targets = F.softmax(targets/T, dim=1) | |||||
logp_logits = F.log_softmax(logits/T, dim=1) | |||||
kld = F.kl_div(logp_logits, p_targets, reduction='none') * (T**2) | |||||
return kld.sum(1).mean() | |||||
def jsdiv(logits, targets, T=1.0, reduction='mean'): | |||||
p = F.softmax(logits, dim=1) | |||||
q = F.softmax(targets, dim=1) | |||||
log_m = torch.log( (p+q) / 2 ) | |||||
return 0.5* ( F.kl_div( log_m, p, reduction=reduction) + F.kl_div( log_m, q, reduction=reduction) ) | |||||
def mmd_loss(f1, f2, sigmas, normalized=False): | |||||
if len(f1.shape) != 2: | |||||
N, C, H, W = f1.shape | |||||
f1 = f1.view(N, -1) | |||||
N, C, H, W = f2.shape | |||||
f2 = f2.view(N, -1) | |||||
if normalized == True: | |||||
f1 = F.normalize(f1, p=2, dim=1) | |||||
f2 = F.normalize(f2, p=2, dim=1) | |||||
return _mmd_rbf2(f1, f2, sigmas=sigmas) | |||||
def psnr(img1, img2, size_average=True, data_range=255): | |||||
N = img1.shape[0] | |||||
mse = torch.mean(((img1-img2)**2).view(N, -1), dim=1) | |||||
psnr = torch.clamp(torch.log10(data_range**2 / mse) * 10, 0.0, 99.99) | |||||
if size_average == True: | |||||
psnr = psnr.mean() | |||||
return psnr | |||||
def soft_cross_entropy(logits, targets, T=1.0, size_average=True): | |||||
""" Cross Entropy for soft targets | |||||
**Parameters:** | |||||
- **logits** (Tensor): logits score (e.g. outputs of fc layer) | |||||
- **targets** (Tensor): logits of soft targets | |||||
- **T** (float): temperature of distill | |||||
- **size_average**: average the outputs | |||||
""" | |||||
p_targets = F.softmax(targets/T, dim=1) | |||||
logp_pred = F.log_softmax(logits/T, dim=1) | |||||
ce = torch.sum(-p_targets * logp_pred, dim=1) | |||||
if size_average: | |||||
return ce.mean() * T * T | |||||
else: | |||||
return ce * T * T | |||||
def _mmd_rbf2(x, y, sigmas=None): | |||||
N, _ = x.shape | |||||
xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t()) | |||||
rx = (xx.diag().unsqueeze(0).expand_as(xx)) | |||||
ry = (yy.diag().unsqueeze(0).expand_as(yy)) | |||||
K = L = P = 0.0 | |||||
XX2 = rx.t() + rx - 2*xx | |||||
YY2 = ry.t() + ry - 2*yy | |||||
XY2 = rx.t() + ry - 2*zz | |||||
if sigmas is None: | |||||
sigma2 = torch.mean((XX2.detach()+YY2.detach()+2*XY2.detach()) / 4) | |||||
sigmas2 = [sigma2/4, sigma2/2, sigma2, sigma2*2, sigma2*4] | |||||
alphas = [1.0 / (2 * sigma2) for sigma2 in sigmas2] | |||||
else: | |||||
alphas = [1.0 / (2 * sigma**2) for sigma in sigmas] | |||||
for alpha in alphas: | |||||
K += torch.exp(- alpha * (XX2.clamp(min=1e-12))) | |||||
L += torch.exp(- alpha * (YY2.clamp(min=1e-12))) | |||||
P += torch.exp(- alpha * (XY2.clamp(min=1e-12))) | |||||
beta = (1./(N*(N))) | |||||
gamma = (2./(N*N)) | |||||
return F.relu(beta * (torch.sum(K)+torch.sum(L)) - gamma * torch.sum(P)) |
@@ -0,0 +1,386 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
import torch | |||||
import torch.nn.functional as F | |||||
import torch.nn as nn | |||||
import numpy as np | |||||
#from pytorch_msssim import ssim, ms_ssim, MS_SSIM, SSIM | |||||
from .functional import * | |||||
class KLDiv(object): | |||||
def __init__(self, T=1.0): | |||||
self.T = T | |||||
def __call__(self, logits, targets): | |||||
return kldiv( logits, targets, T=self.T ) | |||||
class KDLoss(nn.Module): | |||||
""" KD Loss Function | |||||
""" | |||||
def __init__(self, T=1.0, alpha=1.0, use_kldiv=False): | |||||
super(KDLoss, self).__init__() | |||||
self.T = T | |||||
self.alpha = alpha | |||||
self.kdloss = kldiv if use_kldiv else soft_cross_entropy | |||||
def forward(self, logits, targets, hard_targets=None): | |||||
loss = self.kdloss(logits, targets, T=self.T) | |||||
if hard_targets is not None and self.alpha != 0.0: | |||||
loss += self.alpha*F.cross_entropy(logits, hard_targets) | |||||
return loss | |||||
class CFLLoss(nn.Module): | |||||
""" Common Feature Learning Loss | |||||
CFL Loss = MMD + MSE | |||||
""" | |||||
def __init__(self, sigmas, normalized=True): | |||||
super(CFLLoss, self).__init__() | |||||
self.sigmas = sigmas | |||||
self.normalized = normalized | |||||
def forward(self, hs, hts, fts_, fts): | |||||
mmd = mse = 0.0 | |||||
for ht_i in hts: | |||||
mmd += mmd_loss(hs, ht_i, sigmas=self.sigmas, normalized=self.normalized) | |||||
for i in range(len(fts_)): | |||||
mse += F.mse_loss(fts_[i], fts[i]) | |||||
return mmd, mse | |||||
class PSNR_Loss(nn.Module): | |||||
def __init__(self, data_range=1.0, size_average=True): | |||||
super(PSNR_Loss, self).__init__() | |||||
self.data_range = data_range | |||||
self.size_average = size_average | |||||
def forward(self, img1, img2): | |||||
return 100 - psnr(img1, img2, size_average=self.size_average, data_range=self.data_range) | |||||
#class MS_SSIM_Loss(MS_SSIM): | |||||
# def forward(self, img1, img2): | |||||
# return 100*(1 - super(MS_SSIM_Loss, self).forward(img1, img2)) | |||||
class ScaleInvariantLoss(nn.Module): | |||||
"""This criterion is used in depth prediction task. | |||||
**Parameters:** | |||||
- **la** (int, optional): Default value is 0.5. No need to change. | |||||
- **ignore_index** (int, optional): Value to ignore. | |||||
**Shape:** | |||||
- **inputs**: $(N, H, W)$. | |||||
- **targets**: $(N, H, W)$. | |||||
- **output**: scalar. | |||||
""" | |||||
def __init__(self, la=0.5, ignore_index=0): | |||||
super(ScaleInvariantLoss, self).__init__() | |||||
self.la = la | |||||
self.ignore_index = ignore_index | |||||
def forward(self, inputs, targets): | |||||
size = inputs.size() | |||||
if len(size) == 3: | |||||
inputs = inputs.view(size[0], -1) | |||||
targets = targets.view(size[0], -1) | |||||
inv_mask = targets.eq(self.ignore_index) | |||||
nums = (1-inv_mask.float()).sum(1) | |||||
log_d = torch.log(inputs) - torch.log(targets) | |||||
log_d[inv_mask] = 0 | |||||
loss = torch.div(torch.pow(log_d, 2).sum(1), nums) - \ | |||||
self.la * torch.pow(torch.div(log_d.sum(1), nums), 2) | |||||
return loss.mean() | |||||
class AngleLoss(nn.Module): | |||||
"""This criterion is used in surface normal prediction task. | |||||
**Shape:** | |||||
- **inputs**: $(N, 3, H, W)$. Predicted space vector for each pixel. Must be formalized before. | |||||
- **targets**: $(N, 3, H, W)$. Ground truth. Must be formalized before. | |||||
- **masks**: $(N, 1, H, W)$. One for valid pixels, else zero. | |||||
- **output**: scalar. | |||||
""" | |||||
def forward(self, inputs, targets, masks): | |||||
nums = masks.sum(dim=[1,2,3]) | |||||
product = (inputs * targets).sum(1, keepdim=True) | |||||
loss = -torch.div((product * masks.float()).sum([1,2,3]), nums) | |||||
return loss.mean() | |||||
class FocalLoss(nn.Module): | |||||
def __init__(self, alpha=1, gamma=0, size_average=True, ignore_index=255): | |||||
super(FocalLoss, self).__init__() | |||||
self.alpha = alpha | |||||
self.gamma = gamma | |||||
self.ignore_index = ignore_index | |||||
self.size_average = size_average | |||||
def forward(self, inputs, targets): | |||||
ce_loss = F.cross_entropy(inputs, targets, reduction='none', ignore_index=self.ignore_index) | |||||
pt = torch.exp(-ce_loss) | |||||
focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss | |||||
if self.size_average: | |||||
return focal_loss.mean() | |||||
else: | |||||
return focal_loss.sum() | |||||
class AttentionLoss(nn.Module): | |||||
""" Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer""" | |||||
def __init__(self, p=2): | |||||
super(AttentionLoss, self).__init__() | |||||
self.p = p | |||||
def forward(self, g_s, g_t): | |||||
return sum([self.at_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]) | |||||
def at_loss(self, f_s, f_t): | |||||
s_H, t_H = f_s.shape[2], f_t.shape[2] | |||||
if s_H > t_H: | |||||
f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H)) | |||||
elif s_H < t_H: | |||||
f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H)) | |||||
else: | |||||
pass | |||||
return (self.at(f_s) - self.at(f_t)).pow(2).mean() | |||||
def at(self, f): | |||||
return F.normalize(f.pow(self.p).mean(1).view(f.size(0), -1)) | |||||
class NSTLoss(nn.Module): | |||||
"""like what you like: knowledge distill via neuron selectivity transfer""" | |||||
def __init__(self): | |||||
super(NSTLoss, self).__init__() | |||||
pass | |||||
def forward(self, g_s, g_t): | |||||
return sum([self.nst_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]) | |||||
def nst_loss(self, f_s, f_t): | |||||
s_H, t_H = f_s.shape[2], f_t.shape[2] | |||||
if s_H > t_H: | |||||
f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H)) | |||||
elif s_H < t_H: | |||||
f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H)) | |||||
else: | |||||
pass | |||||
f_s = f_s.view(f_s.shape[0], f_s.shape[1], -1) | |||||
f_s = F.normalize(f_s, dim=2) | |||||
f_t = f_t.view(f_t.shape[0], f_t.shape[1], -1) | |||||
f_t = F.normalize(f_t, dim=2) | |||||
# set full_loss as False to avoid unnecessary computation | |||||
full_loss = True | |||||
if full_loss: | |||||
return (self.poly_kernel(f_t, f_t).mean().detach() + self.poly_kernel(f_s, f_s).mean() | |||||
- 2 * self.poly_kernel(f_s, f_t).mean()) | |||||
else: | |||||
return self.poly_kernel(f_s, f_s).mean() - 2 * self.poly_kernel(f_s, f_t).mean() | |||||
def poly_kernel(self, a, b): | |||||
a = a.unsqueeze(1) | |||||
b = b.unsqueeze(2) | |||||
res = (a * b).sum(-1).pow(2) | |||||
return res | |||||
class SPLoss(nn.Module): | |||||
"""Similarity-Preserving Knowledge Distillation, ICCV2019, verified by original author""" | |||||
def __init__(self): | |||||
super(SPLoss, self).__init__() | |||||
def forward(self, g_s, g_t): | |||||
return sum([self.similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]) | |||||
def similarity_loss(self, f_s, f_t): | |||||
bsz = f_s.shape[0] | |||||
f_s = f_s.view(bsz, -1) | |||||
f_t = f_t.view(bsz, -1) | |||||
G_s = torch.mm(f_s, torch.t(f_s)) | |||||
# G_s = G_s / G_s.norm(2) | |||||
G_s = torch.nn.functional.normalize(G_s) | |||||
G_t = torch.mm(f_t, torch.t(f_t)) | |||||
# G_t = G_t / G_t.norm(2) | |||||
G_t = torch.nn.functional.normalize(G_t) | |||||
G_diff = G_t - G_s | |||||
loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz) | |||||
return loss | |||||
class RKDLoss(nn.Module): | |||||
"""Relational Knowledge Disitllation, CVPR2019""" | |||||
def __init__(self, w_d=25, w_a=50): | |||||
super(RKDLoss, self).__init__() | |||||
self.w_d = w_d | |||||
self.w_a = w_a | |||||
def forward(self, f_s, f_t): | |||||
student = f_s.view(f_s.shape[0], -1) | |||||
teacher = f_t.view(f_t.shape[0], -1) | |||||
# RKD distance loss | |||||
with torch.no_grad(): | |||||
t_d = self.pdist(teacher, squared=False) | |||||
mean_td = t_d[t_d > 0].mean() | |||||
t_d = t_d / mean_td | |||||
d = self.pdist(student, squared=False) | |||||
mean_d = d[d > 0].mean() | |||||
d = d / mean_d | |||||
loss_d = F.smooth_l1_loss(d, t_d) | |||||
# RKD Angle loss | |||||
with torch.no_grad(): | |||||
td = (teacher.unsqueeze(0) - teacher.unsqueeze(1)) | |||||
norm_td = F.normalize(td, p=2, dim=2) | |||||
t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1) | |||||
sd = (student.unsqueeze(0) - student.unsqueeze(1)) | |||||
norm_sd = F.normalize(sd, p=2, dim=2) | |||||
s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1) | |||||
loss_a = F.smooth_l1_loss(s_angle, t_angle) | |||||
loss = self.w_d * loss_d + self.w_a * loss_a | |||||
return loss | |||||
@staticmethod | |||||
def pdist(e, squared=False, eps=1e-12): | |||||
e_square = e.pow(2).sum(dim=1) | |||||
prod = e @ e.t() | |||||
res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps) | |||||
if not squared: | |||||
res = res.sqrt() | |||||
res = res.clone() | |||||
res[range(len(e)), range(len(e))] = 0 | |||||
return res | |||||
class PKTLoss(nn.Module): | |||||
"""Probabilistic Knowledge Transfer for deep representation learning""" | |||||
def __init__(self): | |||||
super(PKTLoss, self).__init__() | |||||
def forward(self, f_s, f_t): | |||||
return self.cosine_similarity_loss(f_s, f_t) | |||||
@staticmethod | |||||
def cosine_similarity_loss(output_net, target_net, eps=0.0000001): | |||||
# Normalize each vector by its norm | |||||
output_net_norm = torch.sqrt(torch.sum(output_net ** 2, dim=1, keepdim=True)) | |||||
output_net = output_net / (output_net_norm + eps) | |||||
output_net[output_net != output_net] = 0 | |||||
target_net_norm = torch.sqrt(torch.sum(target_net ** 2, dim=1, keepdim=True)) | |||||
target_net = target_net / (target_net_norm + eps) | |||||
target_net[target_net != target_net] = 0 | |||||
# Calculate the cosine similarity | |||||
model_similarity = torch.mm(output_net, output_net.transpose(0, 1)) | |||||
target_similarity = torch.mm(target_net, target_net.transpose(0, 1)) | |||||
# Scale cosine similarity to 0..1 | |||||
model_similarity = (model_similarity + 1.0) / 2.0 | |||||
target_similarity = (target_similarity + 1.0) / 2.0 | |||||
# Transform them into probabilities | |||||
model_similarity = model_similarity / torch.sum(model_similarity, dim=1, keepdim=True) | |||||
target_similarity = target_similarity / torch.sum(target_similarity, dim=1, keepdim=True) | |||||
# Calculate the KL-divergence | |||||
loss = torch.mean(target_similarity * torch.log((target_similarity + eps) / (model_similarity + eps))) | |||||
return loss | |||||
class SVDLoss(nn.Module): | |||||
""" | |||||
Self-supervised Knowledge Distillation using Singular Value Decomposition | |||||
""" | |||||
def __init__(self, k=1): | |||||
super(SVDLoss, self).__init__() | |||||
self.k = k | |||||
def forward(self, g_s, g_t): | |||||
v_sb = None | |||||
v_tb = None | |||||
losses = [] | |||||
for i, f_s, f_t in zip(range(len(g_s)), g_s, g_t): | |||||
u_t, s_t, v_t = self.svd(f_t, self.k) | |||||
u_s, s_s, v_s = self.svd(f_s, self.k + 3) | |||||
v_s, v_t = self.align_rsv(v_s, v_t) | |||||
s_t = s_t.unsqueeze(1) | |||||
v_t = v_t * s_t | |||||
v_s = v_s * s_t | |||||
if i > 0: | |||||
s_rbf = torch.exp(-(v_s.unsqueeze(2) - v_sb.unsqueeze(1)).pow(2) / 8) | |||||
t_rbf = torch.exp(-(v_t.unsqueeze(2) - v_tb.unsqueeze(1)).pow(2) / 8) | |||||
l2loss = (s_rbf - t_rbf.detach()).pow(2) | |||||
l2loss = torch.where(torch.isfinite(l2loss), l2loss, torch.zeros_like(l2loss)) | |||||
losses.append(l2loss.sum()) | |||||
v_tb = v_t | |||||
v_sb = v_s | |||||
bsz = g_s[0].shape[0] | |||||
losses = [l / bsz for l in losses] | |||||
return sum(losses) | |||||
def svd(self, feat, n=1): | |||||
size = feat.shape | |||||
assert len(size) == 4 | |||||
x = feat.view(-1, size[1], size[2] * size[2]).transpose(-2, -1) | |||||
u, s, v = torch.svd(x) | |||||
u = self.removenan(u) | |||||
s = self.removenan(s) | |||||
v = self.removenan(v) | |||||
if n > 0: | |||||
u = F.normalize(u[:, :, :n], dim=1) | |||||
s = F.normalize(s[:, :n], dim=1) | |||||
v = F.normalize(v[:, :, :n], dim=1) | |||||
return u, s, v | |||||
@staticmethod | |||||
def removenan(x): | |||||
x = torch.where(torch.isfinite(x), x, torch.zeros_like(x)) | |||||
return x | |||||
@staticmethod | |||||
def align_rsv(a, b): | |||||
cosine = torch.matmul(a.transpose(-2, -1), b) | |||||
max_abs_cosine, _ = torch.max(torch.abs(cosine), 1, keepdim=True) | |||||
mask = torch.where(torch.eq(max_abs_cosine, torch.abs(cosine)), | |||||
torch.sign(cosine), torch.zeros_like(cosine)) | |||||
a = torch.matmul(a, mask) | |||||
return a, b | |||||
@@ -0,0 +1,186 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
import abc | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
import sys | |||||
import typing | |||||
from typing import Callable, Dict, List, Any | |||||
from collections import Mapping, Sequence | |||||
from . import loss | |||||
from kamal.core import metrics, exceptions | |||||
from kamal.core.attach import AttachTo | |||||
class Task(object): | |||||
def __init__(self, name): | |||||
self.name = name | |||||
@abc.abstractmethod | |||||
def get_loss( self, outputs, targets ) -> Dict: | |||||
pass | |||||
@abc.abstractmethod | |||||
def predict(self, outputs) -> Any: | |||||
pass | |||||
class GeneralTask(Task): | |||||
def __init__(self, | |||||
name: str, | |||||
loss_fn: Callable, | |||||
scaling:float=1.0, | |||||
pred_fn: Callable=lambda x: x, | |||||
attach_to=None): | |||||
super(GeneralTask, self).__init__(name) | |||||
self._attach = AttachTo(attach_to) | |||||
self.loss_fn = loss_fn | |||||
self.pred_fn = pred_fn | |||||
self.scaling = scaling | |||||
def get_loss(self, outputs, targets): | |||||
outputs, targets = self._attach(outputs, targets) | |||||
return { self.name: self.loss_fn( outputs, targets ) * self.scaling } | |||||
def predict(self, outputs): | |||||
outputs = self._attach(outputs) | |||||
return self.pred_fn(outputs) | |||||
def __repr__(self): | |||||
rep = "Task: [%s loss_fn=%s scaling=%.4f attach=%s]"%(self.name, str(self.loss_fn), self.scaling, self._attach) | |||||
return rep | |||||
class TaskCompose(list): | |||||
def __init__(self, tasks: list): | |||||
for task in tasks: | |||||
if isinstance(task, Task): | |||||
self.append(task) | |||||
def get_loss(self, outputs, targets): | |||||
loss_dict = {} | |||||
for task in self: | |||||
loss_dict.update( task.get_loss( outputs, targets ) ) | |||||
return loss_dict | |||||
def predict(self, outputs): | |||||
results = [] | |||||
for task in self: | |||||
results.append( task.predict( outputs ) ) | |||||
return results | |||||
def __repr__(self): | |||||
rep="TaskCompose: \n" | |||||
for task in self: | |||||
rep+="\t%s\n"%task | |||||
class StandardTask: | |||||
@staticmethod | |||||
def classification(name='ce', scaling=1.0, attach_to=None): | |||||
return GeneralTask( name=name, | |||||
loss_fn=nn.CrossEntropyLoss(), | |||||
scaling=scaling, | |||||
pred_fn=lambda x: x.max(1)[1], | |||||
attach_to=attach_to ) | |||||
@staticmethod | |||||
def binary_classification(name='bce', scaling=1.0, attach_to=None): | |||||
return GeneralTask(name=name, | |||||
loss_fn=F.binary_cross_entropy_with_logits, | |||||
scaling=scaling, | |||||
pred_fn=lambda x: (x>0.5), | |||||
attach_to=attach_to ) | |||||
@staticmethod | |||||
def regression(name='mse', scaling=1.0, attach_to=None): | |||||
return GeneralTask(name=name, | |||||
loss_fn=nn.MSELoss(), | |||||
scaling=scaling, | |||||
pred_fn=lambda x: x, | |||||
attach_to=attach_to ) | |||||
@staticmethod | |||||
def segmentation(name='ce', scaling=1.0, attach_to=None): | |||||
return GeneralTask(name=name, | |||||
loss_fn=nn.CrossEntropyLoss(ignore_index=255), | |||||
scaling=scaling, | |||||
pred_fn=lambda x: x.max(1)[1], | |||||
attach_to=attach_to ) | |||||
@staticmethod | |||||
def monocular_depth(name='l1', scaling=1.0, attach_to=None): | |||||
return GeneralTask(name=name, | |||||
loss_fn=nn.L1Loss(), | |||||
scaling=scaling, | |||||
pred_fn=lambda x: x, | |||||
attach_to=attach_to) | |||||
@staticmethod | |||||
def detection(): | |||||
raise NotImplementedError | |||||
@staticmethod | |||||
def distillation(name='kld', T=1.0, scaling=1.0, attach_to=None): | |||||
return GeneralTask(name=name, | |||||
loss_fn=loss.KLDiv(T=T), | |||||
scaling=scaling, | |||||
pred_fn=lambda x: x.max(1)[1], | |||||
attach_to=attach_to) | |||||
class StandardMetrics(object): | |||||
@staticmethod | |||||
def classification(attach_to=None): | |||||
return metrics.MetricCompose( | |||||
metric_dict={'acc': metrics.Accuracy(attach_to=attach_to)} | |||||
) | |||||
@staticmethod | |||||
def regression(attach_to=None): | |||||
return metrics.MetricCompose( | |||||
metric_dict={'mse': metrics.MeanSquaredError(attach_to=attach_to)} | |||||
) | |||||
@staticmethod | |||||
def segmentation(num_classes, ignore_idx=255, attach_to=None): | |||||
confusion_matrix = metrics.ConfusionMatrix(num_classes=num_classes, ignore_idx=ignore_idx, attach_to=attach_to) | |||||
return metrics.MetricCompose( | |||||
metric_dict={'acc': metrics.Accuracy(attach_to=attach_to), | |||||
'confusion_matrix': confusion_matrix , | |||||
'miou': metrics.mIoU(confusion_matrix)} | |||||
) | |||||
@staticmethod | |||||
def monocular_depth(attach_to=None): | |||||
return metrics.MetricCompose( | |||||
metric_dict={ | |||||
'rmse': metrics.RootMeanSquaredError(attach_to=attach_to), | |||||
'rmse_log': metrics.RootMeanSquaredError( log_scale=True,attach_to=attach_to ), | |||||
'rmse_scale_inv': metrics.ScaleInveriantMeanSquaredError(attach_to=attach_to), | |||||
'abs rel': metrics.AbsoluteRelativeDifference(attach_to=attach_to), | |||||
'sq rel': metrics.SquaredRelativeDifference(attach_to=attach_to), | |||||
'percents within thresholds': metrics.Threshold( thresholds=[1.25, 1.25**2, 1.25**3], attach_to=attach_to ) | |||||
} | |||||
) | |||||
@staticmethod | |||||
def loss_metric(loss_fn): | |||||
return metrics.MetricCompose( | |||||
metric_dict={ | |||||
'loss': metrics.AverageMetric( loss_fn ) | |||||
} | |||||
) |
@@ -0,0 +1,2 @@ | |||||
from .prunning import Pruner, strategy | |||||
from .distillation import * |
@@ -0,0 +1,12 @@ | |||||
from .kd import KDDistiller | |||||
from .hint import * | |||||
from .attention import AttentionDistiller | |||||
from .nst import NSTDistiller | |||||
from .sp import SPDistiller | |||||
from .rkd import RKDDistiller | |||||
from .pkt import PKTDistiller | |||||
from .svd import SVDDistiller | |||||
from .cc import * | |||||
from .vid import * | |||||
from . import data_free |
@@ -0,0 +1,47 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
from .kd import KDDistiller | |||||
from kamal.core.tasks.loss import AttentionLoss | |||||
from kamal.core.tasks.loss import KDLoss | |||||
from kamal.utils import set_mode, move_to_device | |||||
import torch | |||||
import torch.nn as nn | |||||
import time | |||||
class AttentionDistiller(KDDistiller): | |||||
def __init__(self, logger=None, tb_writer=None ): | |||||
super(AttentionDistiller, self).__init__( logger, tb_writer ) | |||||
def setup(self, | |||||
student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, | |||||
stu_hooks=[], tea_hooks=[], out_flags=[], device=None): | |||||
super(AttentionDistiller, self).setup( | |||||
student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device) | |||||
self.stu_hooks = stu_hooks | |||||
self.tea_hooks = tea_hooks | |||||
self.out_flags = out_flags | |||||
self._at_loss = AttentionLoss() | |||||
def additional_kd_loss(self, engine, batch): | |||||
feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] | |||||
feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] | |||||
g_s = feat_s[1:-1] | |||||
g_t = feat_t[1:-1] | |||||
return self._at_loss(g_s, g_t) |
@@ -0,0 +1,55 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
from .kd import KDDistiller | |||||
from kamal.core.tasks.loss import KDLoss | |||||
import torch.nn as nn | |||||
import torch._ops | |||||
import time | |||||
class CCDistiller(KDDistiller): | |||||
def __init__(self, logger=None, tb_writer=None ): | |||||
super(CCDistiller, self).__init__( logger, tb_writer ) | |||||
def setup(self, student, teacher, dataloader, optimizer, embed_s, embed_t, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None ): | |||||
super(CCDistiller, self).setup( | |||||
student, teacher, dataloader, optimizer, T=T, gamma=gamma, alpha=alpha, device=device) | |||||
self.embed_s = embed_s.to(self.device).train() | |||||
self.embed_t = embed_t.to(self.device).train() | |||||
self.stu_hooks = stu_hooks | |||||
self.tea_hooks = tea_hooks | |||||
self.out_flags = out_flags | |||||
def additional_kd_loss(self, engine, batch): | |||||
feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] | |||||
feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] | |||||
f_s = self.embed_s(feat_s[-1]) | |||||
f_t = self.embed_t(feat_t[-1]) | |||||
return torch.mean((torch.abs(f_s-f_t)[:-1] * torch.abs(f_s-f_t)[1:]).sum(1)) | |||||
class LinearEmbed(nn.Module): | |||||
def __init__(self, dim_in=1024, dim_out=128): | |||||
super(LinearEmbed, self).__init__() | |||||
self.linear = nn.Linear(dim_in, dim_out) | |||||
def forward(self, x): | |||||
x = x.view(x.shape[0], -1) | |||||
x = self.linear(x) | |||||
return x |
@@ -0,0 +1 @@ | |||||
from .zskt import ZSKTDistiller |
@@ -0,0 +1,99 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
import torch | |||||
import time | |||||
import torch.nn.functional as F | |||||
from kamal.slim.distillation.kd import KDDistiller | |||||
from kamal.utils import set_mode | |||||
from kamal.core.tasks.loss import kldiv | |||||
class ZSKTDistiller(KDDistiller): | |||||
def __init__( self, | |||||
student, | |||||
teacher, | |||||
generator, | |||||
z_dim, | |||||
logger=None, | |||||
viz=None): | |||||
super(ZSKTDistiller, self).__init__(logger, viz) | |||||
self.teacher = teacher | |||||
self.model = self.student = student | |||||
self.generator = generator | |||||
self.z_dim = z_dim | |||||
def train(self, start_iter, max_iter, optim_s, optim_g, device=None): | |||||
if device is None: | |||||
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) | |||||
self.device = device | |||||
self.optim_s, self.optim_g = optim_s, optim_g | |||||
self.model.to(self.device) | |||||
self.teacher.to(self.device) | |||||
self.generator.to(self.device) | |||||
self.train_loader = [0, ] | |||||
with set_mode(self.student, training=True), \ | |||||
set_mode(self.teacher, training=False), \ | |||||
set_mode(self.generator, training=True): | |||||
super( ZSKTDistiller, self ).train( start_iter, max_iter ) | |||||
def search_optimizer(self, evaluator, train_loader, hpo_space=None, mode='min', max_evals=20, max_iters=400): | |||||
optimizer = hpo.search_optimizer(self, train_loader, evaluator=evaluator, hpo_space=hpo_space, mode=mode, max_evals=max_evals, max_iters=max_iters) | |||||
return optimizer | |||||
def step(self): | |||||
start_time = time.perf_counter() | |||||
# Adv | |||||
z = torch.randn( self.z_dim ).to(self.device) | |||||
fake = self.generator( z ) | |||||
self.optim_g.zero_grad() | |||||
t_out = self.teacher( fake ) | |||||
s_out = self.student( fake ) | |||||
loss_g = -kldiv( s_out, t_out ) | |||||
loss_g.backward() | |||||
self.optim_g.step() | |||||
with torch.no_grad(): | |||||
fake = self.generator( z ) | |||||
t_out = self.teacher( fake.detach() ) | |||||
for _ in range(10): | |||||
self.optim_s.zero_grad() | |||||
s_out = self.student( fake.detach() ) | |||||
loss_s = kldiv( s_out, t_out ) | |||||
loss_s.backward() | |||||
self.optim_s.step() | |||||
loss_dict = { | |||||
'loss_g': loss_g, | |||||
'loss_s': loss_s, | |||||
} | |||||
step_time = time.perf_counter() - start_time | |||||
# record training info | |||||
info = loss_dict | |||||
info['step_time'] = step_time | |||||
info['lr_s'] = float( self.optim_s.param_groups[0]['lr'] ) | |||||
info['lr_g'] = float( self.optim_g.param_groups[0]['lr'] ) | |||||
self.history.put_scalars( **info ) | |||||
def reset(self): | |||||
self.history = None | |||||
self._train_loader_iter = iter(train_loader) | |||||
self.iter = self.start_iter |
@@ -0,0 +1,86 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
from .kd import KDDistiller | |||||
from kamal.core.tasks.loss import KDLoss | |||||
import torch.nn as nn | |||||
import torch._ops | |||||
import time | |||||
class HintDistiller(KDDistiller): | |||||
def __init__(self, logger=None, tb_writer=None ): | |||||
super(HintDistiller, self).__init__( logger, tb_writer ) | |||||
def setup(self, | |||||
student, teacher, regressor, dataloader, optimizer, | |||||
hint_layer=2, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, | |||||
stu_hooks=[], tea_hooks=[], out_flags=[], device=None): | |||||
super( HintDistiller, self ).setup( | |||||
student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device ) | |||||
self.regressor = regressor | |||||
self._hint_layer = hint_layer | |||||
self._beta = beta | |||||
self.stu_hooks = stu_hooks | |||||
self.tea_hooks = tea_hooks | |||||
self.out_flags = out_flags | |||||
self.regressor.to(device) | |||||
def additional_kd_loss(self, engine, batch): | |||||
feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] | |||||
feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] | |||||
f_s = self.regressor(feat_s[self._hint_layer]) | |||||
f_t = feat_t[self._hint_layer] | |||||
return nn.functional.mse_loss(f_s, f_t) | |||||
class Regressor(nn.Module): | |||||
""" | |||||
Convolutional regression for FitNet | |||||
@inproceedings{tian2019crd, | |||||
title={Contrastive Representation Distillation}, | |||||
author={Yonglong Tian and Dilip Krishnan and Phillip Isola}, | |||||
booktitle={International Conference on Learning Representations}, | |||||
year={2020} | |||||
} | |||||
""" | |||||
def __init__(self, s_shape, t_shape, is_relu=True): | |||||
super(Regressor, self).__init__() | |||||
self.is_relu = is_relu | |||||
_, s_C, s_H, s_W = s_shape | |||||
_, t_C, t_H, t_W = t_shape | |||||
if s_H == 2 * t_H: | |||||
self.conv = nn.Conv2d(s_C, t_C, kernel_size=3, stride=2, padding=1) | |||||
elif s_H * 2 == t_H: | |||||
self.conv = nn.ConvTranspose2d( | |||||
s_C, t_C, kernel_size=4, stride=2, padding=1) | |||||
elif s_H >= t_H: | |||||
self.conv = nn.Conv2d(s_C, t_C, kernel_size=(1+s_H-t_H, 1+s_W-t_W)) | |||||
else: | |||||
raise NotImplemented( | |||||
'student size {}, teacher size {}'.format(s_H, t_H)) | |||||
self.bn = nn.BatchNorm2d(t_C) | |||||
self.relu = nn.ReLU(inplace=True) | |||||
def forward(self, x): | |||||
x = self.conv(x) | |||||
if self.is_relu: | |||||
return self.relu(self.bn(x)) | |||||
else: | |||||
return self.bn(x) |
@@ -0,0 +1,90 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
from kamal.core.engine.trainer import Engine | |||||
from kamal.core.tasks.loss import kldiv | |||||
import torch.nn.functional as F | |||||
from kamal.utils.logger import get_logger | |||||
from kamal.utils import set_mode, move_to_device | |||||
import weakref | |||||
import torch | |||||
import torch.nn as nn | |||||
import time | |||||
import numpy as np | |||||
class KDDistiller(Engine): | |||||
def __init__( self, | |||||
logger=None, | |||||
tb_writer=None): | |||||
super(KDDistiller, self).__init__(logger=logger, tb_writer=tb_writer) | |||||
def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, device=None): | |||||
self.model = self.student = student | |||||
self.teacher = teacher | |||||
self.dataloader = dataloader | |||||
self.optimizer = optimizer | |||||
if device is None: | |||||
device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) | |||||
self.device = device | |||||
self.T = T | |||||
self.gamma = gamma | |||||
self.alpha = alpha | |||||
self.beta = beta | |||||
self.student.to(self.device) | |||||
self.teacher.to(self.device) | |||||
def run( self, max_iter, start_iter=0, epoch_length=None): | |||||
with set_mode(self.student, training=True), \ | |||||
set_mode(self.teacher, training=False): | |||||
super( KDDistiller, self ).run( self.step_fn, self.dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length) | |||||
def additional_kd_loss(self, engine, batch): | |||||
return batch[0].new_zeros(1) | |||||
def step_fn(self, engine, batch): | |||||
student = self.student | |||||
teacher = self.teacher | |||||
start_time = time.perf_counter() | |||||
batch = move_to_device(batch, self.device) | |||||
inputs, targets = batch | |||||
outputs = student(inputs) | |||||
with torch.no_grad(): | |||||
soft_targets = teacher(inputs) | |||||
loss_dict = { "loss_kld": self.alpha * kldiv(outputs, soft_targets, T=self.T), | |||||
"loss_ce": self.beta * F.cross_entropy( outputs, targets ), | |||||
"loss_additional": self.gamma * self.additional_kd_loss(engine, batch) } | |||||
loss = sum( loss_dict.values() ) | |||||
self.optimizer.zero_grad() | |||||
loss.backward() | |||||
self.optimizer.step() | |||||
step_time = time.perf_counter() - start_time | |||||
metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } | |||||
metrics.update({ | |||||
'total_loss': loss.item(), | |||||
'step_time': step_time, | |||||
'lr': float( self.optimizer.param_groups[0]['lr'] ) | |||||
}) | |||||
return metrics | |||||
@@ -0,0 +1,44 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
from .kd import KDDistiller | |||||
from kamal.core.tasks.loss import KDLoss | |||||
from kamal.core.tasks.loss import NSTLoss | |||||
import torch | |||||
import torch.nn as nn | |||||
import time | |||||
class NSTDistiller(KDDistiller): | |||||
def __init__(self, logger=None, tb_writer=None ): | |||||
super(NSTDistiller, self).__init__( logger, tb_writer ) | |||||
def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None): | |||||
super( NSTDistiller, self ).setup( | |||||
student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device ) | |||||
self.stu_hooks = stu_hooks | |||||
self.tea_hooks = tea_hooks | |||||
self.out_flags = out_flags | |||||
self._nst_loss = NSTLoss() | |||||
def additional_kd_loss(self, engine, batch): | |||||
feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] | |||||
feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] | |||||
g_s = feat_s[1:-1] | |||||
g_t = feat_t[1:-1] | |||||
return self._nst_loss(g_s, g_t) |
@@ -0,0 +1,44 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
from .kd import KDDistiller | |||||
from kamal.core.tasks.loss import PKTLoss | |||||
from kamal.core.tasks.loss import KDLoss | |||||
import torch | |||||
import torch.nn as nn | |||||
import time | |||||
class PKTDistiller(KDDistiller): | |||||
def __init__(self, logger=None, tb_writer=None ): | |||||
super(PKTDistiller, self).__init__( logger, tb_writer ) | |||||
def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None): | |||||
super( PKTDistiller, self ).setup( | |||||
student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device ) | |||||
self.stu_hooks = stu_hooks | |||||
self.tea_hooks = tea_hooks | |||||
self.out_flags = out_flags | |||||
self._pkt_loss = PKTLoss() | |||||
def additional_kd_loss(self, engine, batch): | |||||
feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] | |||||
feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] | |||||
f_s = feat_s[-1] | |||||
f_t = feat_t[-1] | |||||
return self._pkt_loss(f_s, f_t) |
@@ -0,0 +1,45 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
from .kd import KDDistiller | |||||
from kamal.core.tasks.loss import RKDLoss | |||||
from kamal.core.tasks.loss import KDLoss | |||||
import torch | |||||
import torch.nn as nn | |||||
import time | |||||
class RKDDistiller(KDDistiller): | |||||
def __init__(self, logger=None, tb_writer=None ): | |||||
super(RKDDistiller, self).__init__( logger, tb_writer ) | |||||
def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None): | |||||
super( RKDDistiller, self ).setup( | |||||
student, teacher, dataloader, optimizer, T=T, gamma=gamma, alpha=alpha, device=device ) | |||||
self.stu_hooks = stu_hooks | |||||
self.tea_hooks = tea_hooks | |||||
self.out_flags = out_flags | |||||
self._rkd_loss = RKDLoss() | |||||
def additional_kd_loss(self, engine, batch): | |||||
feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] | |||||
feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] | |||||
f_s = feat_s[-1] | |||||
f_t = feat_t[-1] | |||||
return self._rkd_loss(f_s, f_t) | |||||
@@ -0,0 +1,44 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
from .kd import KDDistiller | |||||
from kamal.core.tasks.loss import SPLoss | |||||
from kamal.core.tasks.loss import KDLoss | |||||
import torch | |||||
import torch.nn as nn | |||||
import time | |||||
class SPDistiller(KDDistiller): | |||||
def __init__(self, logger=None, tb_writer=None ): | |||||
super(SPDistiller, self).__init__( logger, tb_writer ) | |||||
def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None): | |||||
super( SPDistiller, self ).setup( | |||||
student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device ) | |||||
self.stu_hooks = stu_hooks | |||||
self.tea_hooks = tea_hooks | |||||
self.out_flags = out_flags | |||||
self._sp_loss = SPLoss() | |||||
def additional_kd_loss(self, engine, batch): | |||||
feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] | |||||
feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] | |||||
g_s = [feat_s[-2]] | |||||
g_t = [feat_t[-2]] | |||||
return self._sp_loss(g_s, g_t) |
@@ -0,0 +1,45 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
from .kd import KDDistiller | |||||
from kamal.core.tasks.loss import SVDLoss | |||||
from kamal.core.tasks.loss import KDLoss | |||||
import torch | |||||
import torch.nn as nn | |||||
import time | |||||
class SVDDistiller(KDDistiller): | |||||
def __init__(self, logger=None, tb_writer=None ): | |||||
super(SVDDistiller, self).__init__( logger, tb_writer ) | |||||
def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None): | |||||
super( SVDDistiller, self ).setup( | |||||
student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device ) | |||||
self.stu_hooks = stu_hooks | |||||
self.tea_hooks = tea_hooks | |||||
self.out_flags = out_flags | |||||
self._svd_loss = SVDLoss() | |||||
def additional_kd_loss(self, engine, batch): | |||||
feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] | |||||
feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] | |||||
g_s = feat_s[1:-1] | |||||
g_t = feat_t[1:-1] | |||||
return self._svd_loss( g_s, g_t ) |
@@ -0,0 +1,90 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
import numpy as np | |||||
import time | |||||
import torch.nn as nn | |||||
import torch._ops | |||||
import torch.nn.functional as F | |||||
from .kd import KDDistiller | |||||
from kamal.utils import set_mode | |||||
from kamal.core.tasks.loss import KDLoss | |||||
class VIDDistiller(KDDistiller): | |||||
def __init__(self, logger=None, tb_writer=None ): | |||||
super(VIDDistiller, self).__init__( logger, tb_writer ) | |||||
def setup(self, student, teacher, dataloader, optimizer, regressor_l, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None): | |||||
super( VIDDistiller, self ).setup( | |||||
student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device ) | |||||
self.regressor_l = regressor_l | |||||
self.stu_hooks = stu_hooks | |||||
self.tea_hooks = tea_hooks | |||||
self.out_flags = out_flags | |||||
self.regressor_l = [regressor.to(self.device).train() for regressor in self.regressor_l] | |||||
def additional_kd_loss(self, engine, batch): | |||||
feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] | |||||
feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] | |||||
g_s = feat_s[1:-1] | |||||
g_t = feat_t[1:-1] | |||||
return sum([c(f_s, f_t) for f_s, f_t, c in zip(g_s, g_t, self.regressor_l)]) | |||||
class VIDRegressor(nn.Module): | |||||
def __init__(self, | |||||
num_input_channels, | |||||
num_mid_channel, | |||||
num_target_channels, | |||||
init_pred_var=5.0, | |||||
eps=1e-5): | |||||
super(VIDRegressor, self).__init__() | |||||
def conv1x1(in_channels, out_channels, stride=1): | |||||
return nn.Conv2d( | |||||
in_channels, out_channels, | |||||
kernel_size=1, padding=0, | |||||
bias=False, stride=stride) | |||||
self.regressor = nn.Sequential( | |||||
conv1x1(num_input_channels, num_mid_channel), | |||||
nn.ReLU(), | |||||
conv1x1(num_mid_channel, num_mid_channel), | |||||
nn.ReLU(), | |||||
conv1x1(num_mid_channel, num_target_channels), | |||||
) | |||||
self.log_scale = torch.nn.Parameter( | |||||
np.log(np.exp(init_pred_var-eps)-1.0) * torch.ones(num_target_channels) | |||||
) | |||||
self.eps = eps | |||||
def forward(self, input, target): | |||||
# pool for dimentsion match | |||||
s_H, t_H = input.shape[2], target.shape[2] | |||||
if s_H > t_H: | |||||
input = F.adaptive_avg_pool2d(input, (t_H, t_H)) | |||||
elif s_H < t_H: | |||||
target = F.adaptive_avg_pool2d(target, (s_H, s_H)) | |||||
else: | |||||
pass | |||||
pred_mean = self.regressor(input) | |||||
pred_var = torch.log(1.0+torch.exp(self.log_scale))+self.eps | |||||
pred_var = pred_var.view(1, -1, 1, 1) | |||||
neg_log_prob = 0.5*( | |||||
(pred_mean-target)**2/pred_var+torch.log(pred_var) | |||||
) | |||||
loss = torch.mean(neg_log_prob) | |||||
return loss |
@@ -0,0 +1,2 @@ | |||||
from .pruner import Pruner | |||||
from .strategy import LNStrategy, RandomStrategy |
@@ -0,0 +1,37 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
from copy import deepcopy | |||||
import os | |||||
import torch | |||||
class Pruner(object): | |||||
def __init__(self, strategy): | |||||
self.strategy = strategy | |||||
def prune(self, model, rate=0.1, example_inputs=None): | |||||
ori_num_params = sum( [ torch.numel(p) for p in model.parameters() ] ) | |||||
model = deepcopy(model).cpu() | |||||
model = self._prune( model, rate=rate, example_inputs=example_inputs ) | |||||
new_num_params = sum( [ torch.numel(p) for p in model.parameters() ] ) | |||||
print( "%d=>%d, %.2f%% params were pruned"%( ori_num_params, new_num_params, 100*(ori_num_params-new_num_params)/ori_num_params ) ) | |||||
return model | |||||
def _prune(self, model, **kargs): | |||||
return self.strategy( model, **kargs) | |||||
@@ -0,0 +1,85 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
import torch_pruning as tp | |||||
import abc | |||||
import torch | |||||
import torch.nn as nn | |||||
import random | |||||
import numpy as np | |||||
_PRUNABLE_MODULES= tp.DependencyGraph.PRUNABLE_MODULES | |||||
class BaseStrategy(abc.ABC): | |||||
@abc.abstractmethod | |||||
def select(self, layer_to_prune): | |||||
pass | |||||
def __call__(self, model, rate=0.1, example_inputs=None): | |||||
if example_inputs is None: | |||||
example_inputs = torch.randn( 1,3,256,256 ) | |||||
DG = tp.DependencyGraph() | |||||
DG.build_dependency(model, example_inputs=example_inputs) | |||||
prunable_layers = [] | |||||
total_params = 0 | |||||
num_accumulative_conv_params = [ 0, ] | |||||
for m in model.modules(): | |||||
if isinstance(m, _PRUNABLE_MODULES ) : | |||||
nparam = tp.utils.count_prunable_params( m ) | |||||
total_params += nparam | |||||
if isinstance(m, (nn.modules.conv._ConvNd, nn.Linear)): | |||||
prunable_layers.append( m ) | |||||
num_accumulative_conv_params.append( num_accumulative_conv_params[-1]+nparam ) | |||||
prunable_layers.pop(-1) # remove the last layer | |||||
num_accumulative_conv_params.pop(-1) # remove the last layer | |||||
num_conv_params = num_accumulative_conv_params[-1] | |||||
num_accumulative_conv_params = [ ( num_accumulative_conv_params[i], num_accumulative_conv_params[i+1] ) for i in range(len(num_accumulative_conv_params)-1) ] | |||||
def map_param_idx_to_conv_layer(i): | |||||
for l, accu in zip( prunable_layers, num_accumulative_conv_params ): | |||||
if accu[0]<=i and i<accu[1]: | |||||
return l | |||||
num_pruned = 0 | |||||
while num_pruned<total_params*rate: | |||||
layer_to_prune = map_param_idx_to_conv_layer( random.randint( 0, num_conv_params-1 ) ) | |||||
if layer_to_prune.weight.shape[0]<1: | |||||
continue | |||||
idx = self.select( layer_to_prune ) | |||||
fn = tp.prune_conv if isinstance(layer_to_prune, nn.modules.conv._ConvNd) else tp.prune_linear | |||||
plan = DG.get_pruning_plan( layer_to_prune, fn, idxs=idx ) | |||||
num_pruned += plan.exec() | |||||
return model | |||||
class RandomStrategy(BaseStrategy): | |||||
def select(self, layer_to_prune): | |||||
return [ random.randint( 0, layer_to_prune.weight.shape[0]-1 ) ] | |||||
class LNStrategy(BaseStrategy): | |||||
def __init__(self, n=2): | |||||
self.n = n | |||||
def select(self, layer_to_prune): | |||||
w = torch.flatten( layer_to_prune.weight, 1 ) | |||||
norm = torch.norm(w, p=self.n, dim=1) | |||||
idx = [ int(norm.min(dim=0)[1].item()) ] | |||||
return idx |
@@ -0,0 +1,18 @@ | |||||
# Deep Model Transferbility from Attribution Maps | |||||
- [*"Paper: Deep Model Transferbility from Attribution Maps"*](https:), NeurIPS 2019.(released soon) | |||||
J. Song, Y. Chen, X. Wang, C. Shen, M. Song | |||||
[Homepage of VIPA Group](https://www.vipazoo.cn/index_en.html), Zhejiang University, China | |||||
This repo is rewrited by zhfeing in `Pytorch`. | |||||
[Homepage of VIPA Group](https://www.vipazoo.cn/index_en.html), Zhejiang University, China | |||||
<div align="left"> | |||||
<img src="vipa-logo.png" width = "40%" height = "40%" alt="icon"/> | |||||
</div> | |||||
@@ -0,0 +1,20 @@ | |||||
from kamal.transferability.trans_graph import TransferabilityGraph | |||||
from kamal.transferability.trans_metric import AttrMapMetric | |||||
import kamal | |||||
from kamal.vision import sync_transforms as sT | |||||
import os | |||||
import torch | |||||
from PIL import Image | |||||
if __name__=='__main__': | |||||
zoo = '/tmp/pycharm_project_225/kamal/transferability/model2' | |||||
TG = TransferabilityGraph(zoo) | |||||
probe_set_root = '/tmp/pycharm_project_225/kamal/transferability/probe_data' | |||||
for probe_set in os.listdir( probe_set_root ): | |||||
print("Add %s"%(probe_set)) | |||||
imgs_set = list( os.listdir( os.path.join( probe_set_root, probe_set ) ) ) | |||||
images = [ Image.open( os.path.join(probe_set_root, probe_set, img) ) for img in imgs_set ] | |||||
metric = AttrMapMetric(images, device=torch.device('cuda')) | |||||
TG.add_metric( probe_set, metric) | |||||
TG.export_to_json(probe_set, 'exported_metrics/%s.json'%(probe_set), topk=3, normalize=True) |
@@ -0,0 +1,3 @@ | |||||
from .attribution_graph import get_attribution_graph, graph_similarity | |||||
from .attribution_map import attribution_map, attr_map_distance |
@@ -0,0 +1,184 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
from typing import Type, Dict, Any | |||||
import itertools | |||||
import numpy as np | |||||
from numpy import Inf | |||||
import networkx | |||||
import torch | |||||
from torch.nn import Module | |||||
from torch.utils.data import Dataset | |||||
from .attribution_map import attribution_map, attr_map_similarity | |||||
def graph_to_array(graph: networkx.Graph): | |||||
weight_matrix = np.zeros((len(graph.nodes), len(graph.nodes))) | |||||
for i, n1 in enumerate(graph.nodes): | |||||
for j, n2 in enumerate(graph.nodes): | |||||
try: | |||||
dist = graph[n1][n2]["weight"] | |||||
except KeyError: | |||||
dist = 1 | |||||
weight_matrix[i, j] = dist | |||||
return weight_matrix | |||||
class FeatureMapExtractor(): | |||||
def __init__(self, module: Module): | |||||
self.module = module | |||||
self.feature_pool: Dict[str, Dict[str, Any]] = dict() | |||||
self.register_hooks() | |||||
def register_hooks(self): | |||||
for name, m in self.module.named_modules(): | |||||
if "pool" in name: | |||||
m.name = name | |||||
self.feature_pool[name] = dict() | |||||
def hook(m: Module, input, output): | |||||
self.feature_pool[m.name]["feature"] = input | |||||
self.feature_pool[name]["handle"] = m.register_forward_hook(hook) | |||||
def _forward(self, x): | |||||
self.module(x) | |||||
def remove_hooks(self): | |||||
for name, cfg in self.feature_pool.items(): | |||||
cfg["handle"].remove() | |||||
cfg.clear() | |||||
self.feature_pool.clear() | |||||
def extract_final_map(self, x): | |||||
self._forward(x) | |||||
feature_map = None | |||||
max_channel = 0 | |||||
min_size = Inf | |||||
for name, cfg in self.feature_pool.items(): | |||||
f = cfg["feature"] | |||||
if len(f) == 1 and isinstance(f[0], torch.Tensor): | |||||
f = f[0] | |||||
if f.dim() == 4: # BxCxHxW | |||||
b, c, h, w = f.shape | |||||
if c >= max_channel and 1 < h * w <= min_size: | |||||
feature_map = f | |||||
max_channel = c | |||||
min_size = h * w | |||||
return feature_map | |||||
def get_attribution_graph( | |||||
model: Module, | |||||
attribution_type: Type, | |||||
with_noise: bool, | |||||
probe_data: Dataset, | |||||
device: torch.device, | |||||
norm_square: bool = False, | |||||
): | |||||
attribution_graph = networkx.Graph() | |||||
model = model.to(device) | |||||
extractor = FeatureMapExtractor(model) | |||||
for i, x in enumerate(probe_data): | |||||
x = x.to(device) | |||||
x.requires_grad_() | |||||
attribution = attribution_map( | |||||
func=lambda x: extractor.extract_final_map(x), | |||||
attribution_type=attribution_type, | |||||
with_noise=with_noise, | |||||
probe_data=x.unsqueeze(0), | |||||
norm_square=norm_square | |||||
) | |||||
attribution_graph.add_node(i, attribution_map=attribution) | |||||
nodes = attribution_graph.nodes | |||||
for i, j in itertools.product(nodes, nodes): | |||||
if i < j: | |||||
weight = attr_map_similarity( | |||||
attribution_graph.nodes(data=True)[i]["attribution_map"], | |||||
attribution_graph.nodes(data=True)[j]["attribution_map"] | |||||
) | |||||
attribution_graph.add_edge(i, j, weight=weight) | |||||
return attribution_graph | |||||
def edge_to_embedding(graph: networkx.Graph): | |||||
adj = graph_to_array(graph) | |||||
up_tri_mask = np.tri(*adj.shape[-2:], k=0, dtype=bool) | |||||
return adj[up_tri_mask] | |||||
def embedding_to_rank(embedding: np.ndarray): | |||||
order = embedding.argsort() | |||||
ranks = order.argsort() | |||||
return ranks | |||||
def graph_similarity(g1: networkx.Graph, g2: networkx.Graph, Lambda: float = 1.0): | |||||
nodes_1 = g1.nodes(data=True) | |||||
nodes_2 = g2.nodes(data=True) | |||||
assert len(nodes_1) == len(nodes_2) | |||||
# calculate vertex similarity | |||||
v_s = 0 | |||||
n = len(g1.nodes) | |||||
for i in range(n): | |||||
v_s += attr_map_similarity( | |||||
map_1=g1.nodes(data=True)[i]["attribution_map"], | |||||
map_2=g2.nodes(data=True)[i]["attribution_map"] | |||||
) | |||||
vertex_similarity = v_s / n | |||||
# calculate edges similarity | |||||
emb_1 = edge_to_embedding(g1) | |||||
emb_2 = edge_to_embedding(g2) | |||||
rank_1 = embedding_to_rank(emb_1) | |||||
rank_2 = embedding_to_rank(emb_2) | |||||
k = emb_1.shape[0] | |||||
edge_similarity = 1 - 6 * np.sum(np.square(rank_1 - rank_2)) / (k ** 3 - k) | |||||
return vertex_similarity + Lambda * edge_similarity | |||||
if __name__ == "__main__": | |||||
from captum.attr import InputXGradient | |||||
from torchvision.models import resnet34 | |||||
model_1 = resnet34(num_classes=10) | |||||
graph_1 = get_attribution_graph( | |||||
model_1, | |||||
attribution_type=InputXGradient, | |||||
with_noise=False, | |||||
probe_data=torch.rand(10, 3, 244, 244), | |||||
device=torch.device("cpu") | |||||
) | |||||
model_2 = resnet34(num_classes=10) | |||||
graph_2 = get_attribution_graph( | |||||
model_2, | |||||
attribution_type=InputXGradient, | |||||
with_noise=False, | |||||
probe_data=torch.rand(10, 3, 244, 244), | |||||
device=torch.device("cpu") | |||||
) | |||||
print(graph_similarity(graph_1, graph_2)) |
@@ -0,0 +1,87 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
import torch | |||||
from typing import Type, Callable | |||||
from captum.attr import Attribution | |||||
from captum.attr import NoiseTunnel | |||||
def with_norm(func: Callable[[torch.Tensor], torch.Tensor], x: torch.Tensor, square: bool = False): | |||||
x = func(x) | |||||
x = torch.norm(x.flatten(1), dim=1, p=2) | |||||
if square: | |||||
x = torch.pow(x, 2) | |||||
return x | |||||
def attribution_map( | |||||
func: Callable[[torch.Tensor], torch.Tensor], | |||||
attribution_type: Type, | |||||
with_noise: bool, | |||||
probe_data: torch.Tensor, | |||||
norm_square: bool = False, | |||||
**attribution_kwargs | |||||
) -> torch.Tensor: | |||||
""" | |||||
Calculate attribution map with given attribution type(algorithm). | |||||
Args: | |||||
model: pytorch module | |||||
attribution_type: attribution algorithm, e.g. IntegratedGradients, InputXGradient, ... | |||||
with_noise: whether to add noise tunnel | |||||
probe_data: input data to model | |||||
device: torch.device("cuda: 0") | |||||
attribution_kwargs: other kwargs for attribution method | |||||
Return: attribution map | |||||
""" | |||||
attribution: Attribution = attribution_type(lambda x: with_norm(func, x, norm_square)) | |||||
if with_noise: | |||||
attribution = NoiseTunnel(attribution) | |||||
attr_map = attribution.attribute( | |||||
inputs=probe_data, | |||||
target=None, | |||||
**attribution_kwargs | |||||
) | |||||
return attr_map.detach() | |||||
def attr_map_distance(map_1: torch.Tensor, map_2: torch.Tensor): | |||||
if map_1.shape != map_2.shape: | |||||
map_1 = torch.nn.functional.interpolate( map_1, size=map_2.shape[-2:] ) | |||||
#dist = torch.dist(map_1.flatten(1), map_2.flatten(1), p=2).mean() | |||||
dist = 1 - torch.cosine_similarity(map_1.flatten(1), map_2.flatten(1)).mean() | |||||
return dist.item() | |||||
def attr_map_similarity(map_1: torch.Tensor, map_2: torch.Tensor): | |||||
assert(map_1.shape == map_2.shape) | |||||
dist = torch.cosine_similarity(map_1.flatten(1), map_2.flatten(1)).mean() | |||||
return dist.item() | |||||
if __name__ == "__main__": | |||||
import captum | |||||
def ff(x): | |||||
return x ** 2 | |||||
m = attribution_map( | |||||
ff, | |||||
captum.attr.InputXGradient, | |||||
with_noise=False, | |||||
probe_data=torch.tensor([[1, 2, 3, 4]], dtype=torch.float, requires_grad=True), | |||||
norm_square=True | |||||
) | |||||
print(m) |
@@ -0,0 +1,135 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
import torch | |||||
import networkx as nx | |||||
from . import depara | |||||
import os, abc | |||||
from typing import Callable | |||||
from kamal import hub | |||||
import json, numbers | |||||
from tqdm import tqdm | |||||
class Node(object): | |||||
def __init__(self, hub_root, entry_name, spec_name): | |||||
self.hub_root = hub_root | |||||
self.entry_name = entry_name | |||||
self.spec_name = spec_name | |||||
@property | |||||
def model(self): | |||||
return hub.load( self.hub_root, entry_name=self.entry_name, spec_name=self.spec_name ).eval() | |||||
@property | |||||
def tag(self): | |||||
return hub.load_tags(self.hub_root, entry_name=self.entry_name, spec_name=self.spec_name) | |||||
@property | |||||
def metadata(self): | |||||
return hub.load_metadata(self.hub_root, entry_name=self.entry_name, spec_name=self.spec_name) | |||||
class TransferabilityGraph(object): | |||||
def __init__(self, model_zoo_set): | |||||
self.model_zoo_set = model_zoo_set | |||||
# self.model_zoo = os.path.abspath( os.path.expanduser( model_zoo ) ) | |||||
self._graphs = dict() | |||||
self._models = dict() | |||||
self._register_models() | |||||
def _register_models(self): | |||||
cnt = 0 | |||||
for model_zoo in self.model_zoo_set: | |||||
model_zoo = os.path.abspath(os.path.expanduser(model_zoo)) | |||||
for hub_root in self._list_modelzoo(model_zoo): | |||||
for entry_name, spec_name in hub.list_spec(hub_root): | |||||
node = Node( hub_root, entry_name, spec_name ) | |||||
name = node.metadata['name'] | |||||
self._models[name] = node | |||||
cnt += 1 | |||||
print("%d models has been registered!"%cnt) | |||||
def _list_modelzoo(self, zoo_dir): | |||||
zoo_list = [] | |||||
def _traverse(path): | |||||
for item in os.listdir(path): | |||||
item_path = os.path.join(path, item) | |||||
if os.path.isdir(item_path): | |||||
if os.path.exists(os.path.join( item_path, 'code/hubconf.py' )): | |||||
zoo_list.append(item_path) | |||||
else: | |||||
_traverse( item_path ) | |||||
_traverse(zoo_dir) | |||||
return zoo_list | |||||
def add_metric(self, metric_name, metric): | |||||
self._graphs[metric_name] = g = nx.DiGraph() | |||||
g.add_nodes_from( self._models.values() ) | |||||
for n1 in self._models.values(): | |||||
for n2 in tqdm(self._models.values()): | |||||
if n1!=n2 and not g.has_edge(n1, n2): | |||||
try: | |||||
g.add_edge(n1, n2, dist=metric( n1, n2 )) | |||||
except: | |||||
ori_device = metric.device | |||||
metric.device = torch.device('cpu') | |||||
g.add_edge(n1, n2, dist=metric( n1, n2 )) | |||||
metric.device = ori_device | |||||
def export_to_json(self, metric_name, output_filename, topk=None, normalize=False): | |||||
graph = self._graphs.get( metric_name, None ) | |||||
assert graph is not None | |||||
graph_data={ | |||||
'nodes': [], | |||||
'edges': [], | |||||
} | |||||
node_to_idx = {} | |||||
for i, node in enumerate(self._models.values()): | |||||
tags = node.tag | |||||
metadata = node.metadata | |||||
node_data = { k:v for (k, v) in tags.items() if isinstance(v, (numbers.Number, str) ) } | |||||
node_data['name'] = metadata['name'] | |||||
node_data['task'] = metadata['task'] | |||||
node_data['dataset'] = metadata['dataset'] | |||||
node_data['url'] = metadata['url'] | |||||
node_data['id'] = i | |||||
graph_data['nodes'].append({'tags': node_data}) | |||||
node_to_idx[node] = i | |||||
# record Edges | |||||
edge_list = graph_data['edges'] | |||||
topk_dist = { idx: [] for idx in range(len( self._models )) } | |||||
for i, edge in enumerate(graph.edges.data('dist')): | |||||
s, t, d = int( node_to_idx[edge[0]] ), int( node_to_idx[edge[1]] ), float(edge[2]) | |||||
topk_dist[s].append(d) | |||||
edge_list.append([ | |||||
s, t, d # source, target, distance | |||||
]) | |||||
if isinstance(topk, int): | |||||
for i, dist in topk_dist.items(): | |||||
dist.sort() | |||||
topk_dist[i] = dist[topk] | |||||
graph_data['edges'] = [ edge for edge in edge_list if edge[2] < topk_dist[edge[0]] ] | |||||
if normalize: | |||||
edge_dist = [e[2] for e in graph_data['edges']] | |||||
min_dist, max_dist = min(edge_dist), max(edge_dist) | |||||
for e in graph_data['edges']: | |||||
e[2] = (e[2] - min_dist+1e-8) / (max_dist - min_dist+1e-8) | |||||
with open(output_filename, 'w') as fp: | |||||
json.dump(graph_data, fp) |
@@ -0,0 +1,109 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
import abc | |||||
from . import depara | |||||
from captum.attr import InputXGradient | |||||
import torch | |||||
from kamal.core import hub | |||||
from kamal.vision import sync_transforms as sT | |||||
class TransMetric(abc.ABC): | |||||
def __init__(self): | |||||
pass | |||||
def __call__(self, a, b) -> float: | |||||
return 0 | |||||
class DeparaMetric(TransMetric): | |||||
def __init__(self, data, device): | |||||
self.data = data | |||||
self.device = device | |||||
self._cache = {} | |||||
def _get_transform(self, metadata): | |||||
input_metadata = metadata['input'] | |||||
size = input_metadata['size'] | |||||
space = input_metadata['space'] | |||||
drange = input_metadata['range'] | |||||
normalize = input_metadata['normalize'] | |||||
if size==None: | |||||
size=224 | |||||
if isinstance(size, (list, tuple)): | |||||
size = size[-1] | |||||
transform = [ | |||||
sT.Resize(size), | |||||
sT.CenterCrop(size), | |||||
] | |||||
if space=='bgr': | |||||
transform.append(sT.FlipChannels()) | |||||
if list(drange)==[0, 1]: | |||||
transform.append( sT.ToTensor() ) | |||||
elif list(drange)==[0, 255]: | |||||
transform.append( sT.ToTensor(normalize=False, dtype=torch.float) ) | |||||
else: | |||||
raise NotImplementedError | |||||
if normalize is not None: | |||||
transform.append(sT.Normalize( mean=normalize['mean'], std=normalize['std'] )) | |||||
return sT.Compose(transform) | |||||
def _get_attr_graph(self, n): | |||||
transform = self._get_transform(n.metadata) | |||||
data = torch.stack( [ transform( d ) for d in self.data ], dim=0 ) | |||||
return depara.get_attribution_graph( | |||||
n.model, | |||||
attribution_type=InputXGradient, | |||||
with_noise=False, | |||||
probe_data=data, | |||||
device=self.device | |||||
) | |||||
def __call__(self, n1, n2): | |||||
attrgraph_1 = self._cache.get(n1, None) | |||||
attrgraph_2 = self._cache.get(n2, None) | |||||
if attrgraph_1 is None: | |||||
self._cache[n1] = attrgraph_1 = self._get_attr_graph(n1).cpu() | |||||
if attrgraph_2 is None: | |||||
self._cache[n2] = attrgraph_2 = self._get_attr_graph(n2).cpu() | |||||
result = depara.graph_similarity(attrgraph_1, attrgraph_2) | |||||
self._cache[n1] = self._cache[n1] | |||||
self._cache[n2] = self._cache[n2] | |||||
class AttrMapMetric(DeparaMetric): | |||||
def _get_attr_map(self, n): | |||||
transform = self._get_transform(n.metadata) | |||||
data = torch.stack( [ transform( d ).to(self.device) for d in self.data ], dim=0 ) | |||||
return depara.attribution_map( | |||||
n.model.to(self.device), | |||||
attribution_type=InputXGradient, | |||||
with_noise=False, | |||||
probe_data=data, | |||||
) | |||||
def __call__(self, n1, n2): | |||||
attrgraph_1 = self._cache.get(n1, None) | |||||
attrgraph_2 = self._cache.get(n2, None) | |||||
if attrgraph_1 is None: | |||||
self._cache[n1] = attrgraph_1 = self._get_attr_map(n1).cpu() | |||||
if attrgraph_2 is None: | |||||
self._cache[n2] = attrgraph_2 = self._get_attr_map(n2).cpu() | |||||
result = depara.attr_map_distance(attrgraph_1, attrgraph_2) | |||||
self._cache[n1] = self._cache[n1] | |||||
self._cache[n2] = self._cache[n2] | |||||
return result | |||||
@@ -0,0 +1,2 @@ | |||||
from ._utils import * | |||||
from .logger import get_logger |
@@ -0,0 +1,153 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
import numpy as np | |||||
import math | |||||
import torch | |||||
import random | |||||
from copy import deepcopy | |||||
import contextlib, hashlib | |||||
def split_batch(batch): | |||||
if isinstance(batch, (list, tuple)): | |||||
inputs, *targets = batch | |||||
if len(targets)==1: | |||||
targets = targets[0] | |||||
return inputs, targets | |||||
else: | |||||
return [batch, None] | |||||
@contextlib.contextmanager | |||||
def set_mode(model, training=True): | |||||
ori_mode = model.training | |||||
model.train(training) | |||||
yield | |||||
model.train(ori_mode) | |||||
def move_to_device(obj, device): | |||||
if isinstance(obj, torch.Tensor): | |||||
return obj.to(device=device) | |||||
elif isinstance( obj, (list, tuple) ): | |||||
return [ o.to(device=device) for o in obj ] | |||||
elif isinstance(obj, nn.Module): | |||||
return obj.to(device=device) | |||||
def pack_images(images, col=None, channel_last=False): | |||||
# N, C, H, W | |||||
if isinstance(images, (list, tuple) ): | |||||
images = np.stack(images, 0) | |||||
if channel_last: | |||||
images = images.transpose(0,3,1,2) # make it channel first | |||||
assert len(images.shape)==4 | |||||
assert isinstance(images, np.ndarray) | |||||
N,C,H,W = images.shape | |||||
if col is None: | |||||
col = int(math.ceil(math.sqrt(N))) | |||||
row = int(math.ceil(N / col)) | |||||
pack = np.zeros( (C, H*row, W*col), dtype=images.dtype ) | |||||
for idx, img in enumerate(images): | |||||
h = (idx//col) * H | |||||
w = (idx% col) * W | |||||
pack[:, h:h+H, w:w+W] = img | |||||
return pack | |||||
def normalize(tensor, mean, std, reverse=False): | |||||
if reverse: | |||||
_mean = [ -m / s for m, s in zip(mean, std) ] | |||||
_std = [ 1/s for s in std ] | |||||
else: | |||||
_mean = mean | |||||
_std = std | |||||
_mean = torch.as_tensor(_mean, dtype=tensor.dtype, device=tensor.device) | |||||
_std = torch.as_tensor(_std, dtype=tensor.dtype, device=tensor.device) | |||||
tensor = (tensor - _mean[None, :, None, None]) / (_std[None, :, None, None]) | |||||
return tensor | |||||
class Normalizer(object): | |||||
def __init__(self, mean, std, reverse=False): | |||||
self.mean = mean | |||||
self.std = std | |||||
self.reverse = reverse | |||||
def __call__(self, x): | |||||
if self.reverse: | |||||
return self.denormalize(x) | |||||
else: | |||||
return self.normalize(x) | |||||
def normalize(self, x): | |||||
return normalize( x, self.mean, self.std ) | |||||
def denormalize(self, x): | |||||
return normalize( x, self.mean, self.std, reverse=True ) | |||||
def colormap(N=256, normalized=False): | |||||
def bitget(byteval, idx): | |||||
return ((byteval & (1 << idx)) != 0) | |||||
dtype = 'float32' if normalized else 'uint8' | |||||
cmap = np.zeros((N, 3), dtype=dtype) | |||||
for i in range(N): | |||||
r = g = b = 0 | |||||
c = i | |||||
for j in range(8): | |||||
r = r | (bitget(c, 0) << 7-j) | |||||
g = g | (bitget(c, 1) << 7-j) | |||||
b = b | (bitget(c, 2) << 7-j) | |||||
c = c >> 3 | |||||
cmap[i] = np.array([r, g, b]) | |||||
cmap = cmap/255 if normalized else cmap | |||||
return cmap | |||||
DEFAULT_COLORMAP = colormap() | |||||
def flatten_dict(dic): | |||||
flattned = dict() | |||||
def _flatten(prefix, d): | |||||
for k, v in d.items(): | |||||
if isinstance(v, dict): | |||||
if prefix is None: | |||||
_flatten( k, v ) | |||||
else: | |||||
_flatten( prefix+'%s/'%k, v ) | |||||
else: | |||||
flattned[ (prefix+'%s/'%k).strip('/') ] = v | |||||
_flatten('', dic) | |||||
return flattned | |||||
def setup_seed(seed): | |||||
torch.manual_seed(seed) | |||||
torch.cuda.manual_seed(seed) | |||||
np.random.seed(seed) | |||||
random.seed(seed) | |||||
def count_parameters(model): | |||||
return sum( [ p.numel() for p in model.parameters() ] ) | |||||
def md5(fname): | |||||
hash_md5 = hashlib.md5() | |||||
with open(fname, "rb") as f: | |||||
for chunk in iter(lambda: f.read(4096), b""): | |||||
hash_md5.update(chunk) | |||||
return hash_md5.hexdigest() |
@@ -0,0 +1,56 @@ | |||||
import logging | |||||
import os, sys | |||||
from termcolor import colored | |||||
class _ColorfulFormatter(logging.Formatter): | |||||
def __init__(self, *args, **kwargs): | |||||
super(_ColorfulFormatter, self).__init__(*args, **kwargs) | |||||
def formatMessage(self, record): | |||||
log = super(_ColorfulFormatter, self).formatMessage(record) | |||||
if record.levelno == logging.WARNING: | |||||
prefix = colored("WARNING", "yellow", attrs=["blink"]) | |||||
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: | |||||
prefix = colored("ERROR", "red", attrs=["blink", "underline"]) | |||||
else: | |||||
return log | |||||
return prefix + " " + log | |||||
def get_logger(name='Kamal', output=None, color=True): | |||||
logger = logging.getLogger(name) | |||||
logger.setLevel(logging.DEBUG) | |||||
logger.propagate = False | |||||
# STDOUT | |||||
stdout_handler = logging.StreamHandler( stream=sys.stdout ) | |||||
stdout_handler.setLevel( logging.DEBUG ) | |||||
plain_formatter = logging.Formatter( | |||||
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" ) | |||||
if color: | |||||
formatter = _ColorfulFormatter( | |||||
colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", | |||||
datefmt="%m/%d %H:%M:%S") | |||||
else: | |||||
formatter = plain_formatter | |||||
stdout_handler.setFormatter(formatter) | |||||
logger.addHandler(stdout_handler) | |||||
# FILE | |||||
if output is not None: | |||||
if output.endswith('.txt') or output.endswith('.log'): | |||||
os.makedirs(os.path.dirname(output), exist_ok=True) | |||||
filename = output | |||||
else: | |||||
os.makedirs(output, exist_ok=True) | |||||
filename = os.path.join(output, "log.txt") | |||||
file_handler = logging.FileHandler(filename) | |||||
file_handler.setFormatter(plain_formatter) | |||||
file_handler.setLevel(logging.DEBUG) | |||||
logger.addHandler(file_handler) | |||||
return logger | |||||
@@ -0,0 +1,3 @@ | |||||
from . import models | |||||
from . import datasets | |||||
from . import sync_transforms |
@@ -0,0 +1,16 @@ | |||||
from .ade20k import ADE20K | |||||
from .caltech import Caltech101, Caltech256 | |||||
from .camvid import CamVid | |||||
from .cityscapes import Cityscapes | |||||
from .cub200 import CUB200 | |||||
from .fgvc_aircraft import FGVCAircraft | |||||
from .nyu import NYUv2 | |||||
from .stanford_cars import StanfordCars | |||||
from .stanford_dogs import StanfordDogs | |||||
from .sunrgbd import SunRGBD | |||||
from .voc import VOCClassification, VOCSegmentation | |||||
from .dataset import LabelConcatDataset | |||||
from torchvision import datasets as torchvision_datasets | |||||
from .unlabeled import UnlabeledDataset |
@@ -0,0 +1,70 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
import collections | |||||
import torch | |||||
import torchvision | |||||
import numpy as np | |||||
from PIL import Image | |||||
import os | |||||
from torchvision.datasets import VisionDataset | |||||
from .utils import colormap | |||||
class ADE20K(VisionDataset): | |||||
cmap = colormap() | |||||
def __init__( | |||||
self, | |||||
root, | |||||
split="training", | |||||
transform=None, | |||||
target_transform=None, | |||||
transforms=None, | |||||
): | |||||
super( ADE20K, self ).__init__( root=root, transforms=transforms, transform=transform, target_transform=target_transform ) | |||||
assert split in ['training', 'validation'], "split should be \'training\' or \'validation\'" | |||||
self.root = os.path.expanduser(root) | |||||
self.split = split | |||||
self.num_classes = 150 | |||||
img_list = [] | |||||
lbl_list = [] | |||||
img_dir = os.path.join( self.root, 'images', self.split ) | |||||
lbl_dir = os.path.join( self.root, 'annotations', self.split ) | |||||
for img_name in os.listdir( img_dir ): | |||||
img_list.append( os.path.join( img_dir, img_name ) ) | |||||
lbl_list.append( os.path.join( lbl_dir, img_name[:-3]+'png') ) | |||||
self.img_list = img_list | |||||
self.lbl_list = lbl_list | |||||
def __len__(self): | |||||
return len(self.img_list) | |||||
def __getitem__(self, index): | |||||
img = Image.open( self.img_list[index] ) | |||||
lbl = Image.open( self.lbl_list[index] ) | |||||
if self.transforms: | |||||
img, lbl = self.transforms(img, lbl) | |||||
lbl = np.array(lbl, dtype='uint8')-1 # 1-150 => 0-149 + 255 | |||||
return img, lbl | |||||
@classmethod | |||||
def decode_seg_to_color(cls, mask): | |||||
"""decode semantic mask to RGB image""" | |||||
return cls.cmap[mask+1] |
@@ -0,0 +1,226 @@ | |||||
# Modified from https://github.com/pytorch/vision/blob/master/torchvision/datasets/caltech.py | |||||
from __future__ import print_function | |||||
from PIL import Image | |||||
import os | |||||
import os.path | |||||
from torchvision.datasets.vision import VisionDataset | |||||
from torchvision.datasets.utils import download_url | |||||
class Caltech101(VisionDataset): | |||||
"""`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset. | |||||
Args: | |||||
root (string): Root directory of dataset where directory | |||||
``caltech101`` exists or will be saved to if download is set to True. | |||||
target_type (string or list, optional): Type of target to use, ``category`` or | |||||
``annotation``. Can also be a list to output a tuple with all specified target types. | |||||
``category`` represents the target class, and ``annotation`` is a list of points | |||||
from a hand-generated outline. Defaults to ``category``. | |||||
transform (callable, optional): A function/transform that takes in an PIL image | |||||
and returns a transformed version. E.g, ``transforms.RandomCrop`` | |||||
target_transform (callable, optional): A function/transform that takes in the | |||||
target and transforms it. | |||||
download (bool, optional): If true, downloads the dataset from the internet and | |||||
puts it in root directory. If dataset is already downloaded, it is not | |||||
downloaded again. | |||||
""" | |||||
def __init__(self, root, target_type="category", train=True, | |||||
transform=None, target_transform=None, | |||||
download=False): | |||||
super(Caltech101, self).__init__(os.path.join(root, 'caltech101')) | |||||
self.train = train | |||||
self.dir_name = '101_ObjectCategories_split/train' if self.train else '101_ObjectCategories_split/test' | |||||
os.makdirs(self.root, exist_ok=True) | |||||
if isinstance(target_type, list): | |||||
self.target_type = target_type | |||||
else: | |||||
self.target_type = [target_type] | |||||
self.transform = transform | |||||
self.target_transform = target_transform | |||||
if download: | |||||
self.download() | |||||
if not self._check_integrity(): | |||||
raise RuntimeError('Dataset not found or corrupted.' + | |||||
' You can use download=True to download it') | |||||
self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories"))) | |||||
self.categories.remove("BACKGROUND_Google") # this is not a real class | |||||
# For some reason, the category names in "101_ObjectCategories" and | |||||
# "Annotations" do not always match. This is a manual map between the | |||||
# two. Defaults to using same name, since most names are fine. | |||||
name_map = {"Faces": "Faces_2", | |||||
"Faces_easy": "Faces_3", | |||||
"Motorbikes": "Motorbikes_16", | |||||
"airplanes": "Airplanes_Side_2"} | |||||
self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories)) | |||||
self.index = [] | |||||
self.y = [] | |||||
for (i, c) in enumerate(self.categories): | |||||
file_names = os.listdir(os.path.join(self.root, self.dir_name, c)) | |||||
n = len(file_names) | |||||
self.index.extend( file_names ) | |||||
self.y.extend(n * [i]) | |||||
print(self.train, len(self.index)) | |||||
def __getitem__(self, index): | |||||
""" | |||||
Args: | |||||
index (int): Index | |||||
Returns: | |||||
tuple: (image, target) where the type of target specified by target_type. | |||||
""" | |||||
import scipy.io | |||||
img = Image.open(os.path.join(self.root, | |||||
self.dir_name, | |||||
self.categories[self.y[index]], | |||||
self.index[index])).convert("RGB") | |||||
target = [] | |||||
for t in self.target_type: | |||||
if t == "category": | |||||
target.append(self.y[index]) | |||||
elif t == "annotation": | |||||
data = scipy.io.loadmat(os.path.join(self.root, | |||||
"Annotations", | |||||
self.annotation_categories[self.y[index]], | |||||
"annotation_{:04d}.mat".format(self.index[index]))) | |||||
target.append(data["obj_contour"]) | |||||
else: | |||||
raise ValueError("Target type \"{}\" is not recognized.".format(t)) | |||||
target = tuple(target) if len(target) > 1 else target[0] | |||||
if self.transform is not None: | |||||
img = self.transform(img) | |||||
if self.target_transform is not None: | |||||
target = self.target_transform(target) | |||||
return img, target | |||||
def _check_integrity(self): | |||||
# can be more robust and check hash of files | |||||
return os.path.exists(os.path.join(self.root, "101_ObjectCategories")) | |||||
def __len__(self): | |||||
return len(self.index) | |||||
def download(self): | |||||
import tarfile | |||||
if self._check_integrity(): | |||||
print('Files already downloaded and verified') | |||||
return | |||||
download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz", | |||||
self.root, | |||||
"101_ObjectCategories.tar.gz", | |||||
"b224c7392d521a49829488ab0f1120d9") | |||||
download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar", | |||||
self.root, | |||||
"101_Annotations.tar", | |||||
"6f83eeb1f24d99cab4eb377263132c91") | |||||
# extract file | |||||
with tarfile.open(os.path.join(self.root, "101_ObjectCategories.tar.gz"), "r:gz") as tar: | |||||
tar.extractall(path=self.root) | |||||
with tarfile.open(os.path.join(self.root, "101_Annotations.tar"), "r:") as tar: | |||||
tar.extractall(path=self.root) | |||||
def extra_repr(self): | |||||
return "Target type: {target_type}".format(**self.__dict__) | |||||
class Caltech256(VisionDataset): | |||||
"""`Caltech 256 <http://www.vision.caltech.edu/Image_Datasets/Caltech256/>`_ Dataset. | |||||
Args: | |||||
root (string): Root directory of dataset where directory | |||||
``caltech256`` exists or will be saved to if download is set to True. | |||||
transform (callable, optional): A function/transform that takes in an PIL image | |||||
and returns a transformed version. E.g, ``transforms.RandomCrop`` | |||||
target_transform (callable, optional): A function/transform that takes in the | |||||
target and transforms it. | |||||
download (bool, optional): If true, downloads the dataset from the internet and | |||||
puts it in root directory. If dataset is already downloaded, it is not | |||||
downloaded again. | |||||
""" | |||||
def __init__(self, root, | |||||
transform=None, target_transform=None, | |||||
download=False): | |||||
super(Caltech256, self).__init__(os.path.join(root, 'caltech256')) | |||||
os.makedirs(self.root, exist_ok=True) | |||||
self.transform = transform | |||||
self.target_transform = target_transform | |||||
if download: | |||||
self.download() | |||||
if not self._check_integrity(): | |||||
raise RuntimeError('Dataset not found or corrupted.' + | |||||
' You can use download=True to download it') | |||||
self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories"))) | |||||
self.index = [] | |||||
self.y = [] | |||||
for (i, c) in enumerate(self.categories): | |||||
n = len(os.listdir(os.path.join(self.root, "256_ObjectCategories", c))) | |||||
self.index.extend(range(1, n + 1)) | |||||
self.y.extend(n * [i]) | |||||
def __getitem__(self, index): | |||||
""" | |||||
Args: | |||||
index (int): Index | |||||
Returns: | |||||
tuple: (image, target) where target is index of the target class. | |||||
""" | |||||
img = Image.open(os.path.join(self.root, | |||||
"256_ObjectCategories", | |||||
self.categories[self.y[index]], | |||||
"{:03d}_{:04d}.jpg".format(self.y[index] + 1, self.index[index]))) | |||||
target = self.y[index] | |||||
if self.transform is not None: | |||||
img = self.transform(img) | |||||
if self.target_transform is not None: | |||||
target = self.target_transform(target) | |||||
return img, target | |||||
def _check_integrity(self): | |||||
# can be more robust and check hash of files | |||||
return os.path.exists(os.path.join(self.root, "256_ObjectCategories")) | |||||
def __len__(self): | |||||
return len(self.index) | |||||
def download(self): | |||||
import tarfile | |||||
if self._check_integrity(): | |||||
print('Files already downloaded and verified') | |||||
return | |||||
download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar", | |||||
self.root, | |||||
"256_ObjectCategories.tar", | |||||
"67b4f42ca05d46448c6bb8ecd2220f6d") | |||||
# extract file | |||||
with tarfile.open(os.path.join(self.root, "256_ObjectCategories.tar"), "r:") as tar: | |||||
tar.extractall(path=self.root) |
@@ -0,0 +1,78 @@ | |||||
# Modified from https://github.com/davidtvs/PyTorch-ENet/blob/master/data/camvid.py | |||||
import os | |||||
import torch.utils.data as data | |||||
from glob import glob | |||||
from PIL import Image | |||||
import numpy as np | |||||
from torchvision.datasets import VisionDataset | |||||
class CamVid(VisionDataset): | |||||
"""CamVid dataset loader where the dataset is arranged as in https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid. | |||||
Args: | |||||
root (string): | |||||
split (string): The type of dataset: 'train', 'val', 'trainval', or 'test' | |||||
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. Default: None. | |||||
target_transform (callable, optional): A function/transform that takes in the target and transform it. Default: None. | |||||
transforms (callable, optional): A function/transform that takes in both the image and target and transform them. Default: None. | |||||
""" | |||||
cmap = np.array([ | |||||
(128, 128, 128), | |||||
(128, 0, 0), | |||||
(192, 192, 128), | |||||
(128, 64, 128), | |||||
(60, 40, 222), | |||||
(128, 128, 0), | |||||
(192, 128, 128), | |||||
(64, 64, 128), | |||||
(64, 0, 128), | |||||
(64, 64, 0), | |||||
(0, 128, 192), | |||||
(0, 0, 0), | |||||
]) | |||||
def __init__(self, | |||||
root, | |||||
split='train', | |||||
transform=None, | |||||
target_transform=None, | |||||
transforms=None): | |||||
assert split in ('train', 'val', 'test', 'trainval') | |||||
super( CamVid, self ).__init__(root=root, transforms=transforms, transform=transform, target_transform=target_transform) | |||||
self.root = os.path.expanduser(root) | |||||
self.split = split | |||||
if split == 'trainval': | |||||
self.images = glob(os.path.join(self.root, 'train', '*.png')) + glob(os.path.join(self.root, 'val', '*.png')) | |||||
self.labels = glob(os.path.join(self.root, 'trainannot', '*.png')) + glob(os.path.join(self.root, 'valannot', '*.png')) | |||||
else: | |||||
self.images = glob(os.path.join(self.root, self.split, '*.png')) | |||||
self.labels = glob(os.path.join(self.root, self.split+'annot', '*.png')) | |||||
self.images.sort() | |||||
self.labels.sort() | |||||
def __getitem__(self, idx): | |||||
""" | |||||
Args: | |||||
- index (``int``): index of the item in the dataset | |||||
Returns: | |||||
A tuple of ``PIL.Image`` (image, label) where label is the ground-truth | |||||
of the image. | |||||
""" | |||||
img, label = Image.open(self.images[idx]), Image.open(self.labels[idx]) | |||||
if self.transforms is not None: | |||||
img, label = self.transforms(img, label) | |||||
label[label == 11] = 255 # ignore void | |||||
return img, label.squeeze(0) | |||||
def __len__(self): | |||||
return len(self.images) | |||||
@classmethod | |||||
def decode_fn(cls, mask): | |||||
"""decode semantic mask to RGB image""" | |||||
mask[mask == 255] = 11 | |||||
return cls.cmap[mask] |
@@ -0,0 +1,146 @@ | |||||
# Modified from https://github.com/pytorch/vision/blob/master/torchvision/datasets/cityscapes.py | |||||
import json | |||||
import os | |||||
from collections import namedtuple | |||||
import torch | |||||
import torch.utils.data as data | |||||
from PIL import Image | |||||
import numpy as np | |||||
from torchvision.datasets import VisionDataset | |||||
class Cityscapes(VisionDataset): | |||||
"""Cityscapes <http://www.cityscapes-dataset.com/> Dataset. | |||||
Args: | |||||
root (string): Root directory of dataset where directory 'leftImg8bit' and 'gtFine' or 'gtCoarse' are located. | |||||
split (string, optional): The image split to use, 'train', 'test' or 'val' if mode="gtFine" otherwise 'train', 'train_extra' or 'val' | |||||
mode (string, optional): The quality mode to use, 'gtFine' or 'gtCoarse' or 'color'. Can also be a list to output a tuple with all specified target types. | |||||
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` | |||||
target_transform (callable, optional): A function/transform that takes in the target and transforms it. | |||||
""" | |||||
# Based on https://github.com/mcordts/cityscapesScripts | |||||
CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id', | |||||
'has_instances', 'ignore_in_eval', 'color']) | |||||
classes = [ | |||||
CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)), | |||||
CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)), | |||||
CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)), | |||||
CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)), | |||||
CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)), | |||||
CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)), | |||||
CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)), | |||||
CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)), | |||||
CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)), | |||||
CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)), | |||||
CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)), | |||||
CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)), | |||||
CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)), | |||||
CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)), | |||||
CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)), | |||||
CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)), | |||||
CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)), | |||||
CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)), | |||||
CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)), | |||||
CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)), | |||||
CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)), | |||||
CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)), | |||||
CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)), | |||||
CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)), | |||||
CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)), | |||||
CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)), | |||||
CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)), | |||||
CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)), | |||||
CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)), | |||||
CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)), | |||||
CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)), | |||||
CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)), | |||||
CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)), | |||||
CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)), | |||||
CityscapesClass('license plate', -1, 255, 'vehicle', 7, False, True, (0, 0, 142)), | |||||
] | |||||
_TRAIN_ID_TO_COLOR = [c.color for c in classes if (c.train_id != -1 and c.train_id != 255)] | |||||
_TRAIN_ID_TO_COLOR.append([0, 0, 0]) | |||||
_TRAIN_ID_TO_COLOR = np.array(_TRAIN_ID_TO_COLOR) | |||||
_ID_TO_TRAIN_ID = np.array([c.train_id for c in classes]) | |||||
def __init__(self, root, split='train', mode='gtFine', target_type='semantic', transform=None, target_transform=None, transforms=None): | |||||
super(Cityscapes, self).__init__( root, transform=transform, target_transform=target_transform, transforms=transforms ) | |||||
self.root = os.path.expanduser(root) | |||||
self.mode = mode | |||||
self.target_type = target_type | |||||
self.images_dir = os.path.join(self.root, 'leftImg8bit', split) | |||||
self.targets_dir = os.path.join(self.root, self.mode, split) | |||||
self.split = split | |||||
self.images = [] | |||||
self.targets = [] | |||||
if split not in ['train', 'test', 'val']: | |||||
raise ValueError('Invalid split for mode! Please use split="train", split="test"' | |||||
' or split="val"') | |||||
if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir): | |||||
raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the' | |||||
' specified "split" and "mode" are inside the "root" directory') | |||||
for city in os.listdir(self.images_dir): | |||||
img_dir = os.path.join(self.images_dir, city) | |||||
target_dir = os.path.join(self.targets_dir, city) | |||||
for file_name in os.listdir(img_dir): | |||||
self.images.append(os.path.join(img_dir, file_name)) | |||||
target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0], | |||||
self._get_target_suffix(self.mode, self.target_type)) | |||||
self.targets.append(os.path.join(target_dir, target_name)) | |||||
@classmethod | |||||
def encode_target(cls, target): | |||||
if isinstance( target, torch.Tensor ): | |||||
return torch.from_numpy( cls._ID_TO_TRAIN_ID[np.array(target)] ) | |||||
else: | |||||
return cls._ID_TO_TRAIN_ID[target] | |||||
@classmethod | |||||
def decode_fn(cls, target): | |||||
target[target == 255] = 19 | |||||
#target = target.astype('uint8') + 1 | |||||
return cls._TRAIN_ID_TO_COLOR[target] | |||||
def __getitem__(self, index): | |||||
""" | |||||
Args: | |||||
index (int): Index | |||||
Returns: | |||||
tuple: (image, target) where target is a tuple of all target types if target_type is a list with more | |||||
than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation. | |||||
""" | |||||
image = Image.open(self.images[index]).convert('RGB') | |||||
target = Image.open(self.targets[index]) | |||||
if self.transforms: | |||||
image, target = self.transforms(image, target) | |||||
target = self.encode_target(target) | |||||
return image, target | |||||
def __len__(self): | |||||
return len(self.images) | |||||
def _load_json(self, path): | |||||
with open(path, 'r') as file: | |||||
data = json.load(file) | |||||
return data | |||||
def _get_target_suffix(self, mode, target_type): | |||||
if target_type == 'instance': | |||||
return '{}_instanceIds.png'.format(mode) | |||||
elif target_type == 'semantic': | |||||
return '{}_labelIds.png'.format(mode) | |||||
elif target_type == 'color': | |||||
return '{}_color.png'.format(mode) | |||||
elif target_type == 'polygon': | |||||
return '{}_polygons.json'.format(mode) | |||||
elif target_type == 'depth': | |||||
return '{}_disparity.png'.format(mode) |
@@ -0,0 +1,70 @@ | |||||
# Modified from https://github.com/TDeVries/cub2011_dataset/blob/master/cub2011.py | |||||
import os | |||||
import pandas as pd | |||||
from torchvision.datasets.folder import default_loader | |||||
from .utils import download_url | |||||
from torch.utils.data import Dataset | |||||
import shutil | |||||
class CUB200(Dataset): | |||||
base_folder = 'images' | |||||
url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz' | |||||
filename = 'CUB_200_2011.tgz' | |||||
tgz_md5 = '97eceeb196236b17998738112f37df78' | |||||
def __init__(self, root, split='train', transform=None, target_transform=None, loader=default_loader, download=False): | |||||
self.root = os.path.abspath( os.path.expanduser( root ) ) | |||||
self.transform = transform | |||||
self.target_transform = target_transform | |||||
self.loader = default_loader | |||||
self.split = split | |||||
if download: | |||||
self.download() | |||||
self._load_metadata() | |||||
categories = os.listdir(os.path.join( | |||||
self.root, 'CUB_200_2011', 'images')) | |||||
categories.sort() | |||||
self.object_categories = [c[4:] for c in categories] | |||||
print('CUB200, Split: %s, Size: %d' % (self.split, self.__len__())) | |||||
def _load_metadata(self): | |||||
images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ', | |||||
names=['img_id', 'filepath']) | |||||
image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'), | |||||
sep=' ', names=['img_id', 'target'], encoding='latin-1', engine='python') | |||||
train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'), | |||||
sep=' ', names=['img_id', 'is_training_img'], encoding='latin-1', engine='python') | |||||
data = images.merge(image_class_labels, on='img_id') | |||||
self.data = data.merge(train_test_split, on='img_id') | |||||
if self.split == 'train': | |||||
self.data = self.data[self.data.is_training_img == 1] | |||||
else: | |||||
self.data = self.data[self.data.is_training_img == 0] | |||||
def download(self): | |||||
import tarfile | |||||
os.makedirs(self.root, exist_ok=True) | |||||
if not os.path.isfile(os.path.join(self.root, self.filename)): | |||||
download_url(self.url, self.root, self.filename) | |||||
print("Extracting %s..." % self.filename) | |||||
with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: | |||||
tar.extractall(path=self.root) | |||||
def __len__(self): | |||||
return len(self.data) | |||||
def __getitem__(self, idx): | |||||
sample = self.data.iloc[idx] | |||||
path = os.path.join(self.root, 'CUB_200_2011', | |||||
self.base_folder, sample.filepath) | |||||
lbl = sample.target - 1 | |||||
img = self.loader(path) | |||||
if self.transform is not None: | |||||
img = self.transform(img) | |||||
if self.target_transform is not None: | |||||
lbl = self.target_transform(lbl) | |||||
return img, lbl |
@@ -0,0 +1,57 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
from torchvision.datasets import VisionDataset | |||||
from PIL import Image | |||||
import torch | |||||
class LabelConcatDataset(VisionDataset): | |||||
"""Dataset as a concatenation of dataset's lables. | |||||
This class is useful to assemble the same dataset's labels. | |||||
Arguments: | |||||
datasets (sequence): List of datasets to be concatenated | |||||
tasks (list) : List of teacher tasks | |||||
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. | |||||
target_transform (callable, optional): A function/transform that takes in the target and transforms it. | |||||
transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version. | |||||
""" | |||||
def __init__(self, datasets, transforms=None, transform=None, target_transform=None): | |||||
super(LabelConcatDataset, self).__init__( | |||||
root=None, transforms=transforms, transform=transform, target_transform=target_transform) | |||||
assert len(datasets) > 0, 'datasets should not be an empty iterable' | |||||
self.datasets = list(datasets) | |||||
for d in self.datasets: | |||||
for t in ('transform', 'transforms', 'target_transform'): | |||||
if hasattr( d, t ): | |||||
setattr( d, t, None ) | |||||
def __getitem__(self, idx): | |||||
targets_list = [] | |||||
for dst in self.datasets: | |||||
image, target = dst[idx] | |||||
targets_list.append(target) | |||||
if self.transforms is not None: | |||||
image, *targets_list = self.transforms( image, *targets_list ) | |||||
data = [ image, *targets_list ] | |||||
return data | |||||
def __len__(self): | |||||
return len(self.datasets[0].images) |
@@ -0,0 +1,143 @@ | |||||
#Modified from https://github.com/pytorch/vision/pull/467/files | |||||
from __future__ import print_function | |||||
import torch.utils.data as data | |||||
from torchvision.datasets.folder import pil_loader, accimage_loader, default_loader | |||||
from PIL import Image | |||||
import os | |||||
import numpy as np | |||||
from .utils import download_url, mkdir | |||||
def make_dataset(dir, image_ids, targets): | |||||
assert(len(image_ids) == len(targets)) | |||||
images = [] | |||||
dir = os.path.expanduser(dir) | |||||
for i in range(len(image_ids)): | |||||
item = (os.path.join(dir, 'fgvc-aircraft-2013b', 'data', 'images', | |||||
'%s.jpg' % image_ids[i]), targets[i]) | |||||
images.append(item) | |||||
return images | |||||
def find_classes(classes_file): | |||||
# read classes file, separating out image IDs and class names | |||||
image_ids = [] | |||||
targets = [] | |||||
f = open(classes_file, 'r') | |||||
for line in f: | |||||
split_line = line.split(' ') | |||||
image_ids.append(split_line[0]) | |||||
targets.append(' '.join(split_line[1:])) | |||||
f.close() | |||||
# index class names | |||||
classes = np.unique(targets) | |||||
class_to_idx = {classes[i]: i for i in range(len(classes))} | |||||
targets = [class_to_idx[c] for c in targets] | |||||
return (image_ids, targets, classes, class_to_idx) | |||||
class FGVCAircraft(data.Dataset): | |||||
"""`FGVC-Aircraft <http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft>`_ Dataset. | |||||
Args: | |||||
root (string): Root directory path to dataset. | |||||
class_type (string, optional): The level of FGVC-Aircraft fine-grain classification | |||||
to label data with (i.e., ``variant``, ``family``, or ``manufacturer``). | |||||
transforms (callable, optional): A function/transforms that takes in a PIL image | |||||
and returns a transformed version. E.g. ``transforms.RandomCrop`` | |||||
target_transform (callable, optional): A function/transforms that takes in the | |||||
target and transforms it. | |||||
loader (callable, optional): A function to load an image given its path. | |||||
download (bool, optional): If true, downloads the dataset from the internet and | |||||
puts it in the root directory. If dataset is already downloaded, it is not | |||||
downloaded again. | |||||
""" | |||||
url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz' | |||||
class_types = ('variant', 'family', 'manufacturer') | |||||
splits = ('train', 'val', 'trainval', 'test') | |||||
def __init__(self, root, class_type='variant', split='train', transform=None, | |||||
target_transform=None, loader=default_loader, download=False): | |||||
if split not in self.splits: | |||||
raise ValueError('Split "{}" not found. Valid splits are: {}'.format( | |||||
split, ', '.join(self.splits), | |||||
)) | |||||
if class_type not in self.class_types: | |||||
raise ValueError('Class type "{}" not found. Valid class types are: {}'.format( | |||||
class_type, ', '.join(self.class_types), | |||||
)) | |||||
self.root = root | |||||
self.class_type = class_type | |||||
self.split = split | |||||
self.classes_file = os.path.join(self.root, 'fgvc-aircraft-2013b', 'data', | |||||
'images_%s_%s.txt' % (self.class_type, self.split)) | |||||
if download: | |||||
self.download() | |||||
(image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file) | |||||
samples = make_dataset(self.root, image_ids, targets) | |||||
self.transform = transform | |||||
self.target_transform = target_transform | |||||
self.loader = loader | |||||
self.samples = samples | |||||
self.classes = classes | |||||
self.class_to_idx = class_to_idx | |||||
with open(os.path.join(self.root, 'fgvc-aircraft-2013b/data', 'variants.txt')) as f: | |||||
self.object_categories = [ | |||||
line.strip('\n') for line in f.readlines()] | |||||
print('FGVC-Aircraft, Split: %s, Size: %d' % (self.split, self.__len__())) | |||||
def __getitem__(self, index): | |||||
""" | |||||
Args: | |||||
index (int): Index | |||||
Returns: | |||||
tuple: (sample, target) where target is class_index of the target class. | |||||
""" | |||||
path, target = self.samples[index] | |||||
sample = self.loader(path) | |||||
if self.transform is not None: | |||||
sample = self.transform(sample) | |||||
if self.target_transform is not None: | |||||
target = self.target_transform(target) | |||||
return sample, target | |||||
def __len__(self): | |||||
return len(self.samples) | |||||
def __repr__(self): | |||||
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' | |||||
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) | |||||
fmt_str += ' Root Location: {}\n'.format(self.root) | |||||
tmp = ' Transforms (if any): ' | |||||
fmt_str += '{0}{1}\n'.format( | |||||
tmp, self.transforms.__repr__().replace('\n', '\n' + ' ' * len(tmp))) | |||||
tmp = ' Target Transforms (if any): ' | |||||
fmt_str += '{0}{1}'.format( | |||||
tmp, self.target_transforms.__repr__().replace('\n', '\n' + ' ' * len(tmp))) | |||||
return fmt_str | |||||
def _check_exists(self): | |||||
return os.path.exists(os.path.join(self.root, 'data', 'images')) and \ | |||||
os.path.exists(self.classes_file) | |||||
def download(self): | |||||
"""Download the FGVC-Aircraft data if it doesn't exist already.""" | |||||
from six.moves import urllib | |||||
import tarfile | |||||
mkdir(self.root) | |||||
fpath = os.path.join(self.root, 'fgvc-aircraft-2013b.tar.gz') | |||||
if not os.path.isfile(fpath): | |||||
download_url(self.url, self.root, 'fgvc-aircraft-2013b.tar.gz') | |||||
print("Extracting fgvc-aircraft-2013b.tar.gz...") | |||||
with tarfile.open(fpath, "r:gz") as tar: | |||||
tar.extractall(path=self.root) |
@@ -0,0 +1,84 @@ | |||||
# Modified from https://github.com/VainF/nyuv2-python-toolkit | |||||
import os | |||||
import torch | |||||
import torch.utils.data as data | |||||
from PIL import Image | |||||
from scipy.io import loadmat | |||||
import numpy as np | |||||
import glob | |||||
from torchvision import transforms | |||||
from torchvision.datasets import VisionDataset | |||||
import random | |||||
from .utils import colormap | |||||
class NYUv2(VisionDataset): | |||||
"""NYUv2 dataset | |||||
See https://github.com/VainF/nyuv2-python-toolkit for more details. | |||||
Args: | |||||
root (string): Root directory path. | |||||
split (string, optional): 'train' for training set, and 'test' for test set. Default: 'train'. | |||||
target_type (string, optional): Type of target to use, ``semantic``, ``depth`` or ``normal``. | |||||
num_classes (int, optional): The number of classes, must be 40 or 13. Default:13. | |||||
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. | |||||
target_transform (callable, optional): A function/transform that takes in the target and transforms it. | |||||
transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version. | |||||
""" | |||||
cmap = colormap() | |||||
def __init__(self, | |||||
root, | |||||
split='train', | |||||
target_type='semantic', | |||||
num_classes=13, | |||||
transforms=None, | |||||
transform=None, | |||||
target_transform=None): | |||||
super( NYUv2, self ).__init__(root, transforms=transforms, transform=transform, target_transform=target_transform) | |||||
assert(split in ('train', 'test')) | |||||
self.root = root | |||||
self.split = split | |||||
self.target_type = target_type | |||||
self.num_classes = num_classes | |||||
split_mat = loadmat(os.path.join(self.root, 'splits.mat')) | |||||
idxs = split_mat[self.split+'Ndxs'].reshape(-1) - 1 | |||||
img_names = os.listdir( os.path.join(self.root, 'image', self.split) ) | |||||
img_names.sort() | |||||
images_dir = os.path.join(self.root, 'image', self.split) | |||||
self.images = [os.path.join(images_dir, name) for name in img_names] | |||||
self._is_depth = False | |||||
if self.target_type=='semantic': | |||||
semantic_dir = os.path.join(self.root, 'seg%d'%self.num_classes, self.split) | |||||
self.labels = [os.path.join(semantic_dir, name) for name in img_names] | |||||
self.targets = self.labels | |||||
if self.target_type=='depth': | |||||
depth_dir = os.path.join(self.root, 'depth', self.split) | |||||
self.depths = [os.path.join(depth_dir, name) for name in img_names] | |||||
self.targets = self.depths | |||||
self._is_depth = True | |||||
if self.target_type=='normal': | |||||
normal_dir = os.path.join(self.root, 'normal', self.split) | |||||
self.normals = [os.path.join(normal_dir, name) for name in img_names] | |||||
self.targets = self.normals | |||||
def __getitem__(self, idx): | |||||
image = Image.open(self.images[idx]) | |||||
target = Image.open(self.targets[idx]) | |||||
if self.transforms is not None: | |||||
image, target = self.transforms( image, target ) | |||||
return image, target | |||||
def __len__(self): | |||||
return len(self.images) | |||||
@classmethod | |||||
def decode_fn(cls, mask: np.ndarray): | |||||
"""decode semantic mask to RGB image""" | |||||
mask = mask.astype('uint8') + 1 # 255 => 0 | |||||
return cls.cmap[mask] |
@@ -0,0 +1,61 @@ | |||||
import os, sys | |||||
from glob import glob | |||||
import random | |||||
import argparse | |||||
from PIL import Image | |||||
if __name__=='__main__': | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument('--data_root', type=str, default='../101_ObjectCategories') | |||||
parser.add_argument('--test_split', type=float, default=0.3) | |||||
args = parser.parse_args() | |||||
SAVE_DIR = os.path.join( os.path.dirname(args.data_root), 'caltech101_data' ) | |||||
if not os.path.exists(SAVE_DIR): | |||||
os.mkdir(SAVE_DIR) | |||||
# Train | |||||
TRAIN_DIR = os.path.join( SAVE_DIR, 'train' ) | |||||
if not os.path.exists(TRAIN_DIR): | |||||
os.mkdir(TRAIN_DIR) | |||||
# Test | |||||
TEST_DIR = os.path.join( SAVE_DIR, 'test' ) | |||||
if not os.path.exists(TEST_DIR): | |||||
os.mkdir(TEST_DIR) | |||||
img_folders = os.listdir(args.data_root) | |||||
img_folders.sort() | |||||
for folder in img_folders: | |||||
if folder=='Faces': | |||||
continue | |||||
print('Processing %s'%(folder)) | |||||
img_paths = glob(os.path.join( args.data_root, folder, '*.jpg') ) | |||||
img_name = [os.path.split(p)[-1] for p in img_paths] | |||||
random.shuffle(img_name) | |||||
img_n = len(img_name) | |||||
test_n = int(args.test_split * img_n) | |||||
test_set = img_name[:test_n] | |||||
train_set = img_name[test_n:] | |||||
# test | |||||
dst_path = os.path.join(TEST_DIR, folder) | |||||
if not os.path.exists(dst_path): | |||||
os.mkdir(dst_path) | |||||
for test_name in test_set: | |||||
img = Image.open(os.path.join( args.data_root, folder, test_name )) | |||||
img.save( os.path.join(dst_path, test_name ) ) | |||||
# train | |||||
dst_path = os.path.join(TRAIN_DIR, folder) | |||||
if not os.path.exists(dst_path): | |||||
os.mkdir(dst_path) | |||||
for train_name in train_set: | |||||
img = Image.open(os.path.join( args.data_root, folder, train_name )) | |||||
img.save( os.path.join(dst_path, train_name ) ) |
@@ -0,0 +1,196 @@ | |||||
from __future__ import print_function | |||||
import sys | |||||
import os, sys, tarfile, errno | |||||
import numpy as np | |||||
import matplotlib.pyplot as plt | |||||
import argparse | |||||
if sys.version_info >= (3, 0, 0): | |||||
import urllib.request as urllib # ugly but works | |||||
else: | |||||
import urllib | |||||
try: | |||||
from imageio import imsave | |||||
except: | |||||
from scipy.misc import imsave | |||||
print(sys.version_info) | |||||
# image shape | |||||
HEIGHT = 96 | |||||
WIDTH = 96 | |||||
DEPTH = 3 | |||||
# size of a single image in bytes | |||||
SIZE = HEIGHT * WIDTH * DEPTH | |||||
# path to the directory with the data | |||||
#DATA_DIR = './data' | |||||
# url of the binary data | |||||
#DATA_URL = 'http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz' | |||||
# path to the binary train file with image data | |||||
#DATA_PATH = './data/stl10_binary/train_X.bin' | |||||
# path to the binary train file with labels | |||||
#LABEL_PATH = './data/stl10_binary/train_y.bin' | |||||
def read_labels(path_to_labels): | |||||
""" | |||||
:param path_to_labels: path to the binary file containing labels from the STL-10 dataset | |||||
:return: an array containing the labels | |||||
""" | |||||
with open(path_to_labels, 'rb') as f: | |||||
labels = np.fromfile(f, dtype=np.uint8) | |||||
return labels | |||||
def read_all_images(path_to_data): | |||||
""" | |||||
:param path_to_data: the file containing the binary images from the STL-10 dataset | |||||
:return: an array containing all the images | |||||
""" | |||||
with open(path_to_data, 'rb') as f: | |||||
# read whole file in uint8 chunks | |||||
everything = np.fromfile(f, dtype=np.uint8) | |||||
# We force the data into 3x96x96 chunks, since the | |||||
# images are stored in "column-major order", meaning | |||||
# that "the first 96*96 values are the red channel, | |||||
# the next 96*96 are green, and the last are blue." | |||||
# The -1 is since the size of the pictures depends | |||||
# on the input file, and this way numpy determines | |||||
# the size on its own. | |||||
images = np.reshape(everything, (-1, 3, 96, 96)) | |||||
# Now transpose the images into a standard image format | |||||
# readable by, for example, matplotlib.imshow | |||||
# You might want to comment this line or reverse the shuffle | |||||
# if you will use a learning algorithm like CNN, since they like | |||||
# their channels separated. | |||||
images = np.transpose(images, (0, 3, 2, 1)) | |||||
return images | |||||
def read_single_image(image_file): | |||||
""" | |||||
CAREFUL! - this method uses a file as input instead of the path - so the | |||||
position of the reader will be remembered outside of context of this method. | |||||
:param image_file: the open file containing the images | |||||
:return: a single image | |||||
""" | |||||
# read a single image, count determines the number of uint8's to read | |||||
image = np.fromfile(image_file, dtype=np.uint8, count=SIZE) | |||||
# force into image matrix | |||||
image = np.reshape(image, (3, 96, 96)) | |||||
# transpose to standard format | |||||
# You might want to comment this line or reverse the shuffle | |||||
# if you will use a learning algorithm like CNN, since they like | |||||
# their channels separated. | |||||
image = np.transpose(image, (2, 1, 0)) | |||||
return image | |||||
def plot_image(image): | |||||
""" | |||||
:param image: the image to be plotted in a 3-D matrix format | |||||
:return: None | |||||
""" | |||||
plt.imshow(image) | |||||
plt.show() | |||||
def save_image(image, name): | |||||
imsave("%s.png" % name, image, format="png") | |||||
def download_and_extract(DATA_DIR): | |||||
""" | |||||
Download and extract the STL-10 dataset | |||||
:return: None | |||||
""" | |||||
dest_directory = DATA_DIR | |||||
if not os.path.exists(dest_directory): | |||||
os.makedirs(dest_directory) | |||||
filename = DATA_URL.split('/')[-1] | |||||
filepath = os.path.join(dest_directory, filename) | |||||
if not os.path.exists(filepath): | |||||
def _progress(count, block_size, total_size): | |||||
sys.stdout.write('\rDownloading %s %.2f%%' % (filename, | |||||
float(count * block_size) / float(total_size) * 100.0)) | |||||
sys.stdout.flush() | |||||
filepath, _ = urllib.urlretrieve(DATA_URL, filepath, reporthook=_progress) | |||||
print('Downloaded', filename) | |||||
tarfile.open(filepath, 'r:gz').extractall(dest_directory) | |||||
def save_images(images, labels, save_dir): | |||||
print("Saving images to disk") | |||||
i = 0 | |||||
if not os.path.exists(save_dir): | |||||
os.mkdir(save_dir) | |||||
for image in images: | |||||
if labels is not None: | |||||
label = labels[i] | |||||
directory = os.path.join( save_dir, str(label) ) | |||||
if not os.path.exists(directory): | |||||
os.makedirs(directory) | |||||
else: | |||||
directory = save_dir | |||||
filename = os.path.join( directory, str(i) ) | |||||
print(filename) | |||||
save_image(image, filename) | |||||
i = i+1 | |||||
if __name__ == "__main__": | |||||
# download data if needed | |||||
# download_and_extract() | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument('--DATA_DIR', type=str, default='../stl10_binary') | |||||
args = parser.parse_args() | |||||
BASE_PATH = os.path.join( args.DATA_DIR, 'stl10_data' ) | |||||
if not os.path.exists(BASE_PATH): | |||||
os.mkdir(BASE_PATH) | |||||
# Train | |||||
DATA_PATH = os.path.join( args.DATA_DIR, 'train_X.bin') | |||||
LABEL_PATH = os.path.join( args.DATA_DIR, 'train_y.bin') | |||||
print('Preparing Train') | |||||
images = read_all_images(DATA_PATH) | |||||
print(images.shape) | |||||
labels = read_labels(LABEL_PATH) | |||||
print(labels.shape) | |||||
save_images(images, labels, os.path.join(BASE_PATH, 'train')) | |||||
# Test | |||||
DATA_PATH = os.path.join( args.DATA_DIR, 'test_X.bin') | |||||
LABEL_PATH = os.path.join( args.DATA_DIR, 'test_y.bin') | |||||
#with open(DATA_PATH) as f: | |||||
# image = read_single_image(f) | |||||
# plot_image(image) | |||||
print('Preparing Test') | |||||
images = read_all_images(DATA_PATH) | |||||
print(images.shape) | |||||
labels = read_labels(LABEL_PATH) | |||||
print(labels.shape) | |||||
save_images(images, labels, os.path.join(BASE_PATH, 'test')) | |||||
# Unlabeled | |||||
print('Preparing Unlabeled') | |||||
DATA_PATH = os.path.join( args.DATA_DIR, 'unlabeled_X.bin') | |||||
images = read_all_images(DATA_PATH) | |||||
save_images(images, None, os.path.join(BASE_PATH, 'unlabeled')) | |||||
@@ -0,0 +1,53 @@ | |||||
import argparse | |||||
import os | |||||
from PIL import Image | |||||
import numpy as np | |||||
SIZE=(480, 320) # W, H | |||||
def is_image(path: str): | |||||
return path.endswith( 'png' ) or path.endswith('jpg') or path.endswith( 'jpeg' ) | |||||
def copy_and_resize( src_dir, dst_dir, resize_fn ): | |||||
for file_or_dir in os.listdir( src_dir ): | |||||
src = os.path.join( src_dir, file_or_dir ) | |||||
dst = os.path.join( dst_dir, file_or_dir ) | |||||
if os.path.isdir( src ): | |||||
os.mkdir( dst ) | |||||
copy_and_resize( src, dst, resize_fn ) | |||||
elif is_image( src ): | |||||
print(src, ' -> ', dst) | |||||
image = Image.open( src ) | |||||
resized_image = resize_fn(image) | |||||
resized_image.save( dst ) | |||||
def resize_input( image: Image.Image ): | |||||
return image.resize( SIZE, resample=Image.BICUBIC ) | |||||
def resize_seg( image: Image.Image ): | |||||
return image.resize( SIZE, resample=Image.NEAREST ) | |||||
def main(): | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument('--root', type=str, required=True) | |||||
ROOT = parser.parse_args().root | |||||
NEW_ROOT = os.path.join( ROOT, '%d_%d'%(*SIZE) ) | |||||
os.mkdir(NEW_ROOT, ) | |||||
for split in ['train', 'val', 'test']: | |||||
IMG_DIR = os.path.join( ROOT, split ) | |||||
GT_DIR = os.path.join( ROOT, split+'annot' ) | |||||
NEW_IMG_DIR = os.path.join( NEW_ROOT, split ) | |||||
NEW_GT_DIR = os.path.join( NEW_ROOT, split+'annot' ) | |||||
os.mkdir( NEW_IMG_DIR ) | |||||
os.mkdir( NEW_GT_DIR ) | |||||
copy_and_resize( IMG_DIR, NEW_IMG_DIR, resize_input ) | |||||
copy_and_resize( GT_DIR, NEW_GT_DIR, resize_seg ) | |||||
if __name__=='__main__': | |||||
main() |
@@ -0,0 +1,65 @@ | |||||
import argparse | |||||
import os | |||||
from PIL import Image | |||||
import numpy as np | |||||
SIZE=(640, 320) | |||||
def is_image(path: str): | |||||
return path.endswith( 'png' ) or path.endswith('jpg') or path.endswith( 'jpeg' ) | |||||
def copy_and_resize( src_dir, dst_dir, resize_fn ): | |||||
for file_or_dir in os.listdir( src_dir ): | |||||
src = os.path.join( src_dir, file_or_dir ) | |||||
dst = os.path.join( dst_dir, file_or_dir ) | |||||
if os.path.isdir( src ): | |||||
os.mkdir( dst ) | |||||
copy_and_resize( src, dst, resize_fn ) | |||||
elif is_image( src ): | |||||
print(src, ' -> ', dst) | |||||
image = Image.open( src ) | |||||
resized_image = resize_fn(image) | |||||
resized_image.save( dst ) | |||||
def resize_input( image: Image.Image ): | |||||
return image.resize( SIZE, resample=Image.BICUBIC ) | |||||
def resize_seg( image: Image.Image ): | |||||
return image.resize( SIZE, resample=Image.NEAREST ) | |||||
def resize_depth( image: Image.Image ): | |||||
return image.resize( SIZE, resample=Image.NEAREST ) | |||||
def main(): | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument('--root', type=str, required=True) | |||||
ROOT = parser.parse_args().root | |||||
IMG_DIR = os.path.join( ROOT, 'leftImg8bit' ) | |||||
GT_DIR = os.path.join( ROOT, 'gtFine' ) | |||||
DEPTH_DIR = os.path.join( ROOT, 'disparity' ) | |||||
NEW_ROOT = os.path.join( ROOT, '%d_%d'%(*SIZE) ) | |||||
NEW_IMG_DIR = os.path.join( NEW_ROOT, 'leftImg8bit' ) | |||||
NEW_GT_DIR = os.path.join( NEW_ROOT, 'gtFine' ) | |||||
NEW_DEPTH_DIR = os.path.join( NEW_ROOT, 'disparity' ) | |||||
if os.path.exists(NEW_ROOT): | |||||
print("Directory %s existed, please remove it before running this script"%NEW_ROOT) | |||||
return | |||||
os.mkdir(NEW_ROOT) | |||||
os.mkdir( NEW_IMG_DIR ) | |||||
os.mkdir( NEW_GT_DIR ) | |||||
os.mkdir( NEW_DEPTH_DIR ) | |||||
copy_and_resize( IMG_DIR, NEW_IMG_DIR, resize_input ) | |||||
copy_and_resize( GT_DIR, NEW_GT_DIR, resize_seg ) | |||||
copy_and_resize( DEPTH_DIR, NEW_DEPTH_DIR, resize_depth ) | |||||
if __name__=='__main__': | |||||
main() |
@@ -0,0 +1,59 @@ | |||||
import argparse | |||||
import os | |||||
from PIL import Image | |||||
import numpy as np | |||||
from torchvision import transforms | |||||
import shutil | |||||
SIZE=240 | |||||
def is_image(path: str): | |||||
return path.endswith( 'png' ) or path.endswith('jpg') or path.endswith( 'jpeg' ) | |||||
def copy_and_resize( src_dir, dst_dir, resize_fn ): | |||||
for file_or_dir in os.listdir( src_dir ): | |||||
src = os.path.join( src_dir, file_or_dir ) | |||||
dst = os.path.join( dst_dir, file_or_dir ) | |||||
if os.path.isdir( src ): | |||||
os.mkdir( dst ) | |||||
copy_and_resize( src, dst, resize_fn ) | |||||
elif is_image( src ): | |||||
print(src, ' -> ', dst) | |||||
image = Image.open( src ) | |||||
resized_image = resize_fn(image) | |||||
resized_image.save( dst ) | |||||
def resize_input( image: Image.Image ): | |||||
return transforms.functional.resize( image, SIZE, interpolation=Image.BILINEAR ) | |||||
def resize_seg( image: Image.Image ): | |||||
return transforms.functional.resize( image, SIZE, interpolation=Image.NEAREST) | |||||
def main(): | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument('--root', type=str, required=True) | |||||
ROOT = parser.parse_args().root | |||||
IMG_DIR = os.path.join( ROOT, 'JPEGImages' ) | |||||
GT_DIR = os.path.join( ROOT, 'SegmentationClass' ) | |||||
NEW_ROOT = os.path.join( ROOT, str(SIZE) ) | |||||
NEW_IMG_DIR = os.path.join( NEW_ROOT, 'JPEGImages' ) | |||||
NEW_GT_DIR = os.path.join( NEW_ROOT, 'SegmentationClass' ) | |||||
if os.path.exists(NEW_ROOT): | |||||
print("Directory %s existed, please remove it before running this script"%NEW_ROOT) | |||||
return | |||||
os.mkdir(NEW_ROOT) | |||||
os.mkdir( NEW_IMG_DIR ) | |||||
os.mkdir( NEW_GT_DIR ) | |||||
copy_and_resize( IMG_DIR, NEW_IMG_DIR, resize_input ) | |||||
copy_and_resize( GT_DIR, NEW_GT_DIR, resize_seg ) | |||||
shutil.copytree( os.path.join( ROOT, 'ImageSets'), os.path.join( NEW_ROOT, 'ImageSets' ) ) | |||||
if __name__=='__main__': | |||||
main() |
@@ -0,0 +1,57 @@ | |||||
import argparse | |||||
import os | |||||
from PIL import Image | |||||
import numpy as np | |||||
from torchvision import transforms | |||||
import shutil | |||||
def is_image(path: str): | |||||
return path.endswith( 'png' ) or path.endswith('jpg') or path.endswith( 'jpeg' ) | |||||
def copy_and_resize( src_dir, dst_dir, resize_fn ): | |||||
for file_or_dir in os.listdir( src_dir ): | |||||
src = os.path.join( src_dir, file_or_dir ) | |||||
dst = os.path.join( dst_dir, file_or_dir ) | |||||
if os.path.isdir( src ): | |||||
os.mkdir( dst ) | |||||
copy_and_resize( src, dst, resize_fn ) | |||||
elif is_image( src ): | |||||
print(src, ' -> ', dst) | |||||
image = Image.open( src ) | |||||
resized_image = resize_fn(image) | |||||
resized_image.save( dst ) | |||||
def resize_input( image: Image.Image ): | |||||
return transforms.functional.resize( image, 240, interpolation=Image.BILINEAR ) | |||||
def resize_seg( image: Image.Image ): | |||||
return transforms.functional.resize( image, 240, interpolation=Image.NEAREST) | |||||
def main(): | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument('--root', type=str, required=True) | |||||
ROOT = parser.parse_args().root | |||||
IMG_DIR = os.path.join( ROOT, 'JPEGImages' ) | |||||
GT_DIR = os.path.join( ROOT, 'SegmentationClass' ) | |||||
NEW_ROOT = os.path.join( ROOT, '240' ) | |||||
NEW_IMG_DIR = os.path.join( NEW_ROOT, 'JPEGImages' ) | |||||
NEW_GT_DIR = os.path.join( NEW_ROOT, 'SegmentationClass' ) | |||||
if os.path.exists(NEW_ROOT): | |||||
print("Directory %s existed, please remove it before running this script"%NEW_ROOT) | |||||
return | |||||
os.mkdir(NEW_ROOT) | |||||
os.mkdir( NEW_IMG_DIR ) | |||||
os.mkdir( NEW_GT_DIR ) | |||||
copy_and_resize( IMG_DIR, NEW_IMG_DIR, resize_input ) | |||||
copy_and_resize( GT_DIR, NEW_GT_DIR, resize_seg ) | |||||
shutil.copytree( os.path.join( ROOT, 'ImageSets'), os.path.join( NEW_ROOT, 'ImageSets' ) ) | |||||
if __name__=='__main__': | |||||
main() |
@@ -0,0 +1,80 @@ | |||||
import os | |||||
import glob | |||||
from PIL import Image | |||||
import numpy as np | |||||
from scipy.io import loadmat | |||||
from torch.utils import data | |||||
from .utils import download_url, mkdir | |||||
from shutil import copyfile | |||||
class StanfordCars(data.Dataset): | |||||
"""Dataset for Stanford Cars | |||||
""" | |||||
urls = {'cars_train.tgz': 'http://imagenet.stanford.edu/internal/car196/cars_train.tgz', | |||||
'cars_test.tgz': 'http://imagenet.stanford.edu/internal/car196/cars_test.tgz', | |||||
'car_devkit.tgz': 'https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz', | |||||
'cars_test_annos_withlabels.mat': 'http://imagenet.stanford.edu/internal/car196/cars_test_annos_withlabels.mat'} | |||||
def __init__(self, root, split='train', download=False, transform=None, target_transform=None): | |||||
self.root = os.path.abspath( os.path.expanduser(root) ) | |||||
self.split = split | |||||
self.transform = transform | |||||
self.target_transform = target_transform | |||||
if download: | |||||
self.download() | |||||
if self.split == 'train': | |||||
annos = os.path.join(self.root, 'devkit', 'cars_train_annos.mat') | |||||
else: | |||||
annos = os.path.join(self.root, 'devkit', | |||||
'cars_test_annos_withlabels.mat') | |||||
annos = loadmat(annos) | |||||
size = len(annos['annotations'][0]) | |||||
self.files = glob.glob(os.path.join( | |||||
self.root, 'cars_'+self.split, '*.jpg')) | |||||
self.files.sort() | |||||
self.labels = np.array([int(l[4])-1 for l in annos['annotations'][0]]) | |||||
lbl_annos = loadmat(os.path.join(self.root, 'devkit', 'cars_meta.mat')) | |||||
self.object_categories = [str(c[0]) | |||||
for c in lbl_annos['class_names'][0]] | |||||
print('Stanford Cars, Split: %s, Size: %d' % | |||||
(self.split, self.__len__())) | |||||
def __len__(self): | |||||
return len(self.files) | |||||
def __getitem__(self, idx): | |||||
img = Image.open(os.path.join(self.root, 'Images', | |||||
self.files[idx])).convert("RGB") | |||||
lbl = self.labels[idx] | |||||
if self.transform is not None: | |||||
img = self.transform(img) | |||||
if self.target_transform is not None: | |||||
lbl = self.target_transform(lbl) | |||||
return img, lbl | |||||
def download(self): | |||||
import tarfile | |||||
mkdir(self.root) | |||||
for fname, url in self.urls.items(): | |||||
if not os.path.isfile(os.path.join(self.root, fname)): | |||||
download_url(url, self.root, fname) | |||||
if fname.endswith('tgz'): | |||||
print("Extracting %s..." % fname) | |||||
with tarfile.open(os.path.join(self.root, fname), "r:gz") as tar: | |||||
tar.extractall(path=self.root) | |||||
copyfile(os.path.join(self.root, 'cars_test_annos_withlabels.mat'), | |||||
os.path.join(self.root, 'devkit', 'cars_test_annos_withlabels.mat')) |
@@ -0,0 +1,58 @@ | |||||
import os | |||||
import numpy as np | |||||
from PIL import Image | |||||
from scipy.io import loadmat | |||||
from torch.utils import data | |||||
from .utils import download_url | |||||
from shutil import move | |||||
class StanfordDogs(data.Dataset): | |||||
"""Dataset for Stanford Dogs | |||||
""" | |||||
urls = {"images.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/images.tar", | |||||
"annotation.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/annotation.tar", | |||||
"lists.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/lists.tar"} | |||||
def __init__(self, root, split='train', download=False, transform=None, target_transform=None): | |||||
self.root = os.path.abspath( os.path.expanduser(root) ) | |||||
self.split = split | |||||
self.transform = transform | |||||
self.target_transform = target_transform | |||||
if download: | |||||
self.download() | |||||
list_file = os.path.join(self.root, self.split+'_list.mat') | |||||
mat_file = loadmat(list_file) | |||||
size = len(mat_file['file_list']) | |||||
self.files = [str(mat_file['file_list'][i][0][0]) for i in range(size)] | |||||
self.labels = np.array( | |||||
[mat_file['labels'][i][0]-1 for i in range(size)]) | |||||
categories = os.listdir(os.path.join(self.root, 'Images')) | |||||
categories.sort() | |||||
self.object_categories = [c[10:] for c in categories] | |||||
print('Stanford Dogs, Split: %s, Size: %d' % | |||||
(self.split, self.__len__())) | |||||
def __len__(self): | |||||
return len(self.files) | |||||
def __getitem__(self, idx): | |||||
img = Image.open(os.path.join(self.root, 'Images', | |||||
self.files[idx])).convert("RGB") | |||||
lbl = self.labels[idx] | |||||
if self.transform is not None: | |||||
img = self.transform(img) | |||||
if self.target_transform is not None: | |||||
lbl = self.target_transform( lbl ) | |||||
return img, lbl | |||||
def download(self): | |||||
import tarfile | |||||
os.makedirs(self.root, exist_ok=True) | |||||
for fname, url in self.urls.items(): | |||||
if not os.path.isfile(os.path.join(self.root, fname)): | |||||
download_url(url, self.root, fname) | |||||
# extract file | |||||
print("Extracting %s..." % fname) | |||||
with tarfile.open(os.path.join(self.root, fname), "r") as tar: | |||||
tar.extractall(path=self.root) |
@@ -0,0 +1,65 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
import os | |||||
from glob import glob | |||||
from PIL import Image | |||||
from .utils import colormap | |||||
from torchvision.datasets import VisionDataset | |||||
class SunRGBD(VisionDataset): | |||||
cmap = colormap() | |||||
def __init__(self, | |||||
root, | |||||
split='train', | |||||
transform=None, | |||||
target_transform=None, | |||||
transforms=None): | |||||
super( SunRGBD, self ).__init__( root, transform=transform, target_transform=target_transform, transforms=transforms ) | |||||
self.root = root | |||||
self.split = split | |||||
self.images = glob(os.path.join(self.root, 'SUNRGBD-%s_images'%self.split, '*.jpg')) | |||||
self.labels = glob(os.path.join(self.root, '%s13labels'%self.split, '*.png')) | |||||
self.images.sort() | |||||
self.labels.sort() | |||||
def __getitem__(self, idx): | |||||
""" | |||||
Args: | |||||
- index (``int``): index of the item in the dataset | |||||
Returns: | |||||
A tuple of ``PIL.Image`` (image, label) where label is the ground-truth | |||||
of the image. | |||||
""" | |||||
img, label = Image.open(self.images[idx]), Image.open(self.labels[idx]) | |||||
if self.transform is not None: | |||||
img, label = self.transform(img, label) | |||||
label = label-1 # void 0=>255 | |||||
return img, label | |||||
def __len__(self): | |||||
return len(self.images) | |||||
@classmethod | |||||
def decode_fn(cls, mask): | |||||
"""decode semantic mask to RGB image""" | |||||
return cls.cmap[mask.astype('uint8')+1] |
@@ -0,0 +1,67 @@ | |||||
""" | |||||
Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
============================================================= | |||||
""" | |||||
import torch | |||||
from torch.utils.data import Dataset | |||||
import os | |||||
from PIL import Image | |||||
import random | |||||
from copy import deepcopy | |||||
def _collect_all_images(root, postfix=['png', 'jpg', 'jpeg', 'JPEG']): | |||||
images = [] | |||||
if isinstance( postfix, str): | |||||
postfix = [ postfix ] | |||||
for dirpath, dirnames, files in os.walk(root): | |||||
for pos in postfix: | |||||
for f in files: | |||||
if f.endswith( pos ): | |||||
images.append( os.path.join( dirpath, f ) ) | |||||
return images | |||||
def get_train_val_set(root, val_size=0.3): | |||||
if not isinstance(root, (list, tuple)): | |||||
root = [root] | |||||
train_set = [] | |||||
val_set = [] | |||||
for _root in root: | |||||
_part_train_set = _collect_all_images( _root) | |||||
if os.path.isdir( os.path.join(_root, 'test') ): | |||||
_part_val_set = _collect_all_images( os.path.join(_root, 'test') ) | |||||
else: | |||||
_val_size = int( len(_part_train_set) * val_size ) | |||||
_part_val_set = random.sample( _part_train_set, k=_val_size ) | |||||
_part_train_set = [ d for d in _part_train_set if d not in _part_val_set ] | |||||
train_set.extend(_part_train_set) | |||||
val_set.extend(_part_val_set) | |||||
return train_set, val_set | |||||
class UnlabeledDataset(Dataset): | |||||
def __init__(self, data, transform=None, postfix=['png', 'jpg', 'jpeg', 'JPEG']): | |||||
self.transform = transform | |||||
self.data = data | |||||
def __getitem__(self, idx): | |||||
data = Image.open( self.data[idx] ) | |||||
if self.transform is not None: | |||||
data = self.transform(data) | |||||
return data | |||||
def __len__(self): | |||||
return len(self.data) | |||||
@@ -0,0 +1,161 @@ | |||||
# Modified from https://github.com/pytorch/vision | |||||
import os | |||||
import os.path | |||||
import hashlib | |||||
import errno | |||||
from tqdm import tqdm | |||||
import numpy as np | |||||
import torch | |||||
import random | |||||
def mkdir(dir): | |||||
if not os.path.isdir(dir): | |||||
os.mkdir(dir) | |||||
def colormap(N=256, normalized=False): | |||||
def bitget(byteval, idx): | |||||
return ((byteval & (1 << idx)) != 0) | |||||
dtype = 'float32' if normalized else 'uint8' | |||||
cmap = np.zeros((N, 3), dtype=dtype) | |||||
for i in range(N): | |||||
r = g = b = 0 | |||||
c = i | |||||
for j in range(8): | |||||
r = r | (bitget(c, 0) << 7-j) | |||||
g = g | (bitget(c, 1) << 7-j) | |||||
b = b | (bitget(c, 2) << 7-j) | |||||
c = c >> 3 | |||||
cmap[i] = np.array([r, g, b]) | |||||
cmap = cmap/255 if normalized else cmap | |||||
return cmap | |||||
DEFAULT_COLORMAP = colormap() | |||||
def gen_bar_updater(pbar): | |||||
def bar_update(count, block_size, total_size): | |||||
if pbar.total is None and total_size: | |||||
pbar.total = total_size | |||||
progress_bytes = count * block_size | |||||
pbar.update(progress_bytes - pbar.n) | |||||
return bar_update | |||||
def check_integrity(fpath, md5=None): | |||||
if md5 is None: | |||||
return True | |||||
if not os.path.isfile(fpath): | |||||
return False | |||||
md5o = hashlib.md5() | |||||
with open(fpath, 'rb') as f: | |||||
# read in 1MB chunks | |||||
for chunk in iter(lambda: f.read(1024 * 1024), b''): | |||||
md5o.update(chunk) | |||||
md5c = md5o.hexdigest() | |||||
if md5c != md5: | |||||
return False | |||||
return True | |||||
def makedir_exist_ok(dirpath): | |||||
""" | |||||
Python2 support for os.makedirs(.., exist_ok=True) | |||||
""" | |||||
try: | |||||
os.makedirs(dirpath) | |||||
except OSError as e: | |||||
if e.errno == errno.EEXIST: | |||||
pass | |||||
else: | |||||
raise | |||||
def download_url(url, root, filename=None, md5=None): | |||||
"""Download a file from a url and place it in root. | |||||
Args: | |||||
url (str): URL to download file from | |||||
root (str): Directory to place downloaded file in | |||||
filename (str): Name to save the file under. If None, use the basename of the URL | |||||
md5 (str): MD5 checksum of the download. If None, do not check | |||||
""" | |||||
from six.moves import urllib | |||||
root = os.path.expanduser(root) | |||||
if not filename: | |||||
filename = os.path.basename(url) | |||||
fpath = os.path.join(root, filename) | |||||
makedir_exist_ok(root) | |||||
# downloads file | |||||
if os.path.isfile(fpath) and check_integrity(fpath, md5): | |||||
print('Using downloaded and verified file: ' + fpath) | |||||
else: | |||||
try: | |||||
print('Downloading ' + url + ' to ' + fpath) | |||||
urllib.request.urlretrieve( | |||||
url, fpath, | |||||
reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True)) | |||||
) | |||||
except OSError: | |||||
if url[:5] == 'https': | |||||
url = url.replace('https:', 'http:') | |||||
print('Failed download. Trying https -> http instead.' | |||||
' Downloading ' + url + ' to ' + fpath) | |||||
urllib.request.urlretrieve( | |||||
url, fpath, | |||||
reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True)) | |||||
) | |||||
def list_dir(root, prefix=False): | |||||
"""List all directories at a given root | |||||
Args: | |||||
root (str): Path to directory whose folders need to be listed | |||||
prefix (bool, optional): If true, prepends the path to each result, otherwise | |||||
only returns the name of the directories found | |||||
""" | |||||
root = os.path.expanduser(root) | |||||
directories = list( | |||||
filter( | |||||
lambda p: os.path.isdir(os.path.join(root, p)), | |||||
os.listdir(root) | |||||
) | |||||
) | |||||
if prefix is True: | |||||
directories = [os.path.join(root, d) for d in directories] | |||||
return directories | |||||
def list_files(root, suffix, prefix=False): | |||||
"""List all files ending with a suffix at a given root | |||||
Args: | |||||
root (str): Path to directory whose folders need to be listed | |||||
suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). | |||||
It uses the Python "str.endswith" method and is passed directly | |||||
prefix (bool, optional): If true, prepends the path to each result, otherwise | |||||
only returns the name of the files found | |||||
""" | |||||
root = os.path.expanduser(root) | |||||
files = list( | |||||
filter( | |||||
lambda p: os.path.isfile(os.path.join( | |||||
root, p)) and p.endswith(suffix), | |||||
os.listdir(root) | |||||
) | |||||
) | |||||
if prefix is True: | |||||
files = [os.path.join(root, d) for d in files] | |||||
return files | |||||
def set_seed(random_seed): | |||||
torch.manual_seed(random_seed) | |||||
torch.cuda.manual_seed(random_seed) | |||||
np.random.seed(random_seed) | |||||
random.seed(random_seed) |
@@ -0,0 +1,209 @@ | |||||
# Modified from https://github.com/pytorch/vision | |||||
import os | |||||
import sys | |||||
import tarfile | |||||
import collections | |||||
import torch.utils.data as data | |||||
import shutil | |||||
import numpy as np | |||||
from .utils import colormap | |||||
from torchvision.datasets import VisionDataset | |||||
import torch | |||||
from PIL import Image | |||||
from torchvision.datasets.utils import download_url, check_integrity | |||||
DATASET_YEAR_DICT = { | |||||
'2012aug': { | |||||
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', | |||||
'filename': 'VOCtrainval_11-May-2012.tar', | |||||
'md5': '6cd6e144f989b92b3379bac3b3de84fd', | |||||
'base_dir': 'VOCdevkit/VOC2012' | |||||
}, | |||||
'2012': { | |||||
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', | |||||
'filename': 'VOCtrainval_11-May-2012.tar', | |||||
'md5': '6cd6e144f989b92b3379bac3b3de84fd', | |||||
'base_dir': 'VOCdevkit/VOC2012' | |||||
}, | |||||
'2011': { | |||||
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar', | |||||
'filename': 'VOCtrainval_25-May-2011.tar', | |||||
'md5': '6c3384ef61512963050cb5d687e5bf1e', | |||||
'base_dir': 'TrainVal/VOCdevkit/VOC2011' | |||||
}, | |||||
'2010': { | |||||
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar', | |||||
'filename': 'VOCtrainval_03-May-2010.tar', | |||||
'md5': 'da459979d0c395079b5c75ee67908abb', | |||||
'base_dir': 'VOCdevkit/VOC2010' | |||||
}, | |||||
'2009': { | |||||
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar', | |||||
'filename': 'VOCtrainval_11-May-2009.tar', | |||||
'md5': '59065e4b188729180974ef6572f6a212', | |||||
'base_dir': 'VOCdevkit/VOC2009' | |||||
}, | |||||
'2008': { | |||||
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar', | |||||
'filename': 'VOCtrainval_11-May-2012.tar', | |||||
'md5': '2629fa636546599198acfcfbfcf1904a', | |||||
'base_dir': 'VOCdevkit/VOC2008' | |||||
}, | |||||
'2007': { | |||||
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar', | |||||
'filename': 'VOCtrainval_06-Nov-2007.tar', | |||||
'md5': 'c52e279531787c972589f7e41ab4ae64', | |||||
'base_dir': 'VOCdevkit/VOC2007' | |||||
} | |||||
} | |||||
class VOCSegmentation(VisionDataset): | |||||
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset. | |||||
Args: | |||||
root (string): Root directory of the VOC Dataset. | |||||
year (string, optional): The dataset year, supports years 2007 to 2012. | |||||
image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val`` | |||||
download (bool, optional): If true, downloads the dataset from the internet and | |||||
puts it in root directory. If dataset is already downloaded, it is not | |||||
downloaded again. | |||||
transform (callable, optional): A function/transform that takes in an PIL image | |||||
and returns a transformed version. E.g, ``transforms.RandomCrop`` | |||||
""" | |||||
cmap = colormap() | |||||
def __init__(self, | |||||
root, | |||||
year='2012', | |||||
image_set='train', | |||||
download=False, | |||||
transform=None, | |||||
target_transform=None, | |||||
transforms=None, | |||||
): | |||||
super( VOCSegmentation, self ).__init__( root, transform=transform, target_transform=target_transform, transforms=transforms ) | |||||
is_aug=False | |||||
if year=='2012aug': | |||||
is_aug = True | |||||
year = '2012' | |||||
self.root = os.path.expanduser(root) | |||||
self.year = year | |||||
self.url = DATASET_YEAR_DICT[year]['url'] | |||||
self.filename = DATASET_YEAR_DICT[year]['filename'] | |||||
self.md5 = DATASET_YEAR_DICT[year]['md5'] | |||||
self.image_set = image_set | |||||
base_dir = DATASET_YEAR_DICT[year]['base_dir'] | |||||
voc_root = os.path.join(self.root, base_dir) | |||||
image_dir = os.path.join(voc_root, 'JPEGImages') | |||||
if download: | |||||
download_extract(self.url, self.root, self.filename, self.md5) | |||||
if not os.path.isdir(voc_root): | |||||
raise RuntimeError('Dataset not found or corrupted.' + | |||||
' You can use download=True to download it') | |||||
if is_aug and image_set=='train': | |||||
mask_dir = os.path.join(voc_root, 'SegmentationClassAug') | |||||
assert os.path.exists(mask_dir), "SegmentationClassAug not found, please refer to README.md and prepare it manually" | |||||
split_f = os.path.join( self.root, 'train_aug.txt') | |||||
else: | |||||
mask_dir = os.path.join(voc_root, 'SegmentationClass') | |||||
splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation') | |||||
split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt') | |||||
if not os.path.exists(split_f): | |||||
raise ValueError( | |||||
'Wrong image_set entered! Please use image_set="train" ' | |||||
'or image_set="trainval" or image_set="val"') | |||||
with open(os.path.join(split_f), "r") as f: | |||||
file_names = [x.strip() for x in f.readlines()] | |||||
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] | |||||
self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names] | |||||
assert (len(self.images) == len(self.masks)) | |||||
def __getitem__(self, index): | |||||
""" | |||||
Args: | |||||
index (int): Index | |||||
Returns: | |||||
tuple: (image, target) where target is the image segmentation. | |||||
""" | |||||
img = Image.open(self.images[index]).convert('RGB') | |||||
target = Image.open(self.masks[index]) | |||||
if self.transforms is not None: | |||||
img, target = self.transforms(img, target) | |||||
return img, target.squeeze(0) | |||||
def __len__(self): | |||||
return len(self.images) | |||||
@classmethod | |||||
def decode_fn(cls, mask): | |||||
"""decode semantic mask to RGB image""" | |||||
return cls.cmap[mask] | |||||
def download_extract(url, root, filename, md5): | |||||
download_url(url, root, filename, md5) | |||||
with tarfile.open(os.path.join(root, filename), "r") as tar: | |||||
tar.extractall(path=root) | |||||
CLASSES = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', | |||||
'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] | |||||
class VOCClassification(data.Dataset): | |||||
def __init__(self, | |||||
root, | |||||
year='2010', | |||||
split='train', | |||||
download=False, | |||||
transforms=None, | |||||
target_transforms=None): | |||||
voc_root = os.path.join(root, 'VOC{}'.format(year)) | |||||
if not os.path.isdir(voc_root): | |||||
raise RuntimeError('Dataset not found or corrupted.' + | |||||
' You can use download=True to download it') | |||||
self.transforms = transforms | |||||
self.target_transforms = target_transforms | |||||
image_dir = os.path.join(voc_root, 'JPEGImages') | |||||
label_dir = os.path.join(voc_root, 'ImageSets/Main') | |||||
self.labels_list = [] | |||||
fname = os.path.join(label_dir, '{}.txt'.format(split)) | |||||
with open(fname) as f: | |||||
self.images = [os.path.join(image_dir, line.split()[0]+'.jpg') for line in f] | |||||
for clas in CLASSES: | |||||
labels = [] | |||||
with open(os.path.join(label_dir, '{}_{}.txt'.format(clas, split))) as f: | |||||
labels = [int(line.split()[1]) for line in f] | |||||
self.labels_list.append(labels) | |||||
assert (len(self.images) == len(self.labels_list[0])) | |||||
def __getitem__(self, index): | |||||
""" | |||||
Args: | |||||
index (int): Index | |||||
Returns: | |||||
tuple: (image, target) where target is the image segmentation. | |||||
""" | |||||
img = Image.open(self.images[index]).convert('RGB') | |||||
labels = [labels[index] for labels in self.labels_list] | |||||
if self.transforms is not None: | |||||
img = self.transforms(img) | |||||
if self.target_transforms is not None: | |||||
labels = self.target_transforms(labels) | |||||
return img, labels | |||||
def __len__(self): | |||||
return len(self.images) |
@@ -0,0 +1,3 @@ | |||||
from . import classification, segmentation | |||||
from torchvision import models as torchvision_models |
@@ -0,0 +1,7 @@ | |||||
from .darknet import * | |||||
from .mobilenetv2 import * | |||||
from .resnet import * | |||||
from .vgg import * | |||||
from . import cifar | |||||
from .alexnet import alexnet |
@@ -0,0 +1,63 @@ | |||||
# Modified from https://github.com/pytorch/vision | |||||
import torch.nn as nn | |||||
import torch.utils.model_zoo as model_zoo | |||||
__all__ = ['AlexNet', 'alexnet'] | |||||
model_urls = { | |||||
'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', | |||||
} | |||||
class AlexNet(nn.Module): | |||||
def __init__(self, num_classes=1000): | |||||
super(AlexNet, self).__init__() | |||||
self.features = nn.Sequential( | |||||
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), | |||||
nn.ReLU(inplace=True), | |||||
nn.MaxPool2d(kernel_size=3, stride=2), | |||||
nn.Conv2d(64, 192, kernel_size=5, padding=2), | |||||
nn.ReLU(inplace=True), | |||||
nn.MaxPool2d(kernel_size=3, stride=2), | |||||
nn.Conv2d(192, 384, kernel_size=3, padding=1), | |||||
nn.ReLU(inplace=True), | |||||
nn.Conv2d(384, 256, kernel_size=3, padding=1), | |||||
nn.ReLU(inplace=True), | |||||
nn.Conv2d(256, 256, kernel_size=3, padding=1), | |||||
nn.ReLU(inplace=True), | |||||
nn.MaxPool2d(kernel_size=3, stride=2), | |||||
) | |||||
self.classifier = nn.Sequential( | |||||
nn.Dropout(), | |||||
nn.Linear(256 * 6 * 6, 4096), | |||||
nn.ReLU(inplace=True), | |||||
nn.Dropout(), | |||||
nn.Linear(4096, 4096), | |||||
nn.ReLU(inplace=True), | |||||
nn.Linear(4096, num_classes), | |||||
) | |||||
def forward(self, x): | |||||
x = self.features(x) | |||||
x = x.view(x.size(0), 256 * 6 * 6) | |||||
x = self.classifier(x) | |||||
return x | |||||
def alexnet(pretrained=False, **kwargs): | |||||
r"""AlexNet model architecture from the | |||||
`"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper. | |||||
Args: | |||||
pretrained (bool): If True, returns a model pre-trained on ImageNet | |||||
""" | |||||
num_classes = kwargs.pop('num_classes', None) | |||||
model = AlexNet(**kwargs) | |||||
if pretrained: | |||||
model.load_state_dict(model_zoo.load_url(model_urls['alexnet'])) | |||||
if num_classes is not None and num_classes!=1000: | |||||
model.classifier[-1] = nn.Linear(4096, num_classes) | |||||
return model |
@@ -0,0 +1 @@ | |||||
from . import wrn |
@@ -0,0 +1,108 @@ | |||||
#Adapted from https://github.com/polo5/ZeroShotKnowledgeTransfer/blob/master/models/wresnet.py | |||||
import math | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
__all__ = ['wrn'] | |||||
class BasicBlock(nn.Module): | |||||
def __init__(self, in_planes, out_planes, stride, dropout_rate=0.0): | |||||
super(BasicBlock, self).__init__() | |||||
self.bn1 = nn.BatchNorm2d(in_planes) | |||||
self.relu1 = nn.ReLU(inplace=True) | |||||
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, | |||||
padding=1, bias=False) | |||||
self.bn2 = nn.BatchNorm2d(out_planes) | |||||
self.relu2 = nn.ReLU(inplace=True) | |||||
self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, | |||||
padding=1, bias=False) | |||||
self.dropout = nn.Dropout( dropout_rate ) | |||||
self.equalInOut = (in_planes == out_planes) | |||||
self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, | |||||
padding=0, bias=False) or None | |||||
def forward(self, x): | |||||
if not self.equalInOut: | |||||
x = self.relu1(self.bn1(x)) | |||||
else: | |||||
out = self.relu1(self.bn1(x)) | |||||
out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) | |||||
out = self.dropout(out) | |||||
out = self.conv2(out) | |||||
return torch.add(x if self.equalInOut else self.convShortcut(x), out) | |||||
class NetworkBlock(nn.Module): | |||||
def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropout_rate=0.0): | |||||
super(NetworkBlock, self).__init__() | |||||
self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropout_rate) | |||||
def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropout_rate): | |||||
layers = [] | |||||
for i in range(nb_layers): | |||||
layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropout_rate)) | |||||
return nn.Sequential(*layers) | |||||
def forward(self, x): | |||||
return self.layer(x) | |||||
class WideResNet(nn.Module): | |||||
def __init__(self, depth, num_classes, widen_factor=1, dropout_rate=0.0): | |||||
super(WideResNet, self).__init__() | |||||
nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] | |||||
assert (depth - 4) % 6 == 0, 'depth should be 6n+4' | |||||
n = (depth - 4) // 6 | |||||
block = BasicBlock | |||||
# 1st conv before any network block | |||||
self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, | |||||
padding=1, bias=False) | |||||
# 1st block | |||||
self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropout_rate) | |||||
# 2nd block | |||||
self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropout_rate) | |||||
# 3rd block | |||||
self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropout_rate) | |||||
# global average pooling and classifier | |||||
self.bn1 = nn.BatchNorm2d(nChannels[3]) | |||||
self.relu = nn.ReLU(inplace=True) | |||||
self.fc = nn.Linear(nChannels[3], num_classes) | |||||
self.nChannels = nChannels[3] | |||||
for m in self.modules(): | |||||
if isinstance(m, nn.Conv2d): | |||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |||||
m.weight.data.normal_(0, math.sqrt(2. / n)) | |||||
elif isinstance(m, nn.BatchNorm2d): | |||||
m.weight.data.fill_(1) | |||||
m.bias.data.zero_() | |||||
elif isinstance(m, nn.Linear): | |||||
m.bias.data.zero_() | |||||
def forward(self, x, return_features=False): | |||||
out = self.conv1(x) | |||||
out = self.block1(out) | |||||
out = self.block2(out) | |||||
out = self.block3(out) | |||||
out = self.relu(self.bn1(out)) | |||||
out = F.avg_pool2d(out, 8) | |||||
features = out.view(-1, self.nChannels) | |||||
out = self.fc(features) | |||||
if return_features: | |||||
return out, features | |||||
else: | |||||
return out | |||||
def wrn_16_1(num_classes, dropout_rate=0): | |||||
return WideResNet(depth=16, num_classes=num_classes, widen_factor=1, dropout_rate=dropout_rate) | |||||
def wrn_16_2(num_classes, dropout_rate=0): | |||||
return WideResNet(depth=16, num_classes=num_classes, widen_factor=2, dropout_rate=dropout_rate) | |||||
def wrn_40_1(num_classes, dropout_rate=0): | |||||
return WideResNet(depth=40, num_classes=num_classes, widen_factor=1, dropout_rate=dropout_rate) | |||||
def wrn_40_2(num_classes, dropout_rate=0): | |||||
return WideResNet(depth=40, num_classes=num_classes, widen_factor=2, dropout_rate=dropout_rate) |