|
- # Copyright (c) Microsoft Corporation.
- # Licensed under the MIT License.
- # Written by Hao Du and Houwen Peng
- # email: haodu8-c@my.cityu.edu.hk and houwen.peng@microsoft.com
-
- import torch
-
- from ptflops import get_model_complexity_info
-
-
- class FlopsEst(object):
- def __init__(self, model, input_shape=(2, 3, 224, 224), device='cpu'):
- self.block_num = len(model.blocks)
- self.choice_num = len(model.blocks[0])
- self.flops_dict = {}
- self.params_dict = {}
-
- if device == 'cpu':
- model = model.cpu()
- else:
- model = model.cuda()
-
- self.params_fixed = 0
- self.flops_fixed = 0
-
- input = torch.randn(input_shape)
-
- flops, params = get_model_complexity_info(
- model.conv_stem, (3, 224, 224), as_strings=False, print_per_layer_stat=False)
- self.params_fixed += params / 1e6
- self.flops_fixed += flops / 1e6
-
- input = model.conv_stem(input)
-
- for block_id, block in enumerate(model.blocks):
- self.flops_dict[block_id] = {}
- self.params_dict[block_id] = {}
- for module_id, module in enumerate(block):
- flops, params = get_model_complexity_info(module, tuple(
- input.shape[1:]), as_strings=False, print_per_layer_stat=False)
- # Flops(M)
- self.flops_dict[block_id][module_id] = flops / 1e6
- # Params(M)
- self.params_dict[block_id][module_id] = params / 1e6
-
- input = module(input)
-
- # conv_last
- flops, params = get_model_complexity_info(model.global_pool, tuple(
- input.shape[1:]), as_strings=False, print_per_layer_stat=False)
- self.params_fixed += params / 1e6
- self.flops_fixed += flops / 1e6
-
- input = model.global_pool(input)
-
- # globalpool
- flops, params = get_model_complexity_info(model.conv_head, tuple(
- input.shape[1:]), as_strings=False, print_per_layer_stat=False)
- self.params_fixed += params / 1e6
- self.flops_fixed += flops / 1e6
-
- # return params (M)
- def get_params(self, arch):
- params = 0
- for block_id, block in enumerate(arch):
- if block == -1:
- continue
- params += self.params_dict[block_id][block]
- return params + self.params_fixed
-
- # return flops (M)
- def get_flops(self, arch):
- flops = 0
- for block_id, block in enumerate(arch):
- if block == 'LayerChoice1' or block_id == 'LayerChoice23':
- continue
- for idx, choice in enumerate(arch[block]):
- flops += self.flops_dict[block_id][idx] * (1 if choice else 0)
- return flops + self.flops_fixed
|