主要做了如下修改: 1. 加入了同义词词典 2. 对SQL进行后处理,如果包含排序,则将空列转化成Primary列 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10670121master
@@ -231,19 +231,6 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||||
header_ids = table['header_id'] + ['null'] | header_ids = table['header_id'] + ['null'] | ||||
sql = result['sql'] | sql = result['sql'] | ||||
str_sel_list, sql_sel_list = [], [] | |||||
for idx, sel in enumerate(sql['sel']): | |||||
header_name = header_names[sel] | |||||
header_id = '`%s`.`%s`' % (table['table_id'], header_ids[sel]) | |||||
if sql['agg'][idx] == 0: | |||||
str_sel_list.append(header_name) | |||||
sql_sel_list.append(header_id) | |||||
else: | |||||
str_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' | |||||
+ header_name + ')') | |||||
sql_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' | |||||
+ header_id + ')') | |||||
str_cond_list, sql_cond_list = [], [] | str_cond_list, sql_cond_list = [], [] | ||||
where_conds, orderby_conds = [], [] | where_conds, orderby_conds = [], [] | ||||
for cond in sql['conds']: | for cond in sql['conds']: | ||||
@@ -285,9 +272,34 @@ class TableQuestionAnsweringPipeline(Pipeline): | |||||
if is_in: | if is_in: | ||||
str_orderby += ' LIMIT %d' % (limit_num) | str_orderby += ' LIMIT %d' % (limit_num) | ||||
sql_orderby += ' LIMIT %d' % (limit_num) | sql_orderby += ' LIMIT %d' % (limit_num) | ||||
# post process null column | |||||
for idx, sel in enumerate(sql['sel']): | |||||
if sel == len(header_ids) - 1: | |||||
primary_sel = 0 | |||||
for index, attrib in enumerate(table['header_attribute']): | |||||
if attrib == 'PRIMARY': | |||||
primary_sel = index | |||||
break | |||||
if primary_sel not in sql['sel']: | |||||
sql['sel'][idx] = primary_sel | |||||
else: | |||||
del sql['sel'][idx] | |||||
else: | else: | ||||
str_orderby = '' | str_orderby = '' | ||||
str_sel_list, sql_sel_list = [], [] | |||||
for idx, sel in enumerate(sql['sel']): | |||||
header_name = header_names[sel] | |||||
header_id = '`%s`.`%s`' % (table['table_id'], header_ids[sel]) | |||||
if sql['agg'][idx] == 0: | |||||
str_sel_list.append(header_name) | |||||
sql_sel_list.append(header_id) | |||||
else: | |||||
str_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' | |||||
+ header_name + ')') | |||||
sql_sel_list.append(self.agg_ops[sql['agg'][idx]] + '(' | |||||
+ header_id + ')') | |||||
if len(str_cond_list) != 0 and len(str_orderby) != 0: | if len(str_cond_list) != 0 and len(str_orderby) != 0: | ||||
final_str = 'SELECT %s FROM %s WHERE %s ORDER BY %s' % ( | final_str = 'SELECT %s FROM %s WHERE %s ORDER BY %s' % ( | ||||
', '.join(str_sel_list), table['table_name'], str_where_conds, | ', '.join(str_sel_list), table['table_name'], str_where_conds, | ||||
@@ -20,9 +20,9 @@ class Database: | |||||
self.connection_obj = sqlite3.connect( | self.connection_obj = sqlite3.connect( | ||||
':memory:', check_same_thread=False) | ':memory:', check_same_thread=False) | ||||
self.type_dict = {'text': 'TEXT', 'number': 'INT', 'date': 'TEXT'} | self.type_dict = {'text': 'TEXT', 'number': 'INT', 'date': 'TEXT'} | ||||
self.tables = self.init_tables(table_file_path=table_file_path) | |||||
self.syn_dict = self.init_syn_dict( | self.syn_dict = self.init_syn_dict( | ||||
syn_dict_file_path=syn_dict_file_path) | syn_dict_file_path=syn_dict_file_path) | ||||
self.tables = self.init_tables(table_file_path=table_file_path) | |||||
def __del__(self): | def __del__(self): | ||||
if self.is_use_sqlite: | if self.is_use_sqlite: | ||||
@@ -75,6 +75,10 @@ class Database: | |||||
continue | continue | ||||
word = str(cell).strip().lower() | word = str(cell).strip().lower() | ||||
trie_set[ii].insert(word, word) | trie_set[ii].insert(word, word) | ||||
if word in self.syn_dict.keys(): | |||||
for term in self.syn_dict[word]: | |||||
if term.strip() != '': | |||||
trie_set[ii].insert(term, word) | |||||
table['value_trie'] = trie_set | table['value_trie'] = trie_set | ||||
@@ -24,13 +24,10 @@ def tableqa_tracking_and_print_results_with_history( | |||||
'utterance': [ | 'utterance': [ | ||||
'有哪些风险类型?', | '有哪些风险类型?', | ||||
'风险类型有多少种?', | '风险类型有多少种?', | ||||
'珠江流域的小(2)型水库的库容总量是多少?', | |||||
'珠江流域的小型水库的库容总量是多少?', | |||||
'那平均值是多少?', | '那平均值是多少?', | ||||
'那水库的名称呢?', | '那水库的名称呢?', | ||||
'换成中型的呢?', | '换成中型的呢?', | ||||
'枣庄营业厅的电话', | |||||
'那地址呢?', | |||||
'枣庄营业厅的电话和地址', | |||||
] | ] | ||||
} | } | ||||
for p in pipelines: | for p in pipelines: | ||||
@@ -55,9 +52,7 @@ def tableqa_tracking_and_print_results_without_history( | |||||
'utterance': [ | 'utterance': [ | ||||
'有哪些风险类型?', | '有哪些风险类型?', | ||||
'风险类型有多少种?', | '风险类型有多少种?', | ||||
'珠江流域的小(2)型水库的库容总量是多少?', | |||||
'枣庄营业厅的电话', | |||||
'枣庄营业厅的电话和地址', | |||||
'珠江流域的小型水库的库容总量是多少?', | |||||
] | ] | ||||
} | } | ||||
for p in pipelines: | for p in pipelines: | ||||
@@ -77,13 +72,10 @@ def tableqa_tracking_and_print_results_with_tableid( | |||||
'utterance': [ | 'utterance': [ | ||||
['有哪些风险类型?', 'fund'], | ['有哪些风险类型?', 'fund'], | ||||
['风险类型有多少种?', 'reservoir'], | ['风险类型有多少种?', 'reservoir'], | ||||
['珠江流域的小(2)型水库的库容总量是多少?', 'reservoir'], | |||||
['珠江流域的小型水库的库容总量是多少?', 'reservoir'], | |||||
['那平均值是多少?', 'reservoir'], | ['那平均值是多少?', 'reservoir'], | ||||
['那水库的名称呢?', 'reservoir'], | ['那水库的名称呢?', 'reservoir'], | ||||
['换成中型的呢?', 'reservoir'], | ['换成中型的呢?', 'reservoir'], | ||||
['枣庄营业厅的电话', 'business'], | |||||
['那地址呢?', 'business'], | |||||
['枣庄营业厅的电话和地址', 'business'], | |||||
], | ], | ||||
} | } | ||||
for p in pipelines: | for p in pipelines: | ||||
@@ -157,7 +149,7 @@ class TableQuestionAnswering(unittest.TestCase): | |||||
os.path.join(model.model_dir, 'databases')) | os.path.join(model.model_dir, 'databases')) | ||||
], | ], | ||||
syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt'), | syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt'), | ||||
is_use_sqlite=False) | |||||
is_use_sqlite=True) | |||||
preprocessor = TableQuestionAnsweringPreprocessor( | preprocessor = TableQuestionAnsweringPreprocessor( | ||||
model_dir=model.model_dir, db=db) | model_dir=model.model_dir, db=db) | ||||
pipelines = [ | pipelines = [ | ||||