|
- import os
- import random
- import sqlite3
-
- import json
-
- from .ontology import all_domains, db_domains
-
-
- class MultiWozDB(object):
-
- def __init__(self, db_dir, db_paths):
- self.dbs = {}
- self.sql_dbs = {}
- for domain in all_domains:
- with open(os.path.join(db_dir, db_paths[domain]), 'r') as f:
- self.dbs[domain] = json.loads(f.read().lower())
-
- def oneHotVector(self, domain, num):
- """Return number of available entities for particular domain."""
- vector = [0, 0, 0, 0]
- if num == '':
- return vector
- if domain != 'train':
- if num == 0:
- vector = [1, 0, 0, 0]
- elif num == 1:
- vector = [0, 1, 0, 0]
- elif num <= 3:
- vector = [0, 0, 1, 0]
- else:
- vector = [0, 0, 0, 1]
- else:
- if num == 0:
- vector = [1, 0, 0, 0]
- elif num <= 5:
- vector = [0, 1, 0, 0]
- elif num <= 10:
- vector = [0, 0, 1, 0]
- else:
- vector = [0, 0, 0, 1]
- return vector
-
- def addBookingPointer(self, turn_da):
- """Add information about availability of the booking option."""
- # Booking pointer
- # Do not consider booking two things in a single turn.
- vector = [0, 0]
- if turn_da.get('booking-nobook'):
- vector = [1, 0]
- if turn_da.get('booking-book') or turn_da.get('train-offerbooked'):
- vector = [0, 1]
- return vector
-
- def addDBPointer(self, domain, match_num, return_num=False):
- """Create database pointer for all related domains."""
- # if turn_domains is None:
- # turn_domains = db_domains
- if domain in db_domains:
- vector = self.oneHotVector(domain, match_num)
- else:
- vector = [0, 0, 0, 0]
- return vector
-
- def addDBIndicator(self, domain, match_num, return_num=False):
- """Create database indicator for all related domains."""
- # if turn_domains is None:
- # turn_domains = db_domains
- if domain in db_domains:
- vector = self.oneHotVector(domain, match_num)
- else:
- vector = [0, 0, 0, 0]
-
- # '[db_nores]', '[db_0]', '[db_1]', '[db_2]', '[db_3]'
- if vector == [0, 0, 0, 0]:
- indicator = '[db_nores]'
- else:
- indicator = '[db_%s]' % vector.index(1)
- return indicator
-
- def get_match_num(self, constraints, return_entry=False):
- """Create database pointer for all related domains."""
- match = {'general': ''}
- entry = {}
- # if turn_domains is None:
- # turn_domains = db_domains
- for domain in all_domains:
- match[domain] = ''
- if domain in db_domains and constraints.get(domain):
- matched_ents = self.queryJsons(domain, constraints[domain])
- match[domain] = len(matched_ents)
- if return_entry:
- entry[domain] = matched_ents
- if return_entry:
- return entry
- return match
-
- def pointerBack(self, vector, domain):
- # multi domain implementation
- # domnum = cfg.domain_num
- if domain.endswith(']'):
- domain = domain[1:-1]
- if domain != 'train':
- nummap = {0: '0', 1: '1', 2: '2-3', 3: '>3'}
- else:
- nummap = {0: '0', 1: '1-5', 2: '6-10', 3: '>10'}
- if vector[:4] == [0, 0, 0, 0]:
- report = ''
- else:
- num = vector.index(1)
- report = domain + ': ' + nummap[num] + '; '
-
- if vector[-2] == 0 and vector[-1] == 1:
- report += 'booking: ok'
- if vector[-2] == 1 and vector[-1] == 0:
- report += 'booking: unable'
-
- return report
-
- def queryJsons(self,
- domain,
- constraints,
- exactly_match=True,
- return_name=False):
- """Returns the list of entities for a given domain
- based on the annotation of the belief state
- constraints: dict e.g. {'pricerange': 'cheap', 'area': 'west'}
- """
- # query the db
- if domain == 'taxi':
- return [{
- 'taxi_colors':
- random.choice(self.dbs[domain]['taxi_colors']),
- 'taxi_types':
- random.choice(self.dbs[domain]['taxi_types']),
- 'taxi_phone': [random.randint(1, 9) for _ in range(10)]
- }]
- if domain == 'police':
- return self.dbs['police']
- if domain == 'hospital':
- if constraints.get('department'):
- for entry in self.dbs['hospital']:
- if entry.get('department') == constraints.get(
- 'department'):
- return [entry]
- else:
- return []
-
- valid_cons = False
- for v in constraints.values():
- if v not in ['not mentioned', '']:
- valid_cons = True
- if not valid_cons:
- return []
-
- match_result = []
-
- if 'name' in constraints:
- for db_ent in self.dbs[domain]:
- if 'name' in db_ent:
- cons = constraints['name']
- dbn = db_ent['name']
- if cons == dbn:
- db_ent = db_ent if not return_name else db_ent['name']
- match_result.append(db_ent)
- return match_result
-
- for db_ent in self.dbs[domain]:
- match = True
- for s, v in constraints.items():
- if s == 'name':
- continue
- if s in ['people', 'stay'] or (domain == 'hotel' and s == 'day') or \
- (domain == 'restaurant' and s in ['day', 'time']):
- # 因为这些inform slot属于book info,而数据库中没有这些slot;
- # 能否book是根据user goal中的信息判断,而非通过数据库查询;
- continue
-
- skip_case = {
- "don't care": 1,
- "do n't care": 1,
- 'dont care': 1,
- 'not mentioned': 1,
- 'dontcare': 1,
- '': 1
- }
- if skip_case.get(v):
- continue
-
- if s not in db_ent:
- # logging.warning('Searching warning: slot %s not in %s db'%(s, domain))
- match = False
- break
-
- # v = 'guesthouse' if v == 'guest house' else v
- # v = 'swimmingpool' if v == 'swimming pool' else v
- v = 'yes' if v == 'free' else v
-
- if s in ['arrive', 'leave']:
- try:
- h, m = v.split(
- ':'
- ) # raise error if time value is not xx:xx format
- v = int(h) * 60 + int(m)
- except:
- match = False
- break
- time = int(db_ent[s].split(':')[0]) * 60 + int(
- db_ent[s].split(':')[1])
- if s == 'arrive' and v > time:
- match = False
- if s == 'leave' and v < time:
- match = False
- else:
- if exactly_match and v != db_ent[s]:
- match = False
- break
- elif v not in db_ent[s]:
- match = False
- break
-
- if match:
- match_result.append(db_ent)
-
- if not return_name:
- return match_result
- else:
- if domain == 'train':
- match_result = [e['id'] for e in match_result]
- else:
- match_result = [e['name'] for e in match_result]
- return match_result
-
- def querySQL(self, domain, constraints):
- if not self.sql_dbs:
- for dom in db_domains:
- db = 'db/{}-dbase.db'.format(dom)
- conn = sqlite3.connect(db)
- c = conn.cursor()
- self.sql_dbs[dom] = c
-
- sql_query = 'select * from {}'.format(domain)
-
- flag = True
- for key, val in constraints.items():
- if val == '' or val == 'dontcare' or val == 'not mentioned' or val == "don't care" or val == 'dont care' or val == "do n't care":
- pass
- else:
- if flag:
- sql_query += ' where '
- val2 = val.replace("'", "''")
- # val2 = normalize(val2)
- if key == 'leaveAt':
- sql_query += r' ' + key + ' > ' + r"'" + val2 + r"'"
- elif key == 'arriveBy':
- sql_query += r' ' + key + ' < ' + r"'" + val2 + r"'"
- else:
- sql_query += r' ' + key + '=' + r"'" + val2 + r"'"
- flag = False
- else:
- val2 = val.replace("'", "''")
- # val2 = normalize(val2)
- if key == 'leaveAt':
- sql_query += r' and ' + key + ' > ' + r"'" + val2 + r"'"
- elif key == 'arriveBy':
- sql_query += r' and ' + key + ' < ' + r"'" + val2 + r"'"
- else:
- sql_query += r' and ' + key + '=' + r"'" + val2 + r"'"
-
- try: # "select * from attraction where name = 'queens college'"
- print(sql_query)
- return self.sql_dbs[domain].execute(sql_query).fetchall()
- except:
- return [] # TODO test it
-
-
- if __name__ == '__main__':
- dbPATHs = {
- 'attraction': 'db/attraction_db_processed.json',
- 'hospital': 'db/hospital_db_processed.json',
- 'hotel': 'db/hotel_db_processed.json',
- 'police': 'db/police_db_processed.json',
- 'restaurant': 'db/restaurant_db_processed.json',
- 'taxi': 'db/taxi_db_processed.json',
- 'train': 'db/train_db_processed.json',
- }
- db = MultiWozDB(dbPATHs)
- while True:
- constraints = {}
- inp = input(
- 'input belief state in fomat: domain-slot1=value1;slot2=value2...\n'
- )
- domain, cons = inp.split('-')
- for sv in cons.split(';'):
- s, v = sv.split('=')
- constraints[s] = v
- # res = db.querySQL(domain, constraints)
- res = db.queryJsons(domain, constraints, return_name=True)
- report = []
- reidx = {
- 'hotel': 8,
- 'restaurant': 6,
- 'attraction': 5,
- 'train': 1,
- }
- # for ent in res:
- # if reidx.get(domain):
- # report.append(ent[reidx[domain]])
- # for ent in res:
- # if 'name' in ent:
- # report.append(ent['name'])
- # if 'trainid' in ent:
- # report.append(ent['trainid'])
- print(constraints)
- print(res)
- print('count:', len(res), '\nnames:', report)
|