Browse Source

add model summary

tags/v0.4.10
yunfan 5 years ago
parent
commit
87e21a26a3
1 changed files with 44 additions and 3 deletions
  1. +44
    -3
      fastNLP/modules/utils.py

+ 44
- 3
fastNLP/modules/utils.py View File

@@ -1,3 +1,5 @@
from functools import reduce
from collections import OrderedDict
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -71,7 +73,7 @@ def initial_parameter(net, initial_method=None):
def get_embeddings(init_embed): def get_embeddings(init_embed):
""" """
得到词嵌入 得到词嵌入
.. todo:: .. todo::
补上文档 补上文档


@@ -81,7 +83,8 @@ def get_embeddings(init_embed):
:return nn.Embedding embeddings: :return nn.Embedding embeddings:
""" """
if isinstance(init_embed, tuple): if isinstance(init_embed, tuple):
res = nn.Embedding(num_embeddings=init_embed[0], embedding_dim=init_embed[1])
res = nn.Embedding(
num_embeddings=init_embed[0], embedding_dim=init_embed[1])
elif isinstance(init_embed, nn.Embedding): elif isinstance(init_embed, nn.Embedding):
res = init_embed res = init_embed
elif isinstance(init_embed, torch.Tensor): elif isinstance(init_embed, torch.Tensor):
@@ -90,5 +93,43 @@ def get_embeddings(init_embed):
init_embed = torch.tensor(init_embed, dtype=torch.float32) init_embed = torch.tensor(init_embed, dtype=torch.float32)
res = nn.Embedding.from_pretrained(init_embed, freeze=False) res = nn.Embedding.from_pretrained(init_embed, freeze=False)
else: else:
raise TypeError('invalid init_embed type: {}'.format((type(init_embed))))
raise TypeError(
'invalid init_embed type: {}'.format((type(init_embed))))
return res return res


def summary(model: nn.Module):
"""
得到模型的总参数量

:params model: Pytorch 模型
:return tuple: 包含总参数量,可训练参数量,不可训练参数量
"""
train = []
nontrain = []

def layer_summary(module: nn.Module):
def count_size(sizes):
return reduce(lambda x, y: x*y, sizes)

for p in module.parameters(recurse=False):
if p.requires_grad:
train.append(count_size(p.shape))
else:
nontrain.append(count_size(p.shape))
for subm in module.children():
layer_summary(subm)

layer_summary(model)
total_train = sum(train)
total_nontrain = sum(nontrain)
total = total_train + total_nontrain
strings = []
strings.append('Total params: {:,}'.format(total))
strings.append('Trainable params: {:,}'.format(total_train))
strings.append('Non-trainable params: {:,}'.format(total_nontrain))
max_len = len(max(strings, key=len))
bar = '-'*(max_len + 3)
strings = [bar] + strings + [bar]
print('\n'.join(strings))
return total, total_train, total_nontrain

Loading…
Cancel
Save