| @@ -464,6 +464,24 @@ class BertModel(nn.Module): | |||||
| logger.info('DistilBert has NOT pooler, will use hidden states of [CLS] token as pooled output.') | logger.info('DistilBert has NOT pooler, will use hidden states of [CLS] token as pooled output.') | ||||
| self.apply(self.init_bert_weights) | self.apply(self.init_bert_weights) | ||||
| @property | |||||
| def dtype(self): | |||||
| """ | |||||
| :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). | |||||
| """ | |||||
| try: | |||||
| return next(self.parameters()).dtype | |||||
| except StopIteration: | |||||
| # For nn.DataParallel compatibility in PyTorch 1.5 | |||||
| def find_tensor_attributes(module: nn.Module): | |||||
| tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] | |||||
| return tuples | |||||
| gen = self._named_members(get_members_fn=find_tensor_attributes) | |||||
| first_tuple = next(gen) | |||||
| return first_tuple[1].dtype | |||||
| def init_bert_weights(self, module): | def init_bert_weights(self, module): | ||||
| r""" Initialize the weights. | r""" Initialize the weights. | ||||
| """ | """ | ||||
| @@ -510,6 +528,7 @@ class BertModel(nn.Module): | |||||
| # effectively the same as removing these entirely. | # effectively the same as removing these entirely. | ||||
| # this will case an issue when DataParallel: https://github.com/pytorch/pytorch/issues/40457#issuecomment-648396469 | # this will case an issue when DataParallel: https://github.com/pytorch/pytorch/issues/40457#issuecomment-648396469 | ||||
| # extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility | # extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility | ||||
| extended_attention_mask = extended_attention_mask.to(self.dtype) | |||||
| extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 | ||||
| embedding_output = self.embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids) | embedding_output = self.embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids) | ||||
| @@ -787,6 +787,24 @@ class GPT2Model(GPT2PreTrainedModel): | |||||
| for layer, heads in heads_to_prune.items(): | for layer, heads in heads_to_prune.items(): | ||||
| self.h[layer].attn.prune_heads(heads) | self.h[layer].attn.prune_heads(heads) | ||||
| @property | |||||
| def dtype(self): | |||||
| """ | |||||
| :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). | |||||
| """ | |||||
| try: | |||||
| return next(self.parameters()).dtype | |||||
| except StopIteration: | |||||
| # For nn.DataParallel compatibility in PyTorch 1.5 | |||||
| def find_tensor_attributes(module: nn.Module): | |||||
| tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] | |||||
| return tuples | |||||
| gen = self._named_members(get_members_fn=find_tensor_attributes) | |||||
| first_tuple = next(gen) | |||||
| return first_tuple[1].dtype | |||||
| def forward(self, input_ids, state=None, attention_mask=None, token_type_ids=None, position_ids=None, | def forward(self, input_ids, state=None, attention_mask=None, token_type_ids=None, position_ids=None, | ||||
| head_mask=None, output_attentions=True): | head_mask=None, output_attentions=True): | ||||
| """ | """ | ||||
| @@ -836,6 +854,7 @@ class GPT2Model(GPT2PreTrainedModel): | |||||
| # effectively the same as removing these entirely. | # effectively the same as removing these entirely. | ||||
| # this will case an issue when DataParallel: https://github.com/pytorch/pytorch/issues/40457#issuecomment-648396469 | # this will case an issue when DataParallel: https://github.com/pytorch/pytorch/issues/40457#issuecomment-648396469 | ||||
| # attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility | # attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility | ||||
| attention_mask = attention_mask.to(self.dtype) | |||||
| attention_mask = (1.0 - attention_mask) * -10000.0 | attention_mask = (1.0 - attention_mask) * -10000.0 | ||||
| # attention_mask = attention_mask.masked_fill(attention_mask.eq(0), -10000.0) | # attention_mask = attention_mask.masked_fill(attention_mask.eq(0), -10000.0) | ||||