Browse Source

conflict merge

tags/v0.4.10
yh_cc 5 years ago
parent
commit
802ad8d1d3
3 changed files with 49 additions and 6 deletions
  1. +3
    -0
      README.md
  2. +3
    -4
      fastNLP/modules/encoder/variational_rnn.py
  3. +43
    -2
      fastNLP/modules/utils.py

+ 3
- 0
README.md View File

@@ -94,3 +94,6 @@ Check out models' performance, usage and source code here.
<td> readers & savers </td>
</tr>
</table>


*In memory of @FengZiYjun. May his soul rest in peace. We will miss you very very much!*

+ 3
- 4
fastNLP/modules/encoder/variational_rnn.py View File

@@ -148,11 +148,10 @@ class VarRNNBase(nn.Module):
seq_len = x.size(1) if self.batch_first else x.size(0)
max_batch_size = x.size(0) if self.batch_first else x.size(1)
seq_lens = torch.LongTensor([seq_len for _ in range(max_batch_size)])
_tmp = pack_padded_sequence(x, seq_lens, batch_first=self.batch_first)
x, batch_sizes = _tmp.data, _tmp.batch_sizes
input = pack_padded_sequence(input, seq_lens, batch_first=self.batch_first)
else:
max_batch_size = int(x.batch_sizes[0])
x, batch_sizes = x.data, x.batch_sizes
max_batch_size = int(input.batch_sizes[0])
input, batch_sizes = input.data, input.batch_sizes

if hx is None:
hx = x.new_zeros(self.num_layers * self.num_directions,


+ 43
- 2
fastNLP/modules/utils.py View File

@@ -1,3 +1,5 @@
from functools import reduce
from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn
@@ -78,7 +80,8 @@ def get_embeddings(init_embed):
:return nn.Embedding embeddings:
"""
if isinstance(init_embed, tuple):
res = nn.Embedding(num_embeddings=init_embed[0], embedding_dim=init_embed[1])
res = nn.Embedding(
num_embeddings=init_embed[0], embedding_dim=init_embed[1])
elif isinstance(init_embed, nn.Embedding):
res = init_embed
elif isinstance(init_embed, torch.Tensor):
@@ -87,5 +90,43 @@ def get_embeddings(init_embed):
init_embed = torch.tensor(init_embed, dtype=torch.float32)
res = nn.Embedding.from_pretrained(init_embed, freeze=False)
else:
raise TypeError('invalid init_embed type: {}'.format((type(init_embed))))
raise TypeError(
'invalid init_embed type: {}'.format((type(init_embed))))
return res


def summary(model: nn.Module):
"""
得到模型的总参数量

:params model: Pytorch 模型
:return tuple: 包含总参数量,可训练参数量,不可训练参数量
"""
train = []
nontrain = []

def layer_summary(module: nn.Module):
def count_size(sizes):
return reduce(lambda x, y: x*y, sizes)

for p in module.parameters(recurse=False):
if p.requires_grad:
train.append(count_size(p.shape))
else:
nontrain.append(count_size(p.shape))
for subm in module.children():
layer_summary(subm)

layer_summary(model)
total_train = sum(train)
total_nontrain = sum(nontrain)
total = total_train + total_nontrain
strings = []
strings.append('Total params: {:,}'.format(total))
strings.append('Trainable params: {:,}'.format(total_train))
strings.append('Non-trainable params: {:,}'.format(total_nontrain))
max_len = len(max(strings, key=len))
bar = '-'*(max_len + 3)
strings = [bar] + strings + [bar]
print('\n'.join(strings))
return total, total_train, total_nontrain

Loading…
Cancel
Save