@@ -464,6 +464,24 @@ class BertModel(nn.Module): | |||||
logger.info('DistilBert has NOT pooler, will use hidden states of [CLS] token as pooled output.') | logger.info('DistilBert has NOT pooler, will use hidden states of [CLS] token as pooled output.') | ||||
self.apply(self.init_bert_weights) | self.apply(self.init_bert_weights) | ||||
@property | |||||
def dtype(self): | |||||
""" | |||||
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). | |||||
""" | |||||
try: | |||||
return next(self.parameters()).dtype | |||||
except StopIteration: | |||||
# For nn.DataParallel compatibility in PyTorch 1.5 | |||||
def find_tensor_attributes(module: nn.Module): | |||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] | |||||
return tuples | |||||
gen = self._named_members(get_members_fn=find_tensor_attributes) | |||||
first_tuple = next(gen) | |||||
return first_tuple[1].dtype | |||||
def init_bert_weights(self, module): | def init_bert_weights(self, module): | ||||
r""" Initialize the weights. | r""" Initialize the weights. | ||||
""" | """ | ||||
@@ -510,6 +528,7 @@ class BertModel(nn.Module): | |||||
# effectively the same as removing these entirely. | # effectively the same as removing these entirely. | ||||
# this will case an issue when DataParallel: https://github.com/pytorch/pytorch/issues/40457#issuecomment-648396469 | # this will case an issue when DataParallel: https://github.com/pytorch/pytorch/issues/40457#issuecomment-648396469 | ||||
# extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility | # extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility | ||||
extended_attention_mask = extended_attention_mask.to(self.dtype) | |||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 | ||||
embedding_output = self.embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids) | embedding_output = self.embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids) | ||||
@@ -787,6 +787,24 @@ class GPT2Model(GPT2PreTrainedModel): | |||||
for layer, heads in heads_to_prune.items(): | for layer, heads in heads_to_prune.items(): | ||||
self.h[layer].attn.prune_heads(heads) | self.h[layer].attn.prune_heads(heads) | ||||
@property | |||||
def dtype(self): | |||||
""" | |||||
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). | |||||
""" | |||||
try: | |||||
return next(self.parameters()).dtype | |||||
except StopIteration: | |||||
# For nn.DataParallel compatibility in PyTorch 1.5 | |||||
def find_tensor_attributes(module: nn.Module): | |||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] | |||||
return tuples | |||||
gen = self._named_members(get_members_fn=find_tensor_attributes) | |||||
first_tuple = next(gen) | |||||
return first_tuple[1].dtype | |||||
def forward(self, input_ids, state=None, attention_mask=None, token_type_ids=None, position_ids=None, | def forward(self, input_ids, state=None, attention_mask=None, token_type_ids=None, position_ids=None, | ||||
head_mask=None, output_attentions=True): | head_mask=None, output_attentions=True): | ||||
""" | """ | ||||
@@ -836,6 +854,7 @@ class GPT2Model(GPT2PreTrainedModel): | |||||
# effectively the same as removing these entirely. | # effectively the same as removing these entirely. | ||||
# this will case an issue when DataParallel: https://github.com/pytorch/pytorch/issues/40457#issuecomment-648396469 | # this will case an issue when DataParallel: https://github.com/pytorch/pytorch/issues/40457#issuecomment-648396469 | ||||
# attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility | # attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility | ||||
attention_mask = attention_mask.to(self.dtype) | |||||
attention_mask = (1.0 - attention_mask) * -10000.0 | attention_mask = (1.0 - attention_mask) * -10000.0 | ||||
# attention_mask = attention_mask.masked_fill(attention_mask.eq(0), -10000.0) | # attention_mask = attention_mask.masked_fill(attention_mask.eq(0), -10000.0) | ||||