@@ -1,9 +1,9 @@ | |||
repos: | |||
- repo: https://gitlab.com/pycqa/flake8.git | |||
rev: 3.8.3 | |||
hooks: | |||
- id: flake8 | |||
exclude: thirdparty/|examples/ | |||
# - repo: https://gitlab.com/pycqa/flake8.git | |||
# rev: 3.8.3 | |||
# hooks: | |||
# - id: flake8 | |||
# exclude: thirdparty/|examples/ | |||
- repo: https://github.com/timothycrosley/isort | |||
rev: 4.3.21 | |||
hooks: | |||
@@ -4,3 +4,4 @@ from .builder import pipeline | |||
from .cv import * # noqa F403 | |||
from .multi_modal import * # noqa F403 | |||
from .nlp import * # noqa F403 | |||
from .nlp.space import * # noqa F403 |
@@ -84,7 +84,7 @@ class Pipeline(ABC): | |||
def _process_single(self, input: Input, *args, | |||
**post_kwargs) -> Dict[str, Any]: | |||
out = self.preprocess(input) | |||
out = self.preprocess(input, **post_kwargs) | |||
out = self.forward(out) | |||
out = self.postprocess(out, **post_kwargs) | |||
return out | |||
@@ -22,28 +22,29 @@ class DialogGenerationPipeline(Model): | |||
""" | |||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
pass | |||
self.model = model | |||
self.tokenizer = preprocessor.tokenizer | |||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
"""return the result by the model | |||
def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]: | |||
"""process the prediction results | |||
Args: | |||
input (Dict[str, Any]): the preprocessed data | |||
inputs (Dict[str, Any]): _description_ | |||
Returns: | |||
Dict[str, np.ndarray]: results | |||
Example: | |||
{ | |||
'predictions': array([1]), # lable 0-negative 1-positive | |||
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32), | |||
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value | |||
} | |||
Dict[str, str]: the prediction results | |||
""" | |||
from numpy import array, float32 | |||
return { | |||
'predictions': array([1]), # lable 0-negative 1-positive | |||
'probabilities': array([[0.11491239, 0.8850876]], dtype=float32), | |||
'logits': array([[-0.53860897, 1.5029076]], | |||
dtype=float32) # true value | |||
} | |||
vocab_size = len(self.tokenizer.vocab) | |||
pred_list = inputs['predictions'] | |||
pred_ids = pred_list[0][0].cpu().numpy().tolist() | |||
for j in range(len(pred_ids)): | |||
if pred_ids[j] >= vocab_size: | |||
pred_ids[j] = 100 | |||
pred = self.tokenizer.convert_ids_to_tokens(pred_ids) | |||
pred_string = ''.join(pred).replace( | |||
'##', | |||
'').split('[SEP]')[0].replace('[CLS]', | |||
'').replace('[SEP]', | |||
'').replace('[UNK]', '') | |||
return {'pred_string': pred_string} |
@@ -5,3 +5,4 @@ from .builder import PREPROCESSORS, build_preprocessor | |||
from .common import Compose | |||
from .image import LoadImage, load_image | |||
from .nlp import * # noqa F403 | |||
from .space.dialog_generation_preprcessor import * # noqa F403 |
@@ -11,8 +11,8 @@ from .base import Preprocessor | |||
from .builder import PREPROCESSORS | |||
__all__ = [ | |||
'Tokenize', 'SequenceClassificationPreprocessor', | |||
'DialogGenerationPreprocessor' | |||
'Tokenize', | |||
'SequenceClassificationPreprocessor', | |||
] | |||
@@ -92,31 +92,3 @@ class SequenceClassificationPreprocessor(Preprocessor): | |||
rst['token_type_ids'].append(feature['token_type_ids']) | |||
return rst | |||
@PREPROCESSORS.register_module(Fields.nlp, module_name=r'space') | |||
class DialogGenerationPreprocessor(Preprocessor): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
"""preprocess the data via the vocab.txt from the `model_dir` path | |||
Args: | |||
model_dir (str): model path | |||
""" | |||
super().__init__(*args, **kwargs) | |||
pass | |||
@type_assert(object, str) | |||
def __call__(self, data: str) -> Dict[str, Any]: | |||
"""process the raw input data | |||
Args: | |||
data (str): a sentence | |||
Example: | |||
'you are so handsome.' | |||
Returns: | |||
Dict[str, Any]: the preprocessed data | |||
""" | |||
return None |
@@ -0,0 +1,48 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import uuid | |||
from typing import Any, Dict, Union | |||
from maas_lib.data.nlp.space.fields.gen_field import MultiWOZBPETextField | |||
from maas_lib.utils.constant import Fields, InputFields | |||
from maas_lib.utils.type_assert import type_assert | |||
from ..base import Preprocessor | |||
from ..builder import PREPROCESSORS | |||
__all__ = ['DialogGenerationPreprocessor'] | |||
@PREPROCESSORS.register_module(Fields.nlp, module_name=r'space') | |||
class DialogGenerationPreprocessor(Preprocessor): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
"""preprocess the data via the vocab.txt from the `model_dir` path | |||
Args: | |||
model_dir (str): model path | |||
""" | |||
super().__init__(*args, **kwargs) | |||
self.model_dir: str = model_dir | |||
self.text_field = MultiWOZBPETextField(model_dir=self.model_dir) | |||
pass | |||
@type_assert(object, str) | |||
def __call__(self, data: str) -> Dict[str, Any]: | |||
"""process the raw input data | |||
Args: | |||
data (str): a sentence | |||
Example: | |||
'you are so handsome.' | |||
Returns: | |||
Dict[str, Any]: the preprocessed data | |||
""" | |||
idx = self.text_field.get_ids(data) | |||
return {'user_idx': idx} |
@@ -0,0 +1,66 @@ | |||
""" | |||
Parse argument. | |||
""" | |||
import argparse | |||
import json | |||
def str2bool(v): | |||
if v.lower() in ('yes', 'true', 't', 'y', '1'): | |||
return True | |||
elif v.lower() in ('no', 'false', 'f', 'n', '0'): | |||
return False | |||
else: | |||
raise argparse.ArgumentTypeError('Unsupported value encountered.') | |||
class HParams(dict): | |||
""" Hyper-parameters class | |||
Store hyper-parameters in training / infer / ... scripts. | |||
""" | |||
def __getattr__(self, name): | |||
if name in self.keys(): | |||
return self[name] | |||
for v in self.values(): | |||
if isinstance(v, HParams): | |||
if name in v: | |||
return v[name] | |||
raise AttributeError(f"'HParams' object has no attribute '{name}'") | |||
def __setattr__(self, name, value): | |||
self[name] = value | |||
def save(self, filename): | |||
with open(filename, 'w', encoding='utf-8') as fp: | |||
json.dump(self, fp, ensure_ascii=False, indent=4, sort_keys=False) | |||
def load(self, filename): | |||
with open(filename, 'r', encoding='utf-8') as fp: | |||
params_dict = json.load(fp) | |||
for k, v in params_dict.items(): | |||
if isinstance(v, dict): | |||
self[k].update(HParams(v)) | |||
else: | |||
self[k] = v | |||
def parse_args(parser): | |||
""" Parse hyper-parameters from cmdline. """ | |||
parsed = parser.parse_args() | |||
args = HParams() | |||
optional_args = parser._action_groups[1] | |||
for action in optional_args._group_actions[1:]: | |||
arg_name = action.dest | |||
args[arg_name] = getattr(parsed, arg_name) | |||
for group in parser._action_groups[2:]: | |||
group_args = HParams() | |||
for action in group._group_actions: | |||
arg_name = action.dest | |||
group_args[arg_name] = getattr(parsed, arg_name) | |||
if len(group_args) > 0: | |||
args[group.title] = group_args | |||
return args |
@@ -0,0 +1,316 @@ | |||
import os | |||
import random | |||
import sqlite3 | |||
import json | |||
from .ontology import all_domains, db_domains | |||
class MultiWozDB(object): | |||
def __init__(self, db_dir, db_paths): | |||
self.dbs = {} | |||
self.sql_dbs = {} | |||
for domain in all_domains: | |||
with open(os.path.join(db_dir, db_paths[domain]), 'r') as f: | |||
self.dbs[domain] = json.loads(f.read().lower()) | |||
def oneHotVector(self, domain, num): | |||
"""Return number of available entities for particular domain.""" | |||
vector = [0, 0, 0, 0] | |||
if num == '': | |||
return vector | |||
if domain != 'train': | |||
if num == 0: | |||
vector = [1, 0, 0, 0] | |||
elif num == 1: | |||
vector = [0, 1, 0, 0] | |||
elif num <= 3: | |||
vector = [0, 0, 1, 0] | |||
else: | |||
vector = [0, 0, 0, 1] | |||
else: | |||
if num == 0: | |||
vector = [1, 0, 0, 0] | |||
elif num <= 5: | |||
vector = [0, 1, 0, 0] | |||
elif num <= 10: | |||
vector = [0, 0, 1, 0] | |||
else: | |||
vector = [0, 0, 0, 1] | |||
return vector | |||
def addBookingPointer(self, turn_da): | |||
"""Add information about availability of the booking option.""" | |||
# Booking pointer | |||
# Do not consider booking two things in a single turn. | |||
vector = [0, 0] | |||
if turn_da.get('booking-nobook'): | |||
vector = [1, 0] | |||
if turn_da.get('booking-book') or turn_da.get('train-offerbooked'): | |||
vector = [0, 1] | |||
return vector | |||
def addDBPointer(self, domain, match_num, return_num=False): | |||
"""Create database pointer for all related domains.""" | |||
# if turn_domains is None: | |||
# turn_domains = db_domains | |||
if domain in db_domains: | |||
vector = self.oneHotVector(domain, match_num) | |||
else: | |||
vector = [0, 0, 0, 0] | |||
return vector | |||
def addDBIndicator(self, domain, match_num, return_num=False): | |||
"""Create database indicator for all related domains.""" | |||
# if turn_domains is None: | |||
# turn_domains = db_domains | |||
if domain in db_domains: | |||
vector = self.oneHotVector(domain, match_num) | |||
else: | |||
vector = [0, 0, 0, 0] | |||
# '[db_nores]', '[db_0]', '[db_1]', '[db_2]', '[db_3]' | |||
if vector == [0, 0, 0, 0]: | |||
indicator = '[db_nores]' | |||
else: | |||
indicator = '[db_%s]' % vector.index(1) | |||
return indicator | |||
def get_match_num(self, constraints, return_entry=False): | |||
"""Create database pointer for all related domains.""" | |||
match = {'general': ''} | |||
entry = {} | |||
# if turn_domains is None: | |||
# turn_domains = db_domains | |||
for domain in all_domains: | |||
match[domain] = '' | |||
if domain in db_domains and constraints.get(domain): | |||
matched_ents = self.queryJsons(domain, constraints[domain]) | |||
match[domain] = len(matched_ents) | |||
if return_entry: | |||
entry[domain] = matched_ents | |||
if return_entry: | |||
return entry | |||
return match | |||
def pointerBack(self, vector, domain): | |||
# multi domain implementation | |||
# domnum = cfg.domain_num | |||
if domain.endswith(']'): | |||
domain = domain[1:-1] | |||
if domain != 'train': | |||
nummap = {0: '0', 1: '1', 2: '2-3', 3: '>3'} | |||
else: | |||
nummap = {0: '0', 1: '1-5', 2: '6-10', 3: '>10'} | |||
if vector[:4] == [0, 0, 0, 0]: | |||
report = '' | |||
else: | |||
num = vector.index(1) | |||
report = domain + ': ' + nummap[num] + '; ' | |||
if vector[-2] == 0 and vector[-1] == 1: | |||
report += 'booking: ok' | |||
if vector[-2] == 1 and vector[-1] == 0: | |||
report += 'booking: unable' | |||
return report | |||
def queryJsons(self, | |||
domain, | |||
constraints, | |||
exactly_match=True, | |||
return_name=False): | |||
"""Returns the list of entities for a given domain | |||
based on the annotation of the belief state | |||
constraints: dict e.g. {'pricerange': 'cheap', 'area': 'west'} | |||
""" | |||
# query the db | |||
if domain == 'taxi': | |||
return [{ | |||
'taxi_colors': | |||
random.choice(self.dbs[domain]['taxi_colors']), | |||
'taxi_types': | |||
random.choice(self.dbs[domain]['taxi_types']), | |||
'taxi_phone': [random.randint(1, 9) for _ in range(10)] | |||
}] | |||
if domain == 'police': | |||
return self.dbs['police'] | |||
if domain == 'hospital': | |||
if constraints.get('department'): | |||
for entry in self.dbs['hospital']: | |||
if entry.get('department') == constraints.get( | |||
'department'): | |||
return [entry] | |||
else: | |||
return [] | |||
valid_cons = False | |||
for v in constraints.values(): | |||
if v not in ['not mentioned', '']: | |||
valid_cons = True | |||
if not valid_cons: | |||
return [] | |||
match_result = [] | |||
if 'name' in constraints: | |||
for db_ent in self.dbs[domain]: | |||
if 'name' in db_ent: | |||
cons = constraints['name'] | |||
dbn = db_ent['name'] | |||
if cons == dbn: | |||
db_ent = db_ent if not return_name else db_ent['name'] | |||
match_result.append(db_ent) | |||
return match_result | |||
for db_ent in self.dbs[domain]: | |||
match = True | |||
for s, v in constraints.items(): | |||
if s == 'name': | |||
continue | |||
if s in ['people', 'stay'] or (domain == 'hotel' and s == 'day') or \ | |||
(domain == 'restaurant' and s in ['day', 'time']): | |||
# 因为这些inform slot属于book info,而数据库中没有这些slot; | |||
# 能否book是根据user goal中的信息判断,而非通过数据库查询; | |||
continue | |||
skip_case = { | |||
"don't care": 1, | |||
"do n't care": 1, | |||
'dont care': 1, | |||
'not mentioned': 1, | |||
'dontcare': 1, | |||
'': 1 | |||
} | |||
if skip_case.get(v): | |||
continue | |||
if s not in db_ent: | |||
# logging.warning('Searching warning: slot %s not in %s db'%(s, domain)) | |||
match = False | |||
break | |||
# v = 'guesthouse' if v == 'guest house' else v | |||
# v = 'swimmingpool' if v == 'swimming pool' else v | |||
v = 'yes' if v == 'free' else v | |||
if s in ['arrive', 'leave']: | |||
try: | |||
h, m = v.split( | |||
':' | |||
) # raise error if time value is not xx:xx format | |||
v = int(h) * 60 + int(m) | |||
except: | |||
match = False | |||
break | |||
time = int(db_ent[s].split(':')[0]) * 60 + int( | |||
db_ent[s].split(':')[1]) | |||
if s == 'arrive' and v > time: | |||
match = False | |||
if s == 'leave' and v < time: | |||
match = False | |||
else: | |||
if exactly_match and v != db_ent[s]: | |||
match = False | |||
break | |||
elif v not in db_ent[s]: | |||
match = False | |||
break | |||
if match: | |||
match_result.append(db_ent) | |||
if not return_name: | |||
return match_result | |||
else: | |||
if domain == 'train': | |||
match_result = [e['id'] for e in match_result] | |||
else: | |||
match_result = [e['name'] for e in match_result] | |||
return match_result | |||
def querySQL(self, domain, constraints): | |||
if not self.sql_dbs: | |||
for dom in db_domains: | |||
db = 'db/{}-dbase.db'.format(dom) | |||
conn = sqlite3.connect(db) | |||
c = conn.cursor() | |||
self.sql_dbs[dom] = c | |||
sql_query = 'select * from {}'.format(domain) | |||
flag = True | |||
for key, val in constraints.items(): | |||
if val == '' or val == 'dontcare' or val == 'not mentioned' or val == "don't care" or val == 'dont care' or val == "do n't care": | |||
pass | |||
else: | |||
if flag: | |||
sql_query += ' where ' | |||
val2 = val.replace("'", "''") | |||
# val2 = normalize(val2) | |||
if key == 'leaveAt': | |||
sql_query += r' ' + key + ' > ' + r"'" + val2 + r"'" | |||
elif key == 'arriveBy': | |||
sql_query += r' ' + key + ' < ' + r"'" + val2 + r"'" | |||
else: | |||
sql_query += r' ' + key + '=' + r"'" + val2 + r"'" | |||
flag = False | |||
else: | |||
val2 = val.replace("'", "''") | |||
# val2 = normalize(val2) | |||
if key == 'leaveAt': | |||
sql_query += r' and ' + key + ' > ' + r"'" + val2 + r"'" | |||
elif key == 'arriveBy': | |||
sql_query += r' and ' + key + ' < ' + r"'" + val2 + r"'" | |||
else: | |||
sql_query += r' and ' + key + '=' + r"'" + val2 + r"'" | |||
try: # "select * from attraction where name = 'queens college'" | |||
print(sql_query) | |||
return self.sql_dbs[domain].execute(sql_query).fetchall() | |||
except: | |||
return [] # TODO test it | |||
if __name__ == '__main__': | |||
dbPATHs = { | |||
'attraction': 'db/attraction_db_processed.json', | |||
'hospital': 'db/hospital_db_processed.json', | |||
'hotel': 'db/hotel_db_processed.json', | |||
'police': 'db/police_db_processed.json', | |||
'restaurant': 'db/restaurant_db_processed.json', | |||
'taxi': 'db/taxi_db_processed.json', | |||
'train': 'db/train_db_processed.json', | |||
} | |||
db = MultiWozDB(dbPATHs) | |||
while True: | |||
constraints = {} | |||
inp = input( | |||
'input belief state in fomat: domain-slot1=value1;slot2=value2...\n' | |||
) | |||
domain, cons = inp.split('-') | |||
for sv in cons.split(';'): | |||
s, v = sv.split('=') | |||
constraints[s] = v | |||
# res = db.querySQL(domain, constraints) | |||
res = db.queryJsons(domain, constraints, return_name=True) | |||
report = [] | |||
reidx = { | |||
'hotel': 8, | |||
'restaurant': 6, | |||
'attraction': 5, | |||
'train': 1, | |||
} | |||
# for ent in res: | |||
# if reidx.get(domain): | |||
# report.append(ent[reidx[domain]]) | |||
# for ent in res: | |||
# if 'name' in ent: | |||
# report.append(ent['name']) | |||
# if 'trainid' in ent: | |||
# report.append(ent['trainid']) | |||
print(constraints) | |||
print(res) | |||
print('count:', len(res), '\nnames:', report) |
@@ -0,0 +1,210 @@ | |||
all_domains = [ | |||
'restaurant', 'hotel', 'attraction', 'train', 'taxi', 'police', 'hospital' | |||
] | |||
db_domains = ['restaurant', 'hotel', 'attraction', 'train'] | |||
normlize_slot_names = { | |||
'car type': 'car', | |||
'entrance fee': 'price', | |||
'duration': 'time', | |||
'leaveat': 'leave', | |||
'arriveby': 'arrive', | |||
'trainid': 'id' | |||
} | |||
requestable_slots = { | |||
'taxi': ['car', 'phone'], | |||
'police': ['postcode', 'address', 'phone'], | |||
'hospital': ['address', 'phone', 'postcode'], | |||
'hotel': [ | |||
'address', 'postcode', 'internet', 'phone', 'parking', 'type', | |||
'pricerange', 'stars', 'area', 'reference' | |||
], | |||
'attraction': | |||
['price', 'type', 'address', 'postcode', 'phone', 'area', 'reference'], | |||
'train': ['time', 'leave', 'price', 'arrive', 'id', 'reference'], | |||
'restaurant': [ | |||
'phone', 'postcode', 'address', 'pricerange', 'food', 'area', | |||
'reference' | |||
] | |||
} | |||
all_reqslot = [ | |||
'car', 'address', 'postcode', 'phone', 'internet', 'parking', 'type', | |||
'pricerange', 'food', 'stars', 'area', 'reference', 'time', 'leave', | |||
'price', 'arrive', 'id' | |||
] | |||
informable_slots = { | |||
'taxi': ['leave', 'destination', 'departure', 'arrive'], | |||
'police': [], | |||
'hospital': ['department'], | |||
'hotel': [ | |||
'type', 'parking', 'pricerange', 'internet', 'stay', 'day', 'people', | |||
'area', 'stars', 'name' | |||
], | |||
'attraction': ['area', 'type', 'name'], | |||
'train': ['destination', 'day', 'arrive', 'departure', 'people', 'leave'], | |||
'restaurant': | |||
['food', 'pricerange', 'area', 'name', 'time', 'day', 'people'] | |||
} | |||
all_infslot = [ | |||
'type', 'parking', 'pricerange', 'internet', 'stay', 'day', 'people', | |||
'area', 'stars', 'name', 'leave', 'destination', 'departure', 'arrive', | |||
'department', 'food', 'time' | |||
] | |||
all_slots = all_reqslot + [ | |||
'stay', 'day', 'people', 'name', 'destination', 'departure', 'department' | |||
] | |||
get_slot = {} | |||
for s in all_slots: | |||
get_slot[s] = 1 | |||
# mapping slots in dialogue act to original goal slot names | |||
da_abbr_to_slot_name = { | |||
'addr': 'address', | |||
'fee': 'price', | |||
'post': 'postcode', | |||
'ref': 'reference', | |||
'ticket': 'price', | |||
'depart': 'departure', | |||
'dest': 'destination', | |||
} | |||
dialog_acts = { | |||
'restaurant': [ | |||
'inform', 'request', 'nooffer', 'recommend', 'select', 'offerbook', | |||
'offerbooked', 'nobook' | |||
], | |||
'hotel': [ | |||
'inform', 'request', 'nooffer', 'recommend', 'select', 'offerbook', | |||
'offerbooked', 'nobook' | |||
], | |||
'attraction': ['inform', 'request', 'nooffer', 'recommend', 'select'], | |||
'train': | |||
['inform', 'request', 'nooffer', 'offerbook', 'offerbooked', 'select'], | |||
'taxi': ['inform', 'request'], | |||
'police': ['inform', 'request'], | |||
'hospital': ['inform', 'request'], | |||
# 'booking': ['book', 'inform', 'nobook', 'request'], | |||
'general': ['bye', 'greet', 'reqmore', 'welcome'], | |||
} | |||
all_acts = [] | |||
for acts in dialog_acts.values(): | |||
for act in acts: | |||
if act not in all_acts: | |||
all_acts.append(act) | |||
dialog_act_params = { | |||
'inform': all_slots + ['choice', 'open'], | |||
'request': all_infslot + ['choice', 'price'], | |||
'nooffer': all_slots + ['choice'], | |||
'recommend': all_reqslot + ['choice', 'open'], | |||
'select': all_slots + ['choice'], | |||
# 'book': ['time', 'people', 'stay', 'reference', 'day', 'name', 'choice'], | |||
'nobook': ['time', 'people', 'stay', 'reference', 'day', 'name', 'choice'], | |||
'offerbook': all_slots + ['choice'], | |||
'offerbooked': all_slots + ['choice'], | |||
'reqmore': [], | |||
'welcome': [], | |||
'bye': [], | |||
'greet': [], | |||
} | |||
dialog_act_all_slots = all_slots + ['choice', 'open'] | |||
# special slot tokens in belief span | |||
# no need of this, just covert slot to [slot] e.g. pricerange -> [pricerange] | |||
slot_name_to_slot_token = {} | |||
# special slot tokens in responses | |||
# not use at the momoent | |||
slot_name_to_value_token = { | |||
# 'entrance fee': '[value_price]', | |||
# 'pricerange': '[value_price]', | |||
# 'arriveby': '[value_time]', | |||
# 'leaveat': '[value_time]', | |||
# 'departure': '[value_place]', | |||
# 'destination': '[value_place]', | |||
# 'stay': 'count', | |||
# 'people': 'count' | |||
} | |||
# eos tokens definition | |||
eos_tokens = { | |||
'user': '<eos_u>', | |||
'user_delex': '<eos_u>', | |||
'resp': '<eos_r>', | |||
'resp_gen': '<eos_r>', | |||
'pv_resp': '<eos_r>', | |||
'bspn': '<eos_b>', | |||
'bspn_gen': '<eos_b>', | |||
'pv_bspn': '<eos_b>', | |||
'bsdx': '<eos_b>', | |||
'bsdx_gen': '<eos_b>', | |||
'pv_bsdx': '<eos_b>', | |||
'qspn': '<eos_q>', | |||
'qspn_gen': '<eos_q>', | |||
'pv_qspn': '<eos_q>', | |||
'aspn': '<eos_a>', | |||
'aspn_gen': '<eos_a>', | |||
'pv_aspn': '<eos_a>', | |||
'dspn': '<eos_d>', | |||
'dspn_gen': '<eos_d>', | |||
'pv_dspn': '<eos_d>' | |||
} | |||
# sos tokens definition | |||
sos_tokens = { | |||
'user': '<sos_u>', | |||
'user_delex': '<sos_u>', | |||
'resp': '<sos_r>', | |||
'resp_gen': '<sos_r>', | |||
'pv_resp': '<sos_r>', | |||
'bspn': '<sos_b>', | |||
'bspn_gen': '<sos_b>', | |||
'pv_bspn': '<sos_b>', | |||
'bsdx': '<sos_b>', | |||
'bsdx_gen': '<sos_b>', | |||
'pv_bsdx': '<sos_b>', | |||
'qspn': '<sos_q>', | |||
'qspn_gen': '<sos_q>', | |||
'pv_qspn': '<sos_q>', | |||
'aspn': '<sos_a>', | |||
'aspn_gen': '<sos_a>', | |||
'pv_aspn': '<sos_a>', | |||
'dspn': '<sos_d>', | |||
'dspn_gen': '<sos_d>', | |||
'pv_dspn': '<sos_d>' | |||
} | |||
# db tokens definition | |||
db_tokens = [ | |||
'<sos_db>', '<eos_db>', '[book_nores]', '[book_fail]', '[book_success]', | |||
'[db_nores]', '[db_0]', '[db_1]', '[db_2]', '[db_3]' | |||
] | |||
# understand tokens definition | |||
def get_understand_tokens(prompt_num_for_understand): | |||
understand_tokens = [] | |||
for i in range(prompt_num_for_understand): | |||
understand_tokens.append(f'<understand_{i}>') | |||
return understand_tokens | |||
# policy tokens definition | |||
def get_policy_tokens(prompt_num_for_policy): | |||
policy_tokens = [] | |||
for i in range(prompt_num_for_policy): | |||
policy_tokens.append(f'<policy_{i}>') | |||
return policy_tokens | |||
# all special tokens definition | |||
def get_special_tokens(other_tokens): | |||
special_tokens = ['<go_r>', '<go_b>', '<go_a>', '<go_d>', | |||
'<eos_u>', '<eos_r>', '<eos_b>', '<eos_a>', '<eos_d>', '<eos_q>', | |||
'<sos_u>', '<sos_r>', '<sos_b>', '<sos_a>', '<sos_d>', '<sos_q>'] \ | |||
+ db_tokens + other_tokens | |||
return special_tokens |
@@ -0,0 +1,6 @@ | |||
def hierarchical_set_score(frame1, frame2): | |||
# deal with empty frame | |||
if not (frame1 and frame2): | |||
return 0. | |||
pass | |||
return 0. |
@@ -0,0 +1,180 @@ | |||
import logging | |||
from collections import OrderedDict | |||
import json | |||
import numpy as np | |||
from . import ontology | |||
def clean_replace(s, r, t, forward=True, backward=False): | |||
def clean_replace_single(s, r, t, forward, backward, sidx=0): | |||
# idx = s[sidx:].find(r) | |||
idx = s.find(r) | |||
if idx == -1: | |||
return s, -1 | |||
idx_r = idx + len(r) | |||
if backward: | |||
while idx > 0 and s[idx - 1]: | |||
idx -= 1 | |||
elif idx > 0 and s[idx - 1] != ' ': | |||
return s, -1 | |||
if forward: | |||
while idx_r < len(s) and (s[idx_r].isalpha() | |||
or s[idx_r].isdigit()): | |||
idx_r += 1 | |||
elif idx_r != len(s) and (s[idx_r].isalpha() or s[idx_r].isdigit()): | |||
return s, -1 | |||
return s[:idx] + t + s[idx_r:], idx_r | |||
# source, replace, target = s, r, t | |||
# count = 0 | |||
sidx = 0 | |||
while sidx != -1: | |||
s, sidx = clean_replace_single(s, r, t, forward, backward, sidx) | |||
# count += 1 | |||
# print(s, sidx) | |||
# if count == 20: | |||
# print(source, '\n', replace, '\n', target) | |||
# quit() | |||
return s | |||
def py2np(list): | |||
return np.array(list) | |||
def write_dict(fn, dic): | |||
with open(fn, 'w') as f: | |||
json.dump(dic, f, indent=2) | |||
def f1_score(label_list, pred_list): | |||
tp = len([t for t in pred_list if t in label_list]) | |||
fp = max(0, len(pred_list) - tp) | |||
fn = max(0, len(label_list) - tp) | |||
precision = tp / (tp + fp + 1e-10) | |||
recall = tp / (tp + fn + 1e-10) | |||
f1 = 2 * precision * recall / (precision + recall + 1e-10) | |||
return f1 | |||
class MultiWOZVocab(object): | |||
def __init__(self, vocab_size=0): | |||
""" | |||
vocab for multiwoz dataset | |||
""" | |||
self.vocab_size = vocab_size | |||
self.vocab_size_oov = 0 # get after construction | |||
self._idx2word = {} # word + oov | |||
self._word2idx = {} # word | |||
self._freq_dict = {} # word + oov | |||
for w in [ | |||
'[PAD]', '<go_r>', '[UNK]', '<go_b>', '<go_a>', '<eos_u>', | |||
'<eos_r>', '<eos_b>', '<eos_a>', '<go_d>', '<eos_d>' | |||
]: | |||
self._absolute_add_word(w) | |||
def _absolute_add_word(self, w): | |||
idx = len(self._idx2word) | |||
self._idx2word[idx] = w | |||
self._word2idx[w] = idx | |||
def add_word(self, word): | |||
if word not in self._freq_dict: | |||
self._freq_dict[word] = 0 | |||
self._freq_dict[word] += 1 | |||
def has_word(self, word): | |||
return self._freq_dict.get(word) | |||
def _add_to_vocab(self, word): | |||
if word not in self._word2idx: | |||
idx = len(self._idx2word) | |||
self._idx2word[idx] = word | |||
self._word2idx[word] = idx | |||
def construct(self): | |||
l = sorted(self._freq_dict.keys(), key=lambda x: -self._freq_dict[x]) | |||
print('Vocabulary size including oov: %d' % | |||
(len(l) + len(self._idx2word))) | |||
if len(l) + len(self._idx2word) < self.vocab_size: | |||
logging.warning( | |||
'actual label set smaller than that configured: {}/{}'.format( | |||
len(l) + len(self._idx2word), self.vocab_size)) | |||
for word in ontology.all_domains + ['general']: | |||
word = '[' + word + ']' | |||
self._add_to_vocab(word) | |||
for word in ontology.all_acts: | |||
word = '[' + word + ']' | |||
self._add_to_vocab(word) | |||
for word in ontology.all_slots: | |||
self._add_to_vocab(word) | |||
for word in l: | |||
if word.startswith('[value_') and word.endswith(']'): | |||
self._add_to_vocab(word) | |||
for word in l: | |||
self._add_to_vocab(word) | |||
self.vocab_size_oov = len(self._idx2word) | |||
def load_vocab(self, vocab_path): | |||
self._freq_dict = json.loads( | |||
open(vocab_path + '.freq.json', 'r').read()) | |||
self._word2idx = json.loads( | |||
open(vocab_path + '.word2idx.json', 'r').read()) | |||
self._idx2word = {} | |||
for w, idx in self._word2idx.items(): | |||
self._idx2word[idx] = w | |||
self.vocab_size_oov = len(self._idx2word) | |||
print('vocab file loaded from "' + vocab_path + '"') | |||
print('Vocabulary size including oov: %d' % (self.vocab_size_oov)) | |||
def save_vocab(self, vocab_path): | |||
_freq_dict = OrderedDict( | |||
sorted( | |||
self._freq_dict.items(), key=lambda kv: kv[1], reverse=True)) | |||
write_dict(vocab_path + '.word2idx.json', self._word2idx) | |||
write_dict(vocab_path + '.freq.json', _freq_dict) | |||
def encode(self, word, include_oov=True): | |||
if include_oov: | |||
if self._word2idx.get(word, None) is None: | |||
raise ValueError( | |||
'Unknown word: %s. Vocabulary should include oovs here.' % | |||
word) | |||
return self._word2idx[word] | |||
else: | |||
word = '<unk>' if word not in self._word2idx else word | |||
return self._word2idx[word] | |||
def sentence_encode(self, word_list): | |||
return [self.encode(_) for _ in word_list] | |||
def oov_idx_map(self, idx): | |||
return 2 if idx > self.vocab_size else idx | |||
def sentence_oov_map(self, index_list): | |||
return [self.oov_idx_map(_) for _ in index_list] | |||
def decode(self, idx, indicate_oov=False): | |||
if not self._idx2word.get(idx): | |||
raise ValueError( | |||
'Error idx: %d. Vocabulary should include oovs here.' % idx) | |||
if not indicate_oov or idx < self.vocab_size: | |||
return self._idx2word[idx] | |||
else: | |||
return self._idx2word[idx] + '(o)' | |||
def sentence_decode(self, index_list, eos=None, indicate_oov=False): | |||
l = [self.decode(_, indicate_oov) for _ in index_list] | |||
if not eos or eos not in l: | |||
return ' '.join(l) | |||
else: | |||
idx = l.index(eos) | |||
return ' '.join(l[:idx]) | |||
def nl_decode(self, l, eos=None): | |||
return [self.sentence_decode(_, eos) + '\n' for _ in l] |
@@ -0,0 +1,2 @@ | |||
spacy==2.3.5 | |||
# python -m spacy download en_core_web_sm |
@@ -0,0 +1,76 @@ | |||
test_case = { | |||
'sng0073': { | |||
'goal': { | |||
'taxi': { | |||
'info': { | |||
'leaveat': '17:15', | |||
'destination': 'pizza hut fen ditton', | |||
'departure': "saint john's college" | |||
}, | |||
'reqt': ['car', 'phone'], | |||
'fail_info': {} | |||
} | |||
}, | |||
'log': [{ | |||
'user': | |||
"i would like a taxi from saint john 's college to pizza hut fen ditton .", | |||
'user_delex': | |||
'i would like a taxi from [value_departure] to [value_destination] .', | |||
'resp': | |||
'what time do you want to leave and what time do you want to arrive by ?', | |||
'sys': | |||
'what time do you want to leave and what time do you want to arrive by ?', | |||
'pointer': '0,0,0,0,0,0', | |||
'match': '', | |||
'constraint': | |||
"[taxi] destination pizza hut fen ditton departure saint john 's college", | |||
'cons_delex': '[taxi] destination departure', | |||
'sys_act': '[taxi] [request] leave arrive', | |||
'turn_num': 0, | |||
'turn_domain': '[taxi]' | |||
}, { | |||
'user': 'i want to leave after 17:15 .', | |||
'user_delex': 'i want to leave after [value_leave] .', | |||
'resp': | |||
'booking completed ! your taxi will be [value_car] contact number is [value_phone]', | |||
'sys': | |||
'booking completed ! your taxi will be blue honda contact number is 07218068540', | |||
'pointer': '0,0,0,0,0,0', | |||
'match': '', | |||
'constraint': | |||
"[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", | |||
'cons_delex': '[taxi] destination departure leave', | |||
'sys_act': '[taxi] [inform] car phone', | |||
'turn_num': 1, | |||
'turn_domain': '[taxi]' | |||
}, { | |||
'user': 'thank you for all the help ! i appreciate it .', | |||
'user_delex': 'thank you for all the help ! i appreciate it .', | |||
'resp': | |||
'you are welcome . is there anything else i can help you with today ?', | |||
'sys': | |||
'you are welcome . is there anything else i can help you with today ?', | |||
'pointer': '0,0,0,0,0,0', | |||
'match': '', | |||
'constraint': | |||
"[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", | |||
'cons_delex': '[taxi] destination departure leave', | |||
'sys_act': '[general] [reqmore]', | |||
'turn_num': 2, | |||
'turn_domain': '[general]' | |||
}, { | |||
'user': 'no , i am all set . have a nice day . bye .', | |||
'user_delex': 'no , i am all set . have a nice day . bye .', | |||
'resp': 'you too ! thank you', | |||
'sys': 'you too ! thank you', | |||
'pointer': '0,0,0,0,0,0', | |||
'match': '', | |||
'constraint': | |||
"[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15", | |||
'cons_delex': '[taxi] destination departure leave', | |||
'sys_act': '[general] [bye]', | |||
'turn_num': 3, | |||
'turn_domain': '[general]' | |||
}] | |||
} | |||
} |
@@ -37,30 +37,31 @@ dialog_case = [{ | |||
}] | |||
def merge(info, result): | |||
return info | |||
class DialogGenerationTest(unittest.TestCase): | |||
def test_run(self): | |||
for item in dialog_case: | |||
q = item['user'] | |||
a = item['sys'] | |||
print('user:{}'.format(q)) | |||
print('sys:{}'.format(a)) | |||
# preprocessor = DialogGenerationPreprocessor() | |||
# # data = DialogGenerationData() | |||
# model = DialogGenerationModel(path, preprocessor.tokenizer) | |||
# pipeline = DialogGenerationPipeline(model, preprocessor) | |||
# | |||
# history_dialog = [] | |||
# for item in dialog_case: | |||
# user_question = item['user'] | |||
# print('user: {}'.format(user_question)) | |||
# | |||
# pipeline(user_question) | |||
# | |||
# sys_answer, history_dialog = pipeline() | |||
# | |||
# print('sys : {}'.format(sys_answer)) | |||
modeldir = '/Users/yangliu/Desktop/space-dialog-generation' | |||
preprocessor = DialogGenerationPreprocessor() | |||
model = DialogGenerationModel( | |||
model_dir=modeldir, preprocessor.tokenizer) | |||
pipeline = DialogGenerationPipeline(model, preprocessor) | |||
history_dialog = {} | |||
for step in range(0, len(dialog_case)): | |||
user_question = dialog_case[step]['user'] | |||
print('user: {}'.format(user_question)) | |||
history_dialog_info = merge(history_dialog_info, | |||
result) if step > 0 else {} | |||
result = pipeline(user_question, history=history_dialog_info) | |||
print('sys : {}'.format(result['pred_answer'])) | |||
if __name__ == '__main__': | |||
@@ -0,0 +1,25 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import unittest | |||
from tests.case.nlp.dialog_generation_case import test_case | |||
from maas_lib.preprocessors import DialogGenerationPreprocessor | |||
from maas_lib.utils.constant import Fields, InputFields | |||
from maas_lib.utils.logger import get_logger | |||
logger = get_logger() | |||
class DialogGenerationPreprocessorTest(unittest.TestCase): | |||
def test_tokenize(self): | |||
modeldir = '/Users/yangliu/Desktop/space-dialog-generation' | |||
processor = DialogGenerationPreprocessor(model_dir=modeldir) | |||
for item in test_case['sng0073']['log']: | |||
print(processor(item['user'])) | |||
if __name__ == '__main__': | |||
unittest.main() |