From 40b677095605594d426b9c731687fb834d04b4fc Mon Sep 17 00:00:00 2001 From: "liugao.lg" Date: Tue, 1 Nov 2022 10:22:11 +0800 Subject: [PATCH] [to #42322933]fix ocr prepreocess & conflict MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复ocr预处理逻辑不一致问题 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10581697 --- modelscope/preprocessors/multi_modal.py | 1 - modelscope/preprocessors/ofa/ocr_recognition.py | 11 ++++++----- requirements/multi-modal.txt | 2 ++ tests/trainers/test_ofa_trainer.py | 2 +- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py index 17dffb48..13876058 100644 --- a/modelscope/preprocessors/multi_modal.py +++ b/modelscope/preprocessors/multi_modal.py @@ -96,7 +96,6 @@ class OfaPreprocessor(Preprocessor): data = input else: data = self._build_dict(input) - data = self._ofa_input_compatibility_conversion(data) sample = self.preprocess(data) str_data = dict() for k, v in data.items(): diff --git a/modelscope/preprocessors/ofa/ocr_recognition.py b/modelscope/preprocessors/ofa/ocr_recognition.py index 26fff9d2..a0342c14 100644 --- a/modelscope/preprocessors/ofa/ocr_recognition.py +++ b/modelscope/preprocessors/ofa/ocr_recognition.py @@ -2,12 +2,12 @@ from typing import Any, Dict import torch -from PIL import Image +import unicodedata2 from torchvision import transforms from torchvision.transforms import InterpolationMode from torchvision.transforms import functional as F +from zhconv import convert -from modelscope.preprocessors.image import load_image from modelscope.utils.constant import ModeKeys from .base import OfaBasePreprocessor @@ -98,8 +98,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: sample = self._build_infer_sample(data) - target = data[self.column_map['text']] - target = target.translate(self.transtab).strip() + target = sample['label'] target_token_list = target.strip().split() target = ' '.join(target_token_list[:self.max_tgt_length]) sample['target'] = self.tokenize_text(target, add_bos=False) @@ -119,5 +118,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): 'patch_mask': torch.tensor([True]) } if 'text' in self.column_map and self.column_map['text'] in data: - sample['label'] = data[self.column_map['text']] + target = data[self.column_map['text']] + target = unicodedata2.normalize('NFKC', convert(target, 'zh-hans')) + sample['label'] = target return sample diff --git a/requirements/multi-modal.txt b/requirements/multi-modal.txt index 255f6155..578f0b54 100644 --- a/requirements/multi-modal.txt +++ b/requirements/multi-modal.txt @@ -11,3 +11,5 @@ timm tokenizers torchvision transformers>=4.12.0 +unicodedata2 +zhconv diff --git a/tests/trainers/test_ofa_trainer.py b/tests/trainers/test_ofa_trainer.py index 3f68a9fb..85c21881 100644 --- a/tests/trainers/test_ofa_trainer.py +++ b/tests/trainers/test_ofa_trainer.py @@ -85,7 +85,7 @@ class TestOfaTrainer(unittest.TestCase): 'ocr_fudanvi_zh', subset_name='scene', namespace='modelscope', - split='train[:200]', + split='train[800:900]', download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS), eval_dataset=MsDataset.load( 'ocr_fudanvi_zh',