|
|
@@ -1,10 +1,20 @@ |
|
|
|
""" |
|
|
|
.. todo:: |
|
|
|
doc |
|
|
|
""" |
|
|
|
|
|
|
|
__all__ = [ |
|
|
|
"initial_parameter", |
|
|
|
"summary" |
|
|
|
] |
|
|
|
|
|
|
|
import os |
|
|
|
from functools import reduce |
|
|
|
|
|
|
|
import torch |
|
|
|
import torch.nn as nn |
|
|
|
import torch.nn.init as init |
|
|
|
import glob |
|
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
def initial_parameter(net, initial_method=None): |
|
|
|
"""A method used to initialize the weights of PyTorch models. |
|
|
@@ -40,7 +50,7 @@ def initial_parameter(net, initial_method=None): |
|
|
|
init_method = init.uniform_ |
|
|
|
else: |
|
|
|
init_method = init.xavier_normal_ |
|
|
|
|
|
|
|
|
|
|
|
def weights_init(m): |
|
|
|
# classname = m.__class__.__name__ |
|
|
|
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv3d): # for all the cnn |
|
|
@@ -66,7 +76,7 @@ def initial_parameter(net, initial_method=None): |
|
|
|
else: |
|
|
|
init.normal_(w.data) # bias |
|
|
|
# print("init else") |
|
|
|
|
|
|
|
|
|
|
|
net.apply(weights_init) |
|
|
|
|
|
|
|
|
|
|
@@ -79,11 +89,11 @@ def summary(model: nn.Module): |
|
|
|
""" |
|
|
|
train = [] |
|
|
|
nontrain = [] |
|
|
|
|
|
|
|
|
|
|
|
def layer_summary(module: nn.Module): |
|
|
|
def count_size(sizes): |
|
|
|
return reduce(lambda x, y: x*y, sizes) |
|
|
|
|
|
|
|
return reduce(lambda x, y: x * y, sizes) |
|
|
|
|
|
|
|
for p in module.parameters(recurse=False): |
|
|
|
if p.requires_grad: |
|
|
|
train.append(count_size(p.shape)) |
|
|
@@ -91,7 +101,7 @@ def summary(model: nn.Module): |
|
|
|
nontrain.append(count_size(p.shape)) |
|
|
|
for subm in module.children(): |
|
|
|
layer_summary(subm) |
|
|
|
|
|
|
|
|
|
|
|
layer_summary(model) |
|
|
|
total_train = sum(train) |
|
|
|
total_nontrain = sum(nontrain) |
|
|
@@ -101,7 +111,7 @@ def summary(model: nn.Module): |
|
|
|
strings.append('Trainable params: {:,}'.format(total_train)) |
|
|
|
strings.append('Non-trainable params: {:,}'.format(total_nontrain)) |
|
|
|
max_len = len(max(strings, key=len)) |
|
|
|
bar = '-'*(max_len + 3) |
|
|
|
bar = '-' * (max_len + 3) |
|
|
|
strings = [bar] + strings + [bar] |
|
|
|
print('\n'.join(strings)) |
|
|
|
return total, total_train, total_nontrain |
|
|
@@ -128,9 +138,9 @@ def _get_file_name_base_on_postfix(dir_path, postfix): |
|
|
|
:param postfix: 形如".bin", ".json"等 |
|
|
|
:return: str,文件的路径 |
|
|
|
""" |
|
|
|
files = list(filter(lambda filename:filename.endswith(postfix), os.listdir(os.path.join(dir_path)))) |
|
|
|
files = list(filter(lambda filename: filename.endswith(postfix), os.listdir(os.path.join(dir_path)))) |
|
|
|
if len(files) == 0: |
|
|
|
raise FileNotFoundError(f"There is no file endswith *{postfix} file in {dir_path}") |
|
|
|
elif len(files) > 1: |
|
|
|
raise FileExistsError(f"There are multiple *{postfix} files in {dir_path}") |
|
|
|
return os.path.join(dir_path, files[0]) |
|
|
|
return os.path.join(dir_path, files[0]) |