You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sequence_classification_model.py 1.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. from typing import Any, Dict, Optional, Union
  2. import numpy as np
  3. from maas_lib.utils.constant import Tasks
  4. from ..base import Model
  5. from ..builder import MODELS
  6. __all__ = ['SequenceClassificationModel']
  7. @MODELS.register_module(
  8. Tasks.text_classification, module_name=r'bert-sentiment-analysis')
  9. class SequenceClassificationModel(Model):
  10. def __init__(self, model_dir: str, *args, **kwargs):
  11. # Model.__init__(self, model_dir, model_cls, first_sequence, *args, **kwargs)
  12. # Predictor.__init__(self, *args, **kwargs)
  13. """initialize the sequence classification model from the `model_dir` path.
  14. Args:
  15. model_dir (str): the model path.
  16. """
  17. super().__init__(model_dir, *args, **kwargs)
  18. from easynlp.appzoo import SequenceClassification
  19. from easynlp.core.predictor import get_model_predictor
  20. import torch
  21. self.model = get_model_predictor(
  22. model_dir=self.model_dir,
  23. model_cls=SequenceClassification,
  24. input_keys=[('input_ids', torch.LongTensor),
  25. ('attention_mask', torch.LongTensor),
  26. ('token_type_ids', torch.LongTensor)],
  27. output_keys=['predictions', 'probabilities', 'logits'])
  28. def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
  29. """return the result by the model
  30. Args:
  31. input (Dict[str, Any]): the preprocessed data
  32. Returns:
  33. Dict[str, np.ndarray]: results
  34. Example:
  35. {
  36. 'predictions': array([1]), # lable 0-negative 1-positive
  37. 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
  38. 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
  39. }
  40. """
  41. return self.model.predict(input)

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展