You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nlp.py 2.8 kB

3 years ago
3 years ago
3 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import uuid
  3. from typing import Any, Dict, Union
  4. from transformers import AutoTokenizer
  5. from maas_lib.utils.constant import Fields, InputFields
  6. from maas_lib.utils.type_assert import type_assert
  7. from .base import Preprocessor
  8. from .builder import PREPROCESSORS
  9. __all__ = [
  10. 'Tokenize',
  11. 'SequenceClassificationPreprocessor',
  12. ]
  13. @PREPROCESSORS.register_module(Fields.nlp)
  14. class Tokenize(Preprocessor):
  15. def __init__(self, tokenizer_name) -> None:
  16. self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
  17. def __call__(self, data: Union[str, Dict[str, Any]]) -> Dict[str, Any]:
  18. if isinstance(data, str):
  19. data = {InputFields.text: data}
  20. token_dict = self._tokenizer(data[InputFields.text])
  21. data.update(token_dict)
  22. return data
  23. @PREPROCESSORS.register_module(
  24. Fields.nlp, module_name=r'bert-sentiment-analysis')
  25. class SequenceClassificationPreprocessor(Preprocessor):
  26. def __init__(self, model_dir: str, *args, **kwargs):
  27. """preprocess the data via the vocab.txt from the `model_dir` path
  28. Args:
  29. model_dir (str): model path
  30. """
  31. super().__init__(*args, **kwargs)
  32. from easynlp.modelzoo import AutoTokenizer
  33. self.model_dir: str = model_dir
  34. self.first_sequence: str = kwargs.pop('first_sequence',
  35. 'first_sequence')
  36. self.second_sequence = kwargs.pop('second_sequence', 'second_sequence')
  37. self.sequence_length = kwargs.pop('sequence_length', 128)
  38. self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
  39. @type_assert(object, str)
  40. def __call__(self, data: str) -> Dict[str, Any]:
  41. """process the raw input data
  42. Args:
  43. data (str): a sentence
  44. Example:
  45. 'you are so handsome.'
  46. Returns:
  47. Dict[str, Any]: the preprocessed data
  48. """
  49. new_data = {self.first_sequence: data}
  50. # preprocess the data for the model input
  51. rst = {
  52. 'id': [],
  53. 'input_ids': [],
  54. 'attention_mask': [],
  55. 'token_type_ids': []
  56. }
  57. max_seq_length = self.sequence_length
  58. text_a = new_data[self.first_sequence]
  59. text_b = new_data.get(self.second_sequence, None)
  60. feature = self.tokenizer(
  61. text_a,
  62. text_b,
  63. padding='max_length',
  64. truncation=True,
  65. max_length=max_seq_length)
  66. rst['id'].append(new_data.get('id', str(uuid.uuid4())))
  67. rst['input_ids'].append(feature['input_ids'])
  68. rst['attention_mask'].append(feature['attention_mask'])
  69. rst['token_type_ids'].append(feature['token_type_ids'])
  70. return rst

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展