Browse Source

[to #42322933] add ut for multi threads

1. 修复multi thread引起的问题
2. 增加multi thread的unittest
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10502008
master
caorongyu.cry yingda.chen 2 years ago
parent
commit
6178f46910
3 changed files with 29 additions and 4 deletions
  1. +5
    -1
      modelscope/pipelines/nlp/table_question_answering_pipeline.py
  2. +2
    -1
      modelscope/preprocessors/space_T_cn/fields/database.py
  3. +22
    -2
      tests/pipelines/test_table_question_answering.py

+ 5
- 1
modelscope/pipelines/nlp/table_question_answering_pipeline.py View File

@@ -17,6 +17,9 @@ from modelscope.preprocessors.space_T_cn.fields.database import Database
from modelscope.preprocessors.space_T_cn.fields.struct import (Constant,
SQLQuery)
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()

__all__ = ['TableQuestionAnsweringPipeline']

@@ -309,7 +312,8 @@ class TableQuestionAnsweringPipeline(Pipeline):
'header_name': header_names,
'rows': rows
}
except Exception:
except Exception as e:
logger.error(e)
tabledata = {'header_id': [], 'header_name': [], 'rows': []}
else:
tabledata = {'header_id': [], 'header_name': [], 'rows': []}


+ 2
- 1
modelscope/preprocessors/space_T_cn/fields/database.py View File

@@ -17,7 +17,8 @@ class Database:
self.tokenizer = tokenizer
self.is_use_sqlite = is_use_sqlite
if self.is_use_sqlite:
self.connection_obj = sqlite3.connect(':memory:')
self.connection_obj = sqlite3.connect(
':memory:', check_same_thread=False)
self.type_dict = {'text': 'TEXT', 'number': 'INT', 'date': 'TEXT'}
self.tables = self.init_tables(table_file_path=table_file_path)
self.syn_dict = self.init_syn_dict(


+ 22
- 2
tests/pipelines/test_table_question_answering.py View File

@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import unittest
from threading import Thread
from typing import List

import json
@@ -108,8 +109,6 @@ class TableQuestionAnswering(unittest.TestCase):
self.task = Tasks.table_question_answering
self.model_id = 'damo/nlp_convai_text2sql_pretrain_cn'

model_id = 'damo/nlp_convai_text2sql_pretrain_cn'

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self):
cache_path = snapshot_download(self.model_id)
@@ -122,6 +121,27 @@ class TableQuestionAnswering(unittest.TestCase):
]
tableqa_tracking_and_print_results_with_history(pipelines)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download_with_multithreads(self):
cache_path = snapshot_download(self.model_id)
pl = pipeline(Tasks.table_question_answering, model=cache_path)

def print_func(pl, i):
result = pl({
'question': '长江流域的小(2)型水库的库容总量是多少?',
'table_id': 'reservoir',
'history_sql': None
})
print(i, json.dumps(result))

procs = []
for i in range(5):
proc = Thread(target=print_func, args=(pl, i))
procs.append(proc)
proc.start()
for proc in procs:
proc.join()

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)


Loading…
Cancel
Save