You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

db_ops.py 11 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. import os
  2. import random
  3. import sqlite3
  4. import json
  5. from .ontology import all_domains, db_domains
  6. class MultiWozDB(object):
  7. def __init__(self, db_dir, db_paths):
  8. self.dbs = {}
  9. self.sql_dbs = {}
  10. for domain in all_domains:
  11. with open(os.path.join(db_dir, db_paths[domain]), 'r') as f:
  12. self.dbs[domain] = json.loads(f.read().lower())
  13. def oneHotVector(self, domain, num):
  14. """Return number of available entities for particular domain."""
  15. vector = [0, 0, 0, 0]
  16. if num == '':
  17. return vector
  18. if domain != 'train':
  19. if num == 0:
  20. vector = [1, 0, 0, 0]
  21. elif num == 1:
  22. vector = [0, 1, 0, 0]
  23. elif num <= 3:
  24. vector = [0, 0, 1, 0]
  25. else:
  26. vector = [0, 0, 0, 1]
  27. else:
  28. if num == 0:
  29. vector = [1, 0, 0, 0]
  30. elif num <= 5:
  31. vector = [0, 1, 0, 0]
  32. elif num <= 10:
  33. vector = [0, 0, 1, 0]
  34. else:
  35. vector = [0, 0, 0, 1]
  36. return vector
  37. def addBookingPointer(self, turn_da):
  38. """Add information about availability of the booking option."""
  39. # Booking pointer
  40. # Do not consider booking two things in a single turn.
  41. vector = [0, 0]
  42. if turn_da.get('booking-nobook'):
  43. vector = [1, 0]
  44. if turn_da.get('booking-book') or turn_da.get('train-offerbooked'):
  45. vector = [0, 1]
  46. return vector
  47. def addDBPointer(self, domain, match_num, return_num=False):
  48. """Create database pointer for all related domains."""
  49. # if turn_domains is None:
  50. # turn_domains = db_domains
  51. if domain in db_domains:
  52. vector = self.oneHotVector(domain, match_num)
  53. else:
  54. vector = [0, 0, 0, 0]
  55. return vector
  56. def addDBIndicator(self, domain, match_num, return_num=False):
  57. """Create database indicator for all related domains."""
  58. # if turn_domains is None:
  59. # turn_domains = db_domains
  60. if domain in db_domains:
  61. vector = self.oneHotVector(domain, match_num)
  62. else:
  63. vector = [0, 0, 0, 0]
  64. # '[db_nores]', '[db_0]', '[db_1]', '[db_2]', '[db_3]'
  65. if vector == [0, 0, 0, 0]:
  66. indicator = '[db_nores]'
  67. else:
  68. indicator = '[db_%s]' % vector.index(1)
  69. return indicator
  70. def get_match_num(self, constraints, return_entry=False):
  71. """Create database pointer for all related domains."""
  72. match = {'general': ''}
  73. entry = {}
  74. # if turn_domains is None:
  75. # turn_domains = db_domains
  76. for domain in all_domains:
  77. match[domain] = ''
  78. if domain in db_domains and constraints.get(domain):
  79. matched_ents = self.queryJsons(domain, constraints[domain])
  80. match[domain] = len(matched_ents)
  81. if return_entry:
  82. entry[domain] = matched_ents
  83. if return_entry:
  84. return entry
  85. return match
  86. def pointerBack(self, vector, domain):
  87. # multi domain implementation
  88. # domnum = cfg.domain_num
  89. if domain.endswith(']'):
  90. domain = domain[1:-1]
  91. if domain != 'train':
  92. nummap = {0: '0', 1: '1', 2: '2-3', 3: '>3'}
  93. else:
  94. nummap = {0: '0', 1: '1-5', 2: '6-10', 3: '>10'}
  95. if vector[:4] == [0, 0, 0, 0]:
  96. report = ''
  97. else:
  98. num = vector.index(1)
  99. report = domain + ': ' + nummap[num] + '; '
  100. if vector[-2] == 0 and vector[-1] == 1:
  101. report += 'booking: ok'
  102. if vector[-2] == 1 and vector[-1] == 0:
  103. report += 'booking: unable'
  104. return report
  105. def queryJsons(self,
  106. domain,
  107. constraints,
  108. exactly_match=True,
  109. return_name=False):
  110. """Returns the list of entities for a given domain
  111. based on the annotation of the belief state
  112. constraints: dict e.g. {'pricerange': 'cheap', 'area': 'west'}
  113. """
  114. # query the db
  115. if domain == 'taxi':
  116. return [{
  117. 'taxi_colors':
  118. random.choice(self.dbs[domain]['taxi_colors']),
  119. 'taxi_types':
  120. random.choice(self.dbs[domain]['taxi_types']),
  121. 'taxi_phone': [random.randint(1, 9) for _ in range(10)]
  122. }]
  123. if domain == 'police':
  124. return self.dbs['police']
  125. if domain == 'hospital':
  126. if constraints.get('department'):
  127. for entry in self.dbs['hospital']:
  128. if entry.get('department') == constraints.get(
  129. 'department'):
  130. return [entry]
  131. else:
  132. return []
  133. valid_cons = False
  134. for v in constraints.values():
  135. if v not in ['not mentioned', '']:
  136. valid_cons = True
  137. if not valid_cons:
  138. return []
  139. match_result = []
  140. if 'name' in constraints:
  141. for db_ent in self.dbs[domain]:
  142. if 'name' in db_ent:
  143. cons = constraints['name']
  144. dbn = db_ent['name']
  145. if cons == dbn:
  146. db_ent = db_ent if not return_name else db_ent['name']
  147. match_result.append(db_ent)
  148. return match_result
  149. for db_ent in self.dbs[domain]:
  150. match = True
  151. for s, v in constraints.items():
  152. if s == 'name':
  153. continue
  154. if s in ['people', 'stay'] or (domain == 'hotel' and s == 'day') or \
  155. (domain == 'restaurant' and s in ['day', 'time']):
  156. # 因为这些inform slot属于book info,而数据库中没有这些slot;
  157. # 能否book是根据user goal中的信息判断,而非通过数据库查询;
  158. continue
  159. skip_case = {
  160. "don't care": 1,
  161. "do n't care": 1,
  162. 'dont care': 1,
  163. 'not mentioned': 1,
  164. 'dontcare': 1,
  165. '': 1
  166. }
  167. if skip_case.get(v):
  168. continue
  169. if s not in db_ent:
  170. # logging.warning('Searching warning: slot %s not in %s db'%(s, domain))
  171. match = False
  172. break
  173. # v = 'guesthouse' if v == 'guest house' else v
  174. # v = 'swimmingpool' if v == 'swimming pool' else v
  175. v = 'yes' if v == 'free' else v
  176. if s in ['arrive', 'leave']:
  177. try:
  178. h, m = v.split(
  179. ':'
  180. ) # raise error if time value is not xx:xx format
  181. v = int(h) * 60 + int(m)
  182. except:
  183. match = False
  184. break
  185. time = int(db_ent[s].split(':')[0]) * 60 + int(
  186. db_ent[s].split(':')[1])
  187. if s == 'arrive' and v > time:
  188. match = False
  189. if s == 'leave' and v < time:
  190. match = False
  191. else:
  192. if exactly_match and v != db_ent[s]:
  193. match = False
  194. break
  195. elif v not in db_ent[s]:
  196. match = False
  197. break
  198. if match:
  199. match_result.append(db_ent)
  200. if not return_name:
  201. return match_result
  202. else:
  203. if domain == 'train':
  204. match_result = [e['id'] for e in match_result]
  205. else:
  206. match_result = [e['name'] for e in match_result]
  207. return match_result
  208. def querySQL(self, domain, constraints):
  209. if not self.sql_dbs:
  210. for dom in db_domains:
  211. db = 'db/{}-dbase.db'.format(dom)
  212. conn = sqlite3.connect(db)
  213. c = conn.cursor()
  214. self.sql_dbs[dom] = c
  215. sql_query = 'select * from {}'.format(domain)
  216. flag = True
  217. for key, val in constraints.items():
  218. if val == '' or val == 'dontcare' or val == 'not mentioned' or val == "don't care" or val == 'dont care' or val == "do n't care":
  219. pass
  220. else:
  221. if flag:
  222. sql_query += ' where '
  223. val2 = val.replace("'", "''")
  224. # val2 = normalize(val2)
  225. if key == 'leaveAt':
  226. sql_query += r' ' + key + ' > ' + r"'" + val2 + r"'"
  227. elif key == 'arriveBy':
  228. sql_query += r' ' + key + ' < ' + r"'" + val2 + r"'"
  229. else:
  230. sql_query += r' ' + key + '=' + r"'" + val2 + r"'"
  231. flag = False
  232. else:
  233. val2 = val.replace("'", "''")
  234. # val2 = normalize(val2)
  235. if key == 'leaveAt':
  236. sql_query += r' and ' + key + ' > ' + r"'" + val2 + r"'"
  237. elif key == 'arriveBy':
  238. sql_query += r' and ' + key + ' < ' + r"'" + val2 + r"'"
  239. else:
  240. sql_query += r' and ' + key + '=' + r"'" + val2 + r"'"
  241. try: # "select * from attraction where name = 'queens college'"
  242. print(sql_query)
  243. return self.sql_dbs[domain].execute(sql_query).fetchall()
  244. except:
  245. return [] # TODO test it
  246. if __name__ == '__main__':
  247. dbPATHs = {
  248. 'attraction': 'db/attraction_db_processed.json',
  249. 'hospital': 'db/hospital_db_processed.json',
  250. 'hotel': 'db/hotel_db_processed.json',
  251. 'police': 'db/police_db_processed.json',
  252. 'restaurant': 'db/restaurant_db_processed.json',
  253. 'taxi': 'db/taxi_db_processed.json',
  254. 'train': 'db/train_db_processed.json',
  255. }
  256. db = MultiWozDB(dbPATHs)
  257. while True:
  258. constraints = {}
  259. inp = input(
  260. 'input belief state in fomat: domain-slot1=value1;slot2=value2...\n'
  261. )
  262. domain, cons = inp.split('-')
  263. for sv in cons.split(';'):
  264. s, v = sv.split('=')
  265. constraints[s] = v
  266. # res = db.querySQL(domain, constraints)
  267. res = db.queryJsons(domain, constraints, return_name=True)
  268. report = []
  269. reidx = {
  270. 'hotel': 8,
  271. 'restaurant': 6,
  272. 'attraction': 5,
  273. 'train': 1,
  274. }
  275. # for ent in res:
  276. # if reidx.get(domain):
  277. # report.append(ent[reidx[domain]])
  278. # for ent in res:
  279. # if 'name' in ent:
  280. # report.append(ent['name'])
  281. # if 'trainid' in ent:
  282. # report.append(ent['trainid'])
  283. print(constraints)
  284. print(res)
  285. print('count:', len(res), '\nnames:', report)

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展