|
|
@@ -1,6 +1,7 @@ |
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates. |
|
|
|
import os |
|
|
|
import unittest |
|
|
|
from threading import Thread |
|
|
|
from typing import List |
|
|
|
|
|
|
|
import json |
|
|
@@ -108,8 +109,6 @@ class TableQuestionAnswering(unittest.TestCase): |
|
|
|
self.task = Tasks.table_question_answering |
|
|
|
self.model_id = 'damo/nlp_convai_text2sql_pretrain_cn' |
|
|
|
|
|
|
|
model_id = 'damo/nlp_convai_text2sql_pretrain_cn' |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
|
def test_run_by_direct_model_download(self): |
|
|
|
cache_path = snapshot_download(self.model_id) |
|
|
@@ -122,6 +121,27 @@ class TableQuestionAnswering(unittest.TestCase): |
|
|
|
] |
|
|
|
tableqa_tracking_and_print_results_with_history(pipelines) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
|
def test_run_by_direct_model_download_with_multithreads(self): |
|
|
|
cache_path = snapshot_download(self.model_id) |
|
|
|
pl = pipeline(Tasks.table_question_answering, model=cache_path) |
|
|
|
|
|
|
|
def print_func(pl, i): |
|
|
|
result = pl({ |
|
|
|
'question': '长江流域的小(2)型水库的库容总量是多少?', |
|
|
|
'table_id': 'reservoir', |
|
|
|
'history_sql': None |
|
|
|
}) |
|
|
|
print(i, json.dumps(result)) |
|
|
|
|
|
|
|
procs = [] |
|
|
|
for i in range(5): |
|
|
|
proc = Thread(target=print_func, args=(pl, i)) |
|
|
|
procs.append(proc) |
|
|
|
proc.start() |
|
|
|
for proc in procs: |
|
|
|
proc.join() |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') |
|
|
|
def test_run_with_model_from_modelhub(self): |
|
|
|
model = Model.from_pretrained(self.model_id) |
|
|
|