|
|
@@ -787,6 +787,24 @@ class GPT2Model(GPT2PreTrainedModel): |
|
|
|
for layer, heads in heads_to_prune.items(): |
|
|
|
self.h[layer].attn.prune_heads(heads) |
|
|
|
|
|
|
|
@property |
|
|
|
def dtype(self): |
|
|
|
""" |
|
|
|
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). |
|
|
|
""" |
|
|
|
try: |
|
|
|
return next(self.parameters()).dtype |
|
|
|
except StopIteration: |
|
|
|
# For nn.DataParallel compatibility in PyTorch 1.5 |
|
|
|
|
|
|
|
def find_tensor_attributes(module: nn.Module): |
|
|
|
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] |
|
|
|
return tuples |
|
|
|
|
|
|
|
gen = self._named_members(get_members_fn=find_tensor_attributes) |
|
|
|
first_tuple = next(gen) |
|
|
|
return first_tuple[1].dtype |
|
|
|
|
|
|
|
def forward(self, input_ids, state=None, attention_mask=None, token_type_ids=None, position_ids=None, |
|
|
|
head_mask=None, output_attentions=True): |
|
|
|
""" |
|
|
@@ -836,6 +854,7 @@ class GPT2Model(GPT2PreTrainedModel): |
|
|
|
# effectively the same as removing these entirely. |
|
|
|
# this will case an issue when DataParallel: https://github.com/pytorch/pytorch/issues/40457#issuecomment-648396469 |
|
|
|
# attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility |
|
|
|
attention_mask = attention_mask.to(self.dtype) |
|
|
|
attention_mask = (1.0 - attention_mask) * -10000.0 |
|
|
|
# attention_mask = attention_mask.masked_fill(attention_mask.eq(0), -10000.0) |
|
|
|
|
|
|
|