|
|
@@ -89,6 +89,7 @@ def summary(model: nn.Module): |
|
|
|
""" |
|
|
|
train = [] |
|
|
|
nontrain = [] |
|
|
|
buffer = [] |
|
|
|
|
|
|
|
def layer_summary(module: nn.Module): |
|
|
|
def count_size(sizes): |
|
|
@@ -99,6 +100,8 @@ def summary(model: nn.Module): |
|
|
|
train.append(count_size(p.shape)) |
|
|
|
else: |
|
|
|
nontrain.append(count_size(p.shape)) |
|
|
|
for p in module.buffers(): |
|
|
|
buffer.append(count_size(p)) |
|
|
|
for subm in module.children(): |
|
|
|
layer_summary(subm) |
|
|
|
|
|
|
@@ -110,6 +113,7 @@ def summary(model: nn.Module): |
|
|
|
strings.append('Total params: {:,}'.format(total)) |
|
|
|
strings.append('Trainable params: {:,}'.format(total_train)) |
|
|
|
strings.append('Non-trainable params: {:,}'.format(total_nontrain)) |
|
|
|
strings.append("Buffer params: {:,}".format(sum(buffer))) |
|
|
|
max_len = len(max(strings, key=len)) |
|
|
|
bar = '-' * (max_len + 3) |
|
|
|
strings = [bar] + strings + [bar] |
|
|
|