Browse Source

fix ocr prepreocess

master
翎航 2 years ago
parent
commit
3b21ff10ec
4 changed files with 10 additions and 9 deletions
  1. +0
    -1
      modelscope/preprocessors/multi_modal.py
  2. +6
    -5
      modelscope/preprocessors/ofa/ocr_recognition.py
  3. +2
    -0
      requirements/multi-modal.txt
  4. +2
    -3
      tests/trainers/test_ofa_trainer.py

+ 0
- 1
modelscope/preprocessors/multi_modal.py View File

@@ -93,7 +93,6 @@ class OfaPreprocessor(Preprocessor):
data = input data = input
else: else:
data = self._build_dict(input) data = self._build_dict(input)
data = self._ofa_input_compatibility_conversion(data)
sample = self.preprocess(data) sample = self.preprocess(data)
str_data = dict() str_data = dict()
for k, v in data.items(): for k, v in data.items():


+ 6
- 5
modelscope/preprocessors/ofa/ocr_recognition.py View File

@@ -2,12 +2,12 @@
from typing import Any, Dict from typing import Any, Dict


import torch import torch
from PIL import Image
import unicodedata2
from torchvision import transforms from torchvision import transforms
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
from torchvision.transforms import functional as F from torchvision.transforms import functional as F
from zhconv import convert


from modelscope.preprocessors.image import load_image
from modelscope.utils.constant import ModeKeys from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor from .base import OfaBasePreprocessor


@@ -98,8 +98,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):


def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self._build_infer_sample(data) sample = self._build_infer_sample(data)
target = data[self.column_map['text']]
target = target.translate(self.transtab).strip()
target = sample['label']
target_token_list = target.strip().split() target_token_list = target.strip().split()
target = ' '.join(target_token_list[:self.max_tgt_length]) target = ' '.join(target_token_list[:self.max_tgt_length])
sample['target'] = self.tokenize_text(target, add_bos=False) sample['target'] = self.tokenize_text(target, add_bos=False)
@@ -119,5 +118,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):
'patch_mask': torch.tensor([True]) 'patch_mask': torch.tensor([True])
} }
if 'text' in self.column_map and self.column_map['text'] in data: if 'text' in self.column_map and self.column_map['text'] in data:
sample['label'] = data[self.column_map['text']]
target = data[self.column_map['text']]
target = unicodedata2.normalize('NFKC', convert(target, 'zh-hans'))
sample['label'] = target
return sample return sample

+ 2
- 0
requirements/multi-modal.txt View File

@@ -11,3 +11,5 @@ timm
tokenizers tokenizers
torchvision torchvision
transformers>=4.12.0 transformers>=4.12.0
unicodedata2
zhconv

+ 2
- 3
tests/trainers/test_ofa_trainer.py View File

@@ -5,7 +5,7 @@ import unittest


import json import json


from modelscope.metainfo import Trainers
from modelscope.metainfo import Metrics, Trainers
from modelscope.msdatasets import MsDataset from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer from modelscope.trainers import build_trainer
from modelscope.utils.constant import DownloadMode, ModelFile from modelscope.utils.constant import DownloadMode, ModelFile
@@ -85,7 +85,7 @@ class TestOfaTrainer(unittest.TestCase):
'ocr_fudanvi_zh', 'ocr_fudanvi_zh',
subset_name='scene', subset_name='scene',
namespace='modelscope', namespace='modelscope',
split='train[:200]',
split='train[800:900]',
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS), download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS),
eval_dataset=MsDataset.load( eval_dataset=MsDataset.load(
'ocr_fudanvi_zh', 'ocr_fudanvi_zh',
@@ -96,7 +96,6 @@ class TestOfaTrainer(unittest.TestCase):
cfg_file=config_file) cfg_file=config_file)
trainer = build_trainer(name=Trainers.ofa, default_args=args) trainer = build_trainer(name=Trainers.ofa, default_args=args)
trainer.train() trainer.train()

self.assertIn( self.assertIn(
ModelFile.TORCH_MODEL_BIN_FILE, ModelFile.TORCH_MODEL_BIN_FILE,
os.listdir(os.path.join(WORKSPACE, ModelFile.TRAIN_OUTPUT_DIR))) os.listdir(os.path.join(WORKSPACE, ModelFile.TRAIN_OUTPUT_DIR)))


Loading…
Cancel
Save