Browse Source

fix finetune-task

master
翎航 2 years ago
parent
commit
ca946067e6
3 changed files with 14 additions and 28 deletions
  1. +2
    -1
      modelscope/preprocessors/ofa/summarization.py
  2. +7
    -4
      modelscope/preprocessors/ofa/visual_entailment.py
  3. +5
    -23
      modelscope/preprocessors/ofa/visual_question_answering.py

+ 2
- 1
modelscope/preprocessors/ofa/summarization.py View File

@@ -72,6 +72,7 @@ class OfaSummarizationPreprocessor(OfaBasePreprocessor):
'noise_ratio', 0.0)
target[noise_indices] = torch.randint(
4,
len(self.src_dict) - self.code_dict_size - self.num_bins,
len(self.src_dict) - self.cfg.model.get('num_codes', 8192)
- self.cfg.model.get('num_bins', 1000),
size=(noise_indices.sum(), ))
return target

+ 7
- 4
modelscope/preprocessors/ofa/visual_entailment.py View File

@@ -61,7 +61,7 @@ class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor):
else:
raise NotImplementedError

target_item[:-len(tgt_item) - 1] = self.tgt_dict.pad()
target_item[:-len(tgt_item) - 1] = self.tokenizer.pad_token_id
sample['target'] = target_item
sample['prev_output_tokens'] = prev_output_item

@@ -85,14 +85,17 @@ class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor):
image = self.get_img_pil(data[self.column_map['image']])
patch_image = self.patch_resize_transform(image)
if 'text2' not in data:
hypothesis = self.pre_caption(data['text'], self.max_src_length)
hypothesis = self.pre_caption(data[self.column_map['text']],
self.max_src_length)
prompt = self.cfg.model.get('prompt',
' does the image describe " {} "?')
text = prompt.format(hypothesis)
else:
assert 'text' in data, f'text must be in the input {data.keys()}'
caption = self.pre_caption(data['text2'], self.max_src_length)
hypothesis = self.pre_caption(data['text'], self.max_src_length)
caption = self.pre_caption(data[self.column_map['text2']],
self.max_src_length)
hypothesis = self.pre_caption(data[self.column_map['text']],
self.max_src_length)
prompt = self.cfg.model.get(
'prompt', ' can image and text1 " {} " imply text2 " {} "?')
text = prompt.format(caption, hypothesis)


+ 5
- 23
modelscope/preprocessors/ofa/visual_question_answering.py View File

@@ -45,42 +45,24 @@ class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor):

def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self._build_infer_sample(data)
src_item = sample['source']
ref = data[self.column_map['ref']]
predict_objects = data[self.column_map['predict_objects']]

ref_dict = {
item.split('|!+')[1]: float(item.split('|!+')[0])
for item in ref.split('&&')
}
answer = max(ref_dict, key=ref_dict.get)
sample['conf'] = torch.tensor([ref_dict[answer]])
tgt_item = self.tokenize_text(
' {}'.format(answer), add_bos=False, add_eos=False)

if self.add_object and predict_objects is not None:
predict_object_seq = ' '.join(
predict_objects.strip().split('&&')[:self.max_object_length])
predict_object_item = self.tokenize_text(
' object: {}'.format(predict_object_seq), add_bos=False)
src_item = torch.cat([src_item, predict_object_item[:-1]])
' {}'.format(sample['label']), add_bos=False, add_eos=False)

if self.prompt_type == 'none':
prev_output_item = torch.cat([self.bos_item, tgt_item])
target_item = torch.cat([prev_output_item[1:], self.eos_item])
elif self.prompt_type == 'src':
prev_output_item = torch.cat([src_item, tgt_item])
prev_output_item = torch.cat([sample['source'], tgt_item])
target_item = torch.cat([prev_output_item[1:], self.eos_item])
elif self.prompt_type == 'prev_output':
prev_output_item = torch.cat([src_item[:-1], tgt_item])
prev_output_item = torch.cat([sample['source'][:-1], tgt_item])
target_item = torch.cat([prev_output_item[1:], self.eos_item])
else:
raise NotImplementedError
target_item[:-len(tgt_item) - 1] = self.tgt_dict.pad()
target_item[:-len(tgt_item) - 1] = self.tokenizer.pad_token_id

sample['prev_output_tokens'] = prev_output_item
sample['target'] = target_item
sample['ref_dict'] = ref_dict

if self.constraint_trie is not None:
constraint_mask = torch.zeros(
@@ -101,7 +83,7 @@ class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor):
def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = self.get_img_pil(data[self.column_map['image']])
patch_image = self.patch_resize_transform(image)
text = ' {}'.format(data[self.column_map['text']])
text = ' {}'.format(data[self.column_map['query']])
inputs = self.tokenize_text(text)
if self.prompt_type == 'none':
decoder_prompt = self.bos_item


Loading…
Cancel
Save