Browse Source

space dst pipeline is ready, but model's result is wrong

master
ly119399 3 years ago
parent
commit
760dcf0247
11 changed files with 453 additions and 166 deletions
  1. +1
    -1
      modelscope/models/nlp/__init__.py
  2. +0
    -77
      modelscope/models/nlp/space/dialog_state_tracking.py
  3. +101
    -0
      modelscope/models/nlp/space/dialog_state_tracking_model.py
  4. +1
    -1
      modelscope/pipelines/nlp/__init__.py
  5. +0
    -45
      modelscope/pipelines/nlp/dialog_state_tracking.py
  6. +146
    -0
      modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py
  7. +100
    -12
      modelscope/preprocessors/space/dialog_state_tracking_preprocessor.py
  8. +17
    -14
      modelscope/preprocessors/space/dst_processors.py
  9. +59
    -0
      modelscope/preprocessors/space/tensorlistdataset.py
  10. +10
    -0
      modelscope/utils/nlp/space/utils_dst.py
  11. +18
    -16
      tests/pipelines/nlp/test_dialog_state_tracking.py

+ 1
- 1
modelscope/models/nlp/__init__.py View File

@@ -7,4 +7,4 @@ from .sbert_for_sentiment_classification import * # noqa F403
from .sbert_for_token_classification import * # noqa F403
from .space.dialog_intent_prediction_model import * # noqa F403
from .space.dialog_modeling_model import * # noqa F403
from .space.dialog_state_tracking import * # noqa F403
from .space.dialog_state_tracking_model import * # noqa F403

+ 0
- 77
modelscope/models/nlp/space/dialog_state_tracking.py View File

@@ -1,77 +0,0 @@
import os
from typing import Any, Dict

from modelscope.utils.config import Config
from modelscope.utils.constant import Tasks
from ...base import Model, Tensor
from ...builder import MODELS
from .model.generator import Generator
from .model.model_base import ModelBase

__all__ = ['DialogStateTrackingModel']


@MODELS.register_module(Tasks.dialog_state_tracking, module_name=r'space')
class DialogStateTrackingModel(Model):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the test generation model from the `model_dir` path.

Args:
model_dir (str): the model path.
model_cls (Optional[Any], optional): model loader, if None, use the
default loader to load model weights, by default None.
"""

super().__init__(model_dir, *args, **kwargs)
self.model_dir = model_dir
self.config = kwargs.pop(
'config',
Config.from_file(
os.path.join(self.model_dir, 'configuration.json')))
self.text_field = kwargs.pop(
'text_field',
IntentBPETextField(self.model_dir, config=self.config))

self.generator = Generator.create(self.config, reader=self.text_field)
self.model = ModelBase.create(
model_dir=model_dir,
config=self.config,
reader=self.text_field,
generator=self.generator)

def to_tensor(array):
"""
numpy array -> tensor
"""
import torch
array = torch.tensor(array)
return array.cuda() if self.config.use_gpu else array

self.trainer = IntentTrainer(
model=self.model,
to_tensor=to_tensor,
config=self.config,
reader=self.text_field)
self.trainer.load()

def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""return the result by the model

Args:
input (Dict[str, Any]): the preprocessed data

Returns:
Dict[str, np.ndarray]: results
Example:
{
'predictions': array([1]), # lable 0-negative 1-positive
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
}
"""
import numpy as np
pred = self.trainer.forward(input)
pred = np.squeeze(pred[0], 0)

return {'pred': pred}

+ 101
- 0
modelscope/models/nlp/space/dialog_state_tracking_model.py View File

@@ -0,0 +1,101 @@
import os
from typing import Any, Dict

from modelscope.utils.constant import Tasks
from ....utils.nlp.space.utils_dst import batch_to_device
from ...base import Model, Tensor
from ...builder import MODELS

__all__ = ['DialogStateTrackingModel']


@MODELS.register_module(Tasks.dialog_state_tracking, module_name=r'space')
class DialogStateTrackingModel(Model):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the test generation model from the `model_dir` path.

Args:
model_dir (str): the model path.
model_cls (Optional[Any], optional): model loader, if None, use the
default loader to load model weights, by default None.
"""

super().__init__(model_dir, *args, **kwargs)

from sofa.models.space import SpaceForDST, SpaceConfig
self.model_dir = model_dir

self.config = SpaceConfig.from_pretrained(self.model_dir)
# self.model = SpaceForDST(self.config)
self.model = SpaceForDST.from_pretrained(self.model_dir)
self.model.to(self.config.device)

def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""return the result by the model

Args:
input (Dict[str, Any]): the preprocessed data

Returns:
Dict[str, np.ndarray]: results
Example:
{
'predictions': array([1]), # lable 0-negative 1-positive
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
}
"""
import numpy as np
import torch

self.model.eval()
batch = input['batch']
batch = batch_to_device(batch, self.config.device)

features = input['features']
diag_state = input['diag_state']
turn_itrs = [features[i.item()].guid.split('-')[2] for i in batch[9]]
reset_diag_state = np.where(np.array(turn_itrs) == '0')[0]
for slot in self.config.dst_slot_list:
for i in reset_diag_state:
diag_state[slot][i] = 0

with torch.no_grad():
inputs = {
'input_ids': batch[0],
'input_mask': batch[1],
'segment_ids': batch[2],
'start_pos': batch[3],
'end_pos': batch[4],
'inform_slot_id': batch[5],
'refer_id': batch[6],
'diag_state': diag_state,
'class_label_id': batch[8]
}
unique_ids = [features[i.item()].guid for i in batch[9]]
values = [features[i.item()].values for i in batch[9]]
input_ids_unmasked = [
features[i.item()].input_ids_unmasked for i in batch[9]
]
inform = [features[i.item()].inform for i in batch[9]]
outputs = self.model(**inputs)

# Update dialog state for next turn.
for slot in self.config.dst_slot_list:
updates = outputs[2][slot].max(1)[1]
for i, u in enumerate(updates):
if u != 0:
diag_state[slot][i] = u

print(outputs)

return {
'inputs': inputs,
'outputs': outputs,
'unique_ids': unique_ids,
'input_ids_unmasked': input_ids_unmasked,
'values': values,
'inform': inform,
'prefix': 'final'
}

+ 1
- 1
modelscope/pipelines/nlp/__init__.py View File

@@ -1,6 +1,6 @@
from .dialog_intent_prediction_pipeline import * # noqa F403
from .dialog_modeling_pipeline import * # noqa F403
from .dialog_state_tracking import * # noqa F403
from .dialog_state_tracking_pipeline import * # noqa F403
from .fill_mask_pipeline import * # noqa F403
from .nli_pipeline import * # noqa F403
from .sentence_similarity_pipeline import * # noqa F403


+ 0
- 45
modelscope/pipelines/nlp/dialog_state_tracking.py View File

@@ -1,45 +0,0 @@
from typing import Any, Dict

from ...metainfo import Pipelines
from ...models.nlp import DialogStateTrackingModel
from ...preprocessors import DialogStateTrackingPreprocessor
from ...utils.constant import Tasks
from ..base import Pipeline
from ..builder import PIPELINES

__all__ = ['DialogStateTrackingPipeline']


@PIPELINES.register_module(
Tasks.dialog_state_tracking, module_name=Pipelines.dialog_state_tracking)
class DialogStateTrackingPipeline(Pipeline):

def __init__(self, model: DialogStateTrackingModel,
preprocessor: DialogStateTrackingPreprocessor, **kwargs):
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction

Args:
model (SequenceClassificationModel): a model instance
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance
"""

super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.model = model
# self.tokenizer = preprocessor.tokenizer

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""process the prediction results

Args:
inputs (Dict[str, Any]): _description_

Returns:
Dict[str, str]: the prediction results
"""
import numpy as np
pred = inputs['pred']
pos = np.where(pred == np.max(pred))

result = {'pred': pred, 'label': pos[0]}

return result

+ 146
- 0
modelscope/pipelines/nlp/dialog_state_tracking_pipeline.py View File

@@ -0,0 +1,146 @@
from typing import Any, Dict

from ...metainfo import Pipelines
from ...models.nlp import DialogStateTrackingModel
from ...preprocessors import DialogStateTrackingPreprocessor
from ...utils.constant import Tasks
from ..base import Pipeline
from ..builder import PIPELINES

__all__ = ['DialogStateTrackingPipeline']


@PIPELINES.register_module(
Tasks.dialog_state_tracking, module_name=Pipelines.dialog_state_tracking)
class DialogStateTrackingPipeline(Pipeline):

def __init__(self, model: DialogStateTrackingModel,
preprocessor: DialogStateTrackingPreprocessor, **kwargs):
"""use `model` and `preprocessor` to create a nlp text classification pipeline for prediction

Args:
model (SequenceClassificationModel): a model instance
preprocessor (SequenceClassificationPreprocessor): a preprocessor instance
"""

super().__init__(model=model, preprocessor=preprocessor, **kwargs)
self.model = model
self.tokenizer = preprocessor.tokenizer
self.config = preprocessor.config

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""process the prediction results

Args:
inputs (Dict[str, Any]): _description_

Returns:
Dict[str, str]: the prediction results
"""

_inputs = inputs['inputs']
_outputs = inputs['outputs']
unique_ids = inputs['unique_ids']
input_ids_unmasked = inputs['input_ids_unmasked']
values = inputs['values']
inform = inputs['inform']
prefix = inputs['prefix']
ds = {slot: 'none' for slot in self.config.dst_slot_list}

ds = predict_and_format(self.config, self.tokenizer, _inputs,
_outputs[2], _outputs[3], _outputs[4],
_outputs[5], unique_ids, input_ids_unmasked,
values, inform, prefix, ds)

return ds


def predict_and_format(config, tokenizer, features, per_slot_class_logits,
per_slot_start_logits, per_slot_end_logits,
per_slot_refer_logits, ids, input_ids_unmasked, values,
inform, prefix, ds):
import re

prediction_list = []
dialog_state = ds
for i in range(len(ids)):
if int(ids[i].split('-')[2]) == 0:
dialog_state = {slot: 'none' for slot in config.dst_slot_list}

prediction = {}
prediction_addendum = {}
for slot in config.dst_slot_list:
class_logits = per_slot_class_logits[slot][i]
start_logits = per_slot_start_logits[slot][i]
end_logits = per_slot_end_logits[slot][i]
refer_logits = per_slot_refer_logits[slot][i]

input_ids = features['input_ids'][i].tolist()
class_label_id = int(features['class_label_id'][slot][i])
start_pos = int(features['start_pos'][slot][i])
end_pos = int(features['end_pos'][slot][i])
refer_id = int(features['refer_id'][slot][i])

class_prediction = int(class_logits.argmax())
start_prediction = int(start_logits.argmax())
end_prediction = int(end_logits.argmax())
refer_prediction = int(refer_logits.argmax())

prediction['guid'] = ids[i].split('-')
prediction['class_prediction_%s' % slot] = class_prediction
prediction['class_label_id_%s' % slot] = class_label_id
prediction['start_prediction_%s' % slot] = start_prediction
prediction['start_pos_%s' % slot] = start_pos
prediction['end_prediction_%s' % slot] = end_prediction
prediction['end_pos_%s' % slot] = end_pos
prediction['refer_prediction_%s' % slot] = refer_prediction
prediction['refer_id_%s' % slot] = refer_id
prediction['input_ids_%s' % slot] = input_ids

if class_prediction == config.dst_class_types.index('dontcare'):
dialog_state[slot] = 'dontcare'
elif class_prediction == config.dst_class_types.index(
'copy_value'):
input_tokens = tokenizer.convert_ids_to_tokens(
input_ids_unmasked[i])
dialog_state[slot] = ' '.join(
input_tokens[start_prediction:end_prediction + 1])
dialog_state[slot] = re.sub('(^| )##', '', dialog_state[slot])
elif 'true' in config.dst_class_types and class_prediction == config.dst_class_types.index(
'true'):
dialog_state[slot] = 'true'
elif 'false' in config.dst_class_types and class_prediction == config.dst_class_types.index(
'false'):
dialog_state[slot] = 'false'
elif class_prediction == config.dst_class_types.index('inform'):
dialog_state[slot] = '§§' + inform[i][slot]
# Referral case is handled below

prediction_addendum['slot_prediction_%s'
% slot] = dialog_state[slot]
prediction_addendum['slot_groundtruth_%s' % slot] = values[i][slot]

# Referral case. All other slot values need to be seen first in order
# to be able to do this correctly.
for slot in config.dst_slot_list:
class_logits = per_slot_class_logits[slot][i]
refer_logits = per_slot_refer_logits[slot][i]

class_prediction = int(class_logits.argmax())
refer_prediction = int(refer_logits.argmax())

if 'refer' in config.dst_class_types and class_prediction == config.dst_class_types.index(
'refer'):
# Only slots that have been mentioned before can be referred to.
# One can think of a situation where one slot is referred to in the same utterance.
# This phenomenon is however currently not properly covered in the training data
# label generation process.
dialog_state[slot] = dialog_state[config.dst_slot_list[
refer_prediction - 1]]
prediction_addendum['slot_prediction_%s' %
slot] = dialog_state[slot] # Value update

prediction.update(prediction_addendum)
prediction_list.append(prediction)

return dialog_state

+ 100
- 12
modelscope/preprocessors/space/dialog_state_tracking_preprocessor.py View File

@@ -3,13 +3,12 @@
import os
from typing import Any, Dict

from modelscope.preprocessors.space.fields.intent_field import \
IntentBPETextField
from modelscope.utils.config import Config
from modelscope.utils.constant import Fields
from modelscope.utils.type_assert import type_assert
from ..base import Preprocessor
from ..builder import PREPROCESSORS
from .dst_processors import convert_examples_to_features, multiwoz22Processor
from .tensorlistdataset import TensorListDataset

__all__ = ['DialogStateTrackingPreprocessor']

@@ -25,14 +24,14 @@ class DialogStateTrackingPreprocessor(Preprocessor):
"""
super().__init__(*args, **kwargs)

from sofa.models.space import SpaceTokenizer, SpaceConfig
self.model_dir: str = model_dir
self.config = Config.from_file(
os.path.join(self.model_dir, 'configuration.json'))
self.text_field = IntentBPETextField(
self.model_dir, config=self.config)
self.config = SpaceConfig.from_pretrained(self.model_dir)
self.tokenizer = SpaceTokenizer.from_pretrained(self.model_dir)
self.processor = multiwoz22Processor()

@type_assert(object, str)
def __call__(self, data: str) -> Dict[str, Any]:
@type_assert(object, dict)
def __call__(self, data: Dict) -> Dict[str, Any]:
"""process the raw input data

Args:
@@ -43,7 +42,96 @@ class DialogStateTrackingPreprocessor(Preprocessor):
Returns:
Dict[str, Any]: the preprocessed data
"""
samples = self.text_field.preprocessor([data])
samples, _ = self.text_field.collate_fn_multi_turn(samples)
import torch
from torch.utils.data import (DataLoader, RandomSampler,
SequentialSampler)

return samples
utter = data['utter']
history_states = data['history_states']
example = self.processor.create_example(
inputs=utter,
history_states=history_states,
set_type='test',
slot_list=self.config.dst_slot_list,
label_maps={},
append_history=True,
use_history_labels=True,
swap_utterances=True,
label_value_repetitions=True,
delexicalize_sys_utts=True,
unk_token='[UNK]',
analyze=False)
print(example)

features = convert_examples_to_features(
examples=[example],
slot_list=self.config.dst_slot_list,
class_types=self.config.dst_class_types,
model_type=self.config.model_type,
tokenizer=self.tokenizer,
max_seq_length=180, # args.max_seq_length
slot_value_dropout=(0.0))

all_input_ids = torch.tensor([f.input_ids for f in features],
dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features],
dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in features],
dtype=torch.long)
all_example_index = torch.arange(
all_input_ids.size(0), dtype=torch.long)
f_start_pos = [f.start_pos for f in features]
f_end_pos = [f.end_pos for f in features]
f_inform_slot_ids = [f.inform_slot for f in features]
f_refer_ids = [f.refer_id for f in features]
f_diag_state = [f.diag_state for f in features]
f_class_label_ids = [f.class_label_id for f in features]
all_start_positions = {}
all_end_positions = {}
all_inform_slot_ids = {}
all_refer_ids = {}
all_diag_state = {}
all_class_label_ids = {}
for s in self.config.dst_slot_list:
all_start_positions[s] = torch.tensor([f[s] for f in f_start_pos],
dtype=torch.long)
all_end_positions[s] = torch.tensor([f[s] for f in f_end_pos],
dtype=torch.long)
all_inform_slot_ids[s] = torch.tensor(
[f[s] for f in f_inform_slot_ids], dtype=torch.long)
all_refer_ids[s] = torch.tensor([f[s] for f in f_refer_ids],
dtype=torch.long)
all_diag_state[s] = torch.tensor([f[s] for f in f_diag_state],
dtype=torch.long)
all_class_label_ids[s] = torch.tensor(
[f[s] for f in f_class_label_ids], dtype=torch.long)
# dataset = TensorListDataset(all_input_ids, all_input_mask, all_segment_ids,
# all_start_positions, all_end_positions,
# all_inform_slot_ids,
# all_refer_ids,
# all_diag_state,
# all_class_label_ids, all_example_index)
#
# eval_sampler = SequentialSampler(dataset)
# eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=self.config.eval_batch_size)
dataset = [
all_input_ids, all_input_mask, all_segment_ids,
all_start_positions, all_end_positions, all_inform_slot_ids,
all_refer_ids, all_diag_state, all_class_label_ids,
all_example_index
]

with torch.no_grad():
diag_state = {
slot:
torch.tensor([0 for _ in range(self.config.eval_batch_size)
]).to(self.config.device)
for slot in self.config.dst_slot_list
}
# print(diag_state)

return {
'batch': dataset,
'features': features,
'diag_state': diag_state
}

modelscope/preprocessors/space/fields/dst_processors.py → modelscope/preprocessors/space/dst_processors.py View File

@@ -1097,29 +1097,31 @@ class DSTExample(object):
return self.__repr__()

def __repr__(self):
s = ''
s += 'guid: %s' % (self.guid)
s += ', text_a: %s' % (self.text_a)
s += ', text_b: %s' % (self.text_b)
s += ', history: %s' % (self.history)
s_dict = dict()
s_dict['guid'] = self.guid
s_dict['text_a'] = self.text_a
s_dict['text_b'] = self.text_b
s_dict['history'] = self.history
if self.text_a_label:
s += ', text_a_label: %d' % (self.text_a_label)
s_dict['text_a_label'] = self.text_a_label
if self.text_b_label:
s += ', text_b_label: %d' % (self.text_b_label)
s_dict['text_b_label'] = self.text_b_label
if self.history_label:
s += ', history_label: %d' % (self.history_label)
s_dict['history_label'] = self.history_label
if self.values:
s += ', values: %d' % (self.values)
s_dict['values'] = self.values
if self.inform_label:
s += ', inform_label: %d' % (self.inform_label)
s_dict['inform_label'] = self.inform_label
if self.inform_slot_label:
s += ', inform_slot_label: %d' % (self.inform_slot_label)
s_dict['inform_slot_label'] = self.inform_slot_label
if self.refer_label:
s += ', refer_label: %d' % (self.refer_label)
s_dict['refer_label'] = self.refer_label
if self.diag_state:
s += ', diag_state: %d' % (self.diag_state)
s_dict['diag_state'] = self.diag_state
if self.class_label:
s += ', class_label: %d' % (self.class_label)
s_dict['class_label'] = self.class_label

s = json.dumps(s_dict)
return s


@@ -1515,6 +1517,7 @@ if __name__ == '__main__':
delexicalize_sys_utts = True,
unk_token = '[UNK]'
analyze = False

example = processor.create_example(utter1, history_states1, set_type,
slot_list, {}, append_history,
use_history_labels, swap_utterances,

+ 59
- 0
modelscope/preprocessors/space/tensorlistdataset.py View File

@@ -0,0 +1,59 @@
#
# Copyright 2020 Heinrich Heine University Duesseldorf
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from torch.utils.data import Dataset


class TensorListDataset(Dataset):
r"""Dataset wrapping tensors, tensor dicts and tensor lists.

Arguments:
*data (Tensor or dict or list of Tensors): tensors that have the same size
of the first dimension.
"""

def __init__(self, *data):
if isinstance(data[0], dict):
size = list(data[0].values())[0].size(0)
elif isinstance(data[0], list):
size = data[0][0].size(0)
else:
size = data[0].size(0)
for element in data:
if isinstance(element, dict):
assert all(
size == tensor.size(0)
for name, tensor in element.items()) # dict of tensors
elif isinstance(element, list):
assert all(size == tensor.size(0)
for tensor in element) # list of tensors
else:
assert size == element.size(0) # tensor
self.size = size
self.data = data

def __getitem__(self, index):
result = []
for element in self.data:
if isinstance(element, dict):
result.append({k: v[index] for k, v in element.items()})
elif isinstance(element, list):
result.append(v[index] for v in element)
else:
result.append(element[index])
return tuple(result)

def __len__(self):
return self.size

+ 10
- 0
modelscope/utils/nlp/space/utils_dst.py View File

@@ -0,0 +1,10 @@
def batch_to_device(batch, device):
batch_on_device = []
for element in batch:
if isinstance(element, dict):
batch_on_device.append(
{k: v.to(device)
for k, v in element.items()})
else:
batch_on_device.append(element.to(device))
return tuple(batch_on_device)

+ 18
- 16
tests/pipelines/nlp/test_dialog_state_tracking.py View File

@@ -14,26 +14,28 @@ from modelscope.utils.constant import Tasks

class DialogStateTrackingTest(unittest.TestCase):
model_id = 'damo/nlp_space_dialog-state-tracking'
test_case = {}

test_case = [{
'utter': {
'User-1':
"I'm looking for a place to stay. It needs to be a guesthouse and include free wifi."
},
'history_states': [{}]
}]

def test_run(self):
# cache_path = ''
cache_path = '/Users/yangliu/Space/maas_model/nlp_space_dialog-state-tracking'
# cache_path = snapshot_download(self.model_id)

# preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path)
# model = DialogStateTrackingModel(
# model_dir=cache_path,
# text_field=preprocessor.text_field,
# config=preprocessor.config)
# pipelines = [
# DialogStateTrackingPipeline(model=model, preprocessor=preprocessor),
# pipeline(
# task=Tasks.dialog_modeling,
# model=model,
# preprocessor=preprocessor)
# ]

print('jizhu test')
model = DialogStateTrackingModel(cache_path)
preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path)
pipeline1 = DialogStateTrackingPipeline(
model=model, preprocessor=preprocessor)

history_states = {}
for step, item in enumerate(self.test_case):
history_states = pipeline1(item)
print(history_states)

@unittest.skip('test with snapshot_download')
def test_run_with_model_from_modelhub(self):


Loading…
Cancel
Save