From 057fa63d7ef039dabff62f2d2d8f3387bc98268f Mon Sep 17 00:00:00 2001 From: yh_cc Date: Fri, 11 Dec 2020 14:15:41 +0800 Subject: [PATCH] =?UTF-8?q?=E8=A7=A3=E5=86=B3dytpe=E7=9A=84=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/modules/encoder/bert.py | 19 +++++++++++++++++++ fastNLP/modules/encoder/gpt2.py | 19 +++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/fastNLP/modules/encoder/bert.py b/fastNLP/modules/encoder/bert.py index 7245b577..28c47eb6 100644 --- a/fastNLP/modules/encoder/bert.py +++ b/fastNLP/modules/encoder/bert.py @@ -464,6 +464,24 @@ class BertModel(nn.Module): logger.info('DistilBert has NOT pooler, will use hidden states of [CLS] token as pooled output.') self.apply(self.init_bert_weights) + @property + def dtype(self): + """ + :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + """ + try: + return next(self.parameters()).dtype + except StopIteration: + # For nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: nn.Module): + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = self._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].dtype + def init_bert_weights(self, module): r""" Initialize the weights. """ @@ -510,6 +528,7 @@ class BertModel(nn.Module): # effectively the same as removing these entirely. # this will case an issue when DataParallel: https://github.com/pytorch/pytorch/issues/40457#issuecomment-648396469 # extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = extended_attention_mask.to(self.dtype) extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 embedding_output = self.embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids) diff --git a/fastNLP/modules/encoder/gpt2.py b/fastNLP/modules/encoder/gpt2.py index 076d98eb..f6bf6dde 100644 --- a/fastNLP/modules/encoder/gpt2.py +++ b/fastNLP/modules/encoder/gpt2.py @@ -787,6 +787,24 @@ class GPT2Model(GPT2PreTrainedModel): for layer, heads in heads_to_prune.items(): self.h[layer].attn.prune_heads(heads) + @property + def dtype(self): + """ + :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + """ + try: + return next(self.parameters()).dtype + except StopIteration: + # For nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: nn.Module): + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = self._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].dtype + def forward(self, input_ids, state=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, output_attentions=True): """ @@ -836,6 +854,7 @@ class GPT2Model(GPT2PreTrainedModel): # effectively the same as removing these entirely. # this will case an issue when DataParallel: https://github.com/pytorch/pytorch/issues/40457#issuecomment-648396469 # attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + attention_mask = attention_mask.to(self.dtype) attention_mask = (1.0 - attention_mask) * -10000.0 # attention_mask = attention_mask.masked_fill(attention_mask.eq(0), -10000.0)