@@ -7,4 +7,4 @@ from .sbert_for_sentiment_classification import * # noqa F403 | |||
from .sbert_for_token_classification import * # noqa F403 | |||
from .space.dialog_intent_prediction_model import * # noqa F403 | |||
from .space.dialog_modeling_model import * # noqa F403 | |||
from .space.dialog_state_tracking import * # noqa F403 | |||
from .space.dialog_state_tracking_model import * # noqa F403 |
@@ -1,77 +0,0 @@ | |||
import os | |||
from typing import Any, Dict | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import Tasks | |||
from ...base import Model, Tensor | |||
from ...builder import MODELS | |||
from .model.generator import Generator | |||
from .model.model_base import ModelBase | |||
__all__ = ['DialogStateTrackingModel'] | |||
@MODELS.register_module(Tasks.dialog_state_tracking, module_name=r'space') | |||
class DialogStateTrackingModel(Model): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
"""initialize the test generation model from the `model_dir` path. | |||
Args: | |||
model_dir (str): the model path. | |||
model_cls (Optional[Any], optional): model loader, if None, use the | |||
default loader to load model weights, by default None. | |||
""" | |||
super().__init__(model_dir, *args, **kwargs) | |||
self.model_dir = model_dir | |||
self.config = kwargs.pop( | |||
'config', | |||
Config.from_file( | |||
os.path.join(self.model_dir, 'configuration.json'))) | |||
self.text_field = kwargs.pop( | |||
'text_field', | |||
IntentBPETextField(self.model_dir, config=self.config)) | |||
self.generator = Generator.create(self.config, reader=self.text_field) | |||
self.model = ModelBase.create( | |||
model_dir=model_dir, | |||
config=self.config, | |||
reader=self.text_field, | |||
generator=self.generator) | |||
def to_tensor(array): | |||
""" | |||
numpy array -> tensor | |||
""" | |||
import torch | |||
array = torch.tensor(array) | |||
return array.cuda() if self.config.use_gpu else array | |||
self.trainer = IntentTrainer( | |||
model=self.model, | |||
to_tensor=to_tensor, | |||
config=self.config, | |||
reader=self.text_field) | |||
self.trainer.load() | |||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
"""return the result by the model | |||
Args: | |||
input (Dict[str, Any]): the preprocessed data | |||
Returns: | |||
Dict[str, np.ndarray]: results | |||
Example: | |||
{ | |||
'predictions': array([1]), # lable 0-negative 1-positive | |||
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||
} | |||
""" | |||
import numpy as np | |||
pred = self.trainer.forward(input) | |||
pred = np.squeeze(pred[0], 0) | |||
return {'pred': pred} |
@@ -0,0 +1,101 @@ | |||
import os | |||
from typing import Any, Dict | |||
from modelscope.utils.constant import Tasks | |||
from ....utils.nlp.space.utils_dst import batch_to_device | |||
from ...base import Model, Tensor | |||
from ...builder import MODELS | |||
__all__ = ['DialogStateTrackingModel'] | |||
@MODELS.register_module(Tasks.dialog_state_tracking, module_name=r'space') | |||
class DialogStateTrackingModel(Model): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
"""initialize the test generation model from the `model_dir` path. | |||
Args: | |||
model_dir (str): the model path. | |||
model_cls (Optional[Any], optional): model loader, if None, use the | |||
default loader to load model weights, by default None. | |||
""" | |||
super().__init__(model_dir, *args, **kwargs) | |||
from sofa.models.space import SpaceForDST, SpaceConfig | |||
self.model_dir = model_dir | |||
self.config = SpaceConfig.from_pretrained(self.model_dir) | |||
# self.model = SpaceForDST(self.config) | |||
self.model = SpaceForDST.from_pretrained(self.model_dir) | |||
self.model.to(self.config.device) | |||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
"""return the result by the model | |||
Args: | |||
input (Dict[str, Any]): the preprocessed data | |||
Returns: | |||
Dict[str, np.ndarray]: results | |||
Example: | |||
{ | |||
'predictions': array([1]), # lable 0-negative 1-positive | |||
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||
} | |||
""" | |||
import numpy as np | |||
import torch | |||
self.model.eval() | |||
batch = input['batch'] | |||
batch = batch_to_device(batch, self.config.device) | |||
features = input['features'] | |||
diag_state = input['diag_state'] | |||
turn_itrs = [features[i.item()].guid.split('-')[2] for i in batch[9]] | |||
reset_diag_state = np.where(np.array(turn_itrs) == '0')[0] | |||
for slot in self.config.dst_slot_list: | |||
for i in reset_diag_state: | |||
diag_state[slot][i] = 0 | |||
with torch.no_grad(): | |||
inputs = { | |||
'input_ids': batch[0], | |||
'input_mask': batch[1], | |||
'segment_ids': batch[2], | |||
'start_pos': batch[3], | |||
'end_pos': batch[4], | |||
'inform_slot_id': batch[5], | |||
'refer_id': batch[6], | |||
'diag_state': diag_state, | |||
'class_label_id': batch[8] | |||
} | |||
unique_ids = [features[i.item()].guid for i in batch[9]] | |||
values = [features[i.item()].values for i in batch[9]] | |||
input_ids_unmasked = [ | |||
features[i.item()].input_ids_unmasked for i in batch[9] | |||
] | |||
inform = [features[i.item()].inform for i in batch[9]] | |||
outputs = self.model(**inputs) | |||
# Update dialog state for next turn. | |||
for slot in self.config.dst_slot_list: | |||
updates = outputs[2][slot].max(1)[1] | |||
for i, u in enumerate(updates): | |||
if u != 0: | |||
diag_state[slot][i] = u | |||
print(outputs) | |||
return { | |||
'inputs': inputs, | |||
'outputs': outputs, | |||
'unique_ids': unique_ids, | |||
'input_ids_unmasked': input_ids_unmasked, | |||
'values': values, | |||
'inform': inform, | |||
'prefix': 'final' | |||
} |
@@ -1,6 +1,6 @@ | |||
from .dialog_intent_prediction_pipeline import * # noqa F403 | |||
from .dialog_modeling_pipeline import * # noqa F403 | |||
from .dialog_state_tracking import * # noqa F403 | |||
from .dialog_state_tracking_pipeline import * # noqa F403 | |||
from .fill_mask_pipeline import * # noqa F403 | |||
from .nli_pipeline import * # noqa F403 | |||
from .sentence_similarity_pipeline import * # noqa F403 | |||
@@ -1,45 +0,0 @@ | |||
from typing import Any, Dict | |||
from ...metainfo import Pipelines | |||
from ...models.nlp import DialogStateTrackingModel | |||
from ...preprocessors import DialogStateTrackingPreprocessor | |||
from ...utils.constant import Tasks | |||
from ..base import Pipeline | |||
from ..builder import PIPELINES | |||
__all__ = ['DialogStateTrackingPipeline'] | |||
@PIPELINES.register_module( | |||
Tasks.dialog_state_tracking, module_name=Pipelines.dialog_state_tracking) | |||
class DialogStateTrackingPipeline(Pipeline): | |||
def __init__(self, model: DialogStateTrackingModel, | |||
preprocessor: DialogStateTrackingPreprocessor, **kwargs): | |||
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||
Args: | |||
model (SequenceClassificationModel): a model instance | |||
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | |||
""" | |||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
self.model = model | |||
# self.tokenizer = preprocessor.tokenizer | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||
"""process the prediction results | |||
Args: | |||
inputs (Dict[str, Any]): _description_ | |||
Returns: | |||
Dict[str, str]: the prediction results | |||
""" | |||
import numpy as np | |||
pred = inputs['pred'] | |||
pos = np.where(pred == np.max(pred)) | |||
result = {'pred': pred, 'label': pos[0]} | |||
return result |
@@ -0,0 +1,146 @@ | |||
from typing import Any, Dict | |||
from ...metainfo import Pipelines | |||
from ...models.nlp import DialogStateTrackingModel | |||
from ...preprocessors import DialogStateTrackingPreprocessor | |||
from ...utils.constant import Tasks | |||
from ..base import Pipeline | |||
from ..builder import PIPELINES | |||
__all__ = ['DialogStateTrackingPipeline'] | |||
@PIPELINES.register_module( | |||
Tasks.dialog_state_tracking, module_name=Pipelines.dialog_state_tracking) | |||
class DialogStateTrackingPipeline(Pipeline): | |||
def __init__(self, model: DialogStateTrackingModel, | |||
preprocessor: DialogStateTrackingPreprocessor, **kwargs): | |||
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction | |||
Args: | |||
model (SequenceClassificationModel): a model instance | |||
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance | |||
""" | |||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
self.model = model | |||
self.tokenizer = preprocessor.tokenizer | |||
self.config = preprocessor.config | |||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]: | |||
"""process the prediction results | |||
Args: | |||
inputs (Dict[str, Any]): _description_ | |||
Returns: | |||
Dict[str, str]: the prediction results | |||
""" | |||
_inputs = inputs['inputs'] | |||
_outputs = inputs['outputs'] | |||
unique_ids = inputs['unique_ids'] | |||
input_ids_unmasked = inputs['input_ids_unmasked'] | |||
values = inputs['values'] | |||
inform = inputs['inform'] | |||
prefix = inputs['prefix'] | |||
ds = {slot: 'none' for slot in self.config.dst_slot_list} | |||
ds = predict_and_format(self.config, self.tokenizer, _inputs, | |||
_outputs[2], _outputs[3], _outputs[4], | |||
_outputs[5], unique_ids, input_ids_unmasked, | |||
values, inform, prefix, ds) | |||
return ds | |||
def predict_and_format(config, tokenizer, features, per_slot_class_logits, | |||
per_slot_start_logits, per_slot_end_logits, | |||
per_slot_refer_logits, ids, input_ids_unmasked, values, | |||
inform, prefix, ds): | |||
import re | |||
prediction_list = [] | |||
dialog_state = ds | |||
for i in range(len(ids)): | |||
if int(ids[i].split('-')[2]) == 0: | |||
dialog_state = {slot: 'none' for slot in config.dst_slot_list} | |||
prediction = {} | |||
prediction_addendum = {} | |||
for slot in config.dst_slot_list: | |||
class_logits = per_slot_class_logits[slot][i] | |||
start_logits = per_slot_start_logits[slot][i] | |||
end_logits = per_slot_end_logits[slot][i] | |||
refer_logits = per_slot_refer_logits[slot][i] | |||
input_ids = features['input_ids'][i].tolist() | |||
class_label_id = int(features['class_label_id'][slot][i]) | |||
start_pos = int(features['start_pos'][slot][i]) | |||
end_pos = int(features['end_pos'][slot][i]) | |||
refer_id = int(features['refer_id'][slot][i]) | |||
class_prediction = int(class_logits.argmax()) | |||
start_prediction = int(start_logits.argmax()) | |||
end_prediction = int(end_logits.argmax()) | |||
refer_prediction = int(refer_logits.argmax()) | |||
prediction['guid'] = ids[i].split('-') | |||
prediction['class_prediction_%s' % slot] = class_prediction | |||
prediction['class_label_id_%s' % slot] = class_label_id | |||
prediction['start_prediction_%s' % slot] = start_prediction | |||
prediction['start_pos_%s' % slot] = start_pos | |||
prediction['end_prediction_%s' % slot] = end_prediction | |||
prediction['end_pos_%s' % slot] = end_pos | |||
prediction['refer_prediction_%s' % slot] = refer_prediction | |||
prediction['refer_id_%s' % slot] = refer_id | |||
prediction['input_ids_%s' % slot] = input_ids | |||
if class_prediction == config.dst_class_types.index('dontcare'): | |||
dialog_state[slot] = 'dontcare' | |||
elif class_prediction == config.dst_class_types.index( | |||
'copy_value'): | |||
input_tokens = tokenizer.convert_ids_to_tokens( | |||
input_ids_unmasked[i]) | |||
dialog_state[slot] = ' '.join( | |||
input_tokens[start_prediction:end_prediction + 1]) | |||
dialog_state[slot] = re.sub('(^| )##', '', dialog_state[slot]) | |||
elif 'true' in config.dst_class_types and class_prediction == config.dst_class_types.index( | |||
'true'): | |||
dialog_state[slot] = 'true' | |||
elif 'false' in config.dst_class_types and class_prediction == config.dst_class_types.index( | |||
'false'): | |||
dialog_state[slot] = 'false' | |||
elif class_prediction == config.dst_class_types.index('inform'): | |||
dialog_state[slot] = '§§' + inform[i][slot] | |||
# Referral case is handled below | |||
prediction_addendum['slot_prediction_%s' | |||
% slot] = dialog_state[slot] | |||
prediction_addendum['slot_groundtruth_%s' % slot] = values[i][slot] | |||
# Referral case. All other slot values need to be seen first in order | |||
# to be able to do this correctly. | |||
for slot in config.dst_slot_list: | |||
class_logits = per_slot_class_logits[slot][i] | |||
refer_logits = per_slot_refer_logits[slot][i] | |||
class_prediction = int(class_logits.argmax()) | |||
refer_prediction = int(refer_logits.argmax()) | |||
if 'refer' in config.dst_class_types and class_prediction == config.dst_class_types.index( | |||
'refer'): | |||
# Only slots that have been mentioned before can be referred to. | |||
# One can think of a situation where one slot is referred to in the same utterance. | |||
# This phenomenon is however currently not properly covered in the training data | |||
# label generation process. | |||
dialog_state[slot] = dialog_state[config.dst_slot_list[ | |||
refer_prediction - 1]] | |||
prediction_addendum['slot_prediction_%s' % | |||
slot] = dialog_state[slot] # Value update | |||
prediction.update(prediction_addendum) | |||
prediction_list.append(prediction) | |||
return dialog_state |
@@ -3,13 +3,12 @@ | |||
import os | |||
from typing import Any, Dict | |||
from modelscope.preprocessors.space.fields.intent_field import \ | |||
IntentBPETextField | |||
from modelscope.utils.config import Config | |||
from modelscope.utils.constant import Fields | |||
from modelscope.utils.type_assert import type_assert | |||
from ..base import Preprocessor | |||
from ..builder import PREPROCESSORS | |||
from .dst_processors import convert_examples_to_features, multiwoz22Processor | |||
from .tensorlistdataset import TensorListDataset | |||
__all__ = ['DialogStateTrackingPreprocessor'] | |||
@@ -25,14 +24,14 @@ class DialogStateTrackingPreprocessor(Preprocessor): | |||
""" | |||
super().__init__(*args, **kwargs) | |||
from sofa.models.space import SpaceTokenizer, SpaceConfig | |||
self.model_dir: str = model_dir | |||
self.config = Config.from_file( | |||
os.path.join(self.model_dir, 'configuration.json')) | |||
self.text_field = IntentBPETextField( | |||
self.model_dir, config=self.config) | |||
self.config = SpaceConfig.from_pretrained(self.model_dir) | |||
self.tokenizer = SpaceTokenizer.from_pretrained(self.model_dir) | |||
self.processor = multiwoz22Processor() | |||
@type_assert(object, str) | |||
def __call__(self, data: str) -> Dict[str, Any]: | |||
@type_assert(object, dict) | |||
def __call__(self, data: Dict) -> Dict[str, Any]: | |||
"""process the raw input data | |||
Args: | |||
@@ -43,7 +42,96 @@ class DialogStateTrackingPreprocessor(Preprocessor): | |||
Returns: | |||
Dict[str, Any]: the preprocessed data | |||
""" | |||
samples = self.text_field.preprocessor([data]) | |||
samples, _ = self.text_field.collate_fn_multi_turn(samples) | |||
import torch | |||
from torch.utils.data import (DataLoader, RandomSampler, | |||
SequentialSampler) | |||
return samples | |||
utter = data['utter'] | |||
history_states = data['history_states'] | |||
example = self.processor.create_example( | |||
inputs=utter, | |||
history_states=history_states, | |||
set_type='test', | |||
slot_list=self.config.dst_slot_list, | |||
label_maps={}, | |||
append_history=True, | |||
use_history_labels=True, | |||
swap_utterances=True, | |||
label_value_repetitions=True, | |||
delexicalize_sys_utts=True, | |||
unk_token='[UNK]', | |||
analyze=False) | |||
print(example) | |||
features = convert_examples_to_features( | |||
examples=[example], | |||
slot_list=self.config.dst_slot_list, | |||
class_types=self.config.dst_class_types, | |||
model_type=self.config.model_type, | |||
tokenizer=self.tokenizer, | |||
max_seq_length=180, # args.max_seq_length | |||
slot_value_dropout=(0.0)) | |||
all_input_ids = torch.tensor([f.input_ids for f in features], | |||
dtype=torch.long) | |||
all_input_mask = torch.tensor([f.input_mask for f in features], | |||
dtype=torch.long) | |||
all_segment_ids = torch.tensor([f.segment_ids for f in features], | |||
dtype=torch.long) | |||
all_example_index = torch.arange( | |||
all_input_ids.size(0), dtype=torch.long) | |||
f_start_pos = [f.start_pos for f in features] | |||
f_end_pos = [f.end_pos for f in features] | |||
f_inform_slot_ids = [f.inform_slot for f in features] | |||
f_refer_ids = [f.refer_id for f in features] | |||
f_diag_state = [f.diag_state for f in features] | |||
f_class_label_ids = [f.class_label_id for f in features] | |||
all_start_positions = {} | |||
all_end_positions = {} | |||
all_inform_slot_ids = {} | |||
all_refer_ids = {} | |||
all_diag_state = {} | |||
all_class_label_ids = {} | |||
for s in self.config.dst_slot_list: | |||
all_start_positions[s] = torch.tensor([f[s] for f in f_start_pos], | |||
dtype=torch.long) | |||
all_end_positions[s] = torch.tensor([f[s] for f in f_end_pos], | |||
dtype=torch.long) | |||
all_inform_slot_ids[s] = torch.tensor( | |||
[f[s] for f in f_inform_slot_ids], dtype=torch.long) | |||
all_refer_ids[s] = torch.tensor([f[s] for f in f_refer_ids], | |||
dtype=torch.long) | |||
all_diag_state[s] = torch.tensor([f[s] for f in f_diag_state], | |||
dtype=torch.long) | |||
all_class_label_ids[s] = torch.tensor( | |||
[f[s] for f in f_class_label_ids], dtype=torch.long) | |||
# dataset = TensorListDataset(all_input_ids, all_input_mask, all_segment_ids, | |||
# all_start_positions, all_end_positions, | |||
# all_inform_slot_ids, | |||
# all_refer_ids, | |||
# all_diag_state, | |||
# all_class_label_ids, all_example_index) | |||
# | |||
# eval_sampler = SequentialSampler(dataset) | |||
# eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=self.config.eval_batch_size) | |||
dataset = [ | |||
all_input_ids, all_input_mask, all_segment_ids, | |||
all_start_positions, all_end_positions, all_inform_slot_ids, | |||
all_refer_ids, all_diag_state, all_class_label_ids, | |||
all_example_index | |||
] | |||
with torch.no_grad(): | |||
diag_state = { | |||
slot: | |||
torch.tensor([0 for _ in range(self.config.eval_batch_size) | |||
]).to(self.config.device) | |||
for slot in self.config.dst_slot_list | |||
} | |||
# print(diag_state) | |||
return { | |||
'batch': dataset, | |||
'features': features, | |||
'diag_state': diag_state | |||
} |
@@ -1097,29 +1097,31 @@ class DSTExample(object): | |||
return self.__repr__() | |||
def __repr__(self): | |||
s = '' | |||
s += 'guid: %s' % (self.guid) | |||
s += ', text_a: %s' % (self.text_a) | |||
s += ', text_b: %s' % (self.text_b) | |||
s += ', history: %s' % (self.history) | |||
s_dict = dict() | |||
s_dict['guid'] = self.guid | |||
s_dict['text_a'] = self.text_a | |||
s_dict['text_b'] = self.text_b | |||
s_dict['history'] = self.history | |||
if self.text_a_label: | |||
s += ', text_a_label: %d' % (self.text_a_label) | |||
s_dict['text_a_label'] = self.text_a_label | |||
if self.text_b_label: | |||
s += ', text_b_label: %d' % (self.text_b_label) | |||
s_dict['text_b_label'] = self.text_b_label | |||
if self.history_label: | |||
s += ', history_label: %d' % (self.history_label) | |||
s_dict['history_label'] = self.history_label | |||
if self.values: | |||
s += ', values: %d' % (self.values) | |||
s_dict['values'] = self.values | |||
if self.inform_label: | |||
s += ', inform_label: %d' % (self.inform_label) | |||
s_dict['inform_label'] = self.inform_label | |||
if self.inform_slot_label: | |||
s += ', inform_slot_label: %d' % (self.inform_slot_label) | |||
s_dict['inform_slot_label'] = self.inform_slot_label | |||
if self.refer_label: | |||
s += ', refer_label: %d' % (self.refer_label) | |||
s_dict['refer_label'] = self.refer_label | |||
if self.diag_state: | |||
s += ', diag_state: %d' % (self.diag_state) | |||
s_dict['diag_state'] = self.diag_state | |||
if self.class_label: | |||
s += ', class_label: %d' % (self.class_label) | |||
s_dict['class_label'] = self.class_label | |||
s = json.dumps(s_dict) | |||
return s | |||
@@ -1515,6 +1517,7 @@ if __name__ == '__main__': | |||
delexicalize_sys_utts = True, | |||
unk_token = '[UNK]' | |||
analyze = False | |||
example = processor.create_example(utter1, history_states1, set_type, | |||
slot_list, {}, append_history, | |||
use_history_labels, swap_utterances, |
@@ -0,0 +1,59 @@ | |||
# | |||
# Copyright 2020 Heinrich Heine University Duesseldorf | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from torch.utils.data import Dataset | |||
class TensorListDataset(Dataset): | |||
r"""Dataset wrapping tensors, tensor dicts and tensor lists. | |||
Arguments: | |||
*data (Tensor or dict or list of Tensors): tensors that have the same size | |||
of the first dimension. | |||
""" | |||
def __init__(self, *data): | |||
if isinstance(data[0], dict): | |||
size = list(data[0].values())[0].size(0) | |||
elif isinstance(data[0], list): | |||
size = data[0][0].size(0) | |||
else: | |||
size = data[0].size(0) | |||
for element in data: | |||
if isinstance(element, dict): | |||
assert all( | |||
size == tensor.size(0) | |||
for name, tensor in element.items()) # dict of tensors | |||
elif isinstance(element, list): | |||
assert all(size == tensor.size(0) | |||
for tensor in element) # list of tensors | |||
else: | |||
assert size == element.size(0) # tensor | |||
self.size = size | |||
self.data = data | |||
def __getitem__(self, index): | |||
result = [] | |||
for element in self.data: | |||
if isinstance(element, dict): | |||
result.append({k: v[index] for k, v in element.items()}) | |||
elif isinstance(element, list): | |||
result.append(v[index] for v in element) | |||
else: | |||
result.append(element[index]) | |||
return tuple(result) | |||
def __len__(self): | |||
return self.size |
@@ -0,0 +1,10 @@ | |||
def batch_to_device(batch, device): | |||
batch_on_device = [] | |||
for element in batch: | |||
if isinstance(element, dict): | |||
batch_on_device.append( | |||
{k: v.to(device) | |||
for k, v in element.items()}) | |||
else: | |||
batch_on_device.append(element.to(device)) | |||
return tuple(batch_on_device) |
@@ -14,26 +14,28 @@ from modelscope.utils.constant import Tasks | |||
class DialogStateTrackingTest(unittest.TestCase): | |||
model_id = 'damo/nlp_space_dialog-state-tracking' | |||
test_case = {} | |||
test_case = [{ | |||
'utter': { | |||
'User-1': | |||
"I'm looking for a place to stay. It needs to be a guesthouse and include free wifi." | |||
}, | |||
'history_states': [{}] | |||
}] | |||
def test_run(self): | |||
# cache_path = '' | |||
cache_path = '/Users/yangliu/Space/maas_model/nlp_space_dialog-state-tracking' | |||
# cache_path = snapshot_download(self.model_id) | |||
# preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path) | |||
# model = DialogStateTrackingModel( | |||
# model_dir=cache_path, | |||
# text_field=preprocessor.text_field, | |||
# config=preprocessor.config) | |||
# pipelines = [ | |||
# DialogStateTrackingPipeline(model=model, preprocessor=preprocessor), | |||
# pipeline( | |||
# task=Tasks.dialog_modeling, | |||
# model=model, | |||
# preprocessor=preprocessor) | |||
# ] | |||
print('jizhu test') | |||
model = DialogStateTrackingModel(cache_path) | |||
preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path) | |||
pipeline1 = DialogStateTrackingPipeline( | |||
model=model, preprocessor=preprocessor) | |||
history_states = {} | |||
for step, item in enumerate(self.test_case): | |||
history_states = pipeline1(item) | |||
print(history_states) | |||
@unittest.skip('test with snapshot_download') | |||
def test_run_with_model_from_modelhub(self): | |||