|
|
@@ -2,12 +2,12 @@ |
|
|
|
from typing import Any, Dict |
|
|
|
|
|
|
|
import torch |
|
|
|
from PIL import Image |
|
|
|
import unicodedata2 |
|
|
|
from torchvision import transforms |
|
|
|
from torchvision.transforms import InterpolationMode |
|
|
|
from torchvision.transforms import functional as F |
|
|
|
from zhconv import convert |
|
|
|
|
|
|
|
from modelscope.preprocessors.image import load_image |
|
|
|
from modelscope.utils.constant import ModeKeys |
|
|
|
from .base import OfaBasePreprocessor |
|
|
|
|
|
|
@@ -98,8 +98,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): |
|
|
|
|
|
|
|
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
sample = self._build_infer_sample(data) |
|
|
|
target = data[self.column_map['text']] |
|
|
|
target = target.translate(self.transtab).strip() |
|
|
|
target = sample['label'] |
|
|
|
target_token_list = target.strip().split() |
|
|
|
target = ' '.join(target_token_list[:self.max_tgt_length]) |
|
|
|
sample['target'] = self.tokenize_text(target, add_bos=False) |
|
|
@@ -119,5 +118,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): |
|
|
|
'patch_mask': torch.tensor([True]) |
|
|
|
} |
|
|
|
if 'text' in self.column_map and self.column_map['text'] in data: |
|
|
|
sample['label'] = data[self.column_map['text']] |
|
|
|
target = data[self.column_map['text']] |
|
|
|
target = unicodedata2.normalize('NFKC', convert(target, 'zh-hans')) |
|
|
|
sample['label'] = target |
|
|
|
return sample |