|
@@ -18,9 +18,102 @@ class DialogStateTrackingTest(unittest.TestCase): |
|
|
test_case = [{ |
|
|
test_case = [{ |
|
|
'utter': { |
|
|
'utter': { |
|
|
'User-1': |
|
|
'User-1': |
|
|
"I'm looking for a place to stay. It needs to be a guesthouse and include free wifi." |
|
|
|
|
|
|
|
|
'am looking for a place to to stay that has cheap price range it should be in a type of hotel' |
|
|
}, |
|
|
}, |
|
|
'history_states': [{}] |
|
|
'history_states': [{}] |
|
|
|
|
|
}, { |
|
|
|
|
|
'utter': { |
|
|
|
|
|
'User-1': |
|
|
|
|
|
'am looking for a place to to stay that has cheap price range it should be in a type of hotel', |
|
|
|
|
|
'System-1': |
|
|
|
|
|
'Okay, do you have a specific area you want to stay in?', |
|
|
|
|
|
'Dialog_Act-1': { |
|
|
|
|
|
'Hotel-Request': [['Area', '?']] |
|
|
|
|
|
}, |
|
|
|
|
|
'User-2': |
|
|
|
|
|
"no, i just need to make sure it's cheap. oh, and i need parking" |
|
|
|
|
|
}, |
|
|
|
|
|
'history_states': [{}, { |
|
|
|
|
|
'taxi': { |
|
|
|
|
|
'book': { |
|
|
|
|
|
'booked': [] |
|
|
|
|
|
}, |
|
|
|
|
|
'semi': { |
|
|
|
|
|
'leaveAt': '', |
|
|
|
|
|
'destination': '', |
|
|
|
|
|
'departure': '', |
|
|
|
|
|
'arriveBy': '' |
|
|
|
|
|
} |
|
|
|
|
|
}, |
|
|
|
|
|
'police': { |
|
|
|
|
|
'book': { |
|
|
|
|
|
'booked': [] |
|
|
|
|
|
}, |
|
|
|
|
|
'semi': {} |
|
|
|
|
|
}, |
|
|
|
|
|
'restaurant': { |
|
|
|
|
|
'book': { |
|
|
|
|
|
'booked': [], |
|
|
|
|
|
'people': '', |
|
|
|
|
|
'day': '', |
|
|
|
|
|
'time': '' |
|
|
|
|
|
}, |
|
|
|
|
|
'semi': { |
|
|
|
|
|
'food': '', |
|
|
|
|
|
'pricerange': '', |
|
|
|
|
|
'name': '', |
|
|
|
|
|
'area': '' |
|
|
|
|
|
} |
|
|
|
|
|
}, |
|
|
|
|
|
'hospital': { |
|
|
|
|
|
'book': { |
|
|
|
|
|
'booked': [] |
|
|
|
|
|
}, |
|
|
|
|
|
'semi': { |
|
|
|
|
|
'department': '' |
|
|
|
|
|
} |
|
|
|
|
|
}, |
|
|
|
|
|
'hotel': { |
|
|
|
|
|
'book': { |
|
|
|
|
|
'booked': [], |
|
|
|
|
|
'people': '', |
|
|
|
|
|
'day': '', |
|
|
|
|
|
'stay': '' |
|
|
|
|
|
}, |
|
|
|
|
|
'semi': { |
|
|
|
|
|
'name': 'not mentioned', |
|
|
|
|
|
'area': 'not mentioned', |
|
|
|
|
|
'parking': 'not mentioned', |
|
|
|
|
|
'pricerange': 'cheap', |
|
|
|
|
|
'stars': 'not mentioned', |
|
|
|
|
|
'internet': 'not mentioned', |
|
|
|
|
|
'type': 'hotel' |
|
|
|
|
|
} |
|
|
|
|
|
}, |
|
|
|
|
|
'attraction': { |
|
|
|
|
|
'book': { |
|
|
|
|
|
'booked': [] |
|
|
|
|
|
}, |
|
|
|
|
|
'semi': { |
|
|
|
|
|
'type': '', |
|
|
|
|
|
'name': '', |
|
|
|
|
|
'area': '' |
|
|
|
|
|
} |
|
|
|
|
|
}, |
|
|
|
|
|
'train': { |
|
|
|
|
|
'book': { |
|
|
|
|
|
'booked': [], |
|
|
|
|
|
'people': '' |
|
|
|
|
|
}, |
|
|
|
|
|
'semi': { |
|
|
|
|
|
'leaveAt': '', |
|
|
|
|
|
'destination': '', |
|
|
|
|
|
'day': '', |
|
|
|
|
|
'arriveBy': '', |
|
|
|
|
|
'departure': '' |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
}, {}] |
|
|
}] |
|
|
}] |
|
|
|
|
|
|
|
|
def test_run(self): |
|
|
def test_run(self): |
|
@@ -29,12 +122,19 @@ class DialogStateTrackingTest(unittest.TestCase): |
|
|
|
|
|
|
|
|
model = DialogStateTrackingModel(cache_path) |
|
|
model = DialogStateTrackingModel(cache_path) |
|
|
preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path) |
|
|
preprocessor = DialogStateTrackingPreprocessor(model_dir=cache_path) |
|
|
pipeline1 = DialogStateTrackingPipeline( |
|
|
|
|
|
model=model, preprocessor=preprocessor) |
|
|
|
|
|
|
|
|
pipelines = [ |
|
|
|
|
|
DialogStateTrackingPipeline( |
|
|
|
|
|
model=model, preprocessor=preprocessor), |
|
|
|
|
|
# pipeline( |
|
|
|
|
|
# task=Tasks.dialog_state_tracking, |
|
|
|
|
|
# model=model, |
|
|
|
|
|
# preprocessor=preprocessor) |
|
|
|
|
|
] |
|
|
|
|
|
|
|
|
history_states = {} |
|
|
history_states = {} |
|
|
|
|
|
pipelines_len = len(pipelines) |
|
|
for step, item in enumerate(self.test_case): |
|
|
for step, item in enumerate(self.test_case): |
|
|
history_states = pipeline1(item) |
|
|
|
|
|
|
|
|
history_states = pipelines[step % pipelines_len](item) |
|
|
print(history_states) |
|
|
print(history_states) |
|
|
|
|
|
|
|
|
@unittest.skip('test with snapshot_download') |
|
|
@unittest.skip('test with snapshot_download') |
|
|