You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ontology.py 6.2 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. all_domains = [
  2. 'restaurant', 'hotel', 'attraction', 'train', 'taxi', 'police', 'hospital'
  3. ]
  4. db_domains = ['restaurant', 'hotel', 'attraction', 'train']
  5. normlize_slot_names = {
  6. 'car type': 'car',
  7. 'entrance fee': 'price',
  8. 'duration': 'time',
  9. 'leaveat': 'leave',
  10. 'arriveby': 'arrive',
  11. 'trainid': 'id'
  12. }
  13. requestable_slots = {
  14. 'taxi': ['car', 'phone'],
  15. 'police': ['postcode', 'address', 'phone'],
  16. 'hospital': ['address', 'phone', 'postcode'],
  17. 'hotel': [
  18. 'address', 'postcode', 'internet', 'phone', 'parking', 'type',
  19. 'pricerange', 'stars', 'area', 'reference'
  20. ],
  21. 'attraction':
  22. ['price', 'type', 'address', 'postcode', 'phone', 'area', 'reference'],
  23. 'train': ['time', 'leave', 'price', 'arrive', 'id', 'reference'],
  24. 'restaurant': [
  25. 'phone', 'postcode', 'address', 'pricerange', 'food', 'area',
  26. 'reference'
  27. ]
  28. }
  29. all_reqslot = [
  30. 'car', 'address', 'postcode', 'phone', 'internet', 'parking', 'type',
  31. 'pricerange', 'food', 'stars', 'area', 'reference', 'time', 'leave',
  32. 'price', 'arrive', 'id'
  33. ]
  34. informable_slots = {
  35. 'taxi': ['leave', 'destination', 'departure', 'arrive'],
  36. 'police': [],
  37. 'hospital': ['department'],
  38. 'hotel': [
  39. 'type', 'parking', 'pricerange', 'internet', 'stay', 'day', 'people',
  40. 'area', 'stars', 'name'
  41. ],
  42. 'attraction': ['area', 'type', 'name'],
  43. 'train': ['destination', 'day', 'arrive', 'departure', 'people', 'leave'],
  44. 'restaurant':
  45. ['food', 'pricerange', 'area', 'name', 'time', 'day', 'people']
  46. }
  47. all_infslot = [
  48. 'type', 'parking', 'pricerange', 'internet', 'stay', 'day', 'people',
  49. 'area', 'stars', 'name', 'leave', 'destination', 'departure', 'arrive',
  50. 'department', 'food', 'time'
  51. ]
  52. all_slots = all_reqslot + [
  53. 'stay', 'day', 'people', 'name', 'destination', 'departure', 'department'
  54. ]
  55. get_slot = {}
  56. for s in all_slots:
  57. get_slot[s] = 1
  58. # mapping slots in dialogue act to original goal slot names
  59. da_abbr_to_slot_name = {
  60. 'addr': 'address',
  61. 'fee': 'price',
  62. 'post': 'postcode',
  63. 'ref': 'reference',
  64. 'ticket': 'price',
  65. 'depart': 'departure',
  66. 'dest': 'destination',
  67. }
  68. dialog_acts = {
  69. 'restaurant': [
  70. 'inform', 'request', 'nooffer', 'recommend', 'select', 'offerbook',
  71. 'offerbooked', 'nobook'
  72. ],
  73. 'hotel': [
  74. 'inform', 'request', 'nooffer', 'recommend', 'select', 'offerbook',
  75. 'offerbooked', 'nobook'
  76. ],
  77. 'attraction': ['inform', 'request', 'nooffer', 'recommend', 'select'],
  78. 'train':
  79. ['inform', 'request', 'nooffer', 'offerbook', 'offerbooked', 'select'],
  80. 'taxi': ['inform', 'request'],
  81. 'police': ['inform', 'request'],
  82. 'hospital': ['inform', 'request'],
  83. # 'booking': ['book', 'inform', 'nobook', 'request'],
  84. 'general': ['bye', 'greet', 'reqmore', 'welcome'],
  85. }
  86. all_acts = []
  87. for acts in dialog_acts.values():
  88. for act in acts:
  89. if act not in all_acts:
  90. all_acts.append(act)
  91. dialog_act_params = {
  92. 'inform': all_slots + ['choice', 'open'],
  93. 'request': all_infslot + ['choice', 'price'],
  94. 'nooffer': all_slots + ['choice'],
  95. 'recommend': all_reqslot + ['choice', 'open'],
  96. 'select': all_slots + ['choice'],
  97. # 'book': ['time', 'people', 'stay', 'reference', 'day', 'name', 'choice'],
  98. 'nobook': ['time', 'people', 'stay', 'reference', 'day', 'name', 'choice'],
  99. 'offerbook': all_slots + ['choice'],
  100. 'offerbooked': all_slots + ['choice'],
  101. 'reqmore': [],
  102. 'welcome': [],
  103. 'bye': [],
  104. 'greet': [],
  105. }
  106. dialog_act_all_slots = all_slots + ['choice', 'open']
  107. # special slot tokens in belief span
  108. # no need of this, just covert slot to [slot] e.g. pricerange -> [pricerange]
  109. slot_name_to_slot_token = {}
  110. # special slot tokens in responses
  111. # not use at the momoent
  112. slot_name_to_value_token = {
  113. # 'entrance fee': '[value_price]',
  114. # 'pricerange': '[value_price]',
  115. # 'arriveby': '[value_time]',
  116. # 'leaveat': '[value_time]',
  117. # 'departure': '[value_place]',
  118. # 'destination': '[value_place]',
  119. # 'stay': 'count',
  120. # 'people': 'count'
  121. }
  122. # eos tokens definition
  123. eos_tokens = {
  124. 'user': '<eos_u>',
  125. 'user_delex': '<eos_u>',
  126. 'resp': '<eos_r>',
  127. 'resp_gen': '<eos_r>',
  128. 'pv_resp': '<eos_r>',
  129. 'bspn': '<eos_b>',
  130. 'bspn_gen': '<eos_b>',
  131. 'pv_bspn': '<eos_b>',
  132. 'bsdx': '<eos_b>',
  133. 'bsdx_gen': '<eos_b>',
  134. 'pv_bsdx': '<eos_b>',
  135. 'qspn': '<eos_q>',
  136. 'qspn_gen': '<eos_q>',
  137. 'pv_qspn': '<eos_q>',
  138. 'aspn': '<eos_a>',
  139. 'aspn_gen': '<eos_a>',
  140. 'pv_aspn': '<eos_a>',
  141. 'dspn': '<eos_d>',
  142. 'dspn_gen': '<eos_d>',
  143. 'pv_dspn': '<eos_d>'
  144. }
  145. # sos tokens definition
  146. sos_tokens = {
  147. 'user': '<sos_u>',
  148. 'user_delex': '<sos_u>',
  149. 'resp': '<sos_r>',
  150. 'resp_gen': '<sos_r>',
  151. 'pv_resp': '<sos_r>',
  152. 'bspn': '<sos_b>',
  153. 'bspn_gen': '<sos_b>',
  154. 'pv_bspn': '<sos_b>',
  155. 'bsdx': '<sos_b>',
  156. 'bsdx_gen': '<sos_b>',
  157. 'pv_bsdx': '<sos_b>',
  158. 'qspn': '<sos_q>',
  159. 'qspn_gen': '<sos_q>',
  160. 'pv_qspn': '<sos_q>',
  161. 'aspn': '<sos_a>',
  162. 'aspn_gen': '<sos_a>',
  163. 'pv_aspn': '<sos_a>',
  164. 'dspn': '<sos_d>',
  165. 'dspn_gen': '<sos_d>',
  166. 'pv_dspn': '<sos_d>'
  167. }
  168. # db tokens definition
  169. db_tokens = [
  170. '<sos_db>', '<eos_db>', '[book_nores]', '[book_fail]', '[book_success]',
  171. '[db_nores]', '[db_0]', '[db_1]', '[db_2]', '[db_3]'
  172. ]
  173. # understand tokens definition
  174. def get_understand_tokens(prompt_num_for_understand):
  175. understand_tokens = []
  176. for i in range(prompt_num_for_understand):
  177. understand_tokens.append(f'<understand_{i}>')
  178. return understand_tokens
  179. # policy tokens definition
  180. def get_policy_tokens(prompt_num_for_policy):
  181. policy_tokens = []
  182. for i in range(prompt_num_for_policy):
  183. policy_tokens.append(f'<policy_{i}>')
  184. return policy_tokens
  185. # all special tokens definition
  186. def get_special_tokens(other_tokens):
  187. special_tokens = ['<go_r>', '<go_b>', '<go_a>', '<go_d>',
  188. '<eos_u>', '<eos_r>', '<eos_b>', '<eos_a>', '<eos_d>', '<eos_q>',
  189. '<sos_u>', '<sos_r>', '<sos_b>', '<sos_a>', '<sos_d>', '<sos_q>'] \
  190. + db_tokens + other_tokens
  191. return special_tokens

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展