You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dialog_intent_model.py 2.4 kB

3 years ago
3 years ago
3 years ago
3 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. from typing import Any, Dict, Optional
  2. from maas_lib.trainers.nlp.space.trainers.intent_trainer import IntentTrainer
  3. from maas_lib.utils.constant import Tasks
  4. from ...base import Model, Tensor
  5. from ...builder import MODELS
  6. from .model.generator import Generator
  7. from .model.model_base import ModelBase
  8. __all__ = ['DialogIntentModel']
  9. @MODELS.register_module(Tasks.dialog_intent, module_name=r'space-intent')
  10. class DialogIntentModel(Model):
  11. def __init__(self, model_dir: str, *args, **kwargs):
  12. """initialize the test generation model from the `model_dir` path.
  13. Args:
  14. model_dir (str): the model path.
  15. model_cls (Optional[Any], optional): model loader, if None, use the
  16. default loader to load model weights, by default None.
  17. """
  18. super().__init__(model_dir, *args, **kwargs)
  19. self.model_dir = model_dir
  20. self.text_field = kwargs.pop('text_field')
  21. self.config = kwargs.pop('config')
  22. self.generator = Generator.create(self.config, reader=self.text_field)
  23. self.model = ModelBase.create(
  24. model_dir=model_dir,
  25. config=self.config,
  26. reader=self.text_field,
  27. generator=self.generator)
  28. def to_tensor(array):
  29. """
  30. numpy array -> tensor
  31. """
  32. import torch
  33. array = torch.tensor(array)
  34. return array.cuda() if self.config.use_gpu else array
  35. self.trainer = IntentTrainer(
  36. model=self.model,
  37. to_tensor=to_tensor,
  38. config=self.config,
  39. reader=self.text_field)
  40. self.trainer.load()
  41. def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
  42. """return the result by the model
  43. Args:
  44. input (Dict[str, Any]): the preprocessed data
  45. Returns:
  46. Dict[str, np.ndarray]: results
  47. Example:
  48. {
  49. 'predictions': array([1]), # lable 0-negative 1-positive
  50. 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
  51. 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
  52. }
  53. """
  54. from numpy import array, float32
  55. import torch
  56. print('--forward--')
  57. result = self.trainer.forward(input)
  58. return result

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展