|
|
@@ -7,12 +7,12 @@ from typing import Any, Dict |
|
|
|
import torch |
|
|
|
from torch import nn |
|
|
|
|
|
|
|
from ...utils.config import ConfigDict |
|
|
|
from ...utils.constant import Fields, Tasks |
|
|
|
from ...utils.logger import get_logger |
|
|
|
from ...utils.utils import if_func_recieve_dict_inputs |
|
|
|
from ..base import TorchModel |
|
|
|
from ..builder import build_backbone, build_head |
|
|
|
from modelscope.models.base import TorchModel |
|
|
|
from modelscope.models.builder import build_backbone, build_head |
|
|
|
from modelscope.utils.config import ConfigDict |
|
|
|
from modelscope.utils.constant import Fields, Tasks |
|
|
|
from modelscope.utils.logger import get_logger |
|
|
|
from modelscope.utils.utils import if_func_receive_dict_inputs |
|
|
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
@@ -424,7 +424,7 @@ class SingleBackboneTaskModelBase(BaseTaskModel): |
|
|
|
|
|
|
|
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
"""default forward method is the backbone-only forward""" |
|
|
|
if if_func_recieve_dict_inputs(self.backbone.forward, input): |
|
|
|
if if_func_receive_dict_inputs(self.backbone.forward, input): |
|
|
|
outputs = self.backbone.forward(input) |
|
|
|
else: |
|
|
|
outputs = self.backbone.forward(**input) |
|
|
@@ -472,13 +472,13 @@ class EncoderDecoderTaskModelBase(BaseTaskModel): |
|
|
|
return getattr(self, self._decoder_prefix) |
|
|
|
|
|
|
|
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
if if_func_recieve_dict_inputs(self.encoder_.forward, input): |
|
|
|
if if_func_receive_dict_inputs(self.encoder_.forward, input): |
|
|
|
encoder_outputs = self.encoder_.forward(input) |
|
|
|
else: |
|
|
|
encoder_outputs = self.encoder_.forward(**input) |
|
|
|
decoder_inputs = self.project_decoder_inputs_and_mediate( |
|
|
|
input, encoder_outputs) |
|
|
|
if if_func_recieve_dict_inputs(self.decoder_.forward, input): |
|
|
|
if if_func_receive_dict_inputs(self.decoder_.forward, input): |
|
|
|
outputs = self.decoder_.forward(decoder_inputs) |
|
|
|
else: |
|
|
|
outputs = self.decoder_.forward(**decoder_inputs) |
|
|
|