Browse Source

1.增加AdamW的optimizer;2.修复Trainer中metric_key的bug;3.静态embedding初始化修改;4.CrossEntropyLoss增加对reduction的支持'

tags/v0.4.10
yh_cc 5 years ago
parent
commit
84b18890a1
8 changed files with 336 additions and 109 deletions
  1. +2
    -2
      fastNLP/core/callback.py
  2. +10
    -5
      fastNLP/core/losses.py
  3. +110
    -0
      fastNLP/core/optimizer.py
  4. +2
    -2
      fastNLP/core/trainer.py
  5. +2
    -2
      fastNLP/io/file_reader.py
  6. +172
    -73
      fastNLP/modules/encoder/_bert.py
  7. +36
    -23
      fastNLP/modules/encoder/embedding.py
  8. +2
    -2
      test/test_tutorials.py

+ 2
- 2
fastNLP/core/callback.py View File

@@ -113,7 +113,7 @@ class Callback(object):
@property
def n_steps(self):
"""Trainer一共会运行多少步"""
"""Trainer一共会采多少个batch。当Trainer中update_every设置为非1的值时,该值不等于update的次数"""
return self._trainer.n_steps
@property
@@ -181,7 +181,7 @@ class Callback(object):
:param dict batch_x: DataSet中被设置为input的field的batch。
:param dict batch_y: DataSet中被设置为target的field的batch。
:param list(int) indices: 这次采样使用到的indices,可以通过DataSet[indices]获取出这个batch采出的Instance,在一些
情况下可以帮助定位是哪个Sample导致了错误。仅在Trainer的prefetch为False时可用
情况下可以帮助定位是哪个Sample导致了错误。仅当num_workers=0时有效
:return:
"""
pass


+ 10
- 5
fastNLP/core/losses.py View File

@@ -226,6 +226,7 @@ class CrossEntropyLoss(LossBase):
:param seq_len: 句子的长度, 长度之外的token不会计算loss。。
:param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替
传入seq_len.
:param str reduction: 支持'elementwise_mean'和'sum'.

Example::

@@ -233,21 +234,25 @@ class CrossEntropyLoss(LossBase):
"""
def __init__(self, pred=None, target=None, seq_len=None, padding_idx=-100):
def __init__(self, pred=None, target=None, seq_len=None, padding_idx=-100, reduction='elementwise_mean'):
super(CrossEntropyLoss, self).__init__()
self._init_param_map(pred=pred, target=target, seq_len=seq_len)
self.padding_idx = padding_idx
assert reduction in ('elementwise_mean', 'sum')
self.reduction = reduction
def get_loss(self, pred, target, seq_len=None):
if pred.dim()>2:
pred = pred.view(-1, pred.size(-1))
target = target.view(-1)
if pred.size(1)!=target.size(1):
pred = pred.transpose(1, 2)
pred = pred.reshape(-1, pred.size(-1))
target = target.reshape(-1)
if seq_len is not None:
mask = seq_len_to_mask(seq_len).view(-1).eq(0)
mask = seq_len_to_mask(seq_len).reshape(-1).eq(0)
target = target.masked_fill(mask, self.padding_idx)

return F.cross_entropy(input=pred, target=target,
ignore_index=self.padding_idx)
ignore_index=self.padding_idx, reduction=self.reduction)


class L1Loss(LossBase):


+ 110
- 0
fastNLP/core/optimizer.py View File

@@ -9,6 +9,9 @@ __all__ = [
]

import torch
import math
import torch
from torch.optim.optimizer import Optimizer as TorchOptimizer


class Optimizer(object):
@@ -97,3 +100,110 @@ class Adam(Optimizer):
return torch.optim.Adam(self._get_require_grads_param(model_params), **self.settings)
else:
return torch.optim.Adam(self._get_require_grads_param(self.model_params), **self.settings)


class AdamW(TorchOptimizer):
r"""对AdamW的实现,该实现应该会在pytorch更高版本中出现,https://github.com/pytorch/pytorch/pull/21250。这里提前加入
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.99))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""

def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=1e-2, amsgrad=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad)
super(AdamW, self).__init__(params, defaults)

def __setstate__(self, state):
super(AdamW, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)

def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue

# Perform stepweight decay
p.data.mul_(1 - group['lr'] * group['weight_decay'])

# Perform optimization step
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
amsgrad = group['amsgrad']

state = self.state[p]

# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p.data)

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad:
max_exp_avg_sq = state['max_exp_avg_sq']
beta1, beta2 = group['betas']

state['step'] += 1

# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
else:
denom = exp_avg_sq.sqrt().add_(group['eps'])

bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1

p.data.addcdiv_(-step_size, exp_avg, denom)

return loss

+ 2
- 2
fastNLP/core/trainer.py View File

@@ -454,7 +454,7 @@ class Trainer(object):

if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter):
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data,
metric_key=metric_key, check_level=check_code_level,
metric_key=self.metric_key, check_level=check_code_level,
batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE))
# _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的代码
self.model = _move_model_to_device(model, device=device)
@@ -473,7 +473,7 @@ class Trainer(object):
self.best_dev_step = None
self.best_dev_perf = None
self.n_steps = (len(self.train_data) // self.batch_size + int(
len(self.train_data) % self.batch_size != 0)) * self.n_epochs
len(self.train_data) % self.batch_size != 0)) * int(drop_last==0) * self.n_epochs

if isinstance(optimizer, torch.optim.Optimizer):
self.optimizer = optimizer


+ 2
- 2
fastNLP/io/file_reader.py View File

@@ -104,7 +104,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
except Exception as e:
if dropna:
continue
raise ValueError('invalid instance at line: {}'.format(line_idx))
raise ValueError('invalid instance ends at line: {}'.format(line_idx))
elif line.startswith('#'):
continue
else:
@@ -117,5 +117,5 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
except Exception as e:
if dropna:
return
print('invalid instance at line: {}'.format(line_idx))
print('invalid instance ends at line: {}'.format(line_idx))
raise e

+ 172
- 73
fastNLP/modules/encoder/_bert.py View File

@@ -2,7 +2,8 @@


"""
这个页面的代码很大程度上参考了https://github.com/huggingface/pytorch-pretrained-BERT的代码
这个页面的代码很大程度上参考(复制粘贴)了https://github.com/huggingface/pytorch-pretrained-BERT的代码, 如果你发现该代码对你
有用,也请引用一下他们。
"""


@@ -11,7 +12,6 @@ from ...core.vocabulary import Vocabulary
import collections

import unicodedata
from ...io.file_utils import _get_base_url, cached_path
import numpy as np
from itertools import chain
import copy
@@ -22,9 +22,105 @@ import os
import torch
from torch import nn
import glob
import sys

CONFIG_FILE = 'bert_config.json'
MODEL_WEIGHTS = 'pytorch_model.bin'

class BertConfig(object):
"""Configuration class to store the configuration of a `BertModel`.
"""
def __init__(self,
vocab_size_or_config_json_file,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12):
"""Constructs BertConfig.

Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
hidden_dropout_prob: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`BertModel`.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
layer_norm_eps: The epsilon used by LayerNorm.
"""
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")

@classmethod
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = BertConfig(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config

@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))

def __repr__(self):
return str(self.to_json_string())

def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output

def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"

def to_json_file(self, json_file_path):
""" Save this instance to a json file."""
with open(json_file_path, "w", encoding='utf-8') as writer:
writer.write(self.to_json_string())


def gelu(x):
@@ -40,6 +136,8 @@ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}

class BertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(BertLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
@@ -53,16 +151,18 @@ class BertLayerNorm(nn.Module):


class BertEmbeddings(nn.Module):
def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, hidden_dropout_prob):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config):
super(BertEmbeddings, self).__init__()
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size)
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, input_ids, token_type_ids=None):
seq_length = input_ids.size(1)
@@ -82,21 +182,21 @@ class BertEmbeddings(nn.Module):


class BertSelfAttention(nn.Module):
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob):
def __init__(self, config):
super(BertSelfAttention, self).__init__()
if hidden_size % num_attention_heads != 0:
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, num_attention_heads))
self.num_attention_heads = num_attention_heads
self.attention_head_size = int(hidden_size / num_attention_heads)
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size

self.query = nn.Linear(hidden_size, self.all_head_size)
self.key = nn.Linear(hidden_size, self.all_head_size)
self.value = nn.Linear(hidden_size, self.all_head_size)
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)

self.dropout = nn.Dropout(attention_probs_dropout_prob)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
@@ -133,11 +233,11 @@ class BertSelfAttention(nn.Module):


class BertSelfOutput(nn.Module):
def __init__(self, hidden_size, hidden_dropout_prob):
def __init__(self, config):
super(BertSelfOutput, self).__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
@@ -147,10 +247,10 @@ class BertSelfOutput(nn.Module):


class BertAttention(nn.Module):
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob):
def __init__(self, config):
super(BertAttention, self).__init__()
self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob)
self.output = BertSelfOutput(hidden_size, hidden_dropout_prob)
self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config)

def forward(self, input_tensor, attention_mask):
self_output = self.self(input_tensor, attention_mask)
@@ -159,11 +259,13 @@ class BertAttention(nn.Module):


class BertIntermediate(nn.Module):
def __init__(self, hidden_size, intermediate_size, hidden_act):
def __init__(self, config):
super(BertIntermediate, self).__init__()
self.dense = nn.Linear(hidden_size, intermediate_size)
self.intermediate_act_fn = ACT2FN[hidden_act] \
if isinstance(hidden_act, str) else hidden_act
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act

def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
@@ -172,11 +274,11 @@ class BertIntermediate(nn.Module):


class BertOutput(nn.Module):
def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob):
def __init__(self, config):
super(BertOutput, self).__init__()
self.dense = nn.Linear(intermediate_size, hidden_size)
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12)
self.dropout = nn.Dropout(hidden_dropout_prob)
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
@@ -186,13 +288,11 @@ class BertOutput(nn.Module):


class BertLayer(nn.Module):
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob,
intermediate_size, hidden_act):
def __init__(self, config):
super(BertLayer, self).__init__()
self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob,
hidden_dropout_prob)
self.intermediate = BertIntermediate(hidden_size, intermediate_size, hidden_act)
self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob)
self.attention = BertAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)

def forward(self, hidden_states, attention_mask):
attention_output = self.attention(hidden_states, attention_mask)
@@ -202,13 +302,10 @@ class BertLayer(nn.Module):


class BertEncoder(nn.Module):
def __init__(self, num_hidden_layers, hidden_size, num_attention_heads, attention_probs_dropout_prob,
hidden_dropout_prob,
intermediate_size, hidden_act):
def __init__(self, config):
super(BertEncoder, self).__init__()
layer = BertLayer(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob,
intermediate_size, hidden_act)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_hidden_layers)])
layer = BertLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])

def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
all_encoder_layers = []
@@ -222,9 +319,9 @@ class BertEncoder(nn.Module):


class BertPooler(nn.Module):
def __init__(self, hidden_size):
def __init__(self, config):
super(BertPooler, self).__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()

def forward(self, hidden_states):
@@ -272,34 +369,30 @@ class BertModel(nn.Module):
:param int initializer_range: 初始化权重范围,默认值为0.02
"""

def __init__(self, vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02):
def __init__(self, config, *inputs, **kwargs):
super(BertModel, self).__init__()
self.hidden_size = hidden_size
self.embeddings = BertEmbeddings(vocab_size, hidden_size, max_position_embeddings,
type_vocab_size, hidden_dropout_prob)
self.encoder = BertEncoder(num_hidden_layers, hidden_size, num_attention_heads,
attention_probs_dropout_prob, hidden_dropout_prob, intermediate_size,
hidden_act)
self.pooler = BertPooler(hidden_size)
self.initializer_range = initializer_range

if not isinstance(config, BertConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
super(BertModel, self).__init__()
self.config = config
self.hidden_size = self.config.hidden_size
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config)
self.apply(self.init_bert_weights)

def init_bert_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.initializer_range)
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, BertLayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
@@ -338,14 +431,19 @@ class BertModel(nn.Module):
return encoded_layers, pooled_output

@classmethod
def from_pretrained(cls, pretrained_model_dir, state_dict=None, *inputs, **kwargs):
def from_pretrained(cls, pretrained_model_dir, *inputs, **kwargs):
state_dict = kwargs.get('state_dict', None)
kwargs.pop('state_dict', None)
cache_dir = kwargs.get('cache_dir', None)
kwargs.pop('cache_dir', None)
from_tf = kwargs.get('from_tf', False)
kwargs.pop('from_tf', None)
# Load config
config_file = os.path.join(pretrained_model_dir, CONFIG_FILE)
config = json.load(open(config_file, "r"))
# config = BertConfig.from_json_file(config_file)
config = BertConfig.from_json_file(config_file)
# logger.info("Model config {}".format(config))
# Instantiate model.
model = cls(*inputs, **config, **kwargs)
model = cls(config, *inputs, **kwargs)
if state_dict is None:
files = glob.glob(os.path.join(pretrained_model_dir, '*.bin'))
if len(files)==0:
@@ -353,7 +451,7 @@ class BertModel(nn.Module):
elif len(files)>1:
raise FileExistsError(f"There are multiple *.bin files in {pretrained_model_dir}")
weights_path = files[0]
state_dict = torch.load(weights_path)
state_dict = torch.load(weights_path, map_location='cpu')

old_keys = []
new_keys = []
@@ -840,6 +938,7 @@ class _WordBertModel(nn.Module):
word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i]]))
word_pieces[i, 1:len(word_pieces_i)+1] = torch.LongTensor(word_pieces_i)
attn_masks[i, :len(word_pieces_i)+2].fill_(1)
# TODO 截掉长度超过的部分。
# 2. 获取hidden的结果,根据word_pieces进行对应的pool计算
# all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...]
bert_outputs, _ = self.encoder(word_pieces, token_type_ids=None, attention_mask=attn_masks,


+ 36
- 23
fastNLP/modules/encoder/embedding.py View File

@@ -202,18 +202,12 @@ class StaticEmbedding(TokenEmbedding):
raise ValueError(f"Cannot recognize {model_dir_or_name}.")

# 读取embedding
embedding, hit_flags = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method,
embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method,
normalize=normalize)
self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1],
padding_idx=vocab.padding_idx,
max_norm=None, norm_type=2, scale_grad_by_freq=False,
sparse=False, _weight=embedding)
if vocab._no_create_word_length > 0: # 需要映射,使得来自于dev, test的idx指向unk
words_to_words = nn.Parameter(torch.arange(len(vocab)).long(), requires_grad=False)
for word, idx in vocab:
if vocab._is_word_no_create_entry(word) and not hit_flags[idx]:
words_to_words[idx] = vocab.unknown_idx
self.words_to_words = words_to_words
self._embed_size = self.embedding.weight.size(1)
self.requires_grad = requires_grad

@@ -268,10 +262,8 @@ class StaticEmbedding(TokenEmbedding):
else:
dim = len(parts) - 1
f.seek(0)
matrix = torch.zeros(len(vocab), dim)
if init_method is not None:
init_method(matrix)
hit_flags = np.zeros(len(vocab), dtype=bool)
matrix = {}
found_count = 0
for idx, line in enumerate(f, start_idx):
try:
parts = line.strip().split()
@@ -285,28 +277,49 @@ class StaticEmbedding(TokenEmbedding):
if word in vocab:
index = vocab.to_index(word)
matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim))
hit_flags[index] = True
found_count += 1
except Exception as e:
if error == 'ignore':
warnings.warn("Error occurred at the {} line.".format(idx))
else:
print("Error occurred at the {} line.".format(idx))
raise e
found_count = sum(hit_flags)
print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab)))
if init_method is None:
if len(vocab)-found_count>0 and found_count>0: # 有的没找到
found_vecs = matrix[torch.LongTensor(hit_flags.astype(int)).byte()]
mean = found_vecs.mean(dim=0, keepdim=True)
std = found_vecs.std(dim=0, keepdim=True)
unfound_vec_num = np.sum(hit_flags==False)
unfound_vecs = torch.randn(unfound_vec_num, dim)*std + mean
matrix[torch.LongTensor(hit_flags.astype(int)).eq(0)] = unfound_vecs
for word, index in vocab:
if index not in matrix and not vocab._is_word_no_create_entry(word):
if vocab.unknown_idx in matrix: # 如果有unkonwn,用unknown初始化
matrix[index] = matrix[vocab.unknown_idx]
else:
matrix[index] = None

vectors = torch.zeros(len(matrix), dim)
if init_method:
init_method(vectors)
else:
nn.init.uniform_(vectors, -np.sqrt(3/dim), np.sqrt(3/dim))

if vocab._no_create_word_length>0:
if vocab.unknown is None: # 创建一个专门的unknown
unknown_idx = len(matrix)
vectors = torch.cat([vectors, torch.zeros(1, dim)], dim=0).contiguous()
else:
unknown_idx = vocab.unknown_idx
words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(),
requires_grad=False)
for order, (index, vec) in enumerate(matrix.items()):
if vec is not None:
vectors[order] = vec
words_to_words[index] = order
self.words_to_words = words_to_words
else:
for index, vec in matrix.items():
if vec is not None:
vectors[index] = vec

if normalize:
matrix /= (torch.norm(matrix, dim=1, keepdim=True) + 1e-12)
vectors /= (torch.norm(vectors, dim=1, keepdim=True) + 1e-12)

return matrix, hit_flags
return vectors

def forward(self, words):
"""


+ 2
- 2
test/test_tutorials.py View File

@@ -79,7 +79,7 @@ class TestTutorial(unittest.TestCase):
train_data.rename_field('label', 'label_seq')
test_data.rename_field('label', 'label_seq')

loss = CrossEntropyLoss(pred="output", target="label_seq")
loss = CrossEntropyLoss(target="label_seq")
metric = AccuracyMetric(target="label_seq")

# 实例化Trainer,传入模型和数据,进行训练
@@ -91,7 +91,7 @@ class TestTutorial(unittest.TestCase):

# 用train_data训练,在test_data验证
trainer = Trainer(model=model, train_data=train_data, dev_data=test_data,
loss=CrossEntropyLoss(pred="output", target="label_seq"),
loss=CrossEntropyLoss(target="label_seq"),
metrics=AccuracyMetric(target="label_seq"),
save_path=None,
batch_size=32,


Loading…
Cancel
Save