You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

text_classification.py 2.9 kB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict
  3. import torch
  4. from modelscope.utils.constant import ModeKeys
  5. from .base import OfaBasePreprocessor
  6. class OfaTextClassificationPreprocessor(OfaBasePreprocessor):
  7. def __init__(self,
  8. cfg,
  9. model_dir,
  10. mode=ModeKeys.INFERENCE,
  11. *args,
  12. **kwargs):
  13. """preprocess the data
  14. Args:
  15. cfg(modelscope.utils.config.ConfigDict) : model config
  16. model_dir (str): model path,
  17. mode: preprocessor mode (model mode)
  18. """
  19. super(OfaTextClassificationPreprocessor,
  20. self).__init__(cfg, model_dir, mode, *args, **kwargs)
  21. def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
  22. if self.mode == ModeKeys.TRAIN:
  23. return self._build_train_sample(data)
  24. else:
  25. return self._build_infer_sample(data)
  26. def _build_instruction(self, data):
  27. text1 = ' '.join(
  28. data['text'].lower().strip().split()[:self.max_src_length])
  29. text2 = ' '.join(
  30. data['text2'].lower().strip().split()[:self.max_src_length])
  31. prompt = ' can text1 " {} " imply text2 " {} "?'
  32. text = prompt.format(text1, text2)
  33. instruction_itm = self.tokenize_text(text)
  34. return instruction_itm
  35. def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
  36. instruction_itm = self._build_instruction(data)
  37. assert 'label' in data, 'there must has `label` column in train phase '
  38. label = data['label']
  39. if self.label2ans:
  40. label = self.label2ans[label] # ans
  41. label_itm = self.tokenize_text(f' {label}', add_bos=False)
  42. if self.prompt_type == 'none':
  43. target_itm = label_itm
  44. elif self.prompt_type == 'prev_output':
  45. target_itm = torch.cat([instruction_itm[1:-1], label_itm])
  46. else:
  47. raise NotImplementedError
  48. prev_output_itm = torch.cat([self.bos_item, target_itm[:-1]])
  49. target_itm[:-len(label_itm)] = self.pad_item
  50. sample = {
  51. 'source': instruction_itm,
  52. 'target': target_itm,
  53. 'prev_output_tokens': prev_output_itm,
  54. }
  55. self.add_constraint_mask(sample)
  56. return sample
  57. def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
  58. instruction_itm = self._build_instruction(data)
  59. if self.prompt_type == 'none':
  60. prefix_token = []
  61. elif self.prompt_type == 'prev_output':
  62. prefix_token = instruction_itm[:-1] # remove eos
  63. else:
  64. raise NotImplementedError
  65. sample = {
  66. 'source': instruction_itm,
  67. 'prefix_token': prefix_token,
  68. }
  69. if 'label' in data:
  70. sample['label'] = self.label2ans[data['label']]
  71. return sample