From 87e21a26a31b27a60c832fa8103ef3aaf619b059 Mon Sep 17 00:00:00 2001 From: yunfan Date: Thu, 16 May 2019 14:17:53 +0800 Subject: [PATCH] add model summary --- fastNLP/modules/utils.py | 47 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 44 insertions(+), 3 deletions(-) diff --git a/fastNLP/modules/utils.py b/fastNLP/modules/utils.py index 047ebb78..23768b03 100644 --- a/fastNLP/modules/utils.py +++ b/fastNLP/modules/utils.py @@ -1,3 +1,5 @@ +from functools import reduce +from collections import OrderedDict import numpy as np import torch import torch.nn as nn @@ -71,7 +73,7 @@ def initial_parameter(net, initial_method=None): def get_embeddings(init_embed): """ 得到词嵌入 - + .. todo:: 补上文档 @@ -81,7 +83,8 @@ def get_embeddings(init_embed): :return nn.Embedding embeddings: """ if isinstance(init_embed, tuple): - res = nn.Embedding(num_embeddings=init_embed[0], embedding_dim=init_embed[1]) + res = nn.Embedding( + num_embeddings=init_embed[0], embedding_dim=init_embed[1]) elif isinstance(init_embed, nn.Embedding): res = init_embed elif isinstance(init_embed, torch.Tensor): @@ -90,5 +93,43 @@ def get_embeddings(init_embed): init_embed = torch.tensor(init_embed, dtype=torch.float32) res = nn.Embedding.from_pretrained(init_embed, freeze=False) else: - raise TypeError('invalid init_embed type: {}'.format((type(init_embed)))) + raise TypeError( + 'invalid init_embed type: {}'.format((type(init_embed)))) return res + + +def summary(model: nn.Module): + """ + 得到模型的总参数量 + + :params model: Pytorch 模型 + :return tuple: 包含总参数量,可训练参数量,不可训练参数量 + """ + train = [] + nontrain = [] + + def layer_summary(module: nn.Module): + def count_size(sizes): + return reduce(lambda x, y: x*y, sizes) + + for p in module.parameters(recurse=False): + if p.requires_grad: + train.append(count_size(p.shape)) + else: + nontrain.append(count_size(p.shape)) + for subm in module.children(): + layer_summary(subm) + + layer_summary(model) + total_train = sum(train) + total_nontrain = sum(nontrain) + total = total_train + total_nontrain + strings = [] + strings.append('Total params: {:,}'.format(total)) + strings.append('Trainable params: {:,}'.format(total_train)) + strings.append('Non-trainable params: {:,}'.format(total_nontrain)) + max_len = len(max(strings, key=len)) + bar = '-'*(max_len + 3) + strings = [bar] + strings + [bar] + print('\n'.join(strings)) + return total, total_train, total_nontrain