You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dialog_modeling.py 6.6 kB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. from typing import List
  4. from modelscope.hub.snapshot_download import snapshot_download
  5. from modelscope.models import Model
  6. from modelscope.models.nlp import SpaceForDialogModeling
  7. from modelscope.pipelines import DialogModelingPipeline, pipeline
  8. from modelscope.preprocessors import DialogModelingPreprocessor
  9. from modelscope.utils.constant import Tasks
  10. from modelscope.utils.test_utils import test_level
  11. class DialogModelingTest(unittest.TestCase):
  12. model_id = 'damo/nlp_space_dialog-modeling'
  13. test_case = {
  14. 'sng0073': {
  15. 'goal': {
  16. 'taxi': {
  17. 'info': {
  18. 'leaveat': '17:15',
  19. 'destination': 'pizza hut fen ditton',
  20. 'departure': "saint john's college"
  21. },
  22. 'reqt': ['car', 'phone'],
  23. 'fail_info': {}
  24. }
  25. },
  26. 'log': [{
  27. 'user':
  28. "i would like a taxi from saint john 's college to pizza hut fen ditton .",
  29. 'user_delex':
  30. 'i would like a taxi from [value_departure] to [value_destination] .',
  31. 'resp':
  32. 'what time do you want to leave and what time do you want to arrive by ?',
  33. 'sys':
  34. 'what time do you want to leave and what time do you want to arrive by ?',
  35. 'pointer': '0,0,0,0,0,0',
  36. 'match': '',
  37. 'constraint':
  38. "[taxi] destination pizza hut fen ditton departure saint john 's college",
  39. 'cons_delex': '[taxi] destination departure',
  40. 'sys_act': '[taxi] [request] leave arrive',
  41. 'turn_num': 0,
  42. 'turn_domain': '[taxi]'
  43. }, {
  44. 'user': 'i want to leave after 17:15 .',
  45. 'user_delex': 'i want to leave after [value_leave] .',
  46. 'resp':
  47. 'booking completed ! your taxi will be [value_car] contact number is [value_phone]',
  48. 'sys':
  49. 'booking completed ! your taxi will be blue honda contact number is 07218068540',
  50. 'pointer': '0,0,0,0,0,0',
  51. 'match': '',
  52. 'constraint':
  53. "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15",
  54. 'cons_delex': '[taxi] destination departure leave',
  55. 'sys_act': '[taxi] [inform] car phone',
  56. 'turn_num': 1,
  57. 'turn_domain': '[taxi]'
  58. }, {
  59. 'user': 'thank you for all the help ! i appreciate it .',
  60. 'user_delex': 'thank you for all the help ! i appreciate it .',
  61. 'resp':
  62. 'you are welcome . is there anything else i can help you with today ?',
  63. 'sys':
  64. 'you are welcome . is there anything else i can help you with today ?',
  65. 'pointer': '0,0,0,0,0,0',
  66. 'match': '',
  67. 'constraint':
  68. "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15",
  69. 'cons_delex': '[taxi] destination departure leave',
  70. 'sys_act': '[general] [reqmore]',
  71. 'turn_num': 2,
  72. 'turn_domain': '[general]'
  73. }, {
  74. 'user': 'no , i am all set . have a nice day . bye .',
  75. 'user_delex': 'no , i am all set . have a nice day . bye .',
  76. 'resp': 'you too ! thank you',
  77. 'sys': 'you too ! thank you',
  78. 'pointer': '0,0,0,0,0,0',
  79. 'match': '',
  80. 'constraint':
  81. "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15",
  82. 'cons_delex': '[taxi] destination departure leave',
  83. 'sys_act': '[general] [bye]',
  84. 'turn_num': 3,
  85. 'turn_domain': '[general]'
  86. }]
  87. }
  88. }
  89. def generate_and_print_dialog_response(
  90. self, pipelines: List[DialogModelingPipeline]):
  91. result = {}
  92. for step, item in enumerate(self.test_case['sng0073']['log']):
  93. user = item['user']
  94. print('user: {}'.format(user))
  95. result = pipelines[step % 2]({
  96. 'user_input': user,
  97. 'history': result
  98. })
  99. print('response : {}'.format(result['response']))
  100. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  101. def test_run_by_direct_model_download(self):
  102. cache_path = snapshot_download(self.model_id)
  103. preprocessor = DialogModelingPreprocessor(model_dir=cache_path)
  104. model = SpaceForDialogModeling(
  105. model_dir=cache_path,
  106. text_field=preprocessor.text_field,
  107. config=preprocessor.config)
  108. pipelines = [
  109. DialogModelingPipeline(model=model, preprocessor=preprocessor),
  110. pipeline(
  111. task=Tasks.dialog_modeling,
  112. model=model,
  113. preprocessor=preprocessor)
  114. ]
  115. self.generate_and_print_dialog_response(pipelines)
  116. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  117. def test_run_with_model_from_modelhub(self):
  118. model = Model.from_pretrained(self.model_id)
  119. preprocessor = DialogModelingPreprocessor(model_dir=model.model_dir)
  120. pipelines = [
  121. DialogModelingPipeline(model=model, preprocessor=preprocessor),
  122. pipeline(
  123. task=Tasks.dialog_modeling,
  124. model=model,
  125. preprocessor=preprocessor)
  126. ]
  127. self.generate_and_print_dialog_response(pipelines)
  128. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  129. def test_run_with_model_name(self):
  130. pipelines = [
  131. pipeline(task=Tasks.dialog_modeling, model=self.model_id),
  132. pipeline(task=Tasks.dialog_modeling, model=self.model_id)
  133. ]
  134. self.generate_and_print_dialog_response(pipelines)
  135. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  136. def test_run_with_default_model(self):
  137. pipelines = [
  138. pipeline(task=Tasks.dialog_modeling),
  139. pipeline(task=Tasks.dialog_modeling)
  140. ]
  141. self.generate_and_print_dialog_response(pipelines)
  142. if __name__ == '__main__':
  143. unittest.main()