@@ -0,0 +1,44 @@ | |||||
from typing import Dict | |||||
import numpy as np | |||||
from modelscope.metainfo import Metrics | |||||
from modelscope.outputs import OutputKeys | |||||
from modelscope.utils.registry import default_group | |||||
from .base import Metric | |||||
from .builder import METRICS, MetricKeys | |||||
@METRICS.register_module(group_key=default_group, module_name=Metrics.accuracy) | |||||
class AccuracyMetric(Metric): | |||||
"""The metric computation class for sequence classification classes. | |||||
This metric class calculates accuracy for the whole input batches. | |||||
""" | |||||
def __init__(self, *args, **kwargs): | |||||
super().__init__(*args, **kwargs) | |||||
self.preds = [] | |||||
self.labels = [] | |||||
def add(self, outputs: Dict, inputs: Dict): | |||||
label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS | |||||
ground_truths = inputs[label_name] | |||||
eval_results = outputs[label_name] | |||||
assert type(ground_truths) == type(eval_results) | |||||
if isinstance(ground_truths, list): | |||||
self.preds.extend(eval_results) | |||||
self.labels.extend(ground_truths) | |||||
elif isinstance(ground_truths, np.ndarray): | |||||
self.preds.extend(eval_results.tolist()) | |||||
self.labels.extend(ground_truths.tolist()) | |||||
else: | |||||
raise 'only support list or np.ndarray' | |||||
def evaluate(self): | |||||
assert len(self.preds) == len(self.labels) | |||||
return { | |||||
MetricKeys.ACCURACY: (np.asarray([ | |||||
pred == ref for pred, ref in zip(self.preds, self.labels) | |||||
])).mean().item() | |||||
} |
@@ -409,10 +409,12 @@ class SequenceGenerator(nn.Module): | |||||
out_prefix = p_toks_len_beam < ( | out_prefix = p_toks_len_beam < ( | ||||
step + no_repeat_ngram_size - 1) | step + no_repeat_ngram_size - 1) | ||||
else: | else: | ||||
out_prefix = [True] * bsz * beam_size | |||||
out_prefix = torch.ones(bsz * beam_size).bool() | |||||
ngram_blocker_tokens = tokens[out_prefix] | ngram_blocker_tokens = tokens[out_prefix] | ||||
ngram_blocker_lprobs = lprobs[out_prefix] | ngram_blocker_lprobs = lprobs[out_prefix] | ||||
ngram_blocker_bsz = out_prefix.sum() // beam_size | |||||
ngram_blocker_bsz = torch.div( | |||||
out_prefix.sum(), beam_size, rounding_mode='trunc') | |||||
lprobs[out_prefix] = self.repeat_ngram_blocker( | lprobs[out_prefix] = self.repeat_ngram_blocker( | ||||
tokens=ngram_blocker_tokens, | tokens=ngram_blocker_tokens, | ||||
lprobs=ngram_blocker_lprobs, | lprobs=ngram_blocker_lprobs, | ||||
@@ -0,0 +1 @@ | |||||
from .constant import OFA_TASK_KEY_MAPPING |
@@ -10,7 +10,6 @@ import torch.nn.functional as F | |||||
from modelscope.metainfo import Models | from modelscope.metainfo import Models | ||||
from modelscope.models import TorchModel | from modelscope.models import TorchModel | ||||
from modelscope.models.base import Tensor | |||||
from modelscope.models.builder import MODELS | from modelscope.models.builder import MODELS | ||||
from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
from modelscope.preprocessors.ofa.utils.collate import collate_tokens | from modelscope.preprocessors.ofa.utils.collate import collate_tokens | ||||
@@ -38,7 +37,9 @@ class OfaForAllTasks(TorchModel): | |||||
def __init__(self, model_dir, *args, **kwargs): | def __init__(self, model_dir, *args, **kwargs): | ||||
super().__init__(model_dir=model_dir, *args, **kwargs) | super().__init__(model_dir=model_dir, *args, **kwargs) | ||||
model = OFAModel.from_pretrained(model_dir) | |||||
sd = torch.load(osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE)) | |||||
sd = sd if 'meta' not in sd else sd['state_dict'] | |||||
model = OFAModel.from_pretrained(model_dir, state_dict=sd) | |||||
self.cfg = Config.from_file( | self.cfg = Config.from_file( | ||||
osp.join(model_dir, ModelFile.CONFIGURATION)) | osp.join(model_dir, ModelFile.CONFIGURATION)) | ||||
self.model = model.module if hasattr(model, 'module') else model | self.model = model.module if hasattr(model, 'module') else model | ||||
@@ -65,10 +66,9 @@ class OfaForAllTasks(TorchModel): | |||||
self.gen_type = self.cfg.model.get('gen_type', 'generation') | self.gen_type = self.cfg.model.get('gen_type', 'generation') | ||||
assert self.gen_type in ['generation', 'traverse'], \ | assert self.gen_type in ['generation', 'traverse'], \ | ||||
'model.gen_type must be in ["generation", "traverse"]' | 'model.gen_type must be in ["generation", "traverse"]' | ||||
self._device = torch.device('cuda') if torch.cuda.is_available() \ | |||||
else torch.device('cpu') | |||||
self.eos_item = torch.LongTensor([self.tokenizer.eos_token_id | |||||
]).to(self._device) | |||||
self.bos_item = torch.LongTensor([self.tokenizer.bos_token_id]) | |||||
self.pad_item = torch.LongTensor([self.tokenizer.pad_token_id]) | |||||
self.eos_item = torch.LongTensor([self.tokenizer.eos_token_id]) | |||||
self.index2ans = {} | self.index2ans = {} | ||||
self.ans2label_dict = {} | self.ans2label_dict = {} | ||||
self.load_ans2label() | self.load_ans2label() | ||||
@@ -89,7 +89,8 @@ class OfaForAllTasks(TorchModel): | |||||
self.val_masks_l = [] | self.val_masks_l = [] | ||||
self.build_trie() | self.build_trie() | ||||
sg_args['constraint_trie'] = self.constraint_trie | sg_args['constraint_trie'] = self.constraint_trie | ||||
self.model.to(self._device) | |||||
else: | |||||
self.constraint_trie = None | |||||
self.generator = sg.SequenceGenerator(**sg_args) | self.generator = sg.SequenceGenerator(**sg_args) | ||||
inference_d = { | inference_d = { | ||||
'generation': self._text_gen_inference, | 'generation': self._text_gen_inference, | ||||
@@ -106,42 +107,52 @@ class OfaForAllTasks(TorchModel): | |||||
} | } | ||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | ||||
input = move_to_device(input, self.model.device) | |||||
if self.model.training: | |||||
return self.model(**input['net_input']) | |||||
else: | |||||
return self.inference(input) | |||||
def inference(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
ret = self.task_inference_mapping[self.cfg.task](input) | ret = self.task_inference_mapping[self.cfg.task](input) | ||||
ret['samples'] = input['samples'] | |||||
if 'samples' in input: | |||||
ret['samples'] = input['samples'] | |||||
for key in [ | for key in [ | ||||
OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES, | OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES, | ||||
OutputKeys.LABELS, OutputKeys.SCORES | OutputKeys.LABELS, OutputKeys.SCORES | ||||
]: | ]: | ||||
if key in ret and len(ret[key]) == 1: | |||||
ret[key] = ret[key][0] | |||||
if key not in ret: | if key not in ret: | ||||
ret[key] = None | ret[key] = None | ||||
return ret | return ret | ||||
def postprocess(self, input: Dict[str, Tensor], | |||||
**kwargs) -> Dict[str, Tensor]: | |||||
def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: | |||||
if self.cfg.task == Tasks.image_captioning: | if self.cfg.task == Tasks.image_captioning: | ||||
caption = input[OutputKeys.CAPTION] | caption = input[OutputKeys.CAPTION] | ||||
caption = caption.translate(self.transtab).strip() | |||||
result_l = list() | |||||
for cap in caption: | |||||
result_l.append(cap.translate(self.transtab).strip()) | |||||
input[OutputKeys.CAPTION] = caption | input[OutputKeys.CAPTION] = caption | ||||
return input | return input | ||||
def _text_gen_inference(self, input): | def _text_gen_inference(self, input): | ||||
input = move_to_device(input, self._device) | |||||
if 'prefix_tokens' in input: | |||||
gen_output = self.generator.generate( | |||||
[self.model], input, prefix_tokens=input['prefix_tokens']) | |||||
else: | |||||
gen_output = self.generator.generate([self.model], input) | |||||
gen_outputs = self.generator.generate([self.model], | |||||
input, | |||||
prefix_tokens=input.get( | |||||
'prefix_tokens', None)) | |||||
gen_l = list() | gen_l = list() | ||||
for i in range(len(gen_output)): | |||||
if 'prefix_tokens' in input: | |||||
prefix_tokens = input['prefix_tokens'] | |||||
gen_l.append( | |||||
gen_output[i][0]['tokens'][len(prefix_tokens[i]):]) | |||||
for idx, gen_out in enumerate(gen_outputs): | |||||
if len(gen_out) > 0: | |||||
decode_tokens = gen_out[0]['tokens'] | |||||
if 'prefix_tokens' in input: | |||||
prefix_len = input['prefix_tokens'][idx].ne( | |||||
self.pad_item.to(self.model.device)).sum() | |||||
decode_tokens = decode_tokens[prefix_len:] | |||||
gen_l.append(decode_tokens) | |||||
else: | else: | ||||
gen_l.append(gen_output[i][0]['tokens']) | |||||
gen_l.append('') | |||||
result = self.tokenizer.batch_decode(gen_l, skip_special_tokens=True) | result = self.tokenizer.batch_decode(gen_l, skip_special_tokens=True) | ||||
result = [item.strip() for item in result] | |||||
# text generation tasks have no score | # text generation tasks have no score | ||||
ret = {OFA_TASK_KEY_MAPPING[self.cfg.task]: result} | ret = {OFA_TASK_KEY_MAPPING[self.cfg.task]: result} | ||||
if self.cfg.task.endswith('classification'): | if self.cfg.task.endswith('classification'): | ||||
@@ -149,7 +160,6 @@ class OfaForAllTasks(TorchModel): | |||||
return ret | return ret | ||||
def _visual_grounding_inference(self, input): | def _visual_grounding_inference(self, input): | ||||
input = move_to_device(input, self._device) | |||||
gen_output = self.generator.generate([self.model], input) | gen_output = self.generator.generate([self.model], input) | ||||
tokens = [gen_output[i][0]['tokens'] for i in range(len(gen_output))] | tokens = [gen_output[i][0]['tokens'] for i in range(len(gen_output))] | ||||
region_coord_l = list() | region_coord_l = list() | ||||
@@ -163,13 +173,12 @@ class OfaForAllTasks(TorchModel): | |||||
region_tensor[:, ::2] /= input['w_resize_ratios'] | region_tensor[:, ::2] /= input['w_resize_ratios'] | ||||
region_tensor[:, 1::2] /= input['h_resize_ratios'] | region_tensor[:, 1::2] /= input['h_resize_ratios'] | ||||
return { | return { | ||||
OutputKeys.BOXES: move_to_device(region_tensor, | |||||
torch.device('cpu')), | |||||
OutputKeys.BOXES: | |||||
move_to_device(region_tensor, torch.device('cpu')).tolist(), | |||||
OutputKeys.SCORES: [1.0] * region_tensor.shape[0] | OutputKeys.SCORES: [1.0] * region_tensor.shape[0] | ||||
} | } | ||||
def _traverse_inference(self, input): | def _traverse_inference(self, input): | ||||
input = move_to_device(input, self._device) | |||||
encoder_input = dict() | encoder_input = dict() | ||||
for key in input['net_input'].keys(): | for key in input['net_input'].keys(): | ||||
encoder_input[key] = input['net_input'][key] | encoder_input[key] = input['net_input'][key] | ||||
@@ -193,19 +202,19 @@ class OfaForAllTasks(TorchModel): | |||||
torch.cat([ | torch.cat([ | ||||
torch.zeros( | torch.zeros( | ||||
len(decoder_prompt) - 1, | len(decoder_prompt) - 1, | ||||
valid_constraint_mask.size(1)).bool().to(self._device), | |||||
valid_constraint_mask.size(1)).bool(), | |||||
valid_constraint_mask], dim=0) # yapf: disable | valid_constraint_mask], dim=0) # yapf: disable | ||||
for decoder_prompt in input['decoder_prompts'] # yapf: disable | for decoder_prompt in input['decoder_prompts'] # yapf: disable | ||||
for valid_constraint_mask in val_masks] # yapf: disable | for valid_constraint_mask in val_masks] # yapf: disable | ||||
valid_tgt = collate_tokens( | valid_tgt = collate_tokens( | ||||
valid_tgt_items, | valid_tgt_items, | ||||
pad_idx=self.tokenizer.pad_token_id).to(self._device) | |||||
pad_idx=self.tokenizer.pad_token_id).to(self.model.device) | |||||
valid_prev_output = collate_tokens( | valid_prev_output = collate_tokens( | ||||
valid_prev_items, | valid_prev_items, | ||||
pad_idx=self.tokenizer.pad_token_id).to(self._device) | |||||
pad_idx=self.tokenizer.pad_token_id).to(self.model.device) | |||||
val_masks = collate_tokens( | val_masks = collate_tokens( | ||||
valid_constraint_mask_items, | valid_constraint_mask_items, | ||||
pad_idx=self.tokenizer.pad_token_id).to(self._device) | |||||
pad_idx=self.tokenizer.pad_token_id).to(self.model.device) | |||||
new_encoder_out = { | new_encoder_out = { | ||||
'last_hidden_state': | 'last_hidden_state': | ||||
encoder_out['last_hidden_state'].repeat_interleave( | encoder_out['last_hidden_state'].repeat_interleave( | ||||
@@ -280,8 +289,6 @@ class OfaForAllTasks(TorchModel): | |||||
self.val_masks_l += [ | self.val_masks_l += [ | ||||
constraint_mask_list[i:i + self.val_batch_size] | constraint_mask_list[i:i + self.val_batch_size] | ||||
] | ] | ||||
self.val_ans_l = move_to_device(self.val_ans_l, self._device) | |||||
self.val_masks_l = move_to_device(self.val_masks_l, self._device) | |||||
def load_ans2label(self): | def load_ans2label(self): | ||||
if self.cfg.model.get('answer2label', None): | if self.cfg.model.get('answer2label', None): | ||||
@@ -75,7 +75,7 @@ class MsIterableDataset(torch.utils.data.IterableDataset): | |||||
} | } | ||||
for preprocessor in self.preprocessor_list: | for preprocessor in self.preprocessor_list: | ||||
res.update({ | res.update({ | ||||
k: torch.tensor(v) | |||||
k: v # k: torch.tensor(v) | |||||
for k, v in preprocessor(item_dict).items() | for k, v in preprocessor(item_dict).items() | ||||
if k in self.retained_columns | if k in self.retained_columns | ||||
}) | }) | ||||
@@ -350,14 +350,15 @@ class MsDataset: | |||||
def is_numpy_number(value): | def is_numpy_number(value): | ||||
return np.issubdtype(value.dtype, np.integer) or np.issubdtype( | return np.issubdtype(value.dtype, np.integer) or np.issubdtype( | ||||
value.dtype, np.floating) | |||||
value.dtype, np.floating) or np.issubdtype( | |||||
value.dtype, np.bool) | |||||
retained_columns = [] | retained_columns = [] | ||||
for k in sample_res.keys(): | for k in sample_res.keys(): | ||||
if not is_numpy_number(sample_res[k]): | if not is_numpy_number(sample_res[k]): | ||||
logger.warning( | logger.warning( | ||||
f'Data of column {k} is non-numeric, will be removed') | f'Data of column {k} is non-numeric, will be removed') | ||||
continue | |||||
# continue | |||||
retained_columns.append(k) | retained_columns.append(k) | ||||
return MsIterableDataset(self._hf_ds, preprocessor_list, | return MsIterableDataset(self._hf_ds, preprocessor_list, | ||||
@@ -13,6 +13,7 @@ from modelscope.pipelines.base import Input, Model, Pipeline | |||||
from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
from modelscope.preprocessors import OfaPreprocessor, Preprocessor, load_image | from modelscope.preprocessors import OfaPreprocessor, Preprocessor, load_image | ||||
from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
from modelscope.utils.device import get_device | |||||
from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
logger = get_logger() | logger = get_logger() | ||||
@@ -36,6 +37,7 @@ class ImageClassificationPipeline(Pipeline): | |||||
else: | else: | ||||
raise NotImplementedError | raise NotImplementedError | ||||
pipe_model.model.eval() | pipe_model.model.eval() | ||||
pipe_model.to(get_device()) | |||||
if preprocessor is None and isinstance(pipe_model, OfaForAllTasks): | if preprocessor is None and isinstance(pipe_model, OfaForAllTasks): | ||||
preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir) | preprocessor = OfaPreprocessor(model_dir=pipe_model.model_dir) | ||||
super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) | super().__init__(model=pipe_model, preprocessor=preprocessor, **kwargs) | ||||
@@ -84,7 +84,12 @@ class OfaPreprocessor(Preprocessor): | |||||
def _compatible_with_pretrain(self, data): | def _compatible_with_pretrain(self, data): | ||||
if 'image' in data and self.cfg.model.get('type', None) == 'ofa': | if 'image' in data and self.cfg.model.get('type', None) == 'ofa': | ||||
image = load_image(data['image']) | |||||
if isinstance(data['image'], str): | |||||
image = load_image(data['image']) | |||||
else: | |||||
image = data['image'] | |||||
if image.mode != 'RGB': | |||||
image = image.convert('RGB') | |||||
img_buffer = BytesIO() | img_buffer = BytesIO() | ||||
image.save(img_buffer, format='JPEG') | image.save(img_buffer, format='JPEG') | ||||
data['image'] = Image.open(img_buffer) | data['image'] = Image.open(img_buffer) | ||||
@@ -102,8 +107,6 @@ class OfaPreprocessor(Preprocessor): | |||||
for k, v in data.items(): | for k, v in data.items(): | ||||
str_data[k] = str(v) | str_data[k] = str(v) | ||||
sample['sample'] = str_data | sample['sample'] = str_data | ||||
# import pdb | |||||
# pdb.set_trace() | |||||
if self.no_collate: | if self.no_collate: | ||||
return sample | return sample | ||||
else: | else: | ||||
@@ -42,6 +42,7 @@ class OfaBasePreprocessor: | |||||
for key, value in tokenizer.get_vocab().items() | for key, value in tokenizer.get_vocab().items() | ||||
} | } | ||||
self.max_src_length = cfg.model.get('max_src_length', 256) | self.max_src_length = cfg.model.get('max_src_length', 256) | ||||
self.max_tgt_length = cfg.model.get('max_tgt_length', 256) | |||||
self.max_image_size = cfg.model.get('max_image_size', 512) | self.max_image_size = cfg.model.get('max_image_size', 512) | ||||
self.language = self.cfg.model.get('language', 'en') | self.language = self.cfg.model.get('language', 'en') | ||||
self.prompt_type = self.cfg.model.get('prompt_type', 'none') | self.prompt_type = self.cfg.model.get('prompt_type', 'none') | ||||
@@ -58,22 +59,23 @@ class OfaBasePreprocessor: | |||||
self.std = [0.5, 0.5, 0.5] | self.std = [0.5, 0.5, 0.5] | ||||
self.patch_image_size = self.cfg.model.get('patch_image_size', 480) | self.patch_image_size = self.cfg.model.get('patch_image_size', 480) | ||||
self.constraint_trie = None | self.constraint_trie = None | ||||
self.index2ans = {} | |||||
if self.cfg.model.get('answer2label', False): | |||||
if self.cfg.model.get('answer2label', None): | |||||
ans2label_file = osp.join(model_dir, self.cfg.model.answer2label) | ans2label_file = osp.join(model_dir, self.cfg.model.answer2label) | ||||
with open(ans2label_file, 'r') as reader: | with open(ans2label_file, 'r') as reader: | ||||
ans2label_dict = json.load(reader) | ans2label_dict = json.load(reader) | ||||
self.ans2label = ans2label_dict | |||||
self.label2ans = {v: k for k, v in self.ans2label.items()} | |||||
self.constraint_trie = Trie(tokenizer.eos_token_id) | self.constraint_trie = Trie(tokenizer.eos_token_id) | ||||
for i, answer in enumerate(ans2label_dict.keys()): | for i, answer in enumerate(ans2label_dict.keys()): | ||||
answer_item = tokenizer( | |||||
' ' + answer, | |||||
return_tensors='pt', | |||||
add_special_tokens=False).input_ids.squeeze(0) | |||||
answer_item = self.tokenize_text( | |||||
' ' + answer, add_bos=False, add_eos=False) | |||||
self.constraint_trie.insert([tokenizer.bos_token_id] | self.constraint_trie.insert([tokenizer.bos_token_id] | ||||
+ answer_item.tolist() | + answer_item.tolist() | ||||
+ [tokenizer.eos_token_id]) | + [tokenizer.eos_token_id]) | ||||
def get_inputs(self, text, add_bos=True, add_eos=True): | |||||
def tokenize_text(self, text, add_bos=True, add_eos=True): | |||||
if text is None: | |||||
return None | |||||
inputs = self.tokenizer( | inputs = self.tokenizer( | ||||
text, | text, | ||||
max_length=self.max_src_length, | max_length=self.max_src_length, | ||||
@@ -88,7 +90,7 @@ class OfaBasePreprocessor: | |||||
@staticmethod | @staticmethod | ||||
def pre_caption(caption, max_words=None): | def pre_caption(caption, max_words=None): | ||||
caption = caption.lower().lstrip(',.!?*#:;~').replace('-', ' ')\ | |||||
caption = caption.lower().lstrip(',.!?*#:;~').replace('-', ' ') \ | |||||
.replace('/', ' ').replace('<person>', 'person') | .replace('/', ' ').replace('<person>', 'person') | ||||
caption = re.sub( | caption = re.sub( | ||||
@@ -126,3 +128,18 @@ class OfaBasePreprocessor: | |||||
question = ' '.join(question_words[:max_ques_words]) | question = ' '.join(question_words[:max_ques_words]) | ||||
return question | return question | ||||
def add_constraint_mask(self, sample): | |||||
target_itm = sample['target'] | |||||
len_label_itm = target_itm.ne(self.pad_item).sum(dim=0).item() | |||||
if self.constraint_trie: | |||||
constraint_mask = torch.zeros( | |||||
(len(target_itm), len(self.tgt_dict))).bool() | |||||
start_idx = len(target_itm) - len_label_itm | |||||
for i in range(start_idx, len(target_itm)): | |||||
constraint_prefix_token = self.bos_item.tolist( | |||||
) + target_itm[start_idx:i].tolist() | |||||
constraint_nodes = self.constraint_trie.get_next_layer( | |||||
constraint_prefix_token) | |||||
constraint_mask[i][constraint_nodes] = True | |||||
sample['constraint_mask'] = constraint_mask |
@@ -38,14 +38,29 @@ class OfaImageCaptioningPreprocessor(OfaBasePreprocessor): | |||||
]) | ]) | ||||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | ||||
if self.mode == ModeKeys.TRAIN: | |||||
return self._build_train_sample(data) | |||||
else: | |||||
return self._build_infer_sample(data) | |||||
def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
image = data['image'] if isinstance( | image = data['image'] if isinstance( | ||||
data['image'], Image.Image) else load_image(data['image']) | data['image'], Image.Image) else load_image(data['image']) | ||||
patch_image = self.patch_resize_transform(image) | patch_image = self.patch_resize_transform(image) | ||||
prompt = self.cfg.model.get('prompt', ' what does the image describe?') | prompt = self.cfg.model.get('prompt', ' what does the image describe?') | ||||
inputs = self.get_inputs(prompt) | |||||
inputs = self.tokenize_text(prompt) | |||||
sample = { | sample = { | ||||
'source': inputs, | 'source': inputs, | ||||
'patch_image': patch_image, | 'patch_image': patch_image, | ||||
'patch_mask': torch.tensor([True]) | 'patch_mask': torch.tensor([True]) | ||||
} | } | ||||
return sample | return sample | ||||
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
sample = self._build_infer_sample(data) | |||||
target = data['target'] | |||||
target = target.translate(self.transtab).strip() | |||||
target_token_list = target.strip().split() | |||||
target = ' '.join(target_token_list[:self.max_tgt_length]) | |||||
sample['target'] = self.tokenize_text(target) | |||||
return sample |
@@ -42,7 +42,7 @@ class OfaImageClassificationPreprocessor(OfaBasePreprocessor): | |||||
data['image'], Image.Image) else load_image(data['image']) | data['image'], Image.Image) else load_image(data['image']) | ||||
patch_image = self.patch_resize_transform(image) | patch_image = self.patch_resize_transform(image) | ||||
prompt = self.cfg.model.get('prompt', ' what does the image describe?') | prompt = self.cfg.model.get('prompt', ' what does the image describe?') | ||||
inputs = self.get_inputs(prompt) | |||||
inputs = self.tokenize_text(prompt) | |||||
sample = { | sample = { | ||||
'source': inputs, | 'source': inputs, | ||||
'patch_image': patch_image, | 'patch_image': patch_image, | ||||
@@ -31,7 +31,7 @@ class OfaSummarizationPreprocessor(OfaBasePreprocessor): | |||||
prompt = self.cfg.model.get( | prompt = self.cfg.model.get( | ||||
'prompt', ' " {} " Summarize the article with a title: ') | 'prompt', ' " {} " Summarize the article with a title: ') | ||||
text = prompt.format(source) | text = prompt.format(source) | ||||
inputs = self.get_inputs(text) | |||||
inputs = self.tokenize_text(text) | |||||
if self.prompt_type == 'none': | if self.prompt_type == 'none': | ||||
decoder_prompt = self.bos_item | decoder_prompt = self.bos_item | ||||
elif self.prompt_type == 'prev_output': | elif self.prompt_type == 'prev_output': | ||||
@@ -1,6 +1,8 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
from typing import Any, Dict | from typing import Any, Dict | ||||
import torch | |||||
from modelscope.utils.constant import ModeKeys | from modelscope.utils.constant import ModeKeys | ||||
from .base import OfaBasePreprocessor | from .base import OfaBasePreprocessor | ||||
@@ -24,24 +26,56 @@ class OfaTextClassificationPreprocessor(OfaBasePreprocessor): | |||||
self).__init__(cfg, model_dir, mode, *args, **kwargs) | self).__init__(cfg, model_dir, mode, *args, **kwargs) | ||||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | ||||
if self.mode == ModeKeys.TRAIN: | |||||
return self._build_train_sample(data) | |||||
else: | |||||
return self._build_infer_sample(data) | |||||
def _build_instruction(self, data): | |||||
text1 = ' '.join( | text1 = ' '.join( | ||||
data['text'].lower().strip().split()[:self.max_src_length]) | data['text'].lower().strip().split()[:self.max_src_length]) | ||||
text2 = ' '.join( | text2 = ' '.join( | ||||
data['text2'].lower().strip().split()[:self.max_src_length]) | data['text2'].lower().strip().split()[:self.max_src_length]) | ||||
prompt = ' can text1 " {} " imply text2 " {} "?' | prompt = ' can text1 " {} " imply text2 " {} "?' | ||||
text = prompt.format(text1, text2) | text = prompt.format(text1, text2) | ||||
inputs = self.get_inputs(text) | |||||
instruction_itm = self.tokenize_text(text) | |||||
return instruction_itm | |||||
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
instruction_itm = self._build_instruction(data) | |||||
assert 'label' in data, 'there must has `label` column in train phase ' | |||||
label = data['label'] | |||||
if self.label2ans: | |||||
label = self.label2ans[label] # ans | |||||
label_itm = self.tokenize_text(f' {label}', add_bos=False) | |||||
if self.prompt_type == 'none': | |||||
target_itm = label_itm | |||||
elif self.prompt_type == 'prev_output': | |||||
target_itm = torch.cat([instruction_itm[1:-1], label_itm]) | |||||
else: | |||||
raise NotImplementedError | |||||
prev_output_itm = torch.cat([self.bos_item, target_itm[:-1]]) | |||||
target_itm[:-len(label_itm)] = self.pad_item | |||||
sample = { | |||||
'source': instruction_itm, | |||||
'target': target_itm, | |||||
'prev_output_tokens': prev_output_itm, | |||||
} | |||||
self.add_constraint_mask(sample) | |||||
return sample | |||||
def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||||
instruction_itm = self._build_instruction(data) | |||||
if self.prompt_type == 'none': | if self.prompt_type == 'none': | ||||
decoder_prompt = self.bos_item | |||||
elif self.prompt_type == 'src': | |||||
decoder_prompt = inputs | |||||
prefix_token = [] | |||||
elif self.prompt_type == 'prev_output': | elif self.prompt_type == 'prev_output': | ||||
decoder_prompt = inputs[:-1] | |||||
prefix_token = instruction_itm[:-1] # remove eos | |||||
else: | else: | ||||
raise NotImplementedError | raise NotImplementedError | ||||
sample = { | sample = { | ||||
'source': inputs, | |||||
'decoder_prompt': decoder_prompt, | |||||
'prefix_token': decoder_prompt[:-1], | |||||
'source': instruction_itm, | |||||
'prefix_token': prefix_token, | |||||
} | } | ||||
if 'label' in data: | |||||
sample['label'] = self.label2ans[data['label']] | |||||
return sample | return sample |
@@ -30,7 +30,7 @@ class OfaTextToImageSynthesisPreprocessor(OfaBasePreprocessor): | |||||
source = ' '.join( | source = ' '.join( | ||||
data['text'].lower().strip().split()[:self.max_src_length]) | data['text'].lower().strip().split()[:self.max_src_length]) | ||||
source = 'what is the complete image? caption: {}'.format(source) | source = 'what is the complete image? caption: {}'.format(source) | ||||
inputs = self.get_inputs(source) | |||||
inputs = self.tokenize_text(source) | |||||
sample = { | sample = { | ||||
'source': inputs, | 'source': inputs, | ||||
'patch_images': None, | 'patch_images': None, | ||||
@@ -47,6 +47,8 @@ def collate_fn(samples, pad_idx, eos_idx): | |||||
batch['conf'] = torch.cat([s['conf'] for s in samples], dim=0) | batch['conf'] = torch.cat([s['conf'] for s in samples], dim=0) | ||||
if samples[0].get('ref_dict', None) is not None: | if samples[0].get('ref_dict', None) is not None: | ||||
batch['ref_dict'] = np.array([s['ref_dict'] for s in samples]) | batch['ref_dict'] = np.array([s['ref_dict'] for s in samples]) | ||||
if samples[0].get('label', None) is not None: | |||||
batch['labels'] = np.array([s['label'] for s in samples]).tolist() | |||||
if samples[0].get('constraint_mask', None) is not None: | if samples[0].get('constraint_mask', None) is not None: | ||||
batch['constraint_masks'] = merge('constraint_mask') | batch['constraint_masks'] = merge('constraint_mask') | ||||
if samples[0].get('decoder_prompt', None) is not None: | if samples[0].get('decoder_prompt', None) is not None: | ||||
@@ -53,7 +53,7 @@ class OfaVisualEntailmentPreprocessor(OfaBasePreprocessor): | |||||
prompt = self.cfg.model.get( | prompt = self.cfg.model.get( | ||||
'prompt', ' can image and text1 " {} " imply text2 " {} "?') | 'prompt', ' can image and text1 " {} " imply text2 " {} "?') | ||||
text = prompt.format(caption, hypothesis) | text = prompt.format(caption, hypothesis) | ||||
inputs = self.get_inputs(text) | |||||
inputs = self.tokenize_text(text) | |||||
if self.prompt_type == 'none': | if self.prompt_type == 'none': | ||||
decoder_prompt = self.bos_item | decoder_prompt = self.bos_item | ||||
elif self.prompt_type == 'src': | elif self.prompt_type == 'src': | ||||
@@ -48,7 +48,7 @@ class OfaVisualGroundingPreprocessor(OfaBasePreprocessor): | |||||
prompt = self.cfg.model.get( | prompt = self.cfg.model.get( | ||||
'prompt', ' which region does the text " {} " describe?') | 'prompt', ' which region does the text " {} " describe?') | ||||
text = prompt.format(src_caption) | text = prompt.format(src_caption) | ||||
src_item = self.get_inputs(text) | |||||
src_item = self.tokenize_text(text) | |||||
sample = { | sample = { | ||||
'source': src_item, | 'source': src_item, | ||||
'patch_image': patch_image, | 'patch_image': patch_image, | ||||
@@ -42,7 +42,7 @@ class OfaVisualQuestionAnsweringPreprocessor(OfaBasePreprocessor): | |||||
data['image'], Image.Image) else load_image(data['image']) | data['image'], Image.Image) else load_image(data['image']) | ||||
patch_image = self.patch_resize_transform(image) | patch_image = self.patch_resize_transform(image) | ||||
text = ' {}'.format(data['text']) | text = ' {}'.format(data['text']) | ||||
inputs = self.get_inputs(text) | |||||
inputs = self.tokenize_text(text) | |||||
if self.prompt_type == 'none': | if self.prompt_type == 'none': | ||||
decoder_prompt = self.bos_item | decoder_prompt = self.bos_item | ||||
elif self.prompt_type == 'src': | elif self.prompt_type == 'src': | ||||
@@ -3,6 +3,7 @@ from functools import partial | |||||
from typing import Dict, Optional | from typing import Dict, Optional | ||||
from datasets import load_dataset | from datasets import load_dataset | ||||
from torch import distributed as dist | |||||
from modelscope.metainfo import Trainers | from modelscope.metainfo import Trainers | ||||
from modelscope.models.base import Model | from modelscope.models.base import Model | ||||
@@ -15,7 +16,7 @@ from modelscope.trainers.optimizer.builder import build_optimizer | |||||
from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
from modelscope.utils.constant import ConfigKeys, ModeKeys, ModelFile | from modelscope.utils.constant import ConfigKeys, ModeKeys, ModelFile | ||||
from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion, | from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion, | ||||
OFADataset, get_schedule) | |||||
get_schedule) | |||||
@TRAINERS.register_module(module_name=Trainers.ofa_tasks) | @TRAINERS.register_module(module_name=Trainers.ofa_tasks) | ||||
@@ -36,31 +37,13 @@ class OFATrainer(EpochBasedTrainer): | |||||
preprocessor = { | preprocessor = { | ||||
ConfigKeys.train: | ConfigKeys.train: | ||||
OfaPreprocessor( | OfaPreprocessor( | ||||
model_dir=model_dir, model=ModeKeys.TRAIN, no_collate=True), | |||||
model_dir=model_dir, mode=ModeKeys.TRAIN, no_collate=True), | |||||
ConfigKeys.val: | ConfigKeys.val: | ||||
OfaPreprocessor( | OfaPreprocessor( | ||||
model_dir=model_dir, model=ModeKeys.EVAL, no_collate=True), | |||||
model_dir=model_dir, mode=ModeKeys.EVAL, no_collate=True), | |||||
} | } | ||||
# train_dataset = dataset['train'].to_torch_dataset( | |||||
# preprocessors=OfaPreprocessor(model_dir=model_dir, model=ModeKeys.TRAIN, no_collate=True), | |||||
# ) | |||||
# valid_dataset = dataset['valid'].to_torch_dataset( | |||||
# preprocessors=OfaPreprocessor(model_dir=model_dir, model=ModeKeys.TRAIN, no_collate=True), | |||||
# ) | |||||
# train_dataset = OFADataset( | |||||
# file_path=cfg.dataset.train_set, | |||||
# selected_id_keys=cfg.dataset.selected_id_keys, | |||||
# preprocessor=OfaPreprocessor( | |||||
# model_dir=model_dir, mode=ModeKeys.TRAIN), | |||||
# ) | |||||
# val_dataset = OFADataset( | |||||
# file_path=cfg.dataset.valid_set, | |||||
# selected_id_keys=cfg.dataset.selected_id_keys, | |||||
# preprocessor=OfaPreprocessor( | |||||
# model_dir=model_dir, mode=ModeKeys.EVAL), | |||||
# ) | |||||
epoch_steps = len(dataset['train']) // ( | epoch_steps = len(dataset['train']) // ( | ||||
cfg.train.gradient_accumulation_steps | |||||
cfg.train.optimizer_hook.cumulative_iters | |||||
* cfg.train.dataloader.batch_size_per_gpu) | * cfg.train.dataloader.batch_size_per_gpu) | ||||
cfg.train.lr_scheduler.num_train_steps = epoch_steps * cfg.train.max_epochs | cfg.train.lr_scheduler.num_train_steps = epoch_steps * cfg.train.max_epochs | ||||
cfg.train.criterion.tokenizer = model.tokenizer | cfg.train.criterion.tokenizer = model.tokenizer | ||||
@@ -78,6 +61,11 @@ class OFATrainer(EpochBasedTrainer): | |||||
pad_idx=model.tokenizer.pad_token_id, | pad_idx=model.tokenizer.pad_token_id, | ||||
eos_idx=model.tokenizer.eos_token_id, | eos_idx=model.tokenizer.eos_token_id, | ||||
) | ) | ||||
if 'launcher' not in kwargs and cfg.train.get('launcher', None): | |||||
kwargs['launcher'] = cfg.train.launcher | |||||
if 'use_fp16' not in kwargs and cfg.train.get('use_fp16', False): | |||||
kwargs['use_fp16'] = cfg.train.use_fp16 | |||||
super().__init__( | super().__init__( | ||||
cfg_file=cfg_file, | cfg_file=cfg_file, | ||||
model=model, | model=model, | ||||
@@ -91,14 +79,28 @@ class OFATrainer(EpochBasedTrainer): | |||||
**kwargs, | **kwargs, | ||||
) | ) | ||||
# def train(self, *args, **kwargs): | |||||
# pass | |||||
def evaluate(self, | |||||
checkpoint_path: Optional[str] = None, | |||||
*args, | |||||
**kwargs) -> Dict[str, float]: | |||||
pass | |||||
def prediction_step(self, model, inputs): | |||||
pass | |||||
def train_step(self, model, inputs): | |||||
model.train() | |||||
model_outputs = model.forward(inputs) | |||||
loss, sample_size, logging_output = self.criterion( | |||||
model_outputs, inputs) | |||||
train_outputs = {'loss': loss} | |||||
# add model output info to log | |||||
if 'log_vars' not in train_outputs: | |||||
default_keys_pattern = ['loss'] | |||||
match_keys = set([]) | |||||
for key_p in default_keys_pattern: | |||||
match_keys.update( | |||||
[key for key in train_outputs.keys() if key_p in key]) | |||||
log_vars = {} | |||||
for key in match_keys: | |||||
value = train_outputs.get(key, None) | |||||
if value is not None: | |||||
if dist.is_available() and dist.is_initialized(): | |||||
value = value.data.clone() | |||||
dist.all_reduce(value.div_(dist.get_world_size())) | |||||
log_vars.update({key: value.item()}) | |||||
self.log_buffer.update(log_vars) | |||||
else: | |||||
self.log_buffer.update(train_outputs['log_vars']) | |||||
self.train_outputs = train_outputs |
@@ -1,120 +0,0 @@ | |||||
import os | |||||
from os import path as osp | |||||
from typing import Dict, Optional | |||||
import torch | |||||
import torch.distributed as dist | |||||
import transformers | |||||
from torch.utils.data import DataLoader | |||||
from torch.utils.data.distributed import DistributedSampler | |||||
from modelscope.metainfo import Trainers | |||||
from modelscope.models.base import Model | |||||
from modelscope.preprocessors.multi_modal import OfaPreprocessor | |||||
from modelscope.preprocessors.ofa.utils.collate import collate_fn | |||||
from modelscope.trainers.base import BaseTrainer | |||||
from modelscope.trainers.builder import TRAINERS | |||||
from modelscope.utils.constant import ModeKeys, ModelFile | |||||
from modelscope.utils.logger import get_logger | |||||
from modelscope.utils.torch_utils import init_dist | |||||
from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion, | |||||
OFADataset, get_schedule) | |||||
logger = get_logger() | |||||
@TRAINERS.register_module(module_name=Trainers.ofa_tasks) | |||||
class OFAOldTrainer(BaseTrainer): | |||||
def __init__(self, model: str, *args, **kwargs): | |||||
model = Model.from_pretrained(model) | |||||
super().__init__(osp.join(model.model_dir, ModelFile.CONFIGURATION)) | |||||
self.model_dir = model.model_dir | |||||
self.model = model.model | |||||
self.device_id = 0 | |||||
self.total_epoch = self.cfg.train.epoch | |||||
self.train_batch_size = self.cfg.train.batch_size | |||||
self.val_batch_size = self.cfg.evaluation.batch_size | |||||
self.save_dir = self.cfg.train.save_dir | |||||
init_dist(launcher='pytorch') | |||||
self.train_dataset = OFADataset( | |||||
file_path=self.cfg.dataset.train_set, | |||||
selected_id_keys=self.cfg.dataset.selected_id_keys, | |||||
preprocessor=OfaPreprocessor( | |||||
model_dir=self.model_dir, split=ModeKeys.TRAIN), | |||||
) | |||||
self.val_dataset = OFADataset( | |||||
file_path=self.cfg.dataset.valid_set, | |||||
selected_id_keys=self.cfg.dataset.selected_id_keys, | |||||
preprocessor=OfaPreprocessor( | |||||
model_dir=self.model_dir, split=ModeKeys.EVAL), | |||||
) | |||||
epoch_steps = len( | |||||
self.train_dataset) // self.cfg.train.gradient_accumulation_steps | |||||
self.cfg.train.num_train_steps = epoch_steps * self.cfg.train.epoch | |||||
self.criterion = AdjustLabelSmoothedCrossEntropyCriterion( | |||||
self.cfg.train.criterion) | |||||
def train(self, *args, **kwargs): | |||||
assert dist.is_initialized() | |||||
self.model.train() | |||||
self.model.to(self.device_id) | |||||
ddp_model = torch.nn.parallel.DistributedDataParallel( | |||||
self.model, device_ids=[ | |||||
self.device_id, | |||||
]) | |||||
optimizer = transformers.AdamW( | |||||
self.model.parameters(), | |||||
lr=self.cfg.train.lr, | |||||
weight_decay=self.cfg.train.weight_decay, | |||||
correct_bias=False, | |||||
) | |||||
scheduler_class, scheduler_args = get_schedule(self.cfg.train) | |||||
if scheduler_class is not None: | |||||
lr_scheduler = scheduler_class(**{'optimizer': optimizer}, | |||||
**scheduler_args) | |||||
else: | |||||
lr_scheduler = None | |||||
for epoch in range(self.total_epoch): | |||||
train_sampler = DistributedSampler( | |||||
dataset=self.train_dataset, shuffle=True) | |||||
train_sampler.set_epoch(epoch) | |||||
train_params = { | |||||
'pin_memory': True, | |||||
'collate_fn': collate_fn, | |||||
'batch_size': self.train_batch_size, | |||||
'shuffle': False, | |||||
'drop_last': True, | |||||
'sampler': train_sampler, | |||||
'num_workers': 2, | |||||
} | |||||
train_loader = DataLoader(self.train_dataset, **train_params) | |||||
for idx, batch in enumerate(train_loader, start=1): | |||||
model_outputs = ddp_model(**batch) | |||||
loss, sample_size, logging_output = self.criterion( | |||||
model_outputs, batch) | |||||
loss.backward() | |||||
optimizer.zero_grad() | |||||
if lr_scheduler is not None: | |||||
lr_scheduler.step() | |||||
optimizer.step() | |||||
optimizer.zero_grad() | |||||
if idx % 10 == 0: | |||||
logger.info( | |||||
'epoch: {}, train batch {}/{}, loss={:.5f}'.format( | |||||
epoch, idx, len(train_loader), loss.item())) | |||||
if dist.get_rank() == 0: | |||||
os.makedirs(self.ckpt_dir, exist_ok=True) | |||||
torch.save(ddp_model.module.state_dict(), | |||||
f'{self.ckpt_dir}/epoch{epoch}.bin') | |||||
def evaluate(self, | |||||
checkpoint_path: Optional[str] = None, | |||||
*args, | |||||
**kwargs) -> Dict[str, float]: | |||||
pass |
@@ -172,47 +172,11 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||||
2) the sample size, which is used as the denominator for the gradient | 2) the sample size, which is used as the denominator for the gradient | ||||
3) logging outputs to display while training | 3) logging outputs to display while training | ||||
""" | """ | ||||
if isinstance(sample, list): | |||||
if self.sample_patch_num > 0: | |||||
sample[0]['net_input'][ | |||||
'sample_patch_num'] = self.sample_patch_num | |||||
loss_v1, sample_size_v1, logging_output_v1 = self.forward( | |||||
output[0], sample[0], update_num, reduce) | |||||
loss_v2, sample_size_v2, logging_output_v2 = self.forward( | |||||
output[1], sample[1], update_num, reduce) | |||||
loss = loss_v1 / sample_size_v1 + loss_v2 / sample_size_v2 | |||||
sample_size = 1 | |||||
logging_output = { | |||||
'loss': | |||||
loss.data, | |||||
'loss_v1': | |||||
loss_v1.data, | |||||
'loss_v2': | |||||
loss_v2.data, | |||||
'nll_loss': | |||||
logging_output_v1['nll_loss'].data / sample_size_v1 | |||||
+ logging_output_v2['nll_loss'].data / sample_size_v2, | |||||
'ntokens': | |||||
logging_output_v1['ntokens'] + logging_output_v2['ntokens'], | |||||
'nsentences': | |||||
logging_output_v1['nsentences'] | |||||
+ logging_output_v2['nsentences'], | |||||
'sample_size': | |||||
1, | |||||
'sample_size_v1': | |||||
sample_size_v1, | |||||
'sample_size_v2': | |||||
sample_size_v2, | |||||
} | |||||
return loss, sample_size, logging_output | |||||
if self.use_rdrop: | if self.use_rdrop: | ||||
construct_rdrop_sample(sample) | construct_rdrop_sample(sample) | ||||
net_output = output | |||||
# model(**sample["net_input"]) | |||||
loss, nll_loss, ntokens = self.compute_loss( | loss, nll_loss, ntokens = self.compute_loss( | ||||
net_output, sample, update_num, reduce=reduce) | |||||
output, sample, update_num, reduce=reduce) | |||||
sample_size = ( | sample_size = ( | ||||
sample['target'].size(0) if self.sentence_avg else ntokens) | sample['target'].size(0) if self.sentence_avg else ntokens) | ||||
logging_output = { | logging_output = { | ||||
@@ -12,6 +12,7 @@ import numpy as np | |||||
import torch | import torch | ||||
from torch import distributed as dist | from torch import distributed as dist | ||||
from torch import nn | from torch import nn | ||||
from torch.nn.parallel import DistributedDataParallel | |||||
from torch.utils.data import DataLoader, Dataset | from torch.utils.data import DataLoader, Dataset | ||||
from torch.utils.data.dataloader import default_collate | from torch.utils.data.dataloader import default_collate | ||||
from torch.utils.data.distributed import DistributedSampler | from torch.utils.data.distributed import DistributedSampler | ||||
@@ -159,8 +160,6 @@ class EpochBasedTrainer(BaseTrainer): | |||||
train_dataset, | train_dataset, | ||||
mode=ModeKeys.TRAIN, | mode=ModeKeys.TRAIN, | ||||
preprocessor=self.train_preprocessor) | preprocessor=self.train_preprocessor) | ||||
# import pdb | |||||
# pdb.set_trace() | |||||
self.eval_dataset = self.to_task_dataset( | self.eval_dataset = self.to_task_dataset( | ||||
eval_dataset, | eval_dataset, | ||||
mode=ModeKeys.EVAL, | mode=ModeKeys.EVAL, | ||||
@@ -200,7 +199,6 @@ class EpochBasedTrainer(BaseTrainer): | |||||
self._max_epochs = self.cfg.train.max_epochs | self._max_epochs = self.cfg.train.max_epochs | ||||
else: | else: | ||||
self._max_epochs = kwargs['max_epochs'] | self._max_epochs = kwargs['max_epochs'] | ||||
self._train_iters_per_epoch = kwargs.get('train_iters_per_epoch', None) | self._train_iters_per_epoch = kwargs.get('train_iters_per_epoch', None) | ||||
self._eval_iters_per_epoch = kwargs.get('val_iters_per_epoch', None) | self._eval_iters_per_epoch = kwargs.get('val_iters_per_epoch', None) | ||||
if self._train_iters_per_epoch is None and hasattr( | if self._train_iters_per_epoch is None and hasattr( | ||||
@@ -220,12 +218,12 @@ class EpochBasedTrainer(BaseTrainer): | |||||
init_dist(kwargs['launcher']) | init_dist(kwargs['launcher']) | ||||
self._dist = get_dist_info()[1] > 1 | self._dist = get_dist_info()[1] > 1 | ||||
# model placement | # model placement | ||||
if self.device.type == 'cuda': | if self.device.type == 'cuda': | ||||
self.model.to(self.device) | self.model.to(self.device) | ||||
if not is_parallel(self.model) and self._dist: | if not is_parallel(self.model) and self._dist: | ||||
self.model = self.to_parallel(self.model) | self.model = self.to_parallel(self.model) | ||||
self.device = self.model.device | |||||
def rebuild_config(self, cfg: Config): | def rebuild_config(self, cfg: Config): | ||||
"""A method used to rebuild the config, any subclass can override this method. | """A method used to rebuild the config, any subclass can override this method. | ||||
@@ -429,7 +427,7 @@ class EpochBasedTrainer(BaseTrainer): | |||||
self.register_hook_from_cfg(self.cfg.train.hooks) | self.register_hook_from_cfg(self.cfg.train.hooks) | ||||
self.train_loop(self.train_dataloader) | self.train_loop(self.train_dataloader) | ||||
def evaluate(self, checkpoint_path=None): | |||||
def evaluate(self, checkpoint_path=None, *arg, **kwargs): | |||||
self.model.eval() | self.model.eval() | ||||
self._mode = ModeKeys.EVAL | self._mode = ModeKeys.EVAL | ||||
@@ -475,12 +473,12 @@ class EpochBasedTrainer(BaseTrainer): | |||||
self.cfg.parallel.update( | self.cfg.parallel.update( | ||||
dict(module=model, device_ids=[torch.cuda.current_device()])) | dict(module=model, device_ids=[torch.cuda.current_device()])) | ||||
return build_parallel(self.cfg.parallel) | return build_parallel(self.cfg.parallel) | ||||
model.to(f'cuda:{torch.cuda.current_device()}') | |||||
dp_cfg = dict( | dp_cfg = dict( | ||||
type='DistributedDataParallel', | type='DistributedDataParallel', | ||||
module=model, | module=model, | ||||
find_unused_parameters=True, | |||||
device_ids=[torch.cuda.current_device()]) | device_ids=[torch.cuda.current_device()]) | ||||
return build_parallel(dp_cfg) | return build_parallel(dp_cfg) | ||||
def train_step(self, model, inputs): | def train_step(self, model, inputs): | ||||
@@ -504,8 +502,10 @@ class EpochBasedTrainer(BaseTrainer): | |||||
model.train() | model.train() | ||||
self._mode = ModeKeys.TRAIN | self._mode = ModeKeys.TRAIN | ||||
# call model forward but not __call__ to skip postprocess | # call model forward but not __call__ to skip postprocess | ||||
forward_func = model.module.forward if \ | |||||
isinstance(model, DistributedDataParallel) else model.forward | |||||
if isinstance(inputs, | if isinstance(inputs, | ||||
Mapping) and not func_receive_dict_inputs(model.forward): | |||||
Mapping) and not func_receive_dict_inputs(forward_func): | |||||
train_outputs = model.forward(**inputs) | train_outputs = model.forward(**inputs) | ||||
else: | else: | ||||
train_outputs = model.forward(inputs) | train_outputs = model.forward(inputs) | ||||
@@ -751,7 +751,7 @@ class EpochBasedTrainer(BaseTrainer): | |||||
batch_size = batch_size_per_gpu | batch_size = batch_size_per_gpu | ||||
num_workers = workers_per_gpu | num_workers = workers_per_gpu | ||||
if dist: | |||||
if dist and not isinstance(dataset, torch.utils.data.IterableDataset): | |||||
sampler = DistributedSampler( | sampler = DistributedSampler( | ||||
dataset, num_replicas=world_size, rank=rank, shuffle=shuffle) | dataset, num_replicas=world_size, rank=rank, shuffle=shuffle) | ||||
else: | else: | ||||
@@ -9,6 +9,7 @@ from collections.abc import Mapping | |||||
import torch | import torch | ||||
from torch import distributed as dist | from torch import distributed as dist | ||||
from torch.nn.parallel import DistributedDataParallel | |||||
from tqdm import tqdm | from tqdm import tqdm | ||||
from modelscope.utils.data_utils import to_device | from modelscope.utils.data_utils import to_device | ||||
@@ -68,7 +69,10 @@ def single_gpu_test(model, | |||||
batch_size = 1 # iteration count | batch_size = 1 # iteration count | ||||
else: | else: | ||||
if isinstance(data, dict): | if isinstance(data, dict): | ||||
batch_size = len(next(iter(data.values()))) | |||||
if 'nsentences' in data: | |||||
batch_size = data['nsentences'] | |||||
else: | |||||
batch_size = len(next(iter(data.values()))) | |||||
else: | else: | ||||
batch_size = len(data) | batch_size = len(data) | ||||
for _ in range(batch_size): | for _ in range(batch_size): | ||||
@@ -142,28 +146,38 @@ def multi_gpu_test(model, | |||||
data = to_device(data, device) | data = to_device(data, device) | ||||
data_list.append(data) | data_list.append(data) | ||||
with torch.no_grad(): | with torch.no_grad(): | ||||
if isinstance(data, Mapping) and not func_receive_dict_inputs( | |||||
model.forward): | |||||
forward_func = model.module.forward if \ | |||||
isinstance(model, DistributedDataParallel) else model.forward | |||||
if isinstance(data, Mapping | |||||
) and not func_receive_dict_inputs(forward_func): | |||||
result = model.forward(**data) | result = model.forward(**data) | ||||
else: | else: | ||||
result = model.forward(data) | result = model.forward(data) | ||||
results.append(result) | results.append(result) | ||||
if rank == 0: | |||||
if isinstance(data, dict): | |||||
batch_size = len(next(iter(data.values()))) | |||||
if isinstance(data, dict): | |||||
if 'nsentences' in data: | |||||
batch_size = data['nsentences'] | |||||
else: | else: | ||||
batch_size = len(data) | |||||
if progress_with_iters: | |||||
total_samples += batch_size * world_size | |||||
batch_size = 1 # iteration count | |||||
batch_size = len(next(iter(data.values()))) | |||||
else: | |||||
batch_size = len(data) | |||||
if i >= (data_len // world_size) - 1: | |||||
total_samples = torch.LongTensor([batch_size]).to(model.device) | |||||
dist.all_reduce(total_samples, op=dist.reduce_op.SUM) | |||||
total_samples = total_samples.item() | |||||
else: | |||||
total_samples = batch_size * world_size | |||||
if progress_with_iters: | |||||
iter_cnt_all = world_size | |||||
else: | |||||
iter_cnt_all = total_samples | |||||
count += iter_cnt_all | |||||
batch_size_all = batch_size * world_size | |||||
count += batch_size_all | |||||
if rank == 0: | |||||
if count > data_len: | if count > data_len: | ||||
batch_size_all = data_len - (count - batch_size_all) | |||||
for _ in range(batch_size_all): | |||||
iter_cnt_all = data_len - (count - iter_cnt_all) | |||||
for _ in range(iter_cnt_all): | |||||
pbar.update() | pbar.update() | ||||
if progress_with_iters and (i + 1) >= data_len: | if progress_with_iters and (i + 1) >= data_len: | ||||
@@ -1,5 +1,5 @@ | |||||
# Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
import os | |||||
from contextlib import contextmanager | from contextlib import contextmanager | ||||
from modelscope.utils.constant import Devices, Frameworks | from modelscope.utils.constant import Devices, Frameworks | ||||
@@ -105,3 +105,17 @@ def create_device(device_name): | |||||
device = torch.device('cpu') | device = torch.device('cpu') | ||||
return device | return device | ||||
def get_device(): | |||||
import torch | |||||
from torch import distributed as dist | |||||
if torch.cuda.is_available(): | |||||
if dist.is_available() and dist.is_initialized( | |||||
) and 'LOCAL_RANK' in os.environ: | |||||
device_id = f"cuda:{os.environ['LOCAL_RANK']}" | |||||
else: | |||||
device_id = 'cuda:0' | |||||
else: | |||||
device_id = 'cpu' | |||||
return torch.device(device_id) |
@@ -0,0 +1,17 @@ | |||||
import pdb | |||||
import sys | |||||
class ForkedPdb(pdb.Pdb): | |||||
"""A Pdb subclass that may be used | |||||
from a forked multiprocessing child | |||||
""" | |||||
def interaction(self, *args, **kwargs): | |||||
_stdin = sys.stdin | |||||
try: | |||||
sys.stdin = open('/dev/stdin') | |||||
pdb.Pdb.interaction(self, *args, **kwargs) | |||||
finally: | |||||
sys.stdin = _stdin |
@@ -252,6 +252,27 @@ class OfaTasksTest(unittest.TestCase): | |||||
result[OutputKeys.OUTPUT_IMG].save('result.png') | result[OutputKeys.OUTPUT_IMG].save('result.png') | ||||
print(f'Output written to {osp.abspath("result.png")}') | print(f'Output written to {osp.abspath("result.png")}') | ||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
def test_run_with_visual_question_answering_huge_with_name(self): | |||||
model = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/ofa_visual-question-answering_pretrain_huge_en' | |||||
ofa_pipe = pipeline(Tasks.visual_question_answering, model=model) | |||||
image = 'data/test/images/visual_question_answering.png' | |||||
text = 'what is grown on the plant?' | |||||
input = {'image': image, 'text': text} | |||||
result = ofa_pipe(input) | |||||
print(result) | |||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
def test_run_with_image_captioning_huge_with_name(self): | |||||
model = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/ofa_image-caption_coco_huge_en' | |||||
img_captioning = pipeline( | |||||
task=Tasks.image_captioning, | |||||
model=model, | |||||
) | |||||
result = img_captioning( | |||||
{'image': 'data/test/images/image_captioning.png'}) | |||||
print(result[OutputKeys.CAPTION]) | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
unittest.main() | unittest.main() |
@@ -11,6 +11,7 @@ class TestOfaTrainer(unittest.TestCase): | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
def test_trainer(self): | def test_trainer(self): | ||||
model_id = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/maas_mnli_pretrain_ckpt' | |||||
model_id = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/ofa_text-classification_mnli_large_en' | model_id = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/ofa_text-classification_mnli_large_en' | ||||
self.trainer = OFATrainer(model_id) | self.trainer = OFATrainer(model_id) | ||||
self.trainer.train() | self.trainer.train() | ||||