@@ -129,8 +129,7 @@ class OfaForAllTasks(TorchModel): | |||
result_l = list() | |||
for cap in caption: | |||
result_l.append(cap.translate(self.transtab).strip()) | |||
input[OutputKeys.CAPTION] = caption | |||
input[OutputKeys.CAPTION] = result_l | |||
return input | |||
def _text_gen_inference(self, input): | |||
@@ -182,6 +181,8 @@ class OfaForAllTasks(TorchModel): | |||
encoder_input[key] = input['net_input'][key] | |||
encoder_out = self.model.encoder(**encoder_input) | |||
valid_result = [] | |||
import pdb | |||
pdb.set_trace() | |||
for val_ans, val_masks in zip(self.val_ans_l, self.val_masks_l): | |||
valid_size = len(val_ans) | |||
valid_tgt_items = [ | |||
@@ -1,133 +0,0 @@ | |||
# Copyright 2022 The OFA-Sys Team. | |||
# All rights reserved. | |||
# This source code is licensed under the Apache 2.0 license | |||
# found in the LICENSE file in the root directory. | |||
import os | |||
import pickle | |||
import torch | |||
class OFAFileDataset: | |||
def __init__(self, | |||
file_path, | |||
selected_col_ids=None, | |||
dtypes=None, | |||
separator='\t', | |||
cached_index=False): | |||
self.file_path = file_path | |||
assert os.path.exists( | |||
self.file_path), 'Error: The local datafile {} not exists!'.format( | |||
self.file_path) | |||
self.separator = separator | |||
if selected_col_ids is None: | |||
# default to all fields | |||
self.selected_col_ids = list( | |||
range( | |||
len( | |||
open(self.file_path).readline().rstrip('\n').split( | |||
self.separator)))) | |||
else: | |||
self.selected_col_ids = [ | |||
int(col_id) for col_id in selected_col_ids.split(',') | |||
] | |||
if dtypes is None: | |||
# default to str | |||
self.dtypes = [str for col_id in self.selected_col_ids] | |||
else: | |||
self.dtypes = [eval(col_dtype) for col_dtype in dtypes.split(',')] | |||
assert len(self.dtypes) == len(self.selected_col_ids) | |||
self.data_cnt = 0 | |||
try: | |||
self.slice_id = torch.distributed.get_rank() | |||
self.slice_count = torch.distributed.get_world_size() | |||
except Exception: | |||
self.slice_id = 0 | |||
self.slice_count = 1 | |||
self.cached_index = cached_index | |||
self._init_seek_index() | |||
self._reader = self._get_reader() | |||
print('file {} slice_id {} row count {} total row count {}'.format( | |||
self.file_path, self.slice_id, self.row_count, | |||
self.total_row_count)) | |||
def _init_seek_index(self): | |||
if self.cached_index: | |||
cache_path = '{}.index'.format(self.file_path) | |||
assert os.path.exists( | |||
cache_path), 'cache file {} not exists!'.format(cache_path) | |||
self.total_row_count, self.lineid_to_offset = pickle.load( | |||
open(cache_path, 'rb')) | |||
print( | |||
'local datafile {} slice_id {} use cached row_count and line_idx-to-offset mapping' | |||
.format(self.file_path, self.slice_id)) | |||
else: | |||
# make an iteration over the file to get row_count and line_idx-to-offset mapping | |||
fp = open(self.file_path, 'r') | |||
print( | |||
'local datafile {} slice_id {} begin to initialize row_count and line_idx-to-offset mapping' | |||
.format(self.file_path, self.slice_id)) | |||
self.total_row_count = 0 | |||
offset = 0 | |||
self.lineid_to_offset = [] | |||
for line in fp: | |||
self.lineid_to_offset.append(offset) | |||
self.total_row_count += 1 | |||
offset += len(line.encode('utf-8')) | |||
pickle.dump(self.lineid_to_offset, | |||
open('{}.index'.format(self.file_path), 'wb')) | |||
self._compute_start_pos_and_row_count() | |||
print( | |||
'local datafile {} slice_id {} finished initializing row_count and line_idx-to-offset mapping' | |||
.format(self.file_path, self.slice_id)) | |||
def _compute_start_pos_and_row_count(self): | |||
self.row_count = self.total_row_count // self.slice_count | |||
if self.slice_id < self.total_row_count - self.row_count * self.slice_count: | |||
self.row_count += 1 | |||
self.start_pos = self.row_count * self.slice_id | |||
else: | |||
self.start_pos = self.row_count * self.slice_id + ( | |||
self.total_row_count - self.row_count * self.slice_count) | |||
def _get_reader(self): | |||
fp = open(self.file_path, 'r') | |||
fp.seek(self.lineid_to_offset[self.start_pos]) | |||
return fp | |||
def _seek(self, offset=0): | |||
try: | |||
print('slice_id {} seek offset {}'.format(self.slice_id, | |||
self.start_pos + offset)) | |||
self._reader.seek(self.lineid_to_offset[self.start_pos + offset]) | |||
self.data_cnt = offset | |||
except Exception: | |||
print('slice_id {} seek offset {}'.format(self.slice_id, offset)) | |||
self._reader.seek(self.lineid_to_offset[offset]) | |||
self.data_cnt = offset | |||
def __del__(self): | |||
self._reader.close() | |||
def __len__(self): | |||
return self.row_count | |||
def get_total_row_count(self): | |||
return self.total_row_count | |||
def __getitem__(self, index): | |||
if self.data_cnt == self.row_count: | |||
print('reach the end of datafile, start a new reader') | |||
self.data_cnt = 0 | |||
self._reader = self._get_reader() | |||
column_l = self._reader.readline().rstrip('\n').split(self.separator) | |||
self.data_cnt += 1 | |||
column_l = [ | |||
dtype(column_l[col_id]) | |||
for col_id, dtype in zip(self.selected_col_ids, self.dtypes) | |||
] | |||
return column_l |
@@ -65,7 +65,7 @@ class OFATrainer(EpochBasedTrainer): | |||
kwargs['launcher'] = cfg.train.launcher | |||
if 'use_fp16' not in kwargs and cfg.train.get('use_fp16', False): | |||
kwargs['use_fp16'] = cfg.train.use_fp16 | |||
kwargs['to_tensor'] = False | |||
super().__init__( | |||
cfg_file=cfg_file, | |||
model=model, | |||
@@ -167,19 +167,20 @@ class EpochBasedTrainer(BaseTrainer): | |||
device_name = f'cuda:{local_rank}' | |||
self.device = create_device(device_name) | |||
self.train_dataset = self.to_task_dataset( | |||
train_dataset, | |||
mode=ModeKeys.TRAIN, | |||
task_data_config=self.cfg.dataset.get('train', None) if hasattr( | |||
self.cfg, 'dataset') else None, | |||
preprocessor=self.train_preprocessor) | |||
preprocessor=self.train_preprocessor, | |||
**kwargs) | |||
self.eval_dataset = self.to_task_dataset( | |||
eval_dataset, | |||
mode=ModeKeys.EVAL, | |||
task_data_config=self.cfg.dataset.get('val', None) if hasattr( | |||
self.cfg, 'dataset') else None, | |||
preprocessor=self.eval_preprocessor) | |||
preprocessor=self.eval_preprocessor, | |||
**kwargs) | |||
self.train_data_collator, self.eval_default_collate = None, None | |||
if isinstance(data_collator, Mapping): | |||
@@ -305,13 +306,15 @@ class EpochBasedTrainer(BaseTrainer): | |||
datasets: Union[Dataset, List[Dataset]], | |||
mode: str, | |||
task_data_config: Config = None, | |||
preprocessor: Optional[Preprocessor] = None): | |||
preprocessor: Optional[Preprocessor] = None, | |||
**kwargs): | |||
"""Build the task specific dataset processor for this trainer. | |||
Returns: The task dataset processor for the task. If no result for the very model-type and task, | |||
the default TaskDataset will be returned. | |||
""" | |||
try: | |||
to_tensor = kwargs.get('to_tensor', True) | |||
if not datasets: | |||
return datasets | |||
if isinstance(datasets, TorchTaskDataset): | |||
@@ -327,7 +330,8 @@ class EpochBasedTrainer(BaseTrainer): | |||
return datasets.to_torch_dataset( | |||
task_data_config=task_data_config, | |||
task_name=self.cfg.task, | |||
preprocessors=preprocessor) | |||
preprocessors=preprocessor, | |||
to_tensor=to_tensor) | |||
elif isinstance(datasets, List) and isinstance( | |||
datasets[0], MsDataset): | |||
if task_data_config is None: | |||
@@ -341,7 +345,8 @@ class EpochBasedTrainer(BaseTrainer): | |||
d.to_torch_dataset( | |||
task_data_config=task_data_config, | |||
task_name=self.cfg.task, | |||
preprocessors=preprocessor) for d in datasets | |||
preprocessors=preprocessor, | |||
to_tensor=to_tensor) for d in datasets | |||
] | |||
cfg = ConfigDict( | |||
type=self.cfg.task, mode=mode, datasets=datasets) | |||
@@ -94,8 +94,11 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_run_with_text_classification_with_model(self): | |||
# model = Model.from_pretrained( | |||
# 'damo/ofa_text-classification_mnli_large_en') | |||
model = Model.from_pretrained( | |||
'damo/ofa_text-classification_mnli_large_en') | |||
'/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/ofa_text-classification_mnli_large_en' | |||
) | |||
ofa_pipe = pipeline(Tasks.text_classification, model=model) | |||
text = 'One of our number will carry out your instructions minutely.' | |||
text2 = 'A member of my team will execute your orders with immense precision.' | |||
@@ -12,11 +12,10 @@ class TestOfaTrainer(unittest.TestCase): | |||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
def test_trainer(self): | |||
model_id = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/maas_mnli_pretrain_ckpt' | |||
model_id = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/ofa_text-classification_mnli_large_en' | |||
self.trainer = OFATrainer(model_id) | |||
self.trainer = OFATrainer(model_id, launcher='pytorch') | |||
self.trainer.train() | |||
if os.path.exists(self.trainer.work_dir): | |||
shutil.rmtree(self.trainer.work_dir) | |||
pass | |||
if __name__ == '__main__': | |||