|
|
@@ -67,7 +67,7 @@ class GEMMForMultiModalEmbedding(TorchModel): |
|
|
|
return img_tensor |
|
|
|
|
|
|
|
def parse_text(self, text_str): |
|
|
|
if text_str is None: |
|
|
|
if text_str is None or len(text_str) == 0: |
|
|
|
return None |
|
|
|
if isinstance(text_str, str): |
|
|
|
text_ids_tensor = self.gemm_model.tokenize(text_str) |
|
|
@@ -79,9 +79,12 @@ class GEMMForMultiModalEmbedding(TorchModel): |
|
|
|
return text_ids_tensor.view(1, -1) |
|
|
|
|
|
|
|
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
image = self.parse_image(input.get('image', input.get('img', None))) |
|
|
|
text = self.parse_text(input.get('text', input.get('txt', None))) |
|
|
|
captioning = input.get('captioning', False) is True |
|
|
|
image_input = input.get('image', input.get('img', None)) |
|
|
|
text_input = input.get('text', input.get('txt', None)) |
|
|
|
captioning_input = input.get('captioning', None) |
|
|
|
image = self.parse_image(image_input) |
|
|
|
text = self.parse_text(text_input) |
|
|
|
captioning = captioning_input is True or text_input == '' |
|
|
|
out = self.gemm_model(image, text, captioning) |
|
|
|
output = { |
|
|
|
OutputKeys.IMG_EMBEDDING: out.get('image_feature', None), |
|
|
|