Browse Source

update

master
行嗔 3 years ago
parent
commit
3b09d848ce
6 changed files with 21 additions and 146 deletions
  1. +3
    -2
      modelscope/models/multi_modal/ofa_for_all_tasks.py
  2. +0
    -133
      modelscope/trainers/multi_modal/ofa/ofa_file_dataset.py
  3. +1
    -1
      modelscope/trainers/multi_modal/ofa/ofa_trainer.py
  4. +11
    -6
      modelscope/trainers/trainer.py
  5. +4
    -1
      tests/pipelines/test_ofa_tasks.py
  6. +2
    -3
      tests/trainers/test_ofa_trainer.py

+ 3
- 2
modelscope/models/multi_modal/ofa_for_all_tasks.py View File

@@ -129,8 +129,7 @@ class OfaForAllTasks(TorchModel):
result_l = list()
for cap in caption:
result_l.append(cap.translate(self.transtab).strip())
input[OutputKeys.CAPTION] = caption

input[OutputKeys.CAPTION] = result_l
return input

def _text_gen_inference(self, input):
@@ -182,6 +181,8 @@ class OfaForAllTasks(TorchModel):
encoder_input[key] = input['net_input'][key]
encoder_out = self.model.encoder(**encoder_input)
valid_result = []
import pdb
pdb.set_trace()
for val_ans, val_masks in zip(self.val_ans_l, self.val_masks_l):
valid_size = len(val_ans)
valid_tgt_items = [


+ 0
- 133
modelscope/trainers/multi_modal/ofa/ofa_file_dataset.py View File

@@ -1,133 +0,0 @@
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.

import os
import pickle

import torch


class OFAFileDataset:

def __init__(self,
file_path,
selected_col_ids=None,
dtypes=None,
separator='\t',
cached_index=False):
self.file_path = file_path
assert os.path.exists(
self.file_path), 'Error: The local datafile {} not exists!'.format(
self.file_path)

self.separator = separator
if selected_col_ids is None:
# default to all fields
self.selected_col_ids = list(
range(
len(
open(self.file_path).readline().rstrip('\n').split(
self.separator))))
else:
self.selected_col_ids = [
int(col_id) for col_id in selected_col_ids.split(',')
]
if dtypes is None:
# default to str
self.dtypes = [str for col_id in self.selected_col_ids]
else:
self.dtypes = [eval(col_dtype) for col_dtype in dtypes.split(',')]
assert len(self.dtypes) == len(self.selected_col_ids)

self.data_cnt = 0
try:
self.slice_id = torch.distributed.get_rank()
self.slice_count = torch.distributed.get_world_size()
except Exception:
self.slice_id = 0
self.slice_count = 1
self.cached_index = cached_index
self._init_seek_index()
self._reader = self._get_reader()
print('file {} slice_id {} row count {} total row count {}'.format(
self.file_path, self.slice_id, self.row_count,
self.total_row_count))

def _init_seek_index(self):
if self.cached_index:
cache_path = '{}.index'.format(self.file_path)
assert os.path.exists(
cache_path), 'cache file {} not exists!'.format(cache_path)
self.total_row_count, self.lineid_to_offset = pickle.load(
open(cache_path, 'rb'))
print(
'local datafile {} slice_id {} use cached row_count and line_idx-to-offset mapping'
.format(self.file_path, self.slice_id))
else:
# make an iteration over the file to get row_count and line_idx-to-offset mapping
fp = open(self.file_path, 'r')
print(
'local datafile {} slice_id {} begin to initialize row_count and line_idx-to-offset mapping'
.format(self.file_path, self.slice_id))
self.total_row_count = 0
offset = 0
self.lineid_to_offset = []
for line in fp:
self.lineid_to_offset.append(offset)
self.total_row_count += 1
offset += len(line.encode('utf-8'))
pickle.dump(self.lineid_to_offset,
open('{}.index'.format(self.file_path), 'wb'))
self._compute_start_pos_and_row_count()
print(
'local datafile {} slice_id {} finished initializing row_count and line_idx-to-offset mapping'
.format(self.file_path, self.slice_id))

def _compute_start_pos_and_row_count(self):
self.row_count = self.total_row_count // self.slice_count
if self.slice_id < self.total_row_count - self.row_count * self.slice_count:
self.row_count += 1
self.start_pos = self.row_count * self.slice_id
else:
self.start_pos = self.row_count * self.slice_id + (
self.total_row_count - self.row_count * self.slice_count)

def _get_reader(self):
fp = open(self.file_path, 'r')
fp.seek(self.lineid_to_offset[self.start_pos])
return fp

def _seek(self, offset=0):
try:
print('slice_id {} seek offset {}'.format(self.slice_id,
self.start_pos + offset))
self._reader.seek(self.lineid_to_offset[self.start_pos + offset])
self.data_cnt = offset
except Exception:
print('slice_id {} seek offset {}'.format(self.slice_id, offset))
self._reader.seek(self.lineid_to_offset[offset])
self.data_cnt = offset

def __del__(self):
self._reader.close()

def __len__(self):
return self.row_count

def get_total_row_count(self):
return self.total_row_count

def __getitem__(self, index):
if self.data_cnt == self.row_count:
print('reach the end of datafile, start a new reader')
self.data_cnt = 0
self._reader = self._get_reader()
column_l = self._reader.readline().rstrip('\n').split(self.separator)
self.data_cnt += 1
column_l = [
dtype(column_l[col_id])
for col_id, dtype in zip(self.selected_col_ids, self.dtypes)
]
return column_l

+ 1
- 1
modelscope/trainers/multi_modal/ofa/ofa_trainer.py View File

@@ -65,7 +65,7 @@ class OFATrainer(EpochBasedTrainer):
kwargs['launcher'] = cfg.train.launcher
if 'use_fp16' not in kwargs and cfg.train.get('use_fp16', False):
kwargs['use_fp16'] = cfg.train.use_fp16
kwargs['to_tensor'] = False
super().__init__(
cfg_file=cfg_file,
model=model,


+ 11
- 6
modelscope/trainers/trainer.py View File

@@ -167,19 +167,20 @@ class EpochBasedTrainer(BaseTrainer):
device_name = f'cuda:{local_rank}'

self.device = create_device(device_name)

self.train_dataset = self.to_task_dataset(
train_dataset,
mode=ModeKeys.TRAIN,
task_data_config=self.cfg.dataset.get('train', None) if hasattr(
self.cfg, 'dataset') else None,
preprocessor=self.train_preprocessor)
preprocessor=self.train_preprocessor,
**kwargs)
self.eval_dataset = self.to_task_dataset(
eval_dataset,
mode=ModeKeys.EVAL,
task_data_config=self.cfg.dataset.get('val', None) if hasattr(
self.cfg, 'dataset') else None,
preprocessor=self.eval_preprocessor)
preprocessor=self.eval_preprocessor,
**kwargs)

self.train_data_collator, self.eval_default_collate = None, None
if isinstance(data_collator, Mapping):
@@ -305,13 +306,15 @@ class EpochBasedTrainer(BaseTrainer):
datasets: Union[Dataset, List[Dataset]],
mode: str,
task_data_config: Config = None,
preprocessor: Optional[Preprocessor] = None):
preprocessor: Optional[Preprocessor] = None,
**kwargs):
"""Build the task specific dataset processor for this trainer.

Returns: The task dataset processor for the task. If no result for the very model-type and task,
the default TaskDataset will be returned.
"""
try:
to_tensor = kwargs.get('to_tensor', True)
if not datasets:
return datasets
if isinstance(datasets, TorchTaskDataset):
@@ -327,7 +330,8 @@ class EpochBasedTrainer(BaseTrainer):
return datasets.to_torch_dataset(
task_data_config=task_data_config,
task_name=self.cfg.task,
preprocessors=preprocessor)
preprocessors=preprocessor,
to_tensor=to_tensor)
elif isinstance(datasets, List) and isinstance(
datasets[0], MsDataset):
if task_data_config is None:
@@ -341,7 +345,8 @@ class EpochBasedTrainer(BaseTrainer):
d.to_torch_dataset(
task_data_config=task_data_config,
task_name=self.cfg.task,
preprocessors=preprocessor) for d in datasets
preprocessors=preprocessor,
to_tensor=to_tensor) for d in datasets
]
cfg = ConfigDict(
type=self.cfg.task, mode=mode, datasets=datasets)


+ 4
- 1
tests/pipelines/test_ofa_tasks.py View File

@@ -94,8 +94,11 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck):

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_text_classification_with_model(self):
# model = Model.from_pretrained(
# 'damo/ofa_text-classification_mnli_large_en')
model = Model.from_pretrained(
'damo/ofa_text-classification_mnli_large_en')
'/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/ofa_text-classification_mnli_large_en'
)
ofa_pipe = pipeline(Tasks.text_classification, model=model)
text = 'One of our number will carry out your instructions minutely.'
text2 = 'A member of my team will execute your orders with immense precision.'


+ 2
- 3
tests/trainers/test_ofa_trainer.py View File

@@ -12,11 +12,10 @@ class TestOfaTrainer(unittest.TestCase):
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer(self):
model_id = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/maas_mnli_pretrain_ckpt'
model_id = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/ofa_text-classification_mnli_large_en'
self.trainer = OFATrainer(model_id)
self.trainer = OFATrainer(model_id, launcher='pytorch')
self.trainer.train()
if os.path.exists(self.trainer.work_dir):
shutil.rmtree(self.trainer.work_dir)
pass


if __name__ == '__main__':


Loading…
Cancel
Save