Browse Source

mnli finetune done

master
行嗔 3 years ago
parent
commit
5d83f62312
26 changed files with 322 additions and 281 deletions
  1. +44
    -0
      modelscope/metrics/accuracy_metric.py
  2. +4
    -2
      modelscope/models/multi_modal/ofa/generate/sequence_generator.py
  3. +1
    -0
      modelscope/models/multi_modal/ofa/utils/__init__.py
  4. +42
    -35
      modelscope/models/multi_modal/ofa_for_all_tasks.py
  5. +4
    -3
      modelscope/msdatasets/ms_dataset.py
  6. +2
    -0
      modelscope/pipelines/cv/image_classification_pipeline.py
  7. +6
    -3
      modelscope/preprocessors/multi_modal.py
  8. +25
    -8
      modelscope/preprocessors/ofa/base.py
  9. +16
    -1
      modelscope/preprocessors/ofa/image_captioning.py
  10. +1
    -1
      modelscope/preprocessors/ofa/image_classification.py
  11. +1
    -1
      modelscope/preprocessors/ofa/summarization.py
  12. +42
    -8
      modelscope/preprocessors/ofa/text_classification.py
  13. +1
    -1
      modelscope/preprocessors/ofa/text_to_image_synthesis.py
  14. +2
    -0
      modelscope/preprocessors/ofa/utils/collate.py
  15. +1
    -1
      modelscope/preprocessors/ofa/visual_entailment.py
  16. +1
    -1
      modelscope/preprocessors/ofa/visual_grounding.py
  17. +1
    -1
      modelscope/preprocessors/ofa/visual_question_answering.py
  18. +35
    -33
      modelscope/trainers/multi_modal/ofa/ofa_trainer.py
  19. +0
    -120
      modelscope/trainers/multi_modal/ofa/ofa_trainer_old.py
  20. +1
    -37
      modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py
  21. +9
    -9
      modelscope/trainers/trainer.py
  22. +29
    -15
      modelscope/trainers/utils/inference.py
  23. +15
    -1
      modelscope/utils/device.py
  24. +17
    -0
      modelscope/utils/multi_modal/forked_pdb.py
  25. +21
    -0
      tests/pipelines/test_ofa_tasks.py
  26. +1
    -0
      tests/trainers/test_ofa_trainer.py

+ 44
- 0
modelscope/metrics/accuracy_metric.py View File

@@ -0,0 +1,44 @@
from typing import Dict

import numpy as np

from modelscope.metainfo import Metrics
from modelscope.outputs import OutputKeys
from modelscope.utils.registry import default_group
from .base import Metric
from .builder import METRICS, MetricKeys


@METRICS.register_module(group_key=default_group, module_name=Metrics.accuracy)
class AccuracyMetric(Metric):
"""The metric computation class for sequence classification classes.

This metric class calculates accuracy for the whole input batches.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.preds = []
self.labels = []

def add(self, outputs: Dict, inputs: Dict):
label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS
ground_truths = inputs[label_name]
eval_results = outputs[label_name]
assert type(ground_truths) == type(eval_results)
if isinstance(ground_truths, list):
self.preds.extend(eval_results)
self.labels.extend(ground_truths)
elif isinstance(ground_truths, np.ndarray):
self.preds.extend(eval_results.tolist())
self.labels.extend(ground_truths.tolist())
else:
raise 'only support list or np.ndarray'

def evaluate(self):
assert len(self.preds) == len(self.labels)
return {
MetricKeys.ACCURACY: (np.asarray([
pred == ref for pred, ref in zip(self.preds, self.labels)
])).mean().item()
}

+ 4
- 2
modelscope/models/multi_modal/ofa/generate/sequence_generator.py View File

@@ -409,10 +409,12 @@ class SequenceGenerator(nn.Module):
out_prefix = p_toks_len_beam < (
step + no_repeat_ngram_size - 1)
else:
out_prefix = [True] * bsz * beam_size
out_prefix = torch.ones(bsz * beam_size).bool()
ngram_blocker_tokens = tokens[out_prefix]
ngram_blocker_lprobs = lprobs[out_prefix]
ngram_blocker_bsz = out_prefix.sum() // beam_size
ngram_blocker_bsz = torch.div(
out_prefix.sum(), beam_size, rounding_mode='trunc')

lprobs[out_prefix] = self.repeat_ngram_blocker(
tokens=ngram_blocker_tokens,
lprobs=ngram_blocker_lprobs,


+ 1
- 0
modelscope/models/multi_modal/ofa/utils/__init__.py View File

@@ -0,0 +1 @@
from .constant import OFA_TASK_KEY_MAPPING

+ 42
- 35
modelscope/models/multi_modal/ofa_for_all_tasks.py View File

@@ -10,7 +10,6 @@ import torch.nn.functional as F

from modelscope.metainfo import Models
from modelscope.models import TorchModel
from modelscope.models.base import Tensor
from modelscope.models.builder import MODELS
from modelscope.outputs import OutputKeys
from modelscope.preprocessors.ofa.utils.collate import collate_tokens
@@ -38,7 +37,9 @@ class OfaForAllTasks(TorchModel):

def __init__(self, model_dir, *args, **kwargs):
super().__init__(model_dir=model_dir, *args, **kwargs)
model = OFAModel.from_pretrained(model_dir)
sd = torch.load(osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE))
sd = sd if 'meta' not in sd else sd['state_dict']
model = OFAModel.from_pretrained(model_dir, state_dict=sd)
self.cfg = Config.from_file(
osp.join(model_dir, ModelFile.CONFIGURATION))
self.model = model.module if hasattr(model, 'module') else model
@@ -65,10 +66,9 @@ class OfaForAllTasks(TorchModel):
self.gen_type = self.cfg.model.get('gen_type', 'generation')
assert self.gen_type in ['generation', 'traverse'], \
'model.gen_type must be in ["generation", "traverse"]'
self._device = torch.device('cuda') if torch.cuda.is_available() \
else torch.device('cpu')
self.eos_item = torch.LongTensor([self.tokenizer.eos_token_id
]).to(self._device)
self.bos_item = torch.LongTensor([self.tokenizer.bos_token_id])
self.pad_item = torch.LongTensor([self.tokenizer.pad_token_id])
self.eos_item = torch.LongTensor([self.tokenizer.eos_token_id])
self.index2ans = {}
self.ans2label_dict = {}
self.load_ans2label()
@@ -89,7 +89,8 @@ class OfaForAllTasks(TorchModel):
self.val_masks_l = []
self.build_trie()
sg_args['constraint_trie'] = self.constraint_trie
self.model.to(self._device)
else:
self.constraint_trie = None
self.generator = sg.SequenceGenerator(**sg_args)
inference_d = {
'generation': self._text_gen_inference,
@@ -106,42 +107,52 @@ class OfaForAllTasks(TorchModel):
}

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
input = move_to_device(input, self.model.device)
if self.model.training:
return self.model(**input['net_input'])
else:
return self.inference(input)

def inference(self, input: Dict[str, Any]) -> Dict[str, Any]:
ret = self.task_inference_mapping[self.cfg.task](input)
ret['samples'] = input['samples']
if 'samples' in input:
ret['samples'] = input['samples']
for key in [
OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES,
OutputKeys.LABELS, OutputKeys.SCORES
]:
if key in ret and len(ret[key]) == 1:
ret[key] = ret[key][0]
if key not in ret:
ret[key] = None
return ret

def postprocess(self, input: Dict[str, Tensor],
**kwargs) -> Dict[str, Tensor]:
def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]:
if self.cfg.task == Tasks.image_captioning:
caption = input[OutputKeys.CAPTION]
caption = caption.translate(self.transtab).strip()
result_l = list()
for cap in caption:
result_l.append(cap.translate(self.transtab).strip())
input[OutputKeys.CAPTION] = caption

return input

def _text_gen_inference(self, input):
input = move_to_device(input, self._device)
if 'prefix_tokens' in input:
gen_output = self.generator.generate(
[self.model], input, prefix_tokens=input['prefix_tokens'])
else:
gen_output = self.generator.generate([self.model], input)
gen_outputs = self.generator.generate([self.model],
input,
prefix_tokens=input.get(
'prefix_tokens', None))
gen_l = list()
for i in range(len(gen_output)):
if 'prefix_tokens' in input:
prefix_tokens = input['prefix_tokens']
gen_l.append(
gen_output[i][0]['tokens'][len(prefix_tokens[i]):])
for idx, gen_out in enumerate(gen_outputs):
if len(gen_out) > 0:
decode_tokens = gen_out[0]['tokens']
if 'prefix_tokens' in input:
prefix_len = input['prefix_tokens'][idx].ne(
self.pad_item.to(self.model.device)).sum()
decode_tokens = decode_tokens[prefix_len:]
gen_l.append(decode_tokens)
else:
gen_l.append(gen_output[i][0]['tokens'])
gen_l.append('')
result = self.tokenizer.batch_decode(gen_l, skip_special_tokens=True)
result = [item.strip() for item in result]
# text generation tasks have no score
ret = {OFA_TASK_KEY_MAPPING[self.cfg.task]: result}
if self.cfg.task.endswith('classification'):
@@ -149,7 +160,6 @@ class OfaForAllTasks(TorchModel):
return ret

def _visual_grounding_inference(self, input):
input = move_to_device(input, self._device)
gen_output = self.generator.generate([self.model], input)
tokens = [gen_output[i][0]['tokens'] for i in range(len(gen_output))]
region_coord_l = list()
@@ -163,13 +173,12 @@ class OfaForAllTasks(TorchModel):
region_tensor[:, ::2] /= input['w_resize_ratios']
region_tensor[:, 1::2] /= input['h_resize_ratios']
return {
OutputKeys.BOXES: move_to_device(region_tensor,
torch.device('cpu')),
OutputKeys.BOXES:
move_to_device(region_tensor, torch.device('cpu')).tolist(),
OutputKeys.SCORES: [1.0] * region_tensor.shape[0]
}

def _traverse_inference(self, input):
input = move_to_device(input, self._device)
encoder_input = dict()
for key in input['net_input'].keys():
encoder_input[key] = input['net_input'][key]
@@ -193,19 +202,19 @@ class OfaForAllTasks(TorchModel):
torch.cat([
torch.zeros(
len(decoder_prompt) - 1,
valid_constraint_mask.size(1)).bool().to(self._device),
valid_constraint_mask.size(1)).bool(),
valid_constraint_mask], dim=0) # yapf: disable
for decoder_prompt in input['decoder_prompts'] # yapf: disable
for valid_constraint_mask in val_masks] # yapf: disable
valid_tgt = collate_tokens(
valid_tgt_items,
pad_idx=self.tokenizer.pad_token_id).to(self._device)
pad_idx=self.tokenizer.pad_token_id).to(self.model.device)
valid_prev_output = collate_tokens(
valid_prev_items,
pad_idx=self.tokenizer.pad_token_id).to(self._device)
pad_idx=self.tokenizer.pad_token_id).to(self.model.device)
val_masks = collate_tokens(
valid_constraint_mask_items,
pad_idx=self.tokenizer.pad_token_id).to(self._device)
pad_idx=self.tokenizer.pad_token_id).to(self.model.device)
new_encoder_out = {
'last_hidden_state':
encoder_out['last_hidden_state'].repeat_interleave(
@@ -280,8 +289,6 @@ class OfaForAllTasks(TorchModel):
self.val_masks_l += [
constraint_mask_list[i:i + self.val_batch_size]
]
self.val_ans_l = move_to_device(self.val_ans_l, self._device)
self.val_masks_l = move_to_device(self.val_masks_l, self._device)

def load_ans2label(self):
if self.cfg.model.get('answer2label', None):


+ 4
- 3
modelscope/msdatasets/ms_dataset.py View File

@@ -75,7 +75,7 @@ class MsIterableDataset(torch.utils.data.IterableDataset):
}
for preprocessor in self.preprocessor_list:
res.update({
k: torch.tensor(v)
k: v # k: torch.tensor(v)
for k, v in preprocessor(item_dict).items()
if k in self.retained_columns
})
@@ -350,14 +350,15 @@ class MsDataset:

def is_numpy_number(value):
return np.issubdtype(value.dtype, np.integer) or np.issubdtype(
value.dtype, np.floating)
value.dtype, np.floating) or np.issubdtype(
value.dtype, np.bool)

retained_columns = []
for k in sample_res.keys():
if not is_numpy_number(sample_res[k]):
logger.warning(
f'Data of column {k} is non-numeric, will be removed')
continue
# continue
retained_columns.append(k)

return MsIterableDataset(self._hf_ds, preprocessor_list,


+ 2
- 0
modelscope/pipelines/cv/image_classification_pipeline.py View File

@@ -13,6 +13,7 @@ from modelscope.pipelines.base import Input, Model, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import OfaPreprocessor, Preprocessor, load_image
from modelscope.utils.constant import Tasks
from modelscope.utils.device import get_device
from modelscope.utils.logger import get_logger

logger = get_logger()
@@ -36,6 +37,7 @@ class ImageClassificationPipeline(Pipeline):
else:
raise NotImplementedError
pipe_model.model.eval()
pipe_model.to(get_device())
if preprocessor is None and isinstance(pipe_model, OfaForAllTasks):
preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir)
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs)


+ 6
- 3
modelscope/preprocessors/multi_modal.py View File

@@ -84,7 +84,12 @@ class OfaPreprocessor(Preprocessor):

def _compatible_with_pretrain(self, data):
if 'image' in data and self.cfg.model.get('type', None) == 'ofa':
image = load_image(data['image'])
if isinstance(data['image'], str):
image = load_image(data['image'])
else:
image = data['image']
if image.mode != 'RGB':
image = image.convert('RGB')
img_buffer = BytesIO()
image.save(img_buffer, format='JPEG')
data['image'] = Image.open(img_buffer)
@@ -102,8 +107,6 @@ class OfaPreprocessor(Preprocessor):
for k, v in data.items():
str_data[k] = str(v)
sample['sample'] = str_data
# import pdb
# pdb.set_trace()
if self.no_collate:
return sample
else:


+ 25
- 8
modelscope/preprocessors/ofa/base.py View File

@@ -42,6 +42,7 @@ class OfaBasePreprocessor:
for key, value in tokenizer.get_vocab().items()
}
self.max_src_length = cfg.model.get('max_src_length', 256)
self.max_tgt_length = cfg.model.get('max_tgt_length', 256)
self.max_image_size = cfg.model.get('max_image_size', 512)
self.language = self.cfg.model.get('language', 'en')
self.prompt_type = self.cfg.model.get('prompt_type', 'none')
@@ -58,22 +59,23 @@ class OfaBasePreprocessor:
self.std = [0.5, 0.5, 0.5]
self.patch_image_size = self.cfg.model.get('patch_image_size', 480)
self.constraint_trie = None
self.index2ans = {}
if self.cfg.model.get('answer2label', False):
if self.cfg.model.get('answer2label', None):
ans2label_file = osp.join(model_dir, self.cfg.model.answer2label)
with open(ans2label_file, 'r') as reader:
ans2label_dict = json.load(reader)
self.ans2label = ans2label_dict
self.label2ans = {v: k for k, v in self.ans2label.items()}
self.constraint_trie = Trie(tokenizer.eos_token_id)
for i, answer in enumerate(ans2label_dict.keys()):
answer_item = tokenizer(
' ' + answer,
return_tensors='pt',
add_special_tokens=False).input_ids.squeeze(0)
answer_item = self.tokenize_text(
' ' + answer, add_bos=False, add_eos=False)
self.constraint_trie.insert([tokenizer.bos_token_id]
+ answer_item.tolist()
+ [tokenizer.eos_token_id])

def get_inputs(self, text, add_bos=True, add_eos=True):
def tokenize_text(self, text, add_bos=True, add_eos=True):
if text is None:
return None
inputs = self.tokenizer(
text,
max_length=self.max_src_length,
@@ -88,7 +90,7 @@ class OfaBasePreprocessor:

@staticmethod
def pre_caption(caption, max_words=None):
caption = caption.lower().lstrip(',.!?*#:;~').replace('-', ' ')\
caption = caption.lower().lstrip(',.!?*#:;~').replace('-', ' ') \
.replace('/', ' ').replace('<person>', 'person')

caption = re.sub(
@@ -126,3 +128,18 @@ class OfaBasePreprocessor:
question = ' '.join(question_words[:max_ques_words])

return question

def add_constraint_mask(self, sample):
target_itm = sample['target']
len_label_itm = target_itm.ne(self.pad_item).sum(dim=0).item()
if self.constraint_trie:
constraint_mask = torch.zeros(
(len(target_itm), len(self.tgt_dict))).bool()
start_idx = len(target_itm) - len_label_itm
for i in range(start_idx, len(target_itm)):
constraint_prefix_token = self.bos_item.tolist(
) + target_itm[start_idx:i].tolist()
constraint_nodes = self.constraint_trie.get_next_layer(
constraint_prefix_token)
constraint_mask[i][constraint_nodes] = True
sample['constraint_mask'] = constraint_mask

+ 16
- 1
modelscope/preprocessors/ofa/image_captioning.py View File

@@ -38,14 +38,29 @@ class OfaImageCaptioningPreprocessor(OfaBasePreprocessor):
])

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
if self.mode == ModeKeys.TRAIN:
return self._build_train_sample(data)
else:
return self._build_infer_sample(data)

def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = data['image'] if isinstance(
data['image'], Image.Image) else load_image(data['image'])
patch_image = self.patch_resize_transform(image)
prompt = self.cfg.model.get('prompt', ' what does the image describe?')
inputs = self.get_inputs(prompt)
inputs = self.tokenize_text(prompt)
sample = {
'source': inputs,
'patch_image': patch_image,
'patch_mask': torch.tensor([True])
}
return sample

def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self._build_infer_sample(data)
target = data['target']
target = target.translate(self.transtab).strip()
target_token_list = target.strip().split()
target = ' '.join(target_token_list[:self.max_tgt_length])
sample['target'] = self.tokenize_text(target)
return sample

+ 1
- 1
modelscope/preprocessors/ofa/image_classification.py View File

@@ -42,7 +42,7 @@ class OfaImageClassificationPreprocessor(OfaBasePreprocessor):
data['image'], Image.Image) else load_image(data['image'])
patch_image = self.patch_resize_transform(image)
prompt = self.cfg.model.get('prompt', ' what does the image describe?')
inputs = self.get_inputs(prompt)
inputs = self.tokenize_text(prompt)
sample = {
'source': inputs,
'patch_image': patch_image,


+ 1
- 1
modelscope/preprocessors/ofa/summarization.py View File

@@ -31,7 +31,7 @@ class OfaSummarizationPreprocessor(OfaBasePreprocessor):
prompt = self.cfg.model.get(
'prompt', ' " {} " Summarize the article with a title: ')
text = prompt.format(source)
inputs = self.get_inputs(text)
inputs = self.tokenize_text(text)
if self.prompt_type == 'none':
decoder_prompt = self.bos_item
elif self.prompt_type == 'prev_output':


+ 42
- 8
modelscope/preprocessors/ofa/text_classification.py View File

@@ -1,6 +1,8 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict

import torch

from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor

@@ -24,24 +26,56 @@ class OfaTextClassificationPreprocessor(OfaBasePreprocessor):
self).__init__(cfg, model_dir, mode, *args, **kwargs)

def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
if self.mode == ModeKeys.TRAIN:
return self._build_train_sample(data)
else:
return self._build_infer_sample(data)

def _build_instruction(self, data):
text1 = ' '.join(
data['text'].lower().strip().split()[:self.max_src_length])
text2 = ' '.join(
data['text2'].lower().strip().split()[:self.max_src_length])
prompt = ' can text1 " {} " imply text2 " {} "?'
text = prompt.format(text1, text2)
inputs = self.get_inputs(text)
instruction_itm = self.tokenize_text(text)
return instruction_itm

def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
instruction_itm = self._build_instruction(data)
assert 'label' in data, 'there must has `label` column in train phase '
label = data['label']
if self.label2ans:
label = self.label2ans[label] # ans
label_itm = self.tokenize_text(f' {label}', add_bos=False)
if self.prompt_type == 'none':
target_itm = label_itm
elif self.prompt_type == 'prev_output':
target_itm = torch.cat([instruction_itm[1:-1], label_itm])
else:
raise NotImplementedError
prev_output_itm = torch.cat([self.bos_item, target_itm[:-1]])
target_itm[:-len(label_itm)] = self.pad_item
sample = {
'source': instruction_itm,
'target': target_itm,
'prev_output_tokens': prev_output_itm,
}
self.add_constraint_mask(sample)
return sample

def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
instruction_itm = self._build_instruction(data)
if self.prompt_type == 'none':
decoder_prompt = self.bos_item
elif self.prompt_type == 'src':
decoder_prompt = inputs
prefix_token = []
elif self.prompt_type == 'prev_output':
decoder_prompt = inputs[:-1]
prefix_token = instruction_itm[:-1] # remove eos
else:
raise NotImplementedError
sample = {
'source': inputs,
'decoder_prompt': decoder_prompt,
'prefix_token': decoder_prompt[:-1],
'source': instruction_itm,
'prefix_token': prefix_token,
}
if 'label' in data:
sample['label'] = self.label2ans[data['label']]
return sample

+ 1
- 1
modelscope/preprocessors/ofa/text_to_image_synthesis.py View File

@@ -30,7 +30,7 @@ class OfaTextToImageSynthesisPreprocessor(OfaBasePreprocessor):
source = ' '.join(
data['text'].lower().strip().split()[:self.max_src_length])
source = 'what is the complete image? caption: {}'.format(source)
inputs = self.get_inputs(source)
inputs = self.tokenize_text(source)
sample = {
'source': inputs,
'patch_images': None,


+ 2
- 0
modelscope/preprocessors/ofa/utils/collate.py View File

@@ -47,6 +47,8 @@ def collate_fn(samples, pad_idx, eos_idx):
batch['conf'] = torch.cat([s['conf'] for s in samples], dim=0)
if samples[0].get('ref_dict', None) is not None:
batch['ref_dict'] = np.array([s['ref_dict'] for s in samples])
if samples[0].get('label', None) is not None:
batch['labels'] = np.array([s['label'] for s in samples]).tolist()
if samples[0].get('constraint_mask', None) is not None:
batch['constraint_masks'] = merge('constraint_mask')
if samples[0].get('decoder_prompt', None) is not None:


+ 1
- 1
modelscope/preprocessors/ofa/visual_entailment.py View File

@@ -53,7 +53,7 @@ class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor):
prompt = self.cfg.model.get(
'prompt', ' can image and text1 " {} " imply text2 " {} "?')
text = prompt.format(caption, hypothesis)
inputs = self.get_inputs(text)
inputs = self.tokenize_text(text)
if self.prompt_type == 'none':
decoder_prompt = self.bos_item
elif self.prompt_type == 'src':


+ 1
- 1
modelscope/preprocessors/ofa/visual_grounding.py View File

@@ -48,7 +48,7 @@ class OfaVisualGroundingPreprocessor(OfaBasePreprocessor):
prompt = self.cfg.model.get(
'prompt', ' which region does the text " {} " describe?')
text = prompt.format(src_caption)
src_item = self.get_inputs(text)
src_item = self.tokenize_text(text)
sample = {
'source': src_item,
'patch_image': patch_image,


+ 1
- 1
modelscope/preprocessors/ofa/visual_question_answering.py View File

@@ -42,7 +42,7 @@ class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor):
data['image'], Image.Image) else load_image(data['image'])
patch_image = self.patch_resize_transform(image)
text = ' {}'.format(data['text'])
inputs = self.get_inputs(text)
inputs = self.tokenize_text(text)
if self.prompt_type == 'none':
decoder_prompt = self.bos_item
elif self.prompt_type == 'src':


+ 35
- 33
modelscope/trainers/multi_modal/ofa/ofa_trainer.py View File

@@ -3,6 +3,7 @@ from functools import partial
from typing import Dict, Optional

from datasets import load_dataset
from torch import distributed as dist

from modelscope.metainfo import Trainers
from modelscope.models.base import Model
@@ -15,7 +16,7 @@ from modelscope.trainers.optimizer.builder import build_optimizer
from modelscope.utils.config import Config
from modelscope.utils.constant import ConfigKeys, ModeKeys, ModelFile
from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion,
OFADataset, get_schedule)
get_schedule)


@TRAINERS.register_module(module_name=Trainers.ofa_tasks)
@@ -36,31 +37,13 @@ class OFATrainer(EpochBasedTrainer):
preprocessor = {
ConfigKeys.train:
OfaPreprocessor(
model_dir=model_dir, model=ModeKeys.TRAIN, no_collate=True),
model_dir=model_dir, mode=ModeKeys.TRAIN, no_collate=True),
ConfigKeys.val:
OfaPreprocessor(
model_dir=model_dir, model=ModeKeys.EVAL, no_collate=True),
model_dir=model_dir, mode=ModeKeys.EVAL, no_collate=True),
}
# train_dataset = dataset['train'].to_torch_dataset(
# preprocessors=OfaPreprocessor(model_dir=model_dir, model=ModeKeys.TRAIN, no_collate=True),
# )
# valid_dataset = dataset['valid'].to_torch_dataset(
# preprocessors=OfaPreprocessor(model_dir=model_dir, model=ModeKeys.TRAIN, no_collate=True),
# )
# train_dataset = OFADataset(
# file_path=cfg.dataset.train_set,
# selected_id_keys=cfg.dataset.selected_id_keys,
# preprocessor=OfaPreprocessor(
# model_dir=model_dir, mode=ModeKeys.TRAIN),
# )
# val_dataset = OFADataset(
# file_path=cfg.dataset.valid_set,
# selected_id_keys=cfg.dataset.selected_id_keys,
# preprocessor=OfaPreprocessor(
# model_dir=model_dir, mode=ModeKeys.EVAL),
# )
epoch_steps = len(dataset['train']) // (
cfg.train.gradient_accumulation_steps
cfg.train.optimizer_hook.cumulative_iters
* cfg.train.dataloader.batch_size_per_gpu)
cfg.train.lr_scheduler.num_train_steps = epoch_steps * cfg.train.max_epochs
cfg.train.criterion.tokenizer = model.tokenizer
@@ -78,6 +61,11 @@ class OFATrainer(EpochBasedTrainer):
pad_idx=model.tokenizer.pad_token_id,
eos_idx=model.tokenizer.eos_token_id,
)
if 'launcher' not in kwargs and cfg.train.get('launcher', None):
kwargs['launcher'] = cfg.train.launcher
if 'use_fp16' not in kwargs and cfg.train.get('use_fp16', False):
kwargs['use_fp16'] = cfg.train.use_fp16

super().__init__(
cfg_file=cfg_file,
model=model,
@@ -91,14 +79,28 @@ class OFATrainer(EpochBasedTrainer):
**kwargs,
)

# def train(self, *args, **kwargs):
# pass

def evaluate(self,
checkpoint_path: Optional[str] = None,
*args,
**kwargs) -> Dict[str, float]:
pass

def prediction_step(self, model, inputs):
pass
def train_step(self, model, inputs):
model.train()
model_outputs = model.forward(inputs)
loss, sample_size, logging_output = self.criterion(
model_outputs, inputs)
train_outputs = {'loss': loss}
# add model output info to log
if 'log_vars' not in train_outputs:
default_keys_pattern = ['loss']
match_keys = set([])
for key_p in default_keys_pattern:
match_keys.update(
[key for key in train_outputs.keys() if key_p in key])
log_vars = {}
for key in match_keys:
value = train_outputs.get(key, None)
if value is not None:
if dist.is_available() and dist.is_initialized():
value = value.data.clone()
dist.all_reduce(value.div_(dist.get_world_size()))
log_vars.update({key: value.item()})
self.log_buffer.update(log_vars)
else:
self.log_buffer.update(train_outputs['log_vars'])
self.train_outputs = train_outputs

+ 0
- 120
modelscope/trainers/multi_modal/ofa/ofa_trainer_old.py View File

@@ -1,120 +0,0 @@
import os
from os import path as osp
from typing import Dict, Optional

import torch
import torch.distributed as dist
import transformers
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from modelscope.metainfo import Trainers
from modelscope.models.base import Model
from modelscope.preprocessors.multi_modal import OfaPreprocessor
from modelscope.preprocessors.ofa.utils.collate import collate_fn
from modelscope.trainers.base import BaseTrainer
from modelscope.trainers.builder import TRAINERS
from modelscope.utils.constant import ModeKeys, ModelFile
from modelscope.utils.logger import get_logger
from modelscope.utils.torch_utils import init_dist
from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion,
OFADataset, get_schedule)

logger = get_logger()


@TRAINERS.register_module(module_name=Trainers.ofa_tasks)
class OFAOldTrainer(BaseTrainer):

def __init__(self, model: str, *args, **kwargs):
model = Model.from_pretrained(model)
super().__init__(osp.join(model.model_dir, ModelFile.CONFIGURATION))
self.model_dir = model.model_dir
self.model = model.model
self.device_id = 0
self.total_epoch = self.cfg.train.epoch
self.train_batch_size = self.cfg.train.batch_size
self.val_batch_size = self.cfg.evaluation.batch_size
self.save_dir = self.cfg.train.save_dir
init_dist(launcher='pytorch')
self.train_dataset = OFADataset(
file_path=self.cfg.dataset.train_set,
selected_id_keys=self.cfg.dataset.selected_id_keys,
preprocessor=OfaPreprocessor(
model_dir=self.model_dir, split=ModeKeys.TRAIN),
)
self.val_dataset = OFADataset(
file_path=self.cfg.dataset.valid_set,
selected_id_keys=self.cfg.dataset.selected_id_keys,
preprocessor=OfaPreprocessor(
model_dir=self.model_dir, split=ModeKeys.EVAL),
)
epoch_steps = len(
self.train_dataset) // self.cfg.train.gradient_accumulation_steps
self.cfg.train.num_train_steps = epoch_steps * self.cfg.train.epoch
self.criterion = AdjustLabelSmoothedCrossEntropyCriterion(
self.cfg.train.criterion)

def train(self, *args, **kwargs):
assert dist.is_initialized()

self.model.train()
self.model.to(self.device_id)
ddp_model = torch.nn.parallel.DistributedDataParallel(
self.model, device_ids=[
self.device_id,
])

optimizer = transformers.AdamW(
self.model.parameters(),
lr=self.cfg.train.lr,
weight_decay=self.cfg.train.weight_decay,
correct_bias=False,
)
scheduler_class, scheduler_args = get_schedule(self.cfg.train)
if scheduler_class is not None:
lr_scheduler = scheduler_class(**{'optimizer': optimizer},
**scheduler_args)
else:
lr_scheduler = None
for epoch in range(self.total_epoch):
train_sampler = DistributedSampler(
dataset=self.train_dataset, shuffle=True)
train_sampler.set_epoch(epoch)

train_params = {
'pin_memory': True,
'collate_fn': collate_fn,
'batch_size': self.train_batch_size,
'shuffle': False,
'drop_last': True,
'sampler': train_sampler,
'num_workers': 2,
}

train_loader = DataLoader(self.train_dataset, **train_params)

for idx, batch in enumerate(train_loader, start=1):
model_outputs = ddp_model(**batch)
loss, sample_size, logging_output = self.criterion(
model_outputs, batch)
loss.backward()
optimizer.zero_grad()
if lr_scheduler is not None:
lr_scheduler.step()
optimizer.step()
optimizer.zero_grad()
if idx % 10 == 0:
logger.info(
'epoch: {}, train batch {}/{}, loss={:.5f}'.format(
epoch, idx, len(train_loader), loss.item()))
if dist.get_rank() == 0:
os.makedirs(self.ckpt_dir, exist_ok=True)
torch.save(ddp_model.module.state_dict(),
f'{self.ckpt_dir}/epoch{epoch}.bin')

def evaluate(self,
checkpoint_path: Optional[str] = None,
*args,
**kwargs) -> Dict[str, float]:
pass

+ 1
- 37
modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py View File

@@ -172,47 +172,11 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
if isinstance(sample, list):
if self.sample_patch_num > 0:
sample[0]['net_input'][
'sample_patch_num'] = self.sample_patch_num
loss_v1, sample_size_v1, logging_output_v1 = self.forward(
output[0], sample[0], update_num, reduce)
loss_v2, sample_size_v2, logging_output_v2 = self.forward(
output[1], sample[1], update_num, reduce)
loss = loss_v1 / sample_size_v1 + loss_v2 / sample_size_v2
sample_size = 1
logging_output = {
'loss':
loss.data,
'loss_v1':
loss_v1.data,
'loss_v2':
loss_v2.data,
'nll_loss':
logging_output_v1['nll_loss'].data / sample_size_v1
+ logging_output_v2['nll_loss'].data / sample_size_v2,
'ntokens':
logging_output_v1['ntokens'] + logging_output_v2['ntokens'],
'nsentences':
logging_output_v1['nsentences']
+ logging_output_v2['nsentences'],
'sample_size':
1,
'sample_size_v1':
sample_size_v1,
'sample_size_v2':
sample_size_v2,
}
return loss, sample_size, logging_output

if self.use_rdrop:
construct_rdrop_sample(sample)

net_output = output
# model(**sample["net_input"])
loss, nll_loss, ntokens = self.compute_loss(
net_output, sample, update_num, reduce=reduce)
output, sample, update_num, reduce=reduce)
sample_size = (
sample['target'].size(0) if self.sentence_avg else ntokens)
logging_output = {


+ 9
- 9
modelscope/trainers/trainer.py View File

@@ -12,6 +12,7 @@ import numpy as np
import torch
from torch import distributed as dist
from torch import nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataloader import default_collate
from torch.utils.data.distributed import DistributedSampler
@@ -159,8 +160,6 @@ class EpochBasedTrainer(BaseTrainer):
train_dataset,
mode=ModeKeys.TRAIN,
preprocessor=self.train_preprocessor)
# import pdb
# pdb.set_trace()
self.eval_dataset = self.to_task_dataset(
eval_dataset,
mode=ModeKeys.EVAL,
@@ -200,7 +199,6 @@ class EpochBasedTrainer(BaseTrainer):
self._max_epochs = self.cfg.train.max_epochs
else:
self._max_epochs = kwargs['max_epochs']

self._train_iters_per_epoch = kwargs.get('train_iters_per_epoch', None)
self._eval_iters_per_epoch = kwargs.get('val_iters_per_epoch', None)
if self._train_iters_per_epoch is None and hasattr(
@@ -220,12 +218,12 @@ class EpochBasedTrainer(BaseTrainer):
init_dist(kwargs['launcher'])

self._dist = get_dist_info()[1] > 1

# model placement
if self.device.type == 'cuda':
self.model.to(self.device)
if not is_parallel(self.model) and self._dist:
self.model = self.to_parallel(self.model)
self.device = self.model.device

def rebuild_config(self, cfg: Config):
"""A method used to rebuild the config, any subclass can override this method.
@@ -429,7 +427,7 @@ class EpochBasedTrainer(BaseTrainer):
self.register_hook_from_cfg(self.cfg.train.hooks)
self.train_loop(self.train_dataloader)

def evaluate(self, checkpoint_path=None):
def evaluate(self, checkpoint_path=None, *arg, **kwargs):
self.model.eval()
self._mode = ModeKeys.EVAL

@@ -475,12 +473,12 @@ class EpochBasedTrainer(BaseTrainer):
self.cfg.parallel.update(
dict(module=model, device_ids=[torch.cuda.current_device()]))
return build_parallel(self.cfg.parallel)
model.to(f'cuda:{torch.cuda.current_device()}')
dp_cfg = dict(
type='DistributedDataParallel',
module=model,
find_unused_parameters=True,
device_ids=[torch.cuda.current_device()])

return build_parallel(dp_cfg)

def train_step(self, model, inputs):
@@ -504,8 +502,10 @@ class EpochBasedTrainer(BaseTrainer):
model.train()
self._mode = ModeKeys.TRAIN
# call model forward but not __call__ to skip postprocess
forward_func = model.module.forward if \
isinstance(model, DistributedDataParallel) else model.forward
if isinstance(inputs,
Mapping) and not func_receive_dict_inputs(model.forward):
Mapping) and not func_receive_dict_inputs(forward_func):
train_outputs = model.forward(**inputs)
else:
train_outputs = model.forward(inputs)
@@ -751,7 +751,7 @@ class EpochBasedTrainer(BaseTrainer):
batch_size = batch_size_per_gpu
num_workers = workers_per_gpu

if dist:
if dist and not isinstance(dataset, torch.utils.data.IterableDataset):
sampler = DistributedSampler(
dataset, num_replicas=world_size, rank=rank, shuffle=shuffle)
else:


+ 29
- 15
modelscope/trainers/utils/inference.py View File

@@ -9,6 +9,7 @@ from collections.abc import Mapping

import torch
from torch import distributed as dist
from torch.nn.parallel import DistributedDataParallel
from tqdm import tqdm

from modelscope.utils.data_utils import to_device
@@ -68,7 +69,10 @@ def single_gpu_test(model,
batch_size = 1 # iteration count
else:
if isinstance(data, dict):
batch_size = len(next(iter(data.values())))
if 'nsentences' in data:
batch_size = data['nsentences']
else:
batch_size = len(next(iter(data.values())))
else:
batch_size = len(data)
for _ in range(batch_size):
@@ -142,28 +146,38 @@ def multi_gpu_test(model,
data = to_device(data, device)
data_list.append(data)
with torch.no_grad():
if isinstance(data, Mapping) and not func_receive_dict_inputs(
model.forward):
forward_func = model.module.forward if \
isinstance(model, DistributedDataParallel) else model.forward
if isinstance(data, Mapping
) and not func_receive_dict_inputs(forward_func):
result = model.forward(**data)
else:
result = model.forward(data)
results.append(result)

if rank == 0:
if isinstance(data, dict):
batch_size = len(next(iter(data.values())))
if isinstance(data, dict):
if 'nsentences' in data:
batch_size = data['nsentences']
else:
batch_size = len(data)

if progress_with_iters:
total_samples += batch_size * world_size
batch_size = 1 # iteration count
batch_size = len(next(iter(data.values())))
else:
batch_size = len(data)
if i >= (data_len // world_size) - 1:
total_samples = torch.LongTensor([batch_size]).to(model.device)
dist.all_reduce(total_samples, op=dist.reduce_op.SUM)
total_samples = total_samples.item()
else:
total_samples = batch_size * world_size
if progress_with_iters:
iter_cnt_all = world_size
else:
iter_cnt_all = total_samples
count += iter_cnt_all

batch_size_all = batch_size * world_size
count += batch_size_all
if rank == 0:
if count > data_len:
batch_size_all = data_len - (count - batch_size_all)
for _ in range(batch_size_all):
iter_cnt_all = data_len - (count - iter_cnt_all)
for _ in range(iter_cnt_all):
pbar.update()

if progress_with_iters and (i + 1) >= data_len:


+ 15
- 1
modelscope/utils/device.py View File

@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from contextlib import contextmanager

from modelscope.utils.constant import Devices, Frameworks
@@ -105,3 +105,17 @@ def create_device(device_name):
device = torch.device('cpu')

return device


def get_device():
import torch
from torch import distributed as dist
if torch.cuda.is_available():
if dist.is_available() and dist.is_initialized(
) and 'LOCAL_RANK' in os.environ:
device_id = f"cuda:{os.environ['LOCAL_RANK']}"
else:
device_id = 'cuda:0'
else:
device_id = 'cpu'
return torch.device(device_id)

+ 17
- 0
modelscope/utils/multi_modal/forked_pdb.py View File

@@ -0,0 +1,17 @@
import pdb
import sys


class ForkedPdb(pdb.Pdb):
"""A Pdb subclass that may be used
from a forked multiprocessing child

"""

def interaction(self, *args, **kwargs):
_stdin = sys.stdin
try:
sys.stdin = open('/dev/stdin')
pdb.Pdb.interaction(self, *args, **kwargs)
finally:
sys.stdin = _stdin

+ 21
- 0
tests/pipelines/test_ofa_tasks.py View File

@@ -252,6 +252,27 @@ class OfaTasksTest(unittest.TestCase):
result[OutputKeys.OUTPUT_IMG].save('result.png')
print(f'Output written to {osp.abspath("result.png")}')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_visual_question_answering_huge_with_name(self):
model = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/ofa_visual-question-answering_pretrain_huge_en'
ofa_pipe = pipeline(Tasks.visual_question_answering, model=model)
image = 'data/test/images/visual_question_answering.png'
text = 'what is grown on the plant?'
input = {'image': image, 'text': text}
result = ofa_pipe(input)
print(result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_image_captioning_huge_with_name(self):
model = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/ofa_image-caption_coco_huge_en'
img_captioning = pipeline(
task=Tasks.image_captioning,
model=model,
)
result = img_captioning(
{'image': 'data/test/images/image_captioning.png'})
print(result[OutputKeys.CAPTION])


if __name__ == '__main__':
unittest.main()

+ 1
- 0
tests/trainers/test_ofa_trainer.py View File

@@ -11,6 +11,7 @@ class TestOfaTrainer(unittest.TestCase):

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer(self):
model_id = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/maas_mnli_pretrain_ckpt'
model_id = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/ofa_text-classification_mnli_large_en'
self.trainer = OFATrainer(model_id)
self.trainer.train()


Loading…
Cancel
Save