Browse Source

解决dytpe的问题

tags/v1.0.0alpha
yh_cc 4 years ago
parent
commit
057fa63d7e
2 changed files with 38 additions and 0 deletions
  1. +19
    -0
      fastNLP/modules/encoder/bert.py
  2. +19
    -0
      fastNLP/modules/encoder/gpt2.py

+ 19
- 0
fastNLP/modules/encoder/bert.py View File

@@ -464,6 +464,24 @@ class BertModel(nn.Module):
logger.info('DistilBert has NOT pooler, will use hidden states of [CLS] token as pooled output.')
self.apply(self.init_bert_weights)

@property
def dtype(self):
"""
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
"""
try:
return next(self.parameters()).dtype
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5

def find_tensor_attributes(module: nn.Module):
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples

gen = self._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].dtype

def init_bert_weights(self, module):
r""" Initialize the weights.
"""
@@ -510,6 +528,7 @@ class BertModel(nn.Module):
# effectively the same as removing these entirely.
# this will case an issue when DataParallel: https://github.com/pytorch/pytorch/issues/40457#issuecomment-648396469
# extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = extended_attention_mask.to(self.dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

embedding_output = self.embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)


+ 19
- 0
fastNLP/modules/encoder/gpt2.py View File

@@ -787,6 +787,24 @@ class GPT2Model(GPT2PreTrainedModel):
for layer, heads in heads_to_prune.items():
self.h[layer].attn.prune_heads(heads)

@property
def dtype(self):
"""
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
"""
try:
return next(self.parameters()).dtype
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5

def find_tensor_attributes(module: nn.Module):
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples

gen = self._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].dtype

def forward(self, input_ids, state=None, attention_mask=None, token_type_ids=None, position_ids=None,
head_mask=None, output_attentions=True):
"""
@@ -836,6 +854,7 @@ class GPT2Model(GPT2PreTrainedModel):
# effectively the same as removing these entirely.
# this will case an issue when DataParallel: https://github.com/pytorch/pytorch/issues/40457#issuecomment-648396469
# attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
attention_mask = attention_mask.to(self.dtype)
attention_mask = (1.0 - attention_mask) * -10000.0
# attention_mask = attention_mask.masked_fill(attention_mask.eq(0), -10000.0)



Loading…
Cancel
Save