You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sequence_classification_model.py 2.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. from typing import Any, Dict, Optional, Union
  2. import numpy as np
  3. import torch
  4. from maas_lib.utils.constant import Tasks
  5. from ..base import Model
  6. from ..builder import MODELS
  7. __all__ = ['SequenceClassificationModel']
  8. @MODELS.register_module(
  9. Tasks.text_classification, module_name=r'bert-sentiment-analysis')
  10. class SequenceClassificationModel(Model):
  11. def __init__(self,
  12. model_dir: str,
  13. model_cls: Optional[Any] = None,
  14. *args,
  15. **kwargs):
  16. # Model.__init__(self, model_dir, model_cls, first_sequence, *args, **kwargs)
  17. # Predictor.__init__(self, *args, **kwargs)
  18. """initilize the sequence classification model from the `model_dir` path
  19. Args:
  20. model_dir (str): the model path
  21. model_cls (Optional[Any], optional): model loader, if None, use the
  22. default loader to load model weights, by default None
  23. """
  24. super().__init__(model_dir, model_cls, *args, **kwargs)
  25. from easynlp.appzoo import SequenceClassification
  26. from easynlp.core.predictor import get_model_predictor
  27. self.model_dir = model_dir
  28. model_cls = SequenceClassification if not model_cls else model_cls
  29. self.model = get_model_predictor(
  30. model_dir=model_dir,
  31. model_cls=model_cls,
  32. input_keys=[('input_ids', torch.LongTensor),
  33. ('attention_mask', torch.LongTensor),
  34. ('token_type_ids', torch.LongTensor)],
  35. output_keys=['predictions', 'probabilities', 'logits'])
  36. def forward(self, input: Dict[str, Any]) -> Dict[str, np.ndarray]:
  37. """return the result by the model
  38. Args:
  39. input (Dict[str, Any]): the preprocessed data
  40. Returns:
  41. Dict[str, np.ndarray]: results
  42. Example:
  43. {
  44. 'predictions': array([1]), # lable 0-negative 1-positive
  45. 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
  46. 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
  47. }
  48. """
  49. return self.model.predict(input)
  50. ...

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展