From 5a0d9ad9f9e5111b7ac2f9f242b025a767f206f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=8B=E6=B1=9F=E5=A4=A9=E6=9E=A2?= Date: Wed, 30 Jun 2021 14:59:58 +0800 Subject: [PATCH] update model converter and model measuring --- model-converter/.gitignore | 2 + model-converter/Dockerfile | 7 + model-converter/main.py | 86 + model_measuring/README.md | 0 model_measuring/app.py | 130 ++ model_measuring/kamal/__init__.py | 5 + .../kamal/amalgamation/__init__.py | 4 + .../kamal/amalgamation/common_feature.py | 291 ++++ .../amalgamation/layerwise_amalgamation.py | 131 ++ .../kamal/amalgamation/recombination.py | 209 +++ .../kamal/amalgamation/task_branching.py | 295 ++++ model_measuring/kamal/core/__init__.py | 4 + model_measuring/kamal/core/attach.py | 45 + .../kamal/core/callbacks/__init__.py | 5 + model_measuring/kamal/core/callbacks/base.py | 28 + .../kamal/core/callbacks/eval_and_ckpt.py | 145 ++ .../kamal/core/callbacks/logging.py | 61 + .../kamal/core/callbacks/scheduler.py | 34 + .../kamal/core/callbacks/visualize.py | 161 ++ model_measuring/kamal/core/engine/__init__.py | 7 + model_measuring/kamal/core/engine/engine.py | 190 +++ .../kamal/core/engine/evaluator.py | 130 ++ model_measuring/kamal/core/engine/events.py | 92 + model_measuring/kamal/core/engine/hooks.py | 34 + .../kamal/core/engine/lr_finder.py | 214 +++ model_measuring/kamal/core/engine/trainer.py | 129 ++ model_measuring/kamal/core/exceptions.py | 22 + model_measuring/kamal/core/hub/__init__.py | 2 + model_measuring/kamal/core/hub/_hub.py | 288 ++++ .../kamal/core/hub/_module_mapping.py | 6 + model_measuring/kamal/core/hub/meta/TASK.py | 22 + .../kamal/core/hub/meta/__init__.py | 3 + model_measuring/kamal/core/hub/meta/input.py | 31 + model_measuring/kamal/core/hub/meta/meta.py | 44 + .../kamal/core/metrics/__init__.py | 8 + .../kamal/core/metrics/accuracy.py | 118 ++ model_measuring/kamal/core/metrics/average.py | 49 + .../kamal/core/metrics/confusion_matrix.py | 68 + model_measuring/kamal/core/metrics/normal.py | 86 + .../kamal/core/metrics/regression.py | 199 +++ .../kamal/core/metrics/stream_metrics.py | 82 + model_measuring/kamal/core/tasks/__init__.py | 3 + .../kamal/core/tasks/loss/__init__.py | 2 + .../kamal/core/tasks/loss/functional.py | 107 ++ model_measuring/kamal/core/tasks/loss/loss.py | 386 +++++ model_measuring/kamal/core/tasks/task.py | 186 +++ model_measuring/kamal/slim/__init__.py | 2 + .../kamal/slim/distillation/__init__.py | 12 + .../kamal/slim/distillation/attention.py | 47 + model_measuring/kamal/slim/distillation/cc.py | 55 + .../slim/distillation/data_free/__init__.py | 1 + .../kamal/slim/distillation/data_free/zskt.py | 99 ++ .../kamal/slim/distillation/hint.py | 86 + model_measuring/kamal/slim/distillation/kd.py | 90 + .../kamal/slim/distillation/nst.py | 44 + .../kamal/slim/distillation/pkt.py | 44 + .../kamal/slim/distillation/rkd.py | 45 + model_measuring/kamal/slim/distillation/sp.py | 44 + .../kamal/slim/distillation/svd.py | 45 + .../kamal/slim/distillation/vid.py | 90 + .../kamal/slim/prunning/__init__.py | 2 + model_measuring/kamal/slim/prunning/pruner.py | 37 + .../kamal/slim/prunning/strategy.py | 85 + .../kamal/transferability/README.md | 18 + .../kamal/transferability/__init__.py | 20 + .../kamal/transferability/depara/__init__.py | 3 + .../depara/attribution_graph.py | 184 ++ .../transferability/depara/attribution_map.py | 87 + .../kamal/transferability/trans_graph.py | 135 ++ .../kamal/transferability/trans_metric.py | 109 ++ model_measuring/kamal/utils/__init__.py | 2 + model_measuring/kamal/utils/_utils.py | 153 ++ model_measuring/kamal/utils/logger.py | 56 + model_measuring/kamal/vision/__init__.py | 3 + .../kamal/vision/datasets/__init__.py | 16 + .../kamal/vision/datasets/ade20k.py | 70 + .../kamal/vision/datasets/caltech.py | 226 +++ .../kamal/vision/datasets/camvid.py | 78 + .../kamal/vision/datasets/cityscapes.py | 146 ++ .../kamal/vision/datasets/cub200.py | 70 + .../kamal/vision/datasets/dataset.py | 57 + .../kamal/vision/datasets/fgvc_aircraft.py | 143 ++ model_measuring/kamal/vision/datasets/nyu.py | 84 + .../datasets/preprocess/prepare_caltech101.py | 61 + .../datasets/preprocess/prepare_stl10.py | 196 +++ .../datasets/preprocess/resize_camvid.py | 53 + .../datasets/preprocess/resize_cityscapes.py | 65 + .../vision/datasets/preprocess/resize_voc.py | 59 + .../datasets/preprocess/resize_voc_240.py | 57 + .../kamal/vision/datasets/stanford_cars.py | 80 + .../kamal/vision/datasets/stanford_dogs.py | 58 + .../kamal/vision/datasets/sunrgbd.py | 65 + .../kamal/vision/datasets/unlabeled.py | 67 + .../kamal/vision/datasets/utils.py | 161 ++ model_measuring/kamal/vision/datasets/voc.py | 209 +++ .../kamal/vision/models/__init__.py | 3 + .../vision/models/classification/__init__.py | 7 + .../vision/models/classification/alexnet.py | 63 + .../models/classification/cifar/__init__.py | 1 + .../vision/models/classification/cifar/wrn.py | 108 ++ .../vision/models/classification/darknet.py | 246 +++ .../vision/models/classification/densenet.py | 233 +++ .../models/classification/mobilenetv2.py | 186 +++ .../vision/models/classification/resnet.py | 336 ++++ .../kamal/vision/models/classification/vgg.py | 176 ++ .../kamal/vision/models/detection/__init__.py | 1 + .../vision/models/detection/craft/__init__.py | 1 + .../vision/models/detection/craft/craft.py | 96 ++ .../vision/models/detection/craft/vgg16_bn.py | 73 + .../vision/models/segmentation/__init__.py | 4 + .../models/segmentation/deeplab/__init__.py | 2 + .../models/segmentation/deeplab/deeplab.py | 148 ++ .../models/segmentation/deeplab/layer.py | 159 ++ .../models/segmentation/deeplab/utils.py | 55 + .../models/segmentation/linknet/__init__.py | 1 + .../models/segmentation/linknet/linknet.py | 194 +++ .../models/segmentation/segnet/__init__.py | 1 + .../models/segmentation/segnet/layer.py | 53 + .../models/segmentation/segnet/segnet.py | 194 +++ .../models/segmentation/unet/__init__.py | 1 + .../vision/models/segmentation/unet/layer.py | 51 + .../vision/models/segmentation/unet/unet.py | 58 + model_measuring/kamal/vision/models/utils.py | 115 ++ .../kamal/vision/sync_transforms/__init__.py | 1 + .../vision/sync_transforms/functional.py | 805 +++++++++ .../vision/sync_transforms/transforms.py | 1475 +++++++++++++++++ 126 files changed, 12617 insertions(+) create mode 100644 model-converter/.gitignore create mode 100644 model-converter/Dockerfile create mode 100644 model-converter/main.py create mode 100644 model_measuring/README.md create mode 100644 model_measuring/app.py create mode 100644 model_measuring/kamal/__init__.py create mode 100644 model_measuring/kamal/amalgamation/__init__.py create mode 100644 model_measuring/kamal/amalgamation/common_feature.py create mode 100644 model_measuring/kamal/amalgamation/layerwise_amalgamation.py create mode 100644 model_measuring/kamal/amalgamation/recombination.py create mode 100644 model_measuring/kamal/amalgamation/task_branching.py create mode 100644 model_measuring/kamal/core/__init__.py create mode 100644 model_measuring/kamal/core/attach.py create mode 100644 model_measuring/kamal/core/callbacks/__init__.py create mode 100644 model_measuring/kamal/core/callbacks/base.py create mode 100644 model_measuring/kamal/core/callbacks/eval_and_ckpt.py create mode 100644 model_measuring/kamal/core/callbacks/logging.py create mode 100644 model_measuring/kamal/core/callbacks/scheduler.py create mode 100644 model_measuring/kamal/core/callbacks/visualize.py create mode 100644 model_measuring/kamal/core/engine/__init__.py create mode 100644 model_measuring/kamal/core/engine/engine.py create mode 100644 model_measuring/kamal/core/engine/evaluator.py create mode 100644 model_measuring/kamal/core/engine/events.py create mode 100644 model_measuring/kamal/core/engine/hooks.py create mode 100644 model_measuring/kamal/core/engine/lr_finder.py create mode 100644 model_measuring/kamal/core/engine/trainer.py create mode 100644 model_measuring/kamal/core/exceptions.py create mode 100644 model_measuring/kamal/core/hub/__init__.py create mode 100644 model_measuring/kamal/core/hub/_hub.py create mode 100644 model_measuring/kamal/core/hub/_module_mapping.py create mode 100644 model_measuring/kamal/core/hub/meta/TASK.py create mode 100644 model_measuring/kamal/core/hub/meta/__init__.py create mode 100644 model_measuring/kamal/core/hub/meta/input.py create mode 100644 model_measuring/kamal/core/hub/meta/meta.py create mode 100644 model_measuring/kamal/core/metrics/__init__.py create mode 100644 model_measuring/kamal/core/metrics/accuracy.py create mode 100644 model_measuring/kamal/core/metrics/average.py create mode 100644 model_measuring/kamal/core/metrics/confusion_matrix.py create mode 100644 model_measuring/kamal/core/metrics/normal.py create mode 100644 model_measuring/kamal/core/metrics/regression.py create mode 100644 model_measuring/kamal/core/metrics/stream_metrics.py create mode 100644 model_measuring/kamal/core/tasks/__init__.py create mode 100644 model_measuring/kamal/core/tasks/loss/__init__.py create mode 100644 model_measuring/kamal/core/tasks/loss/functional.py create mode 100644 model_measuring/kamal/core/tasks/loss/loss.py create mode 100644 model_measuring/kamal/core/tasks/task.py create mode 100644 model_measuring/kamal/slim/__init__.py create mode 100644 model_measuring/kamal/slim/distillation/__init__.py create mode 100644 model_measuring/kamal/slim/distillation/attention.py create mode 100644 model_measuring/kamal/slim/distillation/cc.py create mode 100644 model_measuring/kamal/slim/distillation/data_free/__init__.py create mode 100644 model_measuring/kamal/slim/distillation/data_free/zskt.py create mode 100644 model_measuring/kamal/slim/distillation/hint.py create mode 100644 model_measuring/kamal/slim/distillation/kd.py create mode 100644 model_measuring/kamal/slim/distillation/nst.py create mode 100644 model_measuring/kamal/slim/distillation/pkt.py create mode 100644 model_measuring/kamal/slim/distillation/rkd.py create mode 100644 model_measuring/kamal/slim/distillation/sp.py create mode 100644 model_measuring/kamal/slim/distillation/svd.py create mode 100644 model_measuring/kamal/slim/distillation/vid.py create mode 100644 model_measuring/kamal/slim/prunning/__init__.py create mode 100644 model_measuring/kamal/slim/prunning/pruner.py create mode 100644 model_measuring/kamal/slim/prunning/strategy.py create mode 100644 model_measuring/kamal/transferability/README.md create mode 100644 model_measuring/kamal/transferability/__init__.py create mode 100644 model_measuring/kamal/transferability/depara/__init__.py create mode 100644 model_measuring/kamal/transferability/depara/attribution_graph.py create mode 100644 model_measuring/kamal/transferability/depara/attribution_map.py create mode 100644 model_measuring/kamal/transferability/trans_graph.py create mode 100644 model_measuring/kamal/transferability/trans_metric.py create mode 100644 model_measuring/kamal/utils/__init__.py create mode 100644 model_measuring/kamal/utils/_utils.py create mode 100644 model_measuring/kamal/utils/logger.py create mode 100644 model_measuring/kamal/vision/__init__.py create mode 100644 model_measuring/kamal/vision/datasets/__init__.py create mode 100644 model_measuring/kamal/vision/datasets/ade20k.py create mode 100644 model_measuring/kamal/vision/datasets/caltech.py create mode 100644 model_measuring/kamal/vision/datasets/camvid.py create mode 100644 model_measuring/kamal/vision/datasets/cityscapes.py create mode 100644 model_measuring/kamal/vision/datasets/cub200.py create mode 100644 model_measuring/kamal/vision/datasets/dataset.py create mode 100644 model_measuring/kamal/vision/datasets/fgvc_aircraft.py create mode 100644 model_measuring/kamal/vision/datasets/nyu.py create mode 100644 model_measuring/kamal/vision/datasets/preprocess/prepare_caltech101.py create mode 100644 model_measuring/kamal/vision/datasets/preprocess/prepare_stl10.py create mode 100644 model_measuring/kamal/vision/datasets/preprocess/resize_camvid.py create mode 100644 model_measuring/kamal/vision/datasets/preprocess/resize_cityscapes.py create mode 100644 model_measuring/kamal/vision/datasets/preprocess/resize_voc.py create mode 100644 model_measuring/kamal/vision/datasets/preprocess/resize_voc_240.py create mode 100644 model_measuring/kamal/vision/datasets/stanford_cars.py create mode 100644 model_measuring/kamal/vision/datasets/stanford_dogs.py create mode 100644 model_measuring/kamal/vision/datasets/sunrgbd.py create mode 100644 model_measuring/kamal/vision/datasets/unlabeled.py create mode 100644 model_measuring/kamal/vision/datasets/utils.py create mode 100644 model_measuring/kamal/vision/datasets/voc.py create mode 100644 model_measuring/kamal/vision/models/__init__.py create mode 100644 model_measuring/kamal/vision/models/classification/__init__.py create mode 100644 model_measuring/kamal/vision/models/classification/alexnet.py create mode 100644 model_measuring/kamal/vision/models/classification/cifar/__init__.py create mode 100644 model_measuring/kamal/vision/models/classification/cifar/wrn.py create mode 100644 model_measuring/kamal/vision/models/classification/darknet.py create mode 100644 model_measuring/kamal/vision/models/classification/densenet.py create mode 100644 model_measuring/kamal/vision/models/classification/mobilenetv2.py create mode 100644 model_measuring/kamal/vision/models/classification/resnet.py create mode 100644 model_measuring/kamal/vision/models/classification/vgg.py create mode 100644 model_measuring/kamal/vision/models/detection/__init__.py create mode 100644 model_measuring/kamal/vision/models/detection/craft/__init__.py create mode 100644 model_measuring/kamal/vision/models/detection/craft/craft.py create mode 100644 model_measuring/kamal/vision/models/detection/craft/vgg16_bn.py create mode 100644 model_measuring/kamal/vision/models/segmentation/__init__.py create mode 100644 model_measuring/kamal/vision/models/segmentation/deeplab/__init__.py create mode 100644 model_measuring/kamal/vision/models/segmentation/deeplab/deeplab.py create mode 100644 model_measuring/kamal/vision/models/segmentation/deeplab/layer.py create mode 100644 model_measuring/kamal/vision/models/segmentation/deeplab/utils.py create mode 100644 model_measuring/kamal/vision/models/segmentation/linknet/__init__.py create mode 100644 model_measuring/kamal/vision/models/segmentation/linknet/linknet.py create mode 100644 model_measuring/kamal/vision/models/segmentation/segnet/__init__.py create mode 100644 model_measuring/kamal/vision/models/segmentation/segnet/layer.py create mode 100644 model_measuring/kamal/vision/models/segmentation/segnet/segnet.py create mode 100644 model_measuring/kamal/vision/models/segmentation/unet/__init__.py create mode 100644 model_measuring/kamal/vision/models/segmentation/unet/layer.py create mode 100644 model_measuring/kamal/vision/models/segmentation/unet/unet.py create mode 100644 model_measuring/kamal/vision/models/utils.py create mode 100644 model_measuring/kamal/vision/sync_transforms/__init__.py create mode 100644 model_measuring/kamal/vision/sync_transforms/functional.py create mode 100644 model_measuring/kamal/vision/sync_transforms/transforms.py diff --git a/model-converter/.gitignore b/model-converter/.gitignore new file mode 100644 index 0000000..c749986 --- /dev/null +++ b/model-converter/.gitignore @@ -0,0 +1,2 @@ +/.idea/ +*.iml diff --git a/model-converter/Dockerfile b/model-converter/Dockerfile new file mode 100644 index 0000000..4eb7b87 --- /dev/null +++ b/model-converter/Dockerfile @@ -0,0 +1,7 @@ +FROM tensorflow/tensorflow:2.4.1 + +WORKDIR /app +RUN pip install web.py tf2onnx +COPY . /app + +ENTRYPOINT ["python3", "main.py"] diff --git a/model-converter/main.py b/model-converter/main.py new file mode 100644 index 0000000..0b9f863 --- /dev/null +++ b/model-converter/main.py @@ -0,0 +1,86 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +import json +import os +import subprocess +import logging +import web +from subprocess import PIPE + +urls = ( + '/hello', 'Hello', + '/model_convert', 'ModelConvert' +) +logging.basicConfig(filename='onnx.log', level=logging.DEBUG) + +class Hello(object): + def GET(self): + return 'service alive' + +class ModelConvert(object): + def POST(self): + data = web.data() + web.header('Content-Type', 'application/json') + try: + json_data = json.loads(data) + model_path = json_data['model_path'] + output_path = json_data['output_path'] + if not os.path.isdir(model_path): + msg = 'model_path is not a dir: %s' % model_path + logging.error(msg) + return json.dumps({'code': 501, 'msg': msg, 'data': ''}) + if not output_path.endswith('/'): + msg = 'output_path is not a dir: %s' % output_path + logging.error(msg) + return json.dumps({'code': 502, 'msg': msg, 'data': ''}) + exist_flag = exist(model_path) + if not exist_flag: + msg = 'SavedModel file does not exist at: %s' % model_path + logging.error(msg) + return json.dumps({'code': 503, 'msg': msg, 'data': ''}) + convert_flag, msg = convert(model_path, output_path) + if not convert_flag: + return json.dumps({'code': 504, 'msg': msg, 'data': ''}) + except Exception as e: + logging.error(str(e)) + return json.dumps({'code': 505, 'msg': str(e), 'data': ''}) + return json.dumps({'code': 200, 'msg': 'ok', 'data': msg}) + +def exist(model_path): + for file in os.listdir(model_path): + if file=='saved_model.pbtxt' or file=='saved_model.pb': + return True + return False + + +def convert(model_path, output_path): + output_path = output_path+'model.onnx' + try: + logging.info('model_path=%s, output_path=%s' % (model_path, output_path)) + result = subprocess.run(["python", "-m", "tf2onnx.convert", "--saved-model", model_path, "--output", output_path], stdout=PIPE, stderr=PIPE) + logging.info(repr(result)) + if result.returncode != 0: + return False, str(result.stderr) + except Exception as e: + logging.error(str(e)) + return False, str(e) + return True, output_path + +if __name__ == '__main__': + app = web.application(urls, globals()) + app.run() diff --git a/model_measuring/README.md b/model_measuring/README.md new file mode 100644 index 0000000..e69de29 diff --git a/model_measuring/app.py b/model_measuring/app.py new file mode 100644 index 0000000..0765c0c --- /dev/null +++ b/model_measuring/app.py @@ -0,0 +1,130 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +dependencies = ['torch', 'kamal'] # 依赖项 +import json +import traceback +import os +import torch +import web +import kamal +from PIL import Image +from kamal.transferability import TransferabilityGraph +from kamal.transferability.trans_metric import AttrMapMetric +from torchvision.models import * + +urls = ( + '/model_measure/measure', 'Measure', + '/model_measure/package', 'Package', +) +app = web.application(urls, globals()) +class Package: + def POST(self): + req = web.data() + save_path_list = [] + for json_data in json.loads(req): + try: + metadata = json_data['metadata'] + except Exception as e: + traceback.print_exc() + return json.dumps(Response(506, 'Failed to package model, Error: %s'% (traceback.format_exc(limit=1)), save_path_list).__dict__) + entry_name = json_data['entry_name'] + for name, fn in kamal.hub.list_entry(__file__): + if entry_name in name: + try: + dataset = metadata['dataset'] + model = fn(pretrained=False, num_classes=metadata['entry_args']['num_classes']) + num_params = sum( [ torch.numel(p) for p in model.parameters() ] ) + save_path_for_measure = '%sfinegraind_%s/' % (json_data['ckpt'], entry_name) + save_path = save_path_for_measure+'%s' % (metadata['name']) + ckpt = self.file_name(json_data['ckpt']) + if ckpt == '': + return json.dumps( + Response(506, 'Failed to package model [%s]: No .pth file was found in directory ckpt' % (entry_name), save_path_list).__dict__) + model.load_state_dict(torch.load(ckpt), False) + kamal.hub.save( # 该调用将用户的pytorch模型打包成上述格式,并存储至指定位置 + model, # 需要保存的模型 nn.Module + save_path=save_path, + # 导出文件夹名称 + entry_name=entry_name, # 入口函数名,需要与上边的入口函数一致 + spec_name=None, # 具体的参数版本名,为空则自动用md5替代 + code_path=__file__, # 模型依赖的代码,可以是文件夹(必须包含hubconf.py文件), + # 或者是当前hubconf.py, 例子中直接使用了依赖中的模型实现,故只需指定为本文件即可 + metadata=metadata, + tags=dict( + num_params=num_params, + metadata=metadata, + name=metadata['name'], + url=metadata['url'], + dataset=dataset, + img_size=metadata['input']['size'], + readme=json_data['readme']) + + ) + save_path_list.append(save_path_for_measure) + return json.dumps(Response(200, 'Success', save_path_list).__dict__) + except Exception: + traceback.print_exc() + return json.dumps(Response(506,'Failed to package model [%s], Error: %s' % (entry_name, traceback.format_exc(limit=1)), save_path_list).__dict__) + return json.dumps(Response(506, 'Failed to package model [%s], Error: %s' % (entry_name, traceback.format_exc(limit=1)), save_path_list).__dict__) + + + def file_name(self, file_dir): + for root, dirs, files in os.walk(file_dir): + for file in files: + if file.endswith('pth'): + return root + file + return '' + + +class Measure: + def POST(self): + req = web.data() + json_data = json.loads(req) + print(json_data) + try: + measure_name = 'measure' + zoo_set = json_data['zoo_set'] + probe_set_root = json_data['probe_set_root'] + export_path = json_data['export_path'] + output_filename_list = [] + TG = TransferabilityGraph(zoo_set) + print("Add %s" % (probe_set_root)) + imgs_set = list(os.listdir(os.path.join(probe_set_root))) + images = [Image.open(os.path.join(probe_set_root, img)) for img in imgs_set] + metric = AttrMapMetric(images, device=torch.device('cuda')) + TG.add_metric(probe_set_root, metric) + isExists = os.path.exists(export_path) + if not isExists: + # 如果不存在则创建目录 + os.makedirs(export_path) + output_filename = export_path+'%s.json' % (measure_name) + TG.export_to_json(probe_set_root, output_filename, topk=3, normalize=True) + output_filename_list.append(output_filename) + except Exception: + traceback.print_exc() + return json.dumps(Response(506, 'Failed to generate measurement file of [%s], Error: %s' % (probe_set_root, traceback.format_exc(limit=1)), output_filename_list).__dict__) + return json.dumps(Response(200, 'Success', output_filename_list).__dict__) + +class Response: + def __init__(self, code, msg, data): + self.code = code + self.msg = msg + self.data = data + +if __name__ == "__main__": + app.run() \ No newline at end of file diff --git a/model_measuring/kamal/__init__.py b/model_measuring/kamal/__init__.py new file mode 100644 index 0000000..53cecec --- /dev/null +++ b/model_measuring/kamal/__init__.py @@ -0,0 +1,5 @@ +from .core import tasks, metrics, engine, callbacks, hub + +from . import amalgamation, slim, vision, transferability + +from .core import load, save \ No newline at end of file diff --git a/model_measuring/kamal/amalgamation/__init__.py b/model_measuring/kamal/amalgamation/__init__.py new file mode 100644 index 0000000..6e6b633 --- /dev/null +++ b/model_measuring/kamal/amalgamation/__init__.py @@ -0,0 +1,4 @@ +from .layerwise_amalgamation import LayerWiseAmalgamator +from .common_feature import CommonFeatureAmalgamator +from .task_branching import TaskBranchingAmalgamator, JointSegNet +from .recombination import RecombinationAmalgamator, CombinedModel \ No newline at end of file diff --git a/model_measuring/kamal/amalgamation/common_feature.py b/model_measuring/kamal/amalgamation/common_feature.py new file mode 100644 index 0000000..f05aea0 --- /dev/null +++ b/model_measuring/kamal/amalgamation/common_feature.py @@ -0,0 +1,291 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +from kamal.core.engine.engine import Engine +from kamal.core.engine.hooks import FeatureHook +from kamal.core import tasks +from kamal.utils import set_mode, move_to_device + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import typing, time +import numpy as np + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + +class ResBlock(nn.Module): + """ Residual Blocks + """ + def __init__(self, inplanes, planes, stride=1, momentum=0.1): + super(ResBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes, momentum=momentum) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes, momentum=momentum) + if stride > 1 or inplanes != planes: + self.downsample = nn.Sequential( + nn.Conv2d(inplanes, planes, kernel_size=1, + stride=stride, bias=False), + nn.BatchNorm2d(planes, momentum=momentum) + ) + else: + self.downsample = None + + self.stride = stride + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.relu(out) + return out + + +class CFL_FCBlock(nn.Module): + """Common Feature Blocks for Fully-Connected layer + + This module is used to capture the common features of multiple teachers and calculate mmd with features of student. + + **Parameters:** + - cs (int): channel number of student features + - channel_ts (list or tuple): channel number list of teacher features + - ch (int): channel number of hidden features + """ + def __init__(self, cs, cts, ch, k_size=5): + super(CFL_FCBlock, self).__init__() + + self.align_t = nn.ModuleList() + for ct in cts: + self.align_t.append( + nn.Sequential( + nn.Linear(ct, ch), + nn.ReLU(inplace=True) + ) + ) + + self.align_s = nn.Sequential( + nn.Linear(cs, ch), + nn.ReLU(inplace=True), + ) + + self.extractor = nn.Sequential( + nn.Linear(ch, ch), + nn.ReLU(), + nn.Linear(ch, ch), + ) + + self.dec_t = nn.ModuleList() + for ct in cts: + self.dec_t.append( + nn.Sequential( + nn.Linear(ch, ct), + nn.ReLU(inplace=True), + nn.Linear(ct, ct) + ) + ) + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def forward(self, fs, fts): + aligned_t = [self.align_t[i](fts[i]) for i in range(len(fts))] + aligned_s = self.align_s(fs) + + hts = [self.extractor(f) for f in aligned_t] + hs = self.extractor(aligned_s) + + _fts = [self.dec_t[i](hts[i]) for i in range(len(hts))] + return (hs, hts), (_fts, fts) + + +class CFL_ConvBlock(nn.Module): + """Common Feature Blocks for Convolutional layer + + This module is used to capture the common features of multiple teachers and calculate mmd with features of student. + + **Parameters:** + - cs (int): channel number of student features + - channel_ts (list or tuple): channel number list of teacher features + - ch (int): channel number of hidden features + """ + def __init__(self, cs, cts, ch, k_size=5): + super(CFL_ConvBlock, self).__init__() + + self.align_t = nn.ModuleList() + for ct in cts: + self.align_t.append( + nn.Sequential( + nn.Conv2d(in_channels=ct, out_channels=ch, + kernel_size=1), + nn.BatchNorm2d(ch), + nn.ReLU(inplace=True) + ) + ) + + self.align_s = nn.Sequential( + nn.Conv2d(in_channels=cs, out_channels=ch, + kernel_size=1), + nn.BatchNorm2d(ch), + nn.ReLU(inplace=True), + ) + + self.extractor = nn.Sequential( + ResBlock(inplanes=ch, planes=ch, stride=1), + ResBlock(inplanes=ch, planes=ch, stride=1), + ) + + self.dec_t = nn.ModuleList() + for ct in cts: + self.dec_t.append( + nn.Sequential( + nn.Conv2d(ch, ch, kernel_size=1, stride=1), + nn.BatchNorm2d(ch), + nn.ReLU(inplace=True), + nn.Conv2d(ch, ct, kernel_size=1, stride=1) + ) + ) + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def forward(self, fs, fts): + aligned_t = [self.align_t[i](fts[i]) for i in range(len(fts))] + aligned_s = self.align_s(fs) + + hts = [self.extractor(f) for f in aligned_t] + hs = self.extractor(aligned_s) + + _fts = [self.dec_t[i](hts[i]) for i in range(len(hts))] + return (hs, hts), (_fts, fts) + +class CommonFeatureAmalgamator(Engine): + + def setup( + self, + student, + teachers, + layer_groups: typing.Sequence[typing.Sequence], + layer_channels: typing.Sequence[typing.Sequence], + dataloader: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + weights = [1.0, 1.0, 1.0], + on_layer_input=False, + device = None, + ): + self._dataloader = dataloader + if device is None: + device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) + self._device = device + + self._model = self._student = student.to(self._device) + self._teachers = nn.ModuleList(teachers).to(self._device) + self._optimizer = optimizer + self._weights = weights + self._on_layer_input = on_layer_input + + amal_blocks = [] + for group, C in zip( layer_groups, layer_channels ): + hooks = [ FeatureHook(layer) for layer in group ] + if isinstance(group[0], nn.Linear): + amal_block = CFL_FCBlock( cs=C[0], cts=C[1:], ch=C[0]//4 ).to(self._device).train() + print("Building FC Blocks") + else: + amal_block = CFL_ConvBlock(cs=C[0], cts=C[1:], ch=C[0]//4).to(self._device).train() + print("Building Conv Blocks") + amal_blocks.append( (amal_block, hooks, C) ) + self._amal_blocks = amal_blocks + self._cfl_criterion = tasks.loss.CFLLoss( sigmas=[0.001, 0.01, 0.05, 0.1, 0.2, 1, 2] ) + + @property + def device(self): + return self._device + + def run(self, max_iter, start_iter=0, epoch_length=None): + block_params = [] + for block, _, _ in self._amal_blocks: + block_params.extend( list(block.parameters()) ) + if isinstance( self._optimizer, torch.optim.SGD ): + self._amal_optimimizer = torch.optim.SGD( block_params, lr=self._optimizer.param_groups[0]['lr'], momentum=0.9, weight_decay=1e-4 ) + else: + self._amal_optimimizer = torch.optim.Adam( block_params, lr=self._optimizer.param_groups[0]['lr'], weight_decay=1e-4 ) + self._amal_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self._amal_optimimizer, T_max=max_iter ) + + with set_mode(self._student, training=True), \ + set_mode(self._teachers, training=False): + super( CommonFeatureAmalgamator, self ).run(self.step_fn, self._dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length) + + def step_fn(self, engine, batch): + start_time = time.perf_counter() + batch = move_to_device(batch, self._device) + data = batch[0] + s_out = self._student( data ) + with torch.no_grad(): + t_out = [ teacher( data ) for teacher in self._teachers ] + loss_amal = 0 + loss_recons = 0 + for amal_block, hooks, C in self._amal_blocks: + features = [ h.feat_in if self._on_layer_input else h.feat_out for h in hooks ] + fs, fts = features[0], features[1:] + (hs, hts), (_fts, fts) = amal_block( fs, fts ) + _loss_amal, _loss_recons = self._cfl_criterion( hs, hts, _fts, fts ) + loss_amal += _loss_amal + loss_recons += _loss_recons + loss_kd = tasks.loss.kldiv( s_out, torch.cat( t_out, dim=1 ) ) + loss_dict = { + 'loss_kd': self._weights[0]*loss_kd, + 'loss_amal': self._weights[1]*loss_amal, + 'loss_recons': self._weights[2]*loss_recons + } + loss = sum(loss_dict.values()) + self._optimizer.zero_grad() + self._amal_optimimizer.zero_grad() + loss.backward() + self._optimizer.step() + self._amal_optimimizer.step() + self._amal_scheduler.step() + step_time = time.perf_counter() - start_time + + metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } + metrics.update({ + 'total_loss': loss.item(), + 'step_time': step_time, + 'lr': float( self._optimizer.param_groups[0]['lr'] ) + }) + return metrics diff --git a/model_measuring/kamal/amalgamation/layerwise_amalgamation.py b/model_measuring/kamal/amalgamation/layerwise_amalgamation.py new file mode 100644 index 0000000..4880fa5 --- /dev/null +++ b/model_measuring/kamal/amalgamation/layerwise_amalgamation.py @@ -0,0 +1,131 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from kamal.core.engine.engine import Engine +from kamal.core.engine.hooks import FeatureHook +from kamal.core import tasks + +from kamal.utils import set_mode +import typing +import time +from kamal.utils import move_to_device, set_mode + +class AmalBlock(nn.Module): + def __init__(self, cs, cts): + super( AmalBlock, self ).__init__() + self.cs, self.cts = cs, cts + self.enc = nn.Conv2d( in_channels=sum(self.cts), out_channels=self.cs, kernel_size=1, stride=1, padding=0, bias=True ) + self.fam = nn.Conv2d( in_channels=self.cs, out_channels=self.cs, kernel_size=1, stride=1, padding=0, bias=True ) + self.dec = nn.Conv2d( in_channels=self.cs, out_channels=sum(self.cts), kernel_size=1, stride=1, padding=0, bias=True ) + + def forward(self, fs, fts): + rep = self.enc( torch.cat( fts, dim=1 ) ) + _fts = self.dec( rep ) + _fts = torch.split( _fts, self.cts, dim=1 ) + _fs = self.fam( fs ) + return rep, _fs, _fts + +class LayerWiseAmalgamator(Engine): + + def setup( + self, + student, + teachers, + layer_groups: typing.Sequence[typing.Sequence], + layer_channels: typing.Sequence[typing.Sequence], + dataloader: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + weights = [1., 1., 1.], + device=None, + ): + if device is None: + device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) + self._device = device + self._dataloader = dataloader + self.model = self.student = student.to(self.device) + self.teachers = nn.ModuleList(teachers).to(self.device) + self.optimizer = optimizer + self._weights = weights + amal_blocks = [] + + for group, C in zip(layer_groups, layer_channels): + hooks = [ FeatureHook(layer) for layer in group ] + amal_block = AmalBlock(cs=C[0], cts=C[1:]).to(self.device).train() + amal_blocks.append( (amal_block, hooks, C) ) + self._amal_blocks = amal_blocks + @property + def device(self): + return self._device + def run(self, max_iter, start_iter=0, epoch_length=None ): + block_params = [] + for block, _, _ in self._amal_blocks: + block_params.extend( list(block.parameters()) ) + if isinstance( self.optimizer, torch.optim.SGD ): + self._amal_optimimizer = torch.optim.SGD( block_params, lr=self.optimizer.param_groups[0]['lr'], momentum=0.9, weight_decay=1e-4 ) + else: + self._amal_optimimizer = torch.optim.Adam( block_params, lr=self.optimizer.param_groups[0]['lr'], weight_decay=1e-4 ) + self._amal_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self._amal_optimimizer, T_max=max_iter ) + + with set_mode(self.student, training=True), \ + set_mode(self.teachers, training=False): + super( LayerWiseAmalgamator, self ).run(self.step_fn, self._dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length) + + @property + def device(self): + return self._device + + def step_fn(self, engine, batch): + start_time = time.perf_counter() + batch = move_to_device(batch, self._device) + data = batch[0] + s_out = self.student( data ) + with torch.no_grad(): + t_out = [ teacher( data ) for teacher in self.teachers ] + loss_amal = 0 + loss_recons = 0 + for amal_block, hooks, C in self._amal_blocks: + features = [ h.feat_out for h in hooks ] + fs, fts = features[0], features[1:] + rep, _fs, _fts = amal_block( fs, fts ) + loss_amal += F.mse_loss( _fs, rep.detach() ) + loss_recons += sum( [ F.mse_loss( _ft, ft ) for (_ft, ft) in zip( _fts, fts ) ] ) + loss_kd = tasks.loss.kldiv( s_out, torch.cat( t_out, dim=1 ) ) + #loss_kd = F.mse_loss( s_out, torch.cat( t_out, dim=1 ) ) + loss_dict = { "loss_kd": self._weights[0] * loss_kd, + "loss_amal": self._weights[1] * loss_amal, + "loss_recons": self._weights[2] * loss_recons } + loss = sum(loss_dict.values()) + self.optimizer.zero_grad() + self._amal_optimimizer.zero_grad() + loss.backward() + self.optimizer.step() + self._amal_optimimizer.step() + self._amal_scheduler.step() + step_time = time.perf_counter() - start_time + metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } + metrics.update({ + 'total_loss': loss.item(), + 'step_time': step_time, + 'lr': float( self.optimizer.param_groups[0]['lr'] ) + }) + return metrics + + diff --git a/model_measuring/kamal/amalgamation/recombination.py b/model_measuring/kamal/amalgamation/recombination.py new file mode 100644 index 0000000..11f14c6 --- /dev/null +++ b/model_measuring/kamal/amalgamation/recombination.py @@ -0,0 +1,209 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from copy import deepcopy + +from typing import Callable + +from kamal.core.engine.engine import Engine +from kamal.core.engine.trainer import KDTrainer +from kamal.core.engine.hooks import FeatureHook +from kamal.core import tasks +import math + +from kamal.slim.prunning import Pruner, strategy + +def _assert_same_type(layers, layer_type=None): + if layer_type is None: + layer_type = type(layers[0]) + + assert all(isinstance(l, layer_type) for l in layers), 'Model archictures must be the same' + +def _get_layers(model_list): + submodel = [ model.modules() for model in model_list ] + for layers in zip(*submodel): + _assert_same_type(layers) + yield layers + +def bn_combine_fn(layers): + """Combine 2D Batch Normalization Layers + + **Parameters:** + - **layers** (BatchNorm2D): Batch Normalization Layers. + """ + _assert_same_type(layers, nn.BatchNorm2d) + num_features = sum(l.num_features for l in layers) + combined_bn = nn.BatchNorm2d(num_features=num_features, + eps=layers[0].eps, + momentum=layers[0].momentum, + affine=layers[0].affine, + track_running_stats=layers[0].track_running_stats) + combined_bn.running_mean = torch.cat( + [l.running_mean for l in layers], dim=0).clone() + combined_bn.running_var = torch.cat( + [l.running_var for l in layers], dim=0).clone() + + if combined_bn.affine: + combined_bn.weight = torch.nn.Parameter( + torch.cat([l.weight.data.clone() for l in layers], dim=0).clone()) + combined_bn.bias = torch.nn.Parameter( + torch.cat([l.bias.data.clone() for l in layers], dim=0).clone()) + return combined_bn + + +def conv2d_combine_fn(layers): + """Combine 2D Conv Layers + + **Parameters:** + - **layers** (Conv2d): Conv Layers. + """ + _assert_same_type(layers, nn.Conv2d) + + CO, CI = 0, 0 + for l in layers: + O, I, H, W = l.weight.shape + CO += O + CI += I + + dtype = layers[0].weight.dtype + device = layers[0].weight.device + + combined_weight = torch.nn.Parameter( + torch.zeros(CO, CI, H, W, dtype=dtype, device=device)) + if layers[0].bias is not None: + combined_bias = torch.nn.Parameter( + torch.zeros(CO, dtype=dtype, device=device)) + else: + combined_bias = None + co_offset = 0 + ci_offset = 0 + for idx, l in enumerate(layers): + co_len, ci_len = l.weight.shape[0], l.weight.shape[1] + combined_weight[co_offset: co_offset+co_len, + ci_offset: ci_offset+ci_len, :, :] = l.weight.clone() + if combined_bias is not None: + combined_bias[co_offset: co_offset+co_len] = l.bias.clone() + co_offset += co_len + ci_offset += ci_offset + combined_conv2d = nn.Conv2d(in_channels=CI, + out_channels=CO, + kernel_size=layers[0].weight.shape[-2:], + stride=layers[0].stride, + padding=layers[0].padding, + bias=layers[0].bias is not None) + combined_conv2d.weight.data = combined_weight + if combined_bias is not None: + combined_conv2d.bias.data = combined_bias + for p in combined_conv2d.parameters(): + p.requires_grad = True + return combined_conv2d + + +def combine_models(models): + """Combine modules with parser + + **Parameters:** + - **models** (nn.Module): modules to be combined. + - **combine_parser** (function): layer selector + """ + def _recursively_combine(module): + module_output = module + + if isinstance( module, nn.Conv2d ): + combined_module = conv2d_combine_fn( layer_mapping[module] ) + elif isinstance( module, nn.BatchNorm2d ): + combined_module = bn_combine_fn( layer_mapping[module] ) + else: + combined_module = module + + if combined_module is not None: + module_output = combined_module + + for name, child in module.named_children(): + module_output.add_module(name, _recursively_combine(child)) + return module_output + + models = deepcopy(models) + combined_model = deepcopy(models[0]) # copy the model archicture and modify it with _recursively_combine + + layer_mapping = {} + for combined_layer, layers in zip(combined_model.modules(), _get_layers(models)): + layer_mapping[combined_layer] = layers # link to teachers + combined_model = _recursively_combine(combined_model) + return combined_model + + +class CombinedModel(nn.Module): + def __init__(self, models): + super( Combination, self ).__init__() + self.combined_model = combine_models( models ) + self.expand = len(models) + + def forward(self, x): + x.repeat( -1, x.shape[1]*self.expand, -1, -1 ) + return self.combined_model(x) + + +class PruningKDTrainer(KDTrainer): + def setup( + self, + student, + teachers, + task, + dataloader: torch.utils.data.DataLoader, + get_optimizer_and_scheduler:Callable=None, + pruning_rounds=5, + device=None, + ): + if device is None: + device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) + self._device = device + self._dataloader = dataloader + self.model = self.student = student.to(self.device) + self.teachers = nn.ModuleList(teachers).to(self.device) + self.get_optimizer_and_scheduler = get_optimizer_and_scheduler + @property + def device(self): + return self._device + def run(self, max_iter, start_iter=0, epoch_length=None, pruning_rounds=3, target_model_size=0.6 ): + pruning_size_per_round = 1 - math.pow( target_model_size, 1/pruning_rounds ) + prunner = Pruner( strategy.LNStrategy(n=1) ) + for pruning_round in range(pruning_rounds): + prunner.prune( self.student, rate=pruning_size_per_round, example_inputs=torch.randn(1,3,240,240) ) + self.student.to(self.device) + if self.get_optimizer_and_scheduler: + self.optimizer, self.scheduler = self.get_optimizer_and_scheduler( self.student ) + else: + self.optimizer = torch.optim.Adam( self.student.parameters(), lr=1e-4, weight_decay=1e-5 ) + self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( self.optimizer, T_max= (max_iter-start_iter)//pruning_rounds ) + step_iter = (max_iter - start_iter)//pruning_rounds + + with set_mode(self.student, training=True), \ + set_mode(self.teachers, training=False): + super( RecombinationAmalgamation, self ).run(self.step_fn, self._dataloader, + start_iter=start_iter+step_iter*pruning_round , max_iter=start_iter+step_iter*(pruning_round+1), epoch_length=epoch_length) + + def step_fn(self, engine, batch): + metrics = super(RecombinationAmalgamation, self).step_fn( engine, batch ) + self.scheduler.step() + return metrics + +class RecombinationAmalgamator(PruningKDTrainer): + pass \ No newline at end of file diff --git a/model_measuring/kamal/amalgamation/task_branching.py b/model_measuring/kamal/amalgamation/task_branching.py new file mode 100644 index 0000000..ddf9c1a --- /dev/null +++ b/model_measuring/kamal/amalgamation/task_branching.py @@ -0,0 +1,295 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from kamal.core.engine.engine import Engine +from kamal.core.engine.hooks import FeatureHook +from kamal.core import tasks +from kamal.utils import move_to_device, set_mode +from kamal.core.hub import meta +from kamal import vision +import kamal + +from kamal.utils import set_mode +import typing +import time + +from copy import deepcopy +import random +import numpy as np +from collections import defaultdict +import numbers + +class BranchySegNet(nn.Module): + def __init__(self, out_channels, segnet_fn=vision.models.segmentation.segnet_vgg16_bn): + super(BranchySegNet, self).__init__() + channels=[512, 512, 256, 128, 64] + self.register_buffer( 'branch_indices', torch.zeros((len(out_channels),)) ) + + self.student_b_decoders_list = nn.ModuleList() + self.student_adaptors_list = nn.ModuleList() + + ses = [] + for i in range(5): + se = int(channels[i]/4) + ses.append(16 if se < 16 else se) + + for oc in out_channels: + segnet = self.get_segnet( oc, segnet_fn ) + decoders = nn.ModuleList(deepcopy(list(segnet.children())[5:])) + adaptors = nn.ModuleList() + for i in range(5): + adaptor = nn.Sequential( + nn.Conv2d(channels[i], ses[i], kernel_size=1, bias=False), + nn.ReLU(), + nn.Conv2d(ses[i], channels[i], kernel_size=1, bias=False), + nn.Sigmoid() + ) + adaptors.append(adaptor) + self.student_b_decoders_list.append(decoders) + self.student_adaptors_list.append(adaptors) + + self.student_encoders = nn.ModuleList(deepcopy(list(segnet.children())[0:5])) + self.student_decoders = nn.ModuleList(deepcopy(list(segnet.children())[5:])) + + def set_branch(self, branch_indices): + assert len(branch_indices)==len(self.student_b_decoders_list) + self.branch_indices = torch.from_numpy( np.array( branch_indices ) ).to(self.branch_indices.device) + + def get_segnet(self, oc, segnet_fn): + return segnet_fn( num_classes=oc, pretrained_backbone=True ) + + def forward(self, inputs): + output_list = [] + down1, indices_1, unpool_shape1 = self.student_encoders[0](inputs) + down2, indices_2, unpool_shape2 = self.student_encoders[1](down1) + down3, indices_3, unpool_shape3 = self.student_encoders[2](down2) + down4, indices_4, unpool_shape4 = self.student_encoders[3](down3) + down5, indices_5, unpool_shape5 = self.student_encoders[4](down4) + + up5 = self.student_decoders[0](down5, indices_5, unpool_shape5) + up4 = self.student_decoders[1](up5, indices_4, unpool_shape4) + up3 = self.student_decoders[2](up4, indices_3, unpool_shape3) + up2 = self.student_decoders[3](up3, indices_2, unpool_shape2) + up1 = self.student_decoders[4](up2, indices_1, unpool_shape1) + + decoder_features = [down5, up5, up4, up3, up2] + decoder_indices = [indices_5, indices_4, indices_3, indices_2, indices_1] + decoder_shapes = [unpool_shape5, unpool_shape4, unpool_shape3, unpool_shape2, unpool_shape1] + + # Mimic teachers. + for i in range(len(self.branch_indices)): + out_idx = self.branch_indices[i] + output = decoder_features[out_idx] + output = output * self.student_adaptors_list[i][out_idx](F.avg_pool2d(output, output.shape[2:3])) + for j in range(out_idx, 5): + output = self.student_b_decoders_list[i][j]( + output, + decoder_indices[j], + decoder_shapes[j] + ) + output_list.append( output ) + return output_list + +class JointSegNet(nn.Module): + """The online student model to learn from any number of single teacher with 'SegNet' structure. + + **Parameters:** + - **teachers** (list of 'Module' object): Teachers with 'SegNet' structure to learn from. + - **indices** (list of int): Where to branch out for each task. + - **phase** (string): Should be 'block' or 'finetune'. Useful only in training mode. + - **channels** (list of int, optional): Parameter to build adaptor modules, corresponding to that of 'SegNet'. + """ + def __init__(self, teachers, student=None, channels=[512, 512, 256, 128, 64]): + super(JointSegNet, self).__init__() + self.register_buffer( 'branch_indices', torch.zeros((2,)) ) + + if student is None: + student = teachers[0] + + self.student_encoders = nn.ModuleList(deepcopy(list(teachers[0].children())[0:5])) + self.student_decoders = nn.ModuleList(deepcopy(list(teachers[0].children())[5:])) + self.student_b_decoders_list = nn.ModuleList() + self.student_adaptors_list = nn.ModuleList() + + ses = [] + for i in range(5): + se = int(channels[i]/4) + ses.append(16 if se < 16 else se) + + for teacher in teachers: + decoders = nn.ModuleList(deepcopy(list(teacher.children())[5:])) + adaptors = nn.ModuleList() + for i in range(5): + adaptor = nn.Sequential( + nn.Conv2d(channels[i], ses[i], kernel_size=1, bias=False), + nn.ReLU(), + nn.Conv2d(ses[i], channels[i], kernel_size=1, bias=False), + nn.Sigmoid() + ) + adaptors.append(adaptor) + self.student_b_decoders_list.append(decoders) + self.student_adaptors_list.append(adaptors) + + def set_branch(self, branch_indices): + assert len(branch_indices)==len(self.student_b_decoders_list) + self.branch_indices = torch.from_numpy( np.array( branch_indices ) ).to(self.branch_indices.device) + + def forward(self, inputs): + + output_list = [] + + down1, indices_1, unpool_shape1 = self.student_encoders[0](inputs) + down2, indices_2, unpool_shape2 = self.student_encoders[1](down1) + down3, indices_3, unpool_shape3 = self.student_encoders[2](down2) + down4, indices_4, unpool_shape4 = self.student_encoders[3](down3) + down5, indices_5, unpool_shape5 = self.student_encoders[4](down4) + + up5 = self.student_decoders[0](down5, indices_5, unpool_shape5) + up4 = self.student_decoders[1](up5, indices_4, unpool_shape4) + up3 = self.student_decoders[2](up4, indices_3, unpool_shape3) + up2 = self.student_decoders[3](up3, indices_2, unpool_shape2) + up1 = self.student_decoders[4](up2, indices_1, unpool_shape1) + + decoder_features = [down5, up5, up4, up3, up2] + decoder_indices = [indices_5, indices_4, indices_3, indices_2, indices_1] + decoder_shapes = [unpool_shape5, unpool_shape4, unpool_shape3, unpool_shape2, unpool_shape1] + + # Mimic teachers. + for i in range(len(self.branch_indices)): + out_idx = self.branch_indices[i] + output = decoder_features[out_idx] + output = output * self.student_adaptors_list[i][out_idx](F.avg_pool2d(output, output.shape[2:3])) + for j in range(out_idx, 5): + output = self.student_b_decoders_list[i][j]( + output, + decoder_indices[j], + decoder_shapes[j] + ) + output_list.append( output ) + return output_list + + +class TaskBranchingAmalgamator(Engine): + def setup( + self, + joint_student: JointSegNet, + teachers, + tasks, + dataloader: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + device=None, + ): + if device is None: + device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) + self._device = device + self._dataloader = dataloader + self.student = self.model = joint_student.to(self._device) + self.teachers = nn.ModuleList(teachers).to(self._device) + self.tasks = tasks + self.optimizer = optimizer + + self.is_finetuning=False + + @property + def device(self): + return self._device + + def run(self, max_iter, start_iter=0, epoch_length=None, stage_callback=None ): + # Branching + with set_mode(self.student, training=True), \ + set_mode(self.teachers, training=False): + super( TaskBranchingAmalgamator, self ).run(self.step_fn, self._dataloader, start_iter=start_iter, max_iter=max_iter//2, epoch_length=epoch_length) + branch = self.find_the_best_branch( self._dataloader ) + self.logger.info("[Task Branching] the best branch indices: %s"%(branch)) + + if stage_callback is not None: + stage_callback() + + # Finetuning + self.is_finetuning = True + with set_mode(self.student, training=True), \ + set_mode(self.teachers, training=False): + super( TaskBranchingAmalgamator, self ).run(self.step_fn, self._dataloader, start_iter=max_iter-max_iter//2, max_iter=max_iter, epoch_length=epoch_length) + + def find_the_best_branch(self, dataloader): + + with set_mode(self.student, training=False), \ + set_mode(self.teachers, training=False), \ + torch.no_grad(): + n_blocks = len(self.student.student_decoders) + branch_loss = { task: [0 for _ in range(n_blocks)] for task in self.tasks } + for batch in dataloader: + batch = move_to_device(batch, self.device) + data = batch if isinstance(batch, torch.Tensor) else batch[0] + for candidate_branch in range( n_blocks ): + self.student.set_branch( [candidate_branch for _ in range(len(self.teachers))] ) + s_out_list = self.student( data ) + t_out_list = [ teacher( data ) for teacher in self.teachers ] + for task, s_out, t_out in zip( self.tasks, s_out_list, t_out_list ): + task_loss = task.get_loss( s_out, t_out ) + branch_loss[task][candidate_branch] += sum(task_loss.values()) + best_brach = [] + for task in self.tasks: + best_brach.append( int(np.argmin( branch_loss[task] )) ) + + self.student.set_branch(best_brach) + return best_brach + + @property + def device(self): + return self._device + + def step_fn(self, engine, batch): + start_time = time.perf_counter() + batch = move_to_device(batch, self._device) + data = batch[0] + #data = batch if isinstance(batch, torch.Tensor) else batch[0] + data, None + n_blocks = len(self.student.student_decoders) + if not self.is_finetuning: + rand_branch_indices = [ random.randint(0, n_blocks-1) for _ in range(len(self.teachers)) ] + self.student.set_branch(rand_branch_indices) + + s_out_list = self.student( data ) + with torch.no_grad(): + t_out_list = [ teacher( data ) for teacher in self.teachers ] + + loss_dict = {} + for task, s_out, t_out in zip( self.tasks, s_out_list, t_out_list ): + task_loss = task.get_loss( s_out, t_out ) + loss_dict.update( task_loss ) + loss = sum(loss_dict.values()) + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + step_time = time.perf_counter() - start_time + metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } + metrics.update({ + 'total_loss': loss.item(), + 'step_time': step_time, + 'lr': float( self.optimizer.param_groups[0]['lr'] ), + 'branch': self.student.branch_indices.cpu().numpy().tolist() + }) + return metrics + + diff --git a/model_measuring/kamal/core/__init__.py b/model_measuring/kamal/core/__init__.py new file mode 100644 index 0000000..437fc6c --- /dev/null +++ b/model_measuring/kamal/core/__init__.py @@ -0,0 +1,4 @@ +from . import engine, tasks, metrics, callbacks, exceptions, hub +from .attach import AttachTo + +from .hub import load, save \ No newline at end of file diff --git a/model_measuring/kamal/core/attach.py b/model_measuring/kamal/core/attach.py new file mode 100644 index 0000000..f928254 --- /dev/null +++ b/model_measuring/kamal/core/attach.py @@ -0,0 +1,45 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +from typing import Sequence, Callable +from numbers import Number +from kamal.core import exceptions + +class AttachTo(object): + """ Attach task, metrics or visualizer to specified model outputs + """ + def __init__(self, attach_to=None): + if attach_to is not None and not isinstance(attach_to, (Sequence, Number, str, Callable) ): + raise exceptions.InvalidMapping + self._attach_to = attach_to + + def __call__(self, *tensors): + if self._attach_to is not None: + if isinstance(self._attach_to, Callable): + return self._attach_to( *tensors ) + if isinstance(self._attach_to, Sequence): + _attach_to = self._attach_to + else: + _attach_to = [ self._attach_to for _ in range(len(tensors)) ] + _attach_to = _attach_to[:len(tensors)] + tensors = [ tensor[index] for (tensor, index) in zip( tensors, _attach_to ) ] + if len(tensors)==1: + tensors = tensors[0] + return tensors + + def __repr__(self): + rep = "AttachTo: %s"%(self._attach_to) \ No newline at end of file diff --git a/model_measuring/kamal/core/callbacks/__init__.py b/model_measuring/kamal/core/callbacks/__init__.py new file mode 100644 index 0000000..fca4ef6 --- /dev/null +++ b/model_measuring/kamal/core/callbacks/__init__.py @@ -0,0 +1,5 @@ +from .logging import MetricsLogging, ProgressCallback +from .base import Callback +from .eval_and_ckpt import EvalAndCkpt +from .scheduler import LRSchedulerCallback +from .visualize import VisualizeOutputs, VisualizeSegmentation, VisualizeDepth \ No newline at end of file diff --git a/model_measuring/kamal/core/callbacks/base.py b/model_measuring/kamal/core/callbacks/base.py new file mode 100644 index 0000000..166b49d --- /dev/null +++ b/model_measuring/kamal/core/callbacks/base.py @@ -0,0 +1,28 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +import abc + +class Callback(abc.ABC): + r""" Base Class for Callbacks + """ + def __init__(self): + pass + + @abc.abstractmethod + def __call__(self, engine): + pass \ No newline at end of file diff --git a/model_measuring/kamal/core/callbacks/eval_and_ckpt.py b/model_measuring/kamal/core/callbacks/eval_and_ckpt.py new file mode 100644 index 0000000..fd617de --- /dev/null +++ b/model_measuring/kamal/core/callbacks/eval_and_ckpt.py @@ -0,0 +1,145 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +from .base import Callback +import weakref +from kamal import utils +from typing import Sequence, Optional +import numbers +import os, shutil +import torch + +class EvalAndCkpt(Callback): + def __init__(self, + model, + evaluator, + metric_name:str, + metric_mode:str ='max', + save_type:Optional[Sequence]=('best', 'latest'), + ckpt_dir:str ='checkpoints', + ckpt_prefix:str =None, + log_tag:str ='model', + weights_only:bool =True, + verbose:bool =False,): + super(EvalAndCkpt, self).__init__() + self.metric_name = metric_name + assert metric_mode in ('max', 'min'), "metric_mode should be 'max' or 'min'" + self._metric_mode = metric_mode + + self._model = weakref.ref( model ) + self._evaluator = evaluator + self._ckpt_dir = ckpt_dir + self._ckpt_prefix = "" if ckpt_prefix is None else (ckpt_prefix+'_') + if isinstance(save_type, str): + save_type = ( save_type, ) + self._save_type = save_type + if self._save_type is not None: + for save_type in self._save_type: + assert save_type in ('best', 'latest', 'all'), \ + 'save_type should be None or a subset of (\"best\", \"latest\", \"all\")' + self._log_tag = log_tag + self._weights_only = weights_only + self._verbose = verbose + + self._best_score = -999999 if self._metric_mode=='max' else 99999. + self._best_ckpt = None + self._latest_ckpt = None + + @property + def best_ckpt(self): + return self._best_ckpt + + @property + def latest_ckpt(self): + return self._latest_ckpt + + @property + def best_score(self): + return self._best_score + + def __call__(self, trainer): + model = self._model() + results = self._evaluator.eval( model, device=trainer.device ) + results = utils.flatten_dict(results) + current_score = results[self.metric_name] + + scalar_results = { k: float(v) for (k, v) in results.items() if isinstance(v, numbers.Number) or len(v.shape)==0 } + if trainer.logger is not None: + trainer.logger.info( "[Eval %s] Iter %d/%d: %s"%(self._log_tag, trainer.state.iter, trainer.state.max_iter, scalar_results) ) + trainer.state.metrics.update( scalar_results ) + # Visualize results if trainer.tb_writer is not None + + if trainer.tb_writer is not None: + for k, v in scalar_results.items(): + log_tag = "%s:%s"%(self._log_tag, k) + trainer.tb_writer.add_scalar(log_tag, v, global_step=trainer.state.iter) + + if self._save_type is not None: + pth_path_list = [] + # interval model + if 'interval' in self._save_type: + pth_path = os.path.join(self._ckpt_dir, "%s%08d_%s_%.3f.pth" + % (self._ckpt_prefix, trainer.state.iter, self.metric_name, current_score)) + pth_path_list.append(pth_path) + + # the latest model + if 'latest' in self._save_type: + pth_path = os.path.join(self._ckpt_dir, "%slatest_%08d_%s_%.3f.pth" + % (self._ckpt_prefix, trainer.state.iter, self.metric_name, current_score)) + # remove the old ckpt + if self._latest_ckpt is not None and os.path.exists(self._latest_ckpt): + os.remove(self._latest_ckpt) + pth_path_list.append(pth_path) + self._latest_ckpt = pth_path + + # the best model + if 'best' in self._save_type: + if (current_score >= self._best_score and self._metric_mode=='max') or \ + (current_score <= self._best_score and self._metric_mode=='min'): + pth_path = os.path.join(self._ckpt_dir, "%sbest_%08d_%s_%.4f.pth" % + (self._ckpt_prefix, trainer.state.iter, self.metric_name, current_score)) + # remove the old ckpt + if self._best_ckpt is not None and os.path.exists(self._best_ckpt): + os.remove(self._best_ckpt) + pth_path_list.append(pth_path) + self._best_score = current_score + self._best_ckpt = pth_path + + # save model + if self._verbose and trainer.logger: + trainer.logger.info("Model saved as:") + obj = model.state_dict() if self._weights_only else model + os.makedirs( self._ckpt_dir, exist_ok=True ) + for pth_path in pth_path_list: + torch.save(obj, pth_path) + if self._verbose and trainer.logger: + trainer.logger.info("\t%s" % (pth_path)) + + def final_ckpt(self, ckpt_prefix=None, ckpt_dir=None, add_md5=False): + if ckpt_dir is None: + ckpt_dir = self._ckpt_dir + if ckpt_prefix is None: + ckpt_prefix = self._ckpt_prefix + if self._save_type is not None: + #if 'latest' in self._save_type and self._latest_ckpt is not None: + # os.makedirs(ckpt_dir, exist_ok=True) + # save_name = "%slatest%s.pth"%(ckpt_prefix, "" if not add_md5 else "-%s"%utils.md5(self._latest_ckpt)) + # shutil.copy2(self._latest_ckpt, os.path.join(ckpt_dir, save_name)) + if 'best' in self._save_type and self._best_ckpt is not None: + os.makedirs(ckpt_dir, exist_ok=True) + save_name = "%sbest%s.pth"%(ckpt_prefix, "" if not add_md5 else "-%s"%utils.md5(self._best_ckpt)) + shutil.copy2(self._best_ckpt, os.path.join(ckpt_dir, save_name)) \ No newline at end of file diff --git a/model_measuring/kamal/core/callbacks/logging.py b/model_measuring/kamal/core/callbacks/logging.py new file mode 100644 index 0000000..79c6119 --- /dev/null +++ b/model_measuring/kamal/core/callbacks/logging.py @@ -0,0 +1,61 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +from .base import Callback +import numbers +from tqdm import tqdm + +class MetricsLogging(Callback): + def __init__(self, keys): + super(MetricsLogging, self).__init__() + self._keys = keys + + def __call__(self, engine): + if engine.logger==None: + return + state = engine.state + content = "Iter %d/%d (Epoch %d/%d, Batch %d/%d)"%( + state.iter, state.max_iter, + state.current_epoch, state.max_epoch, + state.current_batch_index, state.max_batch_index + ) + for key in self._keys: + value = state.metrics.get(key, None) + if value is not None: + if isinstance(value, numbers.Number): + content += " %s=%.4f"%(key, value) + if engine.tb_writer is not None: + engine.tb_writer.add_scalar(key, value, global_step=state.iter) + elif isinstance(value, (list, tuple)): + content += " %s=%s"%(key, value) + + engine.logger.info(content) + +class ProgressCallback(Callback): + def __init__(self, max_iter=100, tag=None): + self._max_iter = max_iter + self._tag = tag + #self._pbar = tqdm(total=self._max_iter, desc=self._tag) + + def __call__(self, engine): + self._pbar.update(1) + if self._pbar.n==self._max_iter: + self._pbar.close() + + def reset(self): + self._pbar = tqdm(total=self._max_iter, desc=self._tag) + \ No newline at end of file diff --git a/model_measuring/kamal/core/callbacks/scheduler.py b/model_measuring/kamal/core/callbacks/scheduler.py new file mode 100644 index 0000000..8924949 --- /dev/null +++ b/model_measuring/kamal/core/callbacks/scheduler.py @@ -0,0 +1,34 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +from .base import Callback +from typing import Sequence + +class LRSchedulerCallback(Callback): + r""" LR scheduler callback + """ + def __init__(self, schedulers=None): + super(LRSchedulerCallback, self).__init__() + if not isinstance(schedulers, Sequence): + schedulers = ( schedulers, ) + self._schedulers = schedulers + + def __call__(self, trainer): + if self._schedulers is None: + return + for sched in self._schedulers: + sched.step() \ No newline at end of file diff --git a/model_measuring/kamal/core/callbacks/visualize.py b/model_measuring/kamal/core/callbacks/visualize.py new file mode 100644 index 0000000..2819302 --- /dev/null +++ b/model_measuring/kamal/core/callbacks/visualize.py @@ -0,0 +1,161 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +from .base import Callback +from typing import Callable, Union, Sequence +import weakref +import random +from kamal.utils import move_to_device, set_mode, split_batch, colormap +from kamal.core.attach import AttachTo +import torch +import numpy as np + +import matplotlib.pyplot as plt +import matplotlib +matplotlib.use('agg') +import math +import numbers + +class VisualizeOutputs(Callback): + def __init__(self, + model, + dataset: torch.utils.data.Dataset, + idx_list_or_num_vis: Union[int, Sequence]=5, + normalizer: Callable=None, + prepare_fn: Callable=None, + decode_fn: Callable=None, # decode targets and preds + tag: str='viz'): + super(VisualizeOutputs, self).__init__() + self._dataset = dataset + self._model = weakref.ref(model) + if isinstance(idx_list_or_num_vis, int): + self.idx_list = self._get_vis_idx_list(self._dataset, idx_list_or_num_vis) + elif isinstance(idx_list_or_num_vis, Sequence): + self.idx_list = idx_list_or_num_vis + self._normalizer = normalizer + self._decode_fn = decode_fn + if prepare_fn is None: + prepare_fn = VisualizeOutputs.get_prepare_fn() + self._prepare_fn = prepare_fn + self._tag = tag + + def _get_vis_idx_list(self, dataset, num_vis): + return random.sample(list(range(len(dataset))), num_vis) + + @torch.no_grad() + def __call__(self, trainer): + if trainer.tb_writer is None: + trainer.logger.warning("summary writer was not found in trainer") + return + device = trainer.device + model = self._model() + with torch.no_grad(), set_mode(model, training=False): + for img_id, idx in enumerate(self.idx_list): + batch = move_to_device(self._dataset[idx], device) + batch = [ d.unsqueeze(0) for d in batch ] + inputs, targets, preds = self._prepare_fn(model, batch) + if self._normalizer is not None: + inputs = self._normalizer(inputs) + inputs = inputs.detach().cpu().numpy() + preds = preds.detach().cpu().numpy() + targets = targets.detach().cpu().numpy() + if self._decode_fn: # to RGB 0~1 NCHW + preds = self._decode_fn(preds) + targets = self._decode_fn(targets) + inputs = inputs[0] + preds = preds[0] + targets = targets[0] + trainer.tb_writer.add_images("%s-%d"%(self._tag, img_id), np.stack( [inputs, targets, preds], axis=0), global_step=trainer.state.iter) + + @staticmethod + def get_prepare_fn(attach_to=None, pred_fn=lambda x: x): + attach_to = AttachTo(attach_to) + def wrapper(model, batch): + inputs, targets = split_batch(batch) + outputs = model(inputs) + outputs, targets = attach_to(outputs, targets) + return inputs, targets, pred_fn(outputs) + return wrapper + + @staticmethod + def get_seg_decode_fn(cmap=colormap(), index_transform=lambda x: x+1): # 255->0, 0->1, + def wrapper(preds): + if len(preds.shape)>3: + preds = preds.squeeze(1) + out = cmap[ index_transform(preds.astype('uint8')) ] + out = out.transpose(0, 3, 1, 2) / 255 + return out + return wrapper + + @staticmethod + def get_depth_decode_fn(max_depth, log_scale=True, cmap=plt.get_cmap('jet')): + def wrapper(preds): + if log_scale: + _max_depth = np.log( max_depth ) + preds = np.log( preds ) + else: + _max_depth = max_depth + if len(preds.shape)>3: + preds = preds.squeeze(1) + out = (cmap(preds.clip(0, _max_depth)/_max_depth)).transpose(0, 3, 1, 2)[:, :3] + return out + return wrapper + +class VisualizeSegmentation(VisualizeOutputs): + def __init__( + self, model, dataset: torch.utils.data.Dataset, idx_list_or_num_vis: Union[int, Sequence]=5, + cmap = colormap(), + attach_to=None, + + normalizer: Callable=None, + prepare_fn: Callable=None, + decode_fn: Callable=None, + tag: str='seg' + ): + if prepare_fn is None: + prepare_fn = VisualizeOutputs.get_prepare_fn(attach_to=attach_to, pred_fn=lambda x: x.max(1)[1]) + if decode_fn is None: + decode_fn = VisualizeOutputs.get_seg_decode_fn(cmap=cmap, index_transform=lambda x: x+1) + + super(VisualizeSegmentation, self).__init__( + model=model, dataset=dataset, idx_list_or_num_vis=idx_list_or_num_vis, + normalizer=normalizer, prepare_fn=prepare_fn, decode_fn=decode_fn, + tag=tag + ) + +class VisualizeDepth(VisualizeOutputs): + def __init__( + self, model, dataset: torch.utils.data.Dataset, + idx_list_or_num_vis: Union[int, Sequence]=5, + max_depth = 10, + log_scale = True, + attach_to = None, + + normalizer: Callable=None, + prepare_fn: Callable=None, + decode_fn: Callable=None, + tag: str='depth' + ): + if prepare_fn is None: + prepare_fn = VisualizeOutputs.get_prepare_fn(attach_to=attach_to, pred_fn=lambda x: x) + if decode_fn is None: + decode_fn = VisualizeOutputs.get_depth_decode_fn(max_depth=max_depth, log_scale=log_scale) + super(VisualizeDepth, self).__init__( + model=model, dataset=dataset, idx_list_or_num_vis=idx_list_or_num_vis, + normalizer=normalizer, prepare_fn=prepare_fn, decode_fn=decode_fn, + tag=tag + ) \ No newline at end of file diff --git a/model_measuring/kamal/core/engine/__init__.py b/model_measuring/kamal/core/engine/__init__.py new file mode 100644 index 0000000..a8df97d --- /dev/null +++ b/model_measuring/kamal/core/engine/__init__.py @@ -0,0 +1,7 @@ +from . import evaluator +from . import trainer +from . import lr_finder +from . import hooks +from . import events + +from .engine import DefaultEvents, Event \ No newline at end of file diff --git a/model_measuring/kamal/core/engine/engine.py b/model_measuring/kamal/core/engine/engine.py new file mode 100644 index 0000000..f74cb46 --- /dev/null +++ b/model_measuring/kamal/core/engine/engine.py @@ -0,0 +1,190 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +import torch +import torch.nn as nn +import abc, math, weakref, typing, time +from typing import Any, Callable, Optional, Sequence +import numpy as np + +from kamal.core.engine.events import DefaultEvents, Event +from kamal.core import tasks +from kamal.utils import set_mode, move_to_device, get_logger +from collections import defaultdict + +import numbers +import contextlib + +class State(object): + def __init__(self): + self.iter = 0 + self.max_iter = None + self.epoch_length = None + self.dataloader = None + self.seed = None + + self.metrics=dict() + self.batch=None + + @property + def current_epoch(self): + if self.epoch_length is not None: + return self.iter // self.epoch_length + return None + + @property + def max_epoch(self): + if self.epoch_length is not None: + return self.max_iter // self.epoch_length + return None + + @property + def current_batch_index(self): + if self.epoch_length is not None: + return self.iter % self.epoch_length + return None + + @property + def max_batch_index(self): + return self.epoch_length + + def __repr__(self): + rep = "State:\n" + for attr, value in self.__dict__.items(): + if not isinstance(value, (numbers.Number, str, dict)): + value = type(value) + rep += "\t{}: {}\n".format(attr, value) + return rep + +class Engine(abc.ABC): + def __init__(self, logger=None, tb_writer=None): + self._logger = logger if logger else get_logger(name='kamal', color=True) + self._tb_writer = tb_writer + self._callbacks = defaultdict(list) + self._allowed_events = [ *DefaultEvents ] + self._state = State() + + def reset(self): + self._state = State() + + def run(self, step_fn: Callable, dataloader, max_iter, start_iter=0, epoch_length=None): + self.state.iter = self._state.start_iter = start_iter + self.state.max_iter = max_iter + self.state.epoch_length = epoch_length if epoch_length else len(dataloader) + self.state.dataloader = dataloader + self.state.dataloader_iter = iter(dataloader) + self.state.step_fn = step_fn + + self.trigger_events(DefaultEvents.BEFORE_RUN) + for self.state.iter in range( start_iter, max_iter ): + if self.state.epoch_length!=None and \ + self.state.iter%self.state.epoch_length==0: # Epoch Start + self.trigger_events(DefaultEvents.BEFORE_EPOCH) + self.trigger_events(DefaultEvents.BEFORE_STEP) + self.state.batch = self._get_batch() + step_output = step_fn(self, self.state.batch) + if isinstance(step_output, dict): + self.state.metrics.update(step_output) + self.trigger_events(DefaultEvents.AFTER_STEP) + if self.state.epoch_length!=None and \ + (self.state.iter+1)%self.state.epoch_length==0: # Epoch End + self.trigger_events(DefaultEvents.AFTER_EPOCH) + self.trigger_events(DefaultEvents.AFTER_RUN) + + def _get_batch(self): + try: + batch = next( self.state.dataloader_iter ) + except StopIteration: + self.state.dataloader_iter = iter(self.state.dataloader) # reset iterator + batch = next( self.state.dataloader_iter ) + if not isinstance(batch, (list, tuple)): + batch = [ batch, ] # no targets + return batch + + @property + def state(self): + return self._state + + @property + def logger(self): + return self._logger + + @property + def tb_writer(self): + return self._tb_writer + + def add_callback(self, event: Event, callbacks ): + if not isinstance(callbacks, Sequence): + callbacks = [callbacks] + if event in self._allowed_events: + for callback in callbacks: + if callback not in self._callbacks[event]: + if event.trigger!=event.default_trigger: + callback = self._trigger_wrapper(self, event.trigger, callback ) + self._callbacks[event].append( callback ) + callbacks = [ RemovableCallback(self, event, c) for c in callbacks ] + return ( callbacks[0] if len(callbacks)==1 else callbacks ) + + def remove_callback(self, event, callback): + for c in self._callbacks[event]: + if c==callback: + self._callbacks.remove( callback ) + return True + return False + + @staticmethod + def _trigger_wrapper(engine, trigger, callback): + def wrapper(*args, **kwargs) -> Any: + if trigger(engine): + return callback(engine) + return wrapper + + def trigger_events(self, *events): + for e in events: + if e in self._allowed_events: + for callback in self._callbacks[e]: + callback(self) + + def register_events(self, *events): + for e in events: + if e not in self._allowed_events: + self._allowed_events.apped( e ) + + @contextlib.contextmanager + def save_current_callbacks(self): + temp = self._callbacks + self._callbacks = defaultdict(list) + yield + self._callbacks = temp + +class RemovableCallback: + def __init__(self, engine, event, callback): + self._engine = weakref.ref(engine) + self._callback = weakref.ref(callback) + self._event = weakref.ref(event) + + @property + def callback(self): + return self._callback() + + def remove(self): + engine = self._engine() + callback = self._callback() + event = self._event() + return engine.remove_callback(event, callback) + + \ No newline at end of file diff --git a/model_measuring/kamal/core/engine/evaluator.py b/model_measuring/kamal/core/engine/evaluator.py new file mode 100644 index 0000000..c5dbb6e --- /dev/null +++ b/model_measuring/kamal/core/engine/evaluator.py @@ -0,0 +1,130 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +import abc, sys +import torch +from tqdm import tqdm +from kamal.core import metrics +from kamal.utils import set_mode +from typing import Any, Callable +from .engine import Engine +from .events import DefaultEvents +from kamal.core import callbacks + +import weakref +from kamal.utils import move_to_device, split_batch + +class BasicEvaluator(Engine): + def __init__(self, + dataloader: torch.utils.data.DataLoader, + metric: metrics.MetricCompose, + eval_fn: Callable=None, + tag: str='Eval', + progress: bool=False ): + super( BasicEvaluator, self ).__init__() + self.dataloader = dataloader + self.metric = metric + self.progress = progress + if progress: + self.porgress_callback = self.add_callback( + DefaultEvents.AFTER_STEP, callbacks=callbacks.ProgressCallback(max_iter=len(self.dataloader), tag=tag)) + self._model = None + self._tag = tag + if eval_fn is None: + eval_fn = BasicEvaluator.default_eval_fn + self.eval_fn = eval_fn + + def eval(self, model, device=None): + device = device if device is not None else \ + torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) + self._model = weakref.ref(model) # use weakref here + self.device = device + self.metric.reset() + model.to(device) + if self.progress: + self.porgress_callback.callback.reset() + with torch.no_grad(), set_mode(model, training=False): + super(BasicEvaluator, self).run( self.step_fn, self.dataloader, max_iter=len(self.dataloader) ) + return self.metric.get_results() + + @property + def model(self): + if self._model is not None: + return self._model() + return None + + def step_fn(self, engine, batch): + batch = move_to_device(batch, self.device) + self.eval_fn( engine, batch ) + + @staticmethod + def default_eval_fn(evaluator, batch): + model = evaluator.model + inputs, targets = split_batch(batch) + outputs = model( inputs ) + evaluator.metric.update( outputs, targets ) + + +class TeacherEvaluator(BasicEvaluator): + def __init__(self, + dataloader: torch.utils.data.DataLoader, + teacher: torch.nn.Module, + task, + metric: metrics.MetricCompose, + eval_fn: Callable=None, + tag: str='Eval', + progress: bool=False ): + if eval_fn is None: + eval_fn = TeacherEvaluator.default_eval_fn + super( TeacherEvaluator, self ).__init__(dataloader=dataloader, metric=metric, eval_fn=eval_fn, tag=tag, progress=progress) + self._teacher = teacher + self.task = task + + def eval(self, model, device=None): + self.teacher.to(device) + with set_mode(self.teacher, training=False): + return super(TeacherEvaluator, self).eval( model, device=device ) + + @property + def model(self): + if self._model is not None: + return self._model() + return None + + @property + def teacher(self): + return self._teacher + + def step_fn(self, engine, batch): + batch = move_to_device(batch, self.device) + self.eval_fn( engine, batch ) + + @staticmethod + def default_eval_fn(evaluator, batch): + model = evaluator.model + teacher = evaluator.teacher + + inputs, targets = split_batch(batch) + outputs = model( inputs ) + + # get teacher outputs + if isinstance(teacher, torch.nn.ModuleList): + targets = [ task.predict(tea(inputs)) for (tea, task) in zip(teacher, evaluator.task) ] + else: + t_outputs = teacher(inputs) + targets = evaluator.task.predict( t_outputs ) + evaluator.metric.update( outputs, targets ) \ No newline at end of file diff --git a/model_measuring/kamal/core/engine/events.py b/model_measuring/kamal/core/engine/events.py new file mode 100644 index 0000000..f70b494 --- /dev/null +++ b/model_measuring/kamal/core/engine/events.py @@ -0,0 +1,92 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +from typing import Callable, Optional +from enum import Enum + +class Event(object): + def __init__(self, value: str, event_trigger: Optional[Callable]=None ): + if event_trigger is None: + event_trigger = Event.default_trigger + self._trigger = event_trigger + self._name_ = self._value_ = value + + @property + def trigger(self): + return self._trigger + + @property + def name(self): + """The name of the Enum member.""" + return self._name_ + + @property + def value(self): + """The value of the Enum member.""" + return self._value_ + + @staticmethod + def default_trigger(engine): + return True + + @staticmethod + def once_trigger(): + is_triggered = False + def wrapper(engine): + if is_triggered: + return False + is_triggered=True + return True + return wrapper + + @staticmethod + def every_trigger(every: int): + def wrapper(engine): + return every>0 and (engine.state.iter % every)==0 + return wrapper + + def __call__(self, every: Optional[int]=None, once: Optional[bool]=None ): + if every is not None: + assert once is None + return Event(self.value, event_trigger=Event.every_trigger(every) ) + if once is not None: + return Event(self.value, event_trigger=Event.once_trigger() ) + return Event(self.value) + + def __hash__(self): + return hash(self._name_) + + def __eq__(self, other): + if hasattr(other, 'value'): + return self.value==other.value + else: + return + +class DefaultEvents(Event, Enum): + BEFORE_RUN = "before_train" + AFTER_RUN = "after_train" + + BEFORE_EPOCH = "before_epoch" + AFTER_EPOCH = "after_epoch" + + BEFORE_STEP = "before_step" + AFTER_STEP = "after_step" + + BEFORE_GET_BATCH = "before_get_batch" + AFTER_GET_BATCH = "after_get_batch" + + \ No newline at end of file diff --git a/model_measuring/kamal/core/engine/hooks.py b/model_measuring/kamal/core/engine/hooks.py new file mode 100644 index 0000000..7d0bcce --- /dev/null +++ b/model_measuring/kamal/core/engine/hooks.py @@ -0,0 +1,34 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +class FeatureHook(): + def __init__(self, module): + self.module = module + self.feat_in = None + self.feat_out = None + self.register() + + def register(self): + self._hook = self.module.register_forward_hook(self.hook_fn_forward) + + def remove(self): + self._hook.remove() + + def hook_fn_forward(self, module, fea_in, fea_out): + self.feat_in = fea_in[0] + self.feat_out = fea_out + diff --git a/model_measuring/kamal/core/engine/lr_finder.py b/model_measuring/kamal/core/engine/lr_finder.py new file mode 100644 index 0000000..6e184d1 --- /dev/null +++ b/model_measuring/kamal/core/engine/lr_finder.py @@ -0,0 +1,214 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +import torch +import os +import tempfile, uuid +import contextlib +import numpy as np +from tqdm import tqdm +from kamal.core import callbacks +from kamal.core.engine import evaluator +from kamal.core.engine.events import DefaultEvents + +class _ProgressCallback(callbacks.Callback): + def __init__(self, max_iter, tag): + self._tag = tag + self._max_iter = max_iter + self._pbar = tqdm(total=self._max_iter, desc=self._tag) + + def __call__(self, enigine): + self._pbar.update(1) + + def reset(self): + self._pbar = tqdm(total=self._max_iter, desc=self._tag) + +class _LinearLRScheduler(torch.optim.lr_scheduler._LRScheduler): + """ Linear Scheduler + """ + def __init__(self, + optimizer, + lr_range, + max_iter, + last_epoch=-1): + self.lr_range = lr_range + self.max_iter = max_iter + super(_LinearLRScheduler, self).__init__(optimizer, last_epoch) + + def get_lr(self): + r = self.last_epoch / self.max_iter + return [self.lr_range[0] + (self.lr_range[1]-self.lr_range[0]) * r for base_lr in self.base_lrs] + +class _ExponentialLRScheduler(torch.optim.lr_scheduler._LRScheduler): + def __init__(self, optimizer, lr_range, max_iter, last_epoch=-1): + self.lr_range = lr_range + self.max_iter = max_iter + super(_ExponentialLRScheduler, self).__init__(optimizer, last_epoch) + + def get_lr(self): + r = self.last_epoch / (self.max_iter - 1) + return [self.lr_range[0] * (self.lr_range[1] / self.lr_range[0]) ** r for base_lr in self.base_lrs] + +class _LRFinderCallback(object): + def __init__(self, + model, + optimizer, + metric_name, + evaluator, + smooth_momentum): + self._model = model + self._optimizer = optimizer + self._metric_name = metric_name + self._evaluator = evaluator + self._smooth_momentum = smooth_momentum + self._records = [] + + @property + def records(self): + return self._records + + def reset(self): + self._records = [] + + def __call__(self, trainer): + model = self._model + optimizer = self._optimizer + if self._evaluator is not None: + results = self._evaluator.eval( model ) + score = float(results[ self._metric_name ]) + else: + score = float(trainer.state.metrics[self._metric_name]) + + if self._smooth_momentum>0 and len(self._records)>0: + score = self._records[-1][1] * self._smooth_momentum + score * (1-self._smooth_momentum) + + lr = optimizer.param_groups[0]['lr'] + self._records.append( ( lr, score ) ) + + +class LRFinder(object): + + def _reset(self): + init_state = torch.load( self._temp_file ) + self.trainer.model.load_state_dict( init_state['model'] ) + self.trainer.optimizer.load_state_dict( init_state['optimizer'] ) + try: + self.trainer.reset() + except: pass + + def adjust_learning_rate(self, optimizer, lr): + for group in optimizer.param_groups: + group['lr'] = lr + + def _get_default_lr_range(self, optimizer): + if isinstance( optimizer, torch.optim.Adam ): + return ( 1e-5, 1e-2 ) + elif isinstance( optimizer, torch.optim.SGD ): + return ( 1e-3, 0.2 ) + else: + return ( 1e-5, 0.5) + + def plot(self, polyfit: int=3, log_x=True): + import matplotlib.pyplot as plt + lrs = [ rec[0] for rec in self.records ] + scores = [ rec[1] for rec in self.records ] + + if polyfit is not None: + z = np.polyfit( lrs, scores, deg=polyfit ) + fitted_score = np.polyval( z, lrs ) + + fig, ax = plt.subplots() + ax.plot(lrs, scores, label='score') + ax.plot(lrs, fitted_score, label='polyfit') + + if log_x: + plt.xscale('log') + + ax.set_xlabel("Learning rate") + ax.set_ylabel("Score") + return fig + + def suggest( self, mode='min', skip_begin=10, skip_end=5, polyfit=None ): + + scores = np.array( [ self.records[i][1] for i in range( len(self.records) ) ] ) + lrs = np.array( [ self.records[i][0] for i in range( len(self.records) ) ] ) + if polyfit is not None: + z = np.polyfit( lrs, scores, deg=polyfit ) + scores = np.polyval( z, lrs ) + + grad = np.gradient( scores )[skip_begin:-skip_end] + index = grad.argmin() if mode=='min' else grad.argmax() + index = skip_begin + index + return index, self.records[index][0] + + def find(self, + optimizer, + model, + trainer, + metric_name='total_loss', + metric_mode='min', + evaluator=None, + lr_range=[1e-4, 0.1], + max_iter=100, + num_eval=None, + smooth_momentum=0.9, + scheduler='exp', # exp + polyfit=None, # None + skip=[10, 5], + progress=True): + + self.optimizer = optimizer + self.model = model + self.trainer = trainer + # save init state + _filename = str(uuid.uuid4())+'.pth' + _tempdir = tempfile.gettempdir() + self._temp_file = os.path.join(_tempdir, _filename) + init_state = { + 'optimizer': optimizer.state_dict(), + 'model': model.state_dict() + } + torch.save(init_state, self._temp_file) + + if num_eval is None or num_eval > max_iter: + num_eval = max_iter + if lr_range is None: + lr_range = self._get_default_lr_range(optimizer) + + interval = max_iter // num_eval + if scheduler=='exp': + lr_sched = _ExponentialLRScheduler(optimizer, lr_range, max_iter=max_iter) + else: + lr_sched = _LinearLRScheduler(optimizer, lr_range, max_iter=max_iter) + + self._lr_callback = callbacks.LRSchedulerCallback(schedulers=[lr_sched]) + self._finder_callback = _LRFinderCallback(model, optimizer, metric_name, evaluator, smooth_momentum) + self.adjust_learning_rate( self.optimizer, lr_range[0] ) + with self.trainer.save_current_callbacks(): + trainer.add_callback( + DefaultEvents.AFTER_STEP, callbacks=[ + self._lr_callback, + _ProgressCallback(max_iter, '[LR Finder]') ]) + trainer.add_callback( + DefaultEvents.AFTER_STEP(interval), callbacks=self._finder_callback) + self.trainer.run( start_iter=0, max_iter=max_iter ) + + self.records = self._finder_callback.records # get records + index, best_lr = self.suggest(mode=metric_mode, skip_begin=skip[0], skip_end=skip[1], polyfit=polyfit) + self._reset() + del self.model, self.optimizer, self.trainer + return best_lr \ No newline at end of file diff --git a/model_measuring/kamal/core/engine/trainer.py b/model_measuring/kamal/core/engine/trainer.py new file mode 100644 index 0000000..54d0028 --- /dev/null +++ b/model_measuring/kamal/core/engine/trainer.py @@ -0,0 +1,129 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +import torch +import torch.nn as nn +from kamal.core.engine.engine import Engine, Event, DefaultEvents, State +from kamal.core import tasks +from kamal.utils import set_mode, move_to_device, get_logger, split_batch +from typing import Callable, Mapping, Any, Sequence +import time +import weakref + +class BasicTrainer(Engine): + def __init__( self, + logger=None, + tb_writer=None): + super(BasicTrainer, self).__init__(logger=logger, tb_writer=tb_writer) + + def setup(self, + model: torch.nn.Module, + task: tasks.Task, + dataloader: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + device: torch.device=None): + + if device is None: + device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) + self.device = device + if isinstance(task, Sequence): + task = tasks.TaskCompose(task) + self.task = task + self.model = model + self.dataloader = dataloader + self.optimizer = optimizer + return self + + def run( self, max_iter, start_iter=0, epoch_length=None): + self.model.to(self.device) + with set_mode(self.model, training=True): + super( BasicTrainer, self ).run( self.step_fn, self.dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length) + + def step_fn(self, engine, batch): + model = self.model + start_time = time.perf_counter() + batch = move_to_device(batch, self.device) + inputs, targets = split_batch(batch) + outputs = model(inputs) + loss_dict = self.task.get_loss(outputs, targets) # get loss + loss = sum( loss_dict.values() ) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + step_time = time.perf_counter() - start_time + metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } + metrics.update({ + 'total_loss': loss.item(), + 'step_time': step_time, + 'lr': float( self.optimizer.param_groups[0]['lr'] ) + }) + return metrics + + +class KDTrainer(BasicTrainer): + + def setup(self, + student: torch.nn.Module, + teacher: torch.nn.Module, + task: tasks.Task, + dataloader: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + device: torch.device=None): + + super(KDTrainer, self).setup( + model=student, task=task, dataloader=dataloader, optimizer=optimizer, device=device) + if isinstance(teacher, (list, tuple)): + if len(teacher)==1: + teacher=teacher[0] + else: + teacher = nn.ModuleList(teacher) + self.student = self.model + self.teacher = teacher + return self + + def run( self, max_iter, start_iter=0, epoch_length=None): + self.student.to(self.device) + self.teacher.to(self.device) + + with set_mode(self.student, training=True), \ + set_mode(self.teacher, training=False): + super( BasicTrainer, self ).run( + self.step_fn, self.dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length) + + def step_fn(self, engine, batch): + model = self.model + start_time = time.perf_counter() + batch = move_to_device(batch, self.device) + inputs, targets = split_batch(batch) + outputs = model(inputs) + if isinstance(self.teacher, nn.ModuleList): + soft_targets = [ t(inputs) for t in self.teacher ] + else: + soft_targets = self.teacher(inputs) + loss_dict = self.task.get_loss(outputs, soft_targets) # get loss + loss = sum( loss_dict.values() ) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + step_time = time.perf_counter() - start_time + metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } + metrics.update({ + 'total_loss': loss.item(), + 'step_time': step_time, + 'lr': float( self.optimizer.param_groups[0]['lr'] ) + }) + return metrics diff --git a/model_measuring/kamal/core/exceptions.py b/model_measuring/kamal/core/exceptions.py new file mode 100644 index 0000000..37c2d30 --- /dev/null +++ b/model_measuring/kamal/core/exceptions.py @@ -0,0 +1,22 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +class DataTypeError(object): + pass + +class InvalidMapping(object): + pass \ No newline at end of file diff --git a/model_measuring/kamal/core/hub/__init__.py b/model_measuring/kamal/core/hub/__init__.py new file mode 100644 index 0000000..ee9a73b --- /dev/null +++ b/model_measuring/kamal/core/hub/__init__.py @@ -0,0 +1,2 @@ +from ._hub import * +from . import meta \ No newline at end of file diff --git a/model_measuring/kamal/core/hub/_hub.py b/model_measuring/kamal/core/hub/_hub.py new file mode 100644 index 0000000..77c8905 --- /dev/null +++ b/model_measuring/kamal/core/hub/_hub.py @@ -0,0 +1,288 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +from pip._internal import main as pipmain +from typing import Sequence, Optional + +import torch +import os, sys, re +import shutil, inspect +import importlib.util +from glob import glob + +from ruamel.yaml import YAML +from ruamel.yaml import comments +from ._module_mapping import PACKAGE_NAME_TO_IMPORT_NAME +import hashlib + +import inspect + +_DEFAULT_PROTOCOL = 2 +_DEPENDENCY_FILE = "requirements.txt" +_CODE_DIR = "code" +_WEIGHTS_DIR = "weight" +_TAGS_DIR = "tag" + +yaml = YAML() + +def _replace_invalid_char(name): + return name.replace('-', '_').replace(' ', '_') + +def save( + model: torch.nn.Module, + save_path: str, + + entry_name: str, + spec_name: str, + code_path: str, # path to SOURCE_CODE_DIR or hubconf.py + + metadata: dict, + tags: dict = None, + ignore_files: Sequence = None, + save_arch: bool = False, + overwrite: bool = False, +): + entry_name = _replace_invalid_char(entry_name) + if spec_name is not None: + spec_name = _replace_invalid_char(spec_name) + + if not os.path.isdir(save_path): + overwrite = True + + save_path = os.path.abspath(save_path) + export_code_path = os.path.join(save_path, _CODE_DIR) + export_weights_path = os.path.join(save_path, _WEIGHTS_DIR) + os.makedirs(save_path, exist_ok=True) + os.makedirs(export_code_path, exist_ok=True) + os.makedirs(export_weights_path, exist_ok=True) + + code_path = os.path.abspath(code_path) + if os.path.isdir( code_path ): + if overwrite: + shutil.rmtree(export_code_path) + _copy_file_or_tree(src=code_path, dst=export_code_path) # overwrite old files + elif code_path.endswith('.py'): + if overwrite: + shutil.copy2(src=code_path, dst=os.path.join( export_code_path, 'hubconf.py' )) # overwrite old files + + if hasattr(model, 'SETUP_INFO'): + del model.SETUP_INFO + if hasattr(model, 'METADATA'): + del model.METADATA + + model_and_metadata = { + 'model': model if save_arch else model.state_dict(), + 'metadata': metadata, + } + temp_pth = os.path.join(export_weights_path, 'temp.pth') + torch.save(model_and_metadata, temp_pth) + if spec_name is None: + _md5 = md5( temp_pth ) + spec_name = _md5 + shutil.move( temp_pth, os.path.join(export_weights_path, '%s-%s.pth'%(entry_name, spec_name)) ) + + if tags is not None: + save_tags( tags, save_path, entry_name, spec_name ) + +def list_entry(path): + path = os.path.abspath( os.path.expanduser( path ) ) + if path.endswith('.py'): + code_dir = os.path.dirname( path ) + entry_file = path + elif os.path.isdir(path): + code_dir = os.path.join( path, _CODE_DIR ) + entry_file = os.path.join(path, _CODE_DIR, 'app.py') + else: + raise NotImplementedError + sys.path.insert(0, code_dir) + spec = importlib.util.spec_from_file_location('app', entry_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + entry_list = [ ( f, getattr(module, f) ) for f in dir(module) if callable(getattr(module, f)) and not f.startswith('_') ] + sys.path.remove(code_dir) + return entry_list + +def list_spec(path, entry_name=None): + path = os.path.abspath( os.path.expanduser( path ) ) + weight_dir = os.path.join( path, _WEIGHTS_DIR ) + spec_list = [ f.split('-') for f in os.listdir( weight_dir ) if f.endswith('.pth')] + spec_list = [ (f[0], f[1][:-4]) for f in spec_list ] + if entry_name is not None: + spec_list = [ s for s in spec_list if s[0]==entry_name ] + return spec_list + +def load(path: str, entry_name:str=None, spec_name: str=None, pretrained=True, **kwargs): + """ + check dependencies and load pytorch models. + Args: + path: path to the model package + pretrained: load the pretrained model + """ + path = os.path.abspath(path) + if not os.path.exists(path): + raise FileNotFoundError + + code_dir = os.path.join(path, _CODE_DIR) + if os.path.isdir(path): + cwd = os.getcwd() + os.chdir(code_dir) + sys.path.insert(0, code_dir) + spec = importlib.util.spec_from_file_location( + 'hubconf', os.path.join(code_dir, 'hubconf.py')) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + if hasattr( module, 'dependencies' ): + deps = getattr( module, 'dependencies') + for dep in deps: + _import_with_auto_install(dep) + + if entry_name is None: # auto select + pth_file = [pth for pth in os.listdir( os.path.join(path, _WEIGHTS_DIR) )] + assert len(pth_file)<=1, "Loading models with more than one weight files (.pth) is ambiguous" + pth_file = pth_file[0] + entry_name = pth_file.split('-')[0] + entry_fn = getattr( module, entry_name ) + else: + entry_fn = getattr( module, entry_name ) + if spec_name is None: + pth_file = [pth for pth in os.listdir( os.path.join(path, _WEIGHTS_DIR) ) if pth.startswith(entry_name) ] + assert len(pth_file)<=1, "Loading models with more than one weight files (.pth) is ambiguous" + pth_file = pth_file[0] + else: + pth_file = '%s-%s.pth'%(entry_name, spec_name) + + try: + model_and_metadata = torch.load(os.path.join(path, _WEIGHTS_DIR, pth_file), map_location='cpu' ) + except: raise FileNotFoundError + + if isinstance( model_and_metadata['model'], torch.nn.Module ): + model = model_and_metadata['model'] + else: + entry_args = model_and_metadata['metadata']['entry_args'] + if entry_args is None: + entry_args = dict() + model = entry_fn( **entry_args ) + if pretrained: + model.load_state_dict( model_and_metadata['model'], False) + + # setup metadata and atlas info + model.METADATA = model_and_metadata['metadata'] + model.SETUP_INFO = {"path": path, "entry_name": entry_name} + sys.path.pop(0) + os.chdir(cwd) + return model + raise NotImplementedError + +def load_metadata(path, entry_name=None, spec_name=None): + path = os.path.abspath(path) + if os.path.isdir(path): + if entry_name is None: # auto select + pth_file = [pth for pth in os.listdir( os.path.join(path, _WEIGHTS_DIR) )] + assert len(pth_file)<=1, "Loading models with more than one weight files (.pth) is ambiguous" + pth_file = pth_file[0] + else: + if spec_name is None: + pth_file = [pth for pth in os.listdir( os.path.join(path, _WEIGHTS_DIR) ) if pth.startswith(entry_name) ] + assert len(pth_file)<=1, "Loading models with more than one weight files (.pth) is ambiguous" + pth_file = pth_file[0] + else: + pth_file = '%s-%s.pth'%(entry_name, spec_name) + try: + metadata = torch.load( os.path.join( path, _WEIGHTS_DIR, pth_file ), map_location='cpu' )['metadata'] + except: + FileNotFoundError + return metadata + +def load_tags(path, entry_name, spec_name): + path = os.path.abspath(path) + tags_path = os.path.join(path, _TAGS_DIR, '%s-%s.yml'%( entry_name, spec_name )) + if os.path.isfile(tags_path): + return _to_python_type(_yaml_load(tags_path)) + return dict() + +def save_tags(tags, path, entry_name, spec_name): + path = os.path.abspath(path) + if tags is None: + tags = {} + tags_path = os.path.join(path, _TAGS_DIR, '%s-%s.yml'%( entry_name, spec_name )) + os.makedirs( os.path.join( path, _TAGS_DIR ), exist_ok=True ) + _yaml_dump(tags_path, tags) + +def _yaml_dump(f, obj): + with open(f, 'w') as f: + yaml.dump(obj, f) + +def _yaml_load(f): + with open(f, 'r') as f: + return yaml.load(f) + +def _to_python_type(data): + if isinstance(data, dict): + for k, v in data.items(): + data[k] = _to_python_type(v) + return dict(data) + elif isinstance(data, comments.CommentedSeq ): + for idx, v in enumerate(data): + data[idx] = _to_python_type(v) + return list(data) + return data + +def md5(fname): + hash_md5 = hashlib.md5() + with open(fname, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + +def _get_package_name_and_version(package): + _version_sym_list = ('==', '>=', '<=') + for sym in _version_sym_list: + if sym in package: + return package.split(sym) + return package, None + +def _import_with_auto_install(package): + package_name, version = _get_package_name_and_version(package) + package_name = package_name.strip() + import_name = PACKAGE_NAME_TO_IMPORT_NAME.get( + package_name, package_name).replace('-', '_') + try: + return __import__(import_name) + except ImportError: + try: + pipmain.main(['install', package]) + except: + pipmain(['install', package]) + return __import__(import_name) + +from distutils.dir_util import copy_tree +def _copy_file_or_tree(src, dst): + if os.path.isfile(src): + shutil.copy2(src=src, dst=dst) + else: + copy_tree(src=src, dst=dst) + +def _glob_list(path_list, recursive=False): + results = [] + for path in path_list: + if '*' in path: + path = list(glob(path, recursive=recursive)) + results.extend(path) + else: + results.append(path) + return results diff --git a/model_measuring/kamal/core/hub/_module_mapping.py b/model_measuring/kamal/core/hub/_module_mapping.py new file mode 100644 index 0000000..386cc2d --- /dev/null +++ b/model_measuring/kamal/core/hub/_module_mapping.py @@ -0,0 +1,6 @@ +PACKAGE_NAME_TO_IMPORT_NAME = { + 'opencv-python': 'cv2', + 'pillow': 'PIL', + 'scikit-learn': 'sklearn', + 'scikit-image': 'scikit-image', +} \ No newline at end of file diff --git a/model_measuring/kamal/core/hub/meta/TASK.py b/model_measuring/kamal/core/hub/meta/TASK.py new file mode 100644 index 0000000..b4ea775 --- /dev/null +++ b/model_measuring/kamal/core/hub/meta/TASK.py @@ -0,0 +1,22 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +CLASSIFICATION = 'classification' +SEGMENTATION = 'segmentation' +DETECTION = 'detection' +DEPTH = 'depth' +AUTOENCODER = 'autoencoder' \ No newline at end of file diff --git a/model_measuring/kamal/core/hub/meta/__init__.py b/model_measuring/kamal/core/hub/meta/__init__.py new file mode 100644 index 0000000..71f74a0 --- /dev/null +++ b/model_measuring/kamal/core/hub/meta/__init__.py @@ -0,0 +1,3 @@ +from .meta import Metadata +from .input import ImageInput +from . import TASK diff --git a/model_measuring/kamal/core/hub/meta/input.py b/model_measuring/kamal/core/hub/meta/input.py new file mode 100644 index 0000000..e6a3600 --- /dev/null +++ b/model_measuring/kamal/core/hub/meta/input.py @@ -0,0 +1,31 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +from typing import Union, Sequence + +# INPUT Metadata +def ImageInput( size: Union[Sequence[int], int], + range: Union[Sequence[int], int], + space: str, + normalize: Union[Sequence]=None): + assert space in ['rgb', 'gray', 'bgr', 'rgbd'] + return dict( + size=size, + range=range, + space=space, + normalize=normalize + ) \ No newline at end of file diff --git a/model_measuring/kamal/core/hub/meta/meta.py b/model_measuring/kamal/core/hub/meta/meta.py new file mode 100644 index 0000000..ee795fd --- /dev/null +++ b/model_measuring/kamal/core/hub/meta/meta.py @@ -0,0 +1,44 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +from copy import deepcopy +import abc, os + +from . import TASK +import torch + +__all__ = ['yaml', 'ImageInput', 'MetaData', 'AtlasEntryBase'] + +# Model Metadata +def Metadata( name: str, + dataset: str, + task: int, + url: str, + input: dict, + entry_args: dict, + other_metadata: dict): + if task in [TASK.SEGMENTATION, TASK.CLASSIFICATION]: + assert 'num_classes' in other_metadata + return dict( + name=name, + dataset=dataset, + task=task, + url=url, + input=dict(input), + entry_args=entry_args, + other_metadata=other_metadata + ) \ No newline at end of file diff --git a/model_measuring/kamal/core/metrics/__init__.py b/model_measuring/kamal/core/metrics/__init__.py new file mode 100644 index 0000000..ce13dcf --- /dev/null +++ b/model_measuring/kamal/core/metrics/__init__.py @@ -0,0 +1,8 @@ +from .stream_metrics import Metric, MetricCompose +from .accuracy import Accuracy, TopkAccuracy +from .confusion_matrix import ConfusionMatrix, IoU, mIoU +from .regression import * +from .average import AverageMetric + + + \ No newline at end of file diff --git a/model_measuring/kamal/core/metrics/accuracy.py b/model_measuring/kamal/core/metrics/accuracy.py new file mode 100644 index 0000000..026b321 --- /dev/null +++ b/model_measuring/kamal/core/metrics/accuracy.py @@ -0,0 +1,118 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +import numpy as np +import torch +from kamal.core.metrics.stream_metrics import Metric +from typing import Callable + +__all__=['Accuracy', 'TopkAccuracy'] + +class Accuracy(Metric): + def __init__(self, attach_to=None): + super(Accuracy, self).__init__(attach_to=attach_to) + self.reset() + + @torch.no_grad() + def update(self, outputs, targets): + outputs, targets = self._attach(outputs, targets) + outputs = outputs.max(1)[1] + self._correct += ( outputs.view(-1)==targets.view(-1) ).sum() + self._cnt += torch.numel( targets ) + + def get_results(self): + return (self._correct / self._cnt).detach().cpu() + + def reset(self): + self._correct = self._cnt = 0.0 + + +class TopkAccuracy(Metric): + def __init__(self, topk=5, attach_to=None): + super(TopkAccuracy, self).__init__(attach_to=attach_to) + self._topk = topk + self.reset() + + @torch.no_grad() + def update(self, outputs, targets): + outputs, targets = self._attach(outputs, targets) + _, outputs = outputs.topk(self._topk, dim=1, largest=True, sorted=True) + correct = outputs.eq( targets.view(-1, 1).expand_as(outputs) ) + self._correct += correct[:, :self._topk].view(-1).float().sum(0).item() + self._cnt += len(targets) + + def get_results(self): + return self._correct / self._cnt + + def reset(self): + self._correct = 0.0 + self._cnt = 0.0 + + +class StreamCEMAPMetrics(): + @property + def PRIMARY_METRIC(self): + return "eap" + + def __init__(self): + self.reset() + + def update(self, logits, targets): + preds = logits.max(1)[1] + # targets: -1 negative, 0 difficult, 1 positive + if isinstance(preds, torch.Tensor): + preds = preds.cpu().numpy() + targets = targets.cpu().numpy() + + self.preds = preds if self.preds is None else np.append(self.preds, preds, axis=0) + self.targets = targets if self.targets is None else np.append(self.targets, targets, axis=0) + + def get_results(self): + nTest = self.targets.shape[0] + nLabel = self.targets.shape[1] + eap = np.zeros(nTest) + for i in range(0,nTest): + R = np.sum(self.targets[i,:]==1) + for j in range(0,nLabel): + if self.targets[i,j]==1: + r = np.sum(self.preds[i,np.nonzero(self.targets[i,:]!=0)]>=self.preds[i,j]) + rb = np.sum(self.preds[i,np.nonzero(self.targets[i,:]==1)] >= self.preds[i,j]) + + eap[i] = eap[i] + rb/(r*1.0) + eap[i] = eap[i]/R + # emap = np.nanmean(ap) + + cap = np.zeros(nLabel) + for i in range(0,nLabel): + R = np.sum(self.targets[:,i]==1) + for j in range(0,nTest): + if self.targets[j,i]==1: + r = np.sum(self.preds[np.nonzero(self.targets[:,i]!=0),i] >= self.preds[j,i]) + rb = np.sum(self.preds[np.nonzero(self.targets[:,i]==1),i] >= self.preds[j,i]) + cap[i] = cap[i] + rb/(r*1.0) + cap[i] = cap[i]/R + # cmap = np.nanmean(ap) + return { + 'eap': eap, + 'emap': np.nanmean(eap), + 'cap': cap, + 'cmap': np.nanmean(cap), + } + + def reset(self): + self.preds = None + self.targets = None \ No newline at end of file diff --git a/model_measuring/kamal/core/metrics/average.py b/model_measuring/kamal/core/metrics/average.py new file mode 100644 index 0000000..e0eeada --- /dev/null +++ b/model_measuring/kamal/core/metrics/average.py @@ -0,0 +1,49 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +import numpy as np +import torch +from kamal.core.metrics.stream_metrics import Metric +from typing import Callable + +__all__=['AverageMetric'] + +class AverageMetric(Metric): + def __init__(self, fn:Callable, attach_to=None): + super(AverageMetric, self).__init__(attach_to=attach_to) + self._fn = fn + self.reset() + + @torch.no_grad() + def update(self, outputs, targets): + + outputs, targets = self._attach(outputs, targets) + m = self._fn( outputs, targets ) + + if m.ndim > 1: + self._cnt += m.shape[0] + self._accum += m.sum(0) + else: + self._cnt += 1 + self._accum += m + + def get_results(self): + return (self._accum / self._cnt).detach().cpu() + + def reset(self): + self._accum = 0. + self._cnt = 0. \ No newline at end of file diff --git a/model_measuring/kamal/core/metrics/confusion_matrix.py b/model_measuring/kamal/core/metrics/confusion_matrix.py new file mode 100644 index 0000000..4864686 --- /dev/null +++ b/model_measuring/kamal/core/metrics/confusion_matrix.py @@ -0,0 +1,68 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +from .stream_metrics import Metric +import torch +from typing import Callable + +class ConfusionMatrix(Metric): + def __init__(self, num_classes, ignore_idx=None, attach_to=None): + super(ConfusionMatrix, self).__init__(attach_to=attach_to) + self._num_classes = num_classes + self._ignore_idx = ignore_idx + self.reset() + + @torch.no_grad() + def update(self, outputs, targets): + outputs, targets = self._attach(outputs, targets) + if self.confusion_matrix.device != outputs.device: + self.confusion_matrix = self.confusion_matrix.to(device=outputs.device) + preds = outputs.max(1)[1].flatten() + targets = targets.flatten() + mask = (preds=0) + if self._ignore_idx: + mask = mask & (targets!=self._ignore_idx) + preds, targets = preds[mask], targets[mask] + hist = torch.bincount( self._num_classes * targets + preds, + minlength=self._num_classes ** 2 ).view(self._num_classes, self._num_classes) + self.confusion_matrix += hist + + def get_results(self): + return self.confusion_matrix.detach().cpu() + + def reset(self): + self._cnt = 0 + self.confusion_matrix = torch.zeros(self._num_classes, self._num_classes, dtype=torch.int64, requires_grad=False) + +class IoU(Metric): + def __init__(self, confusion_matrix: ConfusionMatrix, attach_to=None): + self._confusion_matrix = confusion_matrix + + def update(self, outputs, targets): + pass # update will be done in confusion matrix + + def reset(self): + pass + + def get_results(self): + cm = self._confusion_matrix.get_results() + iou = cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) - cm.diag() + 1e-9) + return iou + +class mIoU(IoU): + def get_results(self): + return super(mIoU, self).get_results().mean() diff --git a/model_measuring/kamal/core/metrics/normal.py b/model_measuring/kamal/core/metrics/normal.py new file mode 100644 index 0000000..c5cf241 --- /dev/null +++ b/model_measuring/kamal/core/metrics/normal.py @@ -0,0 +1,86 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +import numpy as np +import torch +from kamal.core.metrics.stream_metrics import StreamMetricsBase + +class NormalPredictionMetrics(StreamMetricsBase): + + @property + def PRIMARY_METRIC(self): + return 'mean angle' + + def __init__(self, thresholds): + self.thresholds = thresholds + self.preds = None + self.targets = None + self.masks = None + + @torch.no_grad() + def update(self, preds, targets, masks): + """ + **Type**: numpy.ndarray or torch.Tensor + **Shape:** + - **preds**: $(N, 3, H, W)$. + - **targets**: $(N, 3, H, W)$. + - **masks**: $(N, 1, H, W)$. + """ + if isinstance(preds, torch.Tensor): + preds = preds.cpu().numpy() + targets = targets.cpu().numpy() + masks = masks.cpu().numpy() + + self.preds = preds if self.preds is None else np.append(self.preds, preds, axis=0) + self.targets = targets if self.targets is None else np.append(self.targets, targets, axis=0) + self.masks = masks if self.masks is None else np.append(self.masks, masks, axis=0) + + def get_results(self): + """ + **Returns:** + - **mean angle** + - **median angle** + - **precents for angle within thresholds** + """ + products = np.sum(self.preds * self.targets, axis=1) + + angles = np.arccos(np.clip(products, -1.0, 1.0)) / np.pi * 180 + self.masks = self.masks.squeeze(1) + angles = angles[self.masks == 1] + + mean_angle = np.mean(angles) + median_angle = np.median(angles) + count = self.masks.sum() + + threshold_percents = {} + for threshold in self.thresholds: + # threshold_percents[threshold] = np.sum((angles < threshold)) / count + threshold_percents[threshold] = np.mean(angles < threshold) + + if return_key_metric: + return ('absolute relative', ard) + + return { + 'mean angle': mean_angle, + 'median angle': median_angle, + 'percents within thresholds': threshold_percents + } + + def reset(self): + self.preds = None + self.targets = None + self.masks = None \ No newline at end of file diff --git a/model_measuring/kamal/core/metrics/regression.py b/model_measuring/kamal/core/metrics/regression.py new file mode 100644 index 0000000..6a97e8c --- /dev/null +++ b/model_measuring/kamal/core/metrics/regression.py @@ -0,0 +1,199 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +import numpy as np +import torch + +from kamal.core.metrics.stream_metrics import Metric +from typing import Callable + +__all__=['MeanSquaredError', 'RootMeanSquaredError', 'MeanAbsoluteError', + 'ScaleInveriantMeanSquaredError', 'RelativeDifference', + 'AbsoluteRelativeDifference', 'SquaredRelativeDifference', 'Threshold' ] + +class MeanSquaredError(Metric): + def __init__(self, log_scale=False, attach_to=None): + super(MeanSquaredError, self).__init__( attach_to=attach_to ) + self.reset() + self.log_scale=log_scale + + @torch.no_grad() + def update(self, outputs, targets): + outputs, targets = self._attach(outputs, targets) + if self.log_scale: + diff = torch.sum((torch.log(outputs+1e-8) - torch.log(targets+1e-8))**2) + else: + diff = torch.sum((outputs - targets)**2) + self._accum_sq_diff += diff + self._cnt += torch.numel(outputs) + + def get_results(self): + return (self._accum_sq_diff / self._cnt).detach().cpu() + + def reset(self): + self._accum_sq_diff = 0. + self._cnt = 0. + + +class RootMeanSquaredError(MeanSquaredError): + def get_results(self): + return torch.sqrt( (self._accum_sq_diff / self._cnt) ).detach().cpu() + + +class MeanAbsoluteError(Metric): + def __init__(self, attach_to=None ): + super(MeanAbsoluteError, self).__init__( attach_to=attach_to ) + self.reset() + + @torch.no_grad() + def update(self, outputs, targets): + outputs, targets = self._attach(outputs, targets) + diff = torch.sum((outputs - targets).abs()) + self._accum_abs_diff += diff + self._cnt += torch.numel(outputs) + + def get_results(self): + return (self._accum_abs_diff / self._cnt).detach().cpu() + + def reset(self): + self._accum_abs_diff = 0. + self._cnt = 0. + + +class ScaleInveriantMeanSquaredError(Metric): + def __init__(self, attach_to=None ): + super(ScaleInveriantMeanSquaredError, self).__init__( attach_to=attach_to ) + self.reset() + + @torch.no_grad() + def update(self, outputs, targets): + outputs, targets = self._attach(outputs, targets) + diff_log = torch.log( outputs+1e-8 ) - torch.log( targets+1e-8 ) + self._accum_log_diff = diff_log.sum() + self._accum_sq_log_diff = (diff_log**2).sum() + self._cnt += torch.numel(outputs) + + def get_results(self): + return ( self._accum_sq_log_diff / self._cnt - 0.5 * (self._accum_log_diff**2 / self._cnt**2) ).detach().cpu() + + def reset(self): + self._accum_log_diff = 0. + self._accum_sq_log_diff = 0. + self._cnt = 0. + + +class RelativeDifference(Metric): + def __init__(self, attach_to=None ): + super(RelativeDifference, self).__init__( attach_to=attach_to ) + self.reset() + + @torch.no_grad() + def update(self, outputs, targets): + outputs, targets = self._attach(outputs, targets) + diff = (outputs - targets).abs() + self._accum_abs_rel += (diff/targets).sum() + self._cnt += torch.numel(outputs) + + def get_results(self): + return (self._accum_abs_rel / self._cnt).detach().cpu() + + def reset(self): + self._accum_abs_rel = 0. + self._cnt = 0. + + +class AbsoluteRelativeDifference(Metric): + def __init__(self, attach_to=None ): + super(AbsoluteRelativeDifference, self).__init__( attach_to=attach_to ) + self.reset() + + @torch.no_grad() + def update(self, outputs, targets): + outputs, targets = self._attach(outputs, targets) + diff = (outputs - targets).abs() + self._accum_abs_rel += (diff/targets).sum() + self._cnt += torch.numel(outputs) + + def get_results(self): + return (self._accum_abs_rel / self._cnt).detach().cpu() + + def reset(self): + self._accum_abs_rel = 0. + self._cnt = 0. + + +class AbsoluteRelativeDifference(Metric): + def __init__(self, attach_to=None ): + super(AbsoluteRelativeDifference, self).__init__( attach_to=attach_to ) + self.reset() + + @torch.no_grad() + def update(self, outputs, targets): + outputs, targets = self._attach(outputs, targets) + diff = (outputs - targets).abs() + self._accum_abs_rel += (diff/targets).sum() + self._cnt += torch.numel(outputs) + + def get_results(self): + return (self._accum_abs_rel / self._cnt).detach().cpu() + + def reset(self): + self._accum_abs_rel = 0. + self._cnt = 0. + + +class SquaredRelativeDifference(Metric): + def __init__(self, attach_to=None ): + super(SquaredRelativeDifference, self).__init__( attach_to=attach_to ) + self.reset() + + @torch.no_grad() + def update(self, outputs, targets): + outputs, targets = self._attach(outputs, targets) + diff = (outputs - targets)**2 + self._accum_sq_rel += (diff/targets).sum() + self._cnt += torch.numel(outputs) + + def get_results(self): + return (self._accum_sq_rel / self._cnt).detach().cpu() + + def reset(self): + self._accum_sq_rel = 0. + self._cnt = 0. + + +class Threshold(Metric): + def __init__(self, thresholds=[1.25, 1.25**2, 1.25**3], attach_to=None ): + super(Threshold, self).__init__( attach_to=attach_to ) + self.thresholds = thresholds + self.reset() + + + @torch.no_grad() + def update(self, outputs, targets): + outputs, targets = self._attach(outputs, targets) + sigma = torch.max(outputs / targets, targets / outputs) + for thres in self.thresholds: + self._accum_thres[thres]+=torch.sum( sigma t_H: + f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H)) + elif s_H < t_H: + f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H)) + else: + pass + return (self.at(f_s) - self.at(f_t)).pow(2).mean() + + def at(self, f): + return F.normalize(f.pow(self.p).mean(1).view(f.size(0), -1)) + +class NSTLoss(nn.Module): + """like what you like: knowledge distill via neuron selectivity transfer""" + def __init__(self): + super(NSTLoss, self).__init__() + pass + + def forward(self, g_s, g_t): + return sum([self.nst_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]) + + def nst_loss(self, f_s, f_t): + s_H, t_H = f_s.shape[2], f_t.shape[2] + if s_H > t_H: + f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H)) + elif s_H < t_H: + f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H)) + else: + pass + + f_s = f_s.view(f_s.shape[0], f_s.shape[1], -1) + f_s = F.normalize(f_s, dim=2) + f_t = f_t.view(f_t.shape[0], f_t.shape[1], -1) + f_t = F.normalize(f_t, dim=2) + + # set full_loss as False to avoid unnecessary computation + full_loss = True + if full_loss: + return (self.poly_kernel(f_t, f_t).mean().detach() + self.poly_kernel(f_s, f_s).mean() + - 2 * self.poly_kernel(f_s, f_t).mean()) + else: + return self.poly_kernel(f_s, f_s).mean() - 2 * self.poly_kernel(f_s, f_t).mean() + + def poly_kernel(self, a, b): + a = a.unsqueeze(1) + b = b.unsqueeze(2) + res = (a * b).sum(-1).pow(2) + return res + +class SPLoss(nn.Module): + """Similarity-Preserving Knowledge Distillation, ICCV2019, verified by original author""" + def __init__(self): + super(SPLoss, self).__init__() + + def forward(self, g_s, g_t): + return sum([self.similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]) + + def similarity_loss(self, f_s, f_t): + bsz = f_s.shape[0] + f_s = f_s.view(bsz, -1) + f_t = f_t.view(bsz, -1) + + G_s = torch.mm(f_s, torch.t(f_s)) + # G_s = G_s / G_s.norm(2) + G_s = torch.nn.functional.normalize(G_s) + G_t = torch.mm(f_t, torch.t(f_t)) + # G_t = G_t / G_t.norm(2) + G_t = torch.nn.functional.normalize(G_t) + + G_diff = G_t - G_s + loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz) + return loss + +class RKDLoss(nn.Module): + """Relational Knowledge Disitllation, CVPR2019""" + def __init__(self, w_d=25, w_a=50): + super(RKDLoss, self).__init__() + self.w_d = w_d + self.w_a = w_a + + def forward(self, f_s, f_t): + student = f_s.view(f_s.shape[0], -1) + teacher = f_t.view(f_t.shape[0], -1) + + # RKD distance loss + with torch.no_grad(): + t_d = self.pdist(teacher, squared=False) + mean_td = t_d[t_d > 0].mean() + t_d = t_d / mean_td + + d = self.pdist(student, squared=False) + mean_d = d[d > 0].mean() + d = d / mean_d + + loss_d = F.smooth_l1_loss(d, t_d) + + # RKD Angle loss + with torch.no_grad(): + td = (teacher.unsqueeze(0) - teacher.unsqueeze(1)) + norm_td = F.normalize(td, p=2, dim=2) + t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1) + + sd = (student.unsqueeze(0) - student.unsqueeze(1)) + norm_sd = F.normalize(sd, p=2, dim=2) + s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1) + + loss_a = F.smooth_l1_loss(s_angle, t_angle) + + loss = self.w_d * loss_d + self.w_a * loss_a + + return loss + + @staticmethod + def pdist(e, squared=False, eps=1e-12): + e_square = e.pow(2).sum(dim=1) + prod = e @ e.t() + res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps) + + if not squared: + res = res.sqrt() + + res = res.clone() + res[range(len(e)), range(len(e))] = 0 + return res + +class PKTLoss(nn.Module): + """Probabilistic Knowledge Transfer for deep representation learning""" + def __init__(self): + super(PKTLoss, self).__init__() + + def forward(self, f_s, f_t): + return self.cosine_similarity_loss(f_s, f_t) + + @staticmethod + def cosine_similarity_loss(output_net, target_net, eps=0.0000001): + # Normalize each vector by its norm + output_net_norm = torch.sqrt(torch.sum(output_net ** 2, dim=1, keepdim=True)) + output_net = output_net / (output_net_norm + eps) + output_net[output_net != output_net] = 0 + + target_net_norm = torch.sqrt(torch.sum(target_net ** 2, dim=1, keepdim=True)) + target_net = target_net / (target_net_norm + eps) + target_net[target_net != target_net] = 0 + + # Calculate the cosine similarity + model_similarity = torch.mm(output_net, output_net.transpose(0, 1)) + target_similarity = torch.mm(target_net, target_net.transpose(0, 1)) + + # Scale cosine similarity to 0..1 + model_similarity = (model_similarity + 1.0) / 2.0 + target_similarity = (target_similarity + 1.0) / 2.0 + + # Transform them into probabilities + model_similarity = model_similarity / torch.sum(model_similarity, dim=1, keepdim=True) + target_similarity = target_similarity / torch.sum(target_similarity, dim=1, keepdim=True) + + # Calculate the KL-divergence + loss = torch.mean(target_similarity * torch.log((target_similarity + eps) / (model_similarity + eps))) + + return loss + +class SVDLoss(nn.Module): + """ + Self-supervised Knowledge Distillation using Singular Value Decomposition + """ + def __init__(self, k=1): + super(SVDLoss, self).__init__() + self.k = k + + def forward(self, g_s, g_t): + v_sb = None + v_tb = None + losses = [] + for i, f_s, f_t in zip(range(len(g_s)), g_s, g_t): + + u_t, s_t, v_t = self.svd(f_t, self.k) + u_s, s_s, v_s = self.svd(f_s, self.k + 3) + v_s, v_t = self.align_rsv(v_s, v_t) + s_t = s_t.unsqueeze(1) + v_t = v_t * s_t + v_s = v_s * s_t + + if i > 0: + s_rbf = torch.exp(-(v_s.unsqueeze(2) - v_sb.unsqueeze(1)).pow(2) / 8) + t_rbf = torch.exp(-(v_t.unsqueeze(2) - v_tb.unsqueeze(1)).pow(2) / 8) + + l2loss = (s_rbf - t_rbf.detach()).pow(2) + l2loss = torch.where(torch.isfinite(l2loss), l2loss, torch.zeros_like(l2loss)) + losses.append(l2loss.sum()) + + v_tb = v_t + v_sb = v_s + + bsz = g_s[0].shape[0] + losses = [l / bsz for l in losses] + return sum(losses) + + def svd(self, feat, n=1): + size = feat.shape + assert len(size) == 4 + + x = feat.view(-1, size[1], size[2] * size[2]).transpose(-2, -1) + u, s, v = torch.svd(x) + + u = self.removenan(u) + s = self.removenan(s) + v = self.removenan(v) + + if n > 0: + u = F.normalize(u[:, :, :n], dim=1) + s = F.normalize(s[:, :n], dim=1) + v = F.normalize(v[:, :, :n], dim=1) + + return u, s, v + + @staticmethod + def removenan(x): + x = torch.where(torch.isfinite(x), x, torch.zeros_like(x)) + return x + + @staticmethod + def align_rsv(a, b): + cosine = torch.matmul(a.transpose(-2, -1), b) + max_abs_cosine, _ = torch.max(torch.abs(cosine), 1, keepdim=True) + mask = torch.where(torch.eq(max_abs_cosine, torch.abs(cosine)), + torch.sign(cosine), torch.zeros_like(cosine)) + a = torch.matmul(a, mask) + return a, b + + + + diff --git a/model_measuring/kamal/core/tasks/task.py b/model_measuring/kamal/core/tasks/task.py new file mode 100644 index 0000000..9bc69bb --- /dev/null +++ b/model_measuring/kamal/core/tasks/task.py @@ -0,0 +1,186 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +import abc +import torch +import torch.nn as nn +import torch.nn.functional as F +import sys +import typing +from typing import Callable, Dict, List, Any +from collections import Mapping, Sequence +from . import loss +from kamal.core import metrics, exceptions +from kamal.core.attach import AttachTo + +class Task(object): + def __init__(self, name): + self.name = name + + @abc.abstractmethod + def get_loss( self, outputs, targets ) -> Dict: + pass + + @abc.abstractmethod + def predict(self, outputs) -> Any: + pass + +class GeneralTask(Task): + def __init__(self, + name: str, + loss_fn: Callable, + scaling:float=1.0, + pred_fn: Callable=lambda x: x, + attach_to=None): + super(GeneralTask, self).__init__(name) + self._attach = AttachTo(attach_to) + self.loss_fn = loss_fn + self.pred_fn = pred_fn + self.scaling = scaling + + def get_loss(self, outputs, targets): + outputs, targets = self._attach(outputs, targets) + return { self.name: self.loss_fn( outputs, targets ) * self.scaling } + + def predict(self, outputs): + outputs = self._attach(outputs) + return self.pred_fn(outputs) + + def __repr__(self): + rep = "Task: [%s loss_fn=%s scaling=%.4f attach=%s]"%(self.name, str(self.loss_fn), self.scaling, self._attach) + return rep + +class TaskCompose(list): + def __init__(self, tasks: list): + for task in tasks: + if isinstance(task, Task): + self.append(task) + + def get_loss(self, outputs, targets): + loss_dict = {} + for task in self: + loss_dict.update( task.get_loss( outputs, targets ) ) + return loss_dict + + def predict(self, outputs): + results = [] + for task in self: + results.append( task.predict( outputs ) ) + return results + + def __repr__(self): + rep="TaskCompose: \n" + for task in self: + rep+="\t%s\n"%task + +class StandardTask: + @staticmethod + def classification(name='ce', scaling=1.0, attach_to=None): + return GeneralTask( name=name, + loss_fn=nn.CrossEntropyLoss(), + scaling=scaling, + pred_fn=lambda x: x.max(1)[1], + attach_to=attach_to ) + + @staticmethod + def binary_classification(name='bce', scaling=1.0, attach_to=None): + return GeneralTask(name=name, + loss_fn=F.binary_cross_entropy_with_logits, + scaling=scaling, + pred_fn=lambda x: (x>0.5), + attach_to=attach_to ) + + @staticmethod + def regression(name='mse', scaling=1.0, attach_to=None): + return GeneralTask(name=name, + loss_fn=nn.MSELoss(), + scaling=scaling, + pred_fn=lambda x: x, + attach_to=attach_to ) + + @staticmethod + def segmentation(name='ce', scaling=1.0, attach_to=None): + return GeneralTask(name=name, + loss_fn=nn.CrossEntropyLoss(ignore_index=255), + scaling=scaling, + pred_fn=lambda x: x.max(1)[1], + attach_to=attach_to ) + + @staticmethod + def monocular_depth(name='l1', scaling=1.0, attach_to=None): + return GeneralTask(name=name, + loss_fn=nn.L1Loss(), + scaling=scaling, + pred_fn=lambda x: x, + attach_to=attach_to) + + @staticmethod + def detection(): + raise NotImplementedError + + @staticmethod + def distillation(name='kld', T=1.0, scaling=1.0, attach_to=None): + return GeneralTask(name=name, + loss_fn=loss.KLDiv(T=T), + scaling=scaling, + pred_fn=lambda x: x.max(1)[1], + attach_to=attach_to) + + +class StandardMetrics(object): + + @staticmethod + def classification(attach_to=None): + return metrics.MetricCompose( + metric_dict={'acc': metrics.Accuracy(attach_to=attach_to)} + ) + + @staticmethod + def regression(attach_to=None): + return metrics.MetricCompose( + metric_dict={'mse': metrics.MeanSquaredError(attach_to=attach_to)} + ) + + @staticmethod + def segmentation(num_classes, ignore_idx=255, attach_to=None): + confusion_matrix = metrics.ConfusionMatrix(num_classes=num_classes, ignore_idx=ignore_idx, attach_to=attach_to) + return metrics.MetricCompose( + metric_dict={'acc': metrics.Accuracy(attach_to=attach_to), + 'confusion_matrix': confusion_matrix , + 'miou': metrics.mIoU(confusion_matrix)} + ) + + @staticmethod + def monocular_depth(attach_to=None): + return metrics.MetricCompose( + metric_dict={ + 'rmse': metrics.RootMeanSquaredError(attach_to=attach_to), + 'rmse_log': metrics.RootMeanSquaredError( log_scale=True,attach_to=attach_to ), + 'rmse_scale_inv': metrics.ScaleInveriantMeanSquaredError(attach_to=attach_to), + 'abs rel': metrics.AbsoluteRelativeDifference(attach_to=attach_to), + 'sq rel': metrics.SquaredRelativeDifference(attach_to=attach_to), + 'percents within thresholds': metrics.Threshold( thresholds=[1.25, 1.25**2, 1.25**3], attach_to=attach_to ) + } + ) + + @staticmethod + def loss_metric(loss_fn): + return metrics.MetricCompose( + metric_dict={ + 'loss': metrics.AverageMetric( loss_fn ) + } + ) diff --git a/model_measuring/kamal/slim/__init__.py b/model_measuring/kamal/slim/__init__.py new file mode 100644 index 0000000..a7762da --- /dev/null +++ b/model_measuring/kamal/slim/__init__.py @@ -0,0 +1,2 @@ +from .prunning import Pruner, strategy +from .distillation import * \ No newline at end of file diff --git a/model_measuring/kamal/slim/distillation/__init__.py b/model_measuring/kamal/slim/distillation/__init__.py new file mode 100644 index 0000000..37c6066 --- /dev/null +++ b/model_measuring/kamal/slim/distillation/__init__.py @@ -0,0 +1,12 @@ +from .kd import KDDistiller +from .hint import * +from .attention import AttentionDistiller +from .nst import NSTDistiller +from .sp import SPDistiller +from .rkd import RKDDistiller +from .pkt import PKTDistiller +from .svd import SVDDistiller +from .cc import * +from .vid import * + +from . import data_free \ No newline at end of file diff --git a/model_measuring/kamal/slim/distillation/attention.py b/model_measuring/kamal/slim/distillation/attention.py new file mode 100644 index 0000000..93e878f --- /dev/null +++ b/model_measuring/kamal/slim/distillation/attention.py @@ -0,0 +1,47 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +from .kd import KDDistiller +from kamal.core.tasks.loss import AttentionLoss +from kamal.core.tasks.loss import KDLoss +from kamal.utils import set_mode, move_to_device + +import torch +import torch.nn as nn + +import time + +class AttentionDistiller(KDDistiller): + def __init__(self, logger=None, tb_writer=None ): + super(AttentionDistiller, self).__init__( logger, tb_writer ) + + def setup(self, + student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, + stu_hooks=[], tea_hooks=[], out_flags=[], device=None): + super(AttentionDistiller, self).setup( + student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device) + self.stu_hooks = stu_hooks + self.tea_hooks = tea_hooks + self.out_flags = out_flags + self._at_loss = AttentionLoss() + + def additional_kd_loss(self, engine, batch): + feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] + feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] + g_s = feat_s[1:-1] + g_t = feat_t[1:-1] + return self._at_loss(g_s, g_t) diff --git a/model_measuring/kamal/slim/distillation/cc.py b/model_measuring/kamal/slim/distillation/cc.py new file mode 100644 index 0000000..e22189a --- /dev/null +++ b/model_measuring/kamal/slim/distillation/cc.py @@ -0,0 +1,55 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +from .kd import KDDistiller +from kamal.core.tasks.loss import KDLoss + +import torch.nn as nn +import torch._ops + +import time + +class CCDistiller(KDDistiller): + + def __init__(self, logger=None, tb_writer=None ): + super(CCDistiller, self).__init__( logger, tb_writer ) + + def setup(self, student, teacher, dataloader, optimizer, embed_s, embed_t, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None ): + super(CCDistiller, self).setup( + student, teacher, dataloader, optimizer, T=T, gamma=gamma, alpha=alpha, device=device) + self.embed_s = embed_s.to(self.device).train() + self.embed_t = embed_t.to(self.device).train() + self.stu_hooks = stu_hooks + self.tea_hooks = tea_hooks + self.out_flags = out_flags + + def additional_kd_loss(self, engine, batch): + feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] + feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] + f_s = self.embed_s(feat_s[-1]) + f_t = self.embed_t(feat_t[-1]) + return torch.mean((torch.abs(f_s-f_t)[:-1] * torch.abs(f_s-f_t)[1:]).sum(1)) + +class LinearEmbed(nn.Module): + def __init__(self, dim_in=1024, dim_out=128): + super(LinearEmbed, self).__init__() + self.linear = nn.Linear(dim_in, dim_out) + + def forward(self, x): + x = x.view(x.shape[0], -1) + x = self.linear(x) + return x diff --git a/model_measuring/kamal/slim/distillation/data_free/__init__.py b/model_measuring/kamal/slim/distillation/data_free/__init__.py new file mode 100644 index 0000000..514c58b --- /dev/null +++ b/model_measuring/kamal/slim/distillation/data_free/__init__.py @@ -0,0 +1 @@ +from .zskt import ZSKTDistiller \ No newline at end of file diff --git a/model_measuring/kamal/slim/distillation/data_free/zskt.py b/model_measuring/kamal/slim/distillation/data_free/zskt.py new file mode 100644 index 0000000..2ba6cae --- /dev/null +++ b/model_measuring/kamal/slim/distillation/data_free/zskt.py @@ -0,0 +1,99 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +import torch +import time +import torch.nn.functional as F +from kamal.slim.distillation.kd import KDDistiller +from kamal.utils import set_mode +from kamal.core.tasks.loss import kldiv + +class ZSKTDistiller(KDDistiller): + def __init__( self, + student, + teacher, + generator, + z_dim, + logger=None, + viz=None): + super(ZSKTDistiller, self).__init__(logger, viz) + self.teacher = teacher + self.model = self.student = student + self.generator = generator + self.z_dim = z_dim + + def train(self, start_iter, max_iter, optim_s, optim_g, device=None): + if device is None: + device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) + self.device = device + self.optim_s, self.optim_g = optim_s, optim_g + + self.model.to(self.device) + self.teacher.to(self.device) + self.generator.to(self.device) + self.train_loader = [0, ] + + with set_mode(self.student, training=True), \ + set_mode(self.teacher, training=False), \ + set_mode(self.generator, training=True): + super( ZSKTDistiller, self ).train( start_iter, max_iter ) + + def search_optimizer(self, evaluator, train_loader, hpo_space=None, mode='min', max_evals=20, max_iters=400): + optimizer = hpo.search_optimizer(self, train_loader, evaluator=evaluator, hpo_space=hpo_space, mode=mode, max_evals=max_evals, max_iters=max_iters) + return optimizer + + def step(self): + start_time = time.perf_counter() + + # Adv + z = torch.randn( self.z_dim ).to(self.device) + fake = self.generator( z ) + self.optim_g.zero_grad() + t_out = self.teacher( fake ) + s_out = self.student( fake ) + loss_g = -kldiv( s_out, t_out ) + loss_g.backward() + self.optim_g.step() + + with torch.no_grad(): + fake = self.generator( z ) + t_out = self.teacher( fake.detach() ) + for _ in range(10): + self.optim_s.zero_grad() + s_out = self.student( fake.detach() ) + loss_s = kldiv( s_out, t_out ) + loss_s.backward() + self.optim_s.step() + + loss_dict = { + 'loss_g': loss_g, + 'loss_s': loss_s, + } + + step_time = time.perf_counter() - start_time + + # record training info + info = loss_dict + info['step_time'] = step_time + info['lr_s'] = float( self.optim_s.param_groups[0]['lr'] ) + info['lr_g'] = float( self.optim_g.param_groups[0]['lr'] ) + self.history.put_scalars( **info ) + + def reset(self): + self.history = None + self._train_loader_iter = iter(train_loader) + self.iter = self.start_iter diff --git a/model_measuring/kamal/slim/distillation/hint.py b/model_measuring/kamal/slim/distillation/hint.py new file mode 100644 index 0000000..910070f --- /dev/null +++ b/model_measuring/kamal/slim/distillation/hint.py @@ -0,0 +1,86 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +from .kd import KDDistiller +from kamal.core.tasks.loss import KDLoss + +import torch.nn as nn +import torch._ops + +import time + + +class HintDistiller(KDDistiller): + def __init__(self, logger=None, tb_writer=None ): + super(HintDistiller, self).__init__( logger, tb_writer ) + + def setup(self, + student, teacher, regressor, dataloader, optimizer, + hint_layer=2, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, + stu_hooks=[], tea_hooks=[], out_flags=[], device=None): + super( HintDistiller, self ).setup( + student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device ) + self.regressor = regressor + self._hint_layer = hint_layer + self._beta = beta + self.stu_hooks = stu_hooks + self.tea_hooks = tea_hooks + self.out_flags = out_flags + self.regressor.to(device) + + def additional_kd_loss(self, engine, batch): + feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] + feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] + f_s = self.regressor(feat_s[self._hint_layer]) + f_t = feat_t[self._hint_layer] + return nn.functional.mse_loss(f_s, f_t) + +class Regressor(nn.Module): + """ + Convolutional regression for FitNet + @inproceedings{tian2019crd, + title={Contrastive Representation Distillation}, + author={Yonglong Tian and Dilip Krishnan and Phillip Isola}, + booktitle={International Conference on Learning Representations}, + year={2020} + } + """ + + def __init__(self, s_shape, t_shape, is_relu=True): + super(Regressor, self).__init__() + self.is_relu = is_relu + _, s_C, s_H, s_W = s_shape + _, t_C, t_H, t_W = t_shape + if s_H == 2 * t_H: + self.conv = nn.Conv2d(s_C, t_C, kernel_size=3, stride=2, padding=1) + elif s_H * 2 == t_H: + self.conv = nn.ConvTranspose2d( + s_C, t_C, kernel_size=4, stride=2, padding=1) + elif s_H >= t_H: + self.conv = nn.Conv2d(s_C, t_C, kernel_size=(1+s_H-t_H, 1+s_W-t_W)) + else: + raise NotImplemented( + 'student size {}, teacher size {}'.format(s_H, t_H)) + self.bn = nn.BatchNorm2d(t_C) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + if self.is_relu: + return self.relu(self.bn(x)) + else: + return self.bn(x) diff --git a/model_measuring/kamal/slim/distillation/kd.py b/model_measuring/kamal/slim/distillation/kd.py new file mode 100644 index 0000000..f570ed8 --- /dev/null +++ b/model_measuring/kamal/slim/distillation/kd.py @@ -0,0 +1,90 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +from kamal.core.engine.trainer import Engine +from kamal.core.tasks.loss import kldiv +import torch.nn.functional as F +from kamal.utils.logger import get_logger +from kamal.utils import set_mode, move_to_device +import weakref + +import torch +import torch.nn as nn + +import time +import numpy as np + +class KDDistiller(Engine): + def __init__( self, + logger=None, + tb_writer=None): + super(KDDistiller, self).__init__(logger=logger, tb_writer=tb_writer) + + def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, device=None): + self.model = self.student = student + self.teacher = teacher + self.dataloader = dataloader + self.optimizer = optimizer + if device is None: + device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) + self.device = device + + self.T = T + self.gamma = gamma + self.alpha = alpha + self.beta = beta + + self.student.to(self.device) + self.teacher.to(self.device) + + def run( self, max_iter, start_iter=0, epoch_length=None): + with set_mode(self.student, training=True), \ + set_mode(self.teacher, training=False): + super( KDDistiller, self ).run( self.step_fn, self.dataloader, start_iter=start_iter, max_iter=max_iter, epoch_length=epoch_length) + + def additional_kd_loss(self, engine, batch): + return batch[0].new_zeros(1) + + def step_fn(self, engine, batch): + student = self.student + teacher = self.teacher + start_time = time.perf_counter() + batch = move_to_device(batch, self.device) + inputs, targets = batch + outputs = student(inputs) + with torch.no_grad(): + soft_targets = teacher(inputs) + + loss_dict = { "loss_kld": self.alpha * kldiv(outputs, soft_targets, T=self.T), + "loss_ce": self.beta * F.cross_entropy( outputs, targets ), + "loss_additional": self.gamma * self.additional_kd_loss(engine, batch) } + + loss = sum( loss_dict.values() ) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + step_time = time.perf_counter() - start_time + metrics = { loss_name: loss_value.item() for (loss_name, loss_value) in loss_dict.items() } + metrics.update({ + 'total_loss': loss.item(), + 'step_time': step_time, + 'lr': float( self.optimizer.param_groups[0]['lr'] ) + }) + return metrics + + + \ No newline at end of file diff --git a/model_measuring/kamal/slim/distillation/nst.py b/model_measuring/kamal/slim/distillation/nst.py new file mode 100644 index 0000000..a7d2be9 --- /dev/null +++ b/model_measuring/kamal/slim/distillation/nst.py @@ -0,0 +1,44 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +from .kd import KDDistiller +from kamal.core.tasks.loss import KDLoss +from kamal.core.tasks.loss import NSTLoss + +import torch +import torch.nn as nn + +import time + +class NSTDistiller(KDDistiller): + def __init__(self, logger=None, tb_writer=None ): + super(NSTDistiller, self).__init__( logger, tb_writer ) + + def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None): + super( NSTDistiller, self ).setup( + student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device ) + self.stu_hooks = stu_hooks + self.tea_hooks = tea_hooks + self.out_flags = out_flags + self._nst_loss = NSTLoss() + + def additional_kd_loss(self, engine, batch): + feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] + feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] + g_s = feat_s[1:-1] + g_t = feat_t[1:-1] + return self._nst_loss(g_s, g_t) diff --git a/model_measuring/kamal/slim/distillation/pkt.py b/model_measuring/kamal/slim/distillation/pkt.py new file mode 100644 index 0000000..32b30aa --- /dev/null +++ b/model_measuring/kamal/slim/distillation/pkt.py @@ -0,0 +1,44 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +from .kd import KDDistiller +from kamal.core.tasks.loss import PKTLoss +from kamal.core.tasks.loss import KDLoss + +import torch +import torch.nn as nn + +import time + +class PKTDistiller(KDDistiller): + def __init__(self, logger=None, tb_writer=None ): + super(PKTDistiller, self).__init__( logger, tb_writer ) + + def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None): + super( PKTDistiller, self ).setup( + student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device ) + self.stu_hooks = stu_hooks + self.tea_hooks = tea_hooks + self.out_flags = out_flags + self._pkt_loss = PKTLoss() + + def additional_kd_loss(self, engine, batch): + feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] + feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] + f_s = feat_s[-1] + f_t = feat_t[-1] + return self._pkt_loss(f_s, f_t) diff --git a/model_measuring/kamal/slim/distillation/rkd.py b/model_measuring/kamal/slim/distillation/rkd.py new file mode 100644 index 0000000..c7caebe --- /dev/null +++ b/model_measuring/kamal/slim/distillation/rkd.py @@ -0,0 +1,45 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +from .kd import KDDistiller +from kamal.core.tasks.loss import RKDLoss +from kamal.core.tasks.loss import KDLoss + +import torch +import torch.nn as nn + +import time + +class RKDDistiller(KDDistiller): + def __init__(self, logger=None, tb_writer=None ): + super(RKDDistiller, self).__init__( logger, tb_writer ) + + def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None): + super( RKDDistiller, self ).setup( + student, teacher, dataloader, optimizer, T=T, gamma=gamma, alpha=alpha, device=device ) + self.stu_hooks = stu_hooks + self.tea_hooks = tea_hooks + self.out_flags = out_flags + self._rkd_loss = RKDLoss() + + def additional_kd_loss(self, engine, batch): + feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] + feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] + f_s = feat_s[-1] + f_t = feat_t[-1] + return self._rkd_loss(f_s, f_t) + \ No newline at end of file diff --git a/model_measuring/kamal/slim/distillation/sp.py b/model_measuring/kamal/slim/distillation/sp.py new file mode 100644 index 0000000..6d3db56 --- /dev/null +++ b/model_measuring/kamal/slim/distillation/sp.py @@ -0,0 +1,44 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +from .kd import KDDistiller +from kamal.core.tasks.loss import SPLoss +from kamal.core.tasks.loss import KDLoss + +import torch +import torch.nn as nn + +import time + +class SPDistiller(KDDistiller): + def __init__(self, logger=None, tb_writer=None ): + super(SPDistiller, self).__init__( logger, tb_writer ) + + def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None): + super( SPDistiller, self ).setup( + student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device ) + self.stu_hooks = stu_hooks + self.tea_hooks = tea_hooks + self.out_flags = out_flags + self._sp_loss = SPLoss() + + def additional_kd_loss(self, engine, batch): + feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] + feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] + g_s = [feat_s[-2]] + g_t = [feat_t[-2]] + return self._sp_loss(g_s, g_t) diff --git a/model_measuring/kamal/slim/distillation/svd.py b/model_measuring/kamal/slim/distillation/svd.py new file mode 100644 index 0000000..c8f3ee7 --- /dev/null +++ b/model_measuring/kamal/slim/distillation/svd.py @@ -0,0 +1,45 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +from .kd import KDDistiller +from kamal.core.tasks.loss import SVDLoss +from kamal.core.tasks.loss import KDLoss + +import torch +import torch.nn as nn + +import time + + +class SVDDistiller(KDDistiller): + def __init__(self, logger=None, tb_writer=None ): + super(SVDDistiller, self).__init__( logger, tb_writer ) + + def setup(self, student, teacher, dataloader, optimizer, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None): + super( SVDDistiller, self ).setup( + student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device ) + self.stu_hooks = stu_hooks + self.tea_hooks = tea_hooks + self.out_flags = out_flags + self._svd_loss = SVDLoss() + + def additional_kd_loss(self, engine, batch): + feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] + feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] + g_s = feat_s[1:-1] + g_t = feat_t[1:-1] + return self._svd_loss( g_s, g_t ) \ No newline at end of file diff --git a/model_measuring/kamal/slim/distillation/vid.py b/model_measuring/kamal/slim/distillation/vid.py new file mode 100644 index 0000000..b382f1b --- /dev/null +++ b/model_measuring/kamal/slim/distillation/vid.py @@ -0,0 +1,90 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +import numpy as np +import time +import torch.nn as nn +import torch._ops +import torch.nn.functional as F +from .kd import KDDistiller +from kamal.utils import set_mode +from kamal.core.tasks.loss import KDLoss + +class VIDDistiller(KDDistiller): + def __init__(self, logger=None, tb_writer=None ): + super(VIDDistiller, self).__init__( logger, tb_writer ) + + def setup(self, student, teacher, dataloader, optimizer, regressor_l, T=1.0, alpha=1.0, beta=1.0, gamma=1.0, stu_hooks=[], tea_hooks=[], out_flags=[], device=None): + super( VIDDistiller, self ).setup( + student, teacher, dataloader, optimizer, T=T, alpha=alpha, beta=beta, gamma=gamma, device=device ) + self.regressor_l = regressor_l + self.stu_hooks = stu_hooks + self.tea_hooks = tea_hooks + self.out_flags = out_flags + self.regressor_l = [regressor.to(self.device).train() for regressor in self.regressor_l] + + def additional_kd_loss(self, engine, batch): + feat_s = [f.feat_out if flag else f.feat_in for (f, flag) in zip(self.stu_hooks, self.out_flags)] + feat_t = [f.feat_out.detach() if flag else f.feat_in for (f, flag) in zip(self.tea_hooks, self.out_flags)] + g_s = feat_s[1:-1] + g_t = feat_t[1:-1] + return sum([c(f_s, f_t) for f_s, f_t, c in zip(g_s, g_t, self.regressor_l)]) + +class VIDRegressor(nn.Module): + def __init__(self, + num_input_channels, + num_mid_channel, + num_target_channels, + init_pred_var=5.0, + eps=1e-5): + super(VIDRegressor, self).__init__() + + def conv1x1(in_channels, out_channels, stride=1): + return nn.Conv2d( + in_channels, out_channels, + kernel_size=1, padding=0, + bias=False, stride=stride) + + self.regressor = nn.Sequential( + conv1x1(num_input_channels, num_mid_channel), + nn.ReLU(), + conv1x1(num_mid_channel, num_mid_channel), + nn.ReLU(), + conv1x1(num_mid_channel, num_target_channels), + ) + self.log_scale = torch.nn.Parameter( + np.log(np.exp(init_pred_var-eps)-1.0) * torch.ones(num_target_channels) + ) + self.eps = eps + + def forward(self, input, target): + # pool for dimentsion match + s_H, t_H = input.shape[2], target.shape[2] + if s_H > t_H: + input = F.adaptive_avg_pool2d(input, (t_H, t_H)) + elif s_H < t_H: + target = F.adaptive_avg_pool2d(target, (s_H, s_H)) + else: + pass + pred_mean = self.regressor(input) + pred_var = torch.log(1.0+torch.exp(self.log_scale))+self.eps + pred_var = pred_var.view(1, -1, 1, 1) + neg_log_prob = 0.5*( + (pred_mean-target)**2/pred_var+torch.log(pred_var) + ) + loss = torch.mean(neg_log_prob) + return loss diff --git a/model_measuring/kamal/slim/prunning/__init__.py b/model_measuring/kamal/slim/prunning/__init__.py new file mode 100644 index 0000000..29b7e23 --- /dev/null +++ b/model_measuring/kamal/slim/prunning/__init__.py @@ -0,0 +1,2 @@ +from .pruner import Pruner +from .strategy import LNStrategy, RandomStrategy \ No newline at end of file diff --git a/model_measuring/kamal/slim/prunning/pruner.py b/model_measuring/kamal/slim/prunning/pruner.py new file mode 100644 index 0000000..9e6e2c3 --- /dev/null +++ b/model_measuring/kamal/slim/prunning/pruner.py @@ -0,0 +1,37 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +from copy import deepcopy +import os +import torch + +class Pruner(object): + def __init__(self, strategy): + self.strategy = strategy + + def prune(self, model, rate=0.1, example_inputs=None): + ori_num_params = sum( [ torch.numel(p) for p in model.parameters() ] ) + model = deepcopy(model).cpu() + model = self._prune( model, rate=rate, example_inputs=example_inputs ) + new_num_params = sum( [ torch.numel(p) for p in model.parameters() ] ) + print( "%d=>%d, %.2f%% params were pruned"%( ori_num_params, new_num_params, 100*(ori_num_params-new_num_params)/ori_num_params ) ) + return model + + def _prune(self, model, **kargs): + return self.strategy( model, **kargs) + + \ No newline at end of file diff --git a/model_measuring/kamal/slim/prunning/strategy.py b/model_measuring/kamal/slim/prunning/strategy.py new file mode 100644 index 0000000..5689724 --- /dev/null +++ b/model_measuring/kamal/slim/prunning/strategy.py @@ -0,0 +1,85 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +import torch_pruning as tp +import abc +import torch +import torch.nn as nn +import random +import numpy as np + +_PRUNABLE_MODULES= tp.DependencyGraph.PRUNABLE_MODULES + +class BaseStrategy(abc.ABC): + + @abc.abstractmethod + def select(self, layer_to_prune): + pass + + def __call__(self, model, rate=0.1, example_inputs=None): + if example_inputs is None: + example_inputs = torch.randn( 1,3,256,256 ) + + DG = tp.DependencyGraph() + DG.build_dependency(model, example_inputs=example_inputs) + + prunable_layers = [] + total_params = 0 + num_accumulative_conv_params = [ 0, ] + + for m in model.modules(): + if isinstance(m, _PRUNABLE_MODULES ) : + nparam = tp.utils.count_prunable_params( m ) + total_params += nparam + if isinstance(m, (nn.modules.conv._ConvNd, nn.Linear)): + prunable_layers.append( m ) + num_accumulative_conv_params.append( num_accumulative_conv_params[-1]+nparam ) + prunable_layers.pop(-1) # remove the last layer + num_accumulative_conv_params.pop(-1) # remove the last layer + + num_conv_params = num_accumulative_conv_params[-1] + num_accumulative_conv_params = [ ( num_accumulative_conv_params[i], num_accumulative_conv_params[i+1] ) for i in range(len(num_accumulative_conv_params)-1) ] + + def map_param_idx_to_conv_layer(i): + for l, accu in zip( prunable_layers, num_accumulative_conv_params ): + if accu[0]<=i and i +icon + + + diff --git a/model_measuring/kamal/transferability/__init__.py b/model_measuring/kamal/transferability/__init__.py new file mode 100644 index 0000000..7798dac --- /dev/null +++ b/model_measuring/kamal/transferability/__init__.py @@ -0,0 +1,20 @@ +from kamal.transferability.trans_graph import TransferabilityGraph +from kamal.transferability.trans_metric import AttrMapMetric + +import kamal +from kamal.vision import sync_transforms as sT +import os +import torch +from PIL import Image + +if __name__=='__main__': + zoo = '/tmp/pycharm_project_225/kamal/transferability/model2' + TG = TransferabilityGraph(zoo) + probe_set_root = '/tmp/pycharm_project_225/kamal/transferability/probe_data' + for probe_set in os.listdir( probe_set_root ): + print("Add %s"%(probe_set)) + imgs_set = list( os.listdir( os.path.join( probe_set_root, probe_set ) ) ) + images = [ Image.open( os.path.join(probe_set_root, probe_set, img) ) for img in imgs_set ] + metric = AttrMapMetric(images, device=torch.device('cuda')) + TG.add_metric( probe_set, metric) + TG.export_to_json(probe_set, 'exported_metrics/%s.json'%(probe_set), topk=3, normalize=True) diff --git a/model_measuring/kamal/transferability/depara/__init__.py b/model_measuring/kamal/transferability/depara/__init__.py new file mode 100644 index 0000000..ab0493a --- /dev/null +++ b/model_measuring/kamal/transferability/depara/__init__.py @@ -0,0 +1,3 @@ +from .attribution_graph import get_attribution_graph, graph_similarity + +from .attribution_map import attribution_map, attr_map_distance \ No newline at end of file diff --git a/model_measuring/kamal/transferability/depara/attribution_graph.py b/model_measuring/kamal/transferability/depara/attribution_graph.py new file mode 100644 index 0000000..cb8f5f7 --- /dev/null +++ b/model_measuring/kamal/transferability/depara/attribution_graph.py @@ -0,0 +1,184 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +from typing import Type, Dict, Any +import itertools + +import numpy as np +from numpy import Inf +import networkx + +import torch +from torch.nn import Module +from torch.utils.data import Dataset + +from .attribution_map import attribution_map, attr_map_similarity + + +def graph_to_array(graph: networkx.Graph): + weight_matrix = np.zeros((len(graph.nodes), len(graph.nodes))) + for i, n1 in enumerate(graph.nodes): + for j, n2 in enumerate(graph.nodes): + try: + dist = graph[n1][n2]["weight"] + except KeyError: + dist = 1 + weight_matrix[i, j] = dist + return weight_matrix + + +class FeatureMapExtractor(): + def __init__(self, module: Module): + self.module = module + self.feature_pool: Dict[str, Dict[str, Any]] = dict() + self.register_hooks() + + def register_hooks(self): + for name, m in self.module.named_modules(): + if "pool" in name: + m.name = name + self.feature_pool[name] = dict() + + def hook(m: Module, input, output): + self.feature_pool[m.name]["feature"] = input + self.feature_pool[name]["handle"] = m.register_forward_hook(hook) + + def _forward(self, x): + self.module(x) + + def remove_hooks(self): + for name, cfg in self.feature_pool.items(): + cfg["handle"].remove() + cfg.clear() + self.feature_pool.clear() + + def extract_final_map(self, x): + self._forward(x) + feature_map = None + max_channel = 0 + min_size = Inf + for name, cfg in self.feature_pool.items(): + f = cfg["feature"] + if len(f) == 1 and isinstance(f[0], torch.Tensor): + f = f[0] + if f.dim() == 4: # BxCxHxW + b, c, h, w = f.shape + if c >= max_channel and 1 < h * w <= min_size: + feature_map = f + max_channel = c + min_size = h * w + return feature_map + + +def get_attribution_graph( + model: Module, + attribution_type: Type, + with_noise: bool, + probe_data: Dataset, + device: torch.device, + norm_square: bool = False, +): + attribution_graph = networkx.Graph() + model = model.to(device) + extractor = FeatureMapExtractor(model) + for i, x in enumerate(probe_data): + x = x.to(device) + x.requires_grad_() + + attribution = attribution_map( + func=lambda x: extractor.extract_final_map(x), + attribution_type=attribution_type, + with_noise=with_noise, + probe_data=x.unsqueeze(0), + norm_square=norm_square + ) + + attribution_graph.add_node(i, attribution_map=attribution) + + nodes = attribution_graph.nodes + for i, j in itertools.product(nodes, nodes): + if i < j: + weight = attr_map_similarity( + attribution_graph.nodes(data=True)[i]["attribution_map"], + attribution_graph.nodes(data=True)[j]["attribution_map"] + ) + attribution_graph.add_edge(i, j, weight=weight) + + return attribution_graph + + +def edge_to_embedding(graph: networkx.Graph): + adj = graph_to_array(graph) + up_tri_mask = np.tri(*adj.shape[-2:], k=0, dtype=bool) + return adj[up_tri_mask] + + +def embedding_to_rank(embedding: np.ndarray): + order = embedding.argsort() + ranks = order.argsort() + return ranks + + +def graph_similarity(g1: networkx.Graph, g2: networkx.Graph, Lambda: float = 1.0): + nodes_1 = g1.nodes(data=True) + nodes_2 = g2.nodes(data=True) + assert len(nodes_1) == len(nodes_2) + + # calculate vertex similarity + v_s = 0 + n = len(g1.nodes) + for i in range(n): + v_s += attr_map_similarity( + map_1=g1.nodes(data=True)[i]["attribution_map"], + map_2=g2.nodes(data=True)[i]["attribution_map"] + ) + vertex_similarity = v_s / n + + # calculate edges similarity + emb_1 = edge_to_embedding(g1) + emb_2 = edge_to_embedding(g2) + rank_1 = embedding_to_rank(emb_1) + rank_2 = embedding_to_rank(emb_2) + k = emb_1.shape[0] + edge_similarity = 1 - 6 * np.sum(np.square(rank_1 - rank_2)) / (k ** 3 - k) + + return vertex_similarity + Lambda * edge_similarity + + +if __name__ == "__main__": + from captum.attr import InputXGradient + from torchvision.models import resnet34 + + model_1 = resnet34(num_classes=10) + graph_1 = get_attribution_graph( + model_1, + attribution_type=InputXGradient, + with_noise=False, + probe_data=torch.rand(10, 3, 244, 244), + device=torch.device("cpu") + ) + + model_2 = resnet34(num_classes=10) + graph_2 = get_attribution_graph( + model_2, + attribution_type=InputXGradient, + with_noise=False, + probe_data=torch.rand(10, 3, 244, 244), + device=torch.device("cpu") + ) + + print(graph_similarity(graph_1, graph_2)) diff --git a/model_measuring/kamal/transferability/depara/attribution_map.py b/model_measuring/kamal/transferability/depara/attribution_map.py new file mode 100644 index 0000000..50c4018 --- /dev/null +++ b/model_measuring/kamal/transferability/depara/attribution_map.py @@ -0,0 +1,87 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +import torch +from typing import Type, Callable +from captum.attr import Attribution +from captum.attr import NoiseTunnel + + +def with_norm(func: Callable[[torch.Tensor], torch.Tensor], x: torch.Tensor, square: bool = False): + x = func(x) + x = torch.norm(x.flatten(1), dim=1, p=2) + if square: + x = torch.pow(x, 2) + return x + + +def attribution_map( + func: Callable[[torch.Tensor], torch.Tensor], + attribution_type: Type, + with_noise: bool, + probe_data: torch.Tensor, + norm_square: bool = False, + **attribution_kwargs +) -> torch.Tensor: + """ + Calculate attribution map with given attribution type(algorithm). + Args: + model: pytorch module + attribution_type: attribution algorithm, e.g. IntegratedGradients, InputXGradient, ... + with_noise: whether to add noise tunnel + probe_data: input data to model + device: torch.device("cuda: 0") + attribution_kwargs: other kwargs for attribution method + Return: attribution map + """ + attribution: Attribution = attribution_type(lambda x: with_norm(func, x, norm_square)) + if with_noise: + attribution = NoiseTunnel(attribution) + attr_map = attribution.attribute( + inputs=probe_data, + target=None, + **attribution_kwargs + ) + return attr_map.detach() + +def attr_map_distance(map_1: torch.Tensor, map_2: torch.Tensor): + if map_1.shape != map_2.shape: + map_1 = torch.nn.functional.interpolate( map_1, size=map_2.shape[-2:] ) + #dist = torch.dist(map_1.flatten(1), map_2.flatten(1), p=2).mean() + dist = 1 - torch.cosine_similarity(map_1.flatten(1), map_2.flatten(1)).mean() + return dist.item() + +def attr_map_similarity(map_1: torch.Tensor, map_2: torch.Tensor): + assert(map_1.shape == map_2.shape) + dist = torch.cosine_similarity(map_1.flatten(1), map_2.flatten(1)).mean() + return dist.item() + + +if __name__ == "__main__": + import captum + + def ff(x): + return x ** 2 + + m = attribution_map( + ff, + captum.attr.InputXGradient, + with_noise=False, + probe_data=torch.tensor([[1, 2, 3, 4]], dtype=torch.float, requires_grad=True), + norm_square=True + ) + print(m) diff --git a/model_measuring/kamal/transferability/trans_graph.py b/model_measuring/kamal/transferability/trans_graph.py new file mode 100644 index 0000000..b63ebba --- /dev/null +++ b/model_measuring/kamal/transferability/trans_graph.py @@ -0,0 +1,135 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" +import torch +import networkx as nx +from . import depara +import os, abc +from typing import Callable +from kamal import hub +import json, numbers + +from tqdm import tqdm + +class Node(object): + def __init__(self, hub_root, entry_name, spec_name): + self.hub_root = hub_root + self.entry_name = entry_name + self.spec_name = spec_name + + @property + def model(self): + return hub.load( self.hub_root, entry_name=self.entry_name, spec_name=self.spec_name ).eval() + + @property + def tag(self): + return hub.load_tags(self.hub_root, entry_name=self.entry_name, spec_name=self.spec_name) + + @property + def metadata(self): + return hub.load_metadata(self.hub_root, entry_name=self.entry_name, spec_name=self.spec_name) + +class TransferabilityGraph(object): + def __init__(self, model_zoo_set): + self.model_zoo_set = model_zoo_set + # self.model_zoo = os.path.abspath( os.path.expanduser( model_zoo ) ) + self._graphs = dict() + self._models = dict() + self._register_models() + + def _register_models(self): + cnt = 0 + for model_zoo in self.model_zoo_set: + model_zoo = os.path.abspath(os.path.expanduser(model_zoo)) + for hub_root in self._list_modelzoo(model_zoo): + for entry_name, spec_name in hub.list_spec(hub_root): + node = Node( hub_root, entry_name, spec_name ) + name = node.metadata['name'] + self._models[name] = node + cnt += 1 + print("%d models has been registered!"%cnt) + + def _list_modelzoo(self, zoo_dir): + zoo_list = [] + def _traverse(path): + for item in os.listdir(path): + item_path = os.path.join(path, item) + if os.path.isdir(item_path): + if os.path.exists(os.path.join( item_path, 'code/hubconf.py' )): + zoo_list.append(item_path) + else: + _traverse( item_path ) + _traverse(zoo_dir) + return zoo_list + + def add_metric(self, metric_name, metric): + self._graphs[metric_name] = g = nx.DiGraph() + g.add_nodes_from( self._models.values() ) + for n1 in self._models.values(): + for n2 in tqdm(self._models.values()): + if n1!=n2 and not g.has_edge(n1, n2): + try: + g.add_edge(n1, n2, dist=metric( n1, n2 )) + except: + ori_device = metric.device + metric.device = torch.device('cpu') + g.add_edge(n1, n2, dist=metric( n1, n2 )) + metric.device = ori_device + + def export_to_json(self, metric_name, output_filename, topk=None, normalize=False): + graph = self._graphs.get( metric_name, None ) + assert graph is not None + graph_data={ + 'nodes': [], + 'edges': [], + } + node_to_idx = {} + for i, node in enumerate(self._models.values()): + tags = node.tag + metadata = node.metadata + node_data = { k:v for (k, v) in tags.items() if isinstance(v, (numbers.Number, str) ) } + node_data['name'] = metadata['name'] + node_data['task'] = metadata['task'] + node_data['dataset'] = metadata['dataset'] + node_data['url'] = metadata['url'] + node_data['id'] = i + graph_data['nodes'].append({'tags': node_data}) + node_to_idx[node] = i + + # record Edges + edge_list = graph_data['edges'] + topk_dist = { idx: [] for idx in range(len( self._models )) } + for i, edge in enumerate(graph.edges.data('dist')): + s, t, d = int( node_to_idx[edge[0]] ), int( node_to_idx[edge[1]] ), float(edge[2]) + topk_dist[s].append(d) + edge_list.append([ + s, t, d # source, target, distance + ]) + + if isinstance(topk, int): + for i, dist in topk_dist.items(): + dist.sort() + topk_dist[i] = dist[topk] + graph_data['edges'] = [ edge for edge in edge_list if edge[2] < topk_dist[edge[0]] ] + + if normalize: + edge_dist = [e[2] for e in graph_data['edges']] + min_dist, max_dist = min(edge_dist), max(edge_dist) + for e in graph_data['edges']: + e[2] = (e[2] - min_dist+1e-8) / (max_dist - min_dist+1e-8) + + with open(output_filename, 'w') as fp: + json.dump(graph_data, fp) \ No newline at end of file diff --git a/model_measuring/kamal/transferability/trans_metric.py b/model_measuring/kamal/transferability/trans_metric.py new file mode 100644 index 0000000..c930bd6 --- /dev/null +++ b/model_measuring/kamal/transferability/trans_metric.py @@ -0,0 +1,109 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +import abc + +from . import depara +from captum.attr import InputXGradient +import torch +from kamal.core import hub +from kamal.vision import sync_transforms as sT + +class TransMetric(abc.ABC): + def __init__(self): + pass + + def __call__(self, a, b) -> float: + return 0 + +class DeparaMetric(TransMetric): + def __init__(self, data, device): + self.data = data + self.device = device + self._cache = {} + + def _get_transform(self, metadata): + input_metadata = metadata['input'] + size = input_metadata['size'] + space = input_metadata['space'] + drange = input_metadata['range'] + normalize = input_metadata['normalize'] + if size==None: + size=224 + if isinstance(size, (list, tuple)): + size = size[-1] + transform = [ + sT.Resize(size), + sT.CenterCrop(size), + ] + if space=='bgr': + transform.append(sT.FlipChannels()) + if list(drange)==[0, 1]: + transform.append( sT.ToTensor() ) + elif list(drange)==[0, 255]: + transform.append( sT.ToTensor(normalize=False, dtype=torch.float) ) + else: + raise NotImplementedError + if normalize is not None: + transform.append(sT.Normalize( mean=normalize['mean'], std=normalize['std'] )) + return sT.Compose(transform) + + def _get_attr_graph(self, n): + transform = self._get_transform(n.metadata) + data = torch.stack( [ transform( d ) for d in self.data ], dim=0 ) + return depara.get_attribution_graph( + n.model, + attribution_type=InputXGradient, + with_noise=False, + probe_data=data, + device=self.device + ) + + def __call__(self, n1, n2): + attrgraph_1 = self._cache.get(n1, None) + attrgraph_2 = self._cache.get(n2, None) + if attrgraph_1 is None: + self._cache[n1] = attrgraph_1 = self._get_attr_graph(n1).cpu() + if attrgraph_2 is None: + self._cache[n2] = attrgraph_2 = self._get_attr_graph(n2).cpu() + result = depara.graph_similarity(attrgraph_1, attrgraph_2) + self._cache[n1] = self._cache[n1] + self._cache[n2] = self._cache[n2] + +class AttrMapMetric(DeparaMetric): + def _get_attr_map(self, n): + transform = self._get_transform(n.metadata) + data = torch.stack( [ transform( d ).to(self.device) for d in self.data ], dim=0 ) + return depara.attribution_map( + n.model.to(self.device), + attribution_type=InputXGradient, + with_noise=False, + probe_data=data, + ) + + def __call__(self, n1, n2): + attrgraph_1 = self._cache.get(n1, None) + attrgraph_2 = self._cache.get(n2, None) + if attrgraph_1 is None: + self._cache[n1] = attrgraph_1 = self._get_attr_map(n1).cpu() + if attrgraph_2 is None: + self._cache[n2] = attrgraph_2 = self._get_attr_map(n2).cpu() + result = depara.attr_map_distance(attrgraph_1, attrgraph_2) + self._cache[n1] = self._cache[n1] + self._cache[n2] = self._cache[n2] + return result + diff --git a/model_measuring/kamal/utils/__init__.py b/model_measuring/kamal/utils/__init__.py new file mode 100644 index 0000000..3b000ec --- /dev/null +++ b/model_measuring/kamal/utils/__init__.py @@ -0,0 +1,2 @@ +from ._utils import * +from .logger import get_logger \ No newline at end of file diff --git a/model_measuring/kamal/utils/_utils.py b/model_measuring/kamal/utils/_utils.py new file mode 100644 index 0000000..4560f23 --- /dev/null +++ b/model_measuring/kamal/utils/_utils.py @@ -0,0 +1,153 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +import numpy as np +import math +import torch +import random +from copy import deepcopy +import contextlib, hashlib + +def split_batch(batch): + if isinstance(batch, (list, tuple)): + inputs, *targets = batch + if len(targets)==1: + targets = targets[0] + return inputs, targets + else: + return [batch, None] + +@contextlib.contextmanager +def set_mode(model, training=True): + ori_mode = model.training + model.train(training) + yield + model.train(ori_mode) + +def move_to_device(obj, device): + if isinstance(obj, torch.Tensor): + return obj.to(device=device) + elif isinstance( obj, (list, tuple) ): + return [ o.to(device=device) for o in obj ] + elif isinstance(obj, nn.Module): + return obj.to(device=device) + + +def pack_images(images, col=None, channel_last=False): + # N, C, H, W + if isinstance(images, (list, tuple) ): + images = np.stack(images, 0) + if channel_last: + images = images.transpose(0,3,1,2) # make it channel first + assert len(images.shape)==4 + assert isinstance(images, np.ndarray) + + N,C,H,W = images.shape + if col is None: + col = int(math.ceil(math.sqrt(N))) + row = int(math.ceil(N / col)) + pack = np.zeros( (C, H*row, W*col), dtype=images.dtype ) + for idx, img in enumerate(images): + h = (idx//col) * H + w = (idx% col) * W + pack[:, h:h+H, w:w+W] = img + return pack + +def normalize(tensor, mean, std, reverse=False): + if reverse: + _mean = [ -m / s for m, s in zip(mean, std) ] + _std = [ 1/s for s in std ] + else: + _mean = mean + _std = std + _mean = torch.as_tensor(_mean, dtype=tensor.dtype, device=tensor.device) + _std = torch.as_tensor(_std, dtype=tensor.dtype, device=tensor.device) + tensor = (tensor - _mean[None, :, None, None]) / (_std[None, :, None, None]) + return tensor + +class Normalizer(object): + def __init__(self, mean, std, reverse=False): + self.mean = mean + self.std = std + self.reverse = reverse + + def __call__(self, x): + if self.reverse: + return self.denormalize(x) + else: + return self.normalize(x) + + def normalize(self, x): + return normalize( x, self.mean, self.std ) + + def denormalize(self, x): + return normalize( x, self.mean, self.std, reverse=True ) + + +def colormap(N=256, normalized=False): + def bitget(byteval, idx): + return ((byteval & (1 << idx)) != 0) + + dtype = 'float32' if normalized else 'uint8' + cmap = np.zeros((N, 3), dtype=dtype) + for i in range(N): + r = g = b = 0 + c = i + for j in range(8): + r = r | (bitget(c, 0) << 7-j) + g = g | (bitget(c, 1) << 7-j) + b = b | (bitget(c, 2) << 7-j) + c = c >> 3 + + cmap[i] = np.array([r, g, b]) + + cmap = cmap/255 if normalized else cmap + return cmap + +DEFAULT_COLORMAP = colormap() + +def flatten_dict(dic): + flattned = dict() + + def _flatten(prefix, d): + for k, v in d.items(): + if isinstance(v, dict): + if prefix is None: + _flatten( k, v ) + else: + _flatten( prefix+'%s/'%k, v ) + else: + flattned[ (prefix+'%s/'%k).strip('/') ] = v + + _flatten('', dic) + return flattned + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + +def count_parameters(model): + return sum( [ p.numel() for p in model.parameters() ] ) + +def md5(fname): + hash_md5 = hashlib.md5() + with open(fname, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() \ No newline at end of file diff --git a/model_measuring/kamal/utils/logger.py b/model_measuring/kamal/utils/logger.py new file mode 100644 index 0000000..8a58b23 --- /dev/null +++ b/model_measuring/kamal/utils/logger.py @@ -0,0 +1,56 @@ +import logging +import os, sys +from termcolor import colored + +class _ColorfulFormatter(logging.Formatter): + def __init__(self, *args, **kwargs): + super(_ColorfulFormatter, self).__init__(*args, **kwargs) + + def formatMessage(self, record): + log = super(_ColorfulFormatter, self).formatMessage(record) + + if record.levelno == logging.WARNING: + prefix = colored("WARNING", "yellow", attrs=["blink"]) + elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: + prefix = colored("ERROR", "red", attrs=["blink", "underline"]) + else: + return log + + return prefix + " " + log + +def get_logger(name='Kamal', output=None, color=True): + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + logger.propagate = False + + # STDOUT + stdout_handler = logging.StreamHandler( stream=sys.stdout ) + stdout_handler.setLevel( logging.DEBUG ) + + plain_formatter = logging.Formatter( + "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" ) + if color: + formatter = _ColorfulFormatter( + colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", + datefmt="%m/%d %H:%M:%S") + else: + formatter = plain_formatter + stdout_handler.setFormatter(formatter) + + logger.addHandler(stdout_handler) + + # FILE + if output is not None: + if output.endswith('.txt') or output.endswith('.log'): + os.makedirs(os.path.dirname(output), exist_ok=True) + filename = output + else: + os.makedirs(output, exist_ok=True) + filename = os.path.join(output, "log.txt") + file_handler = logging.FileHandler(filename) + file_handler.setFormatter(plain_formatter) + file_handler.setLevel(logging.DEBUG) + logger.addHandler(file_handler) + return logger + + diff --git a/model_measuring/kamal/vision/__init__.py b/model_measuring/kamal/vision/__init__.py new file mode 100644 index 0000000..6c37b03 --- /dev/null +++ b/model_measuring/kamal/vision/__init__.py @@ -0,0 +1,3 @@ +from . import models +from . import datasets +from . import sync_transforms \ No newline at end of file diff --git a/model_measuring/kamal/vision/datasets/__init__.py b/model_measuring/kamal/vision/datasets/__init__.py new file mode 100644 index 0000000..5f2df69 --- /dev/null +++ b/model_measuring/kamal/vision/datasets/__init__.py @@ -0,0 +1,16 @@ +from .ade20k import ADE20K +from .caltech import Caltech101, Caltech256 +from .camvid import CamVid +from .cityscapes import Cityscapes +from .cub200 import CUB200 +from .fgvc_aircraft import FGVCAircraft +from .nyu import NYUv2 +from .stanford_cars import StanfordCars +from .stanford_dogs import StanfordDogs +from .sunrgbd import SunRGBD +from .voc import VOCClassification, VOCSegmentation +from .dataset import LabelConcatDataset + +from torchvision import datasets as torchvision_datasets + +from .unlabeled import UnlabeledDataset \ No newline at end of file diff --git a/model_measuring/kamal/vision/datasets/ade20k.py b/model_measuring/kamal/vision/datasets/ade20k.py new file mode 100644 index 0000000..f6a8fa8 --- /dev/null +++ b/model_measuring/kamal/vision/datasets/ade20k.py @@ -0,0 +1,70 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +import collections +import torch +import torchvision +import numpy as np +from PIL import Image +import os +from torchvision.datasets import VisionDataset +from .utils import colormap + +class ADE20K(VisionDataset): + cmap = colormap() + + def __init__( + self, + root, + split="training", + transform=None, + target_transform=None, + transforms=None, + ): + super( ADE20K, self ).__init__( root=root, transforms=transforms, transform=transform, target_transform=target_transform ) + assert split in ['training', 'validation'], "split should be \'training\' or \'validation\'" + self.root = os.path.expanduser(root) + self.split = split + self.num_classes = 150 + + img_list = [] + lbl_list = [] + img_dir = os.path.join( self.root, 'images', self.split ) + lbl_dir = os.path.join( self.root, 'annotations', self.split ) + + for img_name in os.listdir( img_dir ): + img_list.append( os.path.join( img_dir, img_name ) ) + lbl_list.append( os.path.join( lbl_dir, img_name[:-3]+'png') ) + + self.img_list = img_list + self.lbl_list = lbl_list + + def __len__(self): + return len(self.img_list) + + def __getitem__(self, index): + img = Image.open( self.img_list[index] ) + lbl = Image.open( self.lbl_list[index] ) + if self.transforms: + img, lbl = self.transforms(img, lbl) + lbl = np.array(lbl, dtype='uint8')-1 # 1-150 => 0-149 + 255 + return img, lbl + + @classmethod + def decode_seg_to_color(cls, mask): + """decode semantic mask to RGB image""" + return cls.cmap[mask+1] diff --git a/model_measuring/kamal/vision/datasets/caltech.py b/model_measuring/kamal/vision/datasets/caltech.py new file mode 100644 index 0000000..885f1c8 --- /dev/null +++ b/model_measuring/kamal/vision/datasets/caltech.py @@ -0,0 +1,226 @@ +# Modified from https://github.com/pytorch/vision/blob/master/torchvision/datasets/caltech.py +from __future__ import print_function +from PIL import Image +import os +import os.path + +from torchvision.datasets.vision import VisionDataset +from torchvision.datasets.utils import download_url + + +class Caltech101(VisionDataset): + """`Caltech 101 `_ Dataset. + + Args: + root (string): Root directory of dataset where directory + ``caltech101`` exists or will be saved to if download is set to True. + target_type (string or list, optional): Type of target to use, ``category`` or + ``annotation``. Can also be a list to output a tuple with all specified target types. + ``category`` represents the target class, and ``annotation`` is a list of points + from a hand-generated outline. Defaults to ``category``. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + """ + + def __init__(self, root, target_type="category", train=True, + transform=None, target_transform=None, + download=False): + super(Caltech101, self).__init__(os.path.join(root, 'caltech101')) + self.train = train + self.dir_name = '101_ObjectCategories_split/train' if self.train else '101_ObjectCategories_split/test' + + os.makdirs(self.root, exist_ok=True) + if isinstance(target_type, list): + self.target_type = target_type + else: + self.target_type = [target_type] + self.transform = transform + self.target_transform = target_transform + + if download: + self.download() + + if not self._check_integrity(): + raise RuntimeError('Dataset not found or corrupted.' + + ' You can use download=True to download it') + + self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories"))) + self.categories.remove("BACKGROUND_Google") # this is not a real class + + # For some reason, the category names in "101_ObjectCategories" and + # "Annotations" do not always match. This is a manual map between the + # two. Defaults to using same name, since most names are fine. + name_map = {"Faces": "Faces_2", + "Faces_easy": "Faces_3", + "Motorbikes": "Motorbikes_16", + "airplanes": "Airplanes_Side_2"} + self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories)) + + self.index = [] + self.y = [] + for (i, c) in enumerate(self.categories): + file_names = os.listdir(os.path.join(self.root, self.dir_name, c)) + n = len(file_names) + self.index.extend( file_names ) + self.y.extend(n * [i]) + + print(self.train, len(self.index)) + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where the type of target specified by target_type. + """ + import scipy.io + + img = Image.open(os.path.join(self.root, + self.dir_name, + self.categories[self.y[index]], + self.index[index])).convert("RGB") + target = [] + for t in self.target_type: + if t == "category": + target.append(self.y[index]) + elif t == "annotation": + data = scipy.io.loadmat(os.path.join(self.root, + "Annotations", + self.annotation_categories[self.y[index]], + "annotation_{:04d}.mat".format(self.index[index]))) + target.append(data["obj_contour"]) + else: + raise ValueError("Target type \"{}\" is not recognized.".format(t)) + target = tuple(target) if len(target) > 1 else target[0] + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def _check_integrity(self): + # can be more robust and check hash of files + return os.path.exists(os.path.join(self.root, "101_ObjectCategories")) + + def __len__(self): + return len(self.index) + + def download(self): + import tarfile + + if self._check_integrity(): + print('Files already downloaded and verified') + return + + download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz", + self.root, + "101_ObjectCategories.tar.gz", + "b224c7392d521a49829488ab0f1120d9") + download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar", + self.root, + "101_Annotations.tar", + "6f83eeb1f24d99cab4eb377263132c91") + + # extract file + with tarfile.open(os.path.join(self.root, "101_ObjectCategories.tar.gz"), "r:gz") as tar: + tar.extractall(path=self.root) + + with tarfile.open(os.path.join(self.root, "101_Annotations.tar"), "r:") as tar: + tar.extractall(path=self.root) + + def extra_repr(self): + return "Target type: {target_type}".format(**self.__dict__) + + +class Caltech256(VisionDataset): + """`Caltech 256 `_ Dataset. + + Args: + root (string): Root directory of dataset where directory + ``caltech256`` exists or will be saved to if download is set to True. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + """ + + def __init__(self, root, + transform=None, target_transform=None, + download=False): + super(Caltech256, self).__init__(os.path.join(root, 'caltech256')) + os.makedirs(self.root, exist_ok=True) + self.transform = transform + self.target_transform = target_transform + + if download: + self.download() + + if not self._check_integrity(): + raise RuntimeError('Dataset not found or corrupted.' + + ' You can use download=True to download it') + + self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories"))) + self.index = [] + self.y = [] + for (i, c) in enumerate(self.categories): + n = len(os.listdir(os.path.join(self.root, "256_ObjectCategories", c))) + self.index.extend(range(1, n + 1)) + self.y.extend(n * [i]) + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is index of the target class. + """ + img = Image.open(os.path.join(self.root, + "256_ObjectCategories", + self.categories[self.y[index]], + "{:03d}_{:04d}.jpg".format(self.y[index] + 1, self.index[index]))) + + target = self.y[index] + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def _check_integrity(self): + # can be more robust and check hash of files + return os.path.exists(os.path.join(self.root, "256_ObjectCategories")) + + def __len__(self): + return len(self.index) + + def download(self): + import tarfile + + if self._check_integrity(): + print('Files already downloaded and verified') + return + + download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar", + self.root, + "256_ObjectCategories.tar", + "67b4f42ca05d46448c6bb8ecd2220f6d") + + # extract file + with tarfile.open(os.path.join(self.root, "256_ObjectCategories.tar"), "r:") as tar: + tar.extractall(path=self.root) diff --git a/model_measuring/kamal/vision/datasets/camvid.py b/model_measuring/kamal/vision/datasets/camvid.py new file mode 100644 index 0000000..63f367b --- /dev/null +++ b/model_measuring/kamal/vision/datasets/camvid.py @@ -0,0 +1,78 @@ +# Modified from https://github.com/davidtvs/PyTorch-ENet/blob/master/data/camvid.py +import os +import torch.utils.data as data +from glob import glob +from PIL import Image +import numpy as np +from torchvision.datasets import VisionDataset + +class CamVid(VisionDataset): + """CamVid dataset loader where the dataset is arranged as in https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid. + + Args: + root (string): + split (string): The type of dataset: 'train', 'val', 'trainval', or 'test' + transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. Default: None. + target_transform (callable, optional): A function/transform that takes in the target and transform it. Default: None. + transforms (callable, optional): A function/transform that takes in both the image and target and transform them. Default: None. + """ + + cmap = np.array([ + (128, 128, 128), + (128, 0, 0), + (192, 192, 128), + (128, 64, 128), + (60, 40, 222), + (128, 128, 0), + (192, 128, 128), + (64, 64, 128), + (64, 0, 128), + (64, 64, 0), + (0, 128, 192), + (0, 0, 0), + ]) + + def __init__(self, + root, + split='train', + transform=None, + target_transform=None, + transforms=None): + assert split in ('train', 'val', 'test', 'trainval') + super( CamVid, self ).__init__(root=root, transforms=transforms, transform=transform, target_transform=target_transform) + self.root = os.path.expanduser(root) + self.split = split + + if split == 'trainval': + self.images = glob(os.path.join(self.root, 'train', '*.png')) + glob(os.path.join(self.root, 'val', '*.png')) + self.labels = glob(os.path.join(self.root, 'trainannot', '*.png')) + glob(os.path.join(self.root, 'valannot', '*.png')) + else: + self.images = glob(os.path.join(self.root, self.split, '*.png')) + self.labels = glob(os.path.join(self.root, self.split+'annot', '*.png')) + + self.images.sort() + self.labels.sort() + + def __getitem__(self, idx): + """ + Args: + - index (``int``): index of the item in the dataset + Returns: + A tuple of ``PIL.Image`` (image, label) where label is the ground-truth + of the image. + """ + + img, label = Image.open(self.images[idx]), Image.open(self.labels[idx]) + if self.transforms is not None: + img, label = self.transforms(img, label) + label[label == 11] = 255 # ignore void + return img, label.squeeze(0) + + def __len__(self): + return len(self.images) + + @classmethod + def decode_fn(cls, mask): + """decode semantic mask to RGB image""" + mask[mask == 255] = 11 + return cls.cmap[mask] \ No newline at end of file diff --git a/model_measuring/kamal/vision/datasets/cityscapes.py b/model_measuring/kamal/vision/datasets/cityscapes.py new file mode 100644 index 0000000..67fbf83 --- /dev/null +++ b/model_measuring/kamal/vision/datasets/cityscapes.py @@ -0,0 +1,146 @@ +# Modified from https://github.com/pytorch/vision/blob/master/torchvision/datasets/cityscapes.py +import json +import os +from collections import namedtuple + +import torch +import torch.utils.data as data +from PIL import Image +import numpy as np +from torchvision.datasets import VisionDataset + +class Cityscapes(VisionDataset): + """Cityscapes Dataset. + + Args: + root (string): Root directory of dataset where directory 'leftImg8bit' and 'gtFine' or 'gtCoarse' are located. + split (string, optional): The image split to use, 'train', 'test' or 'val' if mode="gtFine" otherwise 'train', 'train_extra' or 'val' + mode (string, optional): The quality mode to use, 'gtFine' or 'gtCoarse' or 'color'. Can also be a list to output a tuple with all specified target types. + transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the target and transforms it. + """ + + # Based on https://github.com/mcordts/cityscapesScripts + CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id', + 'has_instances', 'ignore_in_eval', 'color']) + classes = [ + CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)), + CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)), + CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)), + CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)), + CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)), + CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)), + CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)), + CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)), + CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)), + CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)), + CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)), + CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)), + CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)), + CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)), + CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)), + CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)), + CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)), + CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)), + CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)), + CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)), + CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)), + CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)), + CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)), + CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)), + CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)), + CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)), + CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)), + CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)), + CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)), + CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)), + CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)), + CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)), + CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)), + CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)), + CityscapesClass('license plate', -1, 255, 'vehicle', 7, False, True, (0, 0, 142)), + ] + + _TRAIN_ID_TO_COLOR = [c.color for c in classes if (c.train_id != -1 and c.train_id != 255)] + _TRAIN_ID_TO_COLOR.append([0, 0, 0]) + _TRAIN_ID_TO_COLOR = np.array(_TRAIN_ID_TO_COLOR) + _ID_TO_TRAIN_ID = np.array([c.train_id for c in classes]) + + def __init__(self, root, split='train', mode='gtFine', target_type='semantic', transform=None, target_transform=None, transforms=None): + super(Cityscapes, self).__init__( root, transform=transform, target_transform=target_transform, transforms=transforms ) + self.root = os.path.expanduser(root) + self.mode = mode + self.target_type = target_type + + self.images_dir = os.path.join(self.root, 'leftImg8bit', split) + self.targets_dir = os.path.join(self.root, self.mode, split) + self.split = split + + self.images = [] + self.targets = [] + + if split not in ['train', 'test', 'val']: + raise ValueError('Invalid split for mode! Please use split="train", split="test"' + ' or split="val"') + + if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir): + raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the' + ' specified "split" and "mode" are inside the "root" directory') + + for city in os.listdir(self.images_dir): + img_dir = os.path.join(self.images_dir, city) + target_dir = os.path.join(self.targets_dir, city) + + for file_name in os.listdir(img_dir): + self.images.append(os.path.join(img_dir, file_name)) + target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0], + self._get_target_suffix(self.mode, self.target_type)) + self.targets.append(os.path.join(target_dir, target_name)) + + @classmethod + def encode_target(cls, target): + if isinstance( target, torch.Tensor ): + return torch.from_numpy( cls._ID_TO_TRAIN_ID[np.array(target)] ) + else: + return cls._ID_TO_TRAIN_ID[target] + + @classmethod + def decode_fn(cls, target): + target[target == 255] = 19 + #target = target.astype('uint8') + 1 + return cls._TRAIN_ID_TO_COLOR[target] + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: (image, target) where target is a tuple of all target types if target_type is a list with more + than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation. + """ + image = Image.open(self.images[index]).convert('RGB') + target = Image.open(self.targets[index]) + if self.transforms: + image, target = self.transforms(image, target) + target = self.encode_target(target) + return image, target + + def __len__(self): + return len(self.images) + + def _load_json(self, path): + with open(path, 'r') as file: + data = json.load(file) + return data + + def _get_target_suffix(self, mode, target_type): + if target_type == 'instance': + return '{}_instanceIds.png'.format(mode) + elif target_type == 'semantic': + return '{}_labelIds.png'.format(mode) + elif target_type == 'color': + return '{}_color.png'.format(mode) + elif target_type == 'polygon': + return '{}_polygons.json'.format(mode) + elif target_type == 'depth': + return '{}_disparity.png'.format(mode) \ No newline at end of file diff --git a/model_measuring/kamal/vision/datasets/cub200.py b/model_measuring/kamal/vision/datasets/cub200.py new file mode 100644 index 0000000..7c57b16 --- /dev/null +++ b/model_measuring/kamal/vision/datasets/cub200.py @@ -0,0 +1,70 @@ +# Modified from https://github.com/TDeVries/cub2011_dataset/blob/master/cub2011.py +import os +import pandas as pd +from torchvision.datasets.folder import default_loader +from .utils import download_url +from torch.utils.data import Dataset +import shutil + + +class CUB200(Dataset): + base_folder = 'images' + url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz' + filename = 'CUB_200_2011.tgz' + tgz_md5 = '97eceeb196236b17998738112f37df78' + + def __init__(self, root, split='train', transform=None, target_transform=None, loader=default_loader, download=False): + self.root = os.path.abspath( os.path.expanduser( root ) ) + self.transform = transform + self.target_transform = target_transform + self.loader = default_loader + self.split = split + + if download: + self.download() + self._load_metadata() + categories = os.listdir(os.path.join( + self.root, 'CUB_200_2011', 'images')) + categories.sort() + self.object_categories = [c[4:] for c in categories] + print('CUB200, Split: %s, Size: %d' % (self.split, self.__len__())) + + def _load_metadata(self): + images = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'images.txt'), sep=' ', + names=['img_id', 'filepath']) + image_class_labels = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'image_class_labels.txt'), + sep=' ', names=['img_id', 'target'], encoding='latin-1', engine='python') + train_test_split = pd.read_csv(os.path.join(self.root, 'CUB_200_2011', 'train_test_split.txt'), + sep=' ', names=['img_id', 'is_training_img'], encoding='latin-1', engine='python') + data = images.merge(image_class_labels, on='img_id') + self.data = data.merge(train_test_split, on='img_id') + + if self.split == 'train': + self.data = self.data[self.data.is_training_img == 1] + else: + self.data = self.data[self.data.is_training_img == 0] + + def download(self): + import tarfile + os.makedirs(self.root, exist_ok=True) + if not os.path.isfile(os.path.join(self.root, self.filename)): + download_url(self.url, self.root, self.filename) + print("Extracting %s..." % self.filename) + with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: + tar.extractall(path=self.root) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + sample = self.data.iloc[idx] + path = os.path.join(self.root, 'CUB_200_2011', + self.base_folder, sample.filepath) + lbl = sample.target - 1 + img = self.loader(path) + + if self.transform is not None: + img = self.transform(img) + if self.target_transform is not None: + lbl = self.target_transform(lbl) + return img, lbl diff --git a/model_measuring/kamal/vision/datasets/dataset.py b/model_measuring/kamal/vision/datasets/dataset.py new file mode 100644 index 0000000..315855d --- /dev/null +++ b/model_measuring/kamal/vision/datasets/dataset.py @@ -0,0 +1,57 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + + +from torchvision.datasets import VisionDataset +from PIL import Image +import torch + +class LabelConcatDataset(VisionDataset): + """Dataset as a concatenation of dataset's lables. + + This class is useful to assemble the same dataset's labels. + + Arguments: + datasets (sequence): List of datasets to be concatenated + tasks (list) : List of teacher tasks + transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. + target_transform (callable, optional): A function/transform that takes in the target and transforms it. + transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version. + """ + + def __init__(self, datasets, transforms=None, transform=None, target_transform=None): + super(LabelConcatDataset, self).__init__( + root=None, transforms=transforms, transform=transform, target_transform=target_transform) + assert len(datasets) > 0, 'datasets should not be an empty iterable' + self.datasets = list(datasets) + for d in self.datasets: + for t in ('transform', 'transforms', 'target_transform'): + if hasattr( d, t ): + setattr( d, t, None ) + + def __getitem__(self, idx): + targets_list = [] + for dst in self.datasets: + image, target = dst[idx] + targets_list.append(target) + if self.transforms is not None: + image, *targets_list = self.transforms( image, *targets_list ) + data = [ image, *targets_list ] + return data + + def __len__(self): + return len(self.datasets[0].images) diff --git a/model_measuring/kamal/vision/datasets/fgvc_aircraft.py b/model_measuring/kamal/vision/datasets/fgvc_aircraft.py new file mode 100644 index 0000000..591d0ee --- /dev/null +++ b/model_measuring/kamal/vision/datasets/fgvc_aircraft.py @@ -0,0 +1,143 @@ +#Modified from https://github.com/pytorch/vision/pull/467/files +from __future__ import print_function +import torch.utils.data as data +from torchvision.datasets.folder import pil_loader, accimage_loader, default_loader +from PIL import Image +import os +import numpy as np + +from .utils import download_url, mkdir + + +def make_dataset(dir, image_ids, targets): + assert(len(image_ids) == len(targets)) + images = [] + dir = os.path.expanduser(dir) + for i in range(len(image_ids)): + item = (os.path.join(dir, 'fgvc-aircraft-2013b', 'data', 'images', + '%s.jpg' % image_ids[i]), targets[i]) + images.append(item) + return images + + +def find_classes(classes_file): + # read classes file, separating out image IDs and class names + image_ids = [] + targets = [] + f = open(classes_file, 'r') + for line in f: + split_line = line.split(' ') + image_ids.append(split_line[0]) + targets.append(' '.join(split_line[1:])) + f.close() + + # index class names + classes = np.unique(targets) + class_to_idx = {classes[i]: i for i in range(len(classes))} + targets = [class_to_idx[c] for c in targets] + + return (image_ids, targets, classes, class_to_idx) + + +class FGVCAircraft(data.Dataset): + """`FGVC-Aircraft `_ Dataset. + Args: + root (string): Root directory path to dataset. + class_type (string, optional): The level of FGVC-Aircraft fine-grain classification + to label data with (i.e., ``variant``, ``family``, or ``manufacturer``). + transforms (callable, optional): A function/transforms that takes in a PIL image + and returns a transformed version. E.g. ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transforms that takes in the + target and transforms it. + loader (callable, optional): A function to load an image given its path. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in the root directory. If dataset is already downloaded, it is not + downloaded again. + """ + url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz' + class_types = ('variant', 'family', 'manufacturer') + splits = ('train', 'val', 'trainval', 'test') + + def __init__(self, root, class_type='variant', split='train', transform=None, + target_transform=None, loader=default_loader, download=False): + if split not in self.splits: + raise ValueError('Split "{}" not found. Valid splits are: {}'.format( + split, ', '.join(self.splits), + )) + if class_type not in self.class_types: + raise ValueError('Class type "{}" not found. Valid class types are: {}'.format( + class_type, ', '.join(self.class_types), + )) + + self.root = root + self.class_type = class_type + self.split = split + self.classes_file = os.path.join(self.root, 'fgvc-aircraft-2013b', 'data', + 'images_%s_%s.txt' % (self.class_type, self.split)) + if download: + self.download() + + (image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file) + samples = make_dataset(self.root, image_ids, targets) + + self.transform = transform + self.target_transform = target_transform + self.loader = loader + + self.samples = samples + self.classes = classes + self.class_to_idx = class_to_idx + + with open(os.path.join(self.root, 'fgvc-aircraft-2013b/data', 'variants.txt')) as f: + self.object_categories = [ + line.strip('\n') for line in f.readlines()] + print('FGVC-Aircraft, Split: %s, Size: %d' % (self.split, self.__len__())) + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: (sample, target) where target is class_index of the target class. + """ + + path, target = self.samples[index] + sample = self.loader(path) + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + target = self.target_transform(target) + return sample, target + + def __len__(self): + return len(self.samples) + + def __repr__(self): + fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' + fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) + fmt_str += ' Root Location: {}\n'.format(self.root) + tmp = ' Transforms (if any): ' + fmt_str += '{0}{1}\n'.format( + tmp, self.transforms.__repr__().replace('\n', '\n' + ' ' * len(tmp))) + tmp = ' Target Transforms (if any): ' + fmt_str += '{0}{1}'.format( + tmp, self.target_transforms.__repr__().replace('\n', '\n' + ' ' * len(tmp))) + return fmt_str + + def _check_exists(self): + return os.path.exists(os.path.join(self.root, 'data', 'images')) and \ + os.path.exists(self.classes_file) + + def download(self): + """Download the FGVC-Aircraft data if it doesn't exist already.""" + from six.moves import urllib + import tarfile + + mkdir(self.root) + + fpath = os.path.join(self.root, 'fgvc-aircraft-2013b.tar.gz') + if not os.path.isfile(fpath): + download_url(self.url, self.root, 'fgvc-aircraft-2013b.tar.gz') + print("Extracting fgvc-aircraft-2013b.tar.gz...") + with tarfile.open(fpath, "r:gz") as tar: + tar.extractall(path=self.root) diff --git a/model_measuring/kamal/vision/datasets/nyu.py b/model_measuring/kamal/vision/datasets/nyu.py new file mode 100644 index 0000000..0c3e737 --- /dev/null +++ b/model_measuring/kamal/vision/datasets/nyu.py @@ -0,0 +1,84 @@ +# Modified from https://github.com/VainF/nyuv2-python-toolkit +import os +import torch +import torch.utils.data as data +from PIL import Image +from scipy.io import loadmat +import numpy as np +import glob +from torchvision import transforms +from torchvision.datasets import VisionDataset +import random + +from .utils import colormap + +class NYUv2(VisionDataset): + """NYUv2 dataset + See https://github.com/VainF/nyuv2-python-toolkit for more details. + + Args: + root (string): Root directory path. + split (string, optional): 'train' for training set, and 'test' for test set. Default: 'train'. + target_type (string, optional): Type of target to use, ``semantic``, ``depth`` or ``normal``. + num_classes (int, optional): The number of classes, must be 40 or 13. Default:13. + transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. + target_transform (callable, optional): A function/transform that takes in the target and transforms it. + transforms (callable, optional): A function/transform that takes input sample and its target as entry and returns a transformed version. + """ + cmap = colormap() + def __init__(self, + root, + split='train', + target_type='semantic', + num_classes=13, + transforms=None, + transform=None, + target_transform=None): + super( NYUv2, self ).__init__(root, transforms=transforms, transform=transform, target_transform=target_transform) + assert(split in ('train', 'test')) + + self.root = root + self.split = split + self.target_type = target_type + self.num_classes = num_classes + + split_mat = loadmat(os.path.join(self.root, 'splits.mat')) + idxs = split_mat[self.split+'Ndxs'].reshape(-1) - 1 + + img_names = os.listdir( os.path.join(self.root, 'image', self.split) ) + img_names.sort() + images_dir = os.path.join(self.root, 'image', self.split) + self.images = [os.path.join(images_dir, name) for name in img_names] + + self._is_depth = False + if self.target_type=='semantic': + semantic_dir = os.path.join(self.root, 'seg%d'%self.num_classes, self.split) + self.labels = [os.path.join(semantic_dir, name) for name in img_names] + self.targets = self.labels + + if self.target_type=='depth': + depth_dir = os.path.join(self.root, 'depth', self.split) + self.depths = [os.path.join(depth_dir, name) for name in img_names] + self.targets = self.depths + self._is_depth = True + + if self.target_type=='normal': + normal_dir = os.path.join(self.root, 'normal', self.split) + self.normals = [os.path.join(normal_dir, name) for name in img_names] + self.targets = self.normals + + def __getitem__(self, idx): + image = Image.open(self.images[idx]) + target = Image.open(self.targets[idx]) + if self.transforms is not None: + image, target = self.transforms( image, target ) + return image, target + + def __len__(self): + return len(self.images) + + @classmethod + def decode_fn(cls, mask: np.ndarray): + """decode semantic mask to RGB image""" + mask = mask.astype('uint8') + 1 # 255 => 0 + return cls.cmap[mask] diff --git a/model_measuring/kamal/vision/datasets/preprocess/prepare_caltech101.py b/model_measuring/kamal/vision/datasets/preprocess/prepare_caltech101.py new file mode 100644 index 0000000..83d52cf --- /dev/null +++ b/model_measuring/kamal/vision/datasets/preprocess/prepare_caltech101.py @@ -0,0 +1,61 @@ + +import os, sys +from glob import glob +import random +import argparse +from PIL import Image + +if __name__=='__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--data_root', type=str, default='../101_ObjectCategories') + parser.add_argument('--test_split', type=float, default=0.3) + args = parser.parse_args() + + SAVE_DIR = os.path.join( os.path.dirname(args.data_root), 'caltech101_data' ) + if not os.path.exists(SAVE_DIR): + os.mkdir(SAVE_DIR) + + # Train + TRAIN_DIR = os.path.join( SAVE_DIR, 'train' ) + if not os.path.exists(TRAIN_DIR): + os.mkdir(TRAIN_DIR) + + # Test + TEST_DIR = os.path.join( SAVE_DIR, 'test' ) + if not os.path.exists(TEST_DIR): + os.mkdir(TEST_DIR) + + img_folders = os.listdir(args.data_root) + img_folders.sort() + + for folder in img_folders: + if folder=='Faces': + continue + print('Processing %s'%(folder)) + + img_paths = glob(os.path.join( args.data_root, folder, '*.jpg') ) + img_name = [os.path.split(p)[-1] for p in img_paths] + + random.shuffle(img_name) + + img_n = len(img_name) + test_n = int(args.test_split * img_n) + + test_set = img_name[:test_n] + train_set = img_name[test_n:] + + # test + dst_path = os.path.join(TEST_DIR, folder) + if not os.path.exists(dst_path): + os.mkdir(dst_path) + for test_name in test_set: + img = Image.open(os.path.join( args.data_root, folder, test_name )) + img.save( os.path.join(dst_path, test_name ) ) + + # train + dst_path = os.path.join(TRAIN_DIR, folder) + if not os.path.exists(dst_path): + os.mkdir(dst_path) + for train_name in train_set: + img = Image.open(os.path.join( args.data_root, folder, train_name )) + img.save( os.path.join(dst_path, train_name ) ) \ No newline at end of file diff --git a/model_measuring/kamal/vision/datasets/preprocess/prepare_stl10.py b/model_measuring/kamal/vision/datasets/preprocess/prepare_stl10.py new file mode 100644 index 0000000..352a18f --- /dev/null +++ b/model_measuring/kamal/vision/datasets/preprocess/prepare_stl10.py @@ -0,0 +1,196 @@ +from __future__ import print_function + +import sys +import os, sys, tarfile, errno +import numpy as np +import matplotlib.pyplot as plt +import argparse + +if sys.version_info >= (3, 0, 0): + import urllib.request as urllib # ugly but works +else: + import urllib + +try: + from imageio import imsave +except: + from scipy.misc import imsave + +print(sys.version_info) + +# image shape +HEIGHT = 96 +WIDTH = 96 +DEPTH = 3 + +# size of a single image in bytes +SIZE = HEIGHT * WIDTH * DEPTH + +# path to the directory with the data +#DATA_DIR = './data' + +# url of the binary data +#DATA_URL = 'http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz' + +# path to the binary train file with image data +#DATA_PATH = './data/stl10_binary/train_X.bin' + +# path to the binary train file with labels +#LABEL_PATH = './data/stl10_binary/train_y.bin' + +def read_labels(path_to_labels): + """ + :param path_to_labels: path to the binary file containing labels from the STL-10 dataset + :return: an array containing the labels + """ + with open(path_to_labels, 'rb') as f: + labels = np.fromfile(f, dtype=np.uint8) + return labels + + +def read_all_images(path_to_data): + """ + :param path_to_data: the file containing the binary images from the STL-10 dataset + :return: an array containing all the images + """ + + with open(path_to_data, 'rb') as f: + # read whole file in uint8 chunks + everything = np.fromfile(f, dtype=np.uint8) + + # We force the data into 3x96x96 chunks, since the + # images are stored in "column-major order", meaning + # that "the first 96*96 values are the red channel, + # the next 96*96 are green, and the last are blue." + # The -1 is since the size of the pictures depends + # on the input file, and this way numpy determines + # the size on its own. + + images = np.reshape(everything, (-1, 3, 96, 96)) + + # Now transpose the images into a standard image format + # readable by, for example, matplotlib.imshow + # You might want to comment this line or reverse the shuffle + # if you will use a learning algorithm like CNN, since they like + # their channels separated. + images = np.transpose(images, (0, 3, 2, 1)) + return images + + +def read_single_image(image_file): + """ + CAREFUL! - this method uses a file as input instead of the path - so the + position of the reader will be remembered outside of context of this method. + :param image_file: the open file containing the images + :return: a single image + """ + # read a single image, count determines the number of uint8's to read + image = np.fromfile(image_file, dtype=np.uint8, count=SIZE) + # force into image matrix + image = np.reshape(image, (3, 96, 96)) + # transpose to standard format + # You might want to comment this line or reverse the shuffle + # if you will use a learning algorithm like CNN, since they like + # their channels separated. + image = np.transpose(image, (2, 1, 0)) + return image + + +def plot_image(image): + """ + :param image: the image to be plotted in a 3-D matrix format + :return: None + """ + plt.imshow(image) + plt.show() + +def save_image(image, name): + imsave("%s.png" % name, image, format="png") + +def download_and_extract(DATA_DIR): + """ + Download and extract the STL-10 dataset + :return: None + """ + dest_directory = DATA_DIR + if not os.path.exists(dest_directory): + os.makedirs(dest_directory) + filename = DATA_URL.split('/')[-1] + filepath = os.path.join(dest_directory, filename) + if not os.path.exists(filepath): + def _progress(count, block_size, total_size): + sys.stdout.write('\rDownloading %s %.2f%%' % (filename, + float(count * block_size) / float(total_size) * 100.0)) + sys.stdout.flush() + filepath, _ = urllib.urlretrieve(DATA_URL, filepath, reporthook=_progress) + print('Downloaded', filename) + tarfile.open(filepath, 'r:gz').extractall(dest_directory) + +def save_images(images, labels, save_dir): + print("Saving images to disk") + i = 0 + if not os.path.exists(save_dir): + os.mkdir(save_dir) + + for image in images: + if labels is not None: + label = labels[i] + directory = os.path.join( save_dir, str(label) ) + if not os.path.exists(directory): + os.makedirs(directory) + else: + directory = save_dir + + filename = os.path.join( directory, str(i) ) + print(filename) + save_image(image, filename) + i = i+1 + +if __name__ == "__main__": + # download data if needed + # download_and_extract() + parser = argparse.ArgumentParser() + parser.add_argument('--DATA_DIR', type=str, default='../stl10_binary') + args = parser.parse_args() + + BASE_PATH = os.path.join( args.DATA_DIR, 'stl10_data' ) + if not os.path.exists(BASE_PATH): + os.mkdir(BASE_PATH) + + # Train + DATA_PATH = os.path.join( args.DATA_DIR, 'train_X.bin') + LABEL_PATH = os.path.join( args.DATA_DIR, 'train_y.bin') + + print('Preparing Train') + images = read_all_images(DATA_PATH) + print(images.shape) + labels = read_labels(LABEL_PATH) + print(labels.shape) + save_images(images, labels, os.path.join(BASE_PATH, 'train')) + + # Test + DATA_PATH = os.path.join( args.DATA_DIR, 'test_X.bin') + LABEL_PATH = os.path.join( args.DATA_DIR, 'test_y.bin') + + #with open(DATA_PATH) as f: + # image = read_single_image(f) + # plot_image(image) + print('Preparing Test') + images = read_all_images(DATA_PATH) + print(images.shape) + labels = read_labels(LABEL_PATH) + print(labels.shape) + save_images(images, labels, os.path.join(BASE_PATH, 'test')) + + # Unlabeled + print('Preparing Unlabeled') + DATA_PATH = os.path.join( args.DATA_DIR, 'unlabeled_X.bin') + images = read_all_images(DATA_PATH) + save_images(images, None, os.path.join(BASE_PATH, 'unlabeled')) + + + + + + + diff --git a/model_measuring/kamal/vision/datasets/preprocess/resize_camvid.py b/model_measuring/kamal/vision/datasets/preprocess/resize_camvid.py new file mode 100644 index 0000000..a208061 --- /dev/null +++ b/model_measuring/kamal/vision/datasets/preprocess/resize_camvid.py @@ -0,0 +1,53 @@ +import argparse +import os +from PIL import Image +import numpy as np + +SIZE=(480, 320) # W, H + +def is_image(path: str): + return path.endswith( 'png' ) or path.endswith('jpg') or path.endswith( 'jpeg' ) + +def copy_and_resize( src_dir, dst_dir, resize_fn ): + + for file_or_dir in os.listdir( src_dir ): + src = os.path.join( src_dir, file_or_dir ) + dst = os.path.join( dst_dir, file_or_dir ) + if os.path.isdir( src ): + os.mkdir( dst ) + copy_and_resize( src, dst, resize_fn ) + elif is_image( src ): + print(src, ' -> ', dst) + image = Image.open( src ) + resized_image = resize_fn(image) + resized_image.save( dst ) + +def resize_input( image: Image.Image ): + return image.resize( SIZE, resample=Image.BICUBIC ) + +def resize_seg( image: Image.Image ): + return image.resize( SIZE, resample=Image.NEAREST ) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--root', type=str, required=True) + ROOT = parser.parse_args().root + NEW_ROOT = os.path.join( ROOT, '%d_%d'%(*SIZE) ) + os.mkdir(NEW_ROOT, ) + for split in ['train', 'val', 'test']: + IMG_DIR = os.path.join( ROOT, split ) + GT_DIR = os.path.join( ROOT, split+'annot' ) + NEW_IMG_DIR = os.path.join( NEW_ROOT, split ) + NEW_GT_DIR = os.path.join( NEW_ROOT, split+'annot' ) + + os.mkdir( NEW_IMG_DIR ) + os.mkdir( NEW_GT_DIR ) + + copy_and_resize( IMG_DIR, NEW_IMG_DIR, resize_input ) + copy_and_resize( GT_DIR, NEW_GT_DIR, resize_seg ) + + + + +if __name__=='__main__': + main() diff --git a/model_measuring/kamal/vision/datasets/preprocess/resize_cityscapes.py b/model_measuring/kamal/vision/datasets/preprocess/resize_cityscapes.py new file mode 100644 index 0000000..3412dd3 --- /dev/null +++ b/model_measuring/kamal/vision/datasets/preprocess/resize_cityscapes.py @@ -0,0 +1,65 @@ +import argparse +import os +from PIL import Image +import numpy as np + +SIZE=(640, 320) + +def is_image(path: str): + return path.endswith( 'png' ) or path.endswith('jpg') or path.endswith( 'jpeg' ) + +def copy_and_resize( src_dir, dst_dir, resize_fn ): + + for file_or_dir in os.listdir( src_dir ): + src = os.path.join( src_dir, file_or_dir ) + dst = os.path.join( dst_dir, file_or_dir ) + if os.path.isdir( src ): + os.mkdir( dst ) + copy_and_resize( src, dst, resize_fn ) + elif is_image( src ): + print(src, ' -> ', dst) + image = Image.open( src ) + resized_image = resize_fn(image) + resized_image.save( dst ) + +def resize_input( image: Image.Image ): + return image.resize( SIZE, resample=Image.BICUBIC ) + +def resize_seg( image: Image.Image ): + return image.resize( SIZE, resample=Image.NEAREST ) + +def resize_depth( image: Image.Image ): + return image.resize( SIZE, resample=Image.NEAREST ) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--root', type=str, required=True) + + ROOT = parser.parse_args().root + IMG_DIR = os.path.join( ROOT, 'leftImg8bit' ) + GT_DIR = os.path.join( ROOT, 'gtFine' ) + DEPTH_DIR = os.path.join( ROOT, 'disparity' ) + + NEW_ROOT = os.path.join( ROOT, '%d_%d'%(*SIZE) ) + NEW_IMG_DIR = os.path.join( NEW_ROOT, 'leftImg8bit' ) + NEW_GT_DIR = os.path.join( NEW_ROOT, 'gtFine' ) + NEW_DEPTH_DIR = os.path.join( NEW_ROOT, 'disparity' ) + + if os.path.exists(NEW_ROOT): + print("Directory %s existed, please remove it before running this script"%NEW_ROOT) + return + + os.mkdir(NEW_ROOT) + os.mkdir( NEW_IMG_DIR ) + os.mkdir( NEW_GT_DIR ) + os.mkdir( NEW_DEPTH_DIR ) + + copy_and_resize( IMG_DIR, NEW_IMG_DIR, resize_input ) + copy_and_resize( GT_DIR, NEW_GT_DIR, resize_seg ) + copy_and_resize( DEPTH_DIR, NEW_DEPTH_DIR, resize_depth ) + + + + +if __name__=='__main__': + main() diff --git a/model_measuring/kamal/vision/datasets/preprocess/resize_voc.py b/model_measuring/kamal/vision/datasets/preprocess/resize_voc.py new file mode 100644 index 0000000..836d3f4 --- /dev/null +++ b/model_measuring/kamal/vision/datasets/preprocess/resize_voc.py @@ -0,0 +1,59 @@ +import argparse +import os +from PIL import Image +import numpy as np +from torchvision import transforms +import shutil + +SIZE=240 + +def is_image(path: str): + return path.endswith( 'png' ) or path.endswith('jpg') or path.endswith( 'jpeg' ) + +def copy_and_resize( src_dir, dst_dir, resize_fn ): + + for file_or_dir in os.listdir( src_dir ): + src = os.path.join( src_dir, file_or_dir ) + dst = os.path.join( dst_dir, file_or_dir ) + if os.path.isdir( src ): + os.mkdir( dst ) + copy_and_resize( src, dst, resize_fn ) + elif is_image( src ): + print(src, ' -> ', dst) + image = Image.open( src ) + resized_image = resize_fn(image) + resized_image.save( dst ) + +def resize_input( image: Image.Image ): + return transforms.functional.resize( image, SIZE, interpolation=Image.BILINEAR ) + +def resize_seg( image: Image.Image ): + return transforms.functional.resize( image, SIZE, interpolation=Image.NEAREST) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--root', type=str, required=True) + + ROOT = parser.parse_args().root + IMG_DIR = os.path.join( ROOT, 'JPEGImages' ) + GT_DIR = os.path.join( ROOT, 'SegmentationClass' ) + + NEW_ROOT = os.path.join( ROOT, str(SIZE) ) + NEW_IMG_DIR = os.path.join( NEW_ROOT, 'JPEGImages' ) + NEW_GT_DIR = os.path.join( NEW_ROOT, 'SegmentationClass' ) + + if os.path.exists(NEW_ROOT): + print("Directory %s existed, please remove it before running this script"%NEW_ROOT) + return + + os.mkdir(NEW_ROOT) + os.mkdir( NEW_IMG_DIR ) + os.mkdir( NEW_GT_DIR ) + + copy_and_resize( IMG_DIR, NEW_IMG_DIR, resize_input ) + copy_and_resize( GT_DIR, NEW_GT_DIR, resize_seg ) + shutil.copytree( os.path.join( ROOT, 'ImageSets'), os.path.join( NEW_ROOT, 'ImageSets' ) ) + + +if __name__=='__main__': + main() diff --git a/model_measuring/kamal/vision/datasets/preprocess/resize_voc_240.py b/model_measuring/kamal/vision/datasets/preprocess/resize_voc_240.py new file mode 100644 index 0000000..bd66551 --- /dev/null +++ b/model_measuring/kamal/vision/datasets/preprocess/resize_voc_240.py @@ -0,0 +1,57 @@ +import argparse +import os +from PIL import Image +import numpy as np +from torchvision import transforms +import shutil + +def is_image(path: str): + return path.endswith( 'png' ) or path.endswith('jpg') or path.endswith( 'jpeg' ) + +def copy_and_resize( src_dir, dst_dir, resize_fn ): + + for file_or_dir in os.listdir( src_dir ): + src = os.path.join( src_dir, file_or_dir ) + dst = os.path.join( dst_dir, file_or_dir ) + if os.path.isdir( src ): + os.mkdir( dst ) + copy_and_resize( src, dst, resize_fn ) + elif is_image( src ): + print(src, ' -> ', dst) + image = Image.open( src ) + resized_image = resize_fn(image) + resized_image.save( dst ) + +def resize_input( image: Image.Image ): + return transforms.functional.resize( image, 240, interpolation=Image.BILINEAR ) + +def resize_seg( image: Image.Image ): + return transforms.functional.resize( image, 240, interpolation=Image.NEAREST) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--root', type=str, required=True) + + ROOT = parser.parse_args().root + IMG_DIR = os.path.join( ROOT, 'JPEGImages' ) + GT_DIR = os.path.join( ROOT, 'SegmentationClass' ) + + NEW_ROOT = os.path.join( ROOT, '240' ) + NEW_IMG_DIR = os.path.join( NEW_ROOT, 'JPEGImages' ) + NEW_GT_DIR = os.path.join( NEW_ROOT, 'SegmentationClass' ) + + if os.path.exists(NEW_ROOT): + print("Directory %s existed, please remove it before running this script"%NEW_ROOT) + return + + os.mkdir(NEW_ROOT) + os.mkdir( NEW_IMG_DIR ) + os.mkdir( NEW_GT_DIR ) + + copy_and_resize( IMG_DIR, NEW_IMG_DIR, resize_input ) + copy_and_resize( GT_DIR, NEW_GT_DIR, resize_seg ) + shutil.copytree( os.path.join( ROOT, 'ImageSets'), os.path.join( NEW_ROOT, 'ImageSets' ) ) + + +if __name__=='__main__': + main() diff --git a/model_measuring/kamal/vision/datasets/stanford_cars.py b/model_measuring/kamal/vision/datasets/stanford_cars.py new file mode 100644 index 0000000..4707ee1 --- /dev/null +++ b/model_measuring/kamal/vision/datasets/stanford_cars.py @@ -0,0 +1,80 @@ +import os +import glob +from PIL import Image +import numpy as np +from scipy.io import loadmat + +from torch.utils import data +from .utils import download_url, mkdir + +from shutil import copyfile + + +class StanfordCars(data.Dataset): + """Dataset for Stanford Cars + """ + + urls = {'cars_train.tgz': 'http://imagenet.stanford.edu/internal/car196/cars_train.tgz', + 'cars_test.tgz': 'http://imagenet.stanford.edu/internal/car196/cars_test.tgz', + 'car_devkit.tgz': 'https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz', + 'cars_test_annos_withlabels.mat': 'http://imagenet.stanford.edu/internal/car196/cars_test_annos_withlabels.mat'} + + def __init__(self, root, split='train', download=False, transform=None, target_transform=None): + self.root = os.path.abspath( os.path.expanduser(root) ) + self.split = split + self.transform = transform + self.target_transform = target_transform + + if download: + self.download() + + if self.split == 'train': + annos = os.path.join(self.root, 'devkit', 'cars_train_annos.mat') + else: + annos = os.path.join(self.root, 'devkit', + 'cars_test_annos_withlabels.mat') + + annos = loadmat(annos) + size = len(annos['annotations'][0]) + + self.files = glob.glob(os.path.join( + self.root, 'cars_'+self.split, '*.jpg')) + self.files.sort() + + self.labels = np.array([int(l[4])-1 for l in annos['annotations'][0]]) + + lbl_annos = loadmat(os.path.join(self.root, 'devkit', 'cars_meta.mat')) + + self.object_categories = [str(c[0]) + for c in lbl_annos['class_names'][0]] + + print('Stanford Cars, Split: %s, Size: %d' % + (self.split, self.__len__())) + + def __len__(self): + return len(self.files) + + def __getitem__(self, idx): + img = Image.open(os.path.join(self.root, 'Images', + self.files[idx])).convert("RGB") + lbl = self.labels[idx] + if self.transform is not None: + img = self.transform(img) + if self.target_transform is not None: + lbl = self.target_transform(lbl) + return img, lbl + + def download(self): + import tarfile + + mkdir(self.root) + for fname, url in self.urls.items(): + if not os.path.isfile(os.path.join(self.root, fname)): + download_url(url, self.root, fname) + if fname.endswith('tgz'): + print("Extracting %s..." % fname) + with tarfile.open(os.path.join(self.root, fname), "r:gz") as tar: + tar.extractall(path=self.root) + + copyfile(os.path.join(self.root, 'cars_test_annos_withlabels.mat'), + os.path.join(self.root, 'devkit', 'cars_test_annos_withlabels.mat')) diff --git a/model_measuring/kamal/vision/datasets/stanford_dogs.py b/model_measuring/kamal/vision/datasets/stanford_dogs.py new file mode 100644 index 0000000..644ace9 --- /dev/null +++ b/model_measuring/kamal/vision/datasets/stanford_dogs.py @@ -0,0 +1,58 @@ +import os +import numpy as np +from PIL import Image +from scipy.io import loadmat + +from torch.utils import data +from .utils import download_url +from shutil import move + +class StanfordDogs(data.Dataset): + """Dataset for Stanford Dogs + """ + urls = {"images.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/images.tar", + "annotation.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/annotation.tar", + "lists.tar": "http://vision.stanford.edu/aditya86/ImageNetDogs/lists.tar"} + + def __init__(self, root, split='train', download=False, transform=None, target_transform=None): + self.root = os.path.abspath( os.path.expanduser(root) ) + self.split = split + self.transform = transform + self.target_transform = target_transform + if download: + self.download() + list_file = os.path.join(self.root, self.split+'_list.mat') + mat_file = loadmat(list_file) + size = len(mat_file['file_list']) + self.files = [str(mat_file['file_list'][i][0][0]) for i in range(size)] + self.labels = np.array( + [mat_file['labels'][i][0]-1 for i in range(size)]) + categories = os.listdir(os.path.join(self.root, 'Images')) + categories.sort() + self.object_categories = [c[10:] for c in categories] + print('Stanford Dogs, Split: %s, Size: %d' % + (self.split, self.__len__())) + + def __len__(self): + return len(self.files) + + def __getitem__(self, idx): + img = Image.open(os.path.join(self.root, 'Images', + self.files[idx])).convert("RGB") + lbl = self.labels[idx] + if self.transform is not None: + img = self.transform(img) + if self.target_transform is not None: + lbl = self.target_transform( lbl ) + return img, lbl + + def download(self): + import tarfile + os.makedirs(self.root, exist_ok=True) + for fname, url in self.urls.items(): + if not os.path.isfile(os.path.join(self.root, fname)): + download_url(url, self.root, fname) + # extract file + print("Extracting %s..." % fname) + with tarfile.open(os.path.join(self.root, fname), "r") as tar: + tar.extractall(path=self.root) diff --git a/model_measuring/kamal/vision/datasets/sunrgbd.py b/model_measuring/kamal/vision/datasets/sunrgbd.py new file mode 100644 index 0000000..392e189 --- /dev/null +++ b/model_measuring/kamal/vision/datasets/sunrgbd.py @@ -0,0 +1,65 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + + +import os +from glob import glob +from PIL import Image +from .utils import colormap +from torchvision.datasets import VisionDataset + +class SunRGBD(VisionDataset): + cmap = colormap() + def __init__(self, + root, + split='train', + transform=None, + target_transform=None, + transforms=None): + super( SunRGBD, self ).__init__( root, transform=transform, target_transform=target_transform, transforms=transforms ) + self.root = root + self.split = split + + self.images = glob(os.path.join(self.root, 'SUNRGBD-%s_images'%self.split, '*.jpg')) + self.labels = glob(os.path.join(self.root, '%s13labels'%self.split, '*.png')) + + self.images.sort() + self.labels.sort() + + def __getitem__(self, idx): + """ + Args: + - index (``int``): index of the item in the dataset + Returns: + A tuple of ``PIL.Image`` (image, label) where label is the ground-truth + of the image. + """ + + img, label = Image.open(self.images[idx]), Image.open(self.labels[idx]) + + if self.transform is not None: + img, label = self.transform(img, label) + label = label-1 # void 0=>255 + return img, label + + def __len__(self): + return len(self.images) + + @classmethod + def decode_fn(cls, mask): + """decode semantic mask to RGB image""" + return cls.cmap[mask.astype('uint8')+1] \ No newline at end of file diff --git a/model_measuring/kamal/vision/datasets/unlabeled.py b/model_measuring/kamal/vision/datasets/unlabeled.py new file mode 100644 index 0000000..fd0c09a --- /dev/null +++ b/model_measuring/kamal/vision/datasets/unlabeled.py @@ -0,0 +1,67 @@ +""" + Copyright 2020 Tianshu AI Platform. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================= +""" + +import torch +from torch.utils.data import Dataset +import os + +from PIL import Image +import random +from copy import deepcopy + +def _collect_all_images(root, postfix=['png', 'jpg', 'jpeg', 'JPEG']): + images = [] + if isinstance( postfix, str): + postfix = [ postfix ] + for dirpath, dirnames, files in os.walk(root): + for pos in postfix: + for f in files: + if f.endswith( pos ): + images.append( os.path.join( dirpath, f ) ) + return images + +def get_train_val_set(root, val_size=0.3): + if not isinstance(root, (list, tuple)): + root = [root] + train_set = [] + val_set = [] + for _root in root: + _part_train_set = _collect_all_images( _root) + if os.path.isdir( os.path.join(_root, 'test') ): + _part_val_set = _collect_all_images( os.path.join(_root, 'test') ) + else: + _val_size = int( len(_part_train_set) * val_size ) + _part_val_set = random.sample( _part_train_set, k=_val_size ) + _part_train_set = [ d for d in _part_train_set if d not in _part_val_set ] + train_set.extend(_part_train_set) + val_set.extend(_part_val_set) + return train_set, val_set + +class UnlabeledDataset(Dataset): + def __init__(self, data, transform=None, postfix=['png', 'jpg', 'jpeg', 'JPEG']): + self.transform = transform + self.data = data + + def __getitem__(self, idx): + data = Image.open( self.data[idx] ) + if self.transform is not None: + data = self.transform(data) + return data + + def __len__(self): + return len(self.data) + diff --git a/model_measuring/kamal/vision/datasets/utils.py b/model_measuring/kamal/vision/datasets/utils.py new file mode 100644 index 0000000..4fc8910 --- /dev/null +++ b/model_measuring/kamal/vision/datasets/utils.py @@ -0,0 +1,161 @@ +# Modified from https://github.com/pytorch/vision +import os +import os.path +import hashlib +import errno +from tqdm import tqdm +import numpy as np +import torch +import random + +def mkdir(dir): + if not os.path.isdir(dir): + os.mkdir(dir) + +def colormap(N=256, normalized=False): + def bitget(byteval, idx): + return ((byteval & (1 << idx)) != 0) + + dtype = 'float32' if normalized else 'uint8' + cmap = np.zeros((N, 3), dtype=dtype) + for i in range(N): + r = g = b = 0 + c = i + for j in range(8): + r = r | (bitget(c, 0) << 7-j) + g = g | (bitget(c, 1) << 7-j) + b = b | (bitget(c, 2) << 7-j) + c = c >> 3 + + cmap[i] = np.array([r, g, b]) + + cmap = cmap/255 if normalized else cmap + return cmap + +DEFAULT_COLORMAP = colormap() + +def gen_bar_updater(pbar): + def bar_update(count, block_size, total_size): + if pbar.total is None and total_size: + pbar.total = total_size + progress_bytes = count * block_size + pbar.update(progress_bytes - pbar.n) + + return bar_update + + +def check_integrity(fpath, md5=None): + if md5 is None: + return True + if not os.path.isfile(fpath): + return False + md5o = hashlib.md5() + with open(fpath, 'rb') as f: + # read in 1MB chunks + for chunk in iter(lambda: f.read(1024 * 1024), b''): + md5o.update(chunk) + md5c = md5o.hexdigest() + if md5c != md5: + return False + return True + + +def makedir_exist_ok(dirpath): + """ + Python2 support for os.makedirs(.., exist_ok=True) + """ + try: + os.makedirs(dirpath) + except OSError as e: + if e.errno == errno.EEXIST: + pass + else: + raise + +def download_url(url, root, filename=None, md5=None): + """Download a file from a url and place it in root. + Args: + url (str): URL to download file from + root (str): Directory to place downloaded file in + filename (str): Name to save the file under. If None, use the basename of the URL + md5 (str): MD5 checksum of the download. If None, do not check + """ + from six.moves import urllib + + root = os.path.expanduser(root) + if not filename: + filename = os.path.basename(url) + fpath = os.path.join(root, filename) + + makedir_exist_ok(root) + + # downloads file + if os.path.isfile(fpath) and check_integrity(fpath, md5): + print('Using downloaded and verified file: ' + fpath) + else: + try: + print('Downloading ' + url + ' to ' + fpath) + urllib.request.urlretrieve( + url, fpath, + reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True)) + ) + except OSError: + if url[:5] == 'https': + url = url.replace('https:', 'http:') + print('Failed download. Trying https -> http instead.' + ' Downloading ' + url + ' to ' + fpath) + urllib.request.urlretrieve( + url, fpath, + reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True)) + ) + + +def list_dir(root, prefix=False): + """List all directories at a given root + Args: + root (str): Path to directory whose folders need to be listed + prefix (bool, optional): If true, prepends the path to each result, otherwise + only returns the name of the directories found + """ + root = os.path.expanduser(root) + directories = list( + filter( + lambda p: os.path.isdir(os.path.join(root, p)), + os.listdir(root) + ) + ) + + if prefix is True: + directories = [os.path.join(root, d) for d in directories] + + return directories + + +def list_files(root, suffix, prefix=False): + """List all files ending with a suffix at a given root + Args: + root (str): Path to directory whose folders need to be listed + suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). + It uses the Python "str.endswith" method and is passed directly + prefix (bool, optional): If true, prepends the path to each result, otherwise + only returns the name of the files found + """ + root = os.path.expanduser(root) + files = list( + filter( + lambda p: os.path.isfile(os.path.join( + root, p)) and p.endswith(suffix), + os.listdir(root) + ) + ) + + if prefix is True: + files = [os.path.join(root, d) for d in files] + + return files + +def set_seed(random_seed): + torch.manual_seed(random_seed) + torch.cuda.manual_seed(random_seed) + np.random.seed(random_seed) + random.seed(random_seed) diff --git a/model_measuring/kamal/vision/datasets/voc.py b/model_measuring/kamal/vision/datasets/voc.py new file mode 100644 index 0000000..688b7cc --- /dev/null +++ b/model_measuring/kamal/vision/datasets/voc.py @@ -0,0 +1,209 @@ +# Modified from https://github.com/pytorch/vision +import os +import sys +import tarfile +import collections +import torch.utils.data as data +import shutil +import numpy as np +from .utils import colormap +from torchvision.datasets import VisionDataset +import torch +from PIL import Image +from torchvision.datasets.utils import download_url, check_integrity + +DATASET_YEAR_DICT = { + '2012aug': { + 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', + 'filename': 'VOCtrainval_11-May-2012.tar', + 'md5': '6cd6e144f989b92b3379bac3b3de84fd', + 'base_dir': 'VOCdevkit/VOC2012' + }, + '2012': { + 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', + 'filename': 'VOCtrainval_11-May-2012.tar', + 'md5': '6cd6e144f989b92b3379bac3b3de84fd', + 'base_dir': 'VOCdevkit/VOC2012' + }, + '2011': { + 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar', + 'filename': 'VOCtrainval_25-May-2011.tar', + 'md5': '6c3384ef61512963050cb5d687e5bf1e', + 'base_dir': 'TrainVal/VOCdevkit/VOC2011' + }, + '2010': { + 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar', + 'filename': 'VOCtrainval_03-May-2010.tar', + 'md5': 'da459979d0c395079b5c75ee67908abb', + 'base_dir': 'VOCdevkit/VOC2010' + }, + '2009': { + 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar', + 'filename': 'VOCtrainval_11-May-2009.tar', + 'md5': '59065e4b188729180974ef6572f6a212', + 'base_dir': 'VOCdevkit/VOC2009' + }, + '2008': { + 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar', + 'filename': 'VOCtrainval_11-May-2012.tar', + 'md5': '2629fa636546599198acfcfbfcf1904a', + 'base_dir': 'VOCdevkit/VOC2008' + }, + '2007': { + 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar', + 'filename': 'VOCtrainval_06-Nov-2007.tar', + 'md5': 'c52e279531787c972589f7e41ab4ae64', + 'base_dir': 'VOCdevkit/VOC2007' + } +} + +class VOCSegmentation(VisionDataset): + """`Pascal VOC `_ Segmentation Dataset. + Args: + root (string): Root directory of the VOC Dataset. + year (string, optional): The dataset year, supports years 2007 to 2012. + image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val`` + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + """ + cmap = colormap() + def __init__(self, + root, + year='2012', + image_set='train', + download=False, + transform=None, + target_transform=None, + transforms=None, + ): + super( VOCSegmentation, self ).__init__( root, transform=transform, target_transform=target_transform, transforms=transforms ) + + is_aug=False + if year=='2012aug': + is_aug = True + year = '2012' + + self.root = os.path.expanduser(root) + self.year = year + self.url = DATASET_YEAR_DICT[year]['url'] + self.filename = DATASET_YEAR_DICT[year]['filename'] + self.md5 = DATASET_YEAR_DICT[year]['md5'] + + self.image_set = image_set + base_dir = DATASET_YEAR_DICT[year]['base_dir'] + voc_root = os.path.join(self.root, base_dir) + image_dir = os.path.join(voc_root, 'JPEGImages') + + if download: + download_extract(self.url, self.root, self.filename, self.md5) + + if not os.path.isdir(voc_root): + raise RuntimeError('Dataset not found or corrupted.' + + ' You can use download=True to download it') + + if is_aug and image_set=='train': + mask_dir = os.path.join(voc_root, 'SegmentationClassAug') + assert os.path.exists(mask_dir), "SegmentationClassAug not found, please refer to README.md and prepare it manually" + split_f = os.path.join( self.root, 'train_aug.txt') + else: + mask_dir = os.path.join(voc_root, 'SegmentationClass') + splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation') + split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt') + + if not os.path.exists(split_f): + raise ValueError( + 'Wrong image_set entered! Please use image_set="train" ' + 'or image_set="trainval" or image_set="val"') + + with open(os.path.join(split_f), "r") as f: + file_names = [x.strip() for x in f.readlines()] + + self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] + self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names] + assert (len(self.images) == len(self.masks)) + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: (image, target) where target is the image segmentation. + """ + img = Image.open(self.images[index]).convert('RGB') + target = Image.open(self.masks[index]) + if self.transforms is not None: + img, target = self.transforms(img, target) + return img, target.squeeze(0) + + def __len__(self): + return len(self.images) + + @classmethod + def decode_fn(cls, mask): + """decode semantic mask to RGB image""" + return cls.cmap[mask] + +def download_extract(url, root, filename, md5): + download_url(url, root, filename, md5) + with tarfile.open(os.path.join(root, filename), "r") as tar: + tar.extractall(path=root) + +CLASSES = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', + 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'] + +class VOCClassification(data.Dataset): + def __init__(self, + root, + year='2010', + split='train', + download=False, + transforms=None, + target_transforms=None): + + voc_root = os.path.join(root, 'VOC{}'.format(year)) + if not os.path.isdir(voc_root): + raise RuntimeError('Dataset not found or corrupted.' + + ' You can use download=True to download it') + + self.transforms = transforms + self.target_transforms = target_transforms + image_dir = os.path.join(voc_root, 'JPEGImages') + label_dir = os.path.join(voc_root, 'ImageSets/Main') + self.labels_list = [] + + fname = os.path.join(label_dir, '{}.txt'.format(split)) + with open(fname) as f: + self.images = [os.path.join(image_dir, line.split()[0]+'.jpg') for line in f] + + for clas in CLASSES: + labels = [] + with open(os.path.join(label_dir, '{}_{}.txt'.format(clas, split))) as f: + labels = [int(line.split()[1]) for line in f] + self.labels_list.append(labels) + + assert (len(self.images) == len(self.labels_list[0])) + + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: (image, target) where target is the image segmentation. + """ + img = Image.open(self.images[index]).convert('RGB') + labels = [labels[index] for labels in self.labels_list] + + if self.transforms is not None: + img = self.transforms(img) + + if self.target_transforms is not None: + labels = self.target_transforms(labels) + + return img, labels + + def __len__(self): + return len(self.images) \ No newline at end of file diff --git a/model_measuring/kamal/vision/models/__init__.py b/model_measuring/kamal/vision/models/__init__.py new file mode 100644 index 0000000..2b8cd84 --- /dev/null +++ b/model_measuring/kamal/vision/models/__init__.py @@ -0,0 +1,3 @@ +from . import classification, segmentation + +from torchvision import models as torchvision_models \ No newline at end of file diff --git a/model_measuring/kamal/vision/models/classification/__init__.py b/model_measuring/kamal/vision/models/classification/__init__.py new file mode 100644 index 0000000..2f3ccec --- /dev/null +++ b/model_measuring/kamal/vision/models/classification/__init__.py @@ -0,0 +1,7 @@ +from .darknet import * +from .mobilenetv2 import * +from .resnet import * +from .vgg import * +from . import cifar + +from .alexnet import alexnet \ No newline at end of file diff --git a/model_measuring/kamal/vision/models/classification/alexnet.py b/model_measuring/kamal/vision/models/classification/alexnet.py new file mode 100644 index 0000000..8f7c856 --- /dev/null +++ b/model_measuring/kamal/vision/models/classification/alexnet.py @@ -0,0 +1,63 @@ +# Modified from https://github.com/pytorch/vision +import torch.nn as nn +import torch.utils.model_zoo as model_zoo + + +__all__ = ['AlexNet', 'alexnet'] + + +model_urls = { + 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', +} + + +class AlexNet(nn.Module): + + def __init__(self, num_classes=1000): + super(AlexNet, self).__init__() + self.features = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(64, 192, kernel_size=5, padding=2), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + nn.Conv2d(192, 384, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(384, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2), + ) + self.classifier = nn.Sequential( + nn.Dropout(), + nn.Linear(256 * 6 * 6, 4096), + nn.ReLU(inplace=True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(inplace=True), + nn.Linear(4096, num_classes), + ) + + def forward(self, x): + x = self.features(x) + x = x.view(x.size(0), 256 * 6 * 6) + x = self.classifier(x) + return x + + +def alexnet(pretrained=False, **kwargs): + r"""AlexNet model architecture from the + `"One weird trick..." `_ paper. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + num_classes = kwargs.pop('num_classes', None) + model = AlexNet(**kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['alexnet'])) + if num_classes is not None and num_classes!=1000: + model.classifier[-1] = nn.Linear(4096, num_classes) + return model \ No newline at end of file diff --git a/model_measuring/kamal/vision/models/classification/cifar/__init__.py b/model_measuring/kamal/vision/models/classification/cifar/__init__.py new file mode 100644 index 0000000..fea00a4 --- /dev/null +++ b/model_measuring/kamal/vision/models/classification/cifar/__init__.py @@ -0,0 +1 @@ +from . import wrn \ No newline at end of file diff --git a/model_measuring/kamal/vision/models/classification/cifar/wrn.py b/model_measuring/kamal/vision/models/classification/cifar/wrn.py new file mode 100644 index 0000000..7abb8e8 --- /dev/null +++ b/model_measuring/kamal/vision/models/classification/cifar/wrn.py @@ -0,0 +1,108 @@ +#Adapted from https://github.com/polo5/ZeroShotKnowledgeTransfer/blob/master/models/wresnet.py + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['wrn'] + +class BasicBlock(nn.Module): + def __init__(self, in_planes, out_planes, stride, dropout_rate=0.0): + super(BasicBlock, self).__init__() + self.bn1 = nn.BatchNorm2d(in_planes) + self.relu1 = nn.ReLU(inplace=True) + self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(out_planes) + self.relu2 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, + padding=1, bias=False) + self.dropout = nn.Dropout( dropout_rate ) + self.equalInOut = (in_planes == out_planes) + self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, + padding=0, bias=False) or None + + def forward(self, x): + if not self.equalInOut: + x = self.relu1(self.bn1(x)) + else: + out = self.relu1(self.bn1(x)) + out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) + out = self.dropout(out) + out = self.conv2(out) + return torch.add(x if self.equalInOut else self.convShortcut(x), out) + + +class NetworkBlock(nn.Module): + def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropout_rate=0.0): + super(NetworkBlock, self).__init__() + self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropout_rate) + + def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropout_rate): + layers = [] + for i in range(nb_layers): + layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropout_rate)) + return nn.Sequential(*layers) + + def forward(self, x): + return self.layer(x) + + +class WideResNet(nn.Module): + def __init__(self, depth, num_classes, widen_factor=1, dropout_rate=0.0): + super(WideResNet, self).__init__() + nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] + assert (depth - 4) % 6 == 0, 'depth should be 6n+4' + n = (depth - 4) // 6 + block = BasicBlock + # 1st conv before any network block + self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, + padding=1, bias=False) + # 1st block + self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropout_rate) + # 2nd block + self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropout_rate) + # 3rd block + self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropout_rate) + # global average pooling and classifier + self.bn1 = nn.BatchNorm2d(nChannels[3]) + self.relu = nn.ReLU(inplace=True) + self.fc = nn.Linear(nChannels[3], num_classes) + self.nChannels = nChannels[3] + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + m.bias.data.zero_() + + def forward(self, x, return_features=False): + out = self.conv1(x) + out = self.block1(out) + out = self.block2(out) + out = self.block3(out) + out = self.relu(self.bn1(out)) + out = F.avg_pool2d(out, 8) + features = out.view(-1, self.nChannels) + out = self.fc(features) + if return_features: + return out, features + else: + return out + +def wrn_16_1(num_classes, dropout_rate=0): + return WideResNet(depth=16, num_classes=num_classes, widen_factor=1, dropout_rate=dropout_rate) + +def wrn_16_2(num_classes, dropout_rate=0): + return WideResNet(depth=16, num_classes=num_classes, widen_factor=2, dropout_rate=dropout_rate) + +def wrn_40_1(num_classes, dropout_rate=0): + return WideResNet(depth=40, num_classes=num_classes, widen_factor=1, dropout_rate=dropout_rate) + +def wrn_40_2(num_classes, dropout_rate=0): + return WideResNet(depth=40, num_classes=num_classes, widen_factor=2, dropout_rate=dropout_rate) \ No newline at end of file diff --git a/model_measuring/kamal/vision/models/classification/darknet.py b/model_measuring/kamal/vision/models/classification/darknet.py new file mode 100644 index 0000000..c9a35f4 --- /dev/null +++ b/model_measuring/kamal/vision/models/classification/darknet.py @@ -0,0 +1,246 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from ..utils import download_from_url, load_darknet_weights + +model_urls = { + 'darknet19': 'https://pjreddie.com/media/files/darknet19.weights', + 'darknet19_448': 'https://pjreddie.com/media/files/darknet19_448.weights', + 'darknet53': 'https://pjreddie.com/media/files/darknet53.weights', + 'darknet53_448': 'https://pjreddie.com/media/files/darknet53_448.weights' +} + + +def conv3x3(in_planes, out_planes, padding=1, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=padding, + groups=groups, + bias=False, + dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=1, + stride=stride, + padding=0, + bias=False) + + +class BasicBlock(nn.Module): + def __init__(self, planes, norm_layer=None, residual=True): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self.residual = residual + self.block = nn.Sequential( + conv1x1(planes, planes // 2), + norm_layer(planes // 2), + nn.LeakyReLU(0.1, inplace=True), + conv3x3(planes // 2, planes), + norm_layer(planes), + nn.LeakyReLU(0.1, inplace=True), + ) + + def forward(self, x): + identity = x + out = self.block(x) + if self.residual: + out = out + identity + return out + +class DarkNet(nn.Module): + def __init__(self, layers, num_classes=1000, pooling=False, residual=True): + super(DarkNet, self).__init__() + self.inplanes = 32 + self.pooling = pooling + self.residual = residual + + features = [ + conv3x3(3, self.inplanes), + nn.BatchNorm2d(self.inplanes), + nn.LeakyReLU(0.1, inplace=True), + ] + features.extend(self._make_layer(64, layers[0])) + features.extend(self._make_layer(128, layers[1])) + features.extend(self._make_layer(256, layers[2])) + features.extend(self._make_layer(512, layers[3])) + features.extend(self._make_layer(1024, layers[4])) + + self.features = nn.Sequential(*features) + self.conv = nn.Conv2d(1024, + num_classes, + kernel_size=(1, 1), + stride=(1, 1)) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + # self.classifier = nn.Linear(1024, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, + mode='fan_out', + nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, planes, blocks): + layers = [] + if self.pooling == True: + layers.append(nn.MaxPool2d(2, 2)) # downsample with maxpooling + layers.extend([ + conv3x3(self.inplanes, planes, stride=1 if self.pooling else 2), + nn.BatchNorm2d(planes), + nn.LeakyReLU(0.1, inplace=True), + ]) + + for _ in range(blocks): + layers.append(BasicBlock(planes, residual=self.residual)) + self.inplanes = planes + return layers + + def load_weights(self, weights_file, change): + load_darknet_weights(self, weights_file) + if change: + new_order = [ + 278, 212, 250, 193, 217, 147, 387, 285, 350, 283, 286, 353, + 334, 150, 249, 362, 246, 166, 218, 172, 177, 148, 357, 386, + 178, 202, 194, 271, 229, 290, 175, 163, 191, 276, 299, 197, + 380, 364, 339, 359, 251, 165, 157, 361, 179, 268, 233, 356, + 266, 264, 225, 349, 335, 375, 282, 204, 352, 272, 187, 256, + 294, 277, 174, 234, 351, 176, 280, 223, 154, 262, 203, 190, + 370, 298, 384, 292, 170, 342, 241, 340, 348, 245, 365, 253, + 288, 239, 153, 185, 158, 211, 192, 382, 224, 216, 284, 367, + 228, 160, 152, 376, 338, 270, 296, 366, 169, 265, 183, 345, + 199, 244, 381, 236, 195, 238, 240, 155, 221, 259, 181, 343, + 354, 369, 196, 231, 207, 184, 252, 232, 331, 242, 201, 162, + 255, 210, 371, 274, 372, 373, 209, 243, 222, 378, 254, 206, + 186, 205, 341, 261, 248, 215, 267, 189, 289, 214, 273, 198, + 333, 200, 279, 188, 161, 346, 295, 332, 347, 379, 344, 260, + 388, 180, 230, 257, 151, 281, 377, 208, 247, 363, 258, 164, + 168, 358, 336, 227, 368, 355, 237, 330, 171, 291, 219, 213, + 149, 385, 337, 220, 263, 156, 383, 159, 287, 275, 374, 173, + 269, 293, 167, 226, 297, 182, 235, 360, 105, 101, 102, 104, + 103, 106, 763, 879, 780, 805, 401, 310, 327, 117, 579, 620, + 949, 404, 895, 405, 417, 812, 554, 576, 814, 625, 472, 914, + 484, 871, 510, 628, 724, 403, 833, 913, 586, 847, 657, 450, + 537, 444, 671, 565, 705, 428, 791, 670, 561, 547, 820, 408, + 407, 436, 468, 511, 609, 627, 656, 661, 751, 817, 573, 575, + 665, 803, 555, 569, 717, 864, 867, 675, 734, 757, 829, 802, + 866, 660, 870, 880, 603, 612, 690, 431, 516, 520, 564, 453, + 495, 648, 493, 846, 553, 703, 423, 857, 559, 765, 831, 861, + 526, 736, 532, 548, 894, 948, 950, 951, 952, 953, 954, 955, + 956, 957, 988, 989, 998, 984, 987, 990, 687, 881, 494, 541, + 577, 641, 642, 822, 420, 486, 889, 594, 402, 546, 513, 566, + 875, 593, 684, 699, 432, 683, 776, 558, 985, 986, 972, 979, + 970, 980, 976, 977, 973, 975, 978, 974, 596, 499, 623, 726, + 740, 621, 587, 512, 473, 731, 784, 792, 730, 491, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, + 98, 99, 100, 127, 128, 129, 130, 132, 131, 133, 134, 135, 137, + 138, 139, 140, 141, 142, 143, 136, 144, 145, 146, 2, 3, 4, 5, + 6, 389, 391, 0, 1, 390, 392, 393, 396, 397, 394, 395, 33, 34, + 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 51, 49, + 50, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, + 67, 68, 25, 26, 27, 28, 29, 30, 31, 32, 902, 908, 696, 589, + 691, 801, 632, 650, 782, 673, 545, 686, 828, 811, 827, 583, + 426, 769, 685, 778, 409, 530, 892, 604, 835, 704, 826, 531, + 823, 845, 635, 447, 745, 837, 633, 755, 456, 471, 413, 764, + 744, 508, 878, 517, 626, 398, 480, 798, 527, 590, 681, 916, + 595, 856, 742, 800, 886, 786, 613, 844, 600, 479, 694, 723, + 739, 571, 476, 843, 758, 753, 746, 592, 836, 714, 475, 807, + 761, 535, 464, 584, 616, 507, 695, 677, 772, 783, 676, 785, + 795, 470, 607, 818, 862, 678, 718, 872, 645, 674, 815, 69, 70, + 71, 72, 73, 74, 75, 76, 77, 78, 79, 126, 118, 119, 120, 121, + 122, 123, 124, 125, 300, 301, 302, 303, 304, 305, 306, 307, + 308, 309, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, + 321, 322, 323, 324, 325, 326, 107, 108, 109, 110, 111, 112, + 113, 114, 115, 116, 328, 329, 606, 550, 651, 544, 766, 859, + 891, 882, 534, 760, 897, 521, 567, 909, 469, 505, 849, 813, + 406, 873, 706, 821, 839, 888, 425, 580, 698, 663, 624, 410, + 449, 497, 668, 832, 727, 762, 498, 598, 634, 506, 682, 863, + 483, 743, 582, 415, 424, 454, 467, 509, 788, 860, 865, 562, + 500, 915, 536, 458, 649, 421, 460, 525, 489, 716, 912, 825, + 581, 799, 877, 672, 781, 599, 729, 708, 437, 935, 945, 936, + 937, 938, 939, 940, 941, 942, 943, 944, 946, 947, 794, 608, + 478, 591, 774, 412, 771, 923, 679, 522, 568, 855, 697, 770, + 503, 492, 640, 662, 876, 868, 416, 931, 741, 614, 926, 901, + 615, 921, 816, 796, 440, 518, 455, 858, 643, 638, 712, 560, + 433, 850, 597, 737, 713, 887, 918, 574, 927, 834, 900, 552, + 501, 966, 542, 787, 496, 601, 922, 819, 452, 962, 429, 551, + 777, 838, 441, 996, 924, 619, 911, 958, 457, 636, 899, 463, + 533, 809, 969, 666, 869, 693, 488, 840, 659, 964, 907, 789, + 465, 540, 446, 474, 841, 738, 448, 588, 722, 709, 707, 925, + 411, 747, 414, 982, 439, 710, 462, 669, 399, 667, 735, 523, + 732, 810, 968, 752, 920, 749, 754, 961, 524, 652, 629, 793, + 664, 688, 658, 459, 930, 883, 653, 768, 700, 995, 549, 655, + 515, 874, 711, 435, 934, 991, 466, 721, 999, 481, 477, 618, + 994, 631, 585, 400, 538, 519, 903, 965, 720, 490, 854, 905, + 427, 896, 418, 430, 434, 514, 578, 904, 992, 487, 680, 422, + 637, 617, 556, 654, 692, 646, 733, 602, 808, 715, 756, 893, + 482, 917, 719, 919, 442, 563, 906, 890, 689, 775, 748, 451, + 443, 701, 797, 851, 842, 647, 967, 963, 461, 790, 910, 773, + 960, 981, 572, 993, 830, 898, 528, 804, 610, 779, 611, 728, + 759, 529, 419, 929, 885, 852, 570, 539, 630, 928, 932, 750, + 639, 848, 502, 605, 997, 983, 725, 644, 445, 806, 485, 622, + 853, 884, 438, 971, 933, 702, 557, 504, 767, 824, 959, 543 + ] + conv_layers = [ layer for layer in self.modules() if isinstance( layer, nn.Conv2d)] + last_conv_layer = conv_layers[-1] + weight = last_conv_layer.weight.data + new_weight = torch.zeros(weight.size(), dtype=weight.dtype) + for i, idx in enumerate(new_order): + new_weight[idx] = weight[i] + last_conv_layer.weight.data.copy_( new_weight ) + + + def forward(self, x): + x = self.features(x) + x = self.conv(x) + x = self.avgpool(x) + x = torch.flatten(x, 1) + return x + + + +def darknet19(pretrained=False, change=False, progress=True, **kwargs): + model = DarkNet(layers=[0, 1, 1, 2, 2], pooling=True, residual=False) + if pretrained: + weights_file = download_from_url(model_urls['darknet19'], + progress=progress) + model.load_weights(weights_file, change) + return model + + +def darknet19_448(pretrained=False, change=False, progress=True, **kwargs): + model = DarkNet(layers=[0, 1, 1, 2, 2], pooling=True, residual=False) + if pretrained: + weights_file = download_from_url(model_urls['darknet19_448'], + progress=progress) + model.load_weights(weights_file, change) + return model + + +def darknet53(pretrained=False, change=False, progress=True, **kwargs): + model = DarkNet(layers=[1, 2, 8, 8, 4], pooling=False, residual=True) + if pretrained: + weights_file = download_from_url(model_urls['darknet53'], + progress=progress) + model.load_weights(weights_file, change) + return model + + +def darknet53_448(pretrained=False, change=False, progress=True, **kwargs): + model = DarkNet(layers=[1, 2, 8, 8, 4], pooling=False, residual=True) + if pretrained: + weights_file = download_from_url(model_urls['darknet53_448'], + progress=progress) + model.load_weights(weights_file, change) + return model \ No newline at end of file diff --git a/model_measuring/kamal/vision/models/classification/densenet.py b/model_measuring/kamal/vision/models/classification/densenet.py new file mode 100644 index 0000000..3bcd186 --- /dev/null +++ b/model_measuring/kamal/vision/models/classification/densenet.py @@ -0,0 +1,233 @@ +# This implementation is based on the DenseNet-BC implementation in torchvision +# https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py + +# Modified from https://github.com/pytorch/vision + + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from collections import OrderedDict +from copy import deepcopy + + +def _bn_function_factory(norm, relu, conv): + def bn_function(*inputs): + concated_features = torch.cat(inputs, 1) + bottleneck_output = conv(relu(norm(concated_features))) + return bottleneck_output + + return bn_function + + +class _DenseLayer(nn.Module): + def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, efficient=False): + super(_DenseLayer, self).__init__() + self.add_module('norm1', nn.BatchNorm2d(num_input_features)), + self.add_module('relu1', nn.ReLU(inplace=True)), + self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * growth_rate, + kernel_size=1, stride=1, bias=False)), + self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), + self.add_module('relu2', nn.ReLU(inplace=True)), + self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, + kernel_size=3, stride=1, padding=1, bias=False)), + self.drop_rate = drop_rate + self.efficient = efficient + + def forward(self, *prev_features): + bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1) + if self.efficient and any(prev_feature.requires_grad for prev_feature in prev_features): + bottleneck_output = cp.checkpoint(bn_function, *prev_features) + else: + bottleneck_output = bn_function(*prev_features) + new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) + if self.drop_rate > 0: + new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) + return new_features + + +class _Transition(nn.Sequential): + def __init__(self, num_input_features, num_output_features): + super(_Transition, self).__init__() + self.add_module('norm', nn.BatchNorm2d(num_input_features)) + self.add_module('relu', nn.ReLU(inplace=True)) + self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, + kernel_size=1, stride=1, bias=False)) + self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) + + +class _DenseBlock(nn.Module): + def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, efficient=False): + super(_DenseBlock, self).__init__() + for i in range(num_layers): + layer = _DenseLayer( + num_input_features + i * growth_rate, + growth_rate=growth_rate, + bn_size=bn_size, + drop_rate=drop_rate, + efficient=efficient, + ) + self.add_module('denselayer%d' % (i + 1), layer) + + def forward(self, init_features): + features = [init_features] + for name, layer in self.named_children(): + new_features = layer(*features) + features.append(new_features) + return torch.cat(features, 1) + + +class DenseNet(nn.Module): + r"""Densenet-BC model class, based on + `"Densely Connected Convolutional Networks" ` + Args: + growth_rate (int) - how many filters to add each layer (`k` in paper) + block_config (list of 3 or 4 ints) - how many layers in each pooling block + num_init_features (int) - the number of filters to learn in the first convolution layer + bn_size (int) - multiplicative factor for number of bottle neck layers + (i.e. bn_size * k features in the bottleneck layer) + drop_rate (float) - dropout rate after each dense layer + num_classes (int) - number of classification classes + small_inputs (bool) - set to True if images are 32x32. Otherwise assumes images are larger. + efficient (bool) - set to True to use checkpointing. Much more memory efficient, but slower. + """ + def __init__(self, growth_rate=12, block_config=(16, 16, 16), compression=0.5, + num_init_features=24, bn_size=4, drop_rate=0, + num_classes=10, small_inputs=True, efficient=False): + + super(DenseNet, self).__init__() + assert 0 < compression <= 1, 'compression of densenet should be between 0 and 1' + + # First convolution + if small_inputs: + self.preprocess = nn.Sequential(OrderedDict([ + ('conv0', nn.Conv2d(3, num_init_features, kernel_size=3, stride=1, padding=1, bias=False)), + ])) + else: + self.preprocess = nn.Sequential(OrderedDict([ + ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), + ])) + self.preprocess.add_module('norm0', nn.BatchNorm2d(num_init_features)) + self.preprocess.add_module('relu0', nn.ReLU(inplace=True)) + self.preprocess.add_module('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1, + ceil_mode=False)) + + # Each denseblock + num_features = num_init_features + self.denseblocks = nn.ModuleList() + self.transitions = nn.ModuleList() + for i, num_layers in enumerate(block_config): + block = _DenseBlock( + num_layers=num_layers, + num_input_features=num_features, + bn_size=bn_size, + growth_rate=growth_rate, + drop_rate=drop_rate, + efficient=efficient, + ) + self.denseblocks.append(block) + num_features = num_features + num_layers * growth_rate + if i != len(block_config) - 1: + trans = _Transition(num_input_features=num_features, + num_output_features=int(num_features * compression)) + self.transitions.append(trans) + num_features = int(num_features * compression) + + # Final batch norm + self.norm_final = nn.BatchNorm2d(num_features) + + # Linear layer + self.fc = nn.Linear(num_features, num_classes) + + # Initialization + for name, param in self.named_parameters(): + if 'conv' in name and 'weight' in name: + n = param.size(0) * param.size(2) * param.size(3) + param.data.normal_().mul_(math.sqrt(2. / n)) + elif 'norm' in name and 'weight' in name: + param.data.fill_(1) + elif 'norm' in name and 'bias' in name: + param.data.fill_(0) + elif 'fc' in name and 'bias' in name: + param.data.fill_(0) + + def forward(self, x): + features = self.preprocess(x) + for i in range(len(self.denseblocks)): + features = self.denseblocks[i](features) + if i < len(self.transitions): + features = self.transitions[i](features) + features = self.norm_final(features) + + out = F.relu(features, inplace=True) + out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1) + out = self.fc(out) + return out + + +class JointDenseNet(nn.Module): + def __init__(self, teachers, indices, phase): + super(JointDenseNet, self).__init__() + assert(len(indices) == len(teachers)) + self.indices = indices + self.phase = phase + self.preprocess = deepcopy(teachers[0].preprocess) + self.denseblocks = deepcopy(teachers[0].denseblocks) + self.transitions = deepcopy(teachers[0].transitions) + + # Initialization before copying modules and parameters of teachers. + for name, param in self.named_parameters(): + if 'conv' in name and 'weight' in name: + n = param.size(0) * param.size(2) * param.size(3) + param.data.normal_().mul_(math.sqrt(2. / n)) + elif 'norm' in name and 'weight' in name: + param.data.fill_(1) + elif 'norm' in name and 'bias' in name: + param.data.fill_(0) + + self.norm_finals = nn.ModuleList([deepcopy(teacher.norm_final) for teacher in teachers]) + self.fcs = nn.ModuleList([deepcopy(teacher.fc) for teacher in teachers]) + + self.teacher_denseblocks_list = nn.ModuleList([deepcopy(teacher.denseblocks) for teacher in teachers]) + self.teacher_transitions_list = nn.ModuleList([deepcopy(teacher.transitions) for teacher in teachers]) + + # Whether to fix parameters of branches from teachers when training each block. + for name, param in self.teacher_denseblocks_list.named_parameters(): + param.requires_grad = (self.phase != 'block') + for name, param in self.teacher_transitions_list.named_parameters(): + param.requires_grad = (self.phase != 'block') + for name, param in self.norm_finals.named_parameters(): + param.requires_grad = (self.phase != 'block') + for name, param in self.fcs.named_parameters(): + param.requires_grad = (self.phase != 'block') + + def forward(self, x): + num_b = len(self.denseblocks) + x = self.preprocess(x) + features_list = [None for i in range(len(self.indices))] + + out_idx = max(self.indices) + for i in range(out_idx): + x = self.denseblocks[i](x) + if i < num_b-1: + x = self.transitions[i](x) + for j in range(len(self.indices)): + if i == self.indices[j]-1: + features_list[j] = x + + # Mimic teachers. + for i in range(len(self.indices)): + for j in range(self.indices[i], num_b): + features_list[i] = self.teacher_denseblocks_list[i][j](features_list[i]) + if j < num_b-1: + features_list[i] = self.teacher_transitions_list[i][j](features_list[i]) + + features_list = [self.norm_finals[i](f) for i,f in enumerate(features_list)] + outs = [F.relu(f, inplace=True) for f in features_list] + outs = [F.adaptive_avg_pool2d(out, (1,1)).view(features_list[0].size(0), -1) for out in outs] + outs = [self.fcs[i](outs[i]) for i in range(len(self.indices))] + out = torch.cat(outs, dim=1) + + return out \ No newline at end of file diff --git a/model_measuring/kamal/vision/models/classification/mobilenetv2.py b/model_measuring/kamal/vision/models/classification/mobilenetv2.py new file mode 100644 index 0000000..2ec0fb4 --- /dev/null +++ b/model_measuring/kamal/vision/models/classification/mobilenetv2.py @@ -0,0 +1,186 @@ +# Modified from https://github.com/pytorch/vision +from torch import nn +from torchvision.models.utils import load_state_dict_from_url +import torch.nn.functional as F + +__all__ = ['MobileNetV2', 'mobilenet_v2'] + + +model_urls = { + 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', +} + + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNReLU(nn.Sequential): + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, dilation=1, groups=1): + #padding = (kernel_size - 1) // 2 + super(ConvBNReLU, self).__init__( + nn.Conv2d(in_planes, out_planes, kernel_size, stride, 0, dilation=dilation, groups=groups, bias=False), + nn.BatchNorm2d(out_planes), + nn.ReLU6(inplace=True) + ) + +def fixed_padding(kernel_size, dilation): + kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) + pad_total = kernel_size_effective - 1 + pad_beg = pad_total // 2 + pad_end = pad_total - pad_beg + return (pad_beg, pad_end, pad_beg, pad_end) + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, dilation, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + # pw + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + + layers.extend([ + # dw + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, dilation=dilation, groups=hidden_dim), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ]) + self.conv = nn.Sequential(*layers) + + self.input_padding = fixed_padding( 3, dilation ) + + def forward(self, x): + x_pad = F.pad(x, self.input_padding) + if self.use_res_connect: + return x + self.conv(x_pad) + else: + return self.conv(x_pad) + +class MobileNetV2(nn.Module): + def __init__(self, num_classes=1000, output_stride=8, width_mult=1.0, inverted_residual_setting=None, round_nearest=8): + """ + MobileNet V2 main class + Args: + num_classes (int): Number of classes + width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount + inverted_residual_setting: Network structure + round_nearest (int): Round the number of channels in each layer to be a multiple of this number + Set to 1 to turn off rounding + """ + super(MobileNetV2, self).__init__() + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + self.output_stride = output_stride + current_stride = 1 + if inverted_residual_setting is None: + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + + # only check the first element, assuming user knows t,c,n,s are required + if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: + raise ValueError("inverted_residual_setting should be non-empty " + "or a 4-element list, got {}".format(inverted_residual_setting)) + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) + features = [ConvBNReLU(3, input_channel, stride=2)] + current_stride *= 2 + dilation=1 + previous_dilation = 1 + + # building inverted residual blocks + for t, c, n, s in inverted_residual_setting: + output_channel = _make_divisible(c * width_mult, round_nearest) + previous_dilation = dilation + if current_stride == output_stride: + stride = 1 + dilation *= s + else: + stride = s + current_stride *= s + output_channel = int(c * width_mult) + + for i in range(n): + if i==0: + features.append(block(input_channel, output_channel, stride, previous_dilation, expand_ratio=t)) + else: + features.append(block(input_channel, output_channel, 1, dilation, expand_ratio=t)) + input_channel = output_channel + # building last several layers + features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) + # make it nn.Sequential + self.features = nn.Sequential(*features) + + # building classifier + self.classifier = nn.Sequential( + nn.Dropout(0.2), + nn.Linear(self.last_channel, num_classes), + ) + + # weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + + def forward(self, x): + x = self.features(x) + x = x.mean([2, 3]) + x = self.classifier(x) + return x + + +def mobilenet_v2(pretrained=False, progress=True, **kwargs): + """ + Constructs a MobileNetV2 architecture from + `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + model = MobileNetV2(**kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], + progress=progress) + model.load_state_dict(state_dict) + return model \ No newline at end of file diff --git a/model_measuring/kamal/vision/models/classification/resnet.py b/model_measuring/kamal/vision/models/classification/resnet.py new file mode 100644 index 0000000..0a1b9cb --- /dev/null +++ b/model_measuring/kamal/vision/models/classification/resnet.py @@ -0,0 +1,336 @@ +# Modified from https://github.com/pytorch/vision +import torch +import torch.nn as nn +from torchvision.models.utils import load_state_dict_from_url + + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'wide_resnet50_2', 'wide_resnet101_2'] + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', + 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', + 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x + + +def _resnet(arch, block, layers, pretrained, progress, **kwargs): + num_classes = kwargs.pop('num_classes', None) + model = ResNet(block, layers, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + model.load_state_dict(state_dict) + if num_classes is not None and num_classes!=1000: + model.fc = nn.Linear( model.fc.in_features, num_classes ) + return model + + +def resnet18(pretrained=False, progress=True, **kwargs): + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) + + +def resnet34(pretrained=False, progress=True, **kwargs): + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet50(pretrained=False, progress=True, **kwargs): + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet101(pretrained=False, progress=True, **kwargs): + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, + **kwargs) + + +def resnet152(pretrained=False, progress=True, **kwargs): + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, + **kwargs) + + +def resnext50_32x4d(pretrained=False, progress=True, **kwargs): + r"""ResNeXt-50 32x4d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 4 + return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def resnext101_32x8d(pretrained=False, progress=True, **kwargs): + r"""ResNeXt-101 32x8d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 8 + return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + + +def wide_resnet50_2(pretrained=False, progress=True, **kwargs): + r"""Wide ResNet-50-2 model from + `"Wide Residual Networks" `_ + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def wide_resnet101_2(pretrained=False, progress=True, **kwargs): + r"""Wide ResNet-101-2 model from + `"Wide Residual Networks" `_ + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) \ No newline at end of file diff --git a/model_measuring/kamal/vision/models/classification/vgg.py b/model_measuring/kamal/vision/models/classification/vgg.py new file mode 100644 index 0000000..1d9beee --- /dev/null +++ b/model_measuring/kamal/vision/models/classification/vgg.py @@ -0,0 +1,176 @@ +# Modified from https://github.com/pytorch/vision +import torch +import torch.nn as nn +from torchvision.models.utils import load_state_dict_from_url + + +__all__ = [ + 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', + 'vgg19_bn', 'vgg19', +] + + +model_urls = { + 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', + 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', + 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', + 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', + 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', + 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', + 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', + 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', +} + + +class VGG(nn.Module): + + def __init__(self, features, num_classes=1000, init_weights=True): + super(VGG, self).__init__() + self.features = features + self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) + self.classifier = nn.Sequential( + nn.Linear(512 * 7 * 7, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, num_classes), + ) + if init_weights: + self._initialize_weights() + + def forward(self, x): + x = self.features(x) + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.classifier(x) + return x + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + +def make_layers(cfg, batch_norm=False): + layers = [] + in_channels = 3 + for v in cfg: + if v == 'M': + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) + if batch_norm: + layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] + else: + layers += [conv2d, nn.ReLU(inplace=True)] + in_channels = v + return nn.Sequential(*layers) + + +cfgs = { + 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], + 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], +} + + +def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs): + if pretrained: + kwargs['init_weights'] = False + model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], + progress=progress) + model.load_state_dict(state_dict) + return model + + +def vgg11(pretrained=False, progress=True, **kwargs): + r"""VGG 11-layer model (configuration "A") from + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs) + + +def vgg11_bn(pretrained=False, progress=True, **kwargs): + r"""VGG 11-layer model (configuration "A") with batch normalization + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs) + + +def vgg13(pretrained=False, progress=True, **kwargs): + r"""VGG 13-layer model (configuration "B") + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs) + + +def vgg13_bn(pretrained=False, progress=True, **kwargs): + r"""VGG 13-layer model (configuration "B") with batch normalization + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs) + + +def vgg16(pretrained=False, progress=True, **kwargs): + r"""VGG 16-layer model (configuration "D") + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs) + + +def vgg16_bn(pretrained=False, progress=True, **kwargs): + r"""VGG 16-layer model (configuration "D") with batch normalization + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs) + + +def vgg19(pretrained=False, progress=True, **kwargs): + r"""VGG 19-layer model (configuration "E") + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs) + + +def vgg19_bn(pretrained=False, progress=True, **kwargs): + r"""VGG 19-layer model (configuration 'E') with batch normalization + `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs) \ No newline at end of file diff --git a/model_measuring/kamal/vision/models/detection/__init__.py b/model_measuring/kamal/vision/models/detection/__init__.py new file mode 100644 index 0000000..1e202b7 --- /dev/null +++ b/model_measuring/kamal/vision/models/detection/__init__.py @@ -0,0 +1 @@ +from .craft import * \ No newline at end of file diff --git a/model_measuring/kamal/vision/models/detection/craft/__init__.py b/model_measuring/kamal/vision/models/detection/craft/__init__.py new file mode 100644 index 0000000..1e202b7 --- /dev/null +++ b/model_measuring/kamal/vision/models/detection/craft/__init__.py @@ -0,0 +1 @@ +from .craft import * \ No newline at end of file diff --git a/model_measuring/kamal/vision/models/detection/craft/craft.py b/model_measuring/kamal/vision/models/detection/craft/craft.py new file mode 100644 index 0000000..42ab3a6 --- /dev/null +++ b/model_measuring/kamal/vision/models/detection/craft/craft.py @@ -0,0 +1,96 @@ +""" +Copyright (c) 2019-present NAVER Corp. +MIT License +""" + +# -*- coding: utf-8 -*- +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init + +# from .vgg16_bn import vgg16_bn, init_weights +# from kamal.vision.models.classification.vgg import vgg16_bn +from .vgg16_bn import vgg16_bn + +class double_conv(nn.Module): + def __init__(self, in_ch, mid_ch, out_ch): + super(double_conv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1), + nn.BatchNorm2d(mid_ch), + nn.ReLU(inplace=True), + nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1), + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + x = self.conv(x) + return x + + +class CRAFT(nn.Module): + def __init__(self, pretrained=False, freeze=False): + super(CRAFT, self).__init__() + + """ Base network """ + self.basenet = vgg16_bn(pretrained, freeze) + + """ U network """ + self.upconv1 = double_conv(1024, 512, 256) + self.upconv2 = double_conv(512, 256, 128) + self.upconv3 = double_conv(256, 128, 64) + self.upconv4 = double_conv(128, 64, 32) + + num_class = 2 + self.conv_cls = nn.Sequential( + nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), + nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), + nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True), + nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True), + nn.Conv2d(16, num_class, kernel_size=1), + ) + + self.init_weights(self.upconv1.modules()) + self.init_weights(self.upconv2.modules()) + self.init_weights(self.upconv3.modules()) + self.init_weights(self.upconv4.modules()) + self.init_weights(self.conv_cls.modules()) + + def init_weights(self, modules): + for m in modules: + if isinstance(m, nn.Conv2d): + init.xavier_uniform_(m.weight.data) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + m.weight.data.normal_(0, 0.01) + m.bias.data.zero_() + + def forward(self, x): + """ Base network """ + sources = self.basenet(x) + + """ U network """ + y = torch.cat([sources[0], sources[1]], dim=1) + y = self.upconv1(y) + + y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False) + y = torch.cat([y, sources[2]], dim=1) + y = self.upconv2(y) + + y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False) + y = torch.cat([y, sources[3]], dim=1) + y = self.upconv3(y) + + y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False) + y = torch.cat([y, sources[4]], dim=1) + feature = self.upconv4(y) + + y = self.conv_cls(feature) + + return y.permute(0,2,3,1), feature \ No newline at end of file diff --git a/model_measuring/kamal/vision/models/detection/craft/vgg16_bn.py b/model_measuring/kamal/vision/models/detection/craft/vgg16_bn.py new file mode 100644 index 0000000..f3f21a7 --- /dev/null +++ b/model_measuring/kamal/vision/models/detection/craft/vgg16_bn.py @@ -0,0 +1,73 @@ +from collections import namedtuple + +import torch +import torch.nn as nn +import torch.nn.init as init +from torchvision import models +from torchvision.models.vgg import model_urls + +def init_weights(modules): + for m in modules: + if isinstance(m, nn.Conv2d): + init.xavier_uniform_(m.weight.data) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + m.weight.data.normal_(0, 0.01) + m.bias.data.zero_() + +class vgg16_bn(torch.nn.Module): + def __init__(self, pretrained=True, freeze=True): + super(vgg16_bn, self).__init__() + model_urls['vgg16_bn'] = model_urls['vgg16_bn'].replace('https://', 'http://') + vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + for x in range(12): # conv2_2 + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 19): # conv3_3 + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(19, 29): # conv4_3 + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(29, 39): # conv5_3 + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + + # fc6, fc7 without atrous conv + self.slice5 = torch.nn.Sequential( + nn.MaxPool2d(kernel_size=3, stride=1, padding=1), + nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6), + nn.Conv2d(1024, 1024, kernel_size=1) + ) + + if not pretrained: + init_weights(self.slice1.modules()) + init_weights(self.slice2.modules()) + init_weights(self.slice3.modules()) + init_weights(self.slice4.modules()) + + init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7 + + if freeze: + for param in self.slice1.parameters(): # only first conv + param.requires_grad= False + + def forward(self, X): + h = self.slice1(X) + h_relu2_2 = h + h = self.slice2(h) + h_relu3_2 = h + h = self.slice3(h) + h_relu4_3 = h + h = self.slice4(h) + h_relu5_3 = h + h = self.slice5(h) + h_fc7 = h + vgg_outputs = namedtuple("VggOutputs", ['fc7', 'relu5_3', 'relu4_3', 'relu3_2', 'relu2_2']) + out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2) + return out diff --git a/model_measuring/kamal/vision/models/segmentation/__init__.py b/model_measuring/kamal/vision/models/segmentation/__init__.py new file mode 100644 index 0000000..eae7c34 --- /dev/null +++ b/model_measuring/kamal/vision/models/segmentation/__init__.py @@ -0,0 +1,4 @@ +from .deeplab import * +from .linknet import * +from .segnet import * +from .unet import * \ No newline at end of file diff --git a/model_measuring/kamal/vision/models/segmentation/deeplab/__init__.py b/model_measuring/kamal/vision/models/segmentation/deeplab/__init__.py new file mode 100644 index 0000000..57be0ea --- /dev/null +++ b/model_measuring/kamal/vision/models/segmentation/deeplab/__init__.py @@ -0,0 +1,2 @@ +from .deeplab import * +from .layer import convert_to_separable_conv \ No newline at end of file diff --git a/model_measuring/kamal/vision/models/segmentation/deeplab/deeplab.py b/model_measuring/kamal/vision/models/segmentation/deeplab/deeplab.py new file mode 100644 index 0000000..88ddc5d --- /dev/null +++ b/model_measuring/kamal/vision/models/segmentation/deeplab/deeplab.py @@ -0,0 +1,148 @@ +# modified from https://github.com/VainF/DeepLabV3Plus-Pytorch +from .utils import IntermediateLayerGetter +from .layer import DeepLabv3Head, DeepLabv3PlusHead +from ...classification import mobilenetv2, resnet + +import torch.nn as nn +import torch.nn.functional as F + +from torchvision.models.utils import load_state_dict_from_url + +__all__=['DeepLabV3', + 'deeplabv3_mobilenetv2', 'deeplabv3_resnet50', 'deeplabv3_resnet101', + 'deeplabv3plus_mobilenetv2', 'deeplabv3plus_resnet50', 'deeplabv3plus_resnet101'] + +model_urls = { + 'deeplabv3_mobilenetv2': None, + 'deeplabv3_resnet50': None, + 'deeplabv3_resnet101': None, + + 'deeplabv3plus_mobilenetv2': None, + 'deeplabv3plus_resnet50': None, + 'deeplabv3plus_resnet101': None, +} + +class DeepLabV3(nn.Module): + def __init__(self, arch='deeplabv3_mobilenetv2', num_classes=21, output_stride=16, pretrained_backbone=False, aspp_dilate=None): + super(DeepLabV3, self).__init__() + assert arch in __all__[1:], "arch_name for deeplab should be one of %s"%( __all__[1:] ) + + arch_type, backbone_name = arch.split('_') + + if backbone_name=='mobilenetv2': + backbone, classifier = _segm_mobilenet(arch_type, backbone_name, num_classes, + output_stride=output_stride, pretrained_backbone=pretrained_backbone, aspp_dilate=aspp_dilate) + elif backbone_name.startswith('resnet'): + backbone, classifier = _segm_resnet(arch_type, backbone_name, num_classes, + output_stride=output_stride, pretrained_backbone=pretrained_backbone, aspp_dilate=aspp_dilate) + else: + print("backbone nam") + raise NotImplementedError + + self.backbone = backbone + self.classifier = classifier + + def forward(self, x): + input_shape = x.shape[-2:] + features = self.backbone(x) + x = self.classifier(features) + x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) + return x + +def _segm_resnet(name, backbone_name, num_classes, output_stride, pretrained_backbone, aspp_dilate=None): + + if output_stride==8: + replace_stride_with_dilation=[False, True, True] + aspp_dilate = [12, 24, 36] if aspp_dilate is None else aspp_dilate + else: + replace_stride_with_dilation=[False, False, True] + aspp_dilate = [6, 12, 18] if aspp_dilate is None else aspp_dilate + + backbone = resnet.__dict__[backbone_name]( + pretrained=pretrained_backbone, + replace_stride_with_dilation=replace_stride_with_dilation) + + inplanes = 2048 + low_level_planes = 256 + + if name=='deeplabv3plus': + return_layers = {'layer4': 'out', 'layer1': 'low_level'} + classifier = DeepLabv3PlusHead(inplanes, low_level_planes, num_classes, aspp_dilate) + elif name=='deeplabv3': + return_layers = {'layer4': 'out'} + classifier = DeepLabv3Head(inplanes , num_classes, aspp_dilate) + backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) + + #model = DeepLabV3(backbone, classifier) + return backbone, classifier + +def _segm_mobilenet(name, backbone_name, num_classes, output_stride, pretrained_backbone, aspp_dilate=None): + if aspp_dilate is None: + if output_stride==8: + aspp_dilate = [12, 24, 36] + else: + aspp_dilate = [6, 12, 18] + + backbone = mobilenetv2.mobilenet_v2(pretrained=pretrained_backbone, output_stride=output_stride) + + backbone.low_level_features = backbone.features[0:4] + backbone.high_level_features = backbone.features[4:-1] + backbone.features = None + backbone.classifier = None + + inplanes = 320 + low_level_planes = 24 + + if name=='deeplabv3plus': + return_layers = {'high_level_features': 'out', 'low_level_features': 'low_level'} + classifier = DeepLabv3PlusHead(inplanes, low_level_planes, num_classes, aspp_dilate) + elif name=='deeplabv3': + return_layers = {'high_level_features': 'out'} + classifier = DeepLabv3Head(inplanes , num_classes, aspp_dilate) + backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) + + #model = DeepLabV3(backbone, classifier) + return backbone, classifier + +def deeplabv3_mobilenetv2(pretrained=False, progress=True, **kwargs): + model = DeepLabV3(arch='deeplabv3_mobilenetv2', **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + return model + +def deeplabv3_resnet50(pretrained=False, progress=True, **kwargs): + model = DeepLabV3(arch='deeplabv3_resnet50', **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + return model + +def deeplabv3_resnet101(pretrained=False, progress=True, **kwargs): + model = DeepLabV3(arch='deeplabv3_resnet101', **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + return model + + +def deeplabv3plus_mobilenetv2(pretrained=False, progress=True, **kwargs): + model = DeepLabV3(arch='deeplabv3plus_mobilenetv2', **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + return model + +def deeplabv3plus_resnet50(pretrained=False, progress=True, **kwargs): + model = DeepLabV3(arch='deeplabv3plus_resnet50', **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + return model + +def deeplabv3plus_resnet101(pretrained=False, progress=True, **kwargs): + model = DeepLabV3(arch='deeplabv3plus_resnet101', **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + return model \ No newline at end of file diff --git a/model_measuring/kamal/vision/models/segmentation/deeplab/layer.py b/model_measuring/kamal/vision/models/segmentation/deeplab/layer.py new file mode 100644 index 0000000..69b1e8d --- /dev/null +++ b/model_measuring/kamal/vision/models/segmentation/deeplab/layer.py @@ -0,0 +1,159 @@ +# modified from https://github.com/VainF/DeepLabV3Plus-Pytorch +import torch +from torch import nn +from torch.nn import functional as F + +__all__ = ["DeepLabV3"] + + +class DeepLabv3PlusHead(nn.Module): + def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]): + super(DeepLabv3PlusHead, self).__init__() + self.project = nn.Sequential( + nn.Conv2d(low_level_channels, 48, 1, bias=False), + nn.BatchNorm2d(48), + nn.ReLU(inplace=True), + ) + + self.aspp = ASPP(in_channels, aspp_dilate) + + self.classifier = nn.Sequential( + nn.Conv2d(304, 256, 3, padding=1, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.Conv2d(256, num_classes, 1) + ) + self._init_weight() + + def forward(self, feature): + low_level_feature = self.project( feature['low_level'] ) + output_feature = self.aspp(feature['out']) + output_feature = F.interpolate(output_feature, size=low_level_feature.shape[2:], mode='bilinear', align_corners=False) + return self.classifier( torch.cat( [ low_level_feature, output_feature ], dim=1 ) ) + + def _init_weight(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + +class DeepLabv3Head(nn.Module): + def __init__(self, in_channels, num_classes, aspp_dilate=[12, 24, 36]): + super(DeepLabv3Head, self).__init__() + + self.classifier = nn.Sequential( + ASPP(in_channels, aspp_dilate), + nn.Conv2d(256, 256, 3, padding=1, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.Conv2d(256, num_classes, 1) + ) + self._init_weight() + + def forward(self, feature): + return self.classifier( feature['out'] ) + + def _init_weight(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + +class AtrousSeparableConvolution(nn.Module): + """ Atrous Separable Convolution + """ + def __init__(self, in_channels, out_channels, kernel_size, + stride=1, padding=0, dilation=1, bias=True): + super(AtrousSeparableConvolution, self).__init__() + self.body = nn.Sequential( + # Separable Conv + nn.Conv2d( in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, groups=in_channels ), + # PointWise Conv + nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias), + ) + + self._init_weight() + + def forward(self, x): + return self.body(x) + + def _init_weight(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + +class ASPPConv(nn.Sequential): + def __init__(self, in_channels, out_channels, dilation): + modules = [ + nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True) + ] + super(ASPPConv, self).__init__(*modules) + +class ASPPPooling(nn.Sequential): + def __init__(self, in_channels, out_channels): + super(ASPPPooling, self).__init__( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True)) + + def forward(self, x): + size = x.shape[-2:] + x = super(ASPPPooling, self).forward(x) + return F.interpolate(x, size=size, mode='bilinear', align_corners=False) + +class ASPP(nn.Module): + def __init__(self, in_channels, atrous_rates): + super(ASPP, self).__init__() + out_channels = 256 + modules = [] + modules.append(nn.Sequential( + nn.Conv2d(in_channels, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True))) + + rate1, rate2, rate3 = tuple(atrous_rates) + modules.append(ASPPConv(in_channels, out_channels, rate1)) + modules.append(ASPPConv(in_channels, out_channels, rate2)) + modules.append(ASPPConv(in_channels, out_channels, rate3)) + modules.append(ASPPPooling(in_channels, out_channels)) + + self.convs = nn.ModuleList(modules) + + self.project = nn.Sequential( + nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + nn.Dropout(0.1),) + + def forward(self, x): + res = [] + for conv in self.convs: + res.append(conv(x)) + res = torch.cat(res, dim=1) + return self.project(res) + + + +def convert_to_separable_conv(module): + new_module = module + if isinstance(module, nn.Conv2d) and module.kernel_size[0]>1: + new_module = AtrousSeparableConvolution(module.in_channels, + module.out_channels, + module.kernel_size, + module.stride, + module.padding, + module.dilation, + module.bias) + for name, child in module.named_children(): + new_module.add_module(name, convert_to_separable_conv(child)) + return new_module \ No newline at end of file diff --git a/model_measuring/kamal/vision/models/segmentation/deeplab/utils.py b/model_measuring/kamal/vision/models/segmentation/deeplab/utils.py new file mode 100644 index 0000000..4c1f10e --- /dev/null +++ b/model_measuring/kamal/vision/models/segmentation/deeplab/utils.py @@ -0,0 +1,55 @@ +# modified from https://github.com/VainF/DeepLabV3Plus-Pytorch +from collections import OrderedDict +import torch.nn as nn + +class IntermediateLayerGetter(nn.ModuleDict): + """ + Module wrapper that returns intermediate layers from a model + It has a strong assumption that the modules have been registered + into the model in the same order as they are used. + This means that one should **not** reuse the same nn.Module + twice in the forward if you want this to work. + Additionally, it is only able to query submodules that are directly + assigned to the model. So if `model` is passed, `model.feature1` can + be returned, but not `model.feature1.layer2`. + Arguments: + model (nn.Module): model on which we will extract the features + return_layers (Dict[name, new_name]): a dict containing the names + of the modules for which the activations will be returned as + the key of the dict, and the value of the dict is the name + of the returned activation (which the user can specify). + Examples:: + >>> m = torchvision.models.resnet18(pretrained=True) + >>> # extract layer1 and layer3, giving as names `feat1` and feat2` + >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m, + >>> {'layer1': 'feat1', 'layer3': 'feat2'}) + >>> out = new_m(torch.rand(1, 3, 224, 224)) + >>> print([(k, v.shape) for k, v in out.items()]) + >>> [('feat1', torch.Size([1, 64, 56, 56])), + >>> ('feat2', torch.Size([1, 256, 14, 14]))] + """ + def __init__(self, model, return_layers): + if not set(return_layers).issubset([name for name, _ in model.named_children()]): + raise ValueError("return_layers are not present in model") + + orig_return_layers = return_layers + return_layers = {k: v for k, v in return_layers.items()} + layers = OrderedDict() + for name, module in model.named_children(): + layers[name] = module + if name in return_layers: + del return_layers[name] + if not return_layers: + break + + super(IntermediateLayerGetter, self).__init__(layers) + self.return_layers = orig_return_layers + + def forward(self, x): + out = OrderedDict() + for name, module in self.named_children(): + x = module(x) + if name in self.return_layers: + out_name = self.return_layers[name] + out[out_name] = x + return out \ No newline at end of file diff --git a/model_measuring/kamal/vision/models/segmentation/linknet/__init__.py b/model_measuring/kamal/vision/models/segmentation/linknet/__init__.py new file mode 100644 index 0000000..4731931 --- /dev/null +++ b/model_measuring/kamal/vision/models/segmentation/linknet/__init__.py @@ -0,0 +1 @@ +from .linknet import * \ No newline at end of file diff --git a/model_measuring/kamal/vision/models/segmentation/linknet/linknet.py b/model_measuring/kamal/vision/models/segmentation/linknet/linknet.py new file mode 100644 index 0000000..07147a5 --- /dev/null +++ b/model_measuring/kamal/vision/models/segmentation/linknet/linknet.py @@ -0,0 +1,194 @@ +import torch.nn as nn +from ...classification.resnet import BasicBlock, Bottleneck, resnet18, resnet34, resnet50, resnet101, resnet152 + +__all__= ['LinkNet', 'linknet_resnet18','linknet_resnet34','linknet_resnet50','linknet_resnet101','linknet_resnet152'] + +model_urls = { + 'linknet_resnet18': None, + 'linknet_resnet34': None, + 'linknet_resnet50': None, + 'linknet_resnet101': None, + 'linknet_resnet152': None, +} + +_arch_dict = { + 'linknet_resnet18': ( (2, 2, 2, 2), BasicBlock ), + 'linknet_resnet34': ( (2, 2, 2, 2), BasicBlock ), + 'linknet_resnet50': ( (3, 4, 6, 3), Bottleneck ), + 'linknet_resnet101': ( (3, 4, 23, 3), Bottleneck ), + 'linknet_resnet152': ( (3, 8, 36, 3), Bottleneck ), +} + +_backbone_dict = { + 'linknet_resnet18': resnet18, + 'linknet_resnet34': resnet34, + 'linknet_resnet50': resnet50, + 'linknet_resnet101': resnet101, + 'linknet_resnet152': resnet152, +} + +__all__=['LinkNet'] + +class LinkNetDecoder(nn.Sequential): + def __init__(self, in_channels, out_channels, stride=1): + super(LinkNetDecoder, self).__init__( + nn.Conv2d(in_channels, in_channels//4, kernel_size=1, padding=0, stride=1, bias=False), + nn.BatchNorm2d(in_channels//4), + nn.ReLU(inplace=True), + + # upsample + nn.ConvTranspose2d(in_channels//4, in_channels//4, kernel_size=3, stride=stride, padding=1, output_padding=int(stride==2)), + nn.BatchNorm2d(in_channels//4), + nn.ReLU(inplace=True), + + nn.Conv2d(in_channels//4, out_channels, kernel_size=1, padding=0, stride=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + +class LinkNet(nn.Module): + def __init__(self, arch='linknet_resnet18', num_classes=21, in_channels=3, pretrained_backbone=False, channel_list=(64, 128, 256, 512 ), block=BasicBlock): + super(LinkNet, self).__init__() + + # predefined arch + if isinstance(arch, str): + arch_name = arch + assert arch_name in _arch_dict.keys(), "arch_name for SegNet should be one of %s"%( _arch_dict.keys() ) + arch, block = _arch_dict[arch_name] + # customized arch + elif isinstance( arch, (list, tuple) ): + arch_name = 'customized' + + # Encoder + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.inplanes = 64 + + # Encoder + self.layer1 = self._make_layer(block, planes=channel_list[0], blocks=arch[0]) + self.layer2 = self._make_layer(block, planes=channel_list[1], blocks=arch[1], stride=2) + self.layer3 = self._make_layer(block, planes=channel_list[2], blocks=arch[2], stride=2) + self.layer4 = self._make_layer(block, planes=channel_list[3], blocks=arch[3], stride=2) + + decoder_channel_list = [ c*block.expansion for c in channel_list ] + # Decoder + self.decoder4 = LinkNetDecoder(decoder_channel_list[3], decoder_channel_list[2], stride=2) + self.decoder3 = LinkNetDecoder(decoder_channel_list[2], decoder_channel_list[1], stride=2) + self.decoder2 = LinkNetDecoder(decoder_channel_list[1], decoder_channel_list[0], stride=2) + self.decoder1 = LinkNetDecoder(decoder_channel_list[0], decoder_channel_list[0]) + + # Final Classifier + self.classifier = nn.Sequential( + nn.ConvTranspose2d(decoder_channel_list[0], 32, kernel_size=3, stride=2, padding=1, output_padding=1), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + + nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + + nn.ConvTranspose2d(32, num_classes, kernel_size=2, stride=2, padding=0) + ) + + if pretrained_backbone: + self.load_from_pretrained_resnet(arch_name) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False, + ), + nn.BatchNorm2d(planes * block.expansion), + ) + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + return nn.Sequential(*layers) + + def load_from_pretrained_resnet(self, resnet): + if isinstance(resnet, str): + resnet = _backbone_dict[ resnet ](pretrained=True) + + def copy_params(layer1, layer2): + for p1, p2 in zip( layer1.parameters(), layer2.parameters() ): + p1.data = p2.data + + linknet_part = [ self.conv1, self.bn1, self.layer1, self.layer2, self.layer3, self.layer4 ] + resnet_part = [ resnet.conv1, resnet.bn1, resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 ] + + for linknet_layer, resnet_layer in zip( linknet_part, resnet_part ): + copy_params( linknet_layer, resnet_layer ) + + def forward(self, x): + # Encoder + out_size = x.shape[2:] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + e1 = self.layer1(x) + e2 = self.layer2(e1) + e3 = self.layer3(e2) + e4 = self.layer4(e3) + + d4 = self.decoder4(e4) + d4 = d4+e3 + d3 = self.decoder3(d4) + d3 = d3+e2 + d2 = self.decoder2(d3) + d2 = d2+e1 + d1 = self.decoder1(d2) + logits = self.classifier(d1) + if logits.shape[2:]!=out_size: + logits = nn.functional.interpolate( logits, size=out_size, mode='bilinear', align_corners=True ) + return logits + + +def linknet_resnet18(pretrained=False, progress=True, **kwargs): + model = LinkNet(arch='linknet_resnet18', **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + return model + +def linknet_resnet34(pretrained=False, progress=True, **kwargs): + model = LinkNet(arch='linknet_resnet34', **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + return model + +def linknet_resnet50(pretrained=False, progress=True, **kwargs): + model = LinkNet(arch='linknet_resnet50', **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + return model + +def linknet_resnet101(pretrained=False, progress=True, **kwargs): + model = LinkNet(arch='linknet_resnet101', **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + return model + +def linknet_resnet152(pretrained=False, progress=True, **kwargs): + model = LinkNet(arch='linknet_resnet152', **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + return model + diff --git a/model_measuring/kamal/vision/models/segmentation/segnet/__init__.py b/model_measuring/kamal/vision/models/segmentation/segnet/__init__.py new file mode 100644 index 0000000..18c2942 --- /dev/null +++ b/model_measuring/kamal/vision/models/segmentation/segnet/__init__.py @@ -0,0 +1 @@ +from .segnet import * \ No newline at end of file diff --git a/model_measuring/kamal/vision/models/segmentation/segnet/layer.py b/model_measuring/kamal/vision/models/segmentation/segnet/layer.py new file mode 100644 index 0000000..0f5274e --- /dev/null +++ b/model_measuring/kamal/vision/models/segmentation/segnet/layer.py @@ -0,0 +1,53 @@ +# Modified from https://github.com/meetshah1995/pytorch-semseg/blob/801fb20054/ptsemseg/models/segnet.py +import torch.nn as nn + +class ConvBNRelu(nn.Sequential): + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=True, dilation=1, batch_norm=True, activation=True): + conv_mod = nn.Conv2d(int(in_channels), int(out_channels), kernel_size=kernel_size, padding=padding, stride=stride, bias=bias, dilation=dilation) + if batch_norm: + cbr_unit = [conv_mod, nn.BatchNorm2d(int(out_channels)) ] + else: + cbr_unit = [ conv_mod ] + + if activation==True: + cbr_unit.append( nn.ReLU(inplace=True) ) + super(ConvBNRelu, self).__init__(*cbr_unit) + +class SegnetDown(nn.Module): + def __init__(self, in_channels, out_channels, num_convs=2, batch_norm=True): + super(SegnetDown, self).__init__() + layers = [ConvBNRelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, batch_norm=batch_norm)] + for _ in range(num_convs-1): + layers.append( ConvBNRelu(out_channels, out_channels, kernel_size=3, stride=1, padding=1, batch_norm=batch_norm) ) + + self.layers = nn.Sequential( *layers ) + self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True) + + def forward(self, inputs): + outputs = self.layers(inputs) + ori_shape = outputs.size() + outputs, indices = self.maxpool_with_argmax(outputs) + return outputs, indices, ori_shape + +class SegnetUp(nn.Module): + def __init__(self, in_channels, out_channels, num_convs=2, outer_most=False, batch_norm=True): + super(SegnetUp, self).__init__() + if outer_most: + batch_norm = False + activation = False + else: + activation = True + + layers = [] + for _ in range(num_convs-1): + layers.append( ConvBNRelu(in_channels, in_channels, kernel_size=3, stride=1, padding=1, batch_norm=batch_norm) ) + # remove relu if it is the outer most layer + layers.append( ConvBNRelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, batch_norm=batch_norm, activation=activation) ) + self.unpool = nn.MaxUnpool2d(2, 2) + self.layers = nn.Sequential( *layers ) + + def forward(self, inputs, indices, ori_shape): + outputs = self.unpool(input=inputs, indices=indices, output_size=ori_shape) + outputs = self.layers(outputs) + return outputs + diff --git a/model_measuring/kamal/vision/models/segmentation/segnet/segnet.py b/model_measuring/kamal/vision/models/segmentation/segnet/segnet.py new file mode 100644 index 0000000..4dc86db --- /dev/null +++ b/model_measuring/kamal/vision/models/segmentation/segnet/segnet.py @@ -0,0 +1,194 @@ +# Modified from https://github.com/meetshah1995/pytorch-semseg/blob/801fb20054/ptsemseg/models/segnet.py + +from .layer import SegnetDown, SegnetUp +import torch.nn as nn +from ...classification import vgg +from torchvision.models.utils import load_state_dict_from_url + +__all__=[ 'SegNet', + 'segnet_vgg11', 'segnet_vgg13','segnet_vgg16','segnet_vgg19', + 'segnet_vgg11_bn','segnet_vgg13_bn','segnet_vgg16_bn','segnet_vgg19_bn' ] + +model_urls = { + 'segnet_vgg11': None, + 'segnet_vgg13': None, + 'segnet_vgg16': None, + 'segnet_vgg19': None, + 'segnet_vgg11_bn': None, + 'segnet_vgg13_bn': None, + 'segnet_vgg16_bn': None, + 'segnet_vgg19_bn': None, +} + +_arch_dict = { + 'segnet_vgg11_bn': [1, 1, 2, 2, 2], + 'segnet_vgg13_bn': [2, 2, 2, 2, 2], + 'segnet_vgg16_bn': [2, 2, 3, 3, 3], + 'segnet_vgg19_bn': [2, 2, 4, 4, 4], + 'segnet_vgg11': [1, 1, 2, 2, 2], + 'segnet_vgg13': [2, 2, 2, 2, 2], + 'segnet_vgg16': [2, 2, 3, 3, 3], + 'segnet_vgg19': [2, 2, 4, 4, 4], +} + +_backbone_dict = { + 'segnet_vgg11_bn': vgg.vgg11_bn, + 'segnet_vgg13_bn': vgg.vgg13_bn, + 'segnet_vgg16_bn': vgg.vgg16_bn, + 'segnet_vgg19_bn': vgg.vgg19_bn, + 'segnet_vgg11': vgg.vgg11, + 'segnet_vgg13': vgg.vgg13, + 'segnet_vgg16': vgg.vgg16, + 'segnet_vgg19': vgg.vgg19, +} + + +class SegNet( nn.Module ): + def __init__(self, arch='segnet_vgg16_bn', num_classes=21, in_channels=3, pretrained_backbone=False, batch_norm=True, channel_list=(64, 128, 256, 512, 512)): + super( SegNet, self ).__init__() + assert len(channel_list)==5, 'length of channel_list must be 5' + + # predefined arch + if isinstance(arch, str): + arch_name = arch + assert arch_name in _arch_dict.keys(), "arch_name for SegNet should be one of %s"%( _arch_dict.keys() ) + arch = _arch_dict[arch_name] + batch_norm=True if 'bn' in arch_name else False + # customized arch + elif isinstance( arch, (list, tuple) ): + arch_name = 'customized' + + self.num_classes = num_classes + self.in_channels = in_channels + + self.down1 = SegnetDown(self.in_channels, channel_list[0], num_convs=arch[0], batch_norm=batch_norm) # 64 + self.down2 = SegnetDown(channel_list[0], channel_list[1], num_convs=arch[1], batch_norm=batch_norm) # 128 + self.down3 = SegnetDown(channel_list[1], channel_list[2], num_convs=arch[2], batch_norm=batch_norm) # 256 + self.down4 = SegnetDown(channel_list[2], channel_list[3], num_convs=arch[3], batch_norm=batch_norm) # 512 + self.down5 = SegnetDown(channel_list[3], channel_list[4], num_convs=arch[4], batch_norm=batch_norm) # 512 + + self.up5 = SegnetUp(channel_list[4], channel_list[3], num_convs=arch[4], batch_norm=batch_norm) # 512 + self.up4 = SegnetUp(channel_list[3], channel_list[2], num_convs=arch[3], batch_norm=batch_norm) # 256 + self.up3 = SegnetUp(channel_list[2], channel_list[1], num_convs=arch[2], batch_norm=batch_norm) # 128 + self.up2 = SegnetUp(channel_list[1], channel_list[0], num_convs=arch[1], batch_norm=batch_norm) # 64 + self.up1 = SegnetUp(channel_list[0], self.num_classes, num_convs=arch[0], outer_most=True, batch_norm=batch_norm) + + if pretrained_backbone: + assert arch_name!='customized', 'Only predefined archs have pretrained weights' + self.load_from_pretrained_vgg(arch_name) + + def load_from_pretrained_vgg(self, vgg): + if isinstance(vgg, str): + vgg = _backbone_dict[ vgg ](pretrained=True) + + _blocks = [self.down1, self.down2, self.down3, self.down4, self.down5] + segnet_features = [] + for _block in _blocks: + for _layer in _block.layers: + segnet_features.extend( _layer ) + + vgg_features = [ layer for layer in vgg.features if not isinstance( layer, nn.MaxPool2d ) ] + + for segnet_layer, vgg_layer in zip( segnet_features, vgg_features ): + assert type( segnet_layer ) == type( vgg_layer ), "Inconsistant layer: %s, %s"%(type( segnet_features ), type( vgg_layer )) + if isinstance( segnet_layer, nn.Conv2d ): + segnet_layer.weight.data = vgg_layer.weight.data + segnet_layer.bias.data = vgg_layer.bias.data + elif isinstance( segnet_layer, nn.BatchNorm2d): + segnet_layer.weight.data = vgg_layer.weight.data + segnet_layer.bias.data = vgg_layer.bias.data + segnet_layer.running_mean.data = vgg_layer.running_mean.data + segnet_layer.running_var.data = vgg_layer.running_var.data + + def forward(self, inputs): + down1, indices_1, unpool_shape1 = self.down1(inputs) + down2, indices_2, unpool_shape2 = self.down2(down1) + down3, indices_3, unpool_shape3 = self.down3(down2) + down4, indices_4, unpool_shape4 = self.down4(down3) + down5, indices_5, unpool_shape5 = self.down5(down4) + + up5 = self.up5(down5, indices_5, unpool_shape5) + up4 = self.up4(up5, indices_4, unpool_shape4) + up3 = self.up3(up4, indices_3, unpool_shape3) + up2 = self.up2(up3, indices_2, unpool_shape2) + up1 = self.up1(up2, indices_1, unpool_shape1) + return up1 + +def segnet_vgg11(pretrained=False, progress=True, **kwargs): + """Constructs a DeepLabV3+ model with a mobilenet backbone. + """ + model = SegNet(arch='segnet_vgg11', **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls['segnet_vgg11'], progress=progress) + model.load_state_dict(state_dict) + return model + +def segnet_vgg11_bn(pretrained=False, progress=True, **kwargs): + """Constructs a DeepLabV3+ model with a mobilenet backbone. + """ + model = SegNet(arch='segnet_vgg11_bn', **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls['segnet_vgg11_bn'], progress=progress) + model.load_state_dict(state_dict) + return model + +def segnet_vgg13(pretrained=False, progress=True, **kwargs): + """Constructs a DeepLabV3+ model with a mobilenet backbone. + """ + model = SegNet(arch='segnet_vgg13', **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls['segnet_vgg13'], progress=progress) + model.load_state_dict(state_dict) + return model + +def segnet_vgg13_bn(pretrained=False, progress=True, **kwargs): + """Constructs a DeepLabV3+ model with a mobilenet backbone. + """ + model = SegNet(arch='segnet_vgg13_bn', **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls['segnet_vgg13_bn'], progress=progress) + model.load_state_dict(state_dict) + return model + + +def segnet_vgg16(pretrained=False, progress=True, **kwargs): + """Constructs a DeepLabV3+ model with a mobilenet backbone. + """ + model = SegNet(arch='segnet_vgg16', **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls['segnet_vgg16'], progress=progress) + model.load_state_dict(state_dict) + return model + +def segnet_vgg16_bn(pretrained=False, progress=True, **kwargs): + """Constructs a DeepLabV3+ model with a mobilenet backbone. + """ + model = SegNet(arch='segnet_vgg16_bn', **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls['segnet_vgg16_bn'], progress=progress) + model.load_state_dict(state_dict) + return model + +def segnet_vgg19(pretrained=False, progress=True, **kwargs): + """Constructs a DeepLabV3+ model with a mobilenet backbone. + """ + model = SegNet(arch='segnet_vgg19', **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls['segnet_vgg19'], progress=progress) + model.load_state_dict(state_dict) + return model + +def segnet_vgg19_bn(pretrained=False, progress=True, **kwargs): + """Constructs a DeepLabV3+ model with a mobilenet backbone. + """ + model = SegNet(arch='segnet_vgg19_bn', **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls['segnet_vgg19_bn'], progress=progress) + model.load_state_dict(state_dict) + return model + +if __name__=='__main__': + import torch + model = SegNet(num_classes=21) + print(model) + print( model( torch.randn(1,3,256,256) ).shape ) diff --git a/model_measuring/kamal/vision/models/segmentation/unet/__init__.py b/model_measuring/kamal/vision/models/segmentation/unet/__init__.py new file mode 100644 index 0000000..b804fc6 --- /dev/null +++ b/model_measuring/kamal/vision/models/segmentation/unet/__init__.py @@ -0,0 +1 @@ +from .unet import * \ No newline at end of file diff --git a/model_measuring/kamal/vision/models/segmentation/unet/layer.py b/model_measuring/kamal/vision/models/segmentation/unet/layer.py new file mode 100644 index 0000000..f4998e6 --- /dev/null +++ b/model_measuring/kamal/vision/models/segmentation/unet/layer.py @@ -0,0 +1,51 @@ +# Modified from https://github.com/meetshah1995/pytorch-semseg +import torch.nn as nn +import torch.nn.functional as F +import torch + +class ConvBNRelu(nn.Sequential): + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=True, dilation=1, batch_norm=True, activation=True): + conv_mod = nn.Conv2d(int(in_channels), int(out_channels), kernel_size=kernel_size, padding=padding, stride=stride, bias=bias, dilation=dilation) + if batch_norm: + cbr_unit = [conv_mod, nn.BatchNorm2d(int(out_channels)) ] + else: + cbr_unit = [ conv_mod ] + + if activation==True: + cbr_unit.append( nn.ReLU(inplace=True) ) + super(ConvBNRelu, self).__init__(*cbr_unit) + +class DoubleConv(nn.Sequential): + def __init__(self, in_channels, out_channels, batch_norm): + super(DoubleConv, self).__init__( + ConvBNRelu( in_channels, out_channels, 3, 1, 1, batch_norm=batch_norm), + ConvBNRelu( out_channels, out_channels, 3, 1, 1, batch_norm=batch_norm) + ) + +class Down(nn.Module): + def __init__(self, in_channels, out_channels, batch_norm=True): + super(Down, self).__init__() + self.double_conv = DoubleConv(in_channels, out_channels, batch_norm ) + self.downsample = nn.MaxPool2d(2,2) + + def forward(self, inputs): + conv_features = self.double_conv(inputs) + outputs = self.downsample(conv_features) + return outputs, conv_features + +class Up(nn.Module): + def __init__(self, in_channels, out_channels, batch_norm=True, deconv=True): + super(Up, self).__init__() + self.upsample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) if deconv else nn.UpsamplingBilinear2d(scale_factor=2) + self.double_conv = DoubleConv(in_channels, out_channels, batch_norm ) + + def forward(self, inputs, skip_inputs): + outputs = self.upsample(inputs) + padding_h = skip_inputs.size()[2] - outputs.size()[2] + padding_w = skip_inputs.size()[3] - outputs.size()[3] + padding = [ padding_w//2, padding_w-padding_w//2, padding_h//2, padding_h-padding_h//2 ] + outputs = F.pad(outputs, pad=padding) + + outputs = torch.cat( [ skip_inputs, outputs ], dim=1 ) + outputs = self.double_conv(outputs) + return outputs \ No newline at end of file diff --git a/model_measuring/kamal/vision/models/segmentation/unet/unet.py b/model_measuring/kamal/vision/models/segmentation/unet/unet.py new file mode 100644 index 0000000..72d5c70 --- /dev/null +++ b/model_measuring/kamal/vision/models/segmentation/unet/unet.py @@ -0,0 +1,58 @@ +# Modified from https://github.com/meetshah1995/pytorch-semseg +import torch.nn as nn +import torch.nn.functional as F +import torch +from .layer import DoubleConv, Down, Up + +__all__=['UNet', 'unet'] + +model_urls = { + 'unet': None, +} + +class UNet(nn.Module): + def __init__(self, num_classes=21, in_channels=3, deconv=True, batch_norm=True, channel_list=(64, 128, 256, 512, 1024)): + super(UNet, self).__init__() + assert len(channel_list)==5, 'length of channel_list must be 5' + # downsampling + self.down1 = Down(in_channels, channel_list[0], batch_norm) # 64 + self.down2 = Down(channel_list[0], channel_list[1], batch_norm) # 128 + self.down3 = Down(channel_list[1], channel_list[2], batch_norm) # 256 + self.down4 = Down(channel_list[2], channel_list[3], batch_norm) # 512 + self.center = DoubleConv(channel_list[3], channel_list[4], batch_norm) # 1024 + + # upsampling + self.up4 = Up(channel_list[4], channel_list[3], batch_norm, deconv) # 512 + self.up3 = Up(channel_list[3], channel_list[2], batch_norm, deconv) # 256 + self.up2 = Up(channel_list[2], channel_list[1], batch_norm, deconv) # 128 + self.up1 = Up(channel_list[1], channel_list[0], batch_norm, deconv) # 64 + + self.classifier = nn.Conv2d(channel_list[0], num_classes, 1) + + def forward(self, inputs): + out_size = inputs.shape[2:] + out, conv_features1 = self.down1(inputs) + out, conv_features2 = self.down2(out) + out, conv_features3 = self.down3(out) + out, conv_features4 = self.down4(out) + + out = self.center(out) + + out = self.up4(out, conv_features4) + out = self.up3(out, conv_features3) + out = self.up2(out, conv_features2) + out = self.up1(out, conv_features1) + + out = self.classifier(out) + if out.shape[2:]!=out_size: + out = nn.functional.interpolate( out, size=out_size, mode='bilinear', align_corners=True ) + return out + +def unet(pretrained=False, progress=True, **kwargs): + """Constructs a DeepLabV3+ model with a mobilenet backbone. + """ + model = UNet(**kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + return model \ No newline at end of file diff --git a/model_measuring/kamal/vision/models/utils.py b/model_measuring/kamal/vision/models/utils.py new file mode 100644 index 0000000..fe581e4 --- /dev/null +++ b/model_measuring/kamal/vision/models/utils.py @@ -0,0 +1,115 @@ +# A synchronized version modified from https://github.com/pytorch/vision +import os, sys +import torch +import torch.nn as nn +import numpy as np + +from torch.hub import * + +def _get_torch_home(): + torch_home = os.path.expanduser( + os.getenv(ENV_TORCH_HOME, + os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'))) + return torch_home + +def download_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False): + r""" + Adapted from torchvision.models.utils.load_state_dict_from_url + This function only download files from the specified url and return its path as a string. + It is used to get weight files in other formats. + """ + + # Issue warning to move data if old env is set + if os.getenv('TORCH_MODEL_ZOO'): + warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') + + if model_dir is None: + torch_home = _get_torch_home() + model_dir = os.path.join(torch_home, 'checkpoints') + + try: + os.makedirs(model_dir) + except OSError as e: + if e.errno == errno.EEXIST: + # Directory already exists, ignore. + pass + else: + # Unexpected OSError, re-raise. + raise + + parts = urlparse(url) + filename = os.path.basename(parts.path) + cached_file = os.path.join(model_dir, filename) + if not os.path.exists(cached_file): + sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) + hash_prefix = HASH_REGEX.search(filename).group(1) if check_hash else None + download_url_to_file(url, cached_file, hash_prefix, progress=progress) + + return cached_file + + +def load_darknet_weights( model, darknet_file): + layers_with_params = [ layer for layer in model.modules() \ + if isinstance( layer, (nn.Conv2d, nn.Linear, nn.BatchNorm2d) ) ] + + with open(darknet_file, 'rb') as fp: + major, minor, revision = np.fromfile(fp, dtype=np.int32, count=3) + if major*10 + minor >= 2 and major < 1000 and minor < 1000: + seen = np.fromfile(fp, dtype=np.int64, count=1) + else: + seen = np.fromfile(fp, dtype=np.int32, count=1) + #transpose = (major > 1000) | (minor > 1000); + + weights = np.fromfile( fp, dtype=np.float32 ) + + offset = 0 + for i, layer in enumerate( layers_with_params ) : + if isinstance( layer, nn.Conv2d ): + conv = layer + if i width, then image will be rescaled to + :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)` + interpolation (int, optional): Desired interpolation. Default is + ``PIL.Image.BILINEAR`` + + Returns: + PIL Image: Resized image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)): + raise TypeError('Got inappropriate size arg: {}'.format(size)) + + if isinstance(size, int): + w, h = img.size + if (w <= h and w == size) or (h <= w and h == size): + return img + if w < h: + ow = size + oh = int(size * h / w) + return img.resize((ow, oh), interpolation) + else: + oh = size + ow = int(size * w / h) + return img.resize((ow, oh), interpolation) + else: + return img.resize(size[::-1], interpolation) + + +def scale(*args, **kwargs): + warnings.warn("The use of the transforms.Scale transform is deprecated, " + + "please use transforms.Resize instead.") + return resize(*args, **kwargs) + + +def pad(img, padding, fill=0, padding_mode='constant'): + r"""Pad the given PIL Image on all sides with specified padding mode and fill value. + + Args: + img (PIL Image): Image to be padded. + padding (int or tuple): Padding on each border. If a single int is provided this + is used to pad all borders. If tuple of length 2 is provided this is the padding + on left/right and top/bottom respectively. If a tuple of length 4 is provided + this is the padding for the left, top, right and bottom borders + respectively. + fill: Pixel fill value for constant fill. Default is 0. If a tuple of + length 3, it is used to fill R, G, B channels respectively. + This value is only used when the padding_mode is constant + padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. + + - constant: pads with a constant value, this value is specified with fill + + - edge: pads with the last value on the edge of the image + + - reflect: pads with reflection of image (without repeating the last value on the edge) + + padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode + will result in [3, 2, 1, 2, 3, 4, 3, 2] + + - symmetric: pads with reflection of image (repeating the last value on the edge) + + padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode + will result in [2, 1, 1, 2, 3, 4, 4, 3] + + Returns: + PIL Image: Padded image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + if not isinstance(padding, (numbers.Number, tuple)): + raise TypeError('Got inappropriate padding arg') + if not isinstance(fill, (numbers.Number, str, tuple)): + raise TypeError('Got inappropriate fill arg') + if not isinstance(padding_mode, str): + raise TypeError('Got inappropriate padding_mode arg') + + if isinstance(padding, Sequence) and len(padding) not in [2, 4]: + raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + + "{} element tuple".format(len(padding))) + + assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'], \ + 'Padding mode should be either constant, edge, reflect or symmetric' + + if padding_mode == 'constant': + if img.mode == 'P': + palette = img.getpalette() + image = ImageOps.expand(img, border=padding, fill=fill) + image.putpalette(palette) + return image + + return ImageOps.expand(img, border=padding, fill=fill) + else: + if isinstance(padding, int): + pad_left = pad_right = pad_top = pad_bottom = padding + if isinstance(padding, Sequence) and len(padding) == 2: + pad_left = pad_right = padding[0] + pad_top = pad_bottom = padding[1] + if isinstance(padding, Sequence) and len(padding) == 4: + pad_left = padding[0] + pad_top = padding[1] + pad_right = padding[2] + pad_bottom = padding[3] + + if img.mode == 'P': + palette = img.getpalette() + img = np.asarray(img) + img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) + img = Image.fromarray(img) + img.putpalette(palette) + return img + + img = np.asarray(img) + # RGB image + if len(img.shape) == 3: + img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode) + # Grayscale image + if len(img.shape) == 2: + img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) + + return Image.fromarray(img) + + +def crop(img, i, j, h, w): + """Crop the given PIL Image. + + Args: + img (PIL Image): Image to be cropped. + i (int): i in (i,j) i.e coordinates of the upper left corner. + j (int): j in (i,j) i.e coordinates of the upper left corner. + h (int): Height of the cropped image. + w (int): Width of the cropped image. + + Returns: + PIL Image: Cropped image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return img.crop((j, i, j + w, i + h)) + + +def center_crop(img, output_size): + if isinstance(output_size, numbers.Number): + output_size = (int(output_size), int(output_size)) + w, h = img.size + th, tw = output_size + i = int(round((h - th) / 2.)) + j = int(round((w - tw) / 2.)) + return crop(img, i, j, th, tw) + + +def resized_crop(img, i, j, h, w, size, interpolation=Image.BILINEAR): + """Crop the given PIL Image and resize it to desired size. + + Notably used in :class:`~torchvision.transforms.RandomResizedCrop`. + + Args: + img (PIL Image): Image to be cropped. + i (int): i in (i,j) i.e coordinates of the upper left corner + j (int): j in (i,j) i.e coordinates of the upper left corner + h (int): Height of the cropped image. + w (int): Width of the cropped image. + size (sequence or int): Desired output size. Same semantics as ``resize``. + interpolation (int, optional): Desired interpolation. Default is + ``PIL.Image.BILINEAR``. + Returns: + PIL Image: Cropped image. + """ + assert _is_pil_image(img), 'img should be PIL Image' + img = crop(img, i, j, h, w) + img = resize(img, size, interpolation) + return img + + +def hflip(img): + """Horizontally flip the given PIL Image. + + Args: + img (PIL Image): Image to be flipped. + + Returns: + PIL Image: Horizontall flipped image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return img.transpose(Image.FLIP_LEFT_RIGHT) + + +def _get_perspective_coeffs(startpoints, endpoints): + """Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms. + + In Perspective Transform each pixel (x, y) in the orignal image gets transformed as, + (x, y) -> ( (ax + by + c) / (gx + hy + 1), (dx + ey + f) / (gx + hy + 1) ) + + Args: + List containing [top-left, top-right, bottom-right, bottom-left] of the orignal image, + List containing [top-left, top-right, bottom-right, bottom-left] of the transformed + image + Returns: + octuple (a, b, c, d, e, f, g, h) for transforming each pixel. + """ + matrix = [] + + for p1, p2 in zip(endpoints, startpoints): + matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]]) + matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]]) + + A = torch.tensor(matrix, dtype=torch.float) + B = torch.tensor(startpoints, dtype=torch.float).view(8) + res = torch.gels(B, A)[0] + return res.squeeze_(1).tolist() + + +def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC): + """Perform perspective transform of the given PIL Image. + + Args: + img (PIL Image): Image to be transformed. + coeffs (tuple) : 8-tuple (a, b, c, d, e, f, g, h) which contains the coefficients. + for a perspective transform. + interpolation: Default- Image.BICUBIC + Returns: + PIL Image: Perspectively transformed Image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + coeffs = _get_perspective_coeffs(startpoints, endpoints) + return img.transform(img.size, Image.PERSPECTIVE, coeffs, interpolation) + + +def vflip(img): + """Vertically flip the given PIL Image. + + Args: + img (PIL Image): Image to be flipped. + + Returns: + PIL Image: Vertically flipped image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return img.transpose(Image.FLIP_TOP_BOTTOM) + + +def five_crop(img, size): + """Crop the given PIL Image into four corners and the central crop. + + .. Note:: + This transform returns a tuple of images and there may be a + mismatch in the number of inputs and targets your ``Dataset`` returns. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + + Returns: + tuple: tuple (tl, tr, bl, br, center) + Corresponding top left, top right, bottom left, bottom right and center crop. + """ + if isinstance(size, numbers.Number): + size = (int(size), int(size)) + else: + assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + + w, h = img.size + crop_h, crop_w = size + if crop_w > w or crop_h > h: + raise ValueError("Requested crop size {} is bigger than input size {}".format(size, + (h, w))) + tl = img.crop((0, 0, crop_w, crop_h)) + tr = img.crop((w - crop_w, 0, w, crop_h)) + bl = img.crop((0, h - crop_h, crop_w, h)) + br = img.crop((w - crop_w, h - crop_h, w, h)) + center = center_crop(img, (crop_h, crop_w)) + return (tl, tr, bl, br, center) + + +def ten_crop(img, size, vertical_flip=False): + r"""Crop the given PIL Image into four corners and the central crop plus the + flipped version of these (horizontal flipping is used by default). + + .. Note:: + This transform returns a tuple of images and there may be a + mismatch in the number of inputs and targets your ``Dataset`` returns. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + vertical_flip (bool): Use vertical flipping instead of horizontal + + Returns: + tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip) + Corresponding top left, top right, bottom left, bottom right and center crop + and same for the flipped image. + """ + if isinstance(size, numbers.Number): + size = (int(size), int(size)) + else: + assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + + first_five = five_crop(img, size) + + if vertical_flip: + img = vflip(img) + else: + img = hflip(img) + + second_five = five_crop(img, size) + return first_five + second_five + + +def adjust_brightness(img, brightness_factor): + """Adjust brightness of an Image. + + Args: + img (PIL Image): PIL Image to be adjusted. + brightness_factor (float): How much to adjust the brightness. Can be + any non negative number. 0 gives a black image, 1 gives the + original image while 2 increases the brightness by a factor of 2. + + Returns: + PIL Image: Brightness adjusted image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Brightness(img) + img = enhancer.enhance(brightness_factor) + return img + + +def adjust_contrast(img, contrast_factor): + """Adjust contrast of an Image. + + Args: + img (PIL Image): PIL Image to be adjusted. + contrast_factor (float): How much to adjust the contrast. Can be any + non negative number. 0 gives a solid gray image, 1 gives the + original image while 2 increases the contrast by a factor of 2. + + Returns: + PIL Image: Contrast adjusted image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Contrast(img) + img = enhancer.enhance(contrast_factor) + return img + + +def adjust_saturation(img, saturation_factor): + """Adjust color saturation of an image. + + Args: + img (PIL Image): PIL Image to be adjusted. + saturation_factor (float): How much to adjust the saturation. 0 will + give a black and white image, 1 will give the original image while + 2 will enhance the saturation by a factor of 2. + + Returns: + PIL Image: Saturation adjusted image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + enhancer = ImageEnhance.Color(img) + img = enhancer.enhance(saturation_factor) + return img + + +def adjust_hue(img, hue_factor): + """Adjust hue of an image. + + The image hue is adjusted by converting the image to HSV and + cyclically shifting the intensities in the hue channel (H). + The image is then converted back to original image mode. + + `hue_factor` is the amount of shift in H channel and must be in the + interval `[-0.5, 0.5]`. + + See `Hue`_ for more details. + + .. _Hue: https://en.wikipedia.org/wiki/Hue + + Args: + img (PIL Image): PIL Image to be adjusted. + hue_factor (float): How much to shift the hue channel. Should be in + [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in + HSV space in positive and negative direction respectively. + 0 means no shift. Therefore, both -0.5 and 0.5 will give an image + with complementary colors while 0 gives the original image. + + Returns: + PIL Image: Hue adjusted image. + """ + if not(-0.5 <= hue_factor <= 0.5): + raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) + + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + input_mode = img.mode + if input_mode in {'L', '1', 'I', 'F'}: + return img + + h, s, v = img.convert('HSV').split() + + np_h = np.array(h, dtype=np.uint8) + # uint8 addition take cares of rotation across boundaries + with np.errstate(over='ignore'): + np_h += np.uint8(hue_factor * 255) + h = Image.fromarray(np_h, 'L') + + img = Image.merge('HSV', (h, s, v)).convert(input_mode) + return img + + +def adjust_gamma(img, gamma, gain=1): + r"""Perform gamma correction on an image. + + Also known as Power Law Transform. Intensities in RGB mode are adjusted + based on the following equation: + + .. math:: + I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma} + + See `Gamma Correction`_ for more details. + + .. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction + + Args: + img (PIL Image): PIL Image to be adjusted. + gamma (float): Non negative real number, same as :math:`\gamma` in the equation. + gamma larger than 1 make the shadows darker, + while gamma smaller than 1 make dark regions lighter. + gain (float): The constant multiplier. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + if gamma < 0: + raise ValueError('Gamma should be a non-negative real number') + + input_mode = img.mode + img = img.convert('RGB') + + gamma_map = [255 * gain * pow(ele / 255., gamma) for ele in range(256)] * 3 + img = img.point(gamma_map) # use PIL's point-function to accelerate this part + + img = img.convert(input_mode) + return img + + +def rotate(img, angle, resample=False, expand=False, center=None): + """Rotate the image by angle. + + + Args: + img (PIL Image): PIL Image to be rotated. + angle (float or int): In degrees degrees counter clockwise order. + resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): + An optional resampling filter. See `filters`_ for more information. + If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. + expand (bool, optional): Optional expansion flag. + If true, expands the output image to make it large enough to hold the entire rotated image. + If false or omitted, make the output image the same size as the input image. + Note that the expand flag assumes rotation around the center and no translation. + center (2-tuple, optional): Optional center of rotation. + Origin is the upper left corner. + Default is the center of the image. + + .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters + + """ + + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return img.rotate(angle, resample, expand, center) + + +def _get_inverse_affine_matrix(center, angle, translate, scale, shear): + # Helper method to compute inverse matrix for affine transformation + + # As it is explained in PIL.Image.rotate + # We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1 + # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1] + # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1] + # RSS is rotation with scale and shear matrix + # RSS(a, scale, shear) = [ cos(a)*scale -sin(a + shear)*scale 0] + # [ sin(a)*scale cos(a + shear)*scale 0] + # [ 0 0 1] + # Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1 + + angle = math.radians(angle) + shear = math.radians(shear) + scale = 1.0 / scale + + # Inverted rotation matrix with scale and shear + d = math.cos(angle + shear) * math.cos(angle) + math.sin(angle + shear) * math.sin(angle) + matrix = [ + math.cos(angle + shear), math.sin(angle + shear), 0, + -math.sin(angle), math.cos(angle), 0 + ] + matrix = [scale / d * m for m in matrix] + + # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 + matrix[2] += matrix[0] * (-center[0] - translate[0]) + matrix[1] * (-center[1] - translate[1]) + matrix[5] += matrix[3] * (-center[0] - translate[0]) + matrix[4] * (-center[1] - translate[1]) + + # Apply center translation: C * RSS^-1 * C^-1 * T^-1 + matrix[2] += center[0] + matrix[5] += center[1] + return matrix + + +def affine(img, angle, translate, scale, shear, resample=0, fillcolor=None): + """Apply affine transformation on the image keeping image center invariant + + Args: + img (PIL Image): PIL Image to be rotated. + angle (float or int): rotation angle in degrees between -180 and 180, clockwise direction. + translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation) + scale (float): overall scale + shear (float): shear angle value in degrees between -180 to 180, clockwise direction. + resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): + An optional resampling filter. + See `filters`_ for more information. + If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. + fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0) + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ + "Argument translate should be a list or tuple of length 2" + + assert scale > 0.0, "Argument scale should be positive" + + output_size = img.size + center = (img.size[0] * 0.5 + 0.5, img.size[1] * 0.5 + 0.5) + matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) + kwargs = {"fillcolor": fillcolor} if PILLOW_VERSION[0] >= '5' else {} + return img.transform(output_size, Image.AFFINE, matrix, resample, **kwargs) + + +def to_grayscale(img, num_output_channels=1): + """Convert image to grayscale version of image. + + Args: + img (PIL Image): Image to be converted to grayscale. + + Returns: + PIL Image: Grayscale version of the image. + if num_output_channels = 1 : returned image is single channel + + if num_output_channels = 3 : returned image is 3 channel with r = g = b + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + if num_output_channels == 1: + img = img.convert('L') + elif num_output_channels == 3: + img = img.convert('L') + np_img = np.array(img, dtype=np.uint8) + np_img = np.dstack([np_img, np_img, np_img]) + img = Image.fromarray(np_img, 'RGB') + else: + raise ValueError('num_output_channels should be either 1 or 3') + + return img + + +def flip_channels(img): + img = np.array(img)[:, :, ::-1] + return Image.fromarray(img.astype(np.uint8)) \ No newline at end of file diff --git a/model_measuring/kamal/vision/sync_transforms/transforms.py b/model_measuring/kamal/vision/sync_transforms/transforms.py new file mode 100644 index 0000000..58d76c4 --- /dev/null +++ b/model_measuring/kamal/vision/sync_transforms/transforms.py @@ -0,0 +1,1475 @@ +# A synchronized version modified from https://github.com/pytorch/vision +from __future__ import division +import torch +import math +import sys +import random +from PIL import Image +try: + import accimage +except ImportError: + accimage = None +import numpy as np +import numbers +import types +import collections +import warnings +import typing + +from . import functional as F + +if sys.version_info < (3, 3): + Sequence = collections.Sequence + Iterable = collections.Iterable +else: + Sequence = collections.abc.Sequence + Iterable = collections.abc.Iterable + + +__all__ = [ "Sync", "Multi","Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad", + "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", + "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", + "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", + "RandomPerspective", "FlipChannels", "RandomErasing", "ToRGB", "ToGRAY"] + +_pil_interpolation_to_str = { + Image.NEAREST: 'PIL.Image.NEAREST', + Image.BILINEAR: 'PIL.Image.BILINEAR', + Image.BICUBIC: 'PIL.Image.BICUBIC', + Image.LANCZOS: 'PIL.Image.LANCZOS', + Image.HAMMING: 'PIL.Image.HAMMING', + Image.BOX: 'PIL.Image.BOX', +} + +class Sync(object): + def __init__(self, *transforms): + self.transforms = transforms + + def __call__(self, *inputs): + shared_params = None + outputs = [] + if len(self.transforms)==1: + for input in inputs: + out, shared_params = self.transforms[0](input, params=shared_params, return_params=True) + outputs.append( out ) + else: + assert len(inputs) == len(self.transforms), \ + "Expected %d inputs, but got %d"%( len(self.transforms), len(inputs) ) + for (input, trans) in zip( inputs, self.transforms ): + out, shared_params = trans(input, params=shared_params, return_params=True) + outputs.append( out ) + return outputs + + def _format_transform_repr(self, transform, head): + lines = transform.__repr__().splitlines() + return (["{}{}".format(head, lines[0])] + + ["{}{}".format(" " * len(head), line) for line in lines[1:]]) + + def __repr__(self): + body = [self.__class__.__name__] + for transform in self.transforms: + body += self._format_transform_repr(transform, "Transform: ") + return '\n'.join(body) + +class Multi(object): + def __init__(self, *transforms): + self.transforms = transforms + + def __call__(self, *inputs): + if len(self.transforms)==1: + outputs = [ self.transforms[0](input) for input in inputs ] + else: + assert len(inputs) == len(self.transforms), \ + "Expected %d inputs, but got %d"%( len(self.transforms), len(inputs) ) + outputs = [] + for (input, trans) in zip( inputs, self.transforms ): + outputs.append( trans(input) if trans is not None else input ) + return outputs + + def _format_transform_repr(self, transform, head): + lines = transform.__repr__().splitlines() + return (["{}{}".format(head, lines[0])] + + ["{}{}".format(" " * len(head), line) for line in lines[1:]]) + + def __repr__(self): + body = [self.__class__.__name__] + for transform in self.transforms: + body += self._format_transform_repr(transform, "Transform: ") + return '\n'.join(body) + +class Compose(object): + """Composes several transforms together. + + Args: + transforms (list of ``Transform`` objects): list of transforms to compose. + + Example: + >>> transforms.Compose([ + >>> transforms.CenterCrop(10), + >>> transforms.ToTensor(), + >>> ]) + """ + + def __init__(self, transforms): + self.transforms = [] + for t in transforms: + if isinstance( t, typing.Sequence ): + self.transforms.append( Multi( t ) ) + else: + self.transforms.append(t) + + def __call__(self, *imgs): + if len(imgs)==1: + imgs = imgs[0] + for t in self.transforms: + imgs = t(imgs) + return imgs + else: + for t in self.transforms: + imgs = t(*imgs) + return imgs + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' + return format_string + + +class ToTensor(object): + """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. + + Converts a PIL Image or numpy.ndarray (H x W x C) in the range + [0, 255] to a torch.FloatTensor of shape (C x H x W). + + This class is identical to torchvision.transforms.ToTensor if normalize=True. + If normalize=False, tensors of type dtype will be returned without scaling. + """ + def __init__(self, normalize=True, dtype=None): + self.normalize=normalize + self.dtype=dtype + + def __call__(self, pic, params=None, return_params=False): + """ + Args: + pic (PIL Image or numpy.ndarray): Image to be converted to tensor. + + Returns: + Tensor: Converted image. + """ + if return_params: + return F.to_tensor(pic, self.normalize, self.dtype), None + return F.to_tensor(pic, self.normalize, self.dtype) + + def __repr__(self): + return self.__class__.__name__ + '(Normalize={0})'.format(self.normalize) + +class ToPILImage(object): + """Convert a tensor or an ndarray to PIL Image. + + Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape + H x W x C to a PIL Image while preserving the value range. + + Args: + mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). + If ``mode`` is ``None`` (default) there are some assumptions made about the input data: + - If the input has 4 channels, the ``mode`` is assumed to be ``RGBA``. + - If the input has 3 channels, the ``mode`` is assumed to be ``RGB``. + - If the input has 2 channels, the ``mode`` is assumed to be ``LA``. + - If the input has 1 channel, the ``mode`` is determined by the data type (i.e ``int``, ``float``, + ``short``). + + .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes + """ + def __init__(self, mode=None): + self.mode = mode + + def __call__(self, pic, params=None, return_params=False): + """ + Args: + pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. + + Returns: + PIL Image: Image converted to PIL Image. + + """ + if return_params: + return F.to_pil_image(pic, self.mode), None + return F.to_pil_image(pic, self.mode) + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + if self.mode is not None: + format_string += 'mode={0}'.format(self.mode) + format_string += ')' + return format_string + + +class Normalize(object): + """Normalize a tensor image with mean and standard deviation. + Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform + will normalize each channel of the input ``torch.*Tensor`` i.e. + ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` + + .. note: + This transform acts out of place, i.e., it does not mutates the input tensor. + + Args: + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channel. + """ + + def __init__(self, mean, std, inplace=False): + self.mean = mean + self.std = std + self.inplace = inplace + + def __call__(self, tensor, params=None, return_params=False ): + """ + Args: + tensor (Tensor): Tensor image of size (C, H, W) to be normalized. + + Returns: + Tensor: Normalized Tensor image. + """ + if return_params: + F.normalize(tensor, self.mean, self.std, self.inplace), None + return F.normalize(tensor, self.mean, self.std, self.inplace) + + def __repr__(self): + return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) + + +class Resize(object): + """Resize the input PIL Image to the given size. + + Args: + size (sequence or int): Desired output size. If size is a sequence like + (h, w), output size will be matched to this. If size is an int, + smaller edge of the image will be matched to this number. + i.e, if height > width, then image will be rescaled to + (size * height / width, size) + interpolation (int, optional): Desired interpolation. Default is + ``PIL.Image.BILINEAR`` + """ + + def __init__(self, size, interpolation=Image.BILINEAR): + assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2) + self.size = size + self.interpolation = interpolation + + def __call__(self, img, params=None, return_params=False): + """ + Args: + img (PIL Image): Image to be scaled. + + Returns: + PIL Image: Rescaled image. + """ + if return_params: + return F.resize(img, self.size, self.interpolation), None + return F.resize(img, self.size, self.interpolation) + + def __repr__(self): + interpolate_str = _pil_interpolation_to_str[self.interpolation] + return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str) + + +class Scale(Resize): + """ + Note: This transform is deprecated in favor of Resize. + """ + def __init__(self, *args, **kwargs): + warnings.warn("The use of the transforms.Scale transform is deprecated, " + + "please use transforms.Resize instead.") + super(Scale, self).__init__(*args, **kwargs) + + +class CenterCrop(object): + """Crops the given PIL Image at the center. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, img, params=None, return_params=False): + """ + Args: + img (PIL Image): Image to be cropped. + + Returns: + PIL Image: Cropped image. + """ + if return_params: + return F.center_crop(img, self.size), None + return F.center_crop(img, self.size) + + def __repr__(self): + return self.__class__.__name__ + '(size={0})'.format(self.size) + + +class Pad(object): + """Pad the given PIL Image on all sides with the given "pad" value. + + Args: + padding (int or tuple): Padding on each border. If a single int is provided this + is used to pad all borders. If tuple of length 2 is provided this is the padding + on left/right and top/bottom respectively. If a tuple of length 4 is provided + this is the padding for the left, top, right and bottom borders + respectively. + fill (int or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of + length 3, it is used to fill R, G, B channels respectively. + This value is only used when the padding_mode is constant + padding_mode (str): Type of padding. Should be: constant, edge, reflect or symmetric. + Default is constant. + + - constant: pads with a constant value, this value is specified with fill + + - edge: pads with the last value at the edge of the image + + - reflect: pads with reflection of image without repeating the last value on the edge + + For example, padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode + will result in [3, 2, 1, 2, 3, 4, 3, 2] + + - symmetric: pads with reflection of image repeating the last value on the edge + + For example, padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode + will result in [2, 1, 1, 2, 3, 4, 4, 3] + """ + + def __init__(self, padding, fill=0, padding_mode='constant'): + assert isinstance(padding, (numbers.Number, tuple)) + assert isinstance(fill, (numbers.Number, str, tuple)) + assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'] + if isinstance(padding, Sequence) and len(padding) not in [2, 4]: + raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + + "{} element tuple".format(len(padding))) + + self.padding = padding + self.fill = fill + self.padding_mode = padding_mode + + def __call__(self, img, params=None, return_params=False ): + """ + Args: + img (PIL Image): Image to be padded. + + Returns: + PIL Image: Padded image. + """ + if return_params: + return F.pad(img, self.padding, self.fill, self.padding_mode), None + return F.pad(img, self.padding, self.fill, self.padding_mode) + + def __repr__(self): + return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\ + format(self.padding, self.fill, self.padding_mode) + + +class Lambda(object): + """Apply a user-defined lambda as a transform. + + Args: + lambd (function): Lambda/function to be used for transform. + """ + + def __init__(self, lambd): + assert callable(lambd), repr(type(lambd).__name__) + " object is not callable" + self.lambd = lambd + + def __call__(self, img, params=None, return_params=False): + if return_params: + return self.lambd(img), None + return self.lambd(img) + + def __repr__(self): + return self.__class__.__name__ + '()' + + +class RandomTransforms(object): + """Base class for a list of transformations with randomness + + Args: + transforms (list or tuple): list of transformations + """ + + def __init__(self, transforms): + assert isinstance(transforms, (list, tuple)) + self.transforms = transforms + + def __call__(self, *args, **kwargs): + raise NotImplementedError() + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' + return format_string + + +class RandomApply(RandomTransforms): + """Apply randomly a list of transformations with a given probability + + Args: + transforms (list or tuple): list of transformations + p (float): probability + """ + + def __init__(self, transforms, p=0.5): + super(RandomApply, self).__init__(transforms) + self.p = p + + def __call__(self, img, params=None, return_params=False): + if params is None: + p = random.random() + else: + p = params + if self.p < p: + return img + for t in self.transforms: + img = t(img) + if return_params: + return img, p + return img + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + format_string += '\n p={}'.format(self.p) + for t in self.transforms: + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' + return format_string + +class RandomOrder(RandomTransforms): + """Apply a list of transformations in a random order + """ + def __call__(self, img, params=None, return_params=False): + if params is None: # no sync + order = list(range(len(self.transforms))) + random.shuffle(order) + else: + order = params + for i in order: + img = self.transforms[i](img) + if return_params: + return img, order + return img + +class RandomChoice(RandomTransforms): + """Apply single transformation randomly picked from a list + """ + def __call__(self, img, params=None, return_params=False): + if params is None: + t = random.choice(self.transforms) + else: + t = params + if return_params: + return t(img), t + return t(img) + +class RandomCrop(object): + """Crop the given PIL Image at a random location. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + padding (int or sequence, optional): Optional padding on each border + of the image. Default is None, i.e no padding. If a sequence of length + 4 is provided, it is used to pad left, top, right, bottom borders + respectively. If a sequence of length 2 is provided, it is used to + pad left/right, top/bottom borders, respectively. + pad_if_needed (boolean): It will pad the image if smaller than the + desired size to avoid raising an exception. Since cropping is done + after padding, the padding seems to be done at a random offset. + fill: Pixel fill value for constant fill. Default is 0. If a tuple of + length 3, it is used to fill R, G, B channels respectively. + This value is only used when the padding_mode is constant + padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. + + - constant: pads with a constant value, this value is specified with fill + + - edge: pads with the last value on the edge of the image + + - reflect: pads with reflection of image (without repeating the last value on the edge) + + padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode + will result in [3, 2, 1, 2, 3, 4, 3, 2] + + - symmetric: pads with reflection of image (repeating the last value on the edge) + + padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode + will result in [2, 1, 1, 2, 3, 4, 4, 3] + + """ + + def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + self.padding = padding + self.pad_if_needed = pad_if_needed + self.fill = fill + self.padding_mode = padding_mode + + @staticmethod + def get_params(img, output_size): + """Get parameters for ``crop`` for a random crop. + + Args: + img (PIL Image): Image to be cropped. + output_size (tuple): Expected output size of the crop. + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. + """ + w, h = img.size + th, tw = output_size + if w == tw and h == th: + return 0, 0, h, w + + i = random.randint(0, h - th) + j = random.randint(0, w - tw) + return i, j, th, tw + + def __call__(self, img, params=None, return_params=False): + """ + Args: + img (PIL Image): Image to be cropped. + + Returns: + PIL Image: Cropped image. + """ + if self.padding is not None: + img = F.pad(img, self.padding, self.fill, self.padding_mode) + + # pad the width if needed + if self.pad_if_needed and img.size[0] < self.size[1]: + img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) + # pad the height if needed + if self.pad_if_needed and img.size[1] < self.size[0]: + img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) + + if params is None: # no sync + i, j, h, w = self.get_params(img, self.size) + else: + i, j, h, w = params + if return_params: + return F.crop(img, i, j, h, w), (i,j,h,w) + return F.crop(img, i, j, h, w) + + def __repr__(self): + return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding) + + +class RandomHorizontalFlip(object): + """Horizontally flip the given PIL Image randomly with a given probability. + + Args: + p (float): probability of the image being flipped. Default value is 0.5 + """ + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img, params=None, return_params=False): + """ + Args: + img (PIL Image): Image to be flipped. + + Returns: + PIL Image: Randomly flipped image. + """ + if params is None: + p = random.random() + else: + p = params + + if p < self.p: + if return_params: + return F.hflip(img), p + return F.hflip(img) + if return_params: + return img, p + return img + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p) + + +class RandomVerticalFlip(object): + """Vertically flip the given PIL Image randomly with a given probability. + + Args: + p (float): probability of the image being flipped. Default value is 0.5 + """ + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img, params=None, return_params=False): + """ + Args: + img (PIL Image): Image to be flipped. + + Returns: + PIL Image: Randomly flipped image. + """ + if params is None: # no sync + p = random.random() + else: + p = params + if p < self.p: + return F.vflip(img) + if return_params: + return img, p + return img + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p) + + +class RandomPerspective(object): + """Performs Perspective transformation of the given PIL Image randomly with a given probability. + + Args: + interpolation : Default- Image.BICUBIC + + p (float): probability of the image being perspectively transformed. Default value is 0.5 + + distortion_scale(float): it controls the degree of distortion and ranges from 0 to 1. Default value is 0.5. + + """ + + def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BICUBIC): + self.p = p + self.interpolation = interpolation + self.distortion_scale = distortion_scale + + def __call__(self, img, params=None, return_params=False): + """ + Args: + img (PIL Image): Image to be Perspectively transformed. + + Returns: + PIL Image: Random perspectivley transformed image. + """ + if not F._is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + if params is None : # no sync or the first transform + p = random.random() + width, height = img.size + startpoints, endpoints = self.get_params(width, height, self.distortion_scale) + else: + p, startpoints, endpoints = params + + if p scale[1]) or (ratio[0] > ratio[1]): + warnings.warn("range should be of kind (min, max)") + + self.interpolation = interpolation + self.scale = scale + self.ratio = ratio + + @staticmethod + def get_params(img, scale, ratio): + """Get parameters for ``crop`` for a random sized crop. + + Args: + img (PIL Image): Image to be cropped. + scale (tuple): range of size of the origin size cropped + ratio (tuple): range of aspect ratio of the origin aspect ratio cropped + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + area = img.size[0] * img.size[1] + + for attempt in range(10): + target_area = random.uniform(*scale) * area + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if w <= img.size[0] and h <= img.size[1]: + i = random.randint(0, img.size[1] - h) + j = random.randint(0, img.size[0] - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = img.size[0] / img.size[1] + if (in_ratio < min(ratio)): + w = img.size[0] + h = w / min(ratio) + elif (in_ratio > max(ratio)): + h = img.size[1] + w = h * max(ratio) + else: # whole image + w = img.size[0] + h = img.size[1] + i = (img.size[1] - h) // 2 + j = (img.size[0] - w) // 2 + return i, j, h, w + + def __call__(self, img, params=None, return_params=False): + """ + Args: + img (PIL Image): Image to be cropped and resized. + + Returns: + PIL Image: Randomly cropped and resized image. + """ + + if params is None: # no sync + i, j, h, w = self.get_params(img, self.scale, self.ratio) + else: + i, j, h, w = params + if return_params: + return F.resized_crop(img, i, j, h, w, self.size, self.interpolation), (i,j,h,w) + return F.resized_crop(img, i, j, h, w, self.size, self.interpolation) + + def __repr__(self): + interpolate_str = _pil_interpolation_to_str[self.interpolation] + format_string = self.__class__.__name__ + '(size={0}'.format(self.size) + format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) + format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) + format_string += ', interpolation={0})'.format(interpolate_str) + return format_string + + +class RandomSizedCrop(RandomResizedCrop): + """ + Note: This transform is deprecated in favor of RandomResizedCrop. + """ + def __init__(self, *args, **kwargs): + warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " + + "please use transforms.RandomResizedCrop instead.") + super(RandomSizedCrop, self).__init__(*args, **kwargs) + + +class FiveCrop(object): + """Crop the given PIL Image into four corners and the central crop + + .. Note: + This transform returns a tuple of images and there may be a mismatch in the number of + inputs and targets your Dataset returns. See below for an example of how to deal with + this. + + Args: + size (sequence or int): Desired output size of the crop. If size is an ``int`` + instead of sequence like (h, w), a square crop of size (size, size) is made. + + Example: + >>> transform = Compose([ + >>> FiveCrop(size), # this is a list of PIL Images + >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor + >>> ]) + >>> #In your test loop you can do the following: + >>> input, target = batch # input is a 5d tensor, target is 2d + >>> bs, ncrops, c, h, w = input.size() + >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops + >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops + """ + + def __init__(self, size): + self.size = size + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + self.size = size + + def __call__(self, img, params=None, return_params=False): + if return_params: + return F.five_crop(img, self.size), None + return F.five_crop(img, self.size) + + def __repr__(self): + return self.__class__.__name__ + '(size={0})'.format(self.size) + + +class TenCrop(object): + """Crop the given PIL Image into four corners and the central crop plus the flipped version of + these (horizontal flipping is used by default) + + .. Note: + This transform returns a tuple of images and there may be a mismatch in the number of + inputs and targets your Dataset returns. See below for an example of how to deal with + this. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + vertical_flip (bool): Use vertical flipping instead of horizontal + + Example: + >>> transform = Compose([ + >>> TenCrop(size), # this is a list of PIL Images + >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor + >>> ]) + >>> #In your test loop you can do the following: + >>> input, target = batch # input is a 5d tensor, target is 2d + >>> bs, ncrops, c, h, w = input.size() + >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops + >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops + """ + + def __init__(self, size, vertical_flip=False): + self.size = size + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + self.size = size + self.vertical_flip = vertical_flip + + def __call__(self, img, params=None, return_params=False): + if return_params: + return F.ten_crop(img, self.size, self.vertical_flip), None + return F.ten_crop(img, self.size, self.vertical_flip) + + def __repr__(self): + return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip) + + +class LinearTransformation(object): + """Transform a tensor image with a square transformation matrix and a mean_vector computed + offline. + Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and + subtract mean_vector from it which is then followed by computing the dot + product with the transformation matrix and then reshaping the tensor to its + original shape. + + Applications: + whitening transformation: Suppose X is a column vector zero-centered data. + Then compute the data covariance matrix [D x D] with torch.mm(X.t(), X), + perform SVD on this matrix and pass it as transformation_matrix. + + Args: + transformation_matrix (Tensor): tensor [D x D], D = C x H x W + mean_vector (Tensor): tensor [D], D = C x H x W + """ + + def __init__(self, transformation_matrix, mean_vector): + if transformation_matrix.size(0) != transformation_matrix.size(1): + raise ValueError("transformation_matrix should be square. Got " + + "[{} x {}] rectangular matrix.".format(*transformation_matrix.size())) + + if mean_vector.size(0) != transformation_matrix.size(0): + raise ValueError("mean_vector should have the same length {}".format(mean_vector.size(0)) + + " as any one of the dimensions of the transformation_matrix [{} x {}]" + .format(transformation_matrix.size())) + + self.transformation_matrix = transformation_matrix + self.mean_vector = mean_vector + + def __call__(self, tensor, params=None, return_params=False): + """ + Args: + tensor (Tensor): Tensor image of size (C, H, W) to be whitened. + + Returns: + Tensor: Transformed image. + """ + if tensor.size(0) * tensor.size(1) * tensor.size(2) != self.transformation_matrix.size(0): + raise ValueError("tensor and transformation matrix have incompatible shape." + + "[{} x {} x {}] != ".format(*tensor.size()) + + "{}".format(self.transformation_matrix.size(0))) + flat_tensor = tensor.view(1, -1) - self.mean_vector + transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) + tensor = transformed_tensor.view(tensor.size()) + if return_params: + return tensor, None + return tensor + + def __repr__(self): + format_string = self.__class__.__name__ + '(transformation_matrix=' + format_string += (str(self.transformation_matrix.tolist()) + ')') + format_string += (", (mean_vector=" + str(self.mean_vector.tolist()) + ')') + return format_string + + +class ColorJitter(object): + """Randomly change the brightness, contrast and saturation of an image. + + Args: + brightness (float or tuple of float (min, max)): How much to jitter brightness. + brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] + or the given [min, max]. Should be non negative numbers. + contrast (float or tuple of float (min, max)): How much to jitter contrast. + contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] + or the given [min, max]. Should be non negative numbers. + saturation (float or tuple of float (min, max)): How much to jitter saturation. + saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] + or the given [min, max]. Should be non negative numbers. + hue (float or tuple of float (min, max)): How much to jitter hue. + hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. + Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. + """ + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + self.brightness = self._check_input(brightness, 'brightness') + self.contrast = self._check_input(contrast, 'contrast') + self.saturation = self._check_input(saturation, 'saturation') + self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), + clip_first_on_zero=False) + + def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): + if isinstance(value, numbers.Number): + if value < 0: + raise ValueError("If {} is a single number, it must be non negative.".format(name)) + value = [center - value, center + value] + if clip_first_on_zero: + value[0] = max(value[0], 0) + elif isinstance(value, (tuple, list)) and len(value) == 2: + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError("{} values should be between {}".format(name, bound)) + else: + raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) + + # if value is 0 or (1., 1.) for brightness/contrast/saturation + # or (0., 0.) for hue, do nothing + if value[0] == value[1] == center: + value = None + return value + + @staticmethod + def get_params(brightness, contrast, saturation, hue): + """Get a randomized transform to be applied on image. + + Arguments are same as that of __init__. + + Returns: + Transform which randomly adjusts brightness, contrast and + saturation in a random order. + """ + transforms = [] + + if brightness is not None: + brightness_factor = random.uniform(brightness[0], brightness[1]) + transforms.append(Lambda(lambda img: F.adjust_brightness(img, brightness_factor))) + + if contrast is not None: + contrast_factor = random.uniform(contrast[0], contrast[1]) + transforms.append(Lambda(lambda img: F.adjust_contrast(img, contrast_factor))) + + if saturation is not None: + saturation_factor = random.uniform(saturation[0], saturation[1]) + transforms.append(Lambda(lambda img: F.adjust_saturation(img, saturation_factor))) + + if hue is not None: + hue_factor = random.uniform(hue[0], hue[1]) + transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor))) + + random.shuffle(transforms) + transform = Compose(transforms) + + return transform + + def __call__(self, img, params=None, return_params=False): + """ + Args: + img (PIL Image): Input image. + + Returns: + PIL Image: Color jittered image. + """ + sync_group = getattr( self, 'sync_group', None ) + + if params is None: # no sync + transform = self.get_params(self.brightness, self.contrast, + self.saturation, self.hue) + else: + transform = params + if return_params: + return transform(img), transform + return transform(img) + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + format_string += 'brightness={0}'.format(self.brightness) + format_string += ', contrast={0}'.format(self.contrast) + format_string += ', saturation={0}'.format(self.saturation) + format_string += ', hue={0})'.format(self.hue) + return format_string + + +class RandomRotation(object): + """Rotate the image by angle. + + Args: + degrees (sequence or float or int): Range of degrees to select from. + If degrees is a number instead of sequence like (min, max), the range of degrees + will be (-degrees, +degrees). + resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): + An optional resampling filter. See `filters`_ for more information. + If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. + expand (bool, optional): Optional expansion flag. + If true, expands the output to make it large enough to hold the entire rotated image. + If false or omitted, make the output image the same size as the input image. + Note that the expand flag assumes rotation around the center and no translation. + center (2-tuple, optional): Optional center of rotation. + Origin is the upper left corner. + Default is the center of the image. + + .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters + + """ + + def __init__(self, degrees, resample=False, expand=False, center=None): + if isinstance(degrees, numbers.Number): + if degrees < 0: + raise ValueError("If degrees is a single number, it must be positive.") + self.degrees = (-degrees, degrees) + else: + if len(degrees) != 2: + raise ValueError("If degrees is a sequence, it must be of len 2.") + self.degrees = degrees + + self.resample = resample + self.expand = expand + self.center = center + + @staticmethod + def get_params(degrees): + """Get parameters for ``rotate`` for a random rotation. + + Returns: + sequence: params to be passed to ``rotate`` for random rotation. + """ + angle = random.uniform(degrees[0], degrees[1]) + + return angle + + def __call__(self, img, params=None, return_params=False): + """ + Args: + img (PIL Image): Image to be rotated. + + Returns: + PIL Image: Rotated image. + """ + sync_group = getattr(self, 'sync_group', None) + if params is None: # no sync + angle = self.get_params(self.degrees) + else: + angle = params + if return_params: + return F.rotate(img, angle, self.resample, self.expand, self.center), angle + return F.rotate(img, angle, self.resample, self.expand, self.center) + + def __repr__(self): + format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees) + format_string += ', resample={0}'.format(self.resample) + format_string += ', expand={0}'.format(self.expand) + if self.center is not None: + format_string += ', center={0}'.format(self.center) + format_string += ')' + return format_string + + +class RandomAffine(object): + """Random affine transformation of the image keeping center invariant + + Args: + degrees (sequence or float or int): Range of degrees to select from. + If degrees is a number instead of sequence like (min, max), the range of degrees + will be (-degrees, +degrees). Set to 0 to deactivate rotations. + translate (tuple, optional): tuple of maximum absolute fraction for horizontal + and vertical translations. For example translate=(a, b), then horizontal shift + is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is + randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default. + scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is + randomly sampled from the range a <= scale <= b. Will keep original scale by default. + shear (sequence or float or int, optional): Range of degrees to select from. + If degrees is a number instead of sequence like (min, max), the range of degrees + will be (-degrees, +degrees). Will not apply shear by default + resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): + An optional resampling filter. See `filters`_ for more information. + If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. + fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0) + + .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters + + """ + + def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0): + sync_group = None + if isinstance(degrees, numbers.Number): + if degrees < 0: + raise ValueError("If degrees is a single number, it must be positive.") + self.degrees = (-degrees, degrees) + else: + assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \ + "degrees should be a list or tuple and it must be of length 2." + self.degrees = degrees + + if translate is not None: + assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ + "translate should be a list or tuple and it must be of length 2." + for t in translate: + if not (0.0 <= t <= 1.0): + raise ValueError("translation values should be between 0 and 1") + self.translate = translate + + if scale is not None: + assert isinstance(scale, (tuple, list)) and len(scale) == 2, \ + "scale should be a list or tuple and it must be of length 2." + for s in scale: + if s <= 0: + raise ValueError("scale values should be positive") + self.scale = scale + + if shear is not None: + if isinstance(shear, numbers.Number): + if shear < 0: + raise ValueError("If shear is a single number, it must be positive.") + self.shear = (-shear, shear) + else: + assert isinstance(shear, (tuple, list)) and len(shear) == 2, \ + "shear should be a list or tuple and it must be of length 2." + self.shear = shear + else: + self.shear = shear + + self.resample = resample + self.fillcolor = fillcolor + + @staticmethod + def get_params(degrees, translate, scale_ranges, shears, img_size): + """Get parameters for affine transformation + + Returns: + sequence: params to be passed to the affine transformation + """ + angle = random.uniform(degrees[0], degrees[1]) + if translate is not None: + max_dx = translate[0] * img_size[0] + max_dy = translate[1] * img_size[1] + translations = (np.round(random.uniform(-max_dx, max_dx)), + np.round(random.uniform(-max_dy, max_dy))) + else: + translations = (0, 0) + + if scale_ranges is not None: + scale = random.uniform(scale_ranges[0], scale_ranges[1]) + else: + scale = 1.0 + + if shears is not None: + shear = random.uniform(shears[0], shears[1]) + else: + shear = 0.0 + + return angle, translations, scale, shear + + def __call__(self, img, params=None, return_params=False): + """ + img (PIL Image): Image to be transformed. + + Returns: + PIL Image: Affine transformed image. + """ + if params is None: # no sync + params = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size) + if return_params: + return F.affine(img, *params, resample=self.resample, fillcolor=self.fillcolor), params + return F.affine(img, *params, resample=self.resample, fillcolor=self.fillcolor) + + def __repr__(self): + s = '{name}(degrees={degrees}' + if self.translate is not None: + s += ', translate={translate}' + if self.scale is not None: + s += ', scale={scale}' + if self.shear is not None: + s += ', shear={shear}' + if self.resample > 0: + s += ', resample={resample}' + if self.fillcolor != 0: + s += ', fillcolor={fillcolor}' + s += ')' + d = dict(self.__dict__) + d['resample'] = _pil_interpolation_to_str[d['resample']] + return s.format(name=self.__class__.__name__, **d) + +class ToRGB(object): + def __call__(self, img, params=None, return_params=False): + """ + Args: + img (PIL Image): Image to be converted to grayscale. + + Returns: + PIL Image: Randomly grayscaled image. + """ + if return_params: + return img.convert('RGB'), None + return img.convert('RGB') + +class ToGRAY(object): + def __call__(self, img, params=None, return_params=False): + """ + Args: + img (PIL Image): Image to be converted to grayscale. + + Returns: + PIL Image: Randomly grayscaled image. + """ + if return_params: + return img.convert('GRAY'), None + return img.convert('GRAY') + +class Grayscale(object): + """Convert image to grayscale. + + Args: + num_output_channels (int): (1 or 3) number of channels desired for output image + + Returns: + PIL Image: Grayscale version of the input. + - If num_output_channels == 1 : returned image is single channel + - If num_output_channels == 3 : returned image is 3 channel with r == g == b + + """ + + def __init__(self, num_output_channels=1): + self.num_output_channels = num_output_channels + + def __call__(self, img, params=None, return_params=False): + """ + Args: + img (PIL Image): Image to be converted to grayscale. + + Returns: + PIL Image: Randomly grayscaled image. + """ + if return_params: + return F.to_grayscale(img, num_output_channels=self.num_output_channels), None + return F.to_grayscale(img, num_output_channels=self.num_output_channels) + + def __repr__(self): + return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels) + + +class RandomGrayscale(object): + """Randomly convert image to grayscale with a probability of p (default 0.1). + + Args: + p (float): probability that image should be converted to grayscale. + + Returns: + PIL Image: Grayscale version of the input image with probability p and unchanged + with probability (1-p). + - If input image is 1 channel: grayscale version is 1 channel + - If input image is 3 channel: grayscale version is 3 channel with r == g == b + + """ + + def __init__(self, p=0.1): + self.p = p + sync_group = None + + def __call__(self, img, params=None, return_params=False): + """ + Args: + img (PIL Image): Image to be converted to grayscale. + + Returns: + PIL Image: Randomly grayscaled image. + """ + num_output_channels = 1 if img.mode == 'L' else 3 + if params is None: # no sync + p = random.random() + else: + p = params + if p < self.p: + if return_params: + return F.to_grayscale(img, num_output_channels=num_output_channels), p + return F.to_grayscale(img, num_output_channels=num_output_channels) + + if return_params: + return img, p + return img + + def __repr__(self): + return self.__class__.__name__ + '(p={0})'.format(self.p) + +class FlipChannels(object): + def __call__(self, img, params=None, return_params=False): + if return_params: + return F.flip_channels( img ), None + return F.flip_channels( img ) + + def __repr__(self): + return self.__class__.__name__ + '()'.format() + +class RandomErasing(object): + """ Randomly selects a rectangle region in an image and erases its pixels. + 'Random Erasing Data Augmentation' by Zhong et al. + See https://arxiv.org/pdf/1708.04896.pdf + Args: + p: probability that the random erasing operation will be performed. + scale: range of proportion of erased area against input image. + ratio: range of aspect ratio of erased area. + value: erasing value. Default is 0. If a single int, it is used to + erase all pixels. If a tuple of length 3, it is used to erase + R, G, B channels respectively. + If a str of 'random', erasing each pixel with random values. + inplace: boolean to make this transform inplace. Default set to False. + Returns: + Erased Image. + # Examples: + >>> transform = transforms.Compose([ + >>> transforms.RandomHorizontalFlip(), + >>> transforms.ToTensor(), + >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + >>> transforms.RandomErasing(), + >>> ]) + """ + + def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False): + assert isinstance(value, (numbers.Number, str, tuple, list)) + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + warnings.warn("range should be of kind (min, max)") + if scale[0] < 0 or scale[1] > 1: + raise ValueError("range of scale should be between 0 and 1") + if p < 0 or p > 1: + raise ValueError("range of random erasing probability should be between 0 and 1") + + self.p = p + self.scale = scale + self.ratio = ratio + self.value = value + self.inplace = inplace + + @staticmethod + def get_params(img, scale, ratio, value=0): + """Get parameters for ``erase`` for a random erasing. + Args: + img (Tensor): Tensor image of size (C, H, W) to be erased. + scale: range of proportion of erased area against input image. + ratio: range of aspect ratio of erased area. + Returns: + tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing. + """ + img_c, img_h, img_w = img.shape + area = img_h * img_w + + for attempt in range(10): + erase_area = random.uniform(scale[0], scale[1]) * area + aspect_ratio = random.uniform(ratio[0], ratio[1]) + + h = int(round(math.sqrt(erase_area * aspect_ratio))) + w = int(round(math.sqrt(erase_area / aspect_ratio))) + + if h < img_h and w < img_w: + i = random.randint(0, img_h - h) + j = random.randint(0, img_w - w) + if isinstance(value, numbers.Number): + v = value + elif isinstance(value, torch._six.string_classes): + v = torch.empty([img_c, h, w], dtype=torch.float32).normal_() + elif isinstance(value, (list, tuple)): + v = torch.tensor(value, dtype=torch.float32).view(-1, 1, 1).expand(-1, h, w) + return i, j, h, w, v + + # Return original image + return 0, 0, img_h, img_w, img + + def __call__(self, img, params=None, return_params=False): + """ + Args: + img (Tensor): Tensor image of size (C, H, W) to be erased. + Returns: + img (Tensor): Erased Tensor image. + """ + if params is None: # no sync + p = random.uniform(0, 1) + x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=self.value) + else: + p, x, y, h, w, v = params + + if p