You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dialog_state_tracking.py 4.2 kB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import os.path as osp
  4. import tempfile
  5. import unittest
  6. from modelscope.hub.snapshot_download import snapshot_download
  7. from modelscope.models import Model, SpaceForDialogStateTrackingModel
  8. from modelscope.pipelines import DialogStateTrackingPipeline, pipeline
  9. from modelscope.preprocessors import DialogStateTrackingPreprocessor
  10. from modelscope.utils.constant import Tasks
  11. class DialogStateTrackingTest(unittest.TestCase):
  12. model_id = 'damo/nlp_space_dialog-state-tracking'
  13. test_case = [{
  14. 'User-1':
  15. 'Hi, I\'m looking for a train that is going to cambridge and arriving there by 20:45, '
  16. 'is there anything like that?'
  17. }, {
  18. 'System-1':
  19. 'There are over 1,000 trains like that. Where will you be departing from?',
  20. 'Dialog_Act-1': {
  21. 'Train-Inform': [['Choice', 'over 1'], ['Choice', '000']],
  22. 'Train-Request': [['Depart', '?']]
  23. },
  24. 'User-2': 'I am departing from birmingham new street.'
  25. }, {
  26. 'System-2': 'Can you confirm your desired travel day?',
  27. 'Dialog_Act-2': {
  28. 'Train-Request': [['Day', '?']]
  29. },
  30. 'User-3': 'I would like to leave on wednesday'
  31. }, {
  32. 'System-3':
  33. 'I show a train leaving birmingham new street at 17:40 and arriving at 20:23 on Wednesday. '
  34. 'Will this work for you?',
  35. 'Dialog_Act-3': {
  36. 'Train-Inform': [['Arrive', '20:23'], ['Leave', '17:40'],
  37. ['Day', 'Wednesday'],
  38. ['Depart', 'birmingham new street']]
  39. },
  40. 'User-4':
  41. 'That will, yes. Please make a booking for 5 people please.',
  42. }, {
  43. 'System-4':
  44. 'I\'ve booked your train tickets, and your reference number is A9NHSO9Y.',
  45. 'Dialog_Act-4': {
  46. 'Train-OfferBooked': [['Ref', 'A9NHSO9Y']]
  47. },
  48. 'User-5':
  49. 'Thanks so much. I would also need a place to say. '
  50. 'I am looking for something with 4 stars and has free wifi.'
  51. }, {
  52. 'System-5':
  53. 'How about the cambridge belfry? '
  54. 'It has all the attributes you requested and a great name! '
  55. 'Maybe even a real belfry?',
  56. 'Dialog_Act-5': {
  57. 'Hotel-Recommend': [['Name', 'the cambridge belfry']]
  58. },
  59. 'User-6':
  60. 'That sounds great, could you make a booking for me please?',
  61. }, {
  62. 'System-6':
  63. 'What day would you like your booking for?',
  64. 'Dialog_Act-6': {
  65. 'Booking-Request': [['Day', '?']]
  66. },
  67. 'User-7':
  68. 'Please book it for Wednesday for 5 people and 5 nights, please.',
  69. }, {
  70. 'System-7': 'Booking was successful. Reference number is : 5NAWGJDC.',
  71. 'Dialog_Act-7': {
  72. 'Booking-Book': [['Ref', '5NAWGJDC']]
  73. },
  74. 'User-8': 'Thank you, goodbye',
  75. }]
  76. def test_run(self):
  77. cache_path = '/Users/yangliu/Space/maas_model/nlp_space_dialog-state-tracking'
  78. # cache_path = snapshot_download(self.model_id)
  79. model = SpaceForDialogStateTrackingModel(cache_path)
  80. preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path)
  81. pipelines = [
  82. DialogStateTrackingPipeline(
  83. model=model, preprocessor=preprocessor),
  84. pipeline(
  85. task=Tasks.dialog_state_tracking,
  86. model=model,
  87. preprocessor=preprocessor)
  88. ]
  89. pipelines_len = len(pipelines)
  90. import json
  91. for _test_case in self.test_case:
  92. history_states = [{}]
  93. utter = {}
  94. for step, item in enumerate(_test_case):
  95. utter.update(item)
  96. result = pipelines[step % pipelines_len]({
  97. 'utter':
  98. utter,
  99. 'history_states':
  100. history_states
  101. })
  102. print(json.dumps(result))
  103. history_states.extend([result['dialog_states'], {}])
  104. @unittest.skip('test with snapshot_download')
  105. def test_run_with_model_from_modelhub(self):
  106. pass
  107. if __name__ == '__main__':
  108. unittest.main()