Browse Source

[to #42322933]fix ocr prepreocess & conflict

修复ocr预处理逻辑不一致问题
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10581697
master
liugao.lg yingda.chen 2 years ago
parent
commit
40b6770956
4 changed files with 9 additions and 7 deletions
  1. +0
    -1
      modelscope/preprocessors/multi_modal.py
  2. +6
    -5
      modelscope/preprocessors/ofa/ocr_recognition.py
  3. +2
    -0
      requirements/multi-modal.txt
  4. +1
    -1
      tests/trainers/test_ofa_trainer.py

+ 0
- 1
modelscope/preprocessors/multi_modal.py View File

@@ -96,7 +96,6 @@ class OfaPreprocessor(Preprocessor):
data = input
else:
data = self._build_dict(input)
data = self._ofa_input_compatibility_conversion(data)
sample = self.preprocess(data)
str_data = dict()
for k, v in data.items():


+ 6
- 5
modelscope/preprocessors/ofa/ocr_recognition.py View File

@@ -2,12 +2,12 @@
from typing import Any, Dict

import torch
from PIL import Image
import unicodedata2
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torchvision.transforms import functional as F
from zhconv import convert

from modelscope.preprocessors.image import load_image
from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor

@@ -98,8 +98,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):

def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self._build_infer_sample(data)
target = data[self.column_map['text']]
target = target.translate(self.transtab).strip()
target = sample['label']
target_token_list = target.strip().split()
target = ' '.join(target_token_list[:self.max_tgt_length])
sample['target'] = self.tokenize_text(target, add_bos=False)
@@ -119,5 +118,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):
'patch_mask': torch.tensor([True])
}
if 'text' in self.column_map and self.column_map['text'] in data:
sample['label'] = data[self.column_map['text']]
target = data[self.column_map['text']]
target = unicodedata2.normalize('NFKC', convert(target, 'zh-hans'))
sample['label'] = target
return sample

+ 2
- 0
requirements/multi-modal.txt View File

@@ -11,3 +11,5 @@ timm
tokenizers
torchvision
transformers>=4.12.0
unicodedata2
zhconv

+ 1
- 1
tests/trainers/test_ofa_trainer.py View File

@@ -85,7 +85,7 @@ class TestOfaTrainer(unittest.TestCase):
'ocr_fudanvi_zh',
subset_name='scene',
namespace='modelscope',
split='train[:200]',
split='train[800:900]',
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS),
eval_dataset=MsDataset.load(
'ocr_fudanvi_zh',


Loading…
Cancel
Save