| @@ -57,7 +57,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): | |||||
| r'decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight', | r'decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight', | ||||
| ] | ] | ||||
| def __init__(self, config: T5Config): | |||||
| def __init__(self, config: T5Config, **kwargs): | |||||
| super().__init__(config) | super().__init__(config) | ||||
| self.model_dim = config.d_model | self.model_dim = config.d_model | ||||
| @@ -24,7 +24,7 @@ class BertForDocumentSegmentation(BertPreTrainedModel): | |||||
| _keys_to_ignore_on_load_unexpected = [r'pooler'] | _keys_to_ignore_on_load_unexpected = [r'pooler'] | ||||
| def __init__(self, config): | |||||
| def __init__(self, config, **kwargs): | |||||
| super().__init__(config) | super().__init__(config) | ||||
| self.num_labels = config.num_labels | self.num_labels = config.num_labels | ||||
| self.sentence_pooler_type = None | self.sentence_pooler_type = None | ||||
| @@ -11,7 +11,7 @@ from .backbone import BertModel, BertPreTrainedModel | |||||
| @MODELS.register_module(Tasks.sentence_embedding, module_name=Models.bert) | @MODELS.register_module(Tasks.sentence_embedding, module_name=Models.bert) | ||||
| class BertForSentenceEmbedding(BertPreTrainedModel): | class BertForSentenceEmbedding(BertPreTrainedModel): | ||||
| def __init__(self, config): | |||||
| def __init__(self, config, **kwargs): | |||||
| super().__init__(config) | super().__init__(config) | ||||
| self.config = config | self.config = config | ||||
| setattr(self, self.base_model_prefix, | setattr(self, self.base_model_prefix, | ||||
| @@ -66,7 +66,7 @@ class BertForSequenceClassification(BertPreTrainedModel): | |||||
| weights. | weights. | ||||
| """ | """ | ||||
| def __init__(self, config): | |||||
| def __init__(self, config, **kwargs): | |||||
| super().__init__(config) | super().__init__(config) | ||||
| self.num_labels = config.num_labels | self.num_labels = config.num_labels | ||||
| self.config = config | self.config = config | ||||
| @@ -25,7 +25,7 @@ __all__ = ['PoNetForDocumentSegmentation'] | |||||
| class PoNetForDocumentSegmentation(PoNetPreTrainedModel): | class PoNetForDocumentSegmentation(PoNetPreTrainedModel): | ||||
| _keys_to_ignore_on_load_unexpected = [r'pooler'] | _keys_to_ignore_on_load_unexpected = [r'pooler'] | ||||
| def __init__(self, config): | |||||
| def __init__(self, config, **kwargs): | |||||
| super().__init__(config) | super().__init__(config) | ||||
| self.num_labels = config.num_labels | self.num_labels = config.num_labels | ||||