|
|
@@ -26,6 +26,7 @@ class MultimodalTransformer(nn.Module): |
|
|
|
num_decoder_layers=3, |
|
|
|
text_encoder_type='roberta-base', |
|
|
|
freeze_text_encoder=True, |
|
|
|
transformer_cfg_dir=None, |
|
|
|
**kwargs): |
|
|
|
super().__init__() |
|
|
|
self.d_model = kwargs['d_model'] |
|
|
@@ -40,10 +41,12 @@ class MultimodalTransformer(nn.Module): |
|
|
|
self.pos_encoder_2d = PositionEmbeddingSine2D() |
|
|
|
self._reset_parameters() |
|
|
|
|
|
|
|
self.text_encoder = RobertaModel.from_pretrained(text_encoder_type) |
|
|
|
if text_encoder_type != 'roberta-base': |
|
|
|
transformer_cfg_dir = text_encoder_type |
|
|
|
self.text_encoder = RobertaModel.from_pretrained(transformer_cfg_dir) |
|
|
|
self.text_encoder.pooler = None # this pooler is never used, this is a hack to avoid DDP problems... |
|
|
|
self.tokenizer = RobertaTokenizerFast.from_pretrained( |
|
|
|
text_encoder_type) |
|
|
|
transformer_cfg_dir) |
|
|
|
self.freeze_text_encoder = freeze_text_encoder |
|
|
|
if freeze_text_encoder: |
|
|
|
for p in self.text_encoder.parameters(): |
|
|
|