| @@ -0,0 +1,2 @@ | |||
| /.idea/ | |||
| *.iml | |||
| @@ -0,0 +1,7 @@ | |||
| FROM tensorflow/tensorflow:2.4.1 | |||
| WORKDIR /app | |||
| RUN pip install web.py tf2onnx | |||
| COPY . /app | |||
| ENTRYPOINT ["python3", "main.py"] | |||
| @@ -0,0 +1,86 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| import json | |||
| import os | |||
| import subprocess | |||
| import logging | |||
| import web | |||
| from subprocess import PIPE | |||
| urls = ( | |||
| '/hello', 'Hello', | |||
| '/model_convert', 'ModelConvert' | |||
| ) | |||
| logging.basicConfig(filename='onnx.log', level=logging.DEBUG) | |||
| class Hello(object): | |||
| def GET(self): | |||
| return 'service alive' | |||
| class ModelConvert(object): | |||
| def POST(self): | |||
| data = web.data() | |||
| web.header('Content-Type', 'application/json') | |||
| try: | |||
| json_data = json.loads(data) | |||
| model_path = json_data['model_path'] | |||
| output_path = json_data['output_path'] | |||
| if not os.path.isdir(model_path): | |||
| msg = 'model_path is not a dir: %s' % model_path | |||
| logging.error(msg) | |||
| return json.dumps({'code': 501, 'msg': msg, 'data': ''}) | |||
| if not output_path.endswith('/'): | |||
| msg = 'output_path is not a dir: %s' % output_path | |||
| logging.error(msg) | |||
| return json.dumps({'code': 502, 'msg': msg, 'data': ''}) | |||
| exist_flag = exist(model_path) | |||
| if not exist_flag: | |||
| msg = 'SavedModel file does not exist at: %s' % model_path | |||
| logging.error(msg) | |||
| return json.dumps({'code': 503, 'msg': msg, 'data': ''}) | |||
| convert_flag, msg = convert(model_path, output_path) | |||
| if not convert_flag: | |||
| return json.dumps({'code': 504, 'msg': msg, 'data': ''}) | |||
| except Exception as e: | |||
| logging.error(str(e)) | |||
| return json.dumps({'code': 505, 'msg': str(e), 'data': ''}) | |||
| return json.dumps({'code': 200, 'msg': 'ok', 'data': msg}) | |||
| def exist(model_path): | |||
| for file in os.listdir(model_path): | |||
| if file=='saved_model.pbtxt' or file=='saved_model.pb': | |||
| return True | |||
| return False | |||
| def convert(model_path, output_path): | |||
| output_path = output_path+'model.onnx' | |||
| try: | |||
| logging.info('model_path=%s, output_path=%s' % (model_path, output_path)) | |||
| result = subprocess.run(["python", "-m", "tf2onnx.convert", "--saved-model", model_path, "--output", output_path], stdout=PIPE, stderr=PIPE) | |||
| logging.info(repr(result)) | |||
| if result.returncode != 0: | |||
| return False, str(result.stderr) | |||
| except Exception as e: | |||
| logging.error(str(e)) | |||
| return False, str(e) | |||
| return True, output_path | |||
| if __name__ == '__main__': | |||
| app = web.application(urls, globals()) | |||
| app.run() | |||
| @@ -0,0 +1,130 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| dependencies = ['torch', 'kamal'] # 依赖项 | |||
| import json | |||
| import traceback | |||
| import os | |||
| import torch | |||
| import web | |||
| import kamal | |||
| from PIL import Image | |||
| from kamal.transferability import TransferabilityGraph | |||
| from kamal.transferability.trans_metric import AttrMapMetric | |||
| from torchvision.models import * | |||
| urls = ( | |||
| '/model_measure/measure', 'Measure', | |||
| '/model_measure/package', 'Package', | |||
| ) | |||
| app = web.application(urls, globals()) | |||
| class Package: | |||
| def POST(self): | |||
| req = web.data() | |||
| save_path_list = [] | |||
| for json_data in json.loads(req): | |||
| try: | |||
| metadata = json_data['metadata'] | |||
| except Exception as e: | |||
| traceback.print_exc() | |||
| return json.dumps(Response(506, 'Failed to package model, Error: %s'% (traceback.format_exc(limit=1)), save_path_list).__dict__) | |||
| entry_name = json_data['entry_name'] | |||
| for name, fn in kamal.hub.list_entry(__file__): | |||
| if entry_name in name: | |||
| try: | |||
| dataset = metadata['dataset'] | |||
| model = fn(pretrained=False, num_classes=metadata['entry_args']['num_classes']) | |||
| num_params = sum( [ torch.numel(p) for p in model.parameters() ] ) | |||
| save_path_for_measure = '%sfinegraind_%s/' % (json_data['ckpt'], entry_name) | |||
| save_path = save_path_for_measure+'%s' % (metadata['name']) | |||
| ckpt = self.file_name(json_data['ckpt']) | |||
| if ckpt == '': | |||
| return json.dumps( | |||
| Response(506, 'Failed to package model [%s]: No .pth file was found in directory ckpt' % (entry_name), save_path_list).__dict__) | |||
| model.load_state_dict(torch.load(ckpt), False) | |||
| kamal.hub.save( # 该调用将用户的pytorch模型打包成上述格式,并存储至指定位置 | |||
| model, # 需要保存的模型 nn.Module | |||
| save_path=save_path, | |||
| # 导出文件夹名称 | |||
| entry_name=entry_name, # 入口函数名,需要与上边的入口函数一致 | |||
| spec_name=None, # 具体的参数版本名,为空则自动用md5替代 | |||
| code_path=__file__, # 模型依赖的代码,可以是文件夹(必须包含hubconf.py文件), | |||
| # 或者是当前hubconf.py, 例子中直接使用了依赖中的模型实现,故只需指定为本文件即可 | |||
| metadata=metadata, | |||
| tags=dict( | |||
| num_params=num_params, | |||
| metadata=metadata, | |||
| name=metadata['name'], | |||
| url=metadata['url'], | |||
| dataset=dataset, | |||
| img_size=metadata['input']['size'], | |||
| readme=json_data['readme']) | |||
| ) | |||
| save_path_list.append(save_path_for_measure) | |||
| return json.dumps(Response(200, 'Success', save_path_list).__dict__) | |||
| except Exception: | |||
| traceback.print_exc() | |||
| return json.dumps(Response(506,'Failed to package model [%s], Error: %s' % (entry_name, traceback.format_exc(limit=1)), save_path_list).__dict__) | |||
| return json.dumps(Response(506, 'Failed to package model [%s], Error: %s' % (entry_name, traceback.format_exc(limit=1)), save_path_list).__dict__) | |||
| def file_name(self, file_dir): | |||
| for root, dirs, files in os.walk(file_dir): | |||
| for file in files: | |||
| if file.endswith('pth'): | |||
| return root + file | |||
| return '' | |||
| class Measure: | |||
| def POST(self): | |||
| req = web.data() | |||
| json_data = json.loads(req) | |||
| print(json_data) | |||
| try: | |||
| measure_name = 'measure' | |||
| zoo_set = json_data['zoo_set'] | |||
| probe_set_root = json_data['probe_set_root'] | |||
| export_path = json_data['export_path'] | |||
| output_filename_list = [] | |||
| TG = TransferabilityGraph(zoo_set) | |||
| print("Add %s" % (probe_set_root)) | |||
| imgs_set = list(os.listdir(os.path.join(probe_set_root))) | |||
| images = [Image.open(os.path.join(probe_set_root, img)) for img in imgs_set] | |||
| metric = AttrMapMetric(images, device=torch.device('cuda')) | |||
| TG.add_metric(probe_set_root, metric) | |||
| isExists = os.path.exists(export_path) | |||
| if not isExists: | |||
| # 如果不存在则创建目录 | |||
| os.makedirs(export_path) | |||
| output_filename = export_path+'%s.json' % (measure_name) | |||
| TG.export_to_json(probe_set_root, output_filename, topk=3, normalize=True) | |||
| output_filename_list.append(output_filename) | |||
| except Exception: | |||
| traceback.print_exc() | |||
| return json.dumps(Response(506, 'Failed to generate measurement file of [%s], Error: %s' % (probe_set_root, traceback.format_exc(limit=1)), output_filename_list).__dict__) | |||
| return json.dumps(Response(200, 'Success', output_filename_list).__dict__) | |||
| class Response: | |||
| def __init__(self, code, msg, data): | |||
| self.code = code | |||
| self.msg = msg | |||
| self.data = data | |||
| if __name__ == "__main__": | |||
| app.run() | |||
| @@ -0,0 +1,5 @@ | |||
| from .core import tasks, metrics, engine, callbacks, hub | |||
| from . import amalgamation, slim, vision, transferability | |||
| from .core import load, save | |||
| @@ -0,0 +1,4 @@ | |||
| from .layerwise_amalgamation import LayerWiseAmalgamator | |||
| from .common_feature import CommonFeatureAmalgamator | |||
| from .task_branching import TaskBranchingAmalgamator, JointSegNet | |||
| from .recombination import RecombinationAmalgamator, CombinedModel | |||
| @@ -0,0 +1,291 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| from kamal.core.engine.engine import Engine | |||
| from kamal.core.engine.hooks import FeatureHook | |||
| from kamal.core import tasks | |||
| from kamal.utils import set_mode, move_to_device | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| import typing, time | |||
| import numpy as np | |||
| def conv3x3(in_planes, out_planes, stride=1): | |||
| """3x3 convolution with padding""" | |||
| return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, | |||
| padding=1, bias=False) | |||
| class ResBlock(nn.Module): | |||
| """ Residual Blocks | |||
| """ | |||
| def __init__(self, inplanes, planes, stride=1, momentum=0.1): | |||
| super(ResBlock, self).__init__() | |||
| self.conv1 = conv3x3(inplanes, planes, stride) | |||
| self.bn1 = nn.BatchNorm2d(planes, momentum=momentum) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| self.conv2 = conv3x3(planes, planes) | |||
| self.bn2 = nn.BatchNorm2d(planes, momentum=momentum) | |||
| if stride > 1 or inplanes != planes: | |||
| self.downsample = nn.Sequential( | |||
| nn.Conv2d(inplanes, planes, kernel_size=1, | |||
| stride=stride, bias=False), | |||
| nn.BatchNorm2d(planes, momentum=momentum) | |||
| ) | |||
| else: | |||
| self.downsample = None | |||
| self.stride = stride | |||
| def forward(self, x): | |||
| residual = x | |||
| out = self.conv1(x) | |||
| out = self.bn1(out) | |||
| out = self.relu(out) | |||
| out = self.conv2(out) | |||
| out = self.bn2(out) | |||
| if self.downsample is not None: | |||
| residual = self.downsample(x) | |||
| out += residual | |||
| out = self.relu(out) | |||
| return out | |||
| class CFL_FCBlock(nn.Module): | |||
| """Common Feature Blocks for Fully-Connected layer | |||
| This module is used to capture the common features of multiple teachers and calculate mmd with features of student. | |||
| **Parameters:** | |||
| - cs (int): channel number of student features | |||
| - channel_ts (list or tuple): channel number list of teacher features | |||
| - ch (int): channel number of hidden features | |||
| """ | |||
| def __init__(self, cs, cts, ch, k_size=5): | |||
| super(CFL_FCBlock, self).__init__() | |||
| self.align_t = nn.ModuleList() | |||
| for ct in cts: | |||
| self.align_t.append( | |||
| nn.Sequential( | |||
| nn.Linear(ct, ch), | |||
| nn.ReLU(inplace=True) | |||
| ) | |||
| ) | |||
| self.align_s = nn.Sequential( | |||
| nn.Linear(cs, ch), | |||
| nn.ReLU(inplace=True), | |||
| ) | |||
| self.extractor = nn.Sequential( | |||
| nn.Linear(ch, ch), | |||
| nn.ReLU(), | |||
| nn.Linear(ch, ch), | |||
| ) | |||
| self.dec_t = nn.ModuleList() | |||
| for ct in cts: | |||
| self.dec_t.append( | |||
| nn.Sequential( | |||
| nn.Linear(ch, ct), | |||
| nn.ReLU(inplace=True), | |||
| nn.Linear(ct, ct) | |||
| ) | |||
| ) | |||
| def init_weights(self): | |||
| for m in self.modules(): | |||
| if isinstance(m, nn.Linear): | |||
| torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu') | |||
| elif isinstance(m, nn.BatchNorm2d): | |||
| m.weight.data.fill_(1) | |||
| m.bias.data.zero_() | |||
| def forward(self, fs, fts): | |||
| aligned_t = [self.align_t[i](fts[i]) for i in range(len(fts))] | |||
| aligned_s = self.align_s(fs) | |||
| hts = [self.extractor(f) for f in aligned_t] | |||
| hs = self.extractor(aligned_s) | |||
| _fts = [self.dec_t[i](hts[i]) for i in range(len(hts))] | |||
| return (hs, hts), (_fts, fts) | |||
| class CFL_ConvBlock(nn.Module): | |||
| """Common Feature Blocks for Convolutional layer | |||
| This module is used to capture the common features of multiple teachers and calculate mmd with features of student. | |||
| **Parameters:** | |||
| - cs (int): channel number of student features | |||
| - channel_ts (list or tuple): channel number list of teacher features | |||
| - ch (int): channel number of hidden features | |||
| """ | |||
| def __init__(self, cs, cts, ch, k_size=5): | |||
| super(CFL_ConvBlock, self).__init__() | |||
| self.align_t = nn.ModuleList() | |||
| for ct in cts: | |||
| self.align_t.append( | |||
| nn.Sequential( | |||
| nn.Conv2d(in_channels=ct, out_channels=ch, | |||
| kernel_size=1), | |||
| nn.BatchNorm2d(ch), | |||
| nn.ReLU(inplace=True) | |||
| ) | |||
| ) | |||
| self.align_s = nn.Sequential( | |||
| nn.Conv2d(in_channels=cs, out_channels=ch, | |||
| kernel_size=1), | |||
| nn.BatchNorm2d(ch), | |||
| nn.ReLU(inplace=True), | |||
| ) | |||
| self.extractor = nn.Sequential( | |||
| ResBlock(inplanes=ch, planes=ch, stride=1), | |||
| ResBlock(inplanes=ch, planes=ch, stride=1), | |||
| ) | |||
| self.dec_t = nn.ModuleList() | |||
| for ct in cts: | |||
| self.dec_t.append( | |||
| nn.Sequential( | |||
| nn.Conv2d(ch, ch, kernel_size=1, stride=1), | |||
| nn.BatchNorm2d(ch), | |||
| nn.ReLU(inplace=True), | |||
| nn.Conv2d(ch, ct, kernel_size=1, stride=1) | |||
| ) | |||
| ) | |||
| def init_weights(self): | |||
| for m in self.modules(): | |||
| if isinstance(m, nn.Conv2d): | |||
| torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu') | |||
| elif isinstance(m, nn.BatchNorm2d): | |||
| m.weight.data.fill_(1) | |||
| m.bias.data.zero_() | |||
| def forward(self, fs, fts): | |||
| aligned_t = [self.align_t[i](fts[i]) for i in range(len(fts))] | |||
| aligned_s = self.align_s(fs) | |||
| hts = [self.extractor(f) for f in aligned_t] | |||
| hs = self.extractor(aligned_s) | |||
| _fts = [self.dec_t[i](hts[i]) for i in range(len(hts))] | |||
| return (hs, hts), (_fts, fts) | |||
| class CommonFeatureAmalgamator(Engine): | |||
| def setup( | |||
| self, | |||
| student, | |||
| teachers, | |||
| layer_groups: typing.Sequence[typing.Sequence], | |||
| layer_channels: typing.Sequence[typing.Sequence], | |||
| dataloader: torch.utils.data.DataLoader, | |||
| optimizer: torch.optim.Optimizer, | |||
| weights = [1.0, 1.0, 1.0], | |||
| on_layer_input=False, | |||
| device = None, | |||
| ): | |||
| self._dataloader = dataloader | |||
| if device is None: | |||
| device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) | |||
| self._device = device | |||
| self._model = self._student = student.to(self._device) | |||
| self._teachers = nn.ModuleList(teachers).to(self._device) | |||
| self._optimizer = optimizer | |||
| self._weights = weights | |||
| self._on_layer_input = on_layer_input | |||
| amal_blocks = [] | |||
| for group, C in zip( layer_groups, layer_channels ): | |||
| hooks = [ FeatureHook(layer) for layer in group ] | |||
| if isinstance(group[0], nn.Linear): | |||
| amal_block = CFL_FCBlock( cs=C[0], cts=C[1:], ch=C[0]//4 ).to(self._device).train() | |||
| print("Building FC Blocks") | |||
| else: | |||
| amal_block = CFL_ConvBlock(cs=C[0], cts=C[1:], ch=C[0]//4).to(self._device).train() | |||
| print("Building Conv Blocks") | |||
| amal_blocks.append( (amal_block, hooks, C) ) | |||
| self._amal_blocks = amal_blocks | |||
| self._cfl_criterion = tasks.loss.CFLLoss( sigmas=[0.001, 0.01, 0.05, 0.1, 0.2, 1, 2] ) | |||
| @property | |||
| def device(self): | |||
| return self._device | |||
| def run(self, max_iter, start_iter=0, epoch_length=None): | |||
| block_params = [] | |||
| for block, _, _ in self._amal_blocks: | |||
| block_params.extend( list(block.parameters()) ) | |||
| if isinstance( self._optimizer, torch.optim.SGD ): | |||
| self._amal_optimimizer = torch.optim.SGD( block_params, lr=self._optimizer.param_groups[0]['lr'], momentum=0.9, weight_decay=1e-4 ) | |||
| else: | |||
| self._amal_optimimizer = torch.optim.Adam( block_params, lr=self._optimizer.param_groups[0]['lr'], weight_decay=1e-4 ) | |||
| self._amal_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self._amal_optimimizer, T_max=max_iter ) | |||
| with set_mode(self._student, training=True), \ | |||
| set_mode(self._teachers, training=False): | |||
| super( CommonFeatureAmalgamator, self ).run(self.step_fn, self._dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length) | |||
| def step_fn(self, engine, batch): | |||
| start_time = time.perf_counter() | |||
| batch = move_to_device(batch, self._device) | |||
| data = batch[0] | |||
| s_out = self._student( data ) | |||
| with torch.no_grad(): | |||
| t_out = [ teacher( data ) for teacher in self._teachers ] | |||
| loss_amal = 0 | |||
| loss_recons = 0 | |||
| for amal_block, hooks, C in self._amal_blocks: | |||
| features = [ h.feat_in if self._on_layer_input else h.feat_out for h in hooks ] | |||
| fs, fts = features[0], features[1:] | |||
| (hs, hts), (_fts, fts) = amal_block( fs, fts ) | |||
| _loss_amal, _loss_recons = self._cfl_criterion( hs, hts, _fts, fts ) | |||
| loss_amal += _loss_amal | |||
| loss_recons += _loss_recons | |||
| loss_kd = tasks.loss.kldiv( s_out, torch.cat( t_out, dim=1 ) ) | |||
| loss_dict = { | |||
| 'loss_kd': self._weights[0]*loss_kd, | |||
| 'loss_amal': self._weights[1]*loss_amal, | |||
| 'loss_recons': self._weights[2]*loss_recons | |||
| } | |||
| loss = sum(loss_dict.values()) | |||
| self._optimizer.zero_grad() | |||
| self._amal_optimimizer.zero_grad() | |||
| loss.backward() | |||
| self._optimizer.step() | |||
| self._amal_optimimizer.step() | |||
| self._amal_scheduler.step() | |||
| step_time = time.perf_counter() - start_time | |||
| metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } | |||
| metrics.update({ | |||
| 'total_loss': loss.item(), | |||
| 'step_time': step_time, | |||
| 'lr': float( self._optimizer.param_groups[0]['lr'] ) | |||
| }) | |||
| return metrics | |||
| @@ -0,0 +1,131 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from kamal.core.engine.engine import Engine | |||
| from kamal.core.engine.hooks import FeatureHook | |||
| from kamal.core import tasks | |||
| from kamal.utils import set_mode | |||
| import typing | |||
| import time | |||
| from kamal.utils import move_to_device, set_mode | |||
| class AmalBlock(nn.Module): | |||
| def __init__(self, cs, cts): | |||
| super( AmalBlock, self ).__init__() | |||
| self.cs, self.cts = cs, cts | |||
| self.enc = nn.Conv2d( in_channels=sum(self.cts), out_channels=self.cs, kernel_size=1, stride=1, padding=0, bias=True ) | |||
| self.fam = nn.Conv2d( in_channels=self.cs, out_channels=self.cs, kernel_size=1, stride=1, padding=0, bias=True ) | |||
| self.dec = nn.Conv2d( in_channels=self.cs, out_channels=sum(self.cts), kernel_size=1, stride=1, padding=0, bias=True ) | |||
| def forward(self, fs, fts): | |||
| rep = self.enc( torch.cat( fts, dim=1 ) ) | |||
| _fts = self.dec( rep ) | |||
| _fts = torch.split( _fts, self.cts, dim=1 ) | |||
| _fs = self.fam( fs ) | |||
| return rep, _fs, _fts | |||
| class LayerWiseAmalgamator(Engine): | |||
| def setup( | |||
| self, | |||
| student, | |||
| teachers, | |||
| layer_groups: typing.Sequence[typing.Sequence], | |||
| layer_channels: typing.Sequence[typing.Sequence], | |||
| dataloader: torch.utils.data.DataLoader, | |||
| optimizer: torch.optim.Optimizer, | |||
| weights = [1., 1., 1.], | |||
| device=None, | |||
| ): | |||
| if device is None: | |||
| device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) | |||
| self._device = device | |||
| self._dataloader = dataloader | |||
| self.model = self.student = student.to(self.device) | |||
| self.teachers = nn.ModuleList(teachers).to(self.device) | |||
| self.optimizer = optimizer | |||
| self._weights = weights | |||
| amal_blocks = [] | |||
| for group, C in zip(layer_groups, layer_channels): | |||
| hooks = [ FeatureHook(layer) for layer in group ] | |||
| amal_block = AmalBlock(cs=C[0], cts=C[1:]).to(self.device).train() | |||
| amal_blocks.append( (amal_block, hooks, C) ) | |||
| self._amal_blocks = amal_blocks | |||
| @property | |||
| def device(self): | |||
| return self._device | |||
| def run(self, max_iter, start_iter=0, epoch_length=None ): | |||
| block_params = [] | |||
| for block, _, _ in self._amal_blocks: | |||
| block_params.extend( list(block.parameters()) ) | |||
| if isinstance( self.optimizer, torch.optim.SGD ): | |||
| self._amal_optimimizer = torch.optim.SGD( block_params, lr=self.optimizer.param_groups[0]['lr'], momentum=0.9, weight_decay=1e-4 ) | |||
| else: | |||
| self._amal_optimimizer = torch.optim.Adam( block_params, lr=self.optimizer.param_groups[0]['lr'], weight_decay=1e-4 ) | |||
| self._amal_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self._amal_optimimizer, T_max=max_iter ) | |||
| with set_mode(self.student, training=True), \ | |||
| set_mode(self.teachers, training=False): | |||
| super( LayerWiseAmalgamator, self ).run(self.step_fn, self._dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length) | |||
| @property | |||
| def device(self): | |||
| return self._device | |||
| def step_fn(self, engine, batch): | |||
| start_time = time.perf_counter() | |||
| batch = move_to_device(batch, self._device) | |||
| data = batch[0] | |||
| s_out = self.student( data ) | |||
| with torch.no_grad(): | |||
| t_out = [ teacher( data ) for teacher in self.teachers ] | |||
| loss_amal = 0 | |||
| loss_recons = 0 | |||
| for amal_block, hooks, C in self._amal_blocks: | |||
| features = [ h.feat_out for h in hooks ] | |||
| fs, fts = features[0], features[1:] | |||
| rep, _fs, _fts = amal_block( fs, fts ) | |||
| loss_amal += F.mse_loss( _fs, rep.detach() ) | |||
| loss_recons += sum( [ F.mse_loss( _ft, ft ) for (_ft, ft) in zip( _fts, fts ) ] ) | |||
| loss_kd = tasks.loss.kldiv( s_out, torch.cat( t_out, dim=1 ) ) | |||
| #loss_kd = F.mse_loss( s_out, torch.cat( t_out, dim=1 ) ) | |||
| loss_dict = { "loss_kd": self._weights[0] * loss_kd, | |||
| "loss_amal": self._weights[1] * loss_amal, | |||
| "loss_recons": self._weights[2] * loss_recons } | |||
| loss = sum(loss_dict.values()) | |||
| self.optimizer.zero_grad() | |||
| self._amal_optimimizer.zero_grad() | |||
| loss.backward() | |||
| self.optimizer.step() | |||
| self._amal_optimimizer.step() | |||
| self._amal_scheduler.step() | |||
| step_time = time.perf_counter() - start_time | |||
| metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } | |||
| metrics.update({ | |||
| 'total_loss': loss.item(), | |||
| 'step_time': step_time, | |||
| 'lr': float( self.optimizer.param_groups[0]['lr'] ) | |||
| }) | |||
| return metrics | |||
| @@ -0,0 +1,209 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from copy import deepcopy | |||
| from typing import Callable | |||
| from kamal.core.engine.engine import Engine | |||
| from kamal.core.engine.trainer import KDTrainer | |||
| from kamal.core.engine.hooks import FeatureHook | |||
| from kamal.core import tasks | |||
| import math | |||
| from kamal.slim.prunning import Pruner, strategy | |||
| def _assert_same_type(layers, layer_type=None): | |||
| if layer_type is None: | |||
| layer_type = type(layers[0]) | |||
| assert all(isinstance(l, layer_type) for l in layers), 'Model archictures must be the same' | |||
| def _get_layers(model_list): | |||
| submodel = [ model.modules() for model in model_list ] | |||
| for layers in zip(*submodel): | |||
| _assert_same_type(layers) | |||
| yield layers | |||
| def bn_combine_fn(layers): | |||
| """Combine 2D Batch Normalization Layers | |||
| **Parameters:** | |||
| - **layers** (BatchNorm2D): Batch Normalization Layers. | |||
| """ | |||
| _assert_same_type(layers, nn.BatchNorm2d) | |||
| num_features = sum(l.num_features for l in layers) | |||
| combined_bn = nn.BatchNorm2d(num_features=num_features, | |||
| eps=layers[0].eps, | |||
| momentum=layers[0].momentum, | |||
| affine=layers[0].affine, | |||
| track_running_stats=layers[0].track_running_stats) | |||
| combined_bn.running_mean = torch.cat( | |||
| [l.running_mean for l in layers], dim=0).clone() | |||
| combined_bn.running_var = torch.cat( | |||
| [l.running_var for l in layers], dim=0).clone() | |||
| if combined_bn.affine: | |||
| combined_bn.weight = torch.nn.Parameter( | |||
| torch.cat([l.weight.data.clone() for l in layers], dim=0).clone()) | |||
| combined_bn.bias = torch.nn.Parameter( | |||
| torch.cat([l.bias.data.clone() for l in layers], dim=0).clone()) | |||
| return combined_bn | |||
| def conv2d_combine_fn(layers): | |||
| """Combine 2D Conv Layers | |||
| **Parameters:** | |||
| - **layers** (Conv2d): Conv Layers. | |||
| """ | |||
| _assert_same_type(layers, nn.Conv2d) | |||
| CO, CI = 0, 0 | |||
| for l in layers: | |||
| O, I, H, W = l.weight.shape | |||
| CO += O | |||
| CI += I | |||
| dtype = layers[0].weight.dtype | |||
| device = layers[0].weight.device | |||
| combined_weight = torch.nn.Parameter( | |||
| torch.zeros(CO, CI, H, W, dtype=dtype, device=device)) | |||
| if layers[0].bias is not None: | |||
| combined_bias = torch.nn.Parameter( | |||
| torch.zeros(CO, dtype=dtype, device=device)) | |||
| else: | |||
| combined_bias = None | |||
| co_offset = 0 | |||
| ci_offset = 0 | |||
| for idx, l in enumerate(layers): | |||
| co_len, ci_len = l.weight.shape[0], l.weight.shape[1] | |||
| combined_weight[co_offset: co_offset+co_len, | |||
| ci_offset: ci_offset+ci_len, :, :] = l.weight.clone() | |||
| if combined_bias is not None: | |||
| combined_bias[co_offset: co_offset+co_len] = l.bias.clone() | |||
| co_offset += co_len | |||
| ci_offset += ci_offset | |||
| combined_conv2d = nn.Conv2d(in_channels=CI, | |||
| out_channels=CO, | |||
| kernel_size=layers[0].weight.shape[-2:], | |||
| stride=layers[0].stride, | |||
| padding=layers[0].padding, | |||
| bias=layers[0].bias is not None) | |||
| combined_conv2d.weight.data = combined_weight | |||
| if combined_bias is not None: | |||
| combined_conv2d.bias.data = combined_bias | |||
| for p in combined_conv2d.parameters(): | |||
| p.requires_grad = True | |||
| return combined_conv2d | |||
| def combine_models(models): | |||
| """Combine modules with parser | |||
| **Parameters:** | |||
| - **models** (nn.Module): modules to be combined. | |||
| - **combine_parser** (function): layer selector | |||
| """ | |||
| def _recursively_combine(module): | |||
| module_output = module | |||
| if isinstance( module, nn.Conv2d ): | |||
| combined_module = conv2d_combine_fn( layer_mapping[module] ) | |||
| elif isinstance( module, nn.BatchNorm2d ): | |||
| combined_module = bn_combine_fn( layer_mapping[module] ) | |||
| else: | |||
| combined_module = module | |||
| if combined_module is not None: | |||
| module_output = combined_module | |||
| for name, child in module.named_children(): | |||
| module_output.add_module(name, _recursively_combine(child)) | |||
| return module_output | |||
| models = deepcopy(models) | |||
| combined_model = deepcopy(models[0]) # copy the model archicture and modify it with _recursively_combine | |||
| layer_mapping = {} | |||
| for combined_layer, layers in zip(combined_model.modules(), _get_layers(models)): | |||
| layer_mapping[combined_layer] = layers # link to teachers | |||
| combined_model = _recursively_combine(combined_model) | |||
| return combined_model | |||
| class CombinedModel(nn.Module): | |||
| def __init__(self, models): | |||
| super( Combination, self ).__init__() | |||
| self.combined_model = combine_models( models ) | |||
| self.expand = len(models) | |||
| def forward(self, x): | |||
| x.repeat( -1, x.shape[1]*self.expand, -1, -1 ) | |||
| return self.combined_model(x) | |||
| class PruningKDTrainer(KDTrainer): | |||
| def setup( | |||
| self, | |||
| student, | |||
| teachers, | |||
| task, | |||
| dataloader: torch.utils.data.DataLoader, | |||
| get_optimizer_and_scheduler:Callable=None, | |||
| pruning_rounds=5, | |||
| device=None, | |||
| ): | |||
| if device is None: | |||
| device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) | |||
| self._device = device | |||
| self._dataloader = dataloader | |||
| self.model = self.student = student.to(self.device) | |||
| self.teachers = nn.ModuleList(teachers).to(self.device) | |||
| self.get_optimizer_and_scheduler = get_optimizer_and_scheduler | |||
| @property | |||
| def device(self): | |||
| return self._device | |||
| def run(self, max_iter, start_iter=0, epoch_length=None, pruning_rounds=3, target_model_size=0.6 ): | |||
| pruning_size_per_round = 1 - math.pow( target_model_size, 1/pruning_rounds ) | |||
| prunner = Pruner( strategy.LNStrategy(n=1) ) | |||
| for pruning_round in range(pruning_rounds): | |||
| prunner.prune( self.student, rate=pruning_size_per_round, example_inputs=torch.randn(1,3,240,240) ) | |||
| self.student.to(self.device) | |||
| if self.get_optimizer_and_scheduler: | |||
| self.optimizer, self.scheduler = self.get_optimizer_and_scheduler( self.student ) | |||
| else: | |||
| self.optimizer = torch.optim.Adam( self.student.parameters(), lr=1e-4, weight_decay=1e-5 ) | |||
| self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self.optimizer, T_max= (max_iter-start_iter)//pruning_rounds ) | |||
| step_iter = (max_iter - start_iter)//pruning_rounds | |||
| with set_mode(self.student, training=True), \ | |||
| set_mode(self.teachers, training=False): | |||
| super( RecombinationAmalgamation, self ).run(self.step_fn, self._dataloader, | |||
| start_iter=start_iter+step_iter*pruning_round , max_iter=start_iter+step_iter*(pruning_round+1), epoch_length=epoch_length) | |||
| def step_fn(self, engine, batch): | |||
| metrics = super(RecombinationAmalgamation, self).step_fn( engine, batch ) | |||
| self.scheduler.step() | |||
| return metrics | |||
| class RecombinationAmalgamator(PruningKDTrainer): | |||
| pass | |||
| @@ -0,0 +1,295 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from kamal.core.engine.engine import Engine | |||
| from kamal.core.engine.hooks import FeatureHook | |||
| from kamal.core import tasks | |||
| from kamal.utils import move_to_device, set_mode | |||
| from kamal.core.hub import meta | |||
| from kamal import vision | |||
| import kamal | |||
| from kamal.utils import set_mode | |||
| import typing | |||
| import time | |||
| from copy import deepcopy | |||
| import random | |||
| import numpy as np | |||
| from collections import defaultdict | |||
| import numbers | |||
| class BranchySegNet(nn.Module): | |||
| def __init__(self, out_channels, segnet_fn=vision.models.segmentation.segnet_vgg16_bn): | |||
| super(BranchySegNet, self).__init__() | |||
| channels=[512, 512, 256, 128, 64] | |||
| self.register_buffer( 'branch_indices', torch.zeros((len(out_channels),)) ) | |||
| self.student_b_decoders_list = nn.ModuleList() | |||
| self.student_adaptors_list = nn.ModuleList() | |||
| ses = [] | |||
| for i in range(5): | |||
| se = int(channels[i]/4) | |||
| ses.append(16 if se < 16 else se) | |||
| for oc in out_channels: | |||
| segnet = self.get_segnet( oc, segnet_fn ) | |||
| decoders = nn.ModuleList(deepcopy(list(segnet.children())[5:])) | |||
| adaptors = nn.ModuleList() | |||
| for i in range(5): | |||
| adaptor = nn.Sequential( | |||
| nn.Conv2d(channels[i], ses[i], kernel_size=1, bias=False), | |||
| nn.ReLU(), | |||
| nn.Conv2d(ses[i], channels[i], kernel_size=1, bias=False), | |||
| nn.Sigmoid() | |||
| ) | |||
| adaptors.append(adaptor) | |||
| self.student_b_decoders_list.append(decoders) | |||
| self.student_adaptors_list.append(adaptors) | |||
| self.student_encoders = nn.ModuleList(deepcopy(list(segnet.children())[0:5])) | |||
| self.student_decoders = nn.ModuleList(deepcopy(list(segnet.children())[5:])) | |||
| def set_branch(self, branch_indices): | |||
| assert len(branch_indices)==len(self.student_b_decoders_list) | |||
| self.branch_indices = torch.from_numpy( np.array( branch_indices ) ).to(self.branch_indices.device) | |||
| def get_segnet(self, oc, segnet_fn): | |||
| return segnet_fn( num_classes=oc, pretrained_backbone=True ) | |||
| def forward(self, inputs): | |||
| output_list = [] | |||
| down1, indices_1, unpool_shape1 = self.student_encoders[0](inputs) | |||
| down2, indices_2, unpool_shape2 = self.student_encoders[1](down1) | |||
| down3, indices_3, unpool_shape3 = self.student_encoders[2](down2) | |||
| down4, indices_4, unpool_shape4 = self.student_encoders[3](down3) | |||
| down5, indices_5, unpool_shape5 = self.student_encoders[4](down4) | |||
| up5 = self.student_decoders[0](down5, indices_5, unpool_shape5) | |||
| up4 = self.student_decoders[1](up5, indices_4, unpool_shape4) | |||
| up3 = self.student_decoders[2](up4, indices_3, unpool_shape3) | |||
| up2 = self.student_decoders[3](up3, indices_2, unpool_shape2) | |||
| up1 = self.student_decoders[4](up2, indices_1, unpool_shape1) | |||
| decoder_features = [down5, up5, up4, up3, up2] | |||
| decoder_indices = [indices_5, indices_4, indices_3, indices_2, indices_1] | |||
| decoder_shapes = [unpool_shape5, unpool_shape4, unpool_shape3, unpool_shape2, unpool_shape1] | |||
| # Mimic teachers. | |||
| for i in range(len(self.branch_indices)): | |||
| out_idx = self.branch_indices[i] | |||
| output = decoder_features[out_idx] | |||
| output = output * self.student_adaptors_list[i][out_idx](F.avg_pool2d(output, output.shape[2:3])) | |||
| for j in range(out_idx, 5): | |||
| output = self.student_b_decoders_list[i][j]( | |||
| output, | |||
| decoder_indices[j], | |||
| decoder_shapes[j] | |||
| ) | |||
| output_list.append( output ) | |||
| return output_list | |||
| class JointSegNet(nn.Module): | |||
| """The online student model to learn from any number of single teacher with 'SegNet' structure. | |||
| **Parameters:** | |||
| - **teachers** (list of 'Module' object): Teachers with 'SegNet' structure to learn from. | |||
| - **indices** (list of int): Where to branch out for each task. | |||
| - **phase** (string): Should be 'block' or 'finetune'. Useful only in training mode. | |||
| - **channels** (list of int, optional): Parameter to build adaptor modules, corresponding to that of 'SegNet'. | |||
| """ | |||
| def __init__(self, teachers, student=None, channels=[512, 512, 256, 128, 64]): | |||
| super(JointSegNet, self).__init__() | |||
| self.register_buffer( 'branch_indices', torch.zeros((2,)) ) | |||
| if student is None: | |||
| student = teachers[0] | |||
| self.student_encoders = nn.ModuleList(deepcopy(list(teachers[0].children())[0:5])) | |||
| self.student_decoders = nn.ModuleList(deepcopy(list(teachers[0].children())[5:])) | |||
| self.student_b_decoders_list = nn.ModuleList() | |||
| self.student_adaptors_list = nn.ModuleList() | |||
| ses = [] | |||
| for i in range(5): | |||
| se = int(channels[i]/4) | |||
| ses.append(16 if se < 16 else se) | |||
| for teacher in teachers: | |||
| decoders = nn.ModuleList(deepcopy(list(teacher.children())[5:])) | |||
| adaptors = nn.ModuleList() | |||
| for i in range(5): | |||
| adaptor = nn.Sequential( | |||
| nn.Conv2d(channels[i], ses[i], kernel_size=1, bias=False), | |||
| nn.ReLU(), | |||
| nn.Conv2d(ses[i], channels[i], kernel_size=1, bias=False), | |||
| nn.Sigmoid() | |||
| ) | |||
| adaptors.append(adaptor) | |||
| self.student_b_decoders_list.append(decoders) | |||
| self.student_adaptors_list.append(adaptors) | |||
| def set_branch(self, branch_indices): | |||
| assert len(branch_indices)==len(self.student_b_decoders_list) | |||
| self.branch_indices = torch.from_numpy( np.array( branch_indices ) ).to(self.branch_indices.device) | |||
| def forward(self, inputs): | |||
| output_list = [] | |||
| down1, indices_1, unpool_shape1 = self.student_encoders[0](inputs) | |||
| down2, indices_2, unpool_shape2 = self.student_encoders[1](down1) | |||
| down3, indices_3, unpool_shape3 = self.student_encoders[2](down2) | |||
| down4, indices_4, unpool_shape4 = self.student_encoders[3](down3) | |||
| down5, indices_5, unpool_shape5 = self.student_encoders[4](down4) | |||
| up5 = self.student_decoders[0](down5, indices_5, unpool_shape5) | |||
| up4 = self.student_decoders[1](up5, indices_4, unpool_shape4) | |||
| up3 = self.student_decoders[2](up4, indices_3, unpool_shape3) | |||
| up2 = self.student_decoders[3](up3, indices_2, unpool_shape2) | |||
| up1 = self.student_decoders[4](up2, indices_1, unpool_shape1) | |||
| decoder_features = [down5, up5, up4, up3, up2] | |||
| decoder_indices = [indices_5, indices_4, indices_3, indices_2, indices_1] | |||
| decoder_shapes = [unpool_shape5, unpool_shape4, unpool_shape3, unpool_shape2, unpool_shape1] | |||
| # Mimic teachers. | |||
| for i in range(len(self.branch_indices)): | |||
| out_idx = self.branch_indices[i] | |||
| output = decoder_features[out_idx] | |||
| output = output * self.student_adaptors_list[i][out_idx](F.avg_pool2d(output, output.shape[2:3])) | |||
| for j in range(out_idx, 5): | |||
| output = self.student_b_decoders_list[i][j]( | |||
| output, | |||
| decoder_indices[j], | |||
| decoder_shapes[j] | |||
| ) | |||
| output_list.append( output ) | |||
| return output_list | |||
| class TaskBranchingAmalgamator(Engine): | |||
| def setup( | |||
| self, | |||
| joint_student: JointSegNet, | |||
| teachers, | |||
| tasks, | |||
| dataloader: torch.utils.data.DataLoader, | |||
| optimizer: torch.optim.Optimizer, | |||
| device=None, | |||
| ): | |||
| if device is None: | |||
| device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) | |||
| self._device = device | |||
| self._dataloader = dataloader | |||
| self.student = self.model = joint_student.to(self._device) | |||
| self.teachers = nn.ModuleList(teachers).to(self._device) | |||
| self.tasks = tasks | |||
| self.optimizer = optimizer | |||
| self.is_finetuning=False | |||
| @property | |||
| def device(self): | |||
| return self._device | |||
| def run(self, max_iter, start_iter=0, epoch_length=None, stage_callback=None ): | |||
| # Branching | |||
| with set_mode(self.student, training=True), \ | |||
| set_mode(self.teachers, training=False): | |||
| super( TaskBranchingAmalgamator, self ).run(self.step_fn, self._dataloader, start_iter=start_iter, max_iter=max_iter//2, epoch_length=epoch_length) | |||
| branch = self.find_the_best_branch( self._dataloader ) | |||
| self.logger.info("[Task Branching] the best branch indices: %s"%(branch)) | |||
| if stage_callback is not None: | |||
| stage_callback() | |||
| # Finetuning | |||
| self.is_finetuning = True | |||
| with set_mode(self.student, training=True), \ | |||
| set_mode(self.teachers, training=False): | |||
| super( TaskBranchingAmalgamator, self ).run(self.step_fn, self._dataloader, start_iter=max_iter-max_iter//2, max_iter=max_iter, epoch_length=epoch_length) | |||
| def find_the_best_branch(self, dataloader): | |||
| with set_mode(self.student, training=False), \ | |||
| set_mode(self.teachers, training=False), \ | |||
| torch.no_grad(): | |||
| n_blocks = len(self.student.student_decoders) | |||
| branch_loss = { task: [0 for _ in range(n_blocks)] for task in self.tasks } | |||
| for batch in dataloader: | |||
| batch = move_to_device(batch, self.device) | |||
| data = batch if isinstance(batch, torch.Tensor) else batch[0] | |||
| for candidate_branch in range( n_blocks ): | |||
| self.student.set_branch( [candidate_branch for _ in range(len(self.teachers))] ) | |||
| s_out_list = self.student( data ) | |||
| t_out_list = [ teacher( data ) for teacher in self.teachers ] | |||
| for task, s_out, t_out in zip( self.tasks, s_out_list, t_out_list ): | |||
| task_loss = task.get_loss( s_out, t_out ) | |||
| branch_loss[task][candidate_branch] += sum(task_loss.values()) | |||
| best_brach = [] | |||
| for task in self.tasks: | |||
| best_brach.append( int(np.argmin( branch_loss[task] )) ) | |||
| self.student.set_branch(best_brach) | |||
| return best_brach | |||
| @property | |||
| def device(self): | |||
| return self._device | |||
| def step_fn(self, engine, batch): | |||
| start_time = time.perf_counter() | |||
| batch = move_to_device(batch, self._device) | |||
| data = batch[0] | |||
| #data = batch if isinstance(batch, torch.Tensor) else batch[0] | |||
| data, None | |||
| n_blocks = len(self.student.student_decoders) | |||
| if not self.is_finetuning: | |||
| rand_branch_indices = [ random.randint(0, n_blocks-1) for _ in range(len(self.teachers)) ] | |||
| self.student.set_branch(rand_branch_indices) | |||
| s_out_list = self.student( data ) | |||
| with torch.no_grad(): | |||
| t_out_list = [ teacher( data ) for teacher in self.teachers ] | |||
| loss_dict = {} | |||
| for task, s_out, t_out in zip( self.tasks, s_out_list, t_out_list ): | |||
| task_loss = task.get_loss( s_out, t_out ) | |||
| loss_dict.update( task_loss ) | |||
| loss = sum(loss_dict.values()) | |||
| self.optimizer.zero_grad() | |||
| loss.backward() | |||
| self.optimizer.step() | |||
| step_time = time.perf_counter() - start_time | |||
| metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } | |||
| metrics.update({ | |||
| 'total_loss': loss.item(), | |||
| 'step_time': step_time, | |||
| 'lr': float( self.optimizer.param_groups[0]['lr'] ), | |||
| 'branch': self.student.branch_indices.cpu().numpy().tolist() | |||
| }) | |||
| return metrics | |||
| @@ -0,0 +1,4 @@ | |||
| from . import engine, tasks, metrics, callbacks, exceptions, hub | |||
| from .attach import AttachTo | |||
| from .hub import load, save | |||
| @@ -0,0 +1,45 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| from typing import Sequence, Callable | |||
| from numbers import Number | |||
| from kamal.core import exceptions | |||
| class AttachTo(object): | |||
| """ Attach task, metrics or visualizer to specified model outputs | |||
| """ | |||
| def __init__(self, attach_to=None): | |||
| if attach_to is not None and not isinstance(attach_to, (Sequence, Number, str, Callable) ): | |||
| raise exceptions.InvalidMapping | |||
| self._attach_to = attach_to | |||
| def __call__(self, *tensors): | |||
| if self._attach_to is not None: | |||
| if isinstance(self._attach_to, Callable): | |||
| return self._attach_to( *tensors ) | |||
| if isinstance(self._attach_to, Sequence): | |||
| _attach_to = self._attach_to | |||
| else: | |||
| _attach_to = [ self._attach_to for _ in range(len(tensors)) ] | |||
| _attach_to = _attach_to[:len(tensors)] | |||
| tensors = [ tensor[index] for (tensor, index) in zip( tensors, _attach_to ) ] | |||
| if len(tensors)==1: | |||
| tensors = tensors[0] | |||
| return tensors | |||
| def __repr__(self): | |||
| rep = "AttachTo: %s"%(self._attach_to) | |||
| @@ -0,0 +1,5 @@ | |||
| from .logging import MetricsLogging, ProgressCallback | |||
| from .base import Callback | |||
| from .eval_and_ckpt import EvalAndCkpt | |||
| from .scheduler import LRSchedulerCallback | |||
| from .visualize import VisualizeOutputs, VisualizeSegmentation, VisualizeDepth | |||
| @@ -0,0 +1,28 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| import abc | |||
| class Callback(abc.ABC): | |||
| r""" Base Class for Callbacks | |||
| """ | |||
| def __init__(self): | |||
| pass | |||
| @abc.abstractmethod | |||
| def __call__(self, engine): | |||
| pass | |||
| @@ -0,0 +1,145 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| from .base import Callback | |||
| import weakref | |||
| from kamal import utils | |||
| from typing import Sequence, Optional | |||
| import numbers | |||
| import os, shutil | |||
| import torch | |||
| class EvalAndCkpt(Callback): | |||
| def __init__(self, | |||
| model, | |||
| evaluator, | |||
| metric_name:str, | |||
| metric_mode:str ='max', | |||
| save_type:Optional[Sequence]=('best', 'latest'), | |||
| ckpt_dir:str ='checkpoints', | |||
| ckpt_prefix:str =None, | |||
| log_tag:str ='model', | |||
| weights_only:bool =True, | |||
| verbose:bool =False,): | |||
| super(EvalAndCkpt, self).__init__() | |||
| self.metric_name = metric_name | |||
| assert metric_mode in ('max', 'min'), "metric_mode should be 'max' or 'min'" | |||
| self._metric_mode = metric_mode | |||
| self._model = weakref.ref( model ) | |||
| self._evaluator = evaluator | |||
| self._ckpt_dir = ckpt_dir | |||
| self._ckpt_prefix = "" if ckpt_prefix is None else (ckpt_prefix+'_') | |||
| if isinstance(save_type, str): | |||
| save_type = ( save_type, ) | |||
| self._save_type = save_type | |||
| if self._save_type is not None: | |||
| for save_type in self._save_type: | |||
| assert save_type in ('best', 'latest', 'all'), \ | |||
| 'save_type should be None or a subset of (\"best\", \"latest\", \"all\")' | |||
| self._log_tag = log_tag | |||
| self._weights_only = weights_only | |||
| self._verbose = verbose | |||
| self._best_score = -999999 if self._metric_mode=='max' else 99999. | |||
| self._best_ckpt = None | |||
| self._latest_ckpt = None | |||
| @property | |||
| def best_ckpt(self): | |||
| return self._best_ckpt | |||
| @property | |||
| def latest_ckpt(self): | |||
| return self._latest_ckpt | |||
| @property | |||
| def best_score(self): | |||
| return self._best_score | |||
| def __call__(self, trainer): | |||
| model = self._model() | |||
| results = self._evaluator.eval( model, device=trainer.device ) | |||
| results = utils.flatten_dict(results) | |||
| current_score = results[self.metric_name] | |||
| scalar_results = { k: float(v) for (k, v) in results.items() if isinstance(v, numbers.Number) or len(v.shape)==0 } | |||
| if trainer.logger is not None: | |||
| trainer.logger.info( "[Eval %s] Iter %d/%d: %s"%(self._log_tag, trainer.state.iter, trainer.state.max_iter, scalar_results) ) | |||
| trainer.state.metrics.update( scalar_results ) | |||
| # Visualize results if trainer.tb_writer is not None | |||
| if trainer.tb_writer is not None: | |||
| for k, v in scalar_results.items(): | |||
| log_tag = "%s:%s"%(self._log_tag, k) | |||
| trainer.tb_writer.add_scalar(log_tag, v, global_step=trainer.state.iter) | |||
| if self._save_type is not None: | |||
| pth_path_list = [] | |||
| # interval model | |||
| if 'interval' in self._save_type: | |||
| pth_path = os.path.join(self._ckpt_dir, "%s%08d_%s_%.3f.pth" | |||
| % (self._ckpt_prefix, trainer.state.iter, self.metric_name, current_score)) | |||
| pth_path_list.append(pth_path) | |||
| # the latest model | |||
| if 'latest' in self._save_type: | |||
| pth_path = os.path.join(self._ckpt_dir, "%slatest_%08d_%s_%.3f.pth" | |||
| % (self._ckpt_prefix, trainer.state.iter, self.metric_name, current_score)) | |||
| # remove the old ckpt | |||
| if self._latest_ckpt is not None and os.path.exists(self._latest_ckpt): | |||
| os.remove(self._latest_ckpt) | |||
| pth_path_list.append(pth_path) | |||
| self._latest_ckpt = pth_path | |||
| # the best model | |||
| if 'best' in self._save_type: | |||
| if (current_score >= self._best_score and self._metric_mode=='max') or \ | |||
| (current_score <= self._best_score and self._metric_mode=='min'): | |||
| pth_path = os.path.join(self._ckpt_dir, "%sbest_%08d_%s_%.4f.pth" % | |||
| (self._ckpt_prefix, trainer.state.iter, self.metric_name, current_score)) | |||
| # remove the old ckpt | |||
| if self._best_ckpt is not None and os.path.exists(self._best_ckpt): | |||
| os.remove(self._best_ckpt) | |||
| pth_path_list.append(pth_path) | |||
| self._best_score = current_score | |||
| self._best_ckpt = pth_path | |||
| # save model | |||
| if self._verbose and trainer.logger: | |||
| trainer.logger.info("Model saved as:") | |||
| obj = model.state_dict() if self._weights_only else model | |||
| os.makedirs( self._ckpt_dir, exist_ok=True ) | |||
| for pth_path in pth_path_list: | |||
| torch.save(obj, pth_path) | |||
| if self._verbose and trainer.logger: | |||
| trainer.logger.info("\t%s" % (pth_path)) | |||
| def final_ckpt(self, ckpt_prefix=None, ckpt_dir=None, add_md5=False): | |||
| if ckpt_dir is None: | |||
| ckpt_dir = self._ckpt_dir | |||
| if ckpt_prefix is None: | |||
| ckpt_prefix = self._ckpt_prefix | |||
| if self._save_type is not None: | |||
| #if 'latest' in self._save_type and self._latest_ckpt is not None: | |||
| # os.makedirs(ckpt_dir, exist_ok=True) | |||
| # save_name = "%slatest%s.pth"%(ckpt_prefix, "" if not add_md5 else "-%s"%utils.md5(self._latest_ckpt)) | |||
| # shutil.copy2(self._latest_ckpt, os.path.join(ckpt_dir, save_name)) | |||
| if 'best' in self._save_type and self._best_ckpt is not None: | |||
| os.makedirs(ckpt_dir, exist_ok=True) | |||
| save_name = "%sbest%s.pth"%(ckpt_prefix, "" if not add_md5 else "-%s"%utils.md5(self._best_ckpt)) | |||
| shutil.copy2(self._best_ckpt, os.path.join(ckpt_dir, save_name)) | |||
| @@ -0,0 +1,61 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| from .base import Callback | |||
| import numbers | |||
| from tqdm import tqdm | |||
| class MetricsLogging(Callback): | |||
| def __init__(self, keys): | |||
| super(MetricsLogging, self).__init__() | |||
| self._keys = keys | |||
| def __call__(self, engine): | |||
| if engine.logger==None: | |||
| return | |||
| state = engine.state | |||
| content = "Iter %d/%d (Epoch %d/%d, Batch %d/%d)"%( | |||
| state.iter, state.max_iter, | |||
| state.current_epoch, state.max_epoch, | |||
| state.current_batch_index, state.max_batch_index | |||
| ) | |||
| for key in self._keys: | |||
| value = state.metrics.get(key, None) | |||
| if value is not None: | |||
| if isinstance(value, numbers.Number): | |||
| content += " %s=%.4f"%(key, value) | |||
| if engine.tb_writer is not None: | |||
| engine.tb_writer.add_scalar(key, value, global_step=state.iter) | |||
| elif isinstance(value, (list, tuple)): | |||
| content += " %s=%s"%(key, value) | |||
| engine.logger.info(content) | |||
| class ProgressCallback(Callback): | |||
| def __init__(self, max_iter=100, tag=None): | |||
| self._max_iter = max_iter | |||
| self._tag = tag | |||
| #self._pbar = tqdm(total=self._max_iter, desc=self._tag) | |||
| def __call__(self, engine): | |||
| self._pbar.update(1) | |||
| if self._pbar.n==self._max_iter: | |||
| self._pbar.close() | |||
| def reset(self): | |||
| self._pbar = tqdm(total=self._max_iter, desc=self._tag) | |||
| @@ -0,0 +1,34 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| from .base import Callback | |||
| from typing import Sequence | |||
| class LRSchedulerCallback(Callback): | |||
| r""" LR scheduler callback | |||
| """ | |||
| def __init__(self, schedulers=None): | |||
| super(LRSchedulerCallback, self).__init__() | |||
| if not isinstance(schedulers, Sequence): | |||
| schedulers = ( schedulers, ) | |||
| self._schedulers = schedulers | |||
| def __call__(self, trainer): | |||
| if self._schedulers is None: | |||
| return | |||
| for sched in self._schedulers: | |||
| sched.step() | |||
| @@ -0,0 +1,161 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| from .base import Callback | |||
| from typing import Callable, Union, Sequence | |||
| import weakref | |||
| import random | |||
| from kamal.utils import move_to_device, set_mode, split_batch, colormap | |||
| from kamal.core.attach import AttachTo | |||
| import torch | |||
| import numpy as np | |||
| import matplotlib.pyplot as plt | |||
| import matplotlib | |||
| matplotlib.use('agg') | |||
| import math | |||
| import numbers | |||
| class VisualizeOutputs(Callback): | |||
| def __init__(self, | |||
| model, | |||
| dataset: torch.utils.data.Dataset, | |||
| idx_list_or_num_vis: Union[int, Sequence]=5, | |||
| normalizer: Callable=None, | |||
| prepare_fn: Callable=None, | |||
| decode_fn: Callable=None, # decode targets and preds | |||
| tag: str='viz'): | |||
| super(VisualizeOutputs, self).__init__() | |||
| self._dataset = dataset | |||
| self._model = weakref.ref(model) | |||
| if isinstance(idx_list_or_num_vis, int): | |||
| self.idx_list = self._get_vis_idx_list(self._dataset, idx_list_or_num_vis) | |||
| elif isinstance(idx_list_or_num_vis, Sequence): | |||
| self.idx_list = idx_list_or_num_vis | |||
| self._normalizer = normalizer | |||
| self._decode_fn = decode_fn | |||
| if prepare_fn is None: | |||
| prepare_fn = VisualizeOutputs.get_prepare_fn() | |||
| self._prepare_fn = prepare_fn | |||
| self._tag = tag | |||
| def _get_vis_idx_list(self, dataset, num_vis): | |||
| return random.sample(list(range(len(dataset))), num_vis) | |||
| @torch.no_grad() | |||
| def __call__(self, trainer): | |||
| if trainer.tb_writer is None: | |||
| trainer.logger.warning("summary writer was not found in trainer") | |||
| return | |||
| device = trainer.device | |||
| model = self._model() | |||
| with torch.no_grad(), set_mode(model, training=False): | |||
| for img_id, idx in enumerate(self.idx_list): | |||
| batch = move_to_device(self._dataset[idx], device) | |||
| batch = [ d.unsqueeze(0) for d in batch ] | |||
| inputs, targets, preds = self._prepare_fn(model, batch) | |||
| if self._normalizer is not None: | |||
| inputs = self._normalizer(inputs) | |||
| inputs = inputs.detach().cpu().numpy() | |||
| preds = preds.detach().cpu().numpy() | |||
| targets = targets.detach().cpu().numpy() | |||
| if self._decode_fn: # to RGB 0~1 NCHW | |||
| preds = self._decode_fn(preds) | |||
| targets = self._decode_fn(targets) | |||
| inputs = inputs[0] | |||
| preds = preds[0] | |||
| targets = targets[0] | |||
| trainer.tb_writer.add_images("%s-%d"%(self._tag, img_id), np.stack( [inputs, targets, preds], axis=0), global_step=trainer.state.iter) | |||
| @staticmethod | |||
| def get_prepare_fn(attach_to=None, pred_fn=lambda x: x): | |||
| attach_to = AttachTo(attach_to) | |||
| def wrapper(model, batch): | |||
| inputs, targets = split_batch(batch) | |||
| outputs = model(inputs) | |||
| outputs, targets = attach_to(outputs, targets) | |||
| return inputs, targets, pred_fn(outputs) | |||
| return wrapper | |||
| @staticmethod | |||
| def get_seg_decode_fn(cmap=colormap(), index_transform=lambda x: x+1): # 255->0, 0->1, | |||
| def wrapper(preds): | |||
| if len(preds.shape)>3: | |||
| preds = preds.squeeze(1) | |||
| out = cmap[ index_transform(preds.astype('uint8')) ] | |||
| out = out.transpose(0, 3, 1, 2) / 255 | |||
| return out | |||
| return wrapper | |||
| @staticmethod | |||
| def get_depth_decode_fn(max_depth, log_scale=True, cmap=plt.get_cmap('jet')): | |||
| def wrapper(preds): | |||
| if log_scale: | |||
| _max_depth = np.log( max_depth ) | |||
| preds = np.log( preds ) | |||
| else: | |||
| _max_depth = max_depth | |||
| if len(preds.shape)>3: | |||
| preds = preds.squeeze(1) | |||
| out = (cmap(preds.clip(0, _max_depth)/_max_depth)).transpose(0, 3, 1, 2)[:, :3] | |||
| return out | |||
| return wrapper | |||
| class VisualizeSegmentation(VisualizeOutputs): | |||
| def __init__( | |||
| self, model, dataset: torch.utils.data.Dataset, idx_list_or_num_vis: Union[int, Sequence]=5, | |||
| cmap = colormap(), | |||
| attach_to=None, | |||
| normalizer: Callable=None, | |||
| prepare_fn: Callable=None, | |||
| decode_fn: Callable=None, | |||
| tag: str='seg' | |||
| ): | |||
| if prepare_fn is None: | |||
| prepare_fn = VisualizeOutputs.get_prepare_fn(attach_to=attach_to, pred_fn=lambda x: x.max(1)[1]) | |||
| if decode_fn is None: | |||
| decode_fn = VisualizeOutputs.get_seg_decode_fn(cmap=cmap, index_transform=lambda x: x+1) | |||
| super(VisualizeSegmentation, self).__init__( | |||
| model=model, dataset=dataset, idx_list_or_num_vis=idx_list_or_num_vis, | |||
| normalizer=normalizer, prepare_fn=prepare_fn, decode_fn=decode_fn, | |||
| tag=tag | |||
| ) | |||
| class VisualizeDepth(VisualizeOutputs): | |||
| def __init__( | |||
| self, model, dataset: torch.utils.data.Dataset, | |||
| idx_list_or_num_vis: Union[int, Sequence]=5, | |||
| max_depth = 10, | |||
| log_scale = True, | |||
| attach_to = None, | |||
| normalizer: Callable=None, | |||
| prepare_fn: Callable=None, | |||
| decode_fn: Callable=None, | |||
| tag: str='depth' | |||
| ): | |||
| if prepare_fn is None: | |||
| prepare_fn = VisualizeOutputs.get_prepare_fn(attach_to=attach_to, pred_fn=lambda x: x) | |||
| if decode_fn is None: | |||
| decode_fn = VisualizeOutputs.get_depth_decode_fn(max_depth=max_depth, log_scale=log_scale) | |||
| super(VisualizeDepth, self).__init__( | |||
| model=model, dataset=dataset, idx_list_or_num_vis=idx_list_or_num_vis, | |||
| normalizer=normalizer, prepare_fn=prepare_fn, decode_fn=decode_fn, | |||
| tag=tag | |||
| ) | |||
| @@ -0,0 +1,7 @@ | |||
| from . import evaluator | |||
| from . import trainer | |||
| from . import lr_finder | |||
| from . import hooks | |||
| from . import events | |||
| from .engine import DefaultEvents, Event | |||
| @@ -0,0 +1,190 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| import torch | |||
| import torch.nn as nn | |||
| import abc, math, weakref, typing, time | |||
| from typing import Any, Callable, Optional, Sequence | |||
| import numpy as np | |||
| from kamal.core.engine.events import DefaultEvents, Event | |||
| from kamal.core import tasks | |||
| from kamal.utils import set_mode, move_to_device, get_logger | |||
| from collections import defaultdict | |||
| import numbers | |||
| import contextlib | |||
| class State(object): | |||
| def __init__(self): | |||
| self.iter = 0 | |||
| self.max_iter = None | |||
| self.epoch_length = None | |||
| self.dataloader = None | |||
| self.seed = None | |||
| self.metrics=dict() | |||
| self.batch=None | |||
| @property | |||
| def current_epoch(self): | |||
| if self.epoch_length is not None: | |||
| return self.iter // self.epoch_length | |||
| return None | |||
| @property | |||
| def max_epoch(self): | |||
| if self.epoch_length is not None: | |||
| return self.max_iter // self.epoch_length | |||
| return None | |||
| @property | |||
| def current_batch_index(self): | |||
| if self.epoch_length is not None: | |||
| return self.iter % self.epoch_length | |||
| return None | |||
| @property | |||
| def max_batch_index(self): | |||
| return self.epoch_length | |||
| def __repr__(self): | |||
| rep = "State:\n" | |||
| for attr, value in self.__dict__.items(): | |||
| if not isinstance(value, (numbers.Number, str, dict)): | |||
| value = type(value) | |||
| rep += "\t{}: {}\n".format(attr, value) | |||
| return rep | |||
| class Engine(abc.ABC): | |||
| def __init__(self, logger=None, tb_writer=None): | |||
| self._logger = logger if logger else get_logger(name='kamal', color=True) | |||
| self._tb_writer = tb_writer | |||
| self._callbacks = defaultdict(list) | |||
| self._allowed_events = [ *DefaultEvents ] | |||
| self._state = State() | |||
| def reset(self): | |||
| self._state = State() | |||
| def run(self, step_fn: Callable, dataloader, max_iter, start_iter=0, epoch_length=None): | |||
| self.state.iter = self._state.start_iter = start_iter | |||
| self.state.max_iter = max_iter | |||
| self.state.epoch_length = epoch_length if epoch_length else len(dataloader) | |||
| self.state.dataloader = dataloader | |||
| self.state.dataloader_iter = iter(dataloader) | |||
| self.state.step_fn = step_fn | |||
| self.trigger_events(DefaultEvents.BEFORE_RUN) | |||
| for self.state.iter in range( start_iter, max_iter ): | |||
| if self.state.epoch_length!=None and \ | |||
| self.state.iter%self.state.epoch_length==0: # Epoch Start | |||
| self.trigger_events(DefaultEvents.BEFORE_EPOCH) | |||
| self.trigger_events(DefaultEvents.BEFORE_STEP) | |||
| self.state.batch = self._get_batch() | |||
| step_output = step_fn(self, self.state.batch) | |||
| if isinstance(step_output, dict): | |||
| self.state.metrics.update(step_output) | |||
| self.trigger_events(DefaultEvents.AFTER_STEP) | |||
| if self.state.epoch_length!=None and \ | |||
| (self.state.iter+1)%self.state.epoch_length==0: # Epoch End | |||
| self.trigger_events(DefaultEvents.AFTER_EPOCH) | |||
| self.trigger_events(DefaultEvents.AFTER_RUN) | |||
| def _get_batch(self): | |||
| try: | |||
| batch = next( self.state.dataloader_iter ) | |||
| except StopIteration: | |||
| self.state.dataloader_iter = iter(self.state.dataloader) # reset iterator | |||
| batch = next( self.state.dataloader_iter ) | |||
| if not isinstance(batch, (list, tuple)): | |||
| batch = [ batch, ] # no targets | |||
| return batch | |||
| @property | |||
| def state(self): | |||
| return self._state | |||
| @property | |||
| def logger(self): | |||
| return self._logger | |||
| @property | |||
| def tb_writer(self): | |||
| return self._tb_writer | |||
| def add_callback(self, event: Event, callbacks ): | |||
| if not isinstance(callbacks, Sequence): | |||
| callbacks = [callbacks] | |||
| if event in self._allowed_events: | |||
| for callback in callbacks: | |||
| if callback not in self._callbacks[event]: | |||
| if event.trigger!=event.default_trigger: | |||
| callback = self._trigger_wrapper(self, event.trigger, callback ) | |||
| self._callbacks[event].append( callback ) | |||
| callbacks = [ RemovableCallback(self, event, c) for c in callbacks ] | |||
| return ( callbacks[0] if len(callbacks)==1 else callbacks ) | |||
| def remove_callback(self, event, callback): | |||
| for c in self._callbacks[event]: | |||
| if c==callback: | |||
| self._callbacks.remove( callback ) | |||
| return True | |||
| return False | |||
| @staticmethod | |||
| def _trigger_wrapper(engine, trigger, callback): | |||
| def wrapper(*args, **kwargs) -> Any: | |||
| if trigger(engine): | |||
| return callback(engine) | |||
| return wrapper | |||
| def trigger_events(self, *events): | |||
| for e in events: | |||
| if e in self._allowed_events: | |||
| for callback in self._callbacks[e]: | |||
| callback(self) | |||
| def register_events(self, *events): | |||
| for e in events: | |||
| if e not in self._allowed_events: | |||
| self._allowed_events.apped( e ) | |||
| @contextlib.contextmanager | |||
| def save_current_callbacks(self): | |||
| temp = self._callbacks | |||
| self._callbacks = defaultdict(list) | |||
| yield | |||
| self._callbacks = temp | |||
| class RemovableCallback: | |||
| def __init__(self, engine, event, callback): | |||
| self._engine = weakref.ref(engine) | |||
| self._callback = weakref.ref(callback) | |||
| self._event = weakref.ref(event) | |||
| @property | |||
| def callback(self): | |||
| return self._callback() | |||
| def remove(self): | |||
| engine = self._engine() | |||
| callback = self._callback() | |||
| event = self._event() | |||
| return engine.remove_callback(event, callback) | |||
| @@ -0,0 +1,130 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| import abc, sys | |||
| import torch | |||
| from tqdm import tqdm | |||
| from kamal.core import metrics | |||
| from kamal.utils import set_mode | |||
| from typing import Any, Callable | |||
| from .engine import Engine | |||
| from .events import DefaultEvents | |||
| from kamal.core import callbacks | |||
| import weakref | |||
| from kamal.utils import move_to_device, split_batch | |||
| class BasicEvaluator(Engine): | |||
| def __init__(self, | |||
| dataloader: torch.utils.data.DataLoader, | |||
| metric: metrics.MetricCompose, | |||
| eval_fn: Callable=None, | |||
| tag: str='Eval', | |||
| progress: bool=False ): | |||
| super( BasicEvaluator, self ).__init__() | |||
| self.dataloader = dataloader | |||
| self.metric = metric | |||
| self.progress = progress | |||
| if progress: | |||
| self.porgress_callback = self.add_callback( | |||
| DefaultEvents.AFTER_STEP, callbacks=callbacks.ProgressCallback(max_iter=len(self.dataloader), tag=tag)) | |||
| self._model = None | |||
| self._tag = tag | |||
| if eval_fn is None: | |||
| eval_fn = BasicEvaluator.default_eval_fn | |||
| self.eval_fn = eval_fn | |||
| def eval(self, model, device=None): | |||
| device = device if device is not None else \ | |||
| torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) | |||
| self._model = weakref.ref(model) # use weakref here | |||
| self.device = device | |||
| self.metric.reset() | |||
| model.to(device) | |||
| if self.progress: | |||
| self.porgress_callback.callback.reset() | |||
| with torch.no_grad(), set_mode(model, training=False): | |||
| super(BasicEvaluator, self).run( self.step_fn, self.dataloader, max_iter=len(self.dataloader) ) | |||
| return self.metric.get_results() | |||
| @property | |||
| def model(self): | |||
| if self._model is not None: | |||
| return self._model() | |||
| return None | |||
| def step_fn(self, engine, batch): | |||
| batch = move_to_device(batch, self.device) | |||
| self.eval_fn( engine, batch ) | |||
| @staticmethod | |||
| def default_eval_fn(evaluator, batch): | |||
| model = evaluator.model | |||
| inputs, targets = split_batch(batch) | |||
| outputs = model( inputs ) | |||
| evaluator.metric.update( outputs, targets ) | |||
| class TeacherEvaluator(BasicEvaluator): | |||
| def __init__(self, | |||
| dataloader: torch.utils.data.DataLoader, | |||
| teacher: torch.nn.Module, | |||
| task, | |||
| metric: metrics.MetricCompose, | |||
| eval_fn: Callable=None, | |||
| tag: str='Eval', | |||
| progress: bool=False ): | |||
| if eval_fn is None: | |||
| eval_fn = TeacherEvaluator.default_eval_fn | |||
| super( TeacherEvaluator, self ).__init__(dataloader=dataloader, metric=metric, eval_fn=eval_fn, tag=tag, progress=progress) | |||
| self._teacher = teacher | |||
| self.task = task | |||
| def eval(self, model, device=None): | |||
| self.teacher.to(device) | |||
| with set_mode(self.teacher, training=False): | |||
| return super(TeacherEvaluator, self).eval( model, device=device ) | |||
| @property | |||
| def model(self): | |||
| if self._model is not None: | |||
| return self._model() | |||
| return None | |||
| @property | |||
| def teacher(self): | |||
| return self._teacher | |||
| def step_fn(self, engine, batch): | |||
| batch = move_to_device(batch, self.device) | |||
| self.eval_fn( engine, batch ) | |||
| @staticmethod | |||
| def default_eval_fn(evaluator, batch): | |||
| model = evaluator.model | |||
| teacher = evaluator.teacher | |||
| inputs, targets = split_batch(batch) | |||
| outputs = model( inputs ) | |||
| # get teacher outputs | |||
| if isinstance(teacher, torch.nn.ModuleList): | |||
| targets = [ task.predict(tea(inputs)) for (tea, task) in zip(teacher, evaluator.task) ] | |||
| else: | |||
| t_outputs = teacher(inputs) | |||
| targets = evaluator.task.predict( t_outputs ) | |||
| evaluator.metric.update( outputs, targets ) | |||
| @@ -0,0 +1,92 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| from typing import Callable, Optional | |||
| from enum import Enum | |||
| class Event(object): | |||
| def __init__(self, value: str, event_trigger: Optional[Callable]=None ): | |||
| if event_trigger is None: | |||
| event_trigger = Event.default_trigger | |||
| self._trigger = event_trigger | |||
| self._name_ = self._value_ = value | |||
| @property | |||
| def trigger(self): | |||
| return self._trigger | |||
| @property | |||
| def name(self): | |||
| """The name of the Enum member.""" | |||
| return self._name_ | |||
| @property | |||
| def value(self): | |||
| """The value of the Enum member.""" | |||
| return self._value_ | |||
| @staticmethod | |||
| def default_trigger(engine): | |||
| return True | |||
| @staticmethod | |||
| def once_trigger(): | |||
| is_triggered = False | |||
| def wrapper(engine): | |||
| if is_triggered: | |||
| return False | |||
| is_triggered=True | |||
| return True | |||
| return wrapper | |||
| @staticmethod | |||
| def every_trigger(every: int): | |||
| def wrapper(engine): | |||
| return every>0 and (engine.state.iter % every)==0 | |||
| return wrapper | |||
| def __call__(self, every: Optional[int]=None, once: Optional[bool]=None ): | |||
| if every is not None: | |||
| assert once is None | |||
| return Event(self.value, event_trigger=Event.every_trigger(every) ) | |||
| if once is not None: | |||
| return Event(self.value, event_trigger=Event.once_trigger() ) | |||
| return Event(self.value) | |||
| def __hash__(self): | |||
| return hash(self._name_) | |||
| def __eq__(self, other): | |||
| if hasattr(other, 'value'): | |||
| return self.value==other.value | |||
| else: | |||
| return | |||
| class DefaultEvents(Event, Enum): | |||
| BEFORE_RUN = "before_train" | |||
| AFTER_RUN = "after_train" | |||
| BEFORE_EPOCH = "before_epoch" | |||
| AFTER_EPOCH = "after_epoch" | |||
| BEFORE_STEP = "before_step" | |||
| AFTER_STEP = "after_step" | |||
| BEFORE_GET_BATCH = "before_get_batch" | |||
| AFTER_GET_BATCH = "after_get_batch" | |||
| @@ -0,0 +1,34 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| class FeatureHook(): | |||
| def __init__(self, module): | |||
| self.module = module | |||
| self.feat_in = None | |||
| self.feat_out = None | |||
| self.register() | |||
| def register(self): | |||
| self._hook = self.module.register_forward_hook(self.hook_fn_forward) | |||
| def remove(self): | |||
| self._hook.remove() | |||
| def hook_fn_forward(self, module, fea_in, fea_out): | |||
| self.feat_in = fea_in[0] | |||
| self.feat_out = fea_out | |||
| @@ -0,0 +1,214 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| import torch | |||
| import os | |||
| import tempfile, uuid | |||
| import contextlib | |||
| import numpy as np | |||
| from tqdm import tqdm | |||
| from kamal.core import callbacks | |||
| from kamal.core.engine import evaluator | |||
| from kamal.core.engine.events import DefaultEvents | |||
| class _ProgressCallback(callbacks.Callback): | |||
| def __init__(self, max_iter, tag): | |||
| self._tag = tag | |||
| self._max_iter = max_iter | |||
| self._pbar = tqdm(total=self._max_iter, desc=self._tag) | |||
| def __call__(self, enigine): | |||
| self._pbar.update(1) | |||
| def reset(self): | |||
| self._pbar = tqdm(total=self._max_iter, desc=self._tag) | |||
| class _LinearLRScheduler(torch.optim.lr_scheduler._LRScheduler): | |||
| """ Linear Scheduler | |||
| """ | |||
| def __init__(self, | |||
| optimizer, | |||
| lr_range, | |||
| max_iter, | |||
| last_epoch=-1): | |||
| self.lr_range = lr_range | |||
| self.max_iter = max_iter | |||
| super(_LinearLRScheduler, self).__init__(optimizer, last_epoch) | |||
| def get_lr(self): | |||
| r = self.last_epoch / self.max_iter | |||
| return [self.lr_range[0] + (self.lr_range[1]-self.lr_range[0]) * r for base_lr in self.base_lrs] | |||
| class _ExponentialLRScheduler(torch.optim.lr_scheduler._LRScheduler): | |||
| def __init__(self, optimizer, lr_range, max_iter, last_epoch=-1): | |||
| self.lr_range = lr_range | |||
| self.max_iter = max_iter | |||
| super(_ExponentialLRScheduler, self).__init__(optimizer, last_epoch) | |||
| def get_lr(self): | |||
| r = self.last_epoch / (self.max_iter - 1) | |||
| return [self.lr_range[0] * (self.lr_range[1] / self.lr_range[0]) ** r for base_lr in self.base_lrs] | |||
| class _LRFinderCallback(object): | |||
| def __init__(self, | |||
| model, | |||
| optimizer, | |||
| metric_name, | |||
| evaluator, | |||
| smooth_momentum): | |||
| self._model = model | |||
| self._optimizer = optimizer | |||
| self._metric_name = metric_name | |||
| self._evaluator = evaluator | |||
| self._smooth_momentum = smooth_momentum | |||
| self._records = [] | |||
| @property | |||
| def records(self): | |||
| return self._records | |||
| def reset(self): | |||
| self._records = [] | |||
| def __call__(self, trainer): | |||
| model = self._model | |||
| optimizer = self._optimizer | |||
| if self._evaluator is not None: | |||
| results = self._evaluator.eval( model ) | |||
| score = float(results[ self._metric_name ]) | |||
| else: | |||
| score = float(trainer.state.metrics[self._metric_name]) | |||
| if self._smooth_momentum>0 and len(self._records)>0: | |||
| score = self._records[-1][1] * self._smooth_momentum + score * (1-self._smooth_momentum) | |||
| lr = optimizer.param_groups[0]['lr'] | |||
| self._records.append( ( lr, score ) ) | |||
| class LRFinder(object): | |||
| def _reset(self): | |||
| init_state = torch.load( self._temp_file ) | |||
| self.trainer.model.load_state_dict( init_state['model'] ) | |||
| self.trainer.optimizer.load_state_dict( init_state['optimizer'] ) | |||
| try: | |||
| self.trainer.reset() | |||
| except: pass | |||
| def adjust_learning_rate(self, optimizer, lr): | |||
| for group in optimizer.param_groups: | |||
| group['lr'] = lr | |||
| def _get_default_lr_range(self, optimizer): | |||
| if isinstance( optimizer, torch.optim.Adam ): | |||
| return ( 1e-5, 1e-2 ) | |||
| elif isinstance( optimizer, torch.optim.SGD ): | |||
| return ( 1e-3, 0.2 ) | |||
| else: | |||
| return ( 1e-5, 0.5) | |||
| def plot(self, polyfit: int=3, log_x=True): | |||
| import matplotlib.pyplot as plt | |||
| lrs = [ rec[0] for rec in self.records ] | |||
| scores = [ rec[1] for rec in self.records ] | |||
| if polyfit is not None: | |||
| z = np.polyfit( lrs, scores, deg=polyfit ) | |||
| fitted_score = np.polyval( z, lrs ) | |||
| fig, ax = plt.subplots() | |||
| ax.plot(lrs, scores, label='score') | |||
| ax.plot(lrs, fitted_score, label='polyfit') | |||
| if log_x: | |||
| plt.xscale('log') | |||
| ax.set_xlabel("Learning rate") | |||
| ax.set_ylabel("Score") | |||
| return fig | |||
| def suggest( self, mode='min', skip_begin=10, skip_end=5, polyfit=None ): | |||
| scores = np.array( [ self.records[i][1] for i in range( len(self.records) ) ] ) | |||
| lrs = np.array( [ self.records[i][0] for i in range( len(self.records) ) ] ) | |||
| if polyfit is not None: | |||
| z = np.polyfit( lrs, scores, deg=polyfit ) | |||
| scores = np.polyval( z, lrs ) | |||
| grad = np.gradient( scores )[skip_begin:-skip_end] | |||
| index = grad.argmin() if mode=='min' else grad.argmax() | |||
| index = skip_begin + index | |||
| return index, self.records[index][0] | |||
| def find(self, | |||
| optimizer, | |||
| model, | |||
| trainer, | |||
| metric_name='total_loss', | |||
| metric_mode='min', | |||
| evaluator=None, | |||
| lr_range=[1e-4, 0.1], | |||
| max_iter=100, | |||
| num_eval=None, | |||
| smooth_momentum=0.9, | |||
| scheduler='exp', # exp | |||
| polyfit=None, # None | |||
| skip=[10, 5], | |||
| progress=True): | |||
| self.optimizer = optimizer | |||
| self.model = model | |||
| self.trainer = trainer | |||
| # save init state | |||
| _filename = str(uuid.uuid4())+'.pth' | |||
| _tempdir = tempfile.gettempdir() | |||
| self._temp_file = os.path.join(_tempdir, _filename) | |||
| init_state = { | |||
| 'optimizer': optimizer.state_dict(), | |||
| 'model': model.state_dict() | |||
| } | |||
| torch.save(init_state, self._temp_file) | |||
| if num_eval is None or num_eval > max_iter: | |||
| num_eval = max_iter | |||
| if lr_range is None: | |||
| lr_range = self._get_default_lr_range(optimizer) | |||
| interval = max_iter // num_eval | |||
| if scheduler=='exp': | |||
| lr_sched = _ExponentialLRScheduler(optimizer, lr_range, max_iter=max_iter) | |||
| else: | |||
| lr_sched = _LinearLRScheduler(optimizer, lr_range, max_iter=max_iter) | |||
| self._lr_callback = callbacks.LRSchedulerCallback(schedulers=[lr_sched]) | |||
| self._finder_callback = _LRFinderCallback(model, optimizer, metric_name, evaluator, smooth_momentum) | |||
| self.adjust_learning_rate( self.optimizer, lr_range[0] ) | |||
| with self.trainer.save_current_callbacks(): | |||
| trainer.add_callback( | |||
| DefaultEvents.AFTER_STEP, callbacks=[ | |||
| self._lr_callback, | |||
| _ProgressCallback(max_iter, '[LR Finder]') ]) | |||
| trainer.add_callback( | |||
| DefaultEvents.AFTER_STEP(interval), callbacks=self._finder_callback) | |||
| self.trainer.run( start_iter=0, max_iter=max_iter ) | |||
| self.records = self._finder_callback.records # get records | |||
| index, best_lr = self.suggest(mode=metric_mode, skip_begin=skip[0], skip_end=skip[1], polyfit=polyfit) | |||
| self._reset() | |||
| del self.model, self.optimizer, self.trainer | |||
| return best_lr | |||
| @@ -0,0 +1,129 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| import torch | |||
| import torch.nn as nn | |||
| from kamal.core.engine.engine import Engine, Event, DefaultEvents, State | |||
| from kamal.core import tasks | |||
| from kamal.utils import set_mode, move_to_device, get_logger, split_batch | |||
| from typing import Callable, Mapping, Any, Sequence | |||
| import time | |||
| import weakref | |||
| class BasicTrainer(Engine): | |||
| def __init__( self, | |||
| logger=None, | |||
| tb_writer=None): | |||
| super(BasicTrainer, self).__init__(logger=logger, tb_writer=tb_writer) | |||
| def setup(self, | |||
| model: torch.nn.Module, | |||
| task: tasks.Task, | |||
| dataloader: torch.utils.data.DataLoader, | |||
| optimizer: torch.optim.Optimizer, | |||
| device: torch.device=None): | |||
| if device is None: | |||
| device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) | |||
| self.device = device | |||
| if isinstance(task, Sequence): | |||
| task = tasks.TaskCompose(task) | |||
| self.task = task | |||
| self.model = model | |||
| self.dataloader = dataloader | |||
| self.optimizer = optimizer | |||
| return self | |||
| def run( self, max_iter, start_iter=0, epoch_length=None): | |||
| self.model.to(self.device) | |||
| with set_mode(self.model, training=True): | |||
| super( BasicTrainer, self ).run( self.step_fn, self.dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length) | |||
| def step_fn(self, engine, batch): | |||
| model = self.model | |||
| start_time = time.perf_counter() | |||
| batch = move_to_device(batch, self.device) | |||
| inputs, targets = split_batch(batch) | |||
| outputs = model(inputs) | |||
| loss_dict = self.task.get_loss(outputs, targets) # get loss | |||
| loss = sum( loss_dict.values() ) | |||
| self.optimizer.zero_grad() | |||
| loss.backward() | |||
| self.optimizer.step() | |||
| step_time = time.perf_counter() - start_time | |||
| metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } | |||
| metrics.update({ | |||
| 'total_loss': loss.item(), | |||
| 'step_time': step_time, | |||
| 'lr': float( self.optimizer.param_groups[0]['lr'] ) | |||
| }) | |||
| return metrics | |||
| class KDTrainer(BasicTrainer): | |||
| def setup(self, | |||
| student: torch.nn.Module, | |||
| teacher: torch.nn.Module, | |||
| task: tasks.Task, | |||
| dataloader: torch.utils.data.DataLoader, | |||
| optimizer: torch.optim.Optimizer, | |||
| device: torch.device=None): | |||
| super(KDTrainer, self).setup( | |||
| model=student, task=task, dataloader=dataloader, optimizer=optimizer, device=device) | |||
| if isinstance(teacher, (list, tuple)): | |||
| if len(teacher)==1: | |||
| teacher=teacher[0] | |||
| else: | |||
| teacher = nn.ModuleList(teacher) | |||
| self.student = self.model | |||
| self.teacher = teacher | |||
| return self | |||
| def run( self, max_iter, start_iter=0, epoch_length=None): | |||
| self.student.to(self.device) | |||
| self.teacher.to(self.device) | |||
| with set_mode(self.student, training=True), \ | |||
| set_mode(self.teacher, training=False): | |||
| super( BasicTrainer, self ).run( | |||
| self.step_fn, self.dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length) | |||
| def step_fn(self, engine, batch): | |||
| model = self.model | |||
| start_time = time.perf_counter() | |||
| batch = move_to_device(batch, self.device) | |||
| inputs, targets = split_batch(batch) | |||
| outputs = model(inputs) | |||
| if isinstance(self.teacher, nn.ModuleList): | |||
| soft_targets = [ t(inputs) for t in self.teacher ] | |||
| else: | |||
| soft_targets = self.teacher(inputs) | |||
| loss_dict = self.task.get_loss(outputs, soft_targets) # get loss | |||
| loss = sum( loss_dict.values() ) | |||
| self.optimizer.zero_grad() | |||
| loss.backward() | |||
| self.optimizer.step() | |||
| step_time = time.perf_counter() - start_time | |||
| metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } | |||
| metrics.update({ | |||
| 'total_loss': loss.item(), | |||
| 'step_time': step_time, | |||
| 'lr': float( self.optimizer.param_groups[0]['lr'] ) | |||
| }) | |||
| return metrics | |||
| @@ -0,0 +1,22 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| class DataTypeError(object): | |||
| pass | |||
| class InvalidMapping(object): | |||
| pass | |||
| @@ -0,0 +1,2 @@ | |||
| from ._hub import * | |||
| from . import meta | |||
| @@ -0,0 +1,288 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| from pip._internal import main as pipmain | |||
| from typing import Sequence, Optional | |||
| import torch | |||
| import os, sys, re | |||
| import shutil, inspect | |||
| import importlib.util | |||
| from glob import glob | |||
| from ruamel.yaml import YAML | |||
| from ruamel.yaml import comments | |||
| from ._module_mapping import PACKAGE_NAME_TO_IMPORT_NAME | |||
| import hashlib | |||
| import inspect | |||
| _DEFAULT_PROTOCOL = 2 | |||
| _DEPENDENCY_FILE = "requirements.txt" | |||
| _CODE_DIR = "code" | |||
| _WEIGHTS_DIR = "weight" | |||
| _TAGS_DIR = "tag" | |||
| yaml = YAML() | |||
| def _replace_invalid_char(name): | |||
| return name.replace('-', '_').replace(' ', '_') | |||
| def save( | |||
| model: torch.nn.Module, | |||
| save_path: str, | |||
| entry_name: str, | |||
| spec_name: str, | |||
| code_path: str, # path to SOURCE_CODE_DIR or hubconf.py | |||
| metadata: dict, | |||
| tags: dict = None, | |||
| ignore_files: Sequence = None, | |||
| save_arch: bool = False, | |||
| overwrite: bool = False, | |||
| ): | |||
| entry_name = _replace_invalid_char(entry_name) | |||
| if spec_name is not None: | |||
| spec_name = _replace_invalid_char(spec_name) | |||
| if not os.path.isdir(save_path): | |||
| overwrite = True | |||
| save_path = os.path.abspath(save_path) | |||
| export_code_path = os.path.join(save_path, _CODE_DIR) | |||
| export_weights_path = os.path.join(save_path, _WEIGHTS_DIR) | |||
| os.makedirs(save_path, exist_ok=True) | |||
| os.makedirs(export_code_path, exist_ok=True) | |||
| os.makedirs(export_weights_path, exist_ok=True) | |||
| code_path = os.path.abspath(code_path) | |||
| if os.path.isdir( code_path ): | |||
| if overwrite: | |||
| shutil.rmtree(export_code_path) | |||
| _copy_file_or_tree(src=code_path, dst=export_code_path) # overwrite old files | |||
| elif code_path.endswith('.py'): | |||
| if overwrite: | |||
| shutil.copy2(src=code_path, dst=os.path.join( export_code_path, 'hubconf.py' )) # overwrite old files | |||
| if hasattr(model, 'SETUP_INFO'): | |||
| del model.SETUP_INFO | |||
| if hasattr(model, 'METADATA'): | |||
| del model.METADATA | |||
| model_and_metadata = { | |||
| 'model': model if save_arch else model.state_dict(), | |||
| 'metadata': metadata, | |||
| } | |||
| temp_pth = os.path.join(export_weights_path, 'temp.pth') | |||
| torch.save(model_and_metadata, temp_pth) | |||
| if spec_name is None: | |||
| _md5 = md5( temp_pth ) | |||
| spec_name = _md5 | |||
| shutil.move( temp_pth, os.path.join(export_weights_path, '%s-%s.pth'%(entry_name, spec_name)) ) | |||
| if tags is not None: | |||
| save_tags( tags, save_path, entry_name, spec_name ) | |||
| def list_entry(path): | |||
| path = os.path.abspath( os.path.expanduser( path ) ) | |||
| if path.endswith('.py'): | |||
| code_dir = os.path.dirname( path ) | |||
| entry_file = path | |||
| elif os.path.isdir(path): | |||
| code_dir = os.path.join( path, _CODE_DIR ) | |||
| entry_file = os.path.join(path, _CODE_DIR, 'app.py') | |||
| else: | |||
| raise NotImplementedError | |||
| sys.path.insert(0, code_dir) | |||
| spec = importlib.util.spec_from_file_location('app', entry_file) | |||
| module = importlib.util.module_from_spec(spec) | |||
| spec.loader.exec_module(module) | |||
| entry_list = [ ( f, getattr(module, f) ) for f in dir(module) if callable(getattr(module, f)) and not f.startswith('_') ] | |||
| sys.path.remove(code_dir) | |||
| return entry_list | |||
| def list_spec(path, entry_name=None): | |||
| path = os.path.abspath( os.path.expanduser( path ) ) | |||
| weight_dir = os.path.join( path, _WEIGHTS_DIR ) | |||
| spec_list = [ f.split('-') for f in os.listdir( weight_dir ) if f.endswith('.pth')] | |||
| spec_list = [ (f[0], f[1][:-4]) for f in spec_list ] | |||
| if entry_name is not None: | |||
| spec_list = [ s for s in spec_list if s[0]==entry_name ] | |||
| return spec_list | |||
| def load(path: str, entry_name:str=None, spec_name: str=None, pretrained=True, **kwargs): | |||
| """ | |||
| check dependencies and load pytorch models. | |||
| Args: | |||
| path: path to the model package | |||
| pretrained: load the pretrained model | |||
| """ | |||
| path = os.path.abspath(path) | |||
| if not os.path.exists(path): | |||
| raise FileNotFoundError | |||
| code_dir = os.path.join(path, _CODE_DIR) | |||
| if os.path.isdir(path): | |||
| cwd = os.getcwd() | |||
| os.chdir(code_dir) | |||
| sys.path.insert(0, code_dir) | |||
| spec = importlib.util.spec_from_file_location( | |||
| 'hubconf', os.path.join(code_dir, 'hubconf.py')) | |||
| module = importlib.util.module_from_spec(spec) | |||
| spec.loader.exec_module(module) | |||
| if hasattr( module, 'dependencies' ): | |||
| deps = getattr( module, 'dependencies') | |||
| for dep in deps: | |||
| _import_with_auto_install(dep) | |||
| if entry_name is None: # auto select | |||
| pth_file = [pth for pth in os.listdir( os.path.join(path, _WEIGHTS_DIR) )] | |||
| assert len(pth_file)<=1, "Loading models with more than one weight files (.pth) is ambiguous" | |||
| pth_file = pth_file[0] | |||
| entry_name = pth_file.split('-')[0] | |||
| entry_fn = getattr( module, entry_name ) | |||
| else: | |||
| entry_fn = getattr( module, entry_name ) | |||
| if spec_name is None: | |||
| pth_file = [pth for pth in os.listdir( os.path.join(path, _WEIGHTS_DIR) ) if pth.startswith(entry_name) ] | |||
| assert len(pth_file)<=1, "Loading models with more than one weight files (.pth) is ambiguous" | |||
| pth_file = pth_file[0] | |||
| else: | |||
| pth_file = '%s-%s.pth'%(entry_name, spec_name) | |||
| try: | |||
| model_and_metadata = torch.load(os.path.join(path, _WEIGHTS_DIR, pth_file), map_location='cpu' ) | |||
| except: raise FileNotFoundError | |||
| if isinstance( model_and_metadata['model'], torch.nn.Module ): | |||
| model = model_and_metadata['model'] | |||
| else: | |||
| entry_args = model_and_metadata['metadata']['entry_args'] | |||
| if entry_args is None: | |||
| entry_args = dict() | |||
| model = entry_fn( **entry_args ) | |||
| if pretrained: | |||
| model.load_state_dict( model_and_metadata['model'], False) | |||
| # setup metadata and atlas info | |||
| model.METADATA = model_and_metadata['metadata'] | |||
| model.SETUP_INFO = {"path": path, "entry_name": entry_name} | |||
| sys.path.pop(0) | |||
| os.chdir(cwd) | |||
| return model | |||
| raise NotImplementedError | |||
| def load_metadata(path, entry_name=None, spec_name=None): | |||
| path = os.path.abspath(path) | |||
| if os.path.isdir(path): | |||
| if entry_name is None: # auto select | |||
| pth_file = [pth for pth in os.listdir( os.path.join(path, _WEIGHTS_DIR) )] | |||
| assert len(pth_file)<=1, "Loading models with more than one weight files (.pth) is ambiguous" | |||
| pth_file = pth_file[0] | |||
| else: | |||
| if spec_name is None: | |||
| pth_file = [pth for pth in os.listdir( os.path.join(path, _WEIGHTS_DIR) ) if pth.startswith(entry_name) ] | |||
| assert len(pth_file)<=1, "Loading models with more than one weight files (.pth) is ambiguous" | |||
| pth_file = pth_file[0] | |||
| else: | |||
| pth_file = '%s-%s.pth'%(entry_name, spec_name) | |||
| try: | |||
| metadata = torch.load( os.path.join( path, _WEIGHTS_DIR, pth_file ), map_location='cpu' )['metadata'] | |||
| except: | |||
| FileNotFoundError | |||
| return metadata | |||
| def load_tags(path, entry_name, spec_name): | |||
| path = os.path.abspath(path) | |||
| tags_path = os.path.join(path, _TAGS_DIR, '%s-%s.yml'%( entry_name, spec_name )) | |||
| if os.path.isfile(tags_path): | |||
| return _to_python_type(_yaml_load(tags_path)) | |||
| return dict() | |||
| def save_tags(tags, path, entry_name, spec_name): | |||
| path = os.path.abspath(path) | |||
| if tags is None: | |||
| tags = {} | |||
| tags_path = os.path.join(path, _TAGS_DIR, '%s-%s.yml'%( entry_name, spec_name )) | |||
| os.makedirs( os.path.join( path, _TAGS_DIR ), exist_ok=True ) | |||
| _yaml_dump(tags_path, tags) | |||
| def _yaml_dump(f, obj): | |||
| with open(f, 'w') as f: | |||
| yaml.dump(obj, f) | |||
| def _yaml_load(f): | |||
| with open(f, 'r') as f: | |||
| return yaml.load(f) | |||
| def _to_python_type(data): | |||
| if isinstance(data, dict): | |||
| for k, v in data.items(): | |||
| data[k] = _to_python_type(v) | |||
| return dict(data) | |||
| elif isinstance(data, comments.CommentedSeq ): | |||
| for idx, v in enumerate(data): | |||
| data[idx] = _to_python_type(v) | |||
| return list(data) | |||
| return data | |||
| def md5(fname): | |||
| hash_md5 = hashlib.md5() | |||
| with open(fname, "rb") as f: | |||
| for chunk in iter(lambda: f.read(4096), b""): | |||
| hash_md5.update(chunk) | |||
| return hash_md5.hexdigest() | |||
| def _get_package_name_and_version(package): | |||
| _version_sym_list = ('==', '>=', '<=') | |||
| for sym in _version_sym_list: | |||
| if sym in package: | |||
| return package.split(sym) | |||
| return package, None | |||
| def _import_with_auto_install(package): | |||
| package_name, version = _get_package_name_and_version(package) | |||
| package_name = package_name.strip() | |||
| import_name = PACKAGE_NAME_TO_IMPORT_NAME.get( | |||
| package_name, package_name).replace('-', '_') | |||
| try: | |||
| return __import__(import_name) | |||
| except ImportError: | |||
| try: | |||
| pipmain.main(['install', package]) | |||
| except: | |||
| pipmain(['install', package]) | |||
| return __import__(import_name) | |||
| from distutils.dir_util import copy_tree | |||
| def _copy_file_or_tree(src, dst): | |||
| if os.path.isfile(src): | |||
| shutil.copy2(src=src, dst=dst) | |||
| else: | |||
| copy_tree(src=src, dst=dst) | |||
| def _glob_list(path_list, recursive=False): | |||
| results = [] | |||
| for path in path_list: | |||
| if '*' in path: | |||
| path = list(glob(path, recursive=recursive)) | |||
| results.extend(path) | |||
| else: | |||
| results.append(path) | |||
| return results | |||
| @@ -0,0 +1,6 @@ | |||
| PACKAGE_NAME_TO_IMPORT_NAME = { | |||
| 'opencv-python': 'cv2', | |||
| 'pillow': 'PIL', | |||
| 'scikit-learn': 'sklearn', | |||
| 'scikit-image': 'scikit-image', | |||
| } | |||
| @@ -0,0 +1,22 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| CLASSIFICATION = 'classification' | |||
| SEGMENTATION = 'segmentation' | |||
| DETECTION = 'detection' | |||
| DEPTH = 'depth' | |||
| AUTOENCODER = 'autoencoder' | |||
| @@ -0,0 +1,3 @@ | |||
| from .meta import Metadata | |||
| from .input import ImageInput | |||
| from . import TASK | |||
| @@ -0,0 +1,31 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| from typing import Union, Sequence | |||
| # INPUT Metadata | |||
| def ImageInput( size: Union[Sequence[int], int], | |||
| range: Union[Sequence[int], int], | |||
| space: str, | |||
| normalize: Union[Sequence]=None): | |||
| assert space in ['rgb', 'gray', 'bgr', 'rgbd'] | |||
| return dict( | |||
| size=size, | |||
| range=range, | |||
| space=space, | |||
| normalize=normalize | |||
| ) | |||
| @@ -0,0 +1,44 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| from copy import deepcopy | |||
| import abc, os | |||
| from . import TASK | |||
| import torch | |||
| __all__ = ['yaml', 'ImageInput', 'MetaData', 'AtlasEntryBase'] | |||
| # Model Metadata | |||
| def Metadata( name: str, | |||
| dataset: str, | |||
| task: int, | |||
| url: str, | |||
| input: dict, | |||
| entry_args: dict, | |||
| other_metadata: dict): | |||
| if task in [TASK.SEGMENTATION, TASK.CLASSIFICATION]: | |||
| assert 'num_classes' in other_metadata | |||
| return dict( | |||
| name=name, | |||
| dataset=dataset, | |||
| task=task, | |||
| url=url, | |||
| input=dict(input), | |||
| entry_args=entry_args, | |||
| other_metadata=other_metadata | |||
| ) | |||
| @@ -0,0 +1,8 @@ | |||
| from .stream_metrics import Metric, MetricCompose | |||
| from .accuracy import Accuracy, TopkAccuracy | |||
| from .confusion_matrix import ConfusionMatrix, IoU, mIoU | |||
| from .regression import * | |||
| from .average import AverageMetric | |||
| @@ -0,0 +1,118 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| import numpy as np | |||
| import torch | |||
| from kamal.core.metrics.stream_metrics import Metric | |||
| from typing import Callable | |||
| __all__=['Accuracy', 'TopkAccuracy'] | |||
| class Accuracy(Metric): | |||
| def __init__(self, attach_to=None): | |||
| super(Accuracy, self).__init__(attach_to=attach_to) | |||
| self.reset() | |||
| @torch.no_grad() | |||
| def update(self, outputs, targets): | |||
| outputs, targets = self._attach(outputs, targets) | |||
| outputs = outputs.max(1)[1] | |||
| self._correct += ( outputs.view(-1)==targets.view(-1) ).sum() | |||
| self._cnt += torch.numel( targets ) | |||
| def get_results(self): | |||
| return (self._correct / self._cnt).detach().cpu() | |||
| def reset(self): | |||
| self._correct = self._cnt = 0.0 | |||
| class TopkAccuracy(Metric): | |||
| def __init__(self, topk=5, attach_to=None): | |||
| super(TopkAccuracy, self).__init__(attach_to=attach_to) | |||
| self._topk = topk | |||
| self.reset() | |||
| @torch.no_grad() | |||
| def update(self, outputs, targets): | |||
| outputs, targets = self._attach(outputs, targets) | |||
| _, outputs = outputs.topk(self._topk, dim=1, largest=True, sorted=True) | |||
| correct = outputs.eq( targets.view(-1, 1).expand_as(outputs) ) | |||
| self._correct += correct[:, :self._topk].view(-1).float().sum(0).item() | |||
| self._cnt += len(targets) | |||
| def get_results(self): | |||
| return self._correct / self._cnt | |||
| def reset(self): | |||
| self._correct = 0.0 | |||
| self._cnt = 0.0 | |||
| class StreamCEMAPMetrics(): | |||
| @property | |||
| def PRIMARY_METRIC(self): | |||
| return "eap" | |||
| def __init__(self): | |||
| self.reset() | |||
| def update(self, logits, targets): | |||
| preds = logits.max(1)[1] | |||
| # targets: -1 negative, 0 difficult, 1 positive | |||
| if isinstance(preds, torch.Tensor): | |||
| preds = preds.cpu().numpy() | |||
| targets = targets.cpu().numpy() | |||
| self.preds = preds if self.preds is None else np.append(self.preds, preds, axis=0) | |||
| self.targets = targets if self.targets is None else np.append(self.targets, targets, axis=0) | |||
| def get_results(self): | |||
| nTest = self.targets.shape[0] | |||
| nLabel = self.targets.shape[1] | |||
| eap = np.zeros(nTest) | |||
| for i in range(0,nTest): | |||
| R = np.sum(self.targets[i,:]==1) | |||
| for j in range(0,nLabel): | |||
| if self.targets[i,j]==1: | |||
| r = np.sum(self.preds[i,np.nonzero(self.targets[i,:]!=0)]>=self.preds[i,j]) | |||
| rb = np.sum(self.preds[i,np.nonzero(self.targets[i,:]==1)] >= self.preds[i,j]) | |||
| eap[i] = eap[i] + rb/(r*1.0) | |||
| eap[i] = eap[i]/R | |||
| # emap = np.nanmean(ap) | |||
| cap = np.zeros(nLabel) | |||
| for i in range(0,nLabel): | |||
| R = np.sum(self.targets[:,i]==1) | |||
| for j in range(0,nTest): | |||
| if self.targets[j,i]==1: | |||
| r = np.sum(self.preds[np.nonzero(self.targets[:,i]!=0),i] >= self.preds[j,i]) | |||
| rb = np.sum(self.preds[np.nonzero(self.targets[:,i]==1),i] >= self.preds[j,i]) | |||
| cap[i] = cap[i] + rb/(r*1.0) | |||
| cap[i] = cap[i]/R | |||
| # cmap = np.nanmean(ap) | |||
| return { | |||
| 'eap': eap, | |||
| 'emap': np.nanmean(eap), | |||
| 'cap': cap, | |||
| 'cmap': np.nanmean(cap), | |||
| } | |||
| def reset(self): | |||
| self.preds = None | |||
| self.targets = None | |||
| @@ -0,0 +1,49 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| import numpy as np | |||
| import torch | |||
| from kamal.core.metrics.stream_metrics import Metric | |||
| from typing import Callable | |||
| __all__=['AverageMetric'] | |||
| class AverageMetric(Metric): | |||
| def __init__(self, fn:Callable, attach_to=None): | |||
| super(AverageMetric, self).__init__(attach_to=attach_to) | |||
| self._fn = fn | |||
| self.reset() | |||
| @torch.no_grad() | |||
| def update(self, outputs, targets): | |||
| outputs, targets = self._attach(outputs, targets) | |||
| m = self._fn( outputs, targets ) | |||
| if m.ndim > 1: | |||
| self._cnt += m.shape[0] | |||
| self._accum += m.sum(0) | |||
| else: | |||
| self._cnt += 1 | |||
| self._accum += m | |||
| def get_results(self): | |||
| return (self._accum / self._cnt).detach().cpu() | |||
| def reset(self): | |||
| self._accum = 0. | |||
| self._cnt = 0. | |||
| @@ -0,0 +1,68 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| from .stream_metrics import Metric | |||
| import torch | |||
| from typing import Callable | |||
| class ConfusionMatrix(Metric): | |||
| def __init__(self, num_classes, ignore_idx=None, attach_to=None): | |||
| super(ConfusionMatrix, self).__init__(attach_to=attach_to) | |||
| self._num_classes = num_classes | |||
| self._ignore_idx = ignore_idx | |||
| self.reset() | |||
| @torch.no_grad() | |||
| def update(self, outputs, targets): | |||
| outputs, targets = self._attach(outputs, targets) | |||
| if self.confusion_matrix.device != outputs.device: | |||
| self.confusion_matrix = self.confusion_matrix.to(device=outputs.device) | |||
| preds = outputs.max(1)[1].flatten() | |||
| targets = targets.flatten() | |||
| mask = (preds<self._num_classes) & (preds>=0) | |||
| if self._ignore_idx: | |||
| mask = mask & (targets!=self._ignore_idx) | |||
| preds, targets = preds[mask], targets[mask] | |||
| hist = torch.bincount( self._num_classes * targets + preds, | |||
| minlength=self._num_classes ** 2 ).view(self._num_classes, self._num_classes) | |||
| self.confusion_matrix += hist | |||
| def get_results(self): | |||
| return self.confusion_matrix.detach().cpu() | |||
| def reset(self): | |||
| self._cnt = 0 | |||
| self.confusion_matrix = torch.zeros(self._num_classes, self._num_classes, dtype=torch.int64, requires_grad=False) | |||
| class IoU(Metric): | |||
| def __init__(self, confusion_matrix: ConfusionMatrix, attach_to=None): | |||
| self._confusion_matrix = confusion_matrix | |||
| def update(self, outputs, targets): | |||
| pass # update will be done in confusion matrix | |||
| def reset(self): | |||
| pass | |||
| def get_results(self): | |||
| cm = self._confusion_matrix.get_results() | |||
| iou = cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) - cm.diag() + 1e-9) | |||
| return iou | |||
| class mIoU(IoU): | |||
| def get_results(self): | |||
| return super(mIoU, self).get_results().mean() | |||
| @@ -0,0 +1,86 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| import numpy as np | |||
| import torch | |||
| from kamal.core.metrics.stream_metrics import StreamMetricsBase | |||
| class NormalPredictionMetrics(StreamMetricsBase): | |||
| @property | |||
| def PRIMARY_METRIC(self): | |||
| return 'mean angle' | |||
| def __init__(self, thresholds): | |||
| self.thresholds = thresholds | |||
| self.preds = None | |||
| self.targets = None | |||
| self.masks = None | |||
| @torch.no_grad() | |||
| def update(self, preds, targets, masks): | |||
| """ | |||
| **Type**: numpy.ndarray or torch.Tensor | |||
| **Shape:** | |||
| - **preds**: $(N, 3, H, W)$. | |||
| - **targets**: $(N, 3, H, W)$. | |||
| - **masks**: $(N, 1, H, W)$. | |||
| """ | |||
| if isinstance(preds, torch.Tensor): | |||
| preds = preds.cpu().numpy() | |||
| targets = targets.cpu().numpy() | |||
| masks = masks.cpu().numpy() | |||
| self.preds = preds if self.preds is None else np.append(self.preds, preds, axis=0) | |||
| self.targets = targets if self.targets is None else np.append(self.targets, targets, axis=0) | |||
| self.masks = masks if self.masks is None else np.append(self.masks, masks, axis=0) | |||
| def get_results(self): | |||
| """ | |||
| **Returns:** | |||
| - **mean angle** | |||
| - **median angle** | |||
| - **precents for angle within thresholds** | |||
| """ | |||
| products = np.sum(self.preds * self.targets, axis=1) | |||
| angles = np.arccos(np.clip(products, -1.0, 1.0)) / np.pi * 180 | |||
| self.masks = self.masks.squeeze(1) | |||
| angles = angles[self.masks == 1] | |||
| mean_angle = np.mean(angles) | |||
| median_angle = np.median(angles) | |||
| count = self.masks.sum() | |||
| threshold_percents = {} | |||
| for threshold in self.thresholds: | |||
| # threshold_percents[threshold] = np.sum((angles < threshold)) / count | |||
| threshold_percents[threshold] = np.mean(angles < threshold) | |||
| if return_key_metric: | |||
| return ('absolute relative', ard) | |||
| return { | |||
| 'mean angle': mean_angle, | |||
| 'median angle': median_angle, | |||
| 'percents within thresholds': threshold_percents | |||
| } | |||
| def reset(self): | |||
| self.preds = None | |||
| self.targets = None | |||
| self.masks = None | |||
| @@ -0,0 +1,199 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| import numpy as np | |||
| import torch | |||
| from kamal.core.metrics.stream_metrics import Metric | |||
| from typing import Callable | |||
| __all__=['MeanSquaredError', 'RootMeanSquaredError', 'MeanAbsoluteError', | |||
| 'ScaleInveriantMeanSquaredError', 'RelativeDifference', | |||
| 'AbsoluteRelativeDifference', 'SquaredRelativeDifference', 'Threshold' ] | |||
| class MeanSquaredError(Metric): | |||
| def __init__(self, log_scale=False, attach_to=None): | |||
| super(MeanSquaredError, self).__init__( attach_to=attach_to ) | |||
| self.reset() | |||
| self.log_scale=log_scale | |||
| @torch.no_grad() | |||
| def update(self, outputs, targets): | |||
| outputs, targets = self._attach(outputs, targets) | |||
| if self.log_scale: | |||
| diff = torch.sum((torch.log(outputs+1e-8) - torch.log(targets+1e-8))**2) | |||
| else: | |||
| diff = torch.sum((outputs - targets)**2) | |||
| self._accum_sq_diff += diff | |||
| self._cnt += torch.numel(outputs) | |||
| def get_results(self): | |||
| return (self._accum_sq_diff / self._cnt).detach().cpu() | |||
| def reset(self): | |||
| self._accum_sq_diff = 0. | |||
| self._cnt = 0. | |||
| class RootMeanSquaredError(MeanSquaredError): | |||
| def get_results(self): | |||
| return torch.sqrt( (self._accum_sq_diff / self._cnt) ).detach().cpu() | |||
| class MeanAbsoluteError(Metric): | |||
| def __init__(self, attach_to=None ): | |||
| super(MeanAbsoluteError, self).__init__( attach_to=attach_to ) | |||
| self.reset() | |||
| @torch.no_grad() | |||
| def update(self, outputs, targets): | |||
| outputs, targets = self._attach(outputs, targets) | |||
| diff = torch.sum((outputs - targets).abs()) | |||
| self._accum_abs_diff += diff | |||
| self._cnt += torch.numel(outputs) | |||
| def get_results(self): | |||
| return (self._accum_abs_diff / self._cnt).detach().cpu() | |||
| def reset(self): | |||
| self._accum_abs_diff = 0. | |||
| self._cnt = 0. | |||
| class ScaleInveriantMeanSquaredError(Metric): | |||
| def __init__(self, attach_to=None ): | |||
| super(ScaleInveriantMeanSquaredError, self).__init__( attach_to=attach_to ) | |||
| self.reset() | |||
| @torch.no_grad() | |||
| def update(self, outputs, targets): | |||
| outputs, targets = self._attach(outputs, targets) | |||
| diff_log = torch.log( outputs+1e-8 ) - torch.log( targets+1e-8 ) | |||
| self._accum_log_diff = diff_log.sum() | |||
| self._accum_sq_log_diff = (diff_log**2).sum() | |||
| self._cnt += torch.numel(outputs) | |||
| def get_results(self): | |||
| return ( self._accum_sq_log_diff / self._cnt - 0.5 * (self._accum_log_diff**2 / self._cnt**2) ).detach().cpu() | |||
| def reset(self): | |||
| self._accum_log_diff = 0. | |||
| self._accum_sq_log_diff = 0. | |||
| self._cnt = 0. | |||
| class RelativeDifference(Metric): | |||
| def __init__(self, attach_to=None ): | |||
| super(RelativeDifference, self).__init__( attach_to=attach_to ) | |||
| self.reset() | |||
| @torch.no_grad() | |||
| def update(self, outputs, targets): | |||
| outputs, targets = self._attach(outputs, targets) | |||
| diff = (outputs - targets).abs() | |||
| self._accum_abs_rel += (diff/targets).sum() | |||
| self._cnt += torch.numel(outputs) | |||
| def get_results(self): | |||
| return (self._accum_abs_rel / self._cnt).detach().cpu() | |||
| def reset(self): | |||
| self._accum_abs_rel = 0. | |||
| self._cnt = 0. | |||
| class AbsoluteRelativeDifference(Metric): | |||
| def __init__(self, attach_to=None ): | |||
| super(AbsoluteRelativeDifference, self).__init__( attach_to=attach_to ) | |||
| self.reset() | |||
| @torch.no_grad() | |||
| def update(self, outputs, targets): | |||
| outputs, targets = self._attach(outputs, targets) | |||
| diff = (outputs - targets).abs() | |||
| self._accum_abs_rel += (diff/targets).sum() | |||
| self._cnt += torch.numel(outputs) | |||
| def get_results(self): | |||
| return (self._accum_abs_rel / self._cnt).detach().cpu() | |||
| def reset(self): | |||
| self._accum_abs_rel = 0. | |||
| self._cnt = 0. | |||
| class AbsoluteRelativeDifference(Metric): | |||
| def __init__(self, attach_to=None ): | |||
| super(AbsoluteRelativeDifference, self).__init__( attach_to=attach_to ) | |||
| self.reset() | |||
| @torch.no_grad() | |||
| def update(self, outputs, targets): | |||
| outputs, targets = self._attach(outputs, targets) | |||
| diff = (outputs - targets).abs() | |||
| self._accum_abs_rel += (diff/targets).sum() | |||
| self._cnt += torch.numel(outputs) | |||
| def get_results(self): | |||
| return (self._accum_abs_rel / self._cnt).detach().cpu() | |||
| def reset(self): | |||
| self._accum_abs_rel = 0. | |||
| self._cnt = 0. | |||
| class SquaredRelativeDifference(Metric): | |||
| def __init__(self, attach_to=None ): | |||
| super(SquaredRelativeDifference, self).__init__( attach_to=attach_to ) | |||
| self.reset() | |||
| @torch.no_grad() | |||
| def update(self, outputs, targets): | |||
| outputs, targets = self._attach(outputs, targets) | |||
| diff = (outputs - targets)**2 | |||
| self._accum_sq_rel += (diff/targets).sum() | |||
| self._cnt += torch.numel(outputs) | |||
| def get_results(self): | |||
| return (self._accum_sq_rel / self._cnt).detach().cpu() | |||
| def reset(self): | |||
| self._accum_sq_rel = 0. | |||
| self._cnt = 0. | |||
| class Threshold(Metric): | |||
| def __init__(self, thresholds=[1.25, 1.25**2, 1.25**3], attach_to=None ): | |||
| super(Threshold, self).__init__( attach_to=attach_to ) | |||
| self.thresholds = thresholds | |||
| self.reset() | |||
| @torch.no_grad() | |||
| def update(self, outputs, targets): | |||
| outputs, targets = self._attach(outputs, targets) | |||
| sigma = torch.max(outputs / targets, targets / outputs) | |||
| for thres in self.thresholds: | |||
| self._accum_thres[thres]+=torch.sum( sigma<thres ) | |||
| self._cnt += torch.numel(outputs) | |||
| def get_results(self): | |||
| return { thres: (self._accum_thres[thres] / self._cnt).detach().cpu() for thres in self.thresholds } | |||
| def reset(self): | |||
| self._cnt = 0. | |||
| self._accum_thres = {thres: 0. for thres in self.thresholds} | |||
| @@ -0,0 +1,82 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| from __future__ import division | |||
| import torch | |||
| import numpy as np | |||
| from abc import ABC, abstractmethod | |||
| from typing import Callable, Union, Any, Mapping, Sequence | |||
| import numbers | |||
| from kamal.core.attach import AttachTo | |||
| class Metric(ABC): | |||
| def __init__(self, attach_to=None): | |||
| self._attach = AttachTo(attach_to) | |||
| @abstractmethod | |||
| def update(self, pred, target): | |||
| """ Overridden by subclasses """ | |||
| raise NotImplementedError() | |||
| @abstractmethod | |||
| def get_results(self): | |||
| """ Overridden by subclasses """ | |||
| raise NotImplementedError() | |||
| @abstractmethod | |||
| def reset(self): | |||
| """ Overridden by subclasses """ | |||
| raise NotImplementedError() | |||
| class MetricCompose(dict): | |||
| def __init__(self, metric_dict: Mapping): | |||
| self._metric_dict = metric_dict | |||
| def add_metrics( self, metric_dict: Mapping): | |||
| if isinstance(metric_dict, MetricCompose): | |||
| metric_dict = metric_dict.metrics | |||
| self._metric_dict.update(metric_dict) | |||
| return self | |||
| @property | |||
| def metrics(self): | |||
| return self._metric_dict | |||
| @torch.no_grad() | |||
| def update(self, outputs, targets): | |||
| for key, metric in self._metric_dict.items(): | |||
| if isinstance(metric, Metric): | |||
| metric.update(outputs, targets) | |||
| def get_results(self): | |||
| results = {} | |||
| for key, metric in self._metric_dict.items(): | |||
| if isinstance(metric, Metric): | |||
| results[key] = metric.get_results() | |||
| return results | |||
| def reset(self): | |||
| for key, metric in self._metric_dict.items(): | |||
| if isinstance(metric, Metric): | |||
| metric.reset() | |||
| def __getitem__(self, name): | |||
| return self._metric_dict[name] | |||
| @@ -0,0 +1,3 @@ | |||
| from .task import StandardTask, StandardMetrics, GeneralTask, Task, TaskCompose | |||
| from . import loss | |||
| @@ -0,0 +1,2 @@ | |||
| from . import functional, loss | |||
| from .loss import * | |||
| @@ -0,0 +1,107 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| import torch | |||
| import torch.nn.functional as F | |||
| import torch.nn as nn | |||
| def kldiv(logits, targets, T=1.0): | |||
| """ Cross Entropy for soft targets | |||
| Parameters: | |||
| - logits (Tensor): logits score (e.g. outputs of fc layer) | |||
| - targets (Tensor): logits of soft targets | |||
| - T (float): temperature of distill | |||
| - reduction: reduction to the output | |||
| """ | |||
| p_targets = F.softmax(targets/T, dim=1) | |||
| logp_logits = F.log_softmax(logits/T, dim=1) | |||
| kld = F.kl_div(logp_logits, p_targets, reduction='none') * (T**2) | |||
| return kld.sum(1).mean() | |||
| def jsdiv(logits, targets, T=1.0, reduction='mean'): | |||
| p = F.softmax(logits, dim=1) | |||
| q = F.softmax(targets, dim=1) | |||
| log_m = torch.log( (p+q) / 2 ) | |||
| return 0.5* ( F.kl_div( log_m, p, reduction=reduction) + F.kl_div( log_m, q, reduction=reduction) ) | |||
| def mmd_loss(f1, f2, sigmas, normalized=False): | |||
| if len(f1.shape) != 2: | |||
| N, C, H, W = f1.shape | |||
| f1 = f1.view(N, -1) | |||
| N, C, H, W = f2.shape | |||
| f2 = f2.view(N, -1) | |||
| if normalized == True: | |||
| f1 = F.normalize(f1, p=2, dim=1) | |||
| f2 = F.normalize(f2, p=2, dim=1) | |||
| return _mmd_rbf2(f1, f2, sigmas=sigmas) | |||
| def psnr(img1, img2, size_average=True, data_range=255): | |||
| N = img1.shape[0] | |||
| mse = torch.mean(((img1-img2)**2).view(N, -1), dim=1) | |||
| psnr = torch.clamp(torch.log10(data_range**2 / mse) * 10, 0.0, 99.99) | |||
| if size_average == True: | |||
| psnr = psnr.mean() | |||
| return psnr | |||
| def soft_cross_entropy(logits, targets, T=1.0, size_average=True): | |||
| """ Cross Entropy for soft targets | |||
| **Parameters:** | |||
| - **logits** (Tensor): logits score (e.g. outputs of fc layer) | |||
| - **targets** (Tensor): logits of soft targets | |||
| - **T** (float): temperature of distill | |||
| - **size_average**: average the outputs | |||
| """ | |||
| p_targets = F.softmax(targets/T, dim=1) | |||
| logp_pred = F.log_softmax(logits/T, dim=1) | |||
| ce = torch.sum(-p_targets * logp_pred, dim=1) | |||
| if size_average: | |||
| return ce.mean() * T * T | |||
| else: | |||
| return ce * T * T | |||
| def _mmd_rbf2(x, y, sigmas=None): | |||
| N, _ = x.shape | |||
| xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t()) | |||
| rx = (xx.diag().unsqueeze(0).expand_as(xx)) | |||
| ry = (yy.diag().unsqueeze(0).expand_as(yy)) | |||
| K = L = P = 0.0 | |||
| XX2 = rx.t() + rx - 2*xx | |||
| YY2 = ry.t() + ry - 2*yy | |||
| XY2 = rx.t() + ry - 2*zz | |||
| if sigmas is None: | |||
| sigma2 = torch.mean((XX2.detach()+YY2.detach()+2*XY2.detach()) / 4) | |||
| sigmas2 = [sigma2/4, sigma2/2, sigma2, sigma2*2, sigma2*4] | |||
| alphas = [1.0 / (2 * sigma2) for sigma2 in sigmas2] | |||
| else: | |||
| alphas = [1.0 / (2 * sigma**2) for sigma in sigmas] | |||
| for alpha in alphas: | |||
| K += torch.exp(- alpha * (XX2.clamp(min=1e-12))) | |||
| L += torch.exp(- alpha * (YY2.clamp(min=1e-12))) | |||
| P += torch.exp(- alpha * (XY2.clamp(min=1e-12))) | |||
| beta = (1./(N*(N))) | |||
| gamma = (2./(N*N)) | |||
| return F.relu(beta * (torch.sum(K)+torch.sum(L)) - gamma * torch.sum(P)) | |||
| @@ -0,0 +1,386 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| import torch | |||
| import torch.nn.functional as F | |||
| import torch.nn as nn | |||
| import numpy as np | |||
| #from pytorch_msssim import ssim, ms_ssim, MS_SSIM, SSIM | |||
| from .functional import * | |||
| class KLDiv(object): | |||
| def __init__(self, T=1.0): | |||
| self.T = T | |||
| def __call__(self, logits, targets): | |||
| return kldiv( logits, targets, T=self.T ) | |||
| class KDLoss(nn.Module): | |||
| """ KD Loss Function | |||
| """ | |||
| def __init__(self, T=1.0, alpha=1.0, use_kldiv=False): | |||
| super(KDLoss, self).__init__() | |||
| self.T = T | |||
| self.alpha = alpha | |||
| self.kdloss = kldiv if use_kldiv else soft_cross_entropy | |||
| def forward(self, logits, targets, hard_targets=None): | |||
| loss = self.kdloss(logits, targets, T=self.T) | |||
| if hard_targets is not None and self.alpha != 0.0: | |||
| loss += self.alpha*F.cross_entropy(logits, hard_targets) | |||
| return loss | |||
| class CFLLoss(nn.Module): | |||
| """ Common Feature Learning Loss | |||
| CFL Loss = MMD + MSE | |||
| """ | |||
| def __init__(self, sigmas, normalized=True): | |||
| super(CFLLoss, self).__init__() | |||
| self.sigmas = sigmas | |||
| self.normalized = normalized | |||
| def forward(self, hs, hts, fts_, fts): | |||
| mmd = mse = 0.0 | |||
| for ht_i in hts: | |||
| mmd += mmd_loss(hs, ht_i, sigmas=self.sigmas, normalized=self.normalized) | |||
| for i in range(len(fts_)): | |||
| mse += F.mse_loss(fts_[i], fts[i]) | |||
| return mmd, mse | |||
| class PSNR_Loss(nn.Module): | |||
| def __init__(self, data_range=1.0, size_average=True): | |||
| super(PSNR_Loss, self).__init__() | |||
| self.data_range = data_range | |||
| self.size_average = size_average | |||
| def forward(self, img1, img2): | |||
| return 100 - psnr(img1, img2, size_average=self.size_average, data_range=self.data_range) | |||
| #class MS_SSIM_Loss(MS_SSIM): | |||
| # def forward(self, img1, img2): | |||
| # return 100*(1 - super(MS_SSIM_Loss, self).forward(img1, img2)) | |||
| class ScaleInvariantLoss(nn.Module): | |||
| """This criterion is used in depth prediction task. | |||
| **Parameters:** | |||
| - **la** (int, optional): Default value is 0.5. No need to change. | |||
| - **ignore_index** (int, optional): Value to ignore. | |||
| **Shape:** | |||
| - **inputs**: $(N, H, W)$. | |||
| - **targets**: $(N, H, W)$. | |||
| - **output**: scalar. | |||
| """ | |||
| def __init__(self, la=0.5, ignore_index=0): | |||
| super(ScaleInvariantLoss, self).__init__() | |||
| self.la = la | |||
| self.ignore_index = ignore_index | |||
| def forward(self, inputs, targets): | |||
| size = inputs.size() | |||
| if len(size) == 3: | |||
| inputs = inputs.view(size[0], -1) | |||
| targets = targets.view(size[0], -1) | |||
| inv_mask = targets.eq(self.ignore_index) | |||
| nums = (1-inv_mask.float()).sum(1) | |||
| log_d = torch.log(inputs) - torch.log(targets) | |||
| log_d[inv_mask] = 0 | |||
| loss = torch.div(torch.pow(log_d, 2).sum(1), nums) - \ | |||
| self.la * torch.pow(torch.div(log_d.sum(1), nums), 2) | |||
| return loss.mean() | |||
| class AngleLoss(nn.Module): | |||
| """This criterion is used in surface normal prediction task. | |||
| **Shape:** | |||
| - **inputs**: $(N, 3, H, W)$. Predicted space vector for each pixel. Must be formalized before. | |||
| - **targets**: $(N, 3, H, W)$. Ground truth. Must be formalized before. | |||
| - **masks**: $(N, 1, H, W)$. One for valid pixels, else zero. | |||
| - **output**: scalar. | |||
| """ | |||
| def forward(self, inputs, targets, masks): | |||
| nums = masks.sum(dim=[1,2,3]) | |||
| product = (inputs * targets).sum(1, keepdim=True) | |||
| loss = -torch.div((product * masks.float()).sum([1,2,3]), nums) | |||
| return loss.mean() | |||
| class FocalLoss(nn.Module): | |||
| def __init__(self, alpha=1, gamma=0, size_average=True, ignore_index=255): | |||
| super(FocalLoss, self).__init__() | |||
| self.alpha = alpha | |||
| self.gamma = gamma | |||
| self.ignore_index = ignore_index | |||
| self.size_average = size_average | |||
| def forward(self, inputs, targets): | |||
| ce_loss = F.cross_entropy(inputs, targets, reduction='none', ignore_index=self.ignore_index) | |||
| pt = torch.exp(-ce_loss) | |||
| focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss | |||
| if self.size_average: | |||
| return focal_loss.mean() | |||
| else: | |||
| return focal_loss.sum() | |||
| class AttentionLoss(nn.Module): | |||
| """ Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer""" | |||
| def __init__(self, p=2): | |||
| super(AttentionLoss, self).__init__() | |||
| self.p = p | |||
| def forward(self, g_s, g_t): | |||
| return sum([self.at_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]) | |||
| def at_loss(self, f_s, f_t): | |||
| s_H, t_H = f_s.shape[2], f_t.shape[2] | |||
| if s_H > t_H: | |||
| f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H)) | |||
| elif s_H < t_H: | |||
| f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H)) | |||
| else: | |||
| pass | |||
| return (self.at(f_s) - self.at(f_t)).pow(2).mean() | |||
| def at(self, f): | |||
| return F.normalize(f.pow(self.p).mean(1).view(f.size(0), -1)) | |||
| class NSTLoss(nn.Module): | |||
| """like what you like: knowledge distill via neuron selectivity transfer""" | |||
| def __init__(self): | |||
| super(NSTLoss, self).__init__() | |||
| pass | |||
| def forward(self, g_s, g_t): | |||
| return sum([self.nst_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]) | |||
| def nst_loss(self, f_s, f_t): | |||
| s_H, t_H = f_s.shape[2], f_t.shape[2] | |||
| if s_H > t_H: | |||
| f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H)) | |||
| elif s_H < t_H: | |||
| f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H)) | |||
| else: | |||
| pass | |||
| f_s = f_s.view(f_s.shape[0], f_s.shape[1], -1) | |||
| f_s = F.normalize(f_s, dim=2) | |||
| f_t = f_t.view(f_t.shape[0], f_t.shape[1], -1) | |||
| f_t = F.normalize(f_t, dim=2) | |||
| # set full_loss as False to avoid unnecessary computation | |||
| full_loss = True | |||
| if full_loss: | |||
| return (self.poly_kernel(f_t, f_t).mean().detach() + self.poly_kernel(f_s, f_s).mean() | |||
| - 2 * self.poly_kernel(f_s, f_t).mean()) | |||
| else: | |||
| return self.poly_kernel(f_s, f_s).mean() - 2 * self.poly_kernel(f_s, f_t).mean() | |||
| def poly_kernel(self, a, b): | |||
| a = a.unsqueeze(1) | |||
| b = b.unsqueeze(2) | |||
| res = (a * b).sum(-1).pow(2) | |||
| return res | |||
| class SPLoss(nn.Module): | |||
| """Similarity-Preserving Knowledge Distillation, ICCV2019, verified by original author""" | |||
| def __init__(self): | |||
| super(SPLoss, self).__init__() | |||
| def forward(self, g_s, g_t): | |||
| return sum([self.similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]) | |||
| def similarity_loss(self, f_s, f_t): | |||
| bsz = f_s.shape[0] | |||
| f_s = f_s.view(bsz, -1) | |||
| f_t = f_t.view(bsz, -1) | |||
| G_s = torch.mm(f_s, torch.t(f_s)) | |||
| # G_s = G_s / G_s.norm(2) | |||
| G_s = torch.nn.functional.normalize(G_s) | |||
| G_t = torch.mm(f_t, torch.t(f_t)) | |||
| # G_t = G_t / G_t.norm(2) | |||
| G_t = torch.nn.functional.normalize(G_t) | |||
| G_diff = G_t - G_s | |||
| loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz) | |||
| return loss | |||
| class RKDLoss(nn.Module): | |||
| """Relational Knowledge Disitllation, CVPR2019""" | |||
| def __init__(self, w_d=25, w_a=50): | |||
| super(RKDLoss, self).__init__() | |||
| self.w_d = w_d | |||
| self.w_a = w_a | |||
| def forward(self, f_s, f_t): | |||
| student = f_s.view(f_s.shape[0], -1) | |||
| teacher = f_t.view(f_t.shape[0], -1) | |||
| # RKD distance loss | |||
| with torch.no_grad(): | |||
| t_d = self.pdist(teacher, squared=False) | |||
| mean_td = t_d[t_d > 0].mean() | |||
| t_d = t_d / mean_td | |||
| d = self.pdist(student, squared=False) | |||
| mean_d = d[d > 0].mean() | |||
| d = d / mean_d | |||
| loss_d = F.smooth_l1_loss(d, t_d) | |||
| # RKD Angle loss | |||
| with torch.no_grad(): | |||
| td = (teacher.unsqueeze(0) - teacher.unsqueeze(1)) | |||
| norm_td = F.normalize(td, p=2, dim=2) | |||
| t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1) | |||
| sd = (student.unsqueeze(0) - student.unsqueeze(1)) | |||
| norm_sd = F.normalize(sd, p=2, dim=2) | |||
| s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1) | |||
| loss_a = F.smooth_l1_loss(s_angle, t_angle) | |||
| loss = self.w_d * loss_d + self.w_a * loss_a | |||
| return loss | |||
| @staticmethod | |||
| def pdist(e, squared=False, eps=1e-12): | |||
| e_square = e.pow(2).sum(dim=1) | |||
| prod = e @ e.t() | |||
| res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps) | |||
| if not squared: | |||
| res = res.sqrt() | |||
| res = res.clone() | |||
| res[range(len(e)), range(len(e))] = 0 | |||
| return res | |||
| class PKTLoss(nn.Module): | |||
| """Probabilistic Knowledge Transfer for deep representation learning""" | |||
| def __init__(self): | |||
| super(PKTLoss, self).__init__() | |||
| def forward(self, f_s, f_t): | |||
| return self.cosine_similarity_loss(f_s, f_t) | |||
| @staticmethod | |||
| def cosine_similarity_loss(output_net, target_net, eps=0.0000001): | |||
| # Normalize each vector by its norm | |||
| output_net_norm = torch.sqrt(torch.sum(output_net ** 2, dim=1, keepdim=True)) | |||
| output_net = output_net / (output_net_norm + eps) | |||
| output_net[output_net != output_net] = 0 | |||
| target_net_norm = torch.sqrt(torch.sum(target_net ** 2, dim=1, keepdim=True)) | |||
| target_net = target_net / (target_net_norm + eps) | |||
| target_net[target_net != target_net] = 0 | |||
| # Calculate the cosine similarity | |||
| model_similarity = torch.mm(output_net, output_net.transpose(0, 1)) | |||
| target_similarity = torch.mm(target_net, target_net.transpose(0, 1)) | |||
| # Scale cosine similarity to 0..1 | |||
| model_similarity = (model_similarity + 1.0) / 2.0 | |||
| target_similarity = (target_similarity + 1.0) / 2.0 | |||
| # Transform them into probabilities | |||
| model_similarity = model_similarity / torch.sum(model_similarity, dim=1, keepdim=True) | |||
| target_similarity = target_similarity / torch.sum(target_similarity, dim=1, keepdim=True) | |||
| # Calculate the KL-divergence | |||
| loss = torch.mean(target_similarity * torch.log((target_similarity + eps) / (model_similarity + eps))) | |||
| return loss | |||
| class SVDLoss(nn.Module): | |||
| """ | |||
| Self-supervised Knowledge Distillation using Singular Value Decomposition | |||
| """ | |||
| def __init__(self, k=1): | |||
| super(SVDLoss, self).__init__() | |||
| self.k = k | |||
| def forward(self, g_s, g_t): | |||
| v_sb = None | |||
| v_tb = None | |||
| losses = [] | |||
| for i, f_s, f_t in zip(range(len(g_s)), g_s, g_t): | |||
| u_t, s_t, v_t = self.svd(f_t, self.k) | |||
| u_s, s_s, v_s = self.svd(f_s, self.k + 3) | |||
| v_s, v_t = self.align_rsv(v_s, v_t) | |||
| s_t = s_t.unsqueeze(1) | |||
| v_t = v_t * s_t | |||
| v_s = v_s * s_t | |||
| if i > 0: | |||
| s_rbf = torch.exp(-(v_s.unsqueeze(2) - v_sb.unsqueeze(1)).pow(2) / 8) | |||
| t_rbf = torch.exp(-(v_t.unsqueeze(2) - v_tb.unsqueeze(1)).pow(2) / 8) | |||
| l2loss = (s_rbf - t_rbf.detach()).pow(2) | |||
| l2loss = torch.where(torch.isfinite(l2loss), l2loss, torch.zeros_like(l2loss)) | |||
| losses.append(l2loss.sum()) | |||
| v_tb = v_t | |||
| v_sb = v_s | |||
| bsz = g_s[0].shape[0] | |||
| losses = [l / bsz for l in losses] | |||
| return sum(losses) | |||
| def svd(self, feat, n=1): | |||
| size = feat.shape | |||
| assert len(size) == 4 | |||
| x = feat.view(-1, size[1], size[2] * size[2]).transpose(-2, -1) | |||
| u, s, v = torch.svd(x) | |||
| u = self.removenan(u) | |||
| s = self.removenan(s) | |||
| v = self.removenan(v) | |||
| if n > 0: | |||
| u = F.normalize(u[:, :, :n], dim=1) | |||
| s = F.normalize(s[:, :n], dim=1) | |||
| v = F.normalize(v[:, :, :n], dim=1) | |||
| return u, s, v | |||
| @staticmethod | |||
| def removenan(x): | |||
| x = torch.where(torch.isfinite(x), x, torch.zeros_like(x)) | |||
| return x | |||
| @staticmethod | |||
| def align_rsv(a, b): | |||
| cosine = torch.matmul(a.transpose(-2, -1), b) | |||
| max_abs_cosine, _ = torch.max(torch.abs(cosine), 1, keepdim=True) | |||
| mask = torch.where(torch.eq(max_abs_cosine, torch.abs(cosine)), | |||
| torch.sign(cosine), torch.zeros_like(cosine)) | |||
| a = torch.matmul(a, mask) | |||
| return a, b | |||
| @@ -0,0 +1,186 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| import abc | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| import sys | |||
| import typing | |||
| from typing import Callable, Dict, List, Any | |||
| from collections import Mapping, Sequence | |||
| from . import loss | |||
| from kamal.core import metrics, exceptions | |||
| from kamal.core.attach import AttachTo | |||
| class Task(object): | |||
| def __init__(self, name): | |||
| self.name = name | |||
| @abc.abstractmethod | |||
| def get_loss( self, outputs, targets ) -> Dict: | |||
| pass | |||
| @abc.abstractmethod | |||
| def predict(self, outputs) -> Any: | |||
| pass | |||
| class GeneralTask(Task): | |||
| def __init__(self, | |||
| name: str, | |||
| loss_fn: Callable, | |||
| scaling:float=1.0, | |||
| pred_fn: Callable=lambda x: x, | |||
| attach_to=None): | |||
| super(GeneralTask, self).__init__(name) | |||
| self._attach = AttachTo(attach_to) | |||
| self.loss_fn = loss_fn | |||
| self.pred_fn = pred_fn | |||
| self.scaling = scaling | |||
| def get_loss(self, outputs, targets): | |||
| outputs, targets = self._attach(outputs, targets) | |||
| return { self.name: self.loss_fn( outputs, targets ) * self.scaling } | |||
| def predict(self, outputs): | |||
| outputs = self._attach(outputs) | |||
| return self.pred_fn(outputs) | |||
| def __repr__(self): | |||
| rep = "Task: [%s loss_fn=%s scaling=%.4f attach=%s]"%(self.name, str(self.loss_fn), self.scaling, self._attach) | |||
| return rep | |||
| class TaskCompose(list): | |||
| def __init__(self, tasks: list): | |||
| for task in tasks: | |||
| if isinstance(task, Task): | |||
| self.append(task) | |||
| def get_loss(self, outputs, targets): | |||
| loss_dict = {} | |||
| for task in self: | |||
| loss_dict.update( task.get_loss( outputs, targets ) ) | |||
| return loss_dict | |||
| def predict(self, outputs): | |||
| results = [] | |||
| for task in self: | |||
| results.append( task.predict( outputs ) ) | |||
| return results | |||
| def __repr__(self): | |||
| rep="TaskCompose: \n" | |||
| for task in self: | |||
| rep+="\t%s\n"%task | |||
| class StandardTask: | |||
| @staticmethod | |||
| def classification(name='ce', scaling=1.0, attach_to=None): | |||
| return GeneralTask( name=name, | |||
| loss_fn=nn.CrossEntropyLoss(), | |||
| scaling=scaling, | |||
| pred_fn=lambda x: x.max(1)[1], | |||
| attach_to=attach_to ) | |||
| @staticmethod | |||
| def binary_classification(name='bce', scaling=1.0, attach_to=None): | |||
| return GeneralTask(name=name, | |||
| loss_fn=F.binary_cross_entropy_with_logits, | |||
| scaling=scaling, | |||
| pred_fn=lambda x: (x>0.5), | |||
| attach_to=attach_to ) | |||
| @staticmethod | |||
| def regression(name='mse', scaling=1.0, attach_to=None): | |||
| return GeneralTask(name=name, | |||
| loss_fn=nn.MSELoss(), | |||
| scaling=scaling, | |||
| pred_fn=lambda x: x, | |||
| attach_to=attach_to ) | |||
| @staticmethod | |||
| def segmentation(name='ce', scaling=1.0, attach_to=None): | |||
| return GeneralTask(name=name, | |||
| loss_fn=nn.CrossEntropyLoss(ignore_index=255), | |||
| scaling=scaling, | |||
| pred_fn=lambda x: x.max(1)[1], | |||
| attach_to=attach_to ) | |||
| @staticmethod | |||
| def monocular_depth(name='l1', scaling=1.0, attach_to=None): | |||
| return GeneralTask(name=name, | |||
| loss_fn=nn.L1Loss(), | |||
| scaling=scaling, | |||
| pred_fn=lambda x: x, | |||
| attach_to=attach_to) | |||
| @staticmethod | |||
| def detection(): | |||
| raise NotImplementedError | |||
| @staticmethod | |||
| def distillation(name='kld', T=1.0, scaling=1.0, attach_to=None): | |||
| return GeneralTask(name=name, | |||
| loss_fn=loss.KLDiv(T=T), | |||
| scaling=scaling, | |||
| pred_fn=lambda x: x.max(1)[1], | |||
| attach_to=attach_to) | |||
| class StandardMetrics(object): | |||
| @staticmethod | |||
| def classification(attach_to=None): | |||
| return metrics.MetricCompose( | |||
| metric_dict={'acc': metrics.Accuracy(attach_to=attach_to)} | |||
| ) | |||
| @staticmethod | |||
| def regression(attach_to=None): | |||
| return metrics.MetricCompose( | |||
| metric_dict={'mse': metrics.MeanSquaredError(attach_to=attach_to)} | |||
| ) | |||
| @staticmethod | |||
| def segmentation(num_classes, ignore_idx=255, attach_to=None): | |||
| confusion_matrix = metrics.ConfusionMatrix(num_classes=num_classes, ignore_idx=ignore_idx, attach_to=attach_to) | |||
| return metrics.MetricCompose( | |||
| metric_dict={'acc': metrics.Accuracy(attach_to=attach_to), | |||
| 'confusion_matrix': confusion_matrix , | |||
| 'miou': metrics.mIoU(confusion_matrix)} | |||
| ) | |||
| @staticmethod | |||
| def monocular_depth(attach_to=None): | |||
| return metrics.MetricCompose( | |||
| metric_dict={ | |||
| 'rmse': metrics.RootMeanSquaredError(attach_to=attach_to), | |||
| 'rmse_log': metrics.RootMeanSquaredError( log_scale=True,attach_to=attach_to ), | |||
| 'rmse_scale_inv': metrics.ScaleInveriantMeanSquaredError(attach_to=attach_to), | |||
| 'abs rel': metrics.AbsoluteRelativeDifference(attach_to=attach_to), | |||
| 'sq rel': metrics.SquaredRelativeDifference(attach_to=attach_to), | |||
| 'percents within thresholds': metrics.Threshold( thresholds=[1.25, 1.25**2, 1.25**3], attach_to=attach_to ) | |||
| } | |||
| ) | |||
| @staticmethod | |||
| def loss_metric(loss_fn): | |||
| return metrics.MetricCompose( | |||
| metric_dict={ | |||
| 'loss': metrics.AverageMetric( loss_fn ) | |||
| } | |||
| ) | |||
| @@ -0,0 +1,2 @@ | |||
| from .prunning import Pruner, strategy | |||
| from .distillation import * | |||
| @@ -0,0 +1,12 @@ | |||
| from .kd import KDDistiller | |||
| from .hint import * | |||
| from .attention import AttentionDistiller | |||
| from .nst import NSTDistiller | |||
| from .sp import SPDistiller | |||
| from .rkd import RKDDistiller | |||
| from .pkt import PKTDistiller | |||
| from .svd import SVDDistiller | |||
| from .cc import * | |||
| from .vid import * | |||
| from . import data_free | |||
| @@ -0,0 +1,47 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| from .kd import KDDistiller | |||
| from kamal.core.tasks.loss import AttentionLoss | |||
| from kamal.core.tasks.loss import KDLoss | |||
| from kamal.utils import set_mode, move_to_device | |||
| import torch | |||
| import torch.nn as nn | |||
| import time | |||
| class AttentionDistiller(KDDistiller): | |||
| def __init__(self, logger=None, tb_writer=None ): | |||
| super(AttentionDistiller, self).__init__( logger, tb_writer ) | |||
| def setup(self, | |||
| student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, | |||
| stu_hooks=[], tea_hooks=[], out_flags=[], device=None): | |||
| super(AttentionDistiller, self).setup( | |||
| student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device) | |||
| self.stu_hooks = stu_hooks | |||
| self.tea_hooks = tea_hooks | |||
| self.out_flags = out_flags | |||
| self._at_loss = AttentionLoss() | |||
| def additional_kd_loss(self, engine, batch): | |||
| feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] | |||
| feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] | |||
| g_s = feat_s[1:-1] | |||
| g_t = feat_t[1:-1] | |||
| return self._at_loss(g_s, g_t) | |||
| @@ -0,0 +1,55 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| from .kd import KDDistiller | |||
| from kamal.core.tasks.loss import KDLoss | |||
| import torch.nn as nn | |||
| import torch._ops | |||
| import time | |||
| class CCDistiller(KDDistiller): | |||
| def __init__(self, logger=None, tb_writer=None ): | |||
| super(CCDistiller, self).__init__( logger, tb_writer ) | |||
| def setup(self, student, teacher, dataloader, optimizer, embed_s, embed_t, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None ): | |||
| super(CCDistiller, self).setup( | |||
| student, teacher, dataloader, optimizer, T=T, gamma=gamma, alpha=alpha, device=device) | |||
| self.embed_s = embed_s.to(self.device).train() | |||
| self.embed_t = embed_t.to(self.device).train() | |||
| self.stu_hooks = stu_hooks | |||
| self.tea_hooks = tea_hooks | |||
| self.out_flags = out_flags | |||
| def additional_kd_loss(self, engine, batch): | |||
| feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] | |||
| feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] | |||
| f_s = self.embed_s(feat_s[-1]) | |||
| f_t = self.embed_t(feat_t[-1]) | |||
| return torch.mean((torch.abs(f_s-f_t)[:-1] * torch.abs(f_s-f_t)[1:]).sum(1)) | |||
| class LinearEmbed(nn.Module): | |||
| def __init__(self, dim_in=1024, dim_out=128): | |||
| super(LinearEmbed, self).__init__() | |||
| self.linear = nn.Linear(dim_in, dim_out) | |||
| def forward(self, x): | |||
| x = x.view(x.shape[0], -1) | |||
| x = self.linear(x) | |||
| return x | |||
| @@ -0,0 +1 @@ | |||
| from .zskt import ZSKTDistiller | |||
| @@ -0,0 +1,99 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| import torch | |||
| import time | |||
| import torch.nn.functional as F | |||
| from kamal.slim.distillation.kd import KDDistiller | |||
| from kamal.utils import set_mode | |||
| from kamal.core.tasks.loss import kldiv | |||
| class ZSKTDistiller(KDDistiller): | |||
| def __init__( self, | |||
| student, | |||
| teacher, | |||
| generator, | |||
| z_dim, | |||
| logger=None, | |||
| viz=None): | |||
| super(ZSKTDistiller, self).__init__(logger, viz) | |||
| self.teacher = teacher | |||
| self.model = self.student = student | |||
| self.generator = generator | |||
| self.z_dim = z_dim | |||
| def train(self, start_iter, max_iter, optim_s, optim_g, device=None): | |||
| if device is None: | |||
| device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) | |||
| self.device = device | |||
| self.optim_s, self.optim_g = optim_s, optim_g | |||
| self.model.to(self.device) | |||
| self.teacher.to(self.device) | |||
| self.generator.to(self.device) | |||
| self.train_loader = [0, ] | |||
| with set_mode(self.student, training=True), \ | |||
| set_mode(self.teacher, training=False), \ | |||
| set_mode(self.generator, training=True): | |||
| super( ZSKTDistiller, self ).train( start_iter, max_iter ) | |||
| def search_optimizer(self, evaluator, train_loader, hpo_space=None, mode='min', max_evals=20, max_iters=400): | |||
| optimizer = hpo.search_optimizer(self, train_loader, evaluator=evaluator, hpo_space=hpo_space, mode=mode, max_evals=max_evals, max_iters=max_iters) | |||
| return optimizer | |||
| def step(self): | |||
| start_time = time.perf_counter() | |||
| # Adv | |||
| z = torch.randn( self.z_dim ).to(self.device) | |||
| fake = self.generator( z ) | |||
| self.optim_g.zero_grad() | |||
| t_out = self.teacher( fake ) | |||
| s_out = self.student( fake ) | |||
| loss_g = -kldiv( s_out, t_out ) | |||
| loss_g.backward() | |||
| self.optim_g.step() | |||
| with torch.no_grad(): | |||
| fake = self.generator( z ) | |||
| t_out = self.teacher( fake.detach() ) | |||
| for _ in range(10): | |||
| self.optim_s.zero_grad() | |||
| s_out = self.student( fake.detach() ) | |||
| loss_s = kldiv( s_out, t_out ) | |||
| loss_s.backward() | |||
| self.optim_s.step() | |||
| loss_dict = { | |||
| 'loss_g': loss_g, | |||
| 'loss_s': loss_s, | |||
| } | |||
| step_time = time.perf_counter() - start_time | |||
| # record training info | |||
| info = loss_dict | |||
| info['step_time'] = step_time | |||
| info['lr_s'] = float( self.optim_s.param_groups[0]['lr'] ) | |||
| info['lr_g'] = float( self.optim_g.param_groups[0]['lr'] ) | |||
| self.history.put_scalars( **info ) | |||
| def reset(self): | |||
| self.history = None | |||
| self._train_loader_iter = iter(train_loader) | |||
| self.iter = self.start_iter | |||
| @@ -0,0 +1,86 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| from .kd import KDDistiller | |||
| from kamal.core.tasks.loss import KDLoss | |||
| import torch.nn as nn | |||
| import torch._ops | |||
| import time | |||
| class HintDistiller(KDDistiller): | |||
| def __init__(self, logger=None, tb_writer=None ): | |||
| super(HintDistiller, self).__init__( logger, tb_writer ) | |||
| def setup(self, | |||
| student, teacher, regressor, dataloader, optimizer, | |||
| hint_layer=2, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, | |||
| stu_hooks=[], tea_hooks=[], out_flags=[], device=None): | |||
| super( HintDistiller, self ).setup( | |||
| student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device ) | |||
| self.regressor = regressor | |||
| self._hint_layer = hint_layer | |||
| self._beta = beta | |||
| self.stu_hooks = stu_hooks | |||
| self.tea_hooks = tea_hooks | |||
| self.out_flags = out_flags | |||
| self.regressor.to(device) | |||
| def additional_kd_loss(self, engine, batch): | |||
| feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] | |||
| feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] | |||
| f_s = self.regressor(feat_s[self._hint_layer]) | |||
| f_t = feat_t[self._hint_layer] | |||
| return nn.functional.mse_loss(f_s, f_t) | |||
| class Regressor(nn.Module): | |||
| """ | |||
| Convolutional regression for FitNet | |||
| @inproceedings{tian2019crd, | |||
| title={Contrastive Representation Distillation}, | |||
| author={Yonglong Tian and Dilip Krishnan and Phillip Isola}, | |||
| booktitle={International Conference on Learning Representations}, | |||
| year={2020} | |||
| } | |||
| """ | |||
| def __init__(self, s_shape, t_shape, is_relu=True): | |||
| super(Regressor, self).__init__() | |||
| self.is_relu = is_relu | |||
| _, s_C, s_H, s_W = s_shape | |||
| _, t_C, t_H, t_W = t_shape | |||
| if s_H == 2 * t_H: | |||
| self.conv = nn.Conv2d(s_C, t_C, kernel_size=3, stride=2, padding=1) | |||
| elif s_H * 2 == t_H: | |||
| self.conv = nn.ConvTranspose2d( | |||
| s_C, t_C, kernel_size=4, stride=2, padding=1) | |||
| elif s_H >= t_H: | |||
| self.conv = nn.Conv2d(s_C, t_C, kernel_size=(1+s_H-t_H, 1+s_W-t_W)) | |||
| else: | |||
| raise NotImplemented( | |||
| 'student size {}, teacher size {}'.format(s_H, t_H)) | |||
| self.bn = nn.BatchNorm2d(t_C) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| def forward(self, x): | |||
| x = self.conv(x) | |||
| if self.is_relu: | |||
| return self.relu(self.bn(x)) | |||
| else: | |||
| return self.bn(x) | |||
| @@ -0,0 +1,90 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| from kamal.core.engine.trainer import Engine | |||
| from kamal.core.tasks.loss import kldiv | |||
| import torch.nn.functional as F | |||
| from kamal.utils.logger import get_logger | |||
| from kamal.utils import set_mode, move_to_device | |||
| import weakref | |||
| import torch | |||
| import torch.nn as nn | |||
| import time | |||
| import numpy as np | |||
| class KDDistiller(Engine): | |||
| def __init__( self, | |||
| logger=None, | |||
| tb_writer=None): | |||
| super(KDDistiller, self).__init__(logger=logger, tb_writer=tb_writer) | |||
| def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, device=None): | |||
| self.model = self.student = student | |||
| self.teacher = teacher | |||
| self.dataloader = dataloader | |||
| self.optimizer = optimizer | |||
| if device is None: | |||
| device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) | |||
| self.device = device | |||
| self.T = T | |||
| self.gamma = gamma | |||
| self.alpha = alpha | |||
| self.beta = beta | |||
| self.student.to(self.device) | |||
| self.teacher.to(self.device) | |||
| def run( self, max_iter, start_iter=0, epoch_length=None): | |||
| with set_mode(self.student, training=True), \ | |||
| set_mode(self.teacher, training=False): | |||
| super( KDDistiller, self ).run( self.step_fn, self.dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length) | |||
| def additional_kd_loss(self, engine, batch): | |||
| return batch[0].new_zeros(1) | |||
| def step_fn(self, engine, batch): | |||
| student = self.student | |||
| teacher = self.teacher | |||
| start_time = time.perf_counter() | |||
| batch = move_to_device(batch, self.device) | |||
| inputs, targets = batch | |||
| outputs = student(inputs) | |||
| with torch.no_grad(): | |||
| soft_targets = teacher(inputs) | |||
| loss_dict = { "loss_kld": self.alpha * kldiv(outputs, soft_targets, T=self.T), | |||
| "loss_ce": self.beta * F.cross_entropy( outputs, targets ), | |||
| "loss_additional": self.gamma * self.additional_kd_loss(engine, batch) } | |||
| loss = sum( loss_dict.values() ) | |||
| self.optimizer.zero_grad() | |||
| loss.backward() | |||
| self.optimizer.step() | |||
| step_time = time.perf_counter() - start_time | |||
| metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } | |||
| metrics.update({ | |||
| 'total_loss': loss.item(), | |||
| 'step_time': step_time, | |||
| 'lr': float( self.optimizer.param_groups[0]['lr'] ) | |||
| }) | |||
| return metrics | |||
| @@ -0,0 +1,44 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| from .kd import KDDistiller | |||
| from kamal.core.tasks.loss import KDLoss | |||
| from kamal.core.tasks.loss import NSTLoss | |||
| import torch | |||
| import torch.nn as nn | |||
| import time | |||
| class NSTDistiller(KDDistiller): | |||
| def __init__(self, logger=None, tb_writer=None ): | |||
| super(NSTDistiller, self).__init__( logger, tb_writer ) | |||
| def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None): | |||
| super( NSTDistiller, self ).setup( | |||
| student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device ) | |||
| self.stu_hooks = stu_hooks | |||
| self.tea_hooks = tea_hooks | |||
| self.out_flags = out_flags | |||
| self._nst_loss = NSTLoss() | |||
| def additional_kd_loss(self, engine, batch): | |||
| feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] | |||
| feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] | |||
| g_s = feat_s[1:-1] | |||
| g_t = feat_t[1:-1] | |||
| return self._nst_loss(g_s, g_t) | |||
| @@ -0,0 +1,44 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| from .kd import KDDistiller | |||
| from kamal.core.tasks.loss import PKTLoss | |||
| from kamal.core.tasks.loss import KDLoss | |||
| import torch | |||
| import torch.nn as nn | |||
| import time | |||
| class PKTDistiller(KDDistiller): | |||
| def __init__(self, logger=None, tb_writer=None ): | |||
| super(PKTDistiller, self).__init__( logger, tb_writer ) | |||
| def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None): | |||
| super( PKTDistiller, self ).setup( | |||
| student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device ) | |||
| self.stu_hooks = stu_hooks | |||
| self.tea_hooks = tea_hooks | |||
| self.out_flags = out_flags | |||
| self._pkt_loss = PKTLoss() | |||
| def additional_kd_loss(self, engine, batch): | |||
| feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] | |||
| feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] | |||
| f_s = feat_s[-1] | |||
| f_t = feat_t[-1] | |||
| return self._pkt_loss(f_s, f_t) | |||
| @@ -0,0 +1,45 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| from .kd import KDDistiller | |||
| from kamal.core.tasks.loss import RKDLoss | |||
| from kamal.core.tasks.loss import KDLoss | |||
| import torch | |||
| import torch.nn as nn | |||
| import time | |||
| class RKDDistiller(KDDistiller): | |||
| def __init__(self, logger=None, tb_writer=None ): | |||
| super(RKDDistiller, self).__init__( logger, tb_writer ) | |||
| def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None): | |||
| super( RKDDistiller, self ).setup( | |||
| student, teacher, dataloader, optimizer, T=T, gamma=gamma, alpha=alpha, device=device ) | |||
| self.stu_hooks = stu_hooks | |||
| self.tea_hooks = tea_hooks | |||
| self.out_flags = out_flags | |||
| self._rkd_loss = RKDLoss() | |||
| def additional_kd_loss(self, engine, batch): | |||
| feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] | |||
| feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] | |||
| f_s = feat_s[-1] | |||
| f_t = feat_t[-1] | |||
| return self._rkd_loss(f_s, f_t) | |||
| @@ -0,0 +1,44 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| from .kd import KDDistiller | |||
| from kamal.core.tasks.loss import SPLoss | |||
| from kamal.core.tasks.loss import KDLoss | |||
| import torch | |||
| import torch.nn as nn | |||
| import time | |||
| class SPDistiller(KDDistiller): | |||
| def __init__(self, logger=None, tb_writer=None ): | |||
| super(SPDistiller, self).__init__( logger, tb_writer ) | |||
| def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None): | |||
| super( SPDistiller, self ).setup( | |||
| student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device ) | |||
| self.stu_hooks = stu_hooks | |||
| self.tea_hooks = tea_hooks | |||
| self.out_flags = out_flags | |||
| self._sp_loss = SPLoss() | |||
| def additional_kd_loss(self, engine, batch): | |||
| feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] | |||
| feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] | |||
| g_s = [feat_s[-2]] | |||
| g_t = [feat_t[-2]] | |||
| return self._sp_loss(g_s, g_t) | |||
| @@ -0,0 +1,45 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| from .kd import KDDistiller | |||
| from kamal.core.tasks.loss import SVDLoss | |||
| from kamal.core.tasks.loss import KDLoss | |||
| import torch | |||
| import torch.nn as nn | |||
| import time | |||
| class SVDDistiller(KDDistiller): | |||
| def __init__(self, logger=None, tb_writer=None ): | |||
| super(SVDDistiller, self).__init__( logger, tb_writer ) | |||
| def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None): | |||
| super( SVDDistiller, self ).setup( | |||
| student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device ) | |||
| self.stu_hooks = stu_hooks | |||
| self.tea_hooks = tea_hooks | |||
| self.out_flags = out_flags | |||
| self._svd_loss = SVDLoss() | |||
| def additional_kd_loss(self, engine, batch): | |||
| feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] | |||
| feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] | |||
| g_s = feat_s[1:-1] | |||
| g_t = feat_t[1:-1] | |||
| return self._svd_loss( g_s, g_t ) | |||
| @@ -0,0 +1,90 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| import numpy as np | |||
| import time | |||
| import torch.nn as nn | |||
| import torch._ops | |||
| import torch.nn.functional as F | |||
| from .kd import KDDistiller | |||
| from kamal.utils import set_mode | |||
| from kamal.core.tasks.loss import KDLoss | |||
| class VIDDistiller(KDDistiller): | |||
| def __init__(self, logger=None, tb_writer=None ): | |||
| super(VIDDistiller, self).__init__( logger, tb_writer ) | |||
| def setup(self, student, teacher, dataloader, optimizer, regressor_l, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None): | |||
| super( VIDDistiller, self ).setup( | |||
| student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device ) | |||
| self.regressor_l = regressor_l | |||
| self.stu_hooks = stu_hooks | |||
| self.tea_hooks = tea_hooks | |||
| self.out_flags = out_flags | |||
| self.regressor_l = [regressor.to(self.device).train() for regressor in self.regressor_l] | |||
| def additional_kd_loss(self, engine, batch): | |||
| feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] | |||
| feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] | |||
| g_s = feat_s[1:-1] | |||
| g_t = feat_t[1:-1] | |||
| return sum([c(f_s, f_t) for f_s, f_t, c in zip(g_s, g_t, self.regressor_l)]) | |||
| class VIDRegressor(nn.Module): | |||
| def __init__(self, | |||
| num_input_channels, | |||
| num_mid_channel, | |||
| num_target_channels, | |||
| init_pred_var=5.0, | |||
| eps=1e-5): | |||
| super(VIDRegressor, self).__init__() | |||
| def conv1x1(in_channels, out_channels, stride=1): | |||
| return nn.Conv2d( | |||
| in_channels, out_channels, | |||
| kernel_size=1, padding=0, | |||
| bias=False, stride=stride) | |||
| self.regressor = nn.Sequential( | |||
| conv1x1(num_input_channels, num_mid_channel), | |||
| nn.ReLU(), | |||
| conv1x1(num_mid_channel, num_mid_channel), | |||
| nn.ReLU(), | |||
| conv1x1(num_mid_channel, num_target_channels), | |||
| ) | |||
| self.log_scale = torch.nn.Parameter( | |||
| np.log(np.exp(init_pred_var-eps)-1.0) * torch.ones(num_target_channels) | |||
| ) | |||
| self.eps = eps | |||
| def forward(self, input, target): | |||
| # pool for dimentsion match | |||
| s_H, t_H = input.shape[2], target.shape[2] | |||
| if s_H > t_H: | |||
| input = F.adaptive_avg_pool2d(input, (t_H, t_H)) | |||
| elif s_H < t_H: | |||
| target = F.adaptive_avg_pool2d(target, (s_H, s_H)) | |||
| else: | |||
| pass | |||
| pred_mean = self.regressor(input) | |||
| pred_var = torch.log(1.0+torch.exp(self.log_scale))+self.eps | |||
| pred_var = pred_var.view(1, -1, 1, 1) | |||
| neg_log_prob = 0.5*( | |||
| (pred_mean-target)**2/pred_var+torch.log(pred_var) | |||
| ) | |||
| loss = torch.mean(neg_log_prob) | |||
| return loss | |||
| @@ -0,0 +1,2 @@ | |||
| from .pruner import Pruner | |||
| from .strategy import LNStrategy, RandomStrategy | |||
| @@ -0,0 +1,37 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| from copy import deepcopy | |||
| import os | |||
| import torch | |||
| class Pruner(object): | |||
| def __init__(self, strategy): | |||
| self.strategy = strategy | |||
| def prune(self, model, rate=0.1, example_inputs=None): | |||
| ori_num_params = sum( [ torch.numel(p) for p in model.parameters() ] ) | |||
| model = deepcopy(model).cpu() | |||
| model = self._prune( model, rate=rate, example_inputs=example_inputs ) | |||
| new_num_params = sum( [ torch.numel(p) for p in model.parameters() ] ) | |||
| print( "%d=>%d, %.2f%% params were pruned"%( ori_num_params, new_num_params, 100*(ori_num_params-new_num_params)/ori_num_params ) ) | |||
| return model | |||
| def _prune(self, model, **kargs): | |||
| return self.strategy( model, **kargs) | |||
| @@ -0,0 +1,85 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| import torch_pruning as tp | |||
| import abc | |||
| import torch | |||
| import torch.nn as nn | |||
| import random | |||
| import numpy as np | |||
| _PRUNABLE_MODULES= tp.DependencyGraph.PRUNABLE_MODULES | |||
| class BaseStrategy(abc.ABC): | |||
| @abc.abstractmethod | |||
| def select(self, layer_to_prune): | |||
| pass | |||
| def __call__(self, model, rate=0.1, example_inputs=None): | |||
| if example_inputs is None: | |||
| example_inputs = torch.randn( 1,3,256,256 ) | |||
| DG = tp.DependencyGraph() | |||
| DG.build_dependency(model, example_inputs=example_inputs) | |||
| prunable_layers = [] | |||
| total_params = 0 | |||
| num_accumulative_conv_params = [ 0, ] | |||
| for m in model.modules(): | |||
| if isinstance(m, _PRUNABLE_MODULES ) : | |||
| nparam = tp.utils.count_prunable_params( m ) | |||
| total_params += nparam | |||
| if isinstance(m, (nn.modules.conv._ConvNd, nn.Linear)): | |||
| prunable_layers.append( m ) | |||
| num_accumulative_conv_params.append( num_accumulative_conv_params[-1]+nparam ) | |||
| prunable_layers.pop(-1) # remove the last layer | |||
| num_accumulative_conv_params.pop(-1) # remove the last layer | |||
| num_conv_params = num_accumulative_conv_params[-1] | |||
| num_accumulative_conv_params = [ ( num_accumulative_conv_params[i], num_accumulative_conv_params[i+1] ) for i in range(len(num_accumulative_conv_params)-1) ] | |||
| def map_param_idx_to_conv_layer(i): | |||
| for l, accu in zip( prunable_layers, num_accumulative_conv_params ): | |||
| if accu[0]<=i and i<accu[1]: | |||
| return l | |||
| num_pruned = 0 | |||
| while num_pruned<total_params*rate: | |||
| layer_to_prune = map_param_idx_to_conv_layer( random.randint( 0, num_conv_params-1 ) ) | |||
| if layer_to_prune.weight.shape[0]<1: | |||
| continue | |||
| idx = self.select( layer_to_prune ) | |||
| fn = tp.prune_conv if isinstance(layer_to_prune, nn.modules.conv._ConvNd) else tp.prune_linear | |||
| plan = DG.get_pruning_plan( layer_to_prune, fn, idxs=idx ) | |||
| num_pruned += plan.exec() | |||
| return model | |||
| class RandomStrategy(BaseStrategy): | |||
| def select(self, layer_to_prune): | |||
| return [ random.randint( 0, layer_to_prune.weight.shape[0]-1 ) ] | |||
| class LNStrategy(BaseStrategy): | |||
| def __init__(self, n=2): | |||
| self.n = n | |||
| def select(self, layer_to_prune): | |||
| w = torch.flatten( layer_to_prune.weight, 1 ) | |||
| norm = torch.norm(w, p=self.n, dim=1) | |||
| idx = [ int(norm.min(dim=0)[1].item()) ] | |||
| return idx | |||
| @@ -0,0 +1,18 @@ | |||
| # Deep Model Transferbility from Attribution Maps | |||
| - [*"Paper: Deep Model Transferbility from Attribution Maps"*](https:), NeurIPS 2019.(released soon) | |||
| J. Song, Y. Chen, X. Wang, C. Shen, M. Song | |||
| [Homepage of VIPA Group](https://www.vipazoo.cn/index_en.html), Zhejiang University, China | |||
| This repo is rewrited by zhfeing in `Pytorch`. | |||
| [Homepage of VIPA Group](https://www.vipazoo.cn/index_en.html), Zhejiang University, China | |||
| <div align="left"> | |||
| <img src="vipa-logo.png" width = "40%" height = "40%" alt="icon"/> | |||
| </div> | |||
| @@ -0,0 +1,20 @@ | |||
| from kamal.transferability.trans_graph import TransferabilityGraph | |||
| from kamal.transferability.trans_metric import AttrMapMetric | |||
| import kamal | |||
| from kamal.vision import sync_transforms as sT | |||
| import os | |||
| import torch | |||
| from PIL import Image | |||
| if __name__=='__main__': | |||
| zoo = '/tmp/pycharm_project_225/kamal/transferability/model2' | |||
| TG = TransferabilityGraph(zoo) | |||
| probe_set_root = '/tmp/pycharm_project_225/kamal/transferability/probe_data' | |||
| for probe_set in os.listdir( probe_set_root ): | |||
| print("Add %s"%(probe_set)) | |||
| imgs_set = list( os.listdir( os.path.join( probe_set_root, probe_set ) ) ) | |||
| images = [ Image.open( os.path.join(probe_set_root, probe_set, img) ) for img in imgs_set ] | |||
| metric = AttrMapMetric(images, device=torch.device('cuda')) | |||
| TG.add_metric( probe_set, metric) | |||
| TG.export_to_json(probe_set, 'exported_metrics/%s.json'%(probe_set), topk=3, normalize=True) | |||
| @@ -0,0 +1,3 @@ | |||
| from .attribution_graph import get_attribution_graph, graph_similarity | |||
| from .attribution_map import attribution_map, attr_map_distance | |||
| @@ -0,0 +1,184 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| from typing import Type, Dict, Any | |||
| import itertools | |||
| import numpy as np | |||
| from numpy import Inf | |||
| import networkx | |||
| import torch | |||
| from torch.nn import Module | |||
| from torch.utils.data import Dataset | |||
| from .attribution_map import attribution_map, attr_map_similarity | |||
| def graph_to_array(graph: networkx.Graph): | |||
| weight_matrix = np.zeros((len(graph.nodes), len(graph.nodes))) | |||
| for i, n1 in enumerate(graph.nodes): | |||
| for j, n2 in enumerate(graph.nodes): | |||
| try: | |||
| dist = graph[n1][n2]["weight"] | |||
| except KeyError: | |||
| dist = 1 | |||
| weight_matrix[i, j] = dist | |||
| return weight_matrix | |||
| class FeatureMapExtractor(): | |||
| def __init__(self, module: Module): | |||
| self.module = module | |||
| self.feature_pool: Dict[str, Dict[str, Any]] = dict() | |||
| self.register_hooks() | |||
| def register_hooks(self): | |||
| for name, m in self.module.named_modules(): | |||
| if "pool" in name: | |||
| m.name = name | |||
| self.feature_pool[name] = dict() | |||
| def hook(m: Module, input, output): | |||
| self.feature_pool[m.name]["feature"] = input | |||
| self.feature_pool[name]["handle"] = m.register_forward_hook(hook) | |||
| def _forward(self, x): | |||
| self.module(x) | |||
| def remove_hooks(self): | |||
| for name, cfg in self.feature_pool.items(): | |||
| cfg["handle"].remove() | |||
| cfg.clear() | |||
| self.feature_pool.clear() | |||
| def extract_final_map(self, x): | |||
| self._forward(x) | |||
| feature_map = None | |||
| max_channel = 0 | |||
| min_size = Inf | |||
| for name, cfg in self.feature_pool.items(): | |||
| f = cfg["feature"] | |||
| if len(f) == 1 and isinstance(f[0], torch.Tensor): | |||
| f = f[0] | |||
| if f.dim() == 4: # BxCxHxW | |||
| b, c, h, w = f.shape | |||
| if c >= max_channel and 1 < h * w <= min_size: | |||
| feature_map = f | |||
| max_channel = c | |||
| min_size = h * w | |||
| return feature_map | |||
| def get_attribution_graph( | |||
| model: Module, | |||
| attribution_type: Type, | |||
| with_noise: bool, | |||
| probe_data: Dataset, | |||
| device: torch.device, | |||
| norm_square: bool = False, | |||
| ): | |||
| attribution_graph = networkx.Graph() | |||
| model = model.to(device) | |||
| extractor = FeatureMapExtractor(model) | |||
| for i, x in enumerate(probe_data): | |||
| x = x.to(device) | |||
| x.requires_grad_() | |||
| attribution = attribution_map( | |||
| func=lambda x: extractor.extract_final_map(x), | |||
| attribution_type=attribution_type, | |||
| with_noise=with_noise, | |||
| probe_data=x.unsqueeze(0), | |||
| norm_square=norm_square | |||
| ) | |||
| attribution_graph.add_node(i, attribution_map=attribution) | |||
| nodes = attribution_graph.nodes | |||
| for i, j in itertools.product(nodes, nodes): | |||
| if i < j: | |||
| weight = attr_map_similarity( | |||
| attribution_graph.nodes(data=True)[i]["attribution_map"], | |||
| attribution_graph.nodes(data=True)[j]["attribution_map"] | |||
| ) | |||
| attribution_graph.add_edge(i, j, weight=weight) | |||
| return attribution_graph | |||
| def edge_to_embedding(graph: networkx.Graph): | |||
| adj = graph_to_array(graph) | |||
| up_tri_mask = np.tri(*adj.shape[-2:], k=0, dtype=bool) | |||
| return adj[up_tri_mask] | |||
| def embedding_to_rank(embedding: np.ndarray): | |||
| order = embedding.argsort() | |||
| ranks = order.argsort() | |||
| return ranks | |||
| def graph_similarity(g1: networkx.Graph, g2: networkx.Graph, Lambda: float = 1.0): | |||
| nodes_1 = g1.nodes(data=True) | |||
| nodes_2 = g2.nodes(data=True) | |||
| assert len(nodes_1) == len(nodes_2) | |||
| # calculate vertex similarity | |||
| v_s = 0 | |||
| n = len(g1.nodes) | |||
| for i in range(n): | |||
| v_s += attr_map_similarity( | |||
| map_1=g1.nodes(data=True)[i]["attribution_map"], | |||
| map_2=g2.nodes(data=True)[i]["attribution_map"] | |||
| ) | |||
| vertex_similarity = v_s / n | |||
| # calculate edges similarity | |||
| emb_1 = edge_to_embedding(g1) | |||
| emb_2 = edge_to_embedding(g2) | |||
| rank_1 = embedding_to_rank(emb_1) | |||
| rank_2 = embedding_to_rank(emb_2) | |||
| k = emb_1.shape[0] | |||
| edge_similarity = 1 - 6 * np.sum(np.square(rank_1 - rank_2)) / (k ** 3 - k) | |||
| return vertex_similarity + Lambda * edge_similarity | |||
| if __name__ == "__main__": | |||
| from captum.attr import InputXGradient | |||
| from torchvision.models import resnet34 | |||
| model_1 = resnet34(num_classes=10) | |||
| graph_1 = get_attribution_graph( | |||
| model_1, | |||
| attribution_type=InputXGradient, | |||
| with_noise=False, | |||
| probe_data=torch.rand(10, 3, 244, 244), | |||
| device=torch.device("cpu") | |||
| ) | |||
| model_2 = resnet34(num_classes=10) | |||
| graph_2 = get_attribution_graph( | |||
| model_2, | |||
| attribution_type=InputXGradient, | |||
| with_noise=False, | |||
| probe_data=torch.rand(10, 3, 244, 244), | |||
| device=torch.device("cpu") | |||
| ) | |||
| print(graph_similarity(graph_1, graph_2)) | |||
| @@ -0,0 +1,87 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| import torch | |||
| from typing import Type, Callable | |||
| from captum.attr import Attribution | |||
| from captum.attr import NoiseTunnel | |||
| def with_norm(func: Callable[[torch.Tensor], torch.Tensor], x: torch.Tensor, square: bool = False): | |||
| x = func(x) | |||
| x = torch.norm(x.flatten(1), dim=1, p=2) | |||
| if square: | |||
| x = torch.pow(x, 2) | |||
| return x | |||
| def attribution_map( | |||
| func: Callable[[torch.Tensor], torch.Tensor], | |||
| attribution_type: Type, | |||
| with_noise: bool, | |||
| probe_data: torch.Tensor, | |||
| norm_square: bool = False, | |||
| **attribution_kwargs | |||
| ) -> torch.Tensor: | |||
| """ | |||
| Calculate attribution map with given attribution type(algorithm). | |||
| Args: | |||
| model: pytorch module | |||
| attribution_type: attribution algorithm, e.g. IntegratedGradients, InputXGradient, ... | |||
| with_noise: whether to add noise tunnel | |||
| probe_data: input data to model | |||
| device: torch.device("cuda: 0") | |||
| attribution_kwargs: other kwargs for attribution method | |||
| Return: attribution map | |||
| """ | |||
| attribution: Attribution = attribution_type(lambda x: with_norm(func, x, norm_square)) | |||
| if with_noise: | |||
| attribution = NoiseTunnel(attribution) | |||
| attr_map = attribution.attribute( | |||
| inputs=probe_data, | |||
| target=None, | |||
| **attribution_kwargs | |||
| ) | |||
| return attr_map.detach() | |||
| def attr_map_distance(map_1: torch.Tensor, map_2: torch.Tensor): | |||
| if map_1.shape != map_2.shape: | |||
| map_1 = torch.nn.functional.interpolate( map_1, size=map_2.shape[-2:] ) | |||
| #dist = torch.dist(map_1.flatten(1), map_2.flatten(1), p=2).mean() | |||
| dist = 1 - torch.cosine_similarity(map_1.flatten(1), map_2.flatten(1)).mean() | |||
| return dist.item() | |||
| def attr_map_similarity(map_1: torch.Tensor, map_2: torch.Tensor): | |||
| assert(map_1.shape == map_2.shape) | |||
| dist = torch.cosine_similarity(map_1.flatten(1), map_2.flatten(1)).mean() | |||
| return dist.item() | |||
| if __name__ == "__main__": | |||
| import captum | |||
| def ff(x): | |||
| return x ** 2 | |||
| m = attribution_map( | |||
| ff, | |||
| captum.attr.InputXGradient, | |||
| with_noise=False, | |||
| probe_data=torch.tensor([[1, 2, 3, 4]], dtype=torch.float, requires_grad=True), | |||
| norm_square=True | |||
| ) | |||
| print(m) | |||
| @@ -0,0 +1,135 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| import torch | |||
| import networkx as nx | |||
| from . import depara | |||
| import os, abc | |||
| from typing import Callable | |||
| from kamal import hub | |||
| import json, numbers | |||
| from tqdm import tqdm | |||
| class Node(object): | |||
| def __init__(self, hub_root, entry_name, spec_name): | |||
| self.hub_root = hub_root | |||
| self.entry_name = entry_name | |||
| self.spec_name = spec_name | |||
| @property | |||
| def model(self): | |||
| return hub.load( self.hub_root, entry_name=self.entry_name, spec_name=self.spec_name ).eval() | |||
| @property | |||
| def tag(self): | |||
| return hub.load_tags(self.hub_root, entry_name=self.entry_name, spec_name=self.spec_name) | |||
| @property | |||
| def metadata(self): | |||
| return hub.load_metadata(self.hub_root, entry_name=self.entry_name, spec_name=self.spec_name) | |||
| class TransferabilityGraph(object): | |||
| def __init__(self, model_zoo_set): | |||
| self.model_zoo_set = model_zoo_set | |||
| # self.model_zoo = os.path.abspath( os.path.expanduser( model_zoo ) ) | |||
| self._graphs = dict() | |||
| self._models = dict() | |||
| self._register_models() | |||
| def _register_models(self): | |||
| cnt = 0 | |||
| for model_zoo in self.model_zoo_set: | |||
| model_zoo = os.path.abspath(os.path.expanduser(model_zoo)) | |||
| for hub_root in self._list_modelzoo(model_zoo): | |||
| for entry_name, spec_name in hub.list_spec(hub_root): | |||
| node = Node( hub_root, entry_name, spec_name ) | |||
| name = node.metadata['name'] | |||
| self._models[name] = node | |||
| cnt += 1 | |||
| print("%d models has been registered!"%cnt) | |||
| def _list_modelzoo(self, zoo_dir): | |||
| zoo_list = [] | |||
| def _traverse(path): | |||
| for item in os.listdir(path): | |||
| item_path = os.path.join(path, item) | |||
| if os.path.isdir(item_path): | |||
| if os.path.exists(os.path.join( item_path, 'code/hubconf.py' )): | |||
| zoo_list.append(item_path) | |||
| else: | |||
| _traverse( item_path ) | |||
| _traverse(zoo_dir) | |||
| return zoo_list | |||
| def add_metric(self, metric_name, metric): | |||
| self._graphs[metric_name] = g = nx.DiGraph() | |||
| g.add_nodes_from( self._models.values() ) | |||
| for n1 in self._models.values(): | |||
| for n2 in tqdm(self._models.values()): | |||
| if n1!=n2 and not g.has_edge(n1, n2): | |||
| try: | |||
| g.add_edge(n1, n2, dist=metric( n1, n2 )) | |||
| except: | |||
| ori_device = metric.device | |||
| metric.device = torch.device('cpu') | |||
| g.add_edge(n1, n2, dist=metric( n1, n2 )) | |||
| metric.device = ori_device | |||
| def export_to_json(self, metric_name, output_filename, topk=None, normalize=False): | |||
| graph = self._graphs.get( metric_name, None ) | |||
| assert graph is not None | |||
| graph_data={ | |||
| 'nodes': [], | |||
| 'edges': [], | |||
| } | |||
| node_to_idx = {} | |||
| for i, node in enumerate(self._models.values()): | |||
| tags = node.tag | |||
| metadata = node.metadata | |||
| node_data = { k:v for (k, v) in tags.items() if isinstance(v, (numbers.Number, str) ) } | |||
| node_data['name'] = metadata['name'] | |||
| node_data['task'] = metadata['task'] | |||
| node_data['dataset'] = metadata['dataset'] | |||
| node_data['url'] = metadata['url'] | |||
| node_data['id'] = i | |||
| graph_data['nodes'].append({'tags': node_data}) | |||
| node_to_idx[node] = i | |||
| # record Edges | |||
| edge_list = graph_data['edges'] | |||
| topk_dist = { idx: [] for idx in range(len( self._models )) } | |||
| for i, edge in enumerate(graph.edges.data('dist')): | |||
| s, t, d = int( node_to_idx[edge[0]] ), int( node_to_idx[edge[1]] ), float(edge[2]) | |||
| topk_dist[s].append(d) | |||
| edge_list.append([ | |||
| s, t, d # source, target, distance | |||
| ]) | |||
| if isinstance(topk, int): | |||
| for i, dist in topk_dist.items(): | |||
| dist.sort() | |||
| topk_dist[i] = dist[topk] | |||
| graph_data['edges'] = [ edge for edge in edge_list if edge[2] < topk_dist[edge[0]] ] | |||
| if normalize: | |||
| edge_dist = [e[2] for e in graph_data['edges']] | |||
| min_dist, max_dist = min(edge_dist), max(edge_dist) | |||
| for e in graph_data['edges']: | |||
| e[2] = (e[2] - min_dist+1e-8) / (max_dist - min_dist+1e-8) | |||
| with open(output_filename, 'w') as fp: | |||
| json.dump(graph_data, fp) | |||
| @@ -0,0 +1,109 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| import abc | |||
| from . import depara | |||
| from captum.attr import InputXGradient | |||
| import torch | |||
| from kamal.core import hub | |||
| from kamal.vision import sync_transforms as sT | |||
| class TransMetric(abc.ABC): | |||
| def __init__(self): | |||
| pass | |||
| def __call__(self, a, b) -> float: | |||
| return 0 | |||
| class DeparaMetric(TransMetric): | |||
| def __init__(self, data, device): | |||
| self.data = data | |||
| self.device = device | |||
| self._cache = {} | |||
| def _get_transform(self, metadata): | |||
| input_metadata = metadata['input'] | |||
| size = input_metadata['size'] | |||
| space = input_metadata['space'] | |||
| drange = input_metadata['range'] | |||
| normalize = input_metadata['normalize'] | |||
| if size==None: | |||
| size=224 | |||
| if isinstance(size, (list, tuple)): | |||
| size = size[-1] | |||
| transform = [ | |||
| sT.Resize(size), | |||
| sT.CenterCrop(size), | |||
| ] | |||
| if space=='bgr': | |||
| transform.append(sT.FlipChannels()) | |||
| if list(drange)==[0, 1]: | |||
| transform.append( sT.ToTensor() ) | |||
| elif list(drange)==[0, 255]: | |||
| transform.append( sT.ToTensor(normalize=False, dtype=torch.float) ) | |||
| else: | |||
| raise NotImplementedError | |||
| if normalize is not None: | |||
| transform.append(sT.Normalize( mean=normalize['mean'], std=normalize['std'] )) | |||
| return sT.Compose(transform) | |||
| def _get_attr_graph(self, n): | |||
| transform = self._get_transform(n.metadata) | |||
| data = torch.stack( [ transform( d ) for d in self.data ], dim=0 ) | |||
| return depara.get_attribution_graph( | |||
| n.model, | |||
| attribution_type=InputXGradient, | |||
| with_noise=False, | |||
| probe_data=data, | |||
| device=self.device | |||
| ) | |||
| def __call__(self, n1, n2): | |||
| attrgraph_1 = self._cache.get(n1, None) | |||
| attrgraph_2 = self._cache.get(n2, None) | |||
| if attrgraph_1 is None: | |||
| self._cache[n1] = attrgraph_1 = self._get_attr_graph(n1).cpu() | |||
| if attrgraph_2 is None: | |||
| self._cache[n2] = attrgraph_2 = self._get_attr_graph(n2).cpu() | |||
| result = depara.graph_similarity(attrgraph_1, attrgraph_2) | |||
| self._cache[n1] = self._cache[n1] | |||
| self._cache[n2] = self._cache[n2] | |||
| class AttrMapMetric(DeparaMetric): | |||
| def _get_attr_map(self, n): | |||
| transform = self._get_transform(n.metadata) | |||
| data = torch.stack( [ transform( d ).to(self.device) for d in self.data ], dim=0 ) | |||
| return depara.attribution_map( | |||
| n.model.to(self.device), | |||
| attribution_type=InputXGradient, | |||
| with_noise=False, | |||
| probe_data=data, | |||
| ) | |||
| def __call__(self, n1, n2): | |||
| attrgraph_1 = self._cache.get(n1, None) | |||
| attrgraph_2 = self._cache.get(n2, None) | |||
| if attrgraph_1 is None: | |||
| self._cache[n1] = attrgraph_1 = self._get_attr_map(n1).cpu() | |||
| if attrgraph_2 is None: | |||
| self._cache[n2] = attrgraph_2 = self._get_attr_map(n2).cpu() | |||
| result = depara.attr_map_distance(attrgraph_1, attrgraph_2) | |||
| self._cache[n1] = self._cache[n1] | |||
| self._cache[n2] = self._cache[n2] | |||
| return result | |||
| @@ -0,0 +1,2 @@ | |||
| from ._utils import * | |||
| from .logger import get_logger | |||
| @@ -0,0 +1,153 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| import numpy as np | |||
| import math | |||
| import torch | |||
| import random | |||
| from copy import deepcopy | |||
| import contextlib, hashlib | |||
| def split_batch(batch): | |||
| if isinstance(batch, (list, tuple)): | |||
| inputs, *targets = batch | |||
| if len(targets)==1: | |||
| targets = targets[0] | |||
| return inputs, targets | |||
| else: | |||
| return [batch, None] | |||
| @contextlib.contextmanager | |||
| def set_mode(model, training=True): | |||
| ori_mode = model.training | |||
| model.train(training) | |||
| yield | |||
| model.train(ori_mode) | |||
| def move_to_device(obj, device): | |||
| if isinstance(obj, torch.Tensor): | |||
| return obj.to(device=device) | |||
| elif isinstance( obj, (list, tuple) ): | |||
| return [ o.to(device=device) for o in obj ] | |||
| elif isinstance(obj, nn.Module): | |||
| return obj.to(device=device) | |||
| def pack_images(images, col=None, channel_last=False): | |||
| # N, C, H, W | |||
| if isinstance(images, (list, tuple) ): | |||
| images = np.stack(images, 0) | |||
| if channel_last: | |||
| images = images.transpose(0,3,1,2) # make it channel first | |||
| assert len(images.shape)==4 | |||
| assert isinstance(images, np.ndarray) | |||
| N,C,H,W = images.shape | |||
| if col is None: | |||
| col = int(math.ceil(math.sqrt(N))) | |||
| row = int(math.ceil(N / col)) | |||
| pack = np.zeros( (C, H*row, W*col), dtype=images.dtype ) | |||
| for idx, img in enumerate(images): | |||
| h = (idx//col) * H | |||
| w = (idx% col) * W | |||
| pack[:, h:h+H, w:w+W] = img | |||
| return pack | |||
| def normalize(tensor, mean, std, reverse=False): | |||
| if reverse: | |||
| _mean = [ -m / s for m, s in zip(mean, std) ] | |||
| _std = [ 1/s for s in std ] | |||
| else: | |||
| _mean = mean | |||
| _std = std | |||
| _mean = torch.as_tensor(_mean, dtype=tensor.dtype, device=tensor.device) | |||
| _std = torch.as_tensor(_std, dtype=tensor.dtype, device=tensor.device) | |||
| tensor = (tensor - _mean[None, :, None, None]) / (_std[None, :, None, None]) | |||
| return tensor | |||
| class Normalizer(object): | |||
| def __init__(self, mean, std, reverse=False): | |||
| self.mean = mean | |||
| self.std = std | |||
| self.reverse = reverse | |||
| def __call__(self, x): | |||
| if self.reverse: | |||
| return self.denormalize(x) | |||
| else: | |||
| return self.normalize(x) | |||
| def normalize(self, x): | |||
| return normalize( x, self.mean, self.std ) | |||
| def denormalize(self, x): | |||
| return normalize( x, self.mean, self.std, reverse=True ) | |||
| def colormap(N=256, normalized=False): | |||
| def bitget(byteval, idx): | |||
| return ((byteval & (1 << idx)) != 0) | |||
| dtype = 'float32' if normalized else 'uint8' | |||
| cmap = np.zeros((N, 3), dtype=dtype) | |||
| for i in range(N): | |||
| r = g = b = 0 | |||
| c = i | |||
| for j in range(8): | |||
| r = r | (bitget(c, 0) << 7-j) | |||
| g = g | (bitget(c, 1) << 7-j) | |||
| b = b | (bitget(c, 2) << 7-j) | |||
| c = c >> 3 | |||
| cmap[i] = np.array([r, g, b]) | |||
| cmap = cmap/255 if normalized else cmap | |||
| return cmap | |||
| DEFAULT_COLORMAP = colormap() | |||
| def flatten_dict(dic): | |||
| flattned = dict() | |||
| def _flatten(prefix, d): | |||
| for k, v in d.items(): | |||
| if isinstance(v, dict): | |||
| if prefix is None: | |||
| _flatten( k, v ) | |||
| else: | |||
| _flatten( prefix+'%s/'%k, v ) | |||
| else: | |||
| flattned[ (prefix+'%s/'%k).strip('/') ] = v | |||
| _flatten('', dic) | |||
| return flattned | |||
| def setup_seed(seed): | |||
| torch.manual_seed(seed) | |||
| torch.cuda.manual_seed(seed) | |||
| np.random.seed(seed) | |||
| random.seed(seed) | |||
| def count_parameters(model): | |||
| return sum( [ p.numel() for p in model.parameters() ] ) | |||
| def md5(fname): | |||
| hash_md5 = hashlib.md5() | |||
| with open(fname, "rb") as f: | |||
| for chunk in iter(lambda: f.read(4096), b""): | |||
| hash_md5.update(chunk) | |||
| return hash_md5.hexdigest() | |||
| @@ -0,0 +1,56 @@ | |||
| import logging | |||
| import os, sys | |||
| from termcolor import colored | |||
| class _ColorfulFormatter(logging.Formatter): | |||
| def __init__(self, *args, **kwargs): | |||
| super(_ColorfulFormatter, self).__init__(*args, **kwargs) | |||
| def formatMessage(self, record): | |||
| log = super(_ColorfulFormatter, self).formatMessage(record) | |||
| if record.levelno == logging.WARNING: | |||
| prefix = colored("WARNING", "yellow", attrs=["blink"]) | |||
| elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: | |||
| prefix = colored("ERROR", "red", attrs=["blink", "underline"]) | |||
| else: | |||
| return log | |||
| return prefix + " " + log | |||
| def get_logger(name='Kamal', output=None, color=True): | |||
| logger = logging.getLogger(name) | |||
| logger.setLevel(logging.DEBUG) | |||
| logger.propagate = False | |||
| # STDOUT | |||
| stdout_handler = logging.StreamHandler( stream=sys.stdout ) | |||
| stdout_handler.setLevel( logging.DEBUG ) | |||
| plain_formatter = logging.Formatter( | |||
| "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" ) | |||
| if color: | |||
| formatter = _ColorfulFormatter( | |||
| colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", | |||
| datefmt="%m/%d %H:%M:%S") | |||
| else: | |||
| formatter = plain_formatter | |||
| stdout_handler.setFormatter(formatter) | |||
| logger.addHandler(stdout_handler) | |||
| # FILE | |||
| if output is not None: | |||
| if output.endswith('.txt') or output.endswith('.log'): | |||
| os.makedirs(os.path.dirname(output), exist_ok=True) | |||
| filename = output | |||
| else: | |||
| os.makedirs(output, exist_ok=True) | |||
| filename = os.path.join(output, "log.txt") | |||
| file_handler = logging.FileHandler(filename) | |||
| file_handler.setFormatter(plain_formatter) | |||
| file_handler.setLevel(logging.DEBUG) | |||
| logger.addHandler(file_handler) | |||
| return logger | |||
| @@ -0,0 +1,3 @@ | |||
| from . import models | |||
| from . import datasets | |||
| from . import sync_transforms | |||
| @@ -0,0 +1,16 @@ | |||
| from .ade20k import ADE20K | |||
| from .caltech import Caltech101, Caltech256 | |||
| from .camvid import CamVid | |||
| from .cityscapes import Cityscapes | |||
| from .cub200 import CUB200 | |||
| from .fgvc_aircraft import FGVCAircraft | |||
| from .nyu import NYUv2 | |||
| from .stanford_cars import StanfordCars | |||
| from .stanford_dogs import StanfordDogs | |||
| from .sunrgbd import SunRGBD | |||
| from .voc import VOCClassification, VOCSegmentation | |||
| from .dataset import LabelConcatDataset | |||
| from torchvision import datasets as torchvision_datasets | |||
| from .unlabeled import UnlabeledDataset | |||
| @@ -0,0 +1,70 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| import collections | |||
| import torch | |||
| import torchvision | |||
| import numpy as np | |||
| from PIL import Image | |||
| import os | |||
| from torchvision.datasets import VisionDataset | |||
| from .utils import colormap | |||
| class ADE20K(VisionDataset): | |||
| cmap = colormap() | |||
| def __init__( | |||
| self, | |||
| root, | |||
| split="training", | |||
| transform=None, | |||
| target_transform=None, | |||
| transforms=None, | |||
| ): | |||
| super( ADE20K, self ).__init__( root=root, transforms=transforms, transform=transform, target_transform=target_transform ) | |||
| assert split in ['training', 'validation'], "split should be \'training\' or \'validation\'" | |||
| self.root = os.path.expanduser(root) | |||
| self.split = split | |||
| self.num_classes = 150 | |||
| img_list = [] | |||
| lbl_list = [] | |||
| img_dir = os.path.join( self.root, 'images', self.split ) | |||
| lbl_dir = os.path.join( self.root, 'annotations', self.split ) | |||
| for img_name in os.listdir( img_dir ): | |||
| img_list.append( os.path.join( img_dir, img_name ) ) | |||
| lbl_list.append( os.path.join( lbl_dir, img_name[:-3]+'png') ) | |||
| self.img_list = img_list | |||
| self.lbl_list = lbl_list | |||
| def __len__(self): | |||
| return len(self.img_list) | |||
| def __getitem__(self, index): | |||
| img = Image.open( self.img_list[index] ) | |||
| lbl = Image.open( self.lbl_list[index] ) | |||
| if self.transforms: | |||
| img, lbl = self.transforms(img, lbl) | |||
| lbl = np.array(lbl, dtype='uint8')-1 # 1-150 => 0-149 + 255 | |||
| return img, lbl | |||
| @classmethod | |||
| def decode_seg_to_color(cls, mask): | |||
| """decode semantic mask to RGB image""" | |||
| return cls.cmap[mask+1] | |||
| @@ -0,0 +1,226 @@ | |||
| # Modified from https://github.com/pytorch/vision/blob/master/torchvision/datasets/caltech.py | |||
| from __future__ import print_function | |||
| from PIL import Image | |||
| import os | |||
| import os.path | |||
| from torchvision.datasets.vision import VisionDataset | |||
| from torchvision.datasets.utils import download_url | |||
| class Caltech101(VisionDataset): | |||
| """`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset. | |||
| Args: | |||
| root (string): Root directory of dataset where directory | |||
| ``caltech101`` exists or will be saved to if download is set to True. | |||
| target_type (string or list, optional): Type of target to use, ``category`` or | |||
| ``annotation``. Can also be a list to output a tuple with all specified target types. | |||
| ``category`` represents the target class, and ``annotation`` is a list of points | |||
| from a hand-generated outline. Defaults to ``category``. | |||
| transform (callable, optional): A function/transform that takes in an PIL image | |||
| and returns a transformed version. E.g, ``transforms.RandomCrop`` | |||
| target_transform (callable, optional): A function/transform that takes in the | |||
| target and transforms it. | |||
| download (bool, optional): If true, downloads the dataset from the internet and | |||
| puts it in root directory. If dataset is already downloaded, it is not | |||
| downloaded again. | |||
| """ | |||
| def __init__(self, root, target_type="category", train=True, | |||
| transform=None, target_transform=None, | |||
| download=False): | |||
| super(Caltech101, self).__init__(os.path.join(root, 'caltech101')) | |||
| self.train = train | |||
| self.dir_name = '101_ObjectCategories_split/train' if self.train else '101_ObjectCategories_split/test' | |||
| os.makdirs(self.root, exist_ok=True) | |||
| if isinstance(target_type, list): | |||
| self.target_type = target_type | |||
| else: | |||
| self.target_type = [target_type] | |||
| self.transform = transform | |||
| self.target_transform = target_transform | |||
| if download: | |||
| self.download() | |||
| if not self._check_integrity(): | |||
| raise RuntimeError('Dataset not found or corrupted.' + | |||
| ' You can use download=True to download it') | |||
| self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories"))) | |||
| self.categories.remove("BACKGROUND_Google") # this is not a real class | |||
| # For some reason, the category names in "101_ObjectCategories" and | |||
| # "Annotations" do not always match. This is a manual map between the | |||
| # two. Defaults to using same name, since most names are fine. | |||
| name_map = {"Faces": "Faces_2", | |||
| "Faces_easy": "Faces_3", | |||
| "Motorbikes": "Motorbikes_16", | |||
| "airplanes": "Airplanes_Side_2"} | |||
| self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories)) | |||
| self.index = [] | |||
| self.y = [] | |||
| for (i, c) in enumerate(self.categories): | |||
| file_names = os.listdir(os.path.join(self.root, self.dir_name, c)) | |||
| n = len(file_names) | |||
| self.index.extend( file_names ) | |||
| self.y.extend(n * [i]) | |||
| print(self.train, len(self.index)) | |||
| def __getitem__(self, index): | |||
| """ | |||
| Args: | |||
| index (int): Index | |||
| Returns: | |||
| tuple: (image, target) where the type of target specified by target_type. | |||
| """ | |||
| import scipy.io | |||
| img = Image.open(os.path.join(self.root, | |||
| self.dir_name, | |||
| self.categories[self.y[index]], | |||
| self.index[index])).convert("RGB") | |||
| target = [] | |||
| for t in self.target_type: | |||
| if t == "category": | |||
| target.append(self.y[index]) | |||
| elif t == "annotation": | |||
| data = scipy.io.loadmat(os.path.join(self.root, | |||
| "Annotations", | |||
| self.annotation_categories[self.y[index]], | |||
| "annotation_{:04d}.mat".format(self.index[index]))) | |||
| target.append(data["obj_contour"]) | |||
| else: | |||
| raise ValueError("Target type \"{}\" is not recognized.".format(t)) | |||
| target = tuple(target) if len(target) > 1 else target[0] | |||
| if self.transform is not None: | |||
| img = self.transform(img) | |||
| if self.target_transform is not None: | |||
| target = self.target_transform(target) | |||
| return img, target | |||
| def _check_integrity(self): | |||
| # can be more robust and check hash of files | |||
| return os.path.exists(os.path.join(self.root, "101_ObjectCategories")) | |||
| def __len__(self): | |||
| return len(self.index) | |||
| def download(self): | |||
| import tarfile | |||
| if self._check_integrity(): | |||
| print('Files already downloaded and verified') | |||
| return | |||
| download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz", | |||
| self.root, | |||
| "101_ObjectCategories.tar.gz", | |||
| "b224c7392d521a49829488ab0f1120d9") | |||
| download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar", | |||
| self.root, | |||
| "101_Annotations.tar", | |||
| "6f83eeb1f24d99cab4eb377263132c91") | |||
| # extract file | |||
| with tarfile.open(os.path.join(self.root, "101_ObjectCategories.tar.gz"), "r:gz") as tar: | |||
| tar.extractall(path=self.root) | |||
| with tarfile.open(os.path.join(self.root, "101_Annotations.tar"), "r:") as tar: | |||
| tar.extractall(path=self.root) | |||
| def extra_repr(self): | |||
| return "Target type: {target_type}".format(**self.__dict__) | |||
| class Caltech256(VisionDataset): | |||
| """`Caltech 256 <http://www.vision.caltech.edu/Image_Datasets/Caltech256/>`_ Dataset. | |||
| Args: | |||
| root (string): Root directory of dataset where directory | |||
| ``caltech256`` exists or will be saved to if download is set to True. | |||
| transform (callable, optional): A function/transform that takes in an PIL image | |||
| and returns a transformed version. E.g, ``transforms.RandomCrop`` | |||
| target_transform (callable, optional): A function/transform that takes in the | |||
| target and transforms it. | |||
| download (bool, optional): If true, downloads the dataset from the internet and | |||
| puts it in root directory. If dataset is already downloaded, it is not | |||
| downloaded again. | |||
| """ | |||
| def __init__(self, root, | |||
| transform=None, target_transform=None, | |||
| download=False): | |||
| super(Caltech256, self).__init__(os.path.join(root, 'caltech256')) | |||
| os.makedirs(self.root, exist_ok=True) | |||
| self.transform = transform | |||
| self.target_transform = target_transform | |||
| if download: | |||
| self.download() | |||
| if not self._check_integrity(): | |||
| raise RuntimeError('Dataset not found or corrupted.' + | |||
| ' You can use download=True to download it') | |||
| self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories"))) | |||
| self.index = [] | |||
| self.y = [] | |||
| for (i, c) in enumerate(self.categories): | |||
| n = len(os.listdir(os.path.join(self.root, "256_ObjectCategories", c))) | |||
| self.index.extend(range(1, n + 1)) | |||
| self.y.extend(n * [i]) | |||
| def __getitem__(self, index): | |||
| """ | |||
| Args: | |||
| index (int): Index | |||
| Returns: | |||
| tuple: (image, target) where target is index of the target class. | |||
| """ | |||
| img = Image.open(os.path.join(self.root, | |||
| "256_ObjectCategories", | |||
| self.categories[self.y[index]], | |||
| "{:03d}_{:04d}.jpg".format(self.y[index] + 1, self.index[index]))) | |||
| target = self.y[index] | |||
| if self.transform is not None: | |||
| img = self.transform(img) | |||
| if self.target_transform is not None: | |||
| target = self.target_transform(target) | |||
| return img, target | |||
| def _check_integrity(self): | |||
| # can be more robust and check hash of files | |||
| return os.path.exists(os.path.join(self.root, "256_ObjectCategories")) | |||
| def __len__(self): | |||
| return len(self.index) | |||
| def download(self): | |||
| import tarfile | |||
| if self._check_integrity(): | |||
| print('Files already downloaded and verified') | |||
| return | |||
| download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar", | |||
| self.root, | |||
| "256_ObjectCategories.tar", | |||
| "67b4f42ca05d46448c6bb8ecd2220f6d") | |||
| # extract file | |||
| with tarfile.open(os.path.join(self.root, "256_ObjectCategories.tar"), "r:") as tar: | |||
| tar.extractall(path=self.root) | |||
| @@ -0,0 +1,78 @@ | |||
| # Modified from https://github.com/davidtvs/PyTorch-ENet/blob/master/data/camvid.py | |||
| import os | |||
| import torch.utils.data as data | |||
| from glob import glob | |||
| from PIL import Image | |||
| import numpy as np | |||
| from torchvision.datasets import VisionDataset | |||
| class CamVid(VisionDataset): | |||
| """CamVid dataset loader where the dataset is arranged as in https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid. | |||
| Args: | |||
| root (string): | |||
| split (string): The type of dataset: 'train', 'val', 'trainval', or 'test' | |||
| transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. Default: None. | |||
| target_transform (callable, optional): A function/transform that takes in the target and transform it. Default: None. | |||
| transforms (callable, optional): A function/transform that takes in both the image and target and transform them. Default: None. | |||
| """ | |||
| cmap = np.array([ | |||
| (128, 128, 128), | |||
| (128, 0, 0), | |||
| (192, 192, 128), | |||
| (128, 64, 128), | |||
| (60, 40, 222), | |||
| (128, 128, 0), | |||
| (192, 128, 128), | |||
| (64, 64, 128), | |||
| (64, 0, 128), | |||
| (64, 64, 0), | |||
| (0, 128, 192), | |||
| (0, 0, 0), | |||
| ]) | |||
| def __init__(self, | |||
| root, | |||
| split='train', | |||
| transform=None, | |||
| target_transform=None, | |||
| transforms=None): | |||
| assert split in ('train', 'val', 'test', 'trainval') | |||
| super( CamVid, self ).__init__(root=root, transforms=transforms, transform=transform, target_transform=target_transform) | |||
| self.root = os.path.expanduser(root) | |||
| self.split = split | |||
| if split == 'trainval': | |||
| self.images = glob(os.path.join(self.root, 'train', '*.png')) + glob(os.path.join(self.root, 'val', '*.png')) | |||
| self.labels = glob(os.path.join(self.root, 'trainannot', '*.png')) + glob(os.path.join(self.root, 'valannot', '*.png')) | |||
| else: | |||
| self.images = glob(os.path.join(self.root, self.split, '*.png')) | |||
| self.labels = glob(os.path.join(self.root, self.split+'annot', '*.png')) | |||
| self.images.sort() | |||
| self.labels.sort() | |||
| def __getitem__(self, idx): | |||
| """ | |||
| Args: | |||
| - index (``int``): index of the item in the dataset | |||
| Returns: | |||
| A tuple of ``PIL.Image`` (image, label) where label is the ground-truth | |||
| of the image. | |||
| """ | |||
| img, label = Image.open(self.images[idx]), Image.open(self.labels[idx]) | |||
| if self.transforms is not None: | |||
| img, label = self.transforms(img, label) | |||
| label[label == 11] = 255 # ignore void | |||
| return img, label.squeeze(0) | |||
| def __len__(self): | |||
| return len(self.images) | |||
| @classmethod | |||
| def decode_fn(cls, mask): | |||
| """decode semantic mask to RGB image""" | |||
| mask[mask == 255] = 11 | |||
| return cls.cmap[mask] | |||
| @@ -0,0 +1,146 @@ | |||
| # Modified from https://github.com/pytorch/vision/blob/master/torchvision/datasets/cityscapes.py | |||
| import json | |||
| import os | |||
| from collections import namedtuple | |||
| import torch | |||
| import torch.utils.data as data | |||
| from PIL import Image | |||
| import numpy as np | |||
| from torchvision.datasets import VisionDataset | |||
| class Cityscapes(VisionDataset): | |||
| """Cityscapes <http://www.cityscapes-dataset.com/> Dataset. | |||
| Args: | |||
| root (string): Root directory of dataset where directory 'leftImg8bit' and 'gtFine' or 'gtCoarse' are located. | |||
| split (string, optional): The image split to use, 'train', 'test' or 'val' if mode="gtFine" otherwise 'train', 'train_extra' or 'val' | |||
| mode (string, optional): The quality mode to use, 'gtFine' or 'gtCoarse' or 'color'. Can also be a list to output a tuple with all specified target types. | |||
| transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` | |||
| target_transform (callable, optional): A function/transform that takes in the target and transforms it. | |||
| """ | |||
| # Based on https://github.com/mcordts/cityscapesScripts | |||
| CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id', | |||
| 'has_instances', 'ignore_in_eval', 'color']) | |||
| classes = [ | |||
| CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)), | |||
| CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)), | |||
| CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)), | |||
| CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)), | |||
| CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)), | |||
| CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)), | |||
| CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)), | |||
| CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)), | |||
| CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)), | |||
| CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)), | |||
| CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)), | |||
| CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)), | |||
| CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)), | |||
| CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)), | |||
| CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)), | |||
| CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)), | |||
| CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)), | |||
| CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)), | |||
| CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)), | |||
| CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)), | |||
| CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)), | |||
| CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)), | |||
| CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)), | |||
| CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)), | |||
| CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)), | |||
| CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)), | |||
| CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)), | |||
| CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)), | |||
| CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)), | |||
| CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)), | |||
| CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)), | |||
| CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)), | |||
| CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)), | |||
| CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)), | |||
| CityscapesClass('license plate', -1, 255, 'vehicle', 7, False, True, (0, 0, 142)), | |||
| ] | |||
| _TRAIN_ID_TO_COLOR = [c.color for c in classes if (c.train_id != -1 and c.train_id != 255)] | |||
| _TRAIN_ID_TO_COLOR.append([0, 0, 0]) | |||
| _TRAIN_ID_TO_COLOR = np.array(_TRAIN_ID_TO_COLOR) | |||
| _ID_TO_TRAIN_ID = np.array([c.train_id for c in classes]) | |||
| def __init__(self, root, split='train', mode='gtFine', target_type='semantic', transform=None, target_transform=None, transforms=None): | |||
| super(Cityscapes, self).__init__( root, transform=transform, target_transform=target_transform, transforms=transforms ) | |||
| self.root = os.path.expanduser(root) | |||
| self.mode = mode | |||
| self.target_type = target_type | |||
| self.images_dir = os.path.join(self.root, 'leftImg8bit', split) | |||
| self.targets_dir = os.path.join(self.root, self.mode, split) | |||
| self.split = split | |||
| self.images = [] | |||
| self.targets = [] | |||
| if split not in ['train', 'test', 'val']: | |||
| raise ValueError('Invalid split for mode! Please use split="train", split="test"' | |||
| ' or split="val"') | |||
| if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir): | |||
| raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the' | |||
| ' specified "split" and "mode" are inside the "root" directory') | |||
| for city in os.listdir(self.images_dir): | |||
| img_dir = os.path.join(self.images_dir, city) | |||
| target_dir = os.path.join(self.targets_dir, city) | |||
| for file_name in os.listdir(img_dir): | |||
| self.images.append(os.path.join(img_dir, file_name)) | |||
| target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0], | |||
| self._get_target_suffix(self.mode, self.target_type)) | |||
| self.targets.append(os.path.join(target_dir, target_name)) | |||
| @classmethod | |||
| def encode_target(cls, target): | |||
| if isinstance( target, torch.Tensor ): | |||
| return torch.from_numpy( cls._ID_TO_TRAIN_ID[np.array(target)] ) | |||
| else: | |||
| return cls._ID_TO_TRAIN_ID[target] | |||
| @classmethod | |||
| def decode_fn(cls, target): | |||
| target[target == 255] = 19 | |||
| #target = target.astype('uint8') + 1 | |||
| return cls._TRAIN_ID_TO_COLOR[target] | |||
| def __getitem__(self, index): | |||
| """ | |||
| Args: | |||
| index (int): Index | |||
| Returns: | |||
| tuple: (image, target) where target is a tuple of all target types if target_type is a list with more | |||
| than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation. | |||
| """ | |||
| image = Image.open(self.images[index]).convert('RGB') | |||
| target = Image.open(self.targets[index]) | |||
| if self.transforms: | |||
| image, target = self.transforms(image, target) | |||
| target = self.encode_target(target) | |||
| return image, target | |||
| def __len__(self): | |||
| return len(self.images) | |||
| def _load_json(self, path): | |||
| with open(path, 'r') as file: | |||
| data = json.load(file) | |||
| return data | |||
| def _get_target_suffix(self, mode, target_type): | |||
| if target_type == 'instance': | |||
| return '{}_instanceIds.png'.format(mode) | |||
| elif target_type == 'semantic': | |||
| return '{}_labelIds.png'.format(mode) | |||
| elif target_type == 'color': | |||
| return '{}_color.png'.format(mode) | |||
| elif target_type == 'polygon': | |||
| return '{}_polygons.json'.format(mode) | |||
| elif target_type == 'depth': | |||
| return '{}_disparity.png'.format(mode) | |||
| @@ -0,0 +1,70 @@ | |||
| # Modified from https://github.com/TDeVries/cub2011_dataset/blob/master/cub2011.py | |||
| import os | |||
| import pandas as pd | |||
| from torchvision.datasets.folder import default_loader | |||
| from .utils import download_url | |||
| from torch.utils.data import Dataset | |||
| import shutil | |||
| class CUB200(Dataset): | |||
| base_folder = 'images' | |||
| url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz' | |||
| filename = 'CUB_200_2011.tgz' | |||
| tgz_md5 = '97eceeb196236b17998738112f37df78' | |||
| def __init__(self, root, split='train', transform=None, target_transform=None, loader=default_loader, download=False): | |||
| self.root = os.path.abspath( os.path.expanduser( root ) ) | |||
| self.transform = transform | |||
| self.target_transform = target_transform | |||
| self.loader = default_loader | |||
| self.split = split | |||
| if download: | |||
| self.download() | |||
| self._load_metadata() | |||
| categories = os.listdir(os.path.join( | |||
| self.root, 'CUB_200_2011', 'images')) | |||
| categories.sort() | |||
| self.object_categories = [c[4:] for c in categories] | |||
| print('CUB200, Split: %s, Size: %d' % (self.split, self.__len__())) | |||
| def _load_metadata(self): | |||
| images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ', | |||
| names=['img_id', 'filepath']) | |||
| image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'), | |||
| sep=' ', names=['img_id', 'target'], encoding='latin-1', engine='python') | |||
| train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'), | |||
| sep=' ', names=['img_id', 'is_training_img'], encoding='latin-1', engine='python') | |||
| data = images.merge(image_class_labels, on='img_id') | |||
| self.data = data.merge(train_test_split, on='img_id') | |||
| if self.split == 'train': | |||
| self.data = self.data[self.data.is_training_img == 1] | |||
| else: | |||
| self.data = self.data[self.data.is_training_img == 0] | |||
| def download(self): | |||
| import tarfile | |||
| os.makedirs(self.root, exist_ok=True) | |||
| if not os.path.isfile(os.path.join(self.root, self.filename)): | |||
| download_url(self.url, self.root, self.filename) | |||
| print("Extracting %s..." % self.filename) | |||
| with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: | |||
| tar.extractall(path=self.root) | |||
| def __len__(self): | |||
| return len(self.data) | |||
| def __getitem__(self, idx): | |||
| sample = self.data.iloc[idx] | |||
| path = os.path.join(self.root, 'CUB_200_2011', | |||
| self.base_folder, sample.filepath) | |||
| lbl = sample.target - 1 | |||
| img = self.loader(path) | |||
| if self.transform is not None: | |||
| img = self.transform(img) | |||
| if self.target_transform is not None: | |||
| lbl = self.target_transform(lbl) | |||
| return img, lbl | |||
| @@ -0,0 +1,57 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| from torchvision.datasets import VisionDataset | |||
| from PIL import Image | |||
| import torch | |||
| class LabelConcatDataset(VisionDataset): | |||
| """Dataset as a concatenation of dataset's lables. | |||
| This class is useful to assemble the same dataset's labels. | |||
| Arguments: | |||
| datasets (sequence): List of datasets to be concatenated | |||
| tasks (list) : List of teacher tasks | |||
| transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. | |||
| target_transform (callable, optional): A function/transform that takes in the target and transforms it. | |||
| transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version. | |||
| """ | |||
| def __init__(self, datasets, transforms=None, transform=None, target_transform=None): | |||
| super(LabelConcatDataset, self).__init__( | |||
| root=None, transforms=transforms, transform=transform, target_transform=target_transform) | |||
| assert len(datasets) > 0, 'datasets should not be an empty iterable' | |||
| self.datasets = list(datasets) | |||
| for d in self.datasets: | |||
| for t in ('transform', 'transforms', 'target_transform'): | |||
| if hasattr( d, t ): | |||
| setattr( d, t, None ) | |||
| def __getitem__(self, idx): | |||
| targets_list = [] | |||
| for dst in self.datasets: | |||
| image, target = dst[idx] | |||
| targets_list.append(target) | |||
| if self.transforms is not None: | |||
| image, *targets_list = self.transforms( image, *targets_list ) | |||
| data = [ image, *targets_list ] | |||
| return data | |||
| def __len__(self): | |||
| return len(self.datasets[0].images) | |||
| @@ -0,0 +1,143 @@ | |||
| #Modified from https://github.com/pytorch/vision/pull/467/files | |||
| from __future__ import print_function | |||
| import torch.utils.data as data | |||
| from torchvision.datasets.folder import pil_loader, accimage_loader, default_loader | |||
| from PIL import Image | |||
| import os | |||
| import numpy as np | |||
| from .utils import download_url, mkdir | |||
| def make_dataset(dir, image_ids, targets): | |||
| assert(len(image_ids) == len(targets)) | |||
| images = [] | |||
| dir = os.path.expanduser(dir) | |||
| for i in range(len(image_ids)): | |||
| item = (os.path.join(dir, 'fgvc-aircraft-2013b', 'data', 'images', | |||
| '%s.jpg' % image_ids[i]), targets[i]) | |||
| images.append(item) | |||
| return images | |||
| def find_classes(classes_file): | |||
| # read classes file, separating out image IDs and class names | |||
| image_ids = [] | |||
| targets = [] | |||
| f = open(classes_file, 'r') | |||
| for line in f: | |||
| split_line = line.split(' ') | |||
| image_ids.append(split_line[0]) | |||
| targets.append(' '.join(split_line[1:])) | |||
| f.close() | |||
| # index class names | |||
| classes = np.unique(targets) | |||
| class_to_idx = {classes[i]: i for i in range(len(classes))} | |||
| targets = [class_to_idx[c] for c in targets] | |||
| return (image_ids, targets, classes, class_to_idx) | |||
| class FGVCAircraft(data.Dataset): | |||
| """`FGVC-Aircraft <http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft>`_ Dataset. | |||
| Args: | |||
| root (string): Root directory path to dataset. | |||
| class_type (string, optional): The level of FGVC-Aircraft fine-grain classification | |||
| to label data with (i.e., ``variant``, ``family``, or ``manufacturer``). | |||
| transforms (callable, optional): A function/transforms that takes in a PIL image | |||
| and returns a transformed version. E.g. ``transforms.RandomCrop`` | |||
| target_transform (callable, optional): A function/transforms that takes in the | |||
| target and transforms it. | |||
| loader (callable, optional): A function to load an image given its path. | |||
| download (bool, optional): If true, downloads the dataset from the internet and | |||
| puts it in the root directory. If dataset is already downloaded, it is not | |||
| downloaded again. | |||
| """ | |||
| url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz' | |||
| class_types = ('variant', 'family', 'manufacturer') | |||
| splits = ('train', 'val', 'trainval', 'test') | |||
| def __init__(self, root, class_type='variant', split='train', transform=None, | |||
| target_transform=None, loader=default_loader, download=False): | |||
| if split not in self.splits: | |||
| raise ValueError('Split "{}" not found. Valid splits are: {}'.format( | |||
| split, ', '.join(self.splits), | |||
| )) | |||
| if class_type not in self.class_types: | |||
| raise ValueError('Class type "{}" not found. Valid class types are: {}'.format( | |||
| class_type, ', '.join(self.class_types), | |||
| )) | |||
| self.root = root | |||
| self.class_type = class_type | |||
| self.split = split | |||
| self.classes_file = os.path.join(self.root, 'fgvc-aircraft-2013b', 'data', | |||
| 'images_%s_%s.txt' % (self.class_type, self.split)) | |||
| if download: | |||
| self.download() | |||
| (image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file) | |||
| samples = make_dataset(self.root, image_ids, targets) | |||
| self.transform = transform | |||
| self.target_transform = target_transform | |||
| self.loader = loader | |||
| self.samples = samples | |||
| self.classes = classes | |||
| self.class_to_idx = class_to_idx | |||
| with open(os.path.join(self.root, 'fgvc-aircraft-2013b/data', 'variants.txt')) as f: | |||
| self.object_categories = [ | |||
| line.strip('\n') for line in f.readlines()] | |||
| print('FGVC-Aircraft, Split: %s, Size: %d' % (self.split, self.__len__())) | |||
| def __getitem__(self, index): | |||
| """ | |||
| Args: | |||
| index (int): Index | |||
| Returns: | |||
| tuple: (sample, target) where target is class_index of the target class. | |||
| """ | |||
| path, target = self.samples[index] | |||
| sample = self.loader(path) | |||
| if self.transform is not None: | |||
| sample = self.transform(sample) | |||
| if self.target_transform is not None: | |||
| target = self.target_transform(target) | |||
| return sample, target | |||
| def __len__(self): | |||
| return len(self.samples) | |||
| def __repr__(self): | |||
| fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' | |||
| fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) | |||
| fmt_str += ' Root Location: {}\n'.format(self.root) | |||
| tmp = ' Transforms (if any): ' | |||
| fmt_str += '{0}{1}\n'.format( | |||
| tmp, self.transforms.__repr__().replace('\n', '\n' + ' ' * len(tmp))) | |||
| tmp = ' Target Transforms (if any): ' | |||
| fmt_str += '{0}{1}'.format( | |||
| tmp, self.target_transforms.__repr__().replace('\n', '\n' + ' ' * len(tmp))) | |||
| return fmt_str | |||
| def _check_exists(self): | |||
| return os.path.exists(os.path.join(self.root, 'data', 'images')) and \ | |||
| os.path.exists(self.classes_file) | |||
| def download(self): | |||
| """Download the FGVC-Aircraft data if it doesn't exist already.""" | |||
| from six.moves import urllib | |||
| import tarfile | |||
| mkdir(self.root) | |||
| fpath = os.path.join(self.root, 'fgvc-aircraft-2013b.tar.gz') | |||
| if not os.path.isfile(fpath): | |||
| download_url(self.url, self.root, 'fgvc-aircraft-2013b.tar.gz') | |||
| print("Extracting fgvc-aircraft-2013b.tar.gz...") | |||
| with tarfile.open(fpath, "r:gz") as tar: | |||
| tar.extractall(path=self.root) | |||
| @@ -0,0 +1,84 @@ | |||
| # Modified from https://github.com/VainF/nyuv2-python-toolkit | |||
| import os | |||
| import torch | |||
| import torch.utils.data as data | |||
| from PIL import Image | |||
| from scipy.io import loadmat | |||
| import numpy as np | |||
| import glob | |||
| from torchvision import transforms | |||
| from torchvision.datasets import VisionDataset | |||
| import random | |||
| from .utils import colormap | |||
| class NYUv2(VisionDataset): | |||
| """NYUv2 dataset | |||
| See https://github.com/VainF/nyuv2-python-toolkit for more details. | |||
| Args: | |||
| root (string): Root directory path. | |||
| split (string, optional): 'train' for training set, and 'test' for test set. Default: 'train'. | |||
| target_type (string, optional): Type of target to use, ``semantic``, ``depth`` or ``normal``. | |||
| num_classes (int, optional): The number of classes, must be 40 or 13. Default:13. | |||
| transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. | |||
| target_transform (callable, optional): A function/transform that takes in the target and transforms it. | |||
| transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version. | |||
| """ | |||
| cmap = colormap() | |||
| def __init__(self, | |||
| root, | |||
| split='train', | |||
| target_type='semantic', | |||
| num_classes=13, | |||
| transforms=None, | |||
| transform=None, | |||
| target_transform=None): | |||
| super( NYUv2, self ).__init__(root, transforms=transforms, transform=transform, target_transform=target_transform) | |||
| assert(split in ('train', 'test')) | |||
| self.root = root | |||
| self.split = split | |||
| self.target_type = target_type | |||
| self.num_classes = num_classes | |||
| split_mat = loadmat(os.path.join(self.root, 'splits.mat')) | |||
| idxs = split_mat[self.split+'Ndxs'].reshape(-1) - 1 | |||
| img_names = os.listdir( os.path.join(self.root, 'image', self.split) ) | |||
| img_names.sort() | |||
| images_dir = os.path.join(self.root, 'image', self.split) | |||
| self.images = [os.path.join(images_dir, name) for name in img_names] | |||
| self._is_depth = False | |||
| if self.target_type=='semantic': | |||
| semantic_dir = os.path.join(self.root, 'seg%d'%self.num_classes, self.split) | |||
| self.labels = [os.path.join(semantic_dir, name) for name in img_names] | |||
| self.targets = self.labels | |||
| if self.target_type=='depth': | |||
| depth_dir = os.path.join(self.root, 'depth', self.split) | |||
| self.depths = [os.path.join(depth_dir, name) for name in img_names] | |||
| self.targets = self.depths | |||
| self._is_depth = True | |||
| if self.target_type=='normal': | |||
| normal_dir = os.path.join(self.root, 'normal', self.split) | |||
| self.normals = [os.path.join(normal_dir, name) for name in img_names] | |||
| self.targets = self.normals | |||
| def __getitem__(self, idx): | |||
| image = Image.open(self.images[idx]) | |||
| target = Image.open(self.targets[idx]) | |||
| if self.transforms is not None: | |||
| image, target = self.transforms( image, target ) | |||
| return image, target | |||
| def __len__(self): | |||
| return len(self.images) | |||
| @classmethod | |||
| def decode_fn(cls, mask: np.ndarray): | |||
| """decode semantic mask to RGB image""" | |||
| mask = mask.astype('uint8') + 1 # 255 => 0 | |||
| return cls.cmap[mask] | |||
| @@ -0,0 +1,61 @@ | |||
| import os, sys | |||
| from glob import glob | |||
| import random | |||
| import argparse | |||
| from PIL import Image | |||
| if __name__=='__main__': | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument('--data_root', type=str, default='../101_ObjectCategories') | |||
| parser.add_argument('--test_split', type=float, default=0.3) | |||
| args = parser.parse_args() | |||
| SAVE_DIR = os.path.join( os.path.dirname(args.data_root), 'caltech101_data' ) | |||
| if not os.path.exists(SAVE_DIR): | |||
| os.mkdir(SAVE_DIR) | |||
| # Train | |||
| TRAIN_DIR = os.path.join( SAVE_DIR, 'train' ) | |||
| if not os.path.exists(TRAIN_DIR): | |||
| os.mkdir(TRAIN_DIR) | |||
| # Test | |||
| TEST_DIR = os.path.join( SAVE_DIR, 'test' ) | |||
| if not os.path.exists(TEST_DIR): | |||
| os.mkdir(TEST_DIR) | |||
| img_folders = os.listdir(args.data_root) | |||
| img_folders.sort() | |||
| for folder in img_folders: | |||
| if folder=='Faces': | |||
| continue | |||
| print('Processing %s'%(folder)) | |||
| img_paths = glob(os.path.join( args.data_root, folder, '*.jpg') ) | |||
| img_name = [os.path.split(p)[-1] for p in img_paths] | |||
| random.shuffle(img_name) | |||
| img_n = len(img_name) | |||
| test_n = int(args.test_split * img_n) | |||
| test_set = img_name[:test_n] | |||
| train_set = img_name[test_n:] | |||
| # test | |||
| dst_path = os.path.join(TEST_DIR, folder) | |||
| if not os.path.exists(dst_path): | |||
| os.mkdir(dst_path) | |||
| for test_name in test_set: | |||
| img = Image.open(os.path.join( args.data_root, folder, test_name )) | |||
| img.save( os.path.join(dst_path, test_name ) ) | |||
| # train | |||
| dst_path = os.path.join(TRAIN_DIR, folder) | |||
| if not os.path.exists(dst_path): | |||
| os.mkdir(dst_path) | |||
| for train_name in train_set: | |||
| img = Image.open(os.path.join( args.data_root, folder, train_name )) | |||
| img.save( os.path.join(dst_path, train_name ) ) | |||
| @@ -0,0 +1,196 @@ | |||
| from __future__ import print_function | |||
| import sys | |||
| import os, sys, tarfile, errno | |||
| import numpy as np | |||
| import matplotlib.pyplot as plt | |||
| import argparse | |||
| if sys.version_info >= (3, 0, 0): | |||
| import urllib.request as urllib # ugly but works | |||
| else: | |||
| import urllib | |||
| try: | |||
| from imageio import imsave | |||
| except: | |||
| from scipy.misc import imsave | |||
| print(sys.version_info) | |||
| # image shape | |||
| HEIGHT = 96 | |||
| WIDTH = 96 | |||
| DEPTH = 3 | |||
| # size of a single image in bytes | |||
| SIZE = HEIGHT * WIDTH * DEPTH | |||
| # path to the directory with the data | |||
| #DATA_DIR = './data' | |||
| # url of the binary data | |||
| #DATA_URL = 'http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz' | |||
| # path to the binary train file with image data | |||
| #DATA_PATH = './data/stl10_binary/train_X.bin' | |||
| # path to the binary train file with labels | |||
| #LABEL_PATH = './data/stl10_binary/train_y.bin' | |||
| def read_labels(path_to_labels): | |||
| """ | |||
| :param path_to_labels: path to the binary file containing labels from the STL-10 dataset | |||
| :return: an array containing the labels | |||
| """ | |||
| with open(path_to_labels, 'rb') as f: | |||
| labels = np.fromfile(f, dtype=np.uint8) | |||
| return labels | |||
| def read_all_images(path_to_data): | |||
| """ | |||
| :param path_to_data: the file containing the binary images from the STL-10 dataset | |||
| :return: an array containing all the images | |||
| """ | |||
| with open(path_to_data, 'rb') as f: | |||
| # read whole file in uint8 chunks | |||
| everything = np.fromfile(f, dtype=np.uint8) | |||
| # We force the data into 3x96x96 chunks, since the | |||
| # images are stored in "column-major order", meaning | |||
| # that "the first 96*96 values are the red channel, | |||
| # the next 96*96 are green, and the last are blue." | |||
| # The -1 is since the size of the pictures depends | |||
| # on the input file, and this way numpy determines | |||
| # the size on its own. | |||
| images = np.reshape(everything, (-1, 3, 96, 96)) | |||
| # Now transpose the images into a standard image format | |||
| # readable by, for example, matplotlib.imshow | |||
| # You might want to comment this line or reverse the shuffle | |||
| # if you will use a learning algorithm like CNN, since they like | |||
| # their channels separated. | |||
| images = np.transpose(images, (0, 3, 2, 1)) | |||
| return images | |||
| def read_single_image(image_file): | |||
| """ | |||
| CAREFUL! - this method uses a file as input instead of the path - so the | |||
| position of the reader will be remembered outside of context of this method. | |||
| :param image_file: the open file containing the images | |||
| :return: a single image | |||
| """ | |||
| # read a single image, count determines the number of uint8's to read | |||
| image = np.fromfile(image_file, dtype=np.uint8, count=SIZE) | |||
| # force into image matrix | |||
| image = np.reshape(image, (3, 96, 96)) | |||
| # transpose to standard format | |||
| # You might want to comment this line or reverse the shuffle | |||
| # if you will use a learning algorithm like CNN, since they like | |||
| # their channels separated. | |||
| image = np.transpose(image, (2, 1, 0)) | |||
| return image | |||
| def plot_image(image): | |||
| """ | |||
| :param image: the image to be plotted in a 3-D matrix format | |||
| :return: None | |||
| """ | |||
| plt.imshow(image) | |||
| plt.show() | |||
| def save_image(image, name): | |||
| imsave("%s.png" % name, image, format="png") | |||
| def download_and_extract(DATA_DIR): | |||
| """ | |||
| Download and extract the STL-10 dataset | |||
| :return: None | |||
| """ | |||
| dest_directory = DATA_DIR | |||
| if not os.path.exists(dest_directory): | |||
| os.makedirs(dest_directory) | |||
| filename = DATA_URL.split('/')[-1] | |||
| filepath = os.path.join(dest_directory, filename) | |||
| if not os.path.exists(filepath): | |||
| def _progress(count, block_size, total_size): | |||
| sys.stdout.write('\rDownloading %s %.2f%%' % (filename, | |||
| float(count * block_size) / float(total_size) * 100.0)) | |||
| sys.stdout.flush() | |||
| filepath, _ = urllib.urlretrieve(DATA_URL, filepath, reporthook=_progress) | |||
| print('Downloaded', filename) | |||
| tarfile.open(filepath, 'r:gz').extractall(dest_directory) | |||
| def save_images(images, labels, save_dir): | |||
| print("Saving images to disk") | |||
| i = 0 | |||
| if not os.path.exists(save_dir): | |||
| os.mkdir(save_dir) | |||
| for image in images: | |||
| if labels is not None: | |||
| label = labels[i] | |||
| directory = os.path.join( save_dir, str(label) ) | |||
| if not os.path.exists(directory): | |||
| os.makedirs(directory) | |||
| else: | |||
| directory = save_dir | |||
| filename = os.path.join( directory, str(i) ) | |||
| print(filename) | |||
| save_image(image, filename) | |||
| i = i+1 | |||
| if __name__ == "__main__": | |||
| # download data if needed | |||
| # download_and_extract() | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument('--DATA_DIR', type=str, default='../stl10_binary') | |||
| args = parser.parse_args() | |||
| BASE_PATH = os.path.join( args.DATA_DIR, 'stl10_data' ) | |||
| if not os.path.exists(BASE_PATH): | |||
| os.mkdir(BASE_PATH) | |||
| # Train | |||
| DATA_PATH = os.path.join( args.DATA_DIR, 'train_X.bin') | |||
| LABEL_PATH = os.path.join( args.DATA_DIR, 'train_y.bin') | |||
| print('Preparing Train') | |||
| images = read_all_images(DATA_PATH) | |||
| print(images.shape) | |||
| labels = read_labels(LABEL_PATH) | |||
| print(labels.shape) | |||
| save_images(images, labels, os.path.join(BASE_PATH, 'train')) | |||
| # Test | |||
| DATA_PATH = os.path.join( args.DATA_DIR, 'test_X.bin') | |||
| LABEL_PATH = os.path.join( args.DATA_DIR, 'test_y.bin') | |||
| #with open(DATA_PATH) as f: | |||
| # image = read_single_image(f) | |||
| # plot_image(image) | |||
| print('Preparing Test') | |||
| images = read_all_images(DATA_PATH) | |||
| print(images.shape) | |||
| labels = read_labels(LABEL_PATH) | |||
| print(labels.shape) | |||
| save_images(images, labels, os.path.join(BASE_PATH, 'test')) | |||
| # Unlabeled | |||
| print('Preparing Unlabeled') | |||
| DATA_PATH = os.path.join( args.DATA_DIR, 'unlabeled_X.bin') | |||
| images = read_all_images(DATA_PATH) | |||
| save_images(images, None, os.path.join(BASE_PATH, 'unlabeled')) | |||
| @@ -0,0 +1,53 @@ | |||
| import argparse | |||
| import os | |||
| from PIL import Image | |||
| import numpy as np | |||
| SIZE=(480, 320) # W, H | |||
| def is_image(path: str): | |||
| return path.endswith( 'png' ) or path.endswith('jpg') or path.endswith( 'jpeg' ) | |||
| def copy_and_resize( src_dir, dst_dir, resize_fn ): | |||
| for file_or_dir in os.listdir( src_dir ): | |||
| src = os.path.join( src_dir, file_or_dir ) | |||
| dst = os.path.join( dst_dir, file_or_dir ) | |||
| if os.path.isdir( src ): | |||
| os.mkdir( dst ) | |||
| copy_and_resize( src, dst, resize_fn ) | |||
| elif is_image( src ): | |||
| print(src, ' -> ', dst) | |||
| image = Image.open( src ) | |||
| resized_image = resize_fn(image) | |||
| resized_image.save( dst ) | |||
| def resize_input( image: Image.Image ): | |||
| return image.resize( SIZE, resample=Image.BICUBIC ) | |||
| def resize_seg( image: Image.Image ): | |||
| return image.resize( SIZE, resample=Image.NEAREST ) | |||
| def main(): | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument('--root', type=str, required=True) | |||
| ROOT = parser.parse_args().root | |||
| NEW_ROOT = os.path.join( ROOT, '%d_%d'%(*SIZE) ) | |||
| os.mkdir(NEW_ROOT, ) | |||
| for split in ['train', 'val', 'test']: | |||
| IMG_DIR = os.path.join( ROOT, split ) | |||
| GT_DIR = os.path.join( ROOT, split+'annot' ) | |||
| NEW_IMG_DIR = os.path.join( NEW_ROOT, split ) | |||
| NEW_GT_DIR = os.path.join( NEW_ROOT, split+'annot' ) | |||
| os.mkdir( NEW_IMG_DIR ) | |||
| os.mkdir( NEW_GT_DIR ) | |||
| copy_and_resize( IMG_DIR, NEW_IMG_DIR, resize_input ) | |||
| copy_and_resize( GT_DIR, NEW_GT_DIR, resize_seg ) | |||
| if __name__=='__main__': | |||
| main() | |||
| @@ -0,0 +1,65 @@ | |||
| import argparse | |||
| import os | |||
| from PIL import Image | |||
| import numpy as np | |||
| SIZE=(640, 320) | |||
| def is_image(path: str): | |||
| return path.endswith( 'png' ) or path.endswith('jpg') or path.endswith( 'jpeg' ) | |||
| def copy_and_resize( src_dir, dst_dir, resize_fn ): | |||
| for file_or_dir in os.listdir( src_dir ): | |||
| src = os.path.join( src_dir, file_or_dir ) | |||
| dst = os.path.join( dst_dir, file_or_dir ) | |||
| if os.path.isdir( src ): | |||
| os.mkdir( dst ) | |||
| copy_and_resize( src, dst, resize_fn ) | |||
| elif is_image( src ): | |||
| print(src, ' -> ', dst) | |||
| image = Image.open( src ) | |||
| resized_image = resize_fn(image) | |||
| resized_image.save( dst ) | |||
| def resize_input( image: Image.Image ): | |||
| return image.resize( SIZE, resample=Image.BICUBIC ) | |||
| def resize_seg( image: Image.Image ): | |||
| return image.resize( SIZE, resample=Image.NEAREST ) | |||
| def resize_depth( image: Image.Image ): | |||
| return image.resize( SIZE, resample=Image.NEAREST ) | |||
| def main(): | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument('--root', type=str, required=True) | |||
| ROOT = parser.parse_args().root | |||
| IMG_DIR = os.path.join( ROOT, 'leftImg8bit' ) | |||
| GT_DIR = os.path.join( ROOT, 'gtFine' ) | |||
| DEPTH_DIR = os.path.join( ROOT, 'disparity' ) | |||
| NEW_ROOT = os.path.join( ROOT, '%d_%d'%(*SIZE) ) | |||
| NEW_IMG_DIR = os.path.join( NEW_ROOT, 'leftImg8bit' ) | |||
| NEW_GT_DIR = os.path.join( NEW_ROOT, 'gtFine' ) | |||
| NEW_DEPTH_DIR = os.path.join( NEW_ROOT, 'disparity' ) | |||
| if os.path.exists(NEW_ROOT): | |||
| print("Directory %s existed, please remove it before running this script"%NEW_ROOT) | |||
| return | |||
| os.mkdir(NEW_ROOT) | |||
| os.mkdir( NEW_IMG_DIR ) | |||
| os.mkdir( NEW_GT_DIR ) | |||
| os.mkdir( NEW_DEPTH_DIR ) | |||
| copy_and_resize( IMG_DIR, NEW_IMG_DIR, resize_input ) | |||
| copy_and_resize( GT_DIR, NEW_GT_DIR, resize_seg ) | |||
| copy_and_resize( DEPTH_DIR, NEW_DEPTH_DIR, resize_depth ) | |||
| if __name__=='__main__': | |||
| main() | |||
| @@ -0,0 +1,59 @@ | |||
| import argparse | |||
| import os | |||
| from PIL import Image | |||
| import numpy as np | |||
| from torchvision import transforms | |||
| import shutil | |||
| SIZE=240 | |||
| def is_image(path: str): | |||
| return path.endswith( 'png' ) or path.endswith('jpg') or path.endswith( 'jpeg' ) | |||
| def copy_and_resize( src_dir, dst_dir, resize_fn ): | |||
| for file_or_dir in os.listdir( src_dir ): | |||
| src = os.path.join( src_dir, file_or_dir ) | |||
| dst = os.path.join( dst_dir, file_or_dir ) | |||
| if os.path.isdir( src ): | |||
| os.mkdir( dst ) | |||
| copy_and_resize( src, dst, resize_fn ) | |||
| elif is_image( src ): | |||
| print(src, ' -> ', dst) | |||
| image = Image.open( src ) | |||
| resized_image = resize_fn(image) | |||
| resized_image.save( dst ) | |||
| def resize_input( image: Image.Image ): | |||
| return transforms.functional.resize( image, SIZE, interpolation=Image.BILINEAR ) | |||
| def resize_seg( image: Image.Image ): | |||
| return transforms.functional.resize( image, SIZE, interpolation=Image.NEAREST) | |||
| def main(): | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument('--root', type=str, required=True) | |||
| ROOT = parser.parse_args().root | |||
| IMG_DIR = os.path.join( ROOT, 'JPEGImages' ) | |||
| GT_DIR = os.path.join( ROOT, 'SegmentationClass' ) | |||
| NEW_ROOT = os.path.join( ROOT, str(SIZE) ) | |||
| NEW_IMG_DIR = os.path.join( NEW_ROOT, 'JPEGImages' ) | |||
| NEW_GT_DIR = os.path.join( NEW_ROOT, 'SegmentationClass' ) | |||
| if os.path.exists(NEW_ROOT): | |||
| print("Directory %s existed, please remove it before running this script"%NEW_ROOT) | |||
| return | |||
| os.mkdir(NEW_ROOT) | |||
| os.mkdir( NEW_IMG_DIR ) | |||
| os.mkdir( NEW_GT_DIR ) | |||
| copy_and_resize( IMG_DIR, NEW_IMG_DIR, resize_input ) | |||
| copy_and_resize( GT_DIR, NEW_GT_DIR, resize_seg ) | |||
| shutil.copytree( os.path.join( ROOT, 'ImageSets'), os.path.join( NEW_ROOT, 'ImageSets' ) ) | |||
| if __name__=='__main__': | |||
| main() | |||
| @@ -0,0 +1,57 @@ | |||
| import argparse | |||
| import os | |||
| from PIL import Image | |||
| import numpy as np | |||
| from torchvision import transforms | |||
| import shutil | |||
| def is_image(path: str): | |||
| return path.endswith( 'png' ) or path.endswith('jpg') or path.endswith( 'jpeg' ) | |||
| def copy_and_resize( src_dir, dst_dir, resize_fn ): | |||
| for file_or_dir in os.listdir( src_dir ): | |||
| src = os.path.join( src_dir, file_or_dir ) | |||
| dst = os.path.join( dst_dir, file_or_dir ) | |||
| if os.path.isdir( src ): | |||
| os.mkdir( dst ) | |||
| copy_and_resize( src, dst, resize_fn ) | |||
| elif is_image( src ): | |||
| print(src, ' -> ', dst) | |||
| image = Image.open( src ) | |||
| resized_image = resize_fn(image) | |||
| resized_image.save( dst ) | |||
| def resize_input( image: Image.Image ): | |||
| return transforms.functional.resize( image, 240, interpolation=Image.BILINEAR ) | |||
| def resize_seg( image: Image.Image ): | |||
| return transforms.functional.resize( image, 240, interpolation=Image.NEAREST) | |||
| def main(): | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument('--root', type=str, required=True) | |||
| ROOT = parser.parse_args().root | |||
| IMG_DIR = os.path.join( ROOT, 'JPEGImages' ) | |||
| GT_DIR = os.path.join( ROOT, 'SegmentationClass' ) | |||
| NEW_ROOT = os.path.join( ROOT, '240' ) | |||
| NEW_IMG_DIR = os.path.join( NEW_ROOT, 'JPEGImages' ) | |||
| NEW_GT_DIR = os.path.join( NEW_ROOT, 'SegmentationClass' ) | |||
| if os.path.exists(NEW_ROOT): | |||
| print("Directory %s existed, please remove it before running this script"%NEW_ROOT) | |||
| return | |||
| os.mkdir(NEW_ROOT) | |||
| os.mkdir( NEW_IMG_DIR ) | |||
| os.mkdir( NEW_GT_DIR ) | |||
| copy_and_resize( IMG_DIR, NEW_IMG_DIR, resize_input ) | |||
| copy_and_resize( GT_DIR, NEW_GT_DIR, resize_seg ) | |||
| shutil.copytree( os.path.join( ROOT, 'ImageSets'), os.path.join( NEW_ROOT, 'ImageSets' ) ) | |||
| if __name__=='__main__': | |||
| main() | |||
| @@ -0,0 +1,80 @@ | |||
| import os | |||
| import glob | |||
| from PIL import Image | |||
| import numpy as np | |||
| from scipy.io import loadmat | |||
| from torch.utils import data | |||
| from .utils import download_url, mkdir | |||
| from shutil import copyfile | |||
| class StanfordCars(data.Dataset): | |||
| """Dataset for Stanford Cars | |||
| """ | |||
| urls = {'cars_train.tgz': 'http://imagenet.stanford.edu/internal/car196/cars_train.tgz', | |||
| 'cars_test.tgz': 'http://imagenet.stanford.edu/internal/car196/cars_test.tgz', | |||
| 'car_devkit.tgz': 'https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz', | |||
| 'cars_test_annos_withlabels.mat': 'http://imagenet.stanford.edu/internal/car196/cars_test_annos_withlabels.mat'} | |||
| def __init__(self, root, split='train', download=False, transform=None, target_transform=None): | |||
| self.root = os.path.abspath( os.path.expanduser(root) ) | |||
| self.split = split | |||
| self.transform = transform | |||
| self.target_transform = target_transform | |||
| if download: | |||
| self.download() | |||
| if self.split == 'train': | |||
| annos = os.path.join(self.root, 'devkit', 'cars_train_annos.mat') | |||
| else: | |||
| annos = os.path.join(self.root, 'devkit', | |||
| 'cars_test_annos_withlabels.mat') | |||
| annos = loadmat(annos) | |||
| size = len(annos['annotations'][0]) | |||
| self.files = glob.glob(os.path.join( | |||
| self.root, 'cars_'+self.split, '*.jpg')) | |||
| self.files.sort() | |||
| self.labels = np.array([int(l[4])-1 for l in annos['annotations'][0]]) | |||
| lbl_annos = loadmat(os.path.join(self.root, 'devkit', 'cars_meta.mat')) | |||
| self.object_categories = [str(c[0]) | |||
| for c in lbl_annos['class_names'][0]] | |||
| print('Stanford Cars, Split: %s, Size: %d' % | |||
| (self.split, self.__len__())) | |||
| def __len__(self): | |||
| return len(self.files) | |||
| def __getitem__(self, idx): | |||
| img = Image.open(os.path.join(self.root, 'Images', | |||
| self.files[idx])).convert("RGB") | |||
| lbl = self.labels[idx] | |||
| if self.transform is not None: | |||
| img = self.transform(img) | |||
| if self.target_transform is not None: | |||
| lbl = self.target_transform(lbl) | |||
| return img, lbl | |||
| def download(self): | |||
| import tarfile | |||
| mkdir(self.root) | |||
| for fname, url in self.urls.items(): | |||
| if not os.path.isfile(os.path.join(self.root, fname)): | |||
| download_url(url, self.root, fname) | |||
| if fname.endswith('tgz'): | |||
| print("Extracting %s..." % fname) | |||
| with tarfile.open(os.path.join(self.root, fname), "r:gz") as tar: | |||
| tar.extractall(path=self.root) | |||
| copyfile(os.path.join(self.root, 'cars_test_annos_withlabels.mat'), | |||
| os.path.join(self.root, 'devkit', 'cars_test_annos_withlabels.mat')) | |||
| @@ -0,0 +1,58 @@ | |||
| import os | |||
| import numpy as np | |||
| from PIL import Image | |||
| from scipy.io import loadmat | |||
| from torch.utils import data | |||
| from .utils import download_url | |||
| from shutil import move | |||
| class StanfordDogs(data.Dataset): | |||
| """Dataset for Stanford Dogs | |||
| """ | |||
| urls = {"images.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/images.tar", | |||
| "annotation.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/annotation.tar", | |||
| "lists.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/lists.tar"} | |||
| def __init__(self, root, split='train', download=False, transform=None, target_transform=None): | |||
| self.root = os.path.abspath( os.path.expanduser(root) ) | |||
| self.split = split | |||
| self.transform = transform | |||
| self.target_transform = target_transform | |||
| if download: | |||
| self.download() | |||
| list_file = os.path.join(self.root, self.split+'_list.mat') | |||
| mat_file = loadmat(list_file) | |||
| size = len(mat_file['file_list']) | |||
| self.files = [str(mat_file['file_list'][i][0][0]) for i in range(size)] | |||
| self.labels = np.array( | |||
| [mat_file['labels'][i][0]-1 for i in range(size)]) | |||
| categories = os.listdir(os.path.join(self.root, 'Images')) | |||
| categories.sort() | |||
| self.object_categories = [c[10:] for c in categories] | |||
| print('Stanford Dogs, Split: %s, Size: %d' % | |||
| (self.split, self.__len__())) | |||
| def __len__(self): | |||
| return len(self.files) | |||
| def __getitem__(self, idx): | |||
| img = Image.open(os.path.join(self.root, 'Images', | |||
| self.files[idx])).convert("RGB") | |||
| lbl = self.labels[idx] | |||
| if self.transform is not None: | |||
| img = self.transform(img) | |||
| if self.target_transform is not None: | |||
| lbl = self.target_transform( lbl ) | |||
| return img, lbl | |||
| def download(self): | |||
| import tarfile | |||
| os.makedirs(self.root, exist_ok=True) | |||
| for fname, url in self.urls.items(): | |||
| if not os.path.isfile(os.path.join(self.root, fname)): | |||
| download_url(url, self.root, fname) | |||
| # extract file | |||
| print("Extracting %s..." % fname) | |||
| with tarfile.open(os.path.join(self.root, fname), "r") as tar: | |||
| tar.extractall(path=self.root) | |||
| @@ -0,0 +1,65 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| import os | |||
| from glob import glob | |||
| from PIL import Image | |||
| from .utils import colormap | |||
| from torchvision.datasets import VisionDataset | |||
| class SunRGBD(VisionDataset): | |||
| cmap = colormap() | |||
| def __init__(self, | |||
| root, | |||
| split='train', | |||
| transform=None, | |||
| target_transform=None, | |||
| transforms=None): | |||
| super( SunRGBD, self ).__init__( root, transform=transform, target_transform=target_transform, transforms=transforms ) | |||
| self.root = root | |||
| self.split = split | |||
| self.images = glob(os.path.join(self.root, 'SUNRGBD-%s_images'%self.split, '*.jpg')) | |||
| self.labels = glob(os.path.join(self.root, '%s13labels'%self.split, '*.png')) | |||
| self.images.sort() | |||
| self.labels.sort() | |||
| def __getitem__(self, idx): | |||
| """ | |||
| Args: | |||
| - index (``int``): index of the item in the dataset | |||
| Returns: | |||
| A tuple of ``PIL.Image`` (image, label) where label is the ground-truth | |||
| of the image. | |||
| """ | |||
| img, label = Image.open(self.images[idx]), Image.open(self.labels[idx]) | |||
| if self.transform is not None: | |||
| img, label = self.transform(img, label) | |||
| label = label-1 # void 0=>255 | |||
| return img, label | |||
| def __len__(self): | |||
| return len(self.images) | |||
| @classmethod | |||
| def decode_fn(cls, mask): | |||
| """decode semantic mask to RGB image""" | |||
| return cls.cmap[mask.astype('uint8')+1] | |||
| @@ -0,0 +1,67 @@ | |||
| """ | |||
| Copyright 2020 Tianshu AI Platform. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ============================================================= | |||
| """ | |||
| import torch | |||
| from torch.utils.data import Dataset | |||
| import os | |||
| from PIL import Image | |||
| import random | |||
| from copy import deepcopy | |||
| def _collect_all_images(root, postfix=['png', 'jpg', 'jpeg', 'JPEG']): | |||
| images = [] | |||
| if isinstance( postfix, str): | |||
| postfix = [ postfix ] | |||
| for dirpath, dirnames, files in os.walk(root): | |||
| for pos in postfix: | |||
| for f in files: | |||
| if f.endswith( pos ): | |||
| images.append( os.path.join( dirpath, f ) ) | |||
| return images | |||
| def get_train_val_set(root, val_size=0.3): | |||
| if not isinstance(root, (list, tuple)): | |||
| root = [root] | |||
| train_set = [] | |||
| val_set = [] | |||
| for _root in root: | |||
| _part_train_set = _collect_all_images( _root) | |||
| if os.path.isdir( os.path.join(_root, 'test') ): | |||
| _part_val_set = _collect_all_images( os.path.join(_root, 'test') ) | |||
| else: | |||
| _val_size = int( len(_part_train_set) * val_size ) | |||
| _part_val_set = random.sample( _part_train_set, k=_val_size ) | |||
| _part_train_set = [ d for d in _part_train_set if d not in _part_val_set ] | |||
| train_set.extend(_part_train_set) | |||
| val_set.extend(_part_val_set) | |||
| return train_set, val_set | |||
| class UnlabeledDataset(Dataset): | |||
| def __init__(self, data, transform=None, postfix=['png', 'jpg', 'jpeg', 'JPEG']): | |||
| self.transform = transform | |||
| self.data = data | |||
| def __getitem__(self, idx): | |||
| data = Image.open( self.data[idx] ) | |||
| if self.transform is not None: | |||
| data = self.transform(data) | |||
| return data | |||
| def __len__(self): | |||
| return len(self.data) | |||
| @@ -0,0 +1,161 @@ | |||
| # Modified from https://github.com/pytorch/vision | |||
| import os | |||
| import os.path | |||
| import hashlib | |||
| import errno | |||
| from tqdm import tqdm | |||
| import numpy as np | |||
| import torch | |||
| import random | |||
| def mkdir(dir): | |||
| if not os.path.isdir(dir): | |||
| os.mkdir(dir) | |||
| def colormap(N=256, normalized=False): | |||
| def bitget(byteval, idx): | |||
| return ((byteval & (1 << idx)) != 0) | |||
| dtype = 'float32' if normalized else 'uint8' | |||
| cmap = np.zeros((N, 3), dtype=dtype) | |||
| for i in range(N): | |||
| r = g = b = 0 | |||
| c = i | |||
| for j in range(8): | |||
| r = r | (bitget(c, 0) << 7-j) | |||
| g = g | (bitget(c, 1) << 7-j) | |||
| b = b | (bitget(c, 2) << 7-j) | |||
| c = c >> 3 | |||
| cmap[i] = np.array([r, g, b]) | |||
| cmap = cmap/255 if normalized else cmap | |||
| return cmap | |||
| DEFAULT_COLORMAP = colormap() | |||
| def gen_bar_updater(pbar): | |||
| def bar_update(count, block_size, total_size): | |||
| if pbar.total is None and total_size: | |||
| pbar.total = total_size | |||
| progress_bytes = count * block_size | |||
| pbar.update(progress_bytes - pbar.n) | |||
| return bar_update | |||
| def check_integrity(fpath, md5=None): | |||
| if md5 is None: | |||
| return True | |||
| if not os.path.isfile(fpath): | |||
| return False | |||
| md5o = hashlib.md5() | |||
| with open(fpath, 'rb') as f: | |||
| # read in 1MB chunks | |||
| for chunk in iter(lambda: f.read(1024 * 1024), b''): | |||
| md5o.update(chunk) | |||
| md5c = md5o.hexdigest() | |||
| if md5c != md5: | |||
| return False | |||
| return True | |||
| def makedir_exist_ok(dirpath): | |||
| """ | |||
| Python2 support for os.makedirs(.., exist_ok=True) | |||
| """ | |||
| try: | |||
| os.makedirs(dirpath) | |||
| except OSError as e: | |||
| if e.errno == errno.EEXIST: | |||
| pass | |||
| else: | |||
| raise | |||
| def download_url(url, root, filename=None, md5=None): | |||
| """Download a file from a url and place it in root. | |||
| Args: | |||
| url (str): URL to download file from | |||
| root (str): Directory to place downloaded file in | |||
| filename (str): Name to save the file under. If None, use the basename of the URL | |||
| md5 (str): MD5 checksum of the download. If None, do not check | |||
| """ | |||
| from six.moves import urllib | |||
| root = os.path.expanduser(root) | |||
| if not filename: | |||
| filename = os.path.basename(url) | |||
| fpath = os.path.join(root, filename) | |||
| makedir_exist_ok(root) | |||
| # downloads file | |||
| if os.path.isfile(fpath) and check_integrity(fpath, md5): | |||
| print('Using downloaded and verified file: ' + fpath) | |||
| else: | |||
| try: | |||
| print('Downloading ' + url + ' to ' + fpath) | |||
| urllib.request.urlretrieve( | |||
| url, fpath, | |||
| reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True)) | |||
| ) | |||
| except OSError: | |||
| if url[:5] == 'https': | |||
| url = url.replace('https:', 'http:') | |||
| print('Failed download. Trying https -> http instead.' | |||
| ' Downloading ' + url + ' to ' + fpath) | |||
| urllib.request.urlretrieve( | |||
| url, fpath, | |||
| reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True)) | |||
| ) | |||
| def list_dir(root, prefix=False): | |||
| """List all directories at a given root | |||
| Args: | |||
| root (str): Path to directory whose folders need to be listed | |||
| prefix (bool, optional): If true, prepends the path to each result, otherwise | |||
| only returns the name of the directories found | |||
| """ | |||
| root = os.path.expanduser(root) | |||
| directories = list( | |||
| filter( | |||
| lambda p: os.path.isdir(os.path.join(root, p)), | |||
| os.listdir(root) | |||
| ) | |||
| ) | |||
| if prefix is True: | |||
| directories = [os.path.join(root, d) for d in directories] | |||
| return directories | |||
| def list_files(root, suffix, prefix=False): | |||
| """List all files ending with a suffix at a given root | |||
| Args: | |||
| root (str): Path to directory whose folders need to be listed | |||
| suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). | |||
| It uses the Python "str.endswith" method and is passed directly | |||
| prefix (bool, optional): If true, prepends the path to each result, otherwise | |||
| only returns the name of the files found | |||
| """ | |||
| root = os.path.expanduser(root) | |||
| files = list( | |||
| filter( | |||
| lambda p: os.path.isfile(os.path.join( | |||
| root, p)) and p.endswith(suffix), | |||
| os.listdir(root) | |||
| ) | |||
| ) | |||
| if prefix is True: | |||
| files = [os.path.join(root, d) for d in files] | |||
| return files | |||
| def set_seed(random_seed): | |||
| torch.manual_seed(random_seed) | |||
| torch.cuda.manual_seed(random_seed) | |||
| np.random.seed(random_seed) | |||
| random.seed(random_seed) | |||
| @@ -0,0 +1,209 @@ | |||
| # Modified from https://github.com/pytorch/vision | |||
| import os | |||
| import sys | |||
| import tarfile | |||
| import collections | |||
| import torch.utils.data as data | |||
| import shutil | |||
| import numpy as np | |||
| from .utils import colormap | |||
| from torchvision.datasets import VisionDataset | |||
| import torch | |||
| from PIL import Image | |||
| from torchvision.datasets.utils import download_url, check_integrity | |||
| DATASET_YEAR_DICT = { | |||
| '2012aug': { | |||
| 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', | |||
| 'filename': 'VOCtrainval_11-May-2012.tar', | |||
| 'md5': '6cd6e144f989b92b3379bac3b3de84fd', | |||
| 'base_dir': 'VOCdevkit/VOC2012' | |||
| }, | |||
| '2012': { | |||
| 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', | |||
| 'filename': 'VOCtrainval_11-May-2012.tar', | |||
| 'md5': '6cd6e144f989b92b3379bac3b3de84fd', | |||
| 'base_dir': 'VOCdevkit/VOC2012' | |||
| }, | |||
| '2011': { | |||
| 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar', | |||
| 'filename': 'VOCtrainval_25-May-2011.tar', | |||
| 'md5': '6c3384ef61512963050cb5d687e5bf1e', | |||
| 'base_dir': 'TrainVal/VOCdevkit/VOC2011' | |||
| }, | |||
| '2010': { | |||
| 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar', | |||
| 'filename': 'VOCtrainval_03-May-2010.tar', | |||
| 'md5': 'da459979d0c395079b5c75ee67908abb', | |||
| 'base_dir': 'VOCdevkit/VOC2010' | |||
| }, | |||
| '2009': { | |||
| 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar', | |||
| 'filename': 'VOCtrainval_11-May-2009.tar', | |||
| 'md5': '59065e4b188729180974ef6572f6a212', | |||
| 'base_dir': 'VOCdevkit/VOC2009' | |||
| }, | |||
| '2008': { | |||
| 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar', | |||
| 'filename': 'VOCtrainval_11-May-2012.tar', | |||
| 'md5': '2629fa636546599198acfcfbfcf1904a', | |||
| 'base_dir': 'VOCdevkit/VOC2008' | |||
| }, | |||
| '2007': { | |||
| 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar', | |||
| 'filename': 'VOCtrainval_06-Nov-2007.tar', | |||
| 'md5': 'c52e279531787c972589f7e41ab4ae64', | |||
| 'base_dir': 'VOCdevkit/VOC2007' | |||
| } | |||
| } | |||
| class VOCSegmentation(VisionDataset): | |||
| """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset. | |||
| Args: | |||
| root (string): Root directory of the VOC Dataset. | |||
| year (string, optional): The dataset year, supports years 2007 to 2012. | |||
| image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val`` | |||
| download (bool, optional): If true, downloads the dataset from the internet and | |||
| puts it in root directory. If dataset is already downloaded, it is not | |||
| downloaded again. | |||
| transform (callable, optional): A function/transform that takes in an PIL image | |||
| and returns a transformed version. E.g, ``transforms.RandomCrop`` | |||
| """ | |||
| cmap = colormap() | |||
| def __init__(self, | |||
| root, | |||
| year='2012', | |||
| image_set='train', | |||
| download=False, | |||
| transform=None, | |||
| target_transform=None, | |||
| transforms=None, | |||
| ): | |||
| super( VOCSegmentation, self ).__init__( root, transform=transform, target_transform=target_transform, transforms=transforms ) | |||
| is_aug=False | |||
| if year=='2012aug': | |||
| is_aug = True | |||
| year = '2012' | |||
| self.root = os.path.expanduser(root) | |||
| self.year = year | |||
| self.url = DATASET_YEAR_DICT[year]['url'] | |||
| self.filename = DATASET_YEAR_DICT[year]['filename'] | |||
| self.md5 = DATASET_YEAR_DICT[year]['md5'] | |||
| self.image_set = image_set | |||
| base_dir = DATASET_YEAR_DICT[year]['base_dir'] | |||
| voc_root = os.path.join(self.root, base_dir) | |||
| image_dir = os.path.join(voc_root, 'JPEGImages') | |||
| if download: | |||
| download_extract(self.url, self.root, self.filename, self.md5) | |||
| if not os.path.isdir(voc_root): | |||
| raise RuntimeError('Dataset not found or corrupted.' + | |||
| ' You can use download=True to download it') | |||
| if is_aug and image_set=='train': | |||
| mask_dir = os.path.join(voc_root, 'SegmentationClassAug') | |||
| assert os.path.exists(mask_dir), "SegmentationClassAug not found, please refer to README.md and prepare it manually" | |||
| split_f = os.path.join( self.root, 'train_aug.txt') | |||
| else: | |||
| mask_dir = os.path.join(voc_root, 'SegmentationClass') | |||
| splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation') | |||
| split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt') | |||
| if not os.path.exists(split_f): | |||
| raise ValueError( | |||
| 'Wrong image_set entered! Please use image_set="train" ' | |||
| 'or image_set="trainval" or image_set="val"') | |||
| with open(os.path.join(split_f), "r") as f: | |||
| file_names = [x.strip() for x in f.readlines()] | |||
| self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] | |||
| self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names] | |||
| assert (len(self.images) == len(self.masks)) | |||
| def __getitem__(self, index): | |||
| """ | |||
| Args: | |||
| index (int): Index | |||
| Returns: | |||
| tuple: (image, target) where target is the image segmentation. | |||
| """ | |||
| img = Image.open(self.images[index]).convert('RGB') | |||
| target = Image.open(self.masks[index]) | |||
| if self.transforms is not None: | |||
| img, target = self.transforms(img, target) | |||
| return img, target.squeeze(0) | |||
| def __len__(self): | |||
| return len(self.images) | |||
| @classmethod | |||
| def decode_fn(cls, mask): | |||
| """decode semantic mask to RGB image""" | |||
| return cls.cmap[mask] | |||
| def download_extract(url, root, filename, md5): | |||
| download_url(url, root, filename, md5) | |||
| with tarfile.open(os.path.join(root, filename), "r") as tar: | |||
| tar.extractall(path=root) | |||
| CLASSES = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', | |||
| 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] | |||
| class VOCClassification(data.Dataset): | |||
| def __init__(self, | |||
| root, | |||
| year='2010', | |||
| split='train', | |||
| download=False, | |||
| transforms=None, | |||
| target_transforms=None): | |||
| voc_root = os.path.join(root, 'VOC{}'.format(year)) | |||
| if not os.path.isdir(voc_root): | |||
| raise RuntimeError('Dataset not found or corrupted.' + | |||
| ' You can use download=True to download it') | |||
| self.transforms = transforms | |||
| self.target_transforms = target_transforms | |||
| image_dir = os.path.join(voc_root, 'JPEGImages') | |||
| label_dir = os.path.join(voc_root, 'ImageSets/Main') | |||
| self.labels_list = [] | |||
| fname = os.path.join(label_dir, '{}.txt'.format(split)) | |||
| with open(fname) as f: | |||
| self.images = [os.path.join(image_dir, line.split()[0]+'.jpg') for line in f] | |||
| for clas in CLASSES: | |||
| labels = [] | |||
| with open(os.path.join(label_dir, '{}_{}.txt'.format(clas, split))) as f: | |||
| labels = [int(line.split()[1]) for line in f] | |||
| self.labels_list.append(labels) | |||
| assert (len(self.images) == len(self.labels_list[0])) | |||
| def __getitem__(self, index): | |||
| """ | |||
| Args: | |||
| index (int): Index | |||
| Returns: | |||
| tuple: (image, target) where target is the image segmentation. | |||
| """ | |||
| img = Image.open(self.images[index]).convert('RGB') | |||
| labels = [labels[index] for labels in self.labels_list] | |||
| if self.transforms is not None: | |||
| img = self.transforms(img) | |||
| if self.target_transforms is not None: | |||
| labels = self.target_transforms(labels) | |||
| return img, labels | |||
| def __len__(self): | |||
| return len(self.images) | |||
| @@ -0,0 +1,3 @@ | |||
| from . import classification, segmentation | |||
| from torchvision import models as torchvision_models | |||
| @@ -0,0 +1,7 @@ | |||
| from .darknet import * | |||
| from .mobilenetv2 import * | |||
| from .resnet import * | |||
| from .vgg import * | |||
| from . import cifar | |||
| from .alexnet import alexnet | |||
| @@ -0,0 +1,63 @@ | |||
| # Modified from https://github.com/pytorch/vision | |||
| import torch.nn as nn | |||
| import torch.utils.model_zoo as model_zoo | |||
| __all__ = ['AlexNet', 'alexnet'] | |||
| model_urls = { | |||
| 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', | |||
| } | |||
| class AlexNet(nn.Module): | |||
| def __init__(self, num_classes=1000): | |||
| super(AlexNet, self).__init__() | |||
| self.features = nn.Sequential( | |||
| nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), | |||
| nn.ReLU(inplace=True), | |||
| nn.MaxPool2d(kernel_size=3, stride=2), | |||
| nn.Conv2d(64, 192, kernel_size=5, padding=2), | |||
| nn.ReLU(inplace=True), | |||
| nn.MaxPool2d(kernel_size=3, stride=2), | |||
| nn.Conv2d(192, 384, kernel_size=3, padding=1), | |||
| nn.ReLU(inplace=True), | |||
| nn.Conv2d(384, 256, kernel_size=3, padding=1), | |||
| nn.ReLU(inplace=True), | |||
| nn.Conv2d(256, 256, kernel_size=3, padding=1), | |||
| nn.ReLU(inplace=True), | |||
| nn.MaxPool2d(kernel_size=3, stride=2), | |||
| ) | |||
| self.classifier = nn.Sequential( | |||
| nn.Dropout(), | |||
| nn.Linear(256 * 6 * 6, 4096), | |||
| nn.ReLU(inplace=True), | |||
| nn.Dropout(), | |||
| nn.Linear(4096, 4096), | |||
| nn.ReLU(inplace=True), | |||
| nn.Linear(4096, num_classes), | |||
| ) | |||
| def forward(self, x): | |||
| x = self.features(x) | |||
| x = x.view(x.size(0), 256 * 6 * 6) | |||
| x = self.classifier(x) | |||
| return x | |||
| def alexnet(pretrained=False, **kwargs): | |||
| r"""AlexNet model architecture from the | |||
| `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper. | |||
| Args: | |||
| pretrained (bool): If True, returns a model pre-trained on ImageNet | |||
| """ | |||
| num_classes = kwargs.pop('num_classes', None) | |||
| model = AlexNet(**kwargs) | |||
| if pretrained: | |||
| model.load_state_dict(model_zoo.load_url(model_urls['alexnet'])) | |||
| if num_classes is not None and num_classes!=1000: | |||
| model.classifier[-1] = nn.Linear(4096, num_classes) | |||
| return model | |||
| @@ -0,0 +1 @@ | |||
| from . import wrn | |||
| @@ -0,0 +1,108 @@ | |||
| #Adapted from https://github.com/polo5/ZeroShotKnowledgeTransfer/blob/master/models/wresnet.py | |||
| import math | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| __all__ = ['wrn'] | |||
| class BasicBlock(nn.Module): | |||
| def __init__(self, in_planes, out_planes, stride, dropout_rate=0.0): | |||
| super(BasicBlock, self).__init__() | |||
| self.bn1 = nn.BatchNorm2d(in_planes) | |||
| self.relu1 = nn.ReLU(inplace=True) | |||
| self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, | |||
| padding=1, bias=False) | |||
| self.bn2 = nn.BatchNorm2d(out_planes) | |||
| self.relu2 = nn.ReLU(inplace=True) | |||
| self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, | |||
| padding=1, bias=False) | |||
| self.dropout = nn.Dropout( dropout_rate ) | |||
| self.equalInOut = (in_planes == out_planes) | |||
| self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, | |||
| padding=0, bias=False) or None | |||
| def forward(self, x): | |||
| if not self.equalInOut: | |||
| x = self.relu1(self.bn1(x)) | |||
| else: | |||
| out = self.relu1(self.bn1(x)) | |||
| out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) | |||
| out = self.dropout(out) | |||
| out = self.conv2(out) | |||
| return torch.add(x if self.equalInOut else self.convShortcut(x), out) | |||
| class NetworkBlock(nn.Module): | |||
| def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropout_rate=0.0): | |||
| super(NetworkBlock, self).__init__() | |||
| self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropout_rate) | |||
| def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropout_rate): | |||
| layers = [] | |||
| for i in range(nb_layers): | |||
| layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropout_rate)) | |||
| return nn.Sequential(*layers) | |||
| def forward(self, x): | |||
| return self.layer(x) | |||
| class WideResNet(nn.Module): | |||
| def __init__(self, depth, num_classes, widen_factor=1, dropout_rate=0.0): | |||
| super(WideResNet, self).__init__() | |||
| nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] | |||
| assert (depth - 4) % 6 == 0, 'depth should be 6n+4' | |||
| n = (depth - 4) // 6 | |||
| block = BasicBlock | |||
| # 1st conv before any network block | |||
| self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, | |||
| padding=1, bias=False) | |||
| # 1st block | |||
| self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropout_rate) | |||
| # 2nd block | |||
| self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropout_rate) | |||
| # 3rd block | |||
| self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropout_rate) | |||
| # global average pooling and classifier | |||
| self.bn1 = nn.BatchNorm2d(nChannels[3]) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| self.fc = nn.Linear(nChannels[3], num_classes) | |||
| self.nChannels = nChannels[3] | |||
| for m in self.modules(): | |||
| if isinstance(m, nn.Conv2d): | |||
| n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |||
| m.weight.data.normal_(0, math.sqrt(2. / n)) | |||
| elif isinstance(m, nn.BatchNorm2d): | |||
| m.weight.data.fill_(1) | |||
| m.bias.data.zero_() | |||
| elif isinstance(m, nn.Linear): | |||
| m.bias.data.zero_() | |||
| def forward(self, x, return_features=False): | |||
| out = self.conv1(x) | |||
| out = self.block1(out) | |||
| out = self.block2(out) | |||
| out = self.block3(out) | |||
| out = self.relu(self.bn1(out)) | |||
| out = F.avg_pool2d(out, 8) | |||
| features = out.view(-1, self.nChannels) | |||
| out = self.fc(features) | |||
| if return_features: | |||
| return out, features | |||
| else: | |||
| return out | |||
| def wrn_16_1(num_classes, dropout_rate=0): | |||
| return WideResNet(depth=16, num_classes=num_classes, widen_factor=1, dropout_rate=dropout_rate) | |||
| def wrn_16_2(num_classes, dropout_rate=0): | |||
| return WideResNet(depth=16, num_classes=num_classes, widen_factor=2, dropout_rate=dropout_rate) | |||
| def wrn_40_1(num_classes, dropout_rate=0): | |||
| return WideResNet(depth=40, num_classes=num_classes, widen_factor=1, dropout_rate=dropout_rate) | |||
| def wrn_40_2(num_classes, dropout_rate=0): | |||
| return WideResNet(depth=40, num_classes=num_classes, widen_factor=2, dropout_rate=dropout_rate) | |||