You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dialog_modeling.py 6.0 kB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. from modelscope.hub.snapshot_download import snapshot_download
  4. from modelscope.models import Model
  5. from modelscope.models.nlp import SpaceForDialogModeling
  6. from modelscope.pipelines import DialogModelingPipeline, pipeline
  7. from modelscope.preprocessors import DialogModelingPreprocessor
  8. from modelscope.utils.constant import Tasks
  9. from modelscope.utils.test_utils import test_level
  10. class DialogModelingTest(unittest.TestCase):
  11. model_id = 'damo/nlp_space_dialog-modeling'
  12. test_case = {
  13. 'sng0073': {
  14. 'goal': {
  15. 'taxi': {
  16. 'info': {
  17. 'leaveat': '17:15',
  18. 'destination': 'pizza hut fen ditton',
  19. 'departure': "saint john's college"
  20. },
  21. 'reqt': ['car', 'phone'],
  22. 'fail_info': {}
  23. }
  24. },
  25. 'log': [{
  26. 'user':
  27. "i would like a taxi from saint john 's college to pizza hut fen ditton .",
  28. 'user_delex':
  29. 'i would like a taxi from [value_departure] to [value_destination] .',
  30. 'resp':
  31. 'what time do you want to leave and what time do you want to arrive by ?',
  32. 'sys':
  33. 'what time do you want to leave and what time do you want to arrive by ?',
  34. 'pointer': '0,0,0,0,0,0',
  35. 'match': '',
  36. 'constraint':
  37. "[taxi] destination pizza hut fen ditton departure saint john 's college",
  38. 'cons_delex': '[taxi] destination departure',
  39. 'sys_act': '[taxi] [request] leave arrive',
  40. 'turn_num': 0,
  41. 'turn_domain': '[taxi]'
  42. }, {
  43. 'user': 'i want to leave after 17:15 .',
  44. 'user_delex': 'i want to leave after [value_leave] .',
  45. 'resp':
  46. 'booking completed ! your taxi will be [value_car] contact number is [value_phone]',
  47. 'sys':
  48. 'booking completed ! your taxi will be blue honda contact number is 07218068540',
  49. 'pointer': '0,0,0,0,0,0',
  50. 'match': '',
  51. 'constraint':
  52. "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15",
  53. 'cons_delex': '[taxi] destination departure leave',
  54. 'sys_act': '[taxi] [inform] car phone',
  55. 'turn_num': 1,
  56. 'turn_domain': '[taxi]'
  57. }, {
  58. 'user': 'thank you for all the help ! i appreciate it .',
  59. 'user_delex': 'thank you for all the help ! i appreciate it .',
  60. 'resp':
  61. 'you are welcome . is there anything else i can help you with today ?',
  62. 'sys':
  63. 'you are welcome . is there anything else i can help you with today ?',
  64. 'pointer': '0,0,0,0,0,0',
  65. 'match': '',
  66. 'constraint':
  67. "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15",
  68. 'cons_delex': '[taxi] destination departure leave',
  69. 'sys_act': '[general] [reqmore]',
  70. 'turn_num': 2,
  71. 'turn_domain': '[general]'
  72. }, {
  73. 'user': 'no , i am all set . have a nice day . bye .',
  74. 'user_delex': 'no , i am all set . have a nice day . bye .',
  75. 'resp': 'you too ! thank you',
  76. 'sys': 'you too ! thank you',
  77. 'pointer': '0,0,0,0,0,0',
  78. 'match': '',
  79. 'constraint':
  80. "[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15",
  81. 'cons_delex': '[taxi] destination departure leave',
  82. 'sys_act': '[general] [bye]',
  83. 'turn_num': 3,
  84. 'turn_domain': '[general]'
  85. }]
  86. }
  87. }
  88. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  89. def test_run(self):
  90. cache_path = snapshot_download(self.model_id)
  91. preprocessor = DialogModelingPreprocessor(model_dir=cache_path)
  92. model = SpaceForDialogModeling(
  93. model_dir=cache_path,
  94. text_field=preprocessor.text_field,
  95. config=preprocessor.config)
  96. pipelines = [
  97. DialogModelingPipeline(model=model, preprocessor=preprocessor),
  98. pipeline(
  99. task=Tasks.dialog_modeling,
  100. model=model,
  101. preprocessor=preprocessor)
  102. ]
  103. result = {}
  104. for step, item in enumerate(self.test_case['sng0073']['log']):
  105. user = item['user']
  106. print('user: {}'.format(user))
  107. result = pipelines[step % 2]({
  108. 'user_input': user,
  109. 'history': result
  110. })
  111. print('sys : {}'.format(result['sys']))
  112. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  113. def test_run_with_model_from_modelhub(self):
  114. model = Model.from_pretrained(self.model_id)
  115. preprocessor = DialogModelingPreprocessor(model_dir=model.model_dir)
  116. pipelines = [
  117. DialogModelingPipeline(model=model, preprocessor=preprocessor),
  118. pipeline(
  119. task=Tasks.dialog_modeling,
  120. model=model,
  121. preprocessor=preprocessor)
  122. ]
  123. result = {}
  124. for step, item in enumerate(self.test_case['sng0073']['log']):
  125. user = item['user']
  126. print('user: {}'.format(user))
  127. result = pipelines[step % 2]({
  128. 'user_input': user,
  129. 'history': result
  130. })
  131. print('sys : {}'.format(result['sys']))
  132. if __name__ == '__main__':
  133. unittest.main()