|
|
@@ -45,42 +45,24 @@ class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor): |
|
|
|
|
|
|
|
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
sample = self._build_infer_sample(data) |
|
|
|
src_item = sample['source'] |
|
|
|
ref = data[self.column_map['ref']] |
|
|
|
predict_objects = data[self.column_map['predict_objects']] |
|
|
|
|
|
|
|
ref_dict = { |
|
|
|
item.split('|!+')[1]: float(item.split('|!+')[0]) |
|
|
|
for item in ref.split('&&') |
|
|
|
} |
|
|
|
answer = max(ref_dict, key=ref_dict.get) |
|
|
|
sample['conf'] = torch.tensor([ref_dict[answer]]) |
|
|
|
tgt_item = self.tokenize_text( |
|
|
|
' {}'.format(answer), add_bos=False, add_eos=False) |
|
|
|
|
|
|
|
if self.add_object and predict_objects is not None: |
|
|
|
predict_object_seq = ' '.join( |
|
|
|
predict_objects.strip().split('&&')[:self.max_object_length]) |
|
|
|
predict_object_item = self.tokenize_text( |
|
|
|
' object: {}'.format(predict_object_seq), add_bos=False) |
|
|
|
src_item = torch.cat([src_item, predict_object_item[:-1]]) |
|
|
|
' {}'.format(sample['label']), add_bos=False, add_eos=False) |
|
|
|
|
|
|
|
if self.prompt_type == 'none': |
|
|
|
prev_output_item = torch.cat([self.bos_item, tgt_item]) |
|
|
|
target_item = torch.cat([prev_output_item[1:], self.eos_item]) |
|
|
|
elif self.prompt_type == 'src': |
|
|
|
prev_output_item = torch.cat([src_item, tgt_item]) |
|
|
|
prev_output_item = torch.cat([sample['source'], tgt_item]) |
|
|
|
target_item = torch.cat([prev_output_item[1:], self.eos_item]) |
|
|
|
elif self.prompt_type == 'prev_output': |
|
|
|
prev_output_item = torch.cat([src_item[:-1], tgt_item]) |
|
|
|
prev_output_item = torch.cat([sample['source'][:-1], tgt_item]) |
|
|
|
target_item = torch.cat([prev_output_item[1:], self.eos_item]) |
|
|
|
else: |
|
|
|
raise NotImplementedError |
|
|
|
target_item[:-len(tgt_item) - 1] = self.tgt_dict.pad() |
|
|
|
target_item[:-len(tgt_item) - 1] = self.tokenizer.pad_token_id |
|
|
|
|
|
|
|
sample['prev_output_tokens'] = prev_output_item |
|
|
|
sample['target'] = target_item |
|
|
|
sample['ref_dict'] = ref_dict |
|
|
|
|
|
|
|
if self.constraint_trie is not None: |
|
|
|
constraint_mask = torch.zeros( |
|
|
@@ -101,7 +83,7 @@ class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor): |
|
|
|
def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
image = self.get_img_pil(data[self.column_map['image']]) |
|
|
|
patch_image = self.patch_resize_transform(image) |
|
|
|
text = ' {}'.format(data[self.column_map['text']]) |
|
|
|
text = ' {}'.format(data[self.column_map['query']]) |
|
|
|
inputs = self.tokenize_text(text) |
|
|
|
if self.prompt_type == 'none': |
|
|
|
decoder_prompt = self.bos_item |
|
|
|