ly119399 yingda.chen 2 years ago
parent
commit
5ae1e08db6
4 changed files with 46 additions and 26 deletions
  1. +3
    -0
      modelscope/models/nlp/space_T_cn/backbone.py
  2. +31
    -26
      modelscope/models/nlp/space_T_cn/table_question_answering.py
  3. +7
    -0
      modelscope/pipelines/nlp/table_question_answering_pipeline.py
  4. +5
    -0
      tests/trainers/test_dialog_modeling_trainer.py

+ 3
- 0
modelscope/models/nlp/space_T_cn/backbone.py View File

@@ -891,6 +891,9 @@ class Seq2SQL(nn.Module):
self.slen_model = nn.Linear(iS, max_select_num + 1)
self.wlen_model = nn.Linear(iS, max_where_num + 1)

def set_device(self, device):
self.device = device

def forward(self, wemb_layer, l_n, l_hs, start_index, column_index, tokens,
ids):
# chunk input lists for multi-gpu


+ 31
- 26
modelscope/models/nlp/space_T_cn/table_question_answering.py View File

@@ -13,7 +13,6 @@ from modelscope.models.base import Model, Tensor
from modelscope.models.builder import MODELS
from modelscope.preprocessors.nlp.space_T_cn.fields.struct import Constant
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.device import verify_device
from .backbone import Seq2SQL, SpaceTCnModel
from .configuration import SpaceTCnConfig

@@ -33,9 +32,6 @@ class TableQuestionAnswering(Model):
super().__init__(model_dir, *args, **kwargs)
self.tokenizer = BertTokenizer(
os.path.join(model_dir, ModelFile.VOCAB_FILE))
device_name = kwargs.get('device', 'gpu')
verify_device(device_name)
self._device_name = device_name

state_dict = torch.load(
os.path.join(self.model_dir, ModelFile.TORCH_MODEL_BIN_FILE),
@@ -60,13 +56,24 @@ class TableQuestionAnswering(Model):
n_agg_ops = len(self.agg_ops)
n_action_ops = len(self.action_ops)
iS = self.backbone_config.hidden_size
self.head_model = Seq2SQL(iS, 100, 2, 0.0, n_cond_ops, n_agg_ops,
n_action_ops, self.max_select_num,
self.max_where_num, self._device_name)
self.head_model = Seq2SQL(
iS,
100,
2,
0.0,
n_cond_ops,
n_agg_ops,
n_action_ops,
self.max_select_num,
self.max_where_num,
device=self._device_name)
self.head_model.load_state_dict(state_dict['head_model'], strict=False)

self.backbone_model.to(self._device_name)
self.head_model.to(self._device_name)
def to(self, device):
self.device = device
self.backbone_model.to(device)
self.head_model.to(device)
self.head_model.set_device(device)

def convert_string(self, pr_wvi, nlu, nlu_tt):
convs = []
@@ -534,21 +541,20 @@ class TableQuestionAnswering(Model):

# Convert to tensor
all_input_ids = torch.tensor(
input_ids, dtype=torch.long).to(self._device_name)
input_ids, dtype=torch.long).to(self.device)
all_order_ids = torch.tensor(
order_ids, dtype=torch.long).to(self._device_name)
all_type_ids = torch.tensor(
type_ids, dtype=torch.long).to(self._device_name)
order_ids, dtype=torch.long).to(self.device)
all_type_ids = torch.tensor(type_ids, dtype=torch.long).to(self.device)
all_input_mask = torch.tensor(
input_mask, dtype=torch.long).to(self._device_name)
input_mask, dtype=torch.long).to(self.device)
all_segment_ids = torch.tensor(
segment_ids, dtype=torch.long).to(self._device_name)
segment_ids, dtype=torch.long).to(self.device)
all_match_ids = torch.tensor(
match_ids, dtype=torch.long).to(self._device_name)
match_ids, dtype=torch.long).to(self.device)
all_header_ids = torch.tensor(
header_ids, dtype=torch.long).to(self._device_name)
header_ids, dtype=torch.long).to(self.device)
all_ids = torch.arange(
all_input_ids.shape[0], dtype=torch.long).to(self._device_name)
all_input_ids.shape[0], dtype=torch.long).to(self.device)

bS = len(header_flatten_tokenid_list)
max_header_flatten_token_length = max(
@@ -566,12 +572,11 @@ class TableQuestionAnswering(Model):
all_header_flatten_output = numpy.zeros((bS, header_max_len + 1),
dtype='int32')
all_header_flatten_tokens = torch.tensor(
all_header_flatten_tokens, dtype=torch.long).to(self._device_name)
all_header_flatten_tokens, dtype=torch.long).to(self.device)
all_header_flatten_index = torch.tensor(
all_header_flatten_index, dtype=torch.long).to(self._device_name)
all_header_flatten_index, dtype=torch.long).to(self.device)
all_header_flatten_output = torch.tensor(
all_header_flatten_output,
dtype=torch.float32).to(self._device_name)
all_header_flatten_output, dtype=torch.float32).to(self.device)

all_token_column_id = numpy.zeros((bS, cur_max_length), dtype='int32')
all_token_column_mask = numpy.zeros((bS, cur_max_length),
@@ -581,9 +586,9 @@ class TableQuestionAnswering(Model):
all_token_column_id[bi, ki] = vi + 1
all_token_column_mask[bi, ki] = 1.0
all_token_column_id = torch.tensor(
all_token_column_id, dtype=torch.long).to(self._device_name)
all_token_column_id, dtype=torch.long).to(self.device)
all_token_column_mask = torch.tensor(
all_token_column_mask, dtype=torch.float32).to(self._device_name)
all_token_column_mask, dtype=torch.float32).to(self.device)

all_schema_link_matrix = numpy.zeros(
(bS, cur_max_length, cur_max_length), dtype='int32')
@@ -596,9 +601,9 @@ class TableQuestionAnswering(Model):
all_schema_link_mask[i, 0:temp_len,
0:temp_len] = schema_link_mask_list[i]
all_schema_link_matrix = torch.tensor(
all_schema_link_matrix, dtype=torch.long).to(self._device_name)
all_schema_link_matrix, dtype=torch.long).to(self.device)
all_schema_link_mask = torch.tensor(
all_schema_link_mask, dtype=torch.long).to(self._device_name)
all_schema_link_mask, dtype=torch.long).to(self.device)

# 5. generate l_hpu from i_hds
l_hpu = self.gen_l_hpu(i_hds)


+ 7
- 0
modelscope/pipelines/nlp/table_question_answering_pipeline.py View File

@@ -83,6 +83,13 @@ class TableQuestionAnsweringPipeline(Pipeline):
self.schema_link_dict = constant.schema_link_dict
self.limit_dict = constant.limit_dict

def prepare_model(self):
""" Place model on certain device for pytorch models before first inference
"""
self._model_prepare_lock.acquire(timeout=600)
self.model.to(self.device)
self._model_prepare_lock.release()

def post_process_multi_turn(self, history_sql, result, table):
action = self.action_ops[result['action']]
headers = table['header_name']


+ 5
- 0
tests/trainers/test_dialog_modeling_trainer.py View File

@@ -61,8 +61,13 @@ class TestDialogModelingTrainer(unittest.TestCase):

trainer = build_trainer(
name=Trainers.dialog_modeling_trainer, default_args=kwargs)
assert trainer is not None

# todo: it takes too long time to train and evaluate. It will be optimized later.
"""
trainer.train()
checkpoint_path = os.path.join(self.output_dir,
ModelFile.TORCH_MODEL_BIN_FILE)
assert os.path.exists(checkpoint_path)
trainer.evaluate(checkpoint_path=checkpoint_path)
"""

Loading…
Cancel
Save