Browse Source

[to #42322933]fix ocr prepreocess & conflict

修复ocr预处理逻辑不一致问题
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10581697
master
liugao.lg yingda.chen 2 years ago
parent
commit
40b6770956
4 changed files with 9 additions and 7 deletions
  1. +0
    -1
      modelscope/preprocessors/multi_modal.py
  2. +6
    -5
      modelscope/preprocessors/ofa/ocr_recognition.py
  3. +2
    -0
      requirements/multi-modal.txt
  4. +1
    -1
      tests/trainers/test_ofa_trainer.py

+ 0
- 1
modelscope/preprocessors/multi_modal.py View File

@@ -96,7 +96,6 @@ class OfaPreprocessor(Preprocessor):
data = input data = input
else: else:
data = self._build_dict(input) data = self._build_dict(input)
data = self._ofa_input_compatibility_conversion(data)
sample = self.preprocess(data) sample = self.preprocess(data)
str_data = dict() str_data = dict()
for k, v in data.items(): for k, v in data.items():


+ 6
- 5
modelscope/preprocessors/ofa/ocr_recognition.py View File

@@ -2,12 +2,12 @@
from typing import Any, Dict from typing import Any, Dict


import torch import torch
from PIL import Image
import unicodedata2
from torchvision import transforms from torchvision import transforms
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
from torchvision.transforms import functional as F from torchvision.transforms import functional as F
from zhconv import convert


from modelscope.preprocessors.image import load_image
from modelscope.utils.constant import ModeKeys from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor from .base import OfaBasePreprocessor


@@ -98,8 +98,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):


def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self._build_infer_sample(data) sample = self._build_infer_sample(data)
target = data[self.column_map['text']]
target = target.translate(self.transtab).strip()
target = sample['label']
target_token_list = target.strip().split() target_token_list = target.strip().split()
target = ' '.join(target_token_list[:self.max_tgt_length]) target = ' '.join(target_token_list[:self.max_tgt_length])
sample['target'] = self.tokenize_text(target, add_bos=False) sample['target'] = self.tokenize_text(target, add_bos=False)
@@ -119,5 +118,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):
'patch_mask': torch.tensor([True]) 'patch_mask': torch.tensor([True])
} }
if 'text' in self.column_map and self.column_map['text'] in data: if 'text' in self.column_map and self.column_map['text'] in data:
sample['label'] = data[self.column_map['text']]
target = data[self.column_map['text']]
target = unicodedata2.normalize('NFKC', convert(target, 'zh-hans'))
sample['label'] = target
return sample return sample

+ 2
- 0
requirements/multi-modal.txt View File

@@ -11,3 +11,5 @@ timm
tokenizers tokenizers
torchvision torchvision
transformers>=4.12.0 transformers>=4.12.0
unicodedata2
zhconv

+ 1
- 1
tests/trainers/test_ofa_trainer.py View File

@@ -85,7 +85,7 @@ class TestOfaTrainer(unittest.TestCase):
'ocr_fudanvi_zh', 'ocr_fudanvi_zh',
subset_name='scene', subset_name='scene',
namespace='modelscope', namespace='modelscope',
split='train[:200]',
split='train[800:900]',
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS), download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS),
eval_dataset=MsDataset.load( eval_dataset=MsDataset.load(
'ocr_fudanvi_zh', 'ocr_fudanvi_zh',


Loading…
Cancel
Save