diff --git a/modelscope/models/multi_modal/ofa_for_all_tasks.py b/modelscope/models/multi_modal/ofa_for_all_tasks.py index ab9b0357..38d1538d 100644 --- a/modelscope/models/multi_modal/ofa_for_all_tasks.py +++ b/modelscope/models/multi_modal/ofa_for_all_tasks.py @@ -129,8 +129,7 @@ class OfaForAllTasks(TorchModel): result_l = list() for cap in caption: result_l.append(cap.translate(self.transtab).strip()) - input[OutputKeys.CAPTION] = caption - + input[OutputKeys.CAPTION] = result_l return input def _text_gen_inference(self, input): @@ -182,6 +181,8 @@ class OfaForAllTasks(TorchModel): encoder_input[key] = input['net_input'][key] encoder_out = self.model.encoder(**encoder_input) valid_result = [] + import pdb + pdb.set_trace() for val_ans, val_masks in zip(self.val_ans_l, self.val_masks_l): valid_size = len(val_ans) valid_tgt_items = [ diff --git a/modelscope/trainers/multi_modal/ofa/ofa_file_dataset.py b/modelscope/trainers/multi_modal/ofa/ofa_file_dataset.py deleted file mode 100644 index 138f1303..00000000 --- a/modelscope/trainers/multi_modal/ofa/ofa_file_dataset.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright 2022 The OFA-Sys Team. -# All rights reserved. -# This source code is licensed under the Apache 2.0 license -# found in the LICENSE file in the root directory. - -import os -import pickle - -import torch - - -class OFAFileDataset: - - def __init__(self, - file_path, - selected_col_ids=None, - dtypes=None, - separator='\t', - cached_index=False): - self.file_path = file_path - assert os.path.exists( - self.file_path), 'Error: The local datafile {} not exists!'.format( - self.file_path) - - self.separator = separator - if selected_col_ids is None: - # default to all fields - self.selected_col_ids = list( - range( - len( - open(self.file_path).readline().rstrip('\n').split( - self.separator)))) - else: - self.selected_col_ids = [ - int(col_id) for col_id in selected_col_ids.split(',') - ] - if dtypes is None: - # default to str - self.dtypes = [str for col_id in self.selected_col_ids] - else: - self.dtypes = [eval(col_dtype) for col_dtype in dtypes.split(',')] - assert len(self.dtypes) == len(self.selected_col_ids) - - self.data_cnt = 0 - try: - self.slice_id = torch.distributed.get_rank() - self.slice_count = torch.distributed.get_world_size() - except Exception: - self.slice_id = 0 - self.slice_count = 1 - self.cached_index = cached_index - self._init_seek_index() - self._reader = self._get_reader() - print('file {} slice_id {} row count {} total row count {}'.format( - self.file_path, self.slice_id, self.row_count, - self.total_row_count)) - - def _init_seek_index(self): - if self.cached_index: - cache_path = '{}.index'.format(self.file_path) - assert os.path.exists( - cache_path), 'cache file {} not exists!'.format(cache_path) - self.total_row_count, self.lineid_to_offset = pickle.load( - open(cache_path, 'rb')) - print( - 'local datafile {} slice_id {} use cached row_count and line_idx-to-offset mapping' - .format(self.file_path, self.slice_id)) - else: - # make an iteration over the file to get row_count and line_idx-to-offset mapping - fp = open(self.file_path, 'r') - print( - 'local datafile {} slice_id {} begin to initialize row_count and line_idx-to-offset mapping' - .format(self.file_path, self.slice_id)) - self.total_row_count = 0 - offset = 0 - self.lineid_to_offset = [] - for line in fp: - self.lineid_to_offset.append(offset) - self.total_row_count += 1 - offset += len(line.encode('utf-8')) - pickle.dump(self.lineid_to_offset, - open('{}.index'.format(self.file_path), 'wb')) - self._compute_start_pos_and_row_count() - print( - 'local datafile {} slice_id {} finished initializing row_count and line_idx-to-offset mapping' - .format(self.file_path, self.slice_id)) - - def _compute_start_pos_and_row_count(self): - self.row_count = self.total_row_count // self.slice_count - if self.slice_id < self.total_row_count - self.row_count * self.slice_count: - self.row_count += 1 - self.start_pos = self.row_count * self.slice_id - else: - self.start_pos = self.row_count * self.slice_id + ( - self.total_row_count - self.row_count * self.slice_count) - - def _get_reader(self): - fp = open(self.file_path, 'r') - fp.seek(self.lineid_to_offset[self.start_pos]) - return fp - - def _seek(self, offset=0): - try: - print('slice_id {} seek offset {}'.format(self.slice_id, - self.start_pos + offset)) - self._reader.seek(self.lineid_to_offset[self.start_pos + offset]) - self.data_cnt = offset - except Exception: - print('slice_id {} seek offset {}'.format(self.slice_id, offset)) - self._reader.seek(self.lineid_to_offset[offset]) - self.data_cnt = offset - - def __del__(self): - self._reader.close() - - def __len__(self): - return self.row_count - - def get_total_row_count(self): - return self.total_row_count - - def __getitem__(self, index): - if self.data_cnt == self.row_count: - print('reach the end of datafile, start a new reader') - self.data_cnt = 0 - self._reader = self._get_reader() - column_l = self._reader.readline().rstrip('\n').split(self.separator) - self.data_cnt += 1 - column_l = [ - dtype(column_l[col_id]) - for col_id, dtype in zip(self.selected_col_ids, self.dtypes) - ] - return column_l diff --git a/modelscope/trainers/multi_modal/ofa/ofa_trainer.py b/modelscope/trainers/multi_modal/ofa/ofa_trainer.py index c17a15f7..42a68d02 100644 --- a/modelscope/trainers/multi_modal/ofa/ofa_trainer.py +++ b/modelscope/trainers/multi_modal/ofa/ofa_trainer.py @@ -65,7 +65,7 @@ class OFATrainer(EpochBasedTrainer): kwargs['launcher'] = cfg.train.launcher if 'use_fp16' not in kwargs and cfg.train.get('use_fp16', False): kwargs['use_fp16'] = cfg.train.use_fp16 - + kwargs['to_tensor'] = False super().__init__( cfg_file=cfg_file, model=model, diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index 793092c8..8412280b 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -167,19 +167,20 @@ class EpochBasedTrainer(BaseTrainer): device_name = f'cuda:{local_rank}' self.device = create_device(device_name) - self.train_dataset = self.to_task_dataset( train_dataset, mode=ModeKeys.TRAIN, task_data_config=self.cfg.dataset.get('train', None) if hasattr( self.cfg, 'dataset') else None, - preprocessor=self.train_preprocessor) + preprocessor=self.train_preprocessor, + **kwargs) self.eval_dataset = self.to_task_dataset( eval_dataset, mode=ModeKeys.EVAL, task_data_config=self.cfg.dataset.get('val', None) if hasattr( self.cfg, 'dataset') else None, - preprocessor=self.eval_preprocessor) + preprocessor=self.eval_preprocessor, + **kwargs) self.train_data_collator, self.eval_default_collate = None, None if isinstance(data_collator, Mapping): @@ -305,13 +306,15 @@ class EpochBasedTrainer(BaseTrainer): datasets: Union[Dataset, List[Dataset]], mode: str, task_data_config: Config = None, - preprocessor: Optional[Preprocessor] = None): + preprocessor: Optional[Preprocessor] = None, + **kwargs): """Build the task specific dataset processor for this trainer. Returns: The task dataset processor for the task. If no result for the very model-type and task, the default TaskDataset will be returned. """ try: + to_tensor = kwargs.get('to_tensor', True) if not datasets: return datasets if isinstance(datasets, TorchTaskDataset): @@ -327,7 +330,8 @@ class EpochBasedTrainer(BaseTrainer): return datasets.to_torch_dataset( task_data_config=task_data_config, task_name=self.cfg.task, - preprocessors=preprocessor) + preprocessors=preprocessor, + to_tensor=to_tensor) elif isinstance(datasets, List) and isinstance( datasets[0], MsDataset): if task_data_config is None: @@ -341,7 +345,8 @@ class EpochBasedTrainer(BaseTrainer): d.to_torch_dataset( task_data_config=task_data_config, task_name=self.cfg.task, - preprocessors=preprocessor) for d in datasets + preprocessors=preprocessor, + to_tensor=to_tensor) for d in datasets ] cfg = ConfigDict( type=self.cfg.task, mode=mode, datasets=datasets) diff --git a/tests/pipelines/test_ofa_tasks.py b/tests/pipelines/test_ofa_tasks.py index e6638dfa..d89e5d48 100644 --- a/tests/pipelines/test_ofa_tasks.py +++ b/tests/pipelines/test_ofa_tasks.py @@ -94,8 +94,11 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_text_classification_with_model(self): + # model = Model.from_pretrained( + # 'damo/ofa_text-classification_mnli_large_en') model = Model.from_pretrained( - 'damo/ofa_text-classification_mnli_large_en') + '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/ofa_text-classification_mnli_large_en' + ) ofa_pipe = pipeline(Tasks.text_classification, model=model) text = 'One of our number will carry out your instructions minutely.' text2 = 'A member of my team will execute your orders with immense precision.' diff --git a/tests/trainers/test_ofa_trainer.py b/tests/trainers/test_ofa_trainer.py index 39d9fe0c..3948aad7 100644 --- a/tests/trainers/test_ofa_trainer.py +++ b/tests/trainers/test_ofa_trainer.py @@ -12,11 +12,10 @@ class TestOfaTrainer(unittest.TestCase): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_trainer(self): model_id = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/maas_mnli_pretrain_ckpt' - model_id = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/ofa_text-classification_mnli_large_en' - self.trainer = OFATrainer(model_id) + self.trainer = OFATrainer(model_id, launcher='pytorch') self.trainer.train() if os.path.exists(self.trainer.work_dir): - shutil.rmtree(self.trainer.work_dir) + pass if __name__ == '__main__':