diff --git a/README.md b/README.md
index c1b5db3e..bb62fc38 100644
--- a/README.md
+++ b/README.md
@@ -94,3 +94,6 @@ Check out models' performance, usage and source code here.
readers & savers |
+
+
+*In memory of @FengZiYjun. May his soul rest in peace. We will miss you very very much!*
\ No newline at end of file
diff --git a/fastNLP/modules/encoder/variational_rnn.py b/fastNLP/modules/encoder/variational_rnn.py
index b3858020..e0dd9d90 100644
--- a/fastNLP/modules/encoder/variational_rnn.py
+++ b/fastNLP/modules/encoder/variational_rnn.py
@@ -148,11 +148,10 @@ class VarRNNBase(nn.Module):
seq_len = x.size(1) if self.batch_first else x.size(0)
max_batch_size = x.size(0) if self.batch_first else x.size(1)
seq_lens = torch.LongTensor([seq_len for _ in range(max_batch_size)])
- _tmp = pack_padded_sequence(x, seq_lens, batch_first=self.batch_first)
- x, batch_sizes = _tmp.data, _tmp.batch_sizes
+ input = pack_padded_sequence(input, seq_lens, batch_first=self.batch_first)
else:
- max_batch_size = int(x.batch_sizes[0])
- x, batch_sizes = x.data, x.batch_sizes
+ max_batch_size = int(input.batch_sizes[0])
+ input, batch_sizes = input.data, input.batch_sizes
if hx is None:
hx = x.new_zeros(self.num_layers * self.num_directions,
diff --git a/fastNLP/modules/utils.py b/fastNLP/modules/utils.py
index 3dfe1969..0aba7e62 100644
--- a/fastNLP/modules/utils.py
+++ b/fastNLP/modules/utils.py
@@ -1,3 +1,5 @@
+from functools import reduce
+from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn
@@ -78,7 +80,8 @@ def get_embeddings(init_embed):
:return nn.Embedding embeddings:
"""
if isinstance(init_embed, tuple):
- res = nn.Embedding(num_embeddings=init_embed[0], embedding_dim=init_embed[1])
+ res = nn.Embedding(
+ num_embeddings=init_embed[0], embedding_dim=init_embed[1])
elif isinstance(init_embed, nn.Embedding):
res = init_embed
elif isinstance(init_embed, torch.Tensor):
@@ -87,5 +90,43 @@ def get_embeddings(init_embed):
init_embed = torch.tensor(init_embed, dtype=torch.float32)
res = nn.Embedding.from_pretrained(init_embed, freeze=False)
else:
- raise TypeError('invalid init_embed type: {}'.format((type(init_embed))))
+ raise TypeError(
+ 'invalid init_embed type: {}'.format((type(init_embed))))
return res
+
+
+def summary(model: nn.Module):
+ """
+ 得到模型的总参数量
+
+ :params model: Pytorch 模型
+ :return tuple: 包含总参数量,可训练参数量,不可训练参数量
+ """
+ train = []
+ nontrain = []
+
+ def layer_summary(module: nn.Module):
+ def count_size(sizes):
+ return reduce(lambda x, y: x*y, sizes)
+
+ for p in module.parameters(recurse=False):
+ if p.requires_grad:
+ train.append(count_size(p.shape))
+ else:
+ nontrain.append(count_size(p.shape))
+ for subm in module.children():
+ layer_summary(subm)
+
+ layer_summary(model)
+ total_train = sum(train)
+ total_nontrain = sum(nontrain)
+ total = total_train + total_nontrain
+ strings = []
+ strings.append('Total params: {:,}'.format(total))
+ strings.append('Trainable params: {:,}'.format(total_train))
+ strings.append('Non-trainable params: {:,}'.format(total_nontrain))
+ max_len = len(max(strings, key=len))
+ bar = '-'*(max_len + 3)
+ strings = [bar] + strings + [bar]
+ print('\n'.join(strings))
+ return total, total_train, total_nontrain