主要做了如下修改: 1. 加入了同义词词典 2. 对SQL进行后处理,如果包含排序,则将空列转化成Primary列 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10670121master
@@ -231,19 +231,6 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||
header_ids = table['header_id'] + ['null'] | |||
sql = result['sql'] | |||
str_sel_list, sql_sel_list = [], [] | |||
for idx, sel in enumerate(sql['sel']): | |||
header_name = header_names[sel] | |||
header_id = '`%s`.`%s`' % (table['table_id'], header_ids[sel]) | |||
if sql['agg'][idx] == 0: | |||
str_sel_list.append(header_name) | |||
sql_sel_list.append(header_id) | |||
else: | |||
str_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' | |||
+ header_name + ')') | |||
sql_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' | |||
+ header_id + ')') | |||
str_cond_list, sql_cond_list = [], [] | |||
where_conds, orderby_conds = [], [] | |||
for cond in sql['conds']: | |||
@@ -285,9 +272,34 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||
if is_in: | |||
str_orderby += ' LIMIT %d' % (limit_num) | |||
sql_orderby += ' LIMIT %d' % (limit_num) | |||
# post process null column | |||
for idx, sel in enumerate(sql['sel']): | |||
if sel == len(header_ids) - 1: | |||
primary_sel = 0 | |||
for index, attrib in enumerate(table['header_attribute']): | |||
if attrib == 'PRIMARY': | |||
primary_sel = index | |||
break | |||
if primary_sel not in sql['sel']: | |||
sql['sel'][idx] = primary_sel | |||
else: | |||
del sql['sel'][idx] | |||
else: | |||
str_orderby = '' | |||
str_sel_list, sql_sel_list = [], [] | |||
for idx, sel in enumerate(sql['sel']): | |||
header_name = header_names[sel] | |||
header_id = '`%s`.`%s`' % (table['table_id'], header_ids[sel]) | |||
if sql['agg'][idx] == 0: | |||
str_sel_list.append(header_name) | |||
sql_sel_list.append(header_id) | |||
else: | |||
str_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' | |||
+ header_name + ')') | |||
sql_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' | |||
+ header_id + ')') | |||
if len(str_cond_list) != 0 and len(str_orderby) != 0: | |||
final_str = 'SELECT %s FROM %s WHERE %s ORDER BY %s' % ( | |||
', '.join(str_sel_list), table['table_name'], str_where_conds, | |||
@@ -20,9 +20,9 @@ class Database: | |||
self.connection_obj = sqlite3.connect( | |||
':memory:', check_same_thread=False) | |||
self.type_dict = {'text': 'TEXT', 'number': 'INT', 'date': 'TEXT'} | |||
self.tables = self.init_tables(table_file_path=table_file_path) | |||
self.syn_dict = self.init_syn_dict( | |||
syn_dict_file_path=syn_dict_file_path) | |||
self.tables = self.init_tables(table_file_path=table_file_path) | |||
def __del__(self): | |||
if self.is_use_sqlite: | |||
@@ -75,6 +75,10 @@ class Database: | |||
continue | |||
word = str(cell).strip().lower() | |||
trie_set[ii].insert(word, word) | |||
if word in self.syn_dict.keys(): | |||
for term in self.syn_dict[word]: | |||
if term.strip() != '': | |||
trie_set[ii].insert(term, word) | |||
table['value_trie'] = trie_set | |||
@@ -24,13 +24,10 @@ def tableqa_tracking_and_print_results_with_history( | |||
'utterance': [ | |||
'有哪些风险类型?', | |||
'风险类型有多少种?', | |||
'珠江流域的小(2)型水库的库容总量是多少?', | |||
'珠江流域的小型水库的库容总量是多少?', | |||
'那平均值是多少?', | |||
'那水库的名称呢?', | |||
'换成中型的呢?', | |||
'枣庄营业厅的电话', | |||
'那地址呢?', | |||
'枣庄营业厅的电话和地址', | |||
] | |||
} | |||
for p in pipelines: | |||
@@ -55,9 +52,7 @@ def tableqa_tracking_and_print_results_without_history( | |||
'utterance': [ | |||
'有哪些风险类型?', | |||
'风险类型有多少种?', | |||
'珠江流域的小(2)型水库的库容总量是多少?', | |||
'枣庄营业厅的电话', | |||
'枣庄营业厅的电话和地址', | |||
'珠江流域的小型水库的库容总量是多少?', | |||
] | |||
} | |||
for p in pipelines: | |||
@@ -77,13 +72,10 @@ def tableqa_tracking_and_print_results_with_tableid( | |||
'utterance': [ | |||
['有哪些风险类型?', 'fund'], | |||
['风险类型有多少种?', 'reservoir'], | |||
['珠江流域的小(2)型水库的库容总量是多少?', 'reservoir'], | |||
['珠江流域的小型水库的库容总量是多少?', 'reservoir'], | |||
['那平均值是多少?', 'reservoir'], | |||
['那水库的名称呢?', 'reservoir'], | |||
['换成中型的呢?', 'reservoir'], | |||
['枣庄营业厅的电话', 'business'], | |||
['那地址呢?', 'business'], | |||
['枣庄营业厅的电话和地址', 'business'], | |||
], | |||
} | |||
for p in pipelines: | |||
@@ -157,7 +149,7 @@ class TableQuestionAnswering(unittest.TestCase): | |||
os.path.join(model.model_dir, 'databases')) | |||
], | |||
syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt'), | |||
is_use_sqlite=False) | |||
is_use_sqlite=True) | |||
preprocessor = TableQuestionAnsweringPreprocessor( | |||
model_dir=model.model_dir, db=db) | |||
pipelines = [ | |||