|
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.init as init
-
-
- def initial_parameter(net, initial_method=None):
- """A method used to initialize the weights of PyTorch models.
-
- :param net: a PyTorch model
- :param str initial_method: one of the following initializations.
-
- - xavier_uniform
- - xavier_normal (default)
- - kaiming_normal, or msra
- - kaiming_uniform
- - orthogonal
- - sparse
- - normal
- - uniform
-
- """
- if initial_method == 'xavier_uniform':
- init_method = init.xavier_uniform_
- elif initial_method == 'xavier_normal':
- init_method = init.xavier_normal_
- elif initial_method == 'kaiming_normal' or initial_method == 'msra':
- init_method = init.kaiming_normal_
- elif initial_method == 'kaiming_uniform':
- init_method = init.kaiming_uniform_
- elif initial_method == 'orthogonal':
- init_method = init.orthogonal_
- elif initial_method == 'sparse':
- init_method = init.sparse_
- elif initial_method == 'normal':
- init_method = init.normal_
- elif initial_method == 'uniform':
- init_method = init.uniform_
- else:
- init_method = init.xavier_normal_
-
- def weights_init(m):
- # classname = m.__class__.__name__
- if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv3d): # for all the cnn
- if initial_method is not None:
- init_method(m.weight.data)
- else:
- init.xavier_normal_(m.weight.data)
- init.normal_(m.bias.data)
- elif isinstance(m, nn.LSTM):
- for w in m.parameters():
- if len(w.data.size()) > 1:
- init_method(w.data) # weight
- else:
- init.normal_(w.data) # bias
- elif m is not None and hasattr(m, 'weight') and \
- hasattr(m.weight, "requires_grad"):
- init_method(m.weight.data)
- else:
- for w in m.parameters():
- if w.requires_grad:
- if len(w.data.size()) > 1:
- init_method(w.data) # weight
- else:
- init.normal_(w.data) # bias
- # print("init else")
-
- net.apply(weights_init)
-
-
- def seq_mask(seq_len, max_len):
- """
- Create sequence mask.
-
- :param seq_len: list or torch.Tensor, the lengths of sequences in a batch.
- :param max_len: int, the maximum sequence length in a batch.
- :return: mask, torch.LongTensor, [batch_size, max_len]
-
- """
- if not isinstance(seq_len, torch.Tensor):
- seq_len = torch.LongTensor(seq_len)
- seq_len = seq_len.view(-1, 1).long() # [batch_size, 1]
- seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len]
- return torch.gt(seq_len, seq_range) # [batch_size, max_len]
-
-
- def get_embeddings(init_embed):
- """
- 得到词嵌入 TODO
-
- :param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
- embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象,
- 此时就以传入的对象作为embedding
- :return nn.Embedding embeddings:
- """
- if isinstance(init_embed, tuple):
- res = nn.Embedding(num_embeddings=init_embed[0], embedding_dim=init_embed[1])
- elif isinstance(init_embed, nn.Embedding):
- res = init_embed
- elif isinstance(init_embed, torch.Tensor):
- res = nn.Embedding.from_pretrained(init_embed, freeze=False)
- elif isinstance(init_embed, np.ndarray):
- init_embed = torch.tensor(init_embed, dtype=torch.float32)
- res = nn.Embedding.from_pretrained(init_embed, freeze=False)
- else:
- raise TypeError('invalid init_embed type: {}'.format((type(init_embed))))
- return res
|