You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dialog_generation_model.py 3.5 kB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. from typing import Any, Dict, Optional
  2. from maas_lib.trainers.nlp.space.trainers.gen_trainer import MultiWOZTrainer
  3. from maas_lib.utils.constant import Tasks
  4. from ...base import Model, Tensor
  5. from ...builder import MODELS
  6. from .model.generator import Generator
  7. from .model.model_base import ModelBase
  8. __all__ = ['DialogGenerationModel']
  9. @MODELS.register_module(Tasks.dialog_generation, module_name=r'space')
  10. class DialogGenerationModel(Model):
  11. def __init__(self, model_dir: str, *args, **kwargs):
  12. """initialize the test generation model from the `model_dir` path.
  13. Args:
  14. model_dir (str): the model path.
  15. model_cls (Optional[Any], optional): model loader, if None, use the
  16. default loader to load model weights, by default None.
  17. """
  18. super().__init__(model_dir, *args, **kwargs)
  19. self.model_dir = model_dir
  20. self.text_field = kwargs.pop('text_field')
  21. self.config = kwargs.pop('config')
  22. self.generator = Generator.create(self.config, reader=self.text_field)
  23. self.model = ModelBase.create(
  24. model_dir=model_dir,
  25. config=self.config,
  26. reader=self.text_field,
  27. generator=self.generator)
  28. def to_tensor(array):
  29. """
  30. numpy array -> tensor
  31. """
  32. import torch
  33. array = torch.tensor(array)
  34. return array.cuda() if self.config.use_gpu else array
  35. self.trainer = MultiWOZTrainer(
  36. model=self.model,
  37. to_tensor=to_tensor,
  38. config=self.config,
  39. reader=self.text_field,
  40. evaluator=None)
  41. self.trainer.load()
  42. def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
  43. """return the result by the model
  44. Args:
  45. input (Dict[str, Any]): the preprocessed data
  46. Returns:
  47. Dict[str, np.ndarray]: results
  48. Example:
  49. {
  50. 'predictions': array([1]), # lable 0-negative 1-positive
  51. 'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
  52. 'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
  53. }
  54. """
  55. from numpy import array, float32
  56. import torch
  57. turn_1 = {
  58. 'user': [
  59. 13, 1045, 2052, 2066, 1037, 10095, 2013, 3002, 2198, 1005,
  60. 1055, 2267, 2000, 10733, 12570, 21713, 4487, 15474, 1012, 7
  61. ]
  62. }
  63. old_pv_turn_1 = {}
  64. turn_2 = {
  65. 'user':
  66. [13, 1045, 2215, 2000, 2681, 2044, 2459, 1024, 2321, 1012, 7]
  67. }
  68. old_pv_turn_2 = {
  69. 'labels': [[
  70. 13, 1045, 2052, 2066, 1037, 10095, 2013, 3002, 2198, 1005,
  71. 1055, 2267, 2000, 10733, 12570, 21713, 4487, 15474, 1012, 7
  72. ]],
  73. 'resp': [
  74. 14, 1045, 2052, 2022, 3407, 2000, 2393, 2007, 2115, 5227, 1010,
  75. 2079, 2017, 2031, 1037, 2051, 2017, 2052, 2066, 2000, 2681,
  76. 2030, 7180, 2011, 1029, 8
  77. ],
  78. 'bspn': [
  79. 15, 43, 7688, 10733, 12570, 21713, 4487, 15474, 6712, 3002,
  80. 2198, 1005, 1055, 2267, 9
  81. ],
  82. 'db': [19, 24, 21, 20],
  83. 'aspn': [16, 43, 48, 2681, 7180, 10]
  84. }
  85. pv_turn = self.trainer.forward(turn=turn_2, old_pv_turn=old_pv_turn_2)
  86. return pv_turn

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展