diff --git a/README.md b/README.md index c1b5db3e..bb62fc38 100644 --- a/README.md +++ b/README.md @@ -94,3 +94,6 @@ Check out models' performance, usage and source code here. readers & savers + + +*In memory of @FengZiYjun. May his soul rest in peace. We will miss you very very much!* \ No newline at end of file diff --git a/fastNLP/modules/encoder/variational_rnn.py b/fastNLP/modules/encoder/variational_rnn.py index b3858020..e0dd9d90 100644 --- a/fastNLP/modules/encoder/variational_rnn.py +++ b/fastNLP/modules/encoder/variational_rnn.py @@ -148,11 +148,10 @@ class VarRNNBase(nn.Module): seq_len = x.size(1) if self.batch_first else x.size(0) max_batch_size = x.size(0) if self.batch_first else x.size(1) seq_lens = torch.LongTensor([seq_len for _ in range(max_batch_size)]) - _tmp = pack_padded_sequence(x, seq_lens, batch_first=self.batch_first) - x, batch_sizes = _tmp.data, _tmp.batch_sizes + input = pack_padded_sequence(input, seq_lens, batch_first=self.batch_first) else: - max_batch_size = int(x.batch_sizes[0]) - x, batch_sizes = x.data, x.batch_sizes + max_batch_size = int(input.batch_sizes[0]) + input, batch_sizes = input.data, input.batch_sizes if hx is None: hx = x.new_zeros(self.num_layers * self.num_directions, diff --git a/fastNLP/modules/utils.py b/fastNLP/modules/utils.py index 3dfe1969..0aba7e62 100644 --- a/fastNLP/modules/utils.py +++ b/fastNLP/modules/utils.py @@ -1,3 +1,5 @@ +from functools import reduce +from collections import OrderedDict import numpy as np import torch import torch.nn as nn @@ -78,7 +80,8 @@ def get_embeddings(init_embed): :return nn.Embedding embeddings: """ if isinstance(init_embed, tuple): - res = nn.Embedding(num_embeddings=init_embed[0], embedding_dim=init_embed[1]) + res = nn.Embedding( + num_embeddings=init_embed[0], embedding_dim=init_embed[1]) elif isinstance(init_embed, nn.Embedding): res = init_embed elif isinstance(init_embed, torch.Tensor): @@ -87,5 +90,43 @@ def get_embeddings(init_embed): init_embed = torch.tensor(init_embed, dtype=torch.float32) res = nn.Embedding.from_pretrained(init_embed, freeze=False) else: - raise TypeError('invalid init_embed type: {}'.format((type(init_embed)))) + raise TypeError( + 'invalid init_embed type: {}'.format((type(init_embed)))) return res + + +def summary(model: nn.Module): + """ + 得到模型的总参数量 + + :params model: Pytorch 模型 + :return tuple: 包含总参数量,可训练参数量,不可训练参数量 + """ + train = [] + nontrain = [] + + def layer_summary(module: nn.Module): + def count_size(sizes): + return reduce(lambda x, y: x*y, sizes) + + for p in module.parameters(recurse=False): + if p.requires_grad: + train.append(count_size(p.shape)) + else: + nontrain.append(count_size(p.shape)) + for subm in module.children(): + layer_summary(subm) + + layer_summary(model) + total_train = sum(train) + total_nontrain = sum(nontrain) + total = total_train + total_nontrain + strings = [] + strings.append('Total params: {:,}'.format(total)) + strings.append('Trainable params: {:,}'.format(total_train)) + strings.append('Non-trainable params: {:,}'.format(total_nontrain)) + max_len = len(max(strings, key=len)) + bar = '-'*(max_len + 3) + strings = [bar] + strings + [bar] + print('\n'.join(strings)) + return total, total_train, total_nontrain