Browse Source

[to #42322933] add synonym

主要做了如下修改:
1. 加入了同义词词典
2. 对SQL进行后处理,如果包含排序,则将空列转化成Primary列
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10670121
master
caorongyu.cry yingda.chen 2 years ago
parent
commit
e2a9695f93
3 changed files with 34 additions and 26 deletions
  1. +25
    -13
      modelscope/pipelines/nlp/table_question_answering_pipeline.py
  2. +5
    -1
      modelscope/preprocessors/nlp/space_T_cn/fields/database.py
  3. +4
    -12
      tests/pipelines/test_table_question_answering.py

+ 25
- 13
modelscope/pipelines/nlp/table_question_answering_pipeline.py View File

@@ -231,19 +231,6 @@ class TableQuestionAnsweringPipeline(Pipeline):
header_ids = table['header_id'] + ['null']
sql = result['sql']

str_sel_list, sql_sel_list = [], []
for idx, sel in enumerate(sql['sel']):
header_name = header_names[sel]
header_id = '`%s`.`%s`' % (table['table_id'], header_ids[sel])
if sql['agg'][idx] == 0:
str_sel_list.append(header_name)
sql_sel_list.append(header_id)
else:
str_sel_list.append(self.agg_ops[sql['agg'][idx]] + '('
+ header_name + ')')
sql_sel_list.append(self.agg_ops[sql['agg'][idx]] + '('
+ header_id + ')')

str_cond_list, sql_cond_list = [], []
where_conds, orderby_conds = [], []
for cond in sql['conds']:
@@ -285,9 +272,34 @@ class TableQuestionAnsweringPipeline(Pipeline):
if is_in:
str_orderby += ' LIMIT %d' % (limit_num)
sql_orderby += ' LIMIT %d' % (limit_num)
# post process null column
for idx, sel in enumerate(sql['sel']):
if sel == len(header_ids) - 1:
primary_sel = 0
for index, attrib in enumerate(table['header_attribute']):
if attrib == 'PRIMARY':
primary_sel = index
break
if primary_sel not in sql['sel']:
sql['sel'][idx] = primary_sel
else:
del sql['sel'][idx]
else:
str_orderby = ''

str_sel_list, sql_sel_list = [], []
for idx, sel in enumerate(sql['sel']):
header_name = header_names[sel]
header_id = '`%s`.`%s`' % (table['table_id'], header_ids[sel])
if sql['agg'][idx] == 0:
str_sel_list.append(header_name)
sql_sel_list.append(header_id)
else:
str_sel_list.append(self.agg_ops[sql['agg'][idx]] + '('
+ header_name + ')')
sql_sel_list.append(self.agg_ops[sql['agg'][idx]] + '('
+ header_id + ')')

if len(str_cond_list) != 0 and len(str_orderby) != 0:
final_str = 'SELECT %s FROM %s WHERE %s ORDER BY %s' % (
', '.join(str_sel_list), table['table_name'], str_where_conds,


+ 5
- 1
modelscope/preprocessors/nlp/space_T_cn/fields/database.py View File

@@ -20,9 +20,9 @@ class Database:
self.connection_obj = sqlite3.connect(
':memory:', check_same_thread=False)
self.type_dict = {'text': 'TEXT', 'number': 'INT', 'date': 'TEXT'}
self.tables = self.init_tables(table_file_path=table_file_path)
self.syn_dict = self.init_syn_dict(
syn_dict_file_path=syn_dict_file_path)
self.tables = self.init_tables(table_file_path=table_file_path)

def __del__(self):
if self.is_use_sqlite:
@@ -75,6 +75,10 @@ class Database:
continue
word = str(cell).strip().lower()
trie_set[ii].insert(word, word)
if word in self.syn_dict.keys():
for term in self.syn_dict[word]:
if term.strip() != '':
trie_set[ii].insert(term, word)

table['value_trie'] = trie_set



+ 4
- 12
tests/pipelines/test_table_question_answering.py View File

@@ -24,13 +24,10 @@ def tableqa_tracking_and_print_results_with_history(
'utterance': [
'有哪些风险类型?',
'风险类型有多少种?',
'珠江流域的小(2)型水库的库容总量是多少?',
'珠江流域的小型水库的库容总量是多少?',
'那平均值是多少?',
'那水库的名称呢?',
'换成中型的呢?',
'枣庄营业厅的电话',
'那地址呢?',
'枣庄营业厅的电话和地址',
]
}
for p in pipelines:
@@ -55,9 +52,7 @@ def tableqa_tracking_and_print_results_without_history(
'utterance': [
'有哪些风险类型?',
'风险类型有多少种?',
'珠江流域的小(2)型水库的库容总量是多少?',
'枣庄营业厅的电话',
'枣庄营业厅的电话和地址',
'珠江流域的小型水库的库容总量是多少?',
]
}
for p in pipelines:
@@ -77,13 +72,10 @@ def tableqa_tracking_and_print_results_with_tableid(
'utterance': [
['有哪些风险类型?', 'fund'],
['风险类型有多少种?', 'reservoir'],
['珠江流域的小(2)型水库的库容总量是多少?', 'reservoir'],
['珠江流域的小型水库的库容总量是多少?', 'reservoir'],
['那平均值是多少?', 'reservoir'],
['那水库的名称呢?', 'reservoir'],
['换成中型的呢?', 'reservoir'],
['枣庄营业厅的电话', 'business'],
['那地址呢?', 'business'],
['枣庄营业厅的电话和地址', 'business'],
],
}
for p in pipelines:
@@ -157,7 +149,7 @@ class TableQuestionAnswering(unittest.TestCase):
os.path.join(model.model_dir, 'databases'))
],
syn_dict_file_path=os.path.join(model.model_dir, 'synonym.txt'),
is_use_sqlite=False)
is_use_sqlite=True)
preprocessor = TableQuestionAnsweringPreprocessor(
model_dir=model.model_dir, db=db)
pipelines = [


Loading…
Cancel
Save