Browse Source

fix ocr prepreocess

master
翎航 2 years ago
parent
commit
3b21ff10ec
4 changed files with 10 additions and 9 deletions
  1. +0
    -1
      modelscope/preprocessors/multi_modal.py
  2. +6
    -5
      modelscope/preprocessors/ofa/ocr_recognition.py
  3. +2
    -0
      requirements/multi-modal.txt
  4. +2
    -3
      tests/trainers/test_ofa_trainer.py

+ 0
- 1
modelscope/preprocessors/multi_modal.py View File

@@ -93,7 +93,6 @@ class OfaPreprocessor(Preprocessor):
data = input
else:
data = self._build_dict(input)
data = self._ofa_input_compatibility_conversion(data)
sample = self.preprocess(data)
str_data = dict()
for k, v in data.items():


+ 6
- 5
modelscope/preprocessors/ofa/ocr_recognition.py View File

@@ -2,12 +2,12 @@
from typing import Any, Dict

import torch
from PIL import Image
import unicodedata2
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torchvision.transforms import functional as F
from zhconv import convert

from modelscope.preprocessors.image import load_image
from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor

@@ -98,8 +98,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):

def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self._build_infer_sample(data)
target = data[self.column_map['text']]
target = target.translate(self.transtab).strip()
target = sample['label']
target_token_list = target.strip().split()
target = ' '.join(target_token_list[:self.max_tgt_length])
sample['target'] = self.tokenize_text(target, add_bos=False)
@@ -119,5 +118,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):
'patch_mask': torch.tensor([True])
}
if 'text' in self.column_map and self.column_map['text'] in data:
sample['label'] = data[self.column_map['text']]
target = data[self.column_map['text']]
target = unicodedata2.normalize('NFKC', convert(target, 'zh-hans'))
sample['label'] = target
return sample

+ 2
- 0
requirements/multi-modal.txt View File

@@ -11,3 +11,5 @@ timm
tokenizers
torchvision
transformers>=4.12.0
unicodedata2
zhconv

+ 2
- 3
tests/trainers/test_ofa_trainer.py View File

@@ -5,7 +5,7 @@ import unittest

import json

from modelscope.metainfo import Trainers
from modelscope.metainfo import Metrics, Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.utils.constant import DownloadMode, ModelFile
@@ -85,7 +85,7 @@ class TestOfaTrainer(unittest.TestCase):
'ocr_fudanvi_zh',
subset_name='scene',
namespace='modelscope',
split='train[:200]',
split='train[800:900]',
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS),
eval_dataset=MsDataset.load(
'ocr_fudanvi_zh',
@@ -96,7 +96,6 @@ class TestOfaTrainer(unittest.TestCase):
cfg_file=config_file)
trainer = build_trainer(name=Trainers.ofa, default_args=args)
trainer.train()

self.assertIn(
ModelFile.TORCH_MODEL_BIN_FILE,
os.listdir(os.path.join(WORKSPACE, ModelFile.TRAIN_OUTPUT_DIR)))


Loading…
Cancel
Save