Browse Source

1. summary统计buffer的数量

tags/v0.4.10
yh 5 years ago
parent
commit
92823fedb9
2 changed files with 7 additions and 0 deletions
  1. +3
    -0
      fastNLP/modules/__init__.py
  2. +4
    -0
      fastNLP/modules/utils.py

+ 3
- 0
fastNLP/modules/__init__.py View File

@@ -48,6 +48,8 @@ __all__ = [
"allowed_transitions",

"TimestepDropout",

'summary'
]

from . import decoder
@@ -55,6 +57,7 @@ from . import encoder
from .decoder import *
from .dropout import TimestepDropout
from .encoder import *
from .utils import summary

import sys
from ..doc_utils import doc_process


+ 4
- 0
fastNLP/modules/utils.py View File

@@ -89,6 +89,7 @@ def summary(model: nn.Module):
"""
train = []
nontrain = []
buffer = []
def layer_summary(module: nn.Module):
def count_size(sizes):
@@ -99,6 +100,8 @@ def summary(model: nn.Module):
train.append(count_size(p.shape))
else:
nontrain.append(count_size(p.shape))
for p in module.buffers():
buffer.append(count_size(p))
for subm in module.children():
layer_summary(subm)
@@ -110,6 +113,7 @@ def summary(model: nn.Module):
strings.append('Total params: {:,}'.format(total))
strings.append('Trainable params: {:,}'.format(total_train))
strings.append('Non-trainable params: {:,}'.format(total_nontrain))
strings.append("Buffer params: {:,}".format(sum(buffer)))
max_len = len(max(strings, key=len))
bar = '-' * (max_len + 3)
strings = [bar] + strings + [bar]


Loading…
Cancel
Save