diff --git a/fastNLP/modules/__init__.py b/fastNLP/modules/__init__.py index d72d2022..51d5aaac 100644 --- a/fastNLP/modules/__init__.py +++ b/fastNLP/modules/__init__.py @@ -48,6 +48,8 @@ __all__ = [ "allowed_transitions", "TimestepDropout", + + 'summary' ] from . import decoder @@ -55,6 +57,7 @@ from . import encoder from .decoder import * from .dropout import TimestepDropout from .encoder import * +from .utils import summary import sys from ..doc_utils import doc_process diff --git a/fastNLP/modules/utils.py b/fastNLP/modules/utils.py index 09574782..54993479 100644 --- a/fastNLP/modules/utils.py +++ b/fastNLP/modules/utils.py @@ -89,6 +89,7 @@ def summary(model: nn.Module): """ train = [] nontrain = [] + buffer = [] def layer_summary(module: nn.Module): def count_size(sizes): @@ -99,6 +100,8 @@ def summary(model: nn.Module): train.append(count_size(p.shape)) else: nontrain.append(count_size(p.shape)) + for p in module.buffers(): + buffer.append(count_size(p)) for subm in module.children(): layer_summary(subm) @@ -110,6 +113,7 @@ def summary(model: nn.Module): strings.append('Total params: {:,}'.format(total)) strings.append('Trainable params: {:,}'.format(total_train)) strings.append('Non-trainable params: {:,}'.format(total_nontrain)) + strings.append("Buffer params: {:,}".format(sum(buffer))) max_len = len(max(strings, key=len)) bar = '-' * (max_len + 3) strings = [bar] + strings + [bar]