@@ -93,7 +93,6 @@ class OfaPreprocessor(Preprocessor): | |||||
data = input | data = input | ||||
else: | else: | ||||
data = self._build_dict(input) | data = self._build_dict(input) | ||||
data = self._ofa_input_compatibility_conversion(data) | |||||
sample = self.preprocess(data) | sample = self.preprocess(data) | ||||
str_data = dict() | str_data = dict() | ||||
for k, v in data.items(): | for k, v in data.items(): | ||||
@@ -2,12 +2,12 @@ | |||||
from typing import Any, Dict | from typing import Any, Dict | ||||
import torch | import torch | ||||
from PIL import Image | |||||
import unicodedata2 | |||||
from torchvision import transforms | from torchvision import transforms | ||||
from torchvision.transforms import InterpolationMode | from torchvision.transforms import InterpolationMode | ||||
from torchvision.transforms import functional as F | from torchvision.transforms import functional as F | ||||
from zhconv import convert | |||||
from modelscope.preprocessors.image import load_image | |||||
from modelscope.utils.constant import ModeKeys | from modelscope.utils.constant import ModeKeys | ||||
from .base import OfaBasePreprocessor | from .base import OfaBasePreprocessor | ||||
@@ -98,8 +98,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): | |||||
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: | ||||
sample = self._build_infer_sample(data) | sample = self._build_infer_sample(data) | ||||
target = data[self.column_map['text']] | |||||
target = target.translate(self.transtab).strip() | |||||
target = sample['label'] | |||||
target_token_list = target.strip().split() | target_token_list = target.strip().split() | ||||
target = ' '.join(target_token_list[:self.max_tgt_length]) | target = ' '.join(target_token_list[:self.max_tgt_length]) | ||||
sample['target'] = self.tokenize_text(target, add_bos=False) | sample['target'] = self.tokenize_text(target, add_bos=False) | ||||
@@ -119,5 +118,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): | |||||
'patch_mask': torch.tensor([True]) | 'patch_mask': torch.tensor([True]) | ||||
} | } | ||||
if 'text' in self.column_map and self.column_map['text'] in data: | if 'text' in self.column_map and self.column_map['text'] in data: | ||||
sample['label'] = data[self.column_map['text']] | |||||
target = data[self.column_map['text']] | |||||
target = unicodedata2.normalize('NFKC', convert(target, 'zh-hans')) | |||||
sample['label'] = target | |||||
return sample | return sample |
@@ -11,3 +11,5 @@ timm | |||||
tokenizers | tokenizers | ||||
torchvision | torchvision | ||||
transformers>=4.12.0 | transformers>=4.12.0 | ||||
unicodedata2 | |||||
zhconv |
@@ -5,7 +5,7 @@ import unittest | |||||
import json | import json | ||||
from modelscope.metainfo import Trainers | |||||
from modelscope.metainfo import Metrics, Trainers | |||||
from modelscope.msdatasets import MsDataset | from modelscope.msdatasets import MsDataset | ||||
from modelscope.trainers import build_trainer | from modelscope.trainers import build_trainer | ||||
from modelscope.utils.constant import DownloadMode, ModelFile | from modelscope.utils.constant import DownloadMode, ModelFile | ||||
@@ -85,7 +85,7 @@ class TestOfaTrainer(unittest.TestCase): | |||||
'ocr_fudanvi_zh', | 'ocr_fudanvi_zh', | ||||
subset_name='scene', | subset_name='scene', | ||||
namespace='modelscope', | namespace='modelscope', | ||||
split='train[:200]', | |||||
split='train[800:900]', | |||||
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS), | download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS), | ||||
eval_dataset=MsDataset.load( | eval_dataset=MsDataset.load( | ||||
'ocr_fudanvi_zh', | 'ocr_fudanvi_zh', | ||||
@@ -96,7 +96,6 @@ class TestOfaTrainer(unittest.TestCase): | |||||
cfg_file=config_file) | cfg_file=config_file) | ||||
trainer = build_trainer(name=Trainers.ofa, default_args=args) | trainer = build_trainer(name=Trainers.ofa, default_args=args) | ||||
trainer.train() | trainer.train() | ||||
self.assertIn( | self.assertIn( | ||||
ModelFile.TORCH_MODEL_BIN_FILE, | ModelFile.TORCH_MODEL_BIN_FILE, | ||||
os.listdir(os.path.join(WORKSPACE, ModelFile.TRAIN_OUTPUT_DIR))) | os.listdir(os.path.join(WORKSPACE, ModelFile.TRAIN_OUTPUT_DIR))) | ||||