From 84b18890a1d8418043e8f673ac77b172fedc95e2 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Wed, 3 Jul 2019 10:20:15 +0800 Subject: [PATCH] =?UTF-8?q?1.=E5=A2=9E=E5=8A=A0AdamW=E7=9A=84optimizer;2.?= =?UTF-8?q?=E4=BF=AE=E5=A4=8DTrainer=E4=B8=ADmetric=5Fkey=E7=9A=84bug;3.?= =?UTF-8?q?=E9=9D=99=E6=80=81embedding=E5=88=9D=E5=A7=8B=E5=8C=96=E4=BF=AE?= =?UTF-8?q?=E6=94=B9;4.CrossEntropyLoss=E5=A2=9E=E5=8A=A0=E5=AF=B9reductio?= =?UTF-8?q?n=E7=9A=84=E6=94=AF=E6=8C=81'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callback.py | 4 +- fastNLP/core/losses.py | 15 +- fastNLP/core/optimizer.py | 110 ++++++++++++ fastNLP/core/trainer.py | 4 +- fastNLP/io/file_reader.py | 4 +- fastNLP/modules/encoder/_bert.py | 245 +++++++++++++++++++-------- fastNLP/modules/encoder/embedding.py | 59 ++++--- test/test_tutorials.py | 4 +- 8 files changed, 336 insertions(+), 109 deletions(-) diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 5dfd889b..9c6b01d6 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -113,7 +113,7 @@ class Callback(object): @property def n_steps(self): - """Trainer一共会运行多少步""" + """Trainer一共会采多少个batch。当Trainer中update_every设置为非1的值时,该值不等于update的次数""" return self._trainer.n_steps @property @@ -181,7 +181,7 @@ class Callback(object): :param dict batch_x: DataSet中被设置为input的field的batch。 :param dict batch_y: DataSet中被设置为target的field的batch。 :param list(int) indices: 这次采样使用到的indices,可以通过DataSet[indices]获取出这个batch采出的Instance,在一些 - 情况下可以帮助定位是哪个Sample导致了错误。仅在Trainer的prefetch为False时可用。 + 情况下可以帮助定位是哪个Sample导致了错误。仅当num_workers=0时有效。 :return: """ pass diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 526bf37a..46a72802 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -226,6 +226,7 @@ class CrossEntropyLoss(LossBase): :param seq_len: 句子的长度, 长度之外的token不会计算loss。。 :param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替 传入seq_len. + :param str reduction: 支持'elementwise_mean'和'sum'. Example:: @@ -233,21 +234,25 @@ class CrossEntropyLoss(LossBase): """ - def __init__(self, pred=None, target=None, seq_len=None, padding_idx=-100): + def __init__(self, pred=None, target=None, seq_len=None, padding_idx=-100, reduction='elementwise_mean'): super(CrossEntropyLoss, self).__init__() self._init_param_map(pred=pred, target=target, seq_len=seq_len) self.padding_idx = padding_idx + assert reduction in ('elementwise_mean', 'sum') + self.reduction = reduction def get_loss(self, pred, target, seq_len=None): if pred.dim()>2: - pred = pred.view(-1, pred.size(-1)) - target = target.view(-1) + if pred.size(1)!=target.size(1): + pred = pred.transpose(1, 2) + pred = pred.reshape(-1, pred.size(-1)) + target = target.reshape(-1) if seq_len is not None: - mask = seq_len_to_mask(seq_len).view(-1).eq(0) + mask = seq_len_to_mask(seq_len).reshape(-1).eq(0) target = target.masked_fill(mask, self.padding_idx) return F.cross_entropy(input=pred, target=target, - ignore_index=self.padding_idx) + ignore_index=self.padding_idx, reduction=self.reduction) class L1Loss(LossBase): diff --git a/fastNLP/core/optimizer.py b/fastNLP/core/optimizer.py index 0849b35d..1fe035bf 100644 --- a/fastNLP/core/optimizer.py +++ b/fastNLP/core/optimizer.py @@ -9,6 +9,9 @@ __all__ = [ ] import torch +import math +import torch +from torch.optim.optimizer import Optimizer as TorchOptimizer class Optimizer(object): @@ -97,3 +100,110 @@ class Adam(Optimizer): return torch.optim.Adam(self._get_require_grads_param(model_params), **self.settings) else: return torch.optim.Adam(self._get_require_grads_param(self.model_params), **self.settings) + + +class AdamW(TorchOptimizer): + r"""对AdamW的实现,该实现应该会在pytorch更高版本中出现,https://github.com/pytorch/pytorch/pull/21250。这里提前加入 + The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. + The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.99)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=1e-2, amsgrad=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad) + super(AdamW, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdamW, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + # Perform stepweight decay + p.data.mul_(1 - group['lr'] * group['weight_decay']) + + # Perform optimization step + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + amsgrad = group['amsgrad'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(1 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = max_exp_avg_sq.sqrt().add_(group['eps']) + else: + denom = exp_avg_sq.sqrt().add_(group['eps']) + + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 + + p.data.addcdiv_(-step_size, exp_avg, denom) + + return loss diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 6edeb4a0..eabda99c 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -454,7 +454,7 @@ class Trainer(object): if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter): _check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, - metric_key=metric_key, check_level=check_code_level, + metric_key=self.metric_key, check_level=check_code_level, batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) # _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的代码 self.model = _move_model_to_device(model, device=device) @@ -473,7 +473,7 @@ class Trainer(object): self.best_dev_step = None self.best_dev_perf = None self.n_steps = (len(self.train_data) // self.batch_size + int( - len(self.train_data) % self.batch_size != 0)) * self.n_epochs + len(self.train_data) % self.batch_size != 0)) * int(drop_last==0) * self.n_epochs if isinstance(optimizer, torch.optim.Optimizer): self.optimizer = optimizer diff --git a/fastNLP/io/file_reader.py b/fastNLP/io/file_reader.py index 34b5d7c0..0ae0a319 100644 --- a/fastNLP/io/file_reader.py +++ b/fastNLP/io/file_reader.py @@ -104,7 +104,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): except Exception as e: if dropna: continue - raise ValueError('invalid instance at line: {}'.format(line_idx)) + raise ValueError('invalid instance ends at line: {}'.format(line_idx)) elif line.startswith('#'): continue else: @@ -117,5 +117,5 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): except Exception as e: if dropna: return - print('invalid instance at line: {}'.format(line_idx)) + print('invalid instance ends at line: {}'.format(line_idx)) raise e diff --git a/fastNLP/modules/encoder/_bert.py b/fastNLP/modules/encoder/_bert.py index 254917e5..4669b511 100644 --- a/fastNLP/modules/encoder/_bert.py +++ b/fastNLP/modules/encoder/_bert.py @@ -2,7 +2,8 @@ """ -这个页面的代码很大程度上参考了https://github.com/huggingface/pytorch-pretrained-BERT的代码 +这个页面的代码很大程度上参考(复制粘贴)了https://github.com/huggingface/pytorch-pretrained-BERT的代码, 如果你发现该代码对你 + 有用,也请引用一下他们。 """ @@ -11,7 +12,6 @@ from ...core.vocabulary import Vocabulary import collections import unicodedata -from ...io.file_utils import _get_base_url, cached_path import numpy as np from itertools import chain import copy @@ -22,9 +22,105 @@ import os import torch from torch import nn import glob +import sys CONFIG_FILE = 'bert_config.json' -MODEL_WEIGHTS = 'pytorch_model.bin' + +class BertConfig(object): + """Configuration class to store the configuration of a `BertModel`. + """ + def __init__(self, + vocab_size_or_config_json_file, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12): + """Constructs BertConfig. + + Args: + vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. + hidden_size: Size of the encoder layers and the pooler layer. + num_hidden_layers: Number of hidden layers in the Transformer encoder. + num_attention_heads: Number of attention heads for each attention layer in + the Transformer encoder. + intermediate_size: The size of the "intermediate" (i.e., feed-forward) + layer in the Transformer encoder. + hidden_act: The non-linear activation function (function or string) in the + encoder and pooler. If string, "gelu", "relu" and "swish" are supported. + hidden_dropout_prob: The dropout probabilitiy for all fully connected + layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob: The dropout ratio for the attention + probabilities. + max_position_embeddings: The maximum sequence length that this model might + ever be used with. Typically set this to something large just in case + (e.g., 512 or 1024 or 2048). + type_vocab_size: The vocabulary size of the `token_type_ids` passed into + `BertModel`. + initializer_range: The sttdev of the truncated_normal_initializer for + initializing all weight matrices. + layer_norm_eps: The epsilon used by LayerNorm. + """ + if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 + and isinstance(vocab_size_or_config_json_file, unicode)): + with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: + json_config = json.loads(reader.read()) + for key, value in json_config.items(): + self.__dict__[key] = value + elif isinstance(vocab_size_or_config_json_file, int): + self.vocab_size = vocab_size_or_config_json_file + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + else: + raise ValueError("First argument must be either a vocabulary size (int)" + "or the path to a pretrained model config file (str)") + + @classmethod + def from_dict(cls, json_object): + """Constructs a `BertConfig` from a Python dictionary of parameters.""" + config = BertConfig(vocab_size_or_config_json_file=-1) + for key, value in json_object.items(): + config.__dict__[key] = value + return config + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `BertConfig` from a json file of parameters.""" + with open(json_file, "r", encoding='utf-8') as reader: + text = reader.read() + return cls.from_dict(json.loads(text)) + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path): + """ Save this instance to a json file.""" + with open(json_file_path, "w", encoding='utf-8') as writer: + writer.write(self.to_json_string()) def gelu(x): @@ -40,6 +136,8 @@ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} class BertLayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-12): + """Construct a layernorm module in the TF style (epsilon inside the square root). + """ super(BertLayerNorm, self).__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.bias = nn.Parameter(torch.zeros(hidden_size)) @@ -53,16 +151,18 @@ class BertLayerNorm(nn.Module): class BertEmbeddings(nn.Module): - def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, hidden_dropout_prob): + """Construct the embeddings from word, position and token_type embeddings. + """ + def __init__(self, config): super(BertEmbeddings, self).__init__() - self.word_embeddings = nn.Embedding(vocab_size, hidden_size) - self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) - self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # any TensorFlow checkpoint file - self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) - self.dropout = nn.Dropout(hidden_dropout_prob) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, input_ids, token_type_ids=None): seq_length = input_ids.size(1) @@ -82,21 +182,21 @@ class BertEmbeddings(nn.Module): class BertSelfAttention(nn.Module): - def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob): + def __init__(self, config): super(BertSelfAttention, self).__init__() - if hidden_size % num_attention_heads != 0: + if config.hidden_size % config.num_attention_heads != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (hidden_size, num_attention_heads)) - self.num_attention_heads = num_attention_heads - self.attention_head_size = int(hidden_size / num_attention_heads) + "heads (%d)" % (config.hidden_size, config.num_attention_heads)) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size - self.query = nn.Linear(hidden_size, self.all_head_size) - self.key = nn.Linear(hidden_size, self.all_head_size) - self.value = nn.Linear(hidden_size, self.all_head_size) + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) - self.dropout = nn.Dropout(attention_probs_dropout_prob) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) @@ -133,11 +233,11 @@ class BertSelfAttention(nn.Module): class BertSelfOutput(nn.Module): - def __init__(self, hidden_size, hidden_dropout_prob): + def __init__(self, config): super(BertSelfOutput, self).__init__() - self.dense = nn.Linear(hidden_size, hidden_size) - self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) - self.dropout = nn.Dropout(hidden_dropout_prob) + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) @@ -147,10 +247,10 @@ class BertSelfOutput(nn.Module): class BertAttention(nn.Module): - def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): + def __init__(self, config): super(BertAttention, self).__init__() - self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob) - self.output = BertSelfOutput(hidden_size, hidden_dropout_prob) + self.self = BertSelfAttention(config) + self.output = BertSelfOutput(config) def forward(self, input_tensor, attention_mask): self_output = self.self(input_tensor, attention_mask) @@ -159,11 +259,13 @@ class BertAttention(nn.Module): class BertIntermediate(nn.Module): - def __init__(self, hidden_size, intermediate_size, hidden_act): + def __init__(self, config): super(BertIntermediate, self).__init__() - self.dense = nn.Linear(hidden_size, intermediate_size) - self.intermediate_act_fn = ACT2FN[hidden_act] \ - if isinstance(hidden_act, str) else hidden_act + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act def forward(self, hidden_states): hidden_states = self.dense(hidden_states) @@ -172,11 +274,11 @@ class BertIntermediate(nn.Module): class BertOutput(nn.Module): - def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob): + def __init__(self, config): super(BertOutput, self).__init__() - self.dense = nn.Linear(intermediate_size, hidden_size) - self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) - self.dropout = nn.Dropout(hidden_dropout_prob) + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) @@ -186,13 +288,11 @@ class BertOutput(nn.Module): class BertLayer(nn.Module): - def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, - intermediate_size, hidden_act): + def __init__(self, config): super(BertLayer, self).__init__() - self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob, - hidden_dropout_prob) - self.intermediate = BertIntermediate(hidden_size, intermediate_size, hidden_act) - self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob) + self.attention = BertAttention(config) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) def forward(self, hidden_states, attention_mask): attention_output = self.attention(hidden_states, attention_mask) @@ -202,13 +302,10 @@ class BertLayer(nn.Module): class BertEncoder(nn.Module): - def __init__(self, num_hidden_layers, hidden_size, num_attention_heads, attention_probs_dropout_prob, - hidden_dropout_prob, - intermediate_size, hidden_act): + def __init__(self, config): super(BertEncoder, self).__init__() - layer = BertLayer(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, - intermediate_size, hidden_act) - self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_hidden_layers)]) + layer = BertLayer(config) + self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): all_encoder_layers = [] @@ -222,9 +319,9 @@ class BertEncoder(nn.Module): class BertPooler(nn.Module): - def __init__(self, hidden_size): + def __init__(self, config): super(BertPooler, self).__init__() - self.dense = nn.Linear(hidden_size, hidden_size) + self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() def forward(self, hidden_states): @@ -272,34 +369,30 @@ class BertModel(nn.Module): :param int initializer_range: 初始化权重范围,默认值为0.02 """ - def __init__(self, vocab_size=30522, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=2, - initializer_range=0.02): + def __init__(self, config, *inputs, **kwargs): super(BertModel, self).__init__() - self.hidden_size = hidden_size - self.embeddings = BertEmbeddings(vocab_size, hidden_size, max_position_embeddings, - type_vocab_size, hidden_dropout_prob) - self.encoder = BertEncoder(num_hidden_layers, hidden_size, num_attention_heads, - attention_probs_dropout_prob, hidden_dropout_prob, intermediate_size, - hidden_act) - self.pooler = BertPooler(hidden_size) - self.initializer_range = initializer_range - + if not isinstance(config, BertConfig): + raise ValueError( + "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " + "To create a model from a Google pretrained model use " + "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( + self.__class__.__name__, self.__class__.__name__ + )) + super(BertModel, self).__init__() + self.config = config + self.hidden_size = self.config.hidden_size + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + self.pooler = BertPooler(config) self.apply(self.init_bert_weights) def init_bert_weights(self, module): + """ Initialize the weights. + """ if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.initializer_range) + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, BertLayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) @@ -338,14 +431,19 @@ class BertModel(nn.Module): return encoded_layers, pooled_output @classmethod - def from_pretrained(cls, pretrained_model_dir, state_dict=None, *inputs, **kwargs): + def from_pretrained(cls, pretrained_model_dir, *inputs, **kwargs): + state_dict = kwargs.get('state_dict', None) + kwargs.pop('state_dict', None) + cache_dir = kwargs.get('cache_dir', None) + kwargs.pop('cache_dir', None) + from_tf = kwargs.get('from_tf', False) + kwargs.pop('from_tf', None) # Load config config_file = os.path.join(pretrained_model_dir, CONFIG_FILE) - config = json.load(open(config_file, "r")) - # config = BertConfig.from_json_file(config_file) + config = BertConfig.from_json_file(config_file) # logger.info("Model config {}".format(config)) # Instantiate model. - model = cls(*inputs, **config, **kwargs) + model = cls(config, *inputs, **kwargs) if state_dict is None: files = glob.glob(os.path.join(pretrained_model_dir, '*.bin')) if len(files)==0: @@ -353,7 +451,7 @@ class BertModel(nn.Module): elif len(files)>1: raise FileExistsError(f"There are multiple *.bin files in {pretrained_model_dir}") weights_path = files[0] - state_dict = torch.load(weights_path) + state_dict = torch.load(weights_path, map_location='cpu') old_keys = [] new_keys = [] @@ -840,6 +938,7 @@ class _WordBertModel(nn.Module): word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i]])) word_pieces[i, 1:len(word_pieces_i)+1] = torch.LongTensor(word_pieces_i) attn_masks[i, :len(word_pieces_i)+2].fill_(1) + # TODO 截掉长度超过的部分。 # 2. 获取hidden的结果,根据word_pieces进行对应的pool计算 # all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...] bert_outputs, _ = self.encoder(word_pieces, token_type_ids=None, attention_mask=attn_masks, diff --git a/fastNLP/modules/encoder/embedding.py b/fastNLP/modules/encoder/embedding.py index c48cb806..005cfe75 100644 --- a/fastNLP/modules/encoder/embedding.py +++ b/fastNLP/modules/encoder/embedding.py @@ -202,18 +202,12 @@ class StaticEmbedding(TokenEmbedding): raise ValueError(f"Cannot recognize {model_dir_or_name}.") # 读取embedding - embedding, hit_flags = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method, + embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method, normalize=normalize) self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1], padding_idx=vocab.padding_idx, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False, _weight=embedding) - if vocab._no_create_word_length > 0: # 需要映射,使得来自于dev, test的idx指向unk - words_to_words = nn.Parameter(torch.arange(len(vocab)).long(), requires_grad=False) - for word, idx in vocab: - if vocab._is_word_no_create_entry(word) and not hit_flags[idx]: - words_to_words[idx] = vocab.unknown_idx - self.words_to_words = words_to_words self._embed_size = self.embedding.weight.size(1) self.requires_grad = requires_grad @@ -268,10 +262,8 @@ class StaticEmbedding(TokenEmbedding): else: dim = len(parts) - 1 f.seek(0) - matrix = torch.zeros(len(vocab), dim) - if init_method is not None: - init_method(matrix) - hit_flags = np.zeros(len(vocab), dtype=bool) + matrix = {} + found_count = 0 for idx, line in enumerate(f, start_idx): try: parts = line.strip().split() @@ -285,28 +277,49 @@ class StaticEmbedding(TokenEmbedding): if word in vocab: index = vocab.to_index(word) matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim)) - hit_flags[index] = True + found_count += 1 except Exception as e: if error == 'ignore': warnings.warn("Error occurred at the {} line.".format(idx)) else: print("Error occurred at the {} line.".format(idx)) raise e - found_count = sum(hit_flags) print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) - if init_method is None: - if len(vocab)-found_count>0 and found_count>0: # 有的没找到 - found_vecs = matrix[torch.LongTensor(hit_flags.astype(int)).byte()] - mean = found_vecs.mean(dim=0, keepdim=True) - std = found_vecs.std(dim=0, keepdim=True) - unfound_vec_num = np.sum(hit_flags==False) - unfound_vecs = torch.randn(unfound_vec_num, dim)*std + mean - matrix[torch.LongTensor(hit_flags.astype(int)).eq(0)] = unfound_vecs + for word, index in vocab: + if index not in matrix and not vocab._is_word_no_create_entry(word): + if vocab.unknown_idx in matrix: # 如果有unkonwn,用unknown初始化 + matrix[index] = matrix[vocab.unknown_idx] + else: + matrix[index] = None + + vectors = torch.zeros(len(matrix), dim) + if init_method: + init_method(vectors) + else: + nn.init.uniform_(vectors, -np.sqrt(3/dim), np.sqrt(3/dim)) + + if vocab._no_create_word_length>0: + if vocab.unknown is None: # 创建一个专门的unknown + unknown_idx = len(matrix) + vectors = torch.cat([vectors, torch.zeros(1, dim)], dim=0).contiguous() + else: + unknown_idx = vocab.unknown_idx + words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(), + requires_grad=False) + for order, (index, vec) in enumerate(matrix.items()): + if vec is not None: + vectors[order] = vec + words_to_words[index] = order + self.words_to_words = words_to_words + else: + for index, vec in matrix.items(): + if vec is not None: + vectors[index] = vec if normalize: - matrix /= (torch.norm(matrix, dim=1, keepdim=True) + 1e-12) + vectors /= (torch.norm(vectors, dim=1, keepdim=True) + 1e-12) - return matrix, hit_flags + return vectors def forward(self, words): """ diff --git a/test/test_tutorials.py b/test/test_tutorials.py index 87910c3d..6f4a8347 100644 --- a/test/test_tutorials.py +++ b/test/test_tutorials.py @@ -79,7 +79,7 @@ class TestTutorial(unittest.TestCase): train_data.rename_field('label', 'label_seq') test_data.rename_field('label', 'label_seq') - loss = CrossEntropyLoss(pred="output", target="label_seq") + loss = CrossEntropyLoss(target="label_seq") metric = AccuracyMetric(target="label_seq") # 实例化Trainer,传入模型和数据,进行训练 @@ -91,7 +91,7 @@ class TestTutorial(unittest.TestCase): # 用train_data训练,在test_data验证 trainer = Trainer(model=model, train_data=train_data, dev_data=test_data, - loss=CrossEntropyLoss(pred="output", target="label_seq"), + loss=CrossEntropyLoss(target="label_seq"), metrics=AccuracyMetric(target="label_seq"), save_path=None, batch_size=32,