|
|
@@ -1,3 +1,5 @@ |
|
|
|
from functools import reduce |
|
|
|
from collections import OrderedDict |
|
|
|
import numpy as np |
|
|
|
import torch |
|
|
|
import torch.nn as nn |
|
|
@@ -71,7 +73,7 @@ def initial_parameter(net, initial_method=None): |
|
|
|
def get_embeddings(init_embed): |
|
|
|
""" |
|
|
|
得到词嵌入 |
|
|
|
|
|
|
|
|
|
|
|
.. todo:: |
|
|
|
补上文档 |
|
|
|
|
|
|
@@ -81,7 +83,8 @@ def get_embeddings(init_embed): |
|
|
|
:return nn.Embedding embeddings: |
|
|
|
""" |
|
|
|
if isinstance(init_embed, tuple): |
|
|
|
res = nn.Embedding(num_embeddings=init_embed[0], embedding_dim=init_embed[1]) |
|
|
|
res = nn.Embedding( |
|
|
|
num_embeddings=init_embed[0], embedding_dim=init_embed[1]) |
|
|
|
elif isinstance(init_embed, nn.Embedding): |
|
|
|
res = init_embed |
|
|
|
elif isinstance(init_embed, torch.Tensor): |
|
|
@@ -90,5 +93,43 @@ def get_embeddings(init_embed): |
|
|
|
init_embed = torch.tensor(init_embed, dtype=torch.float32) |
|
|
|
res = nn.Embedding.from_pretrained(init_embed, freeze=False) |
|
|
|
else: |
|
|
|
raise TypeError('invalid init_embed type: {}'.format((type(init_embed)))) |
|
|
|
raise TypeError( |
|
|
|
'invalid init_embed type: {}'.format((type(init_embed)))) |
|
|
|
return res |
|
|
|
|
|
|
|
|
|
|
|
def summary(model: nn.Module): |
|
|
|
""" |
|
|
|
得到模型的总参数量 |
|
|
|
|
|
|
|
:params model: Pytorch 模型 |
|
|
|
:return tuple: 包含总参数量,可训练参数量,不可训练参数量 |
|
|
|
""" |
|
|
|
train = [] |
|
|
|
nontrain = [] |
|
|
|
|
|
|
|
def layer_summary(module: nn.Module): |
|
|
|
def count_size(sizes): |
|
|
|
return reduce(lambda x, y: x*y, sizes) |
|
|
|
|
|
|
|
for p in module.parameters(recurse=False): |
|
|
|
if p.requires_grad: |
|
|
|
train.append(count_size(p.shape)) |
|
|
|
else: |
|
|
|
nontrain.append(count_size(p.shape)) |
|
|
|
for subm in module.children(): |
|
|
|
layer_summary(subm) |
|
|
|
|
|
|
|
layer_summary(model) |
|
|
|
total_train = sum(train) |
|
|
|
total_nontrain = sum(nontrain) |
|
|
|
total = total_train + total_nontrain |
|
|
|
strings = [] |
|
|
|
strings.append('Total params: {:,}'.format(total)) |
|
|
|
strings.append('Trainable params: {:,}'.format(total_train)) |
|
|
|
strings.append('Non-trainable params: {:,}'.format(total_nontrain)) |
|
|
|
max_len = len(max(strings, key=len)) |
|
|
|
bar = '-'*(max_len + 3) |
|
|
|
strings = [bar] + strings + [bar] |
|
|
|
print('\n'.join(strings)) |
|
|
|
return total, total_train, total_nontrain |