|
- # Copyright (c) Alibaba, Inc. and its affiliates.
- from typing import Any, Dict
-
- import torch
-
- from modelscope.utils.constant import ModeKeys
- from .base import OfaBasePreprocessor
-
-
- class OfaTextClassificationPreprocessor(OfaBasePreprocessor):
-
- def __init__(self,
- cfg,
- model_dir,
- mode=ModeKeys.INFERENCE,
- *args,
- **kwargs):
- """preprocess the data
-
- Args:
- cfg(modelscope.utils.config.ConfigDict) : model config
- model_dir (str): model path,
- mode: preprocessor mode (model mode)
- """
- super(OfaTextClassificationPreprocessor,
- self).__init__(cfg, model_dir, mode, *args, **kwargs)
-
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
- if self.mode == ModeKeys.TRAIN:
- return self._build_train_sample(data)
- else:
- return self._build_infer_sample(data)
-
- def _build_instruction(self, data):
- text1 = ' '.join(
- data['text'].lower().strip().split()[:self.max_src_length])
- text2 = ' '.join(
- data['text2'].lower().strip().split()[:self.max_src_length])
- prompt = ' can text1 " {} " imply text2 " {} "?'
- text = prompt.format(text1, text2)
- instruction_itm = self.tokenize_text(text)
- return instruction_itm
-
- def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
- instruction_itm = self._build_instruction(data)
- assert 'label' in data, 'there must has `label` column in train phase '
- label = data['label']
- if self.label2ans:
- label = self.label2ans[label] # ans
- label_itm = self.tokenize_text(f' {label}', add_bos=False)
- if self.prompt_type == 'none':
- target_itm = label_itm
- elif self.prompt_type == 'prev_output':
- target_itm = torch.cat([instruction_itm[1:-1], label_itm])
- else:
- raise NotImplementedError
- prev_output_itm = torch.cat([self.bos_item, target_itm[:-1]])
- target_itm[:-len(label_itm)] = self.pad_item
- sample = {
- 'source': instruction_itm,
- 'target': target_itm,
- 'prev_output_tokens': prev_output_itm,
- }
- self.add_constraint_mask(sample)
- return sample
-
- def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
- instruction_itm = self._build_instruction(data)
- if self.prompt_type == 'none':
- prefix_token = []
- elif self.prompt_type == 'prev_output':
- prefix_token = instruction_itm[:-1] # remove eos
- else:
- raise NotImplementedError
- sample = {
- 'source': instruction_itm,
- 'prefix_token': prefix_token,
- }
- if 'label' in data:
- sample['label'] = self.label2ans[data['label']]
- return sample
|