You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nlp.py 2.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import uuid
  3. from typing import Any, Dict, Union
  4. from transformers import AutoTokenizer
  5. from maas_lib.utils.constant import Fields, InputFields
  6. from maas_lib.utils.type_assert import type_assert
  7. from .base import Preprocessor
  8. from .builder import PREPROCESSORS
  9. __all__ = ['Tokenize', 'SequenceClassificationPreprocessor']
  10. @PREPROCESSORS.register_module(Fields.nlp)
  11. class Tokenize(Preprocessor):
  12. def __init__(self, tokenizer_name) -> None:
  13. self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
  14. def __call__(self, data: Union[str, Dict[str, Any]]) -> Dict[str, Any]:
  15. if isinstance(data, str):
  16. data = {InputFields.text: data}
  17. token_dict = self._tokenizer(data[InputFields.text])
  18. data.update(token_dict)
  19. return data
  20. @PREPROCESSORS.register_module(
  21. Fields.nlp, module_name=r'bert-sentiment-analysis')
  22. class SequenceClassificationPreprocessor(Preprocessor):
  23. def __init__(self, model_dir: str, *args, **kwargs):
  24. """preprocess the data via the vocab.txt from the `model_dir` path
  25. Args:
  26. model_dir (str): model path
  27. """
  28. super().__init__(*args, **kwargs)
  29. from easynlp.modelzoo import AutoTokenizer
  30. self.model_dir: str = model_dir
  31. self.first_sequence: str = kwargs.pop('first_sequence',
  32. 'first_sequence')
  33. self.second_sequence = kwargs.pop('second_sequence', 'second_sequence')
  34. self.sequence_length = kwargs.pop('sequence_length', 128)
  35. self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
  36. @type_assert(object, str)
  37. def __call__(self, data: str) -> Dict[str, Any]:
  38. """process the raw input data
  39. Args:
  40. data (str): a sentence
  41. Example:
  42. 'you are so handsome.'
  43. Returns:
  44. Dict[str, Any]: the preprocessed data
  45. """
  46. new_data = {self.first_sequence: data}
  47. # preprocess the data for the model input
  48. rst = {
  49. 'id': [],
  50. 'input_ids': [],
  51. 'attention_mask': [],
  52. 'token_type_ids': []
  53. }
  54. max_seq_length = self.sequence_length
  55. text_a = new_data[self.first_sequence]
  56. text_b = new_data.get(self.second_sequence, None)
  57. feature = self.tokenizer(
  58. text_a,
  59. text_b,
  60. padding='max_length',
  61. truncation=True,
  62. max_length=max_seq_length)
  63. rst['id'].append(new_data.get('id', str(uuid.uuid4())))
  64. rst['input_ids'].append(feature['input_ids'])
  65. rst['attention_mask'].append(feature['attention_mask'])
  66. rst['token_type_ids'].append(feature['token_type_ids'])
  67. return rst

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展