You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dialog_state_tracking.py 5.4 kB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. from typing import List
  4. from modelscope.hub.snapshot_download import snapshot_download
  5. from modelscope.models import Model, SpaceForDialogStateTracking
  6. from modelscope.pipelines import DialogStateTrackingPipeline, pipeline
  7. from modelscope.preprocessors import DialogStateTrackingPreprocessor
  8. from modelscope.utils.constant import Tasks
  9. from modelscope.utils.test_utils import test_level
  10. class DialogStateTrackingTest(unittest.TestCase):
  11. model_id = 'damo/nlp_space_dialog-state-tracking'
  12. test_case = [{
  13. 'User-1':
  14. 'Hi, I\'m looking for a train that is going to cambridge and arriving there by 20:45, '
  15. 'is there anything like that?'
  16. }, {
  17. 'System-1':
  18. 'There are over 1,000 trains like that. Where will you be departing from?',
  19. 'Dialog_Act-1': {
  20. 'Train-Inform': [['Choice', 'over 1'], ['Choice', '000']],
  21. 'Train-Request': [['Depart', '?']]
  22. },
  23. 'User-2': 'I am departing from birmingham new street.'
  24. }, {
  25. 'System-2': 'Can you confirm your desired travel day?',
  26. 'Dialog_Act-2': {
  27. 'Train-Request': [['Day', '?']]
  28. },
  29. 'User-3': 'I would like to leave on wednesday'
  30. }, {
  31. 'System-3':
  32. 'I show a train leaving birmingham new street at 17:40 and arriving at 20:23 on Wednesday. '
  33. 'Will this work for you?',
  34. 'Dialog_Act-3': {
  35. 'Train-Inform': [['Arrive', '20:23'], ['Leave', '17:40'],
  36. ['Day', 'Wednesday'],
  37. ['Depart', 'birmingham new street']]
  38. },
  39. 'User-4':
  40. 'That will, yes. Please make a booking for 5 people please.',
  41. }, {
  42. 'System-4':
  43. 'I\'ve booked your train tickets, and your reference number is A9NHSO9Y.',
  44. 'Dialog_Act-4': {
  45. 'Train-OfferBooked': [['Ref', 'A9NHSO9Y']]
  46. },
  47. 'User-5':
  48. 'Thanks so much. I would also need a place to say. '
  49. 'I am looking for something with 4 stars and has free wifi.'
  50. }, {
  51. 'System-5':
  52. 'How about the cambridge belfry? '
  53. 'It has all the attributes you requested and a great name! '
  54. 'Maybe even a real belfry?',
  55. 'Dialog_Act-5': {
  56. 'Hotel-Recommend': [['Name', 'the cambridge belfry']]
  57. },
  58. 'User-6':
  59. 'That sounds great, could you make a booking for me please?',
  60. }, {
  61. 'System-6':
  62. 'What day would you like your booking for?',
  63. 'Dialog_Act-6': {
  64. 'Booking-Request': [['Day', '?']]
  65. },
  66. 'User-7':
  67. 'Please book it for Wednesday for 5 people and 5 nights, please.',
  68. }, {
  69. 'System-7': 'Booking was successful. Reference number is : 5NAWGJDC.',
  70. 'Dialog_Act-7': {
  71. 'Booking-Book': [['Ref', '5NAWGJDC']]
  72. },
  73. 'User-8': 'Thank you, goodbye',
  74. }]
  75. def tracking_and_print_dialog_states(
  76. self, pipelines: List[DialogStateTrackingPipeline]):
  77. import json
  78. pipelines_len = len(pipelines)
  79. history_states = [{}]
  80. utter = {}
  81. for step, item in enumerate(self.test_case):
  82. utter.update(item)
  83. result = pipelines[step % pipelines_len]({
  84. 'utter':
  85. utter,
  86. 'history_states':
  87. history_states
  88. })
  89. print(json.dumps(result))
  90. history_states.extend([result['dialog_states'], {}])
  91. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  92. def test_run_by_direct_model_download(self):
  93. cache_path = snapshot_download(self.model_id)
  94. model = SpaceForDialogStateTracking(cache_path)
  95. preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path)
  96. pipelines = [
  97. DialogStateTrackingPipeline(
  98. model=model, preprocessor=preprocessor),
  99. pipeline(
  100. task=Tasks.dialog_state_tracking,
  101. model=model,
  102. preprocessor=preprocessor)
  103. ]
  104. self.tracking_and_print_dialog_states(pipelines)
  105. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  106. def test_run_with_model_from_modelhub(self):
  107. model = Model.from_pretrained(self.model_id)
  108. preprocessor = DialogStateTrackingPreprocessor(
  109. model_dir=model.model_dir)
  110. pipelines = [
  111. DialogStateTrackingPipeline(
  112. model=model, preprocessor=preprocessor),
  113. pipeline(
  114. task=Tasks.dialog_state_tracking,
  115. model=model,
  116. preprocessor=preprocessor)
  117. ]
  118. self.tracking_and_print_dialog_states(pipelines)
  119. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  120. def test_run_with_model_name(self):
  121. pipelines = [
  122. pipeline(task=Tasks.dialog_state_tracking, model=self.model_id)
  123. ]
  124. self.tracking_and_print_dialog_states(pipelines)
  125. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  126. def test_run_with_default_model(self):
  127. pipelines = [pipeline(task=Tasks.dialog_state_tracking)]
  128. self.tracking_and_print_dialog_states(pipelines)
  129. if __name__ == '__main__':
  130. unittest.main()