Browse Source

token to ids

master
ly119399 3 years ago
parent
commit
4a244abbf1
22 changed files with 980 additions and 75 deletions
  1. +5
    -5
      .pre-commit-config.yaml
  2. +1
    -0
      maas_lib/pipelines/__init__.py
  3. +1
    -1
      maas_lib/pipelines/base.py
  4. +20
    -19
      maas_lib/pipelines/nlp/space/dialog_generation_pipeline.py
  5. +1
    -0
      maas_lib/preprocessors/__init__.py
  6. +2
    -30
      maas_lib/preprocessors/nlp.py
  7. +0
    -0
      maas_lib/preprocessors/space/__init__.py
  8. +48
    -0
      maas_lib/preprocessors/space/dialog_generation_preprcessor.py
  9. +0
    -0
      maas_lib/utils/nlp/__init__.py
  10. +0
    -0
      maas_lib/utils/nlp/space/__init__.py
  11. +66
    -0
      maas_lib/utils/nlp/space/args.py
  12. +316
    -0
      maas_lib/utils/nlp/space/db_ops.py
  13. +210
    -0
      maas_lib/utils/nlp/space/ontology.py
  14. +6
    -0
      maas_lib/utils/nlp/space/scores.py
  15. +180
    -0
      maas_lib/utils/nlp/space/utils.py
  16. +2
    -0
      requirements/nlp/space.txt
  17. +0
    -0
      tests/case/__init__.py
  18. +0
    -0
      tests/case/nlp/__init__.py
  19. +76
    -0
      tests/case/nlp/dialog_generation_case.py
  20. +21
    -20
      tests/pipelines/nlp/test_dialog_generation.py
  21. +0
    -0
      tests/preprocessors/nlp/__init__.py
  22. +25
    -0
      tests/preprocessors/nlp/test_dialog_generation.py

+ 5
- 5
.pre-commit-config.yaml View File

@@ -1,9 +1,9 @@
repos: repos:
- repo: https://gitlab.com/pycqa/flake8.git
rev: 3.8.3
hooks:
- id: flake8
exclude: thirdparty/|examples/
# - repo: https://gitlab.com/pycqa/flake8.git
# rev: 3.8.3
# hooks:
# - id: flake8
# exclude: thirdparty/|examples/
- repo: https://github.com/timothycrosley/isort - repo: https://github.com/timothycrosley/isort
rev: 4.3.21 rev: 4.3.21
hooks: hooks:


+ 1
- 0
maas_lib/pipelines/__init__.py View File

@@ -4,3 +4,4 @@ from .builder import pipeline
from .cv import * # noqa F403 from .cv import * # noqa F403
from .multi_modal import * # noqa F403 from .multi_modal import * # noqa F403
from .nlp import * # noqa F403 from .nlp import * # noqa F403
from .nlp.space import * # noqa F403

+ 1
- 1
maas_lib/pipelines/base.py View File

@@ -84,7 +84,7 @@ class Pipeline(ABC):


def _process_single(self, input: Input, *args, def _process_single(self, input: Input, *args,
**post_kwargs) -> Dict[str, Any]: **post_kwargs) -> Dict[str, Any]:
out = self.preprocess(input)
out = self.preprocess(input, **post_kwargs)
out = self.forward(out) out = self.forward(out)
out = self.postprocess(out, **post_kwargs) out = self.postprocess(out, **post_kwargs)
return out return out


+ 20
- 19
maas_lib/pipelines/nlp/space/dialog_generation_pipeline.py View File

@@ -22,28 +22,29 @@ class DialogGenerationPipeline(Model):
""" """


super().__init__(model=model, preprocessor=preprocessor, **kwargs) super().__init__(model=model, preprocessor=preprocessor, **kwargs)
pass
self.model = model
self.tokenizer = preprocessor.tokenizer


def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""return the result by the model
def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, str]:
"""process the prediction results


Args: Args:
input (Dict[str, Any]): the preprocessed data
inputs (Dict[str, Any]): _description_


Returns: Returns:
Dict[str, np.ndarray]: results
Example:
{
'predictions': array([1]), # lable 0-negative 1-positive
'probabilities': array([[0.11491239, 0.8850876 ]], dtype=float32),
'logits': array([[-0.53860897, 1.5029076 ]], dtype=float32) # true value
}
Dict[str, str]: the prediction results
""" """
from numpy import array, float32

return {
'predictions': array([1]), # lable 0-negative 1-positive
'probabilities': array([[0.11491239, 0.8850876]], dtype=float32),
'logits': array([[-0.53860897, 1.5029076]],
dtype=float32) # true value
}

vocab_size = len(self.tokenizer.vocab)
pred_list = inputs['predictions']
pred_ids = pred_list[0][0].cpu().numpy().tolist()
for j in range(len(pred_ids)):
if pred_ids[j] >= vocab_size:
pred_ids[j] = 100
pred = self.tokenizer.convert_ids_to_tokens(pred_ids)
pred_string = ''.join(pred).replace(
'##',
'').split('[SEP]')[0].replace('[CLS]',
'').replace('[SEP]',
'').replace('[UNK]', '')
return {'pred_string': pred_string}

+ 1
- 0
maas_lib/preprocessors/__init__.py View File

@@ -5,3 +5,4 @@ from .builder import PREPROCESSORS, build_preprocessor
from .common import Compose from .common import Compose
from .image import LoadImage, load_image from .image import LoadImage, load_image
from .nlp import * # noqa F403 from .nlp import * # noqa F403
from .space.dialog_generation_preprcessor import * # noqa F403

+ 2
- 30
maas_lib/preprocessors/nlp.py View File

@@ -11,8 +11,8 @@ from .base import Preprocessor
from .builder import PREPROCESSORS from .builder import PREPROCESSORS


__all__ = [ __all__ = [
'Tokenize', 'SequenceClassificationPreprocessor',
'DialogGenerationPreprocessor'
'Tokenize',
'SequenceClassificationPreprocessor',
] ]




@@ -92,31 +92,3 @@ class SequenceClassificationPreprocessor(Preprocessor):
rst['token_type_ids'].append(feature['token_type_ids']) rst['token_type_ids'].append(feature['token_type_ids'])


return rst return rst


@PREPROCESSORS.register_module(Fields.nlp, module_name=r'space')
class DialogGenerationPreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path

Args:
model_dir (str): model path
"""
super().__init__(*args, **kwargs)

pass

@type_assert(object, str)
def __call__(self, data: str) -> Dict[str, Any]:
"""process the raw input data

Args:
data (str): a sentence
Example:
'you are so handsome.'

Returns:
Dict[str, Any]: the preprocessed data
"""
return None

+ 0
- 0
maas_lib/preprocessors/space/__init__.py View File


+ 48
- 0
maas_lib/preprocessors/space/dialog_generation_preprcessor.py View File

@@ -0,0 +1,48 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
import uuid
from typing import Any, Dict, Union

from maas_lib.data.nlp.space.fields.gen_field import MultiWOZBPETextField
from maas_lib.utils.constant import Fields, InputFields
from maas_lib.utils.type_assert import type_assert
from ..base import Preprocessor
from ..builder import PREPROCESSORS

__all__ = ['DialogGenerationPreprocessor']


@PREPROCESSORS.register_module(Fields.nlp, module_name=r'space')
class DialogGenerationPreprocessor(Preprocessor):

def __init__(self, model_dir: str, *args, **kwargs):
"""preprocess the data via the vocab.txt from the `model_dir` path

Args:
model_dir (str): model path
"""
super().__init__(*args, **kwargs)

self.model_dir: str = model_dir

self.text_field = MultiWOZBPETextField(model_dir=self.model_dir)

pass

@type_assert(object, str)
def __call__(self, data: str) -> Dict[str, Any]:
"""process the raw input data

Args:
data (str): a sentence
Example:
'you are so handsome.'

Returns:
Dict[str, Any]: the preprocessed data
"""

idx = self.text_field.get_ids(data)

return {'user_idx': idx}

+ 0
- 0
maas_lib/utils/nlp/__init__.py View File


+ 0
- 0
maas_lib/utils/nlp/space/__init__.py View File


+ 66
- 0
maas_lib/utils/nlp/space/args.py View File

@@ -0,0 +1,66 @@
"""
Parse argument.
"""

import argparse

import json


def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Unsupported value encountered.')


class HParams(dict):
""" Hyper-parameters class

Store hyper-parameters in training / infer / ... scripts.
"""

def __getattr__(self, name):
if name in self.keys():
return self[name]
for v in self.values():
if isinstance(v, HParams):
if name in v:
return v[name]
raise AttributeError(f"'HParams' object has no attribute '{name}'")

def __setattr__(self, name, value):
self[name] = value

def save(self, filename):
with open(filename, 'w', encoding='utf-8') as fp:
json.dump(self, fp, ensure_ascii=False, indent=4, sort_keys=False)

def load(self, filename):
with open(filename, 'r', encoding='utf-8') as fp:
params_dict = json.load(fp)
for k, v in params_dict.items():
if isinstance(v, dict):
self[k].update(HParams(v))
else:
self[k] = v


def parse_args(parser):
""" Parse hyper-parameters from cmdline. """
parsed = parser.parse_args()
args = HParams()
optional_args = parser._action_groups[1]
for action in optional_args._group_actions[1:]:
arg_name = action.dest
args[arg_name] = getattr(parsed, arg_name)
for group in parser._action_groups[2:]:
group_args = HParams()
for action in group._group_actions:
arg_name = action.dest
group_args[arg_name] = getattr(parsed, arg_name)
if len(group_args) > 0:
args[group.title] = group_args
return args

+ 316
- 0
maas_lib/utils/nlp/space/db_ops.py View File

@@ -0,0 +1,316 @@
import os
import random
import sqlite3

import json

from .ontology import all_domains, db_domains


class MultiWozDB(object):

def __init__(self, db_dir, db_paths):
self.dbs = {}
self.sql_dbs = {}
for domain in all_domains:
with open(os.path.join(db_dir, db_paths[domain]), 'r') as f:
self.dbs[domain] = json.loads(f.read().lower())

def oneHotVector(self, domain, num):
"""Return number of available entities for particular domain."""
vector = [0, 0, 0, 0]
if num == '':
return vector
if domain != 'train':
if num == 0:
vector = [1, 0, 0, 0]
elif num == 1:
vector = [0, 1, 0, 0]
elif num <= 3:
vector = [0, 0, 1, 0]
else:
vector = [0, 0, 0, 1]
else:
if num == 0:
vector = [1, 0, 0, 0]
elif num <= 5:
vector = [0, 1, 0, 0]
elif num <= 10:
vector = [0, 0, 1, 0]
else:
vector = [0, 0, 0, 1]
return vector

def addBookingPointer(self, turn_da):
"""Add information about availability of the booking option."""
# Booking pointer
# Do not consider booking two things in a single turn.
vector = [0, 0]
if turn_da.get('booking-nobook'):
vector = [1, 0]
if turn_da.get('booking-book') or turn_da.get('train-offerbooked'):
vector = [0, 1]
return vector

def addDBPointer(self, domain, match_num, return_num=False):
"""Create database pointer for all related domains."""
# if turn_domains is None:
# turn_domains = db_domains
if domain in db_domains:
vector = self.oneHotVector(domain, match_num)
else:
vector = [0, 0, 0, 0]
return vector

def addDBIndicator(self, domain, match_num, return_num=False):
"""Create database indicator for all related domains."""
# if turn_domains is None:
# turn_domains = db_domains
if domain in db_domains:
vector = self.oneHotVector(domain, match_num)
else:
vector = [0, 0, 0, 0]

# '[db_nores]', '[db_0]', '[db_1]', '[db_2]', '[db_3]'
if vector == [0, 0, 0, 0]:
indicator = '[db_nores]'
else:
indicator = '[db_%s]' % vector.index(1)
return indicator

def get_match_num(self, constraints, return_entry=False):
"""Create database pointer for all related domains."""
match = {'general': ''}
entry = {}
# if turn_domains is None:
# turn_domains = db_domains
for domain in all_domains:
match[domain] = ''
if domain in db_domains and constraints.get(domain):
matched_ents = self.queryJsons(domain, constraints[domain])
match[domain] = len(matched_ents)
if return_entry:
entry[domain] = matched_ents
if return_entry:
return entry
return match

def pointerBack(self, vector, domain):
# multi domain implementation
# domnum = cfg.domain_num
if domain.endswith(']'):
domain = domain[1:-1]
if domain != 'train':
nummap = {0: '0', 1: '1', 2: '2-3', 3: '>3'}
else:
nummap = {0: '0', 1: '1-5', 2: '6-10', 3: '>10'}
if vector[:4] == [0, 0, 0, 0]:
report = ''
else:
num = vector.index(1)
report = domain + ': ' + nummap[num] + '; '

if vector[-2] == 0 and vector[-1] == 1:
report += 'booking: ok'
if vector[-2] == 1 and vector[-1] == 0:
report += 'booking: unable'

return report

def queryJsons(self,
domain,
constraints,
exactly_match=True,
return_name=False):
"""Returns the list of entities for a given domain
based on the annotation of the belief state
constraints: dict e.g. {'pricerange': 'cheap', 'area': 'west'}
"""
# query the db
if domain == 'taxi':
return [{
'taxi_colors':
random.choice(self.dbs[domain]['taxi_colors']),
'taxi_types':
random.choice(self.dbs[domain]['taxi_types']),
'taxi_phone': [random.randint(1, 9) for _ in range(10)]
}]
if domain == 'police':
return self.dbs['police']
if domain == 'hospital':
if constraints.get('department'):
for entry in self.dbs['hospital']:
if entry.get('department') == constraints.get(
'department'):
return [entry]
else:
return []

valid_cons = False
for v in constraints.values():
if v not in ['not mentioned', '']:
valid_cons = True
if not valid_cons:
return []

match_result = []

if 'name' in constraints:
for db_ent in self.dbs[domain]:
if 'name' in db_ent:
cons = constraints['name']
dbn = db_ent['name']
if cons == dbn:
db_ent = db_ent if not return_name else db_ent['name']
match_result.append(db_ent)
return match_result

for db_ent in self.dbs[domain]:
match = True
for s, v in constraints.items():
if s == 'name':
continue
if s in ['people', 'stay'] or (domain == 'hotel' and s == 'day') or \
(domain == 'restaurant' and s in ['day', 'time']):
# 因为这些inform slot属于book info,而数据库中没有这些slot;
# 能否book是根据user goal中的信息判断,而非通过数据库查询;
continue

skip_case = {
"don't care": 1,
"do n't care": 1,
'dont care': 1,
'not mentioned': 1,
'dontcare': 1,
'': 1
}
if skip_case.get(v):
continue

if s not in db_ent:
# logging.warning('Searching warning: slot %s not in %s db'%(s, domain))
match = False
break

# v = 'guesthouse' if v == 'guest house' else v
# v = 'swimmingpool' if v == 'swimming pool' else v
v = 'yes' if v == 'free' else v

if s in ['arrive', 'leave']:
try:
h, m = v.split(
':'
) # raise error if time value is not xx:xx format
v = int(h) * 60 + int(m)
except:
match = False
break
time = int(db_ent[s].split(':')[0]) * 60 + int(
db_ent[s].split(':')[1])
if s == 'arrive' and v > time:
match = False
if s == 'leave' and v < time:
match = False
else:
if exactly_match and v != db_ent[s]:
match = False
break
elif v not in db_ent[s]:
match = False
break

if match:
match_result.append(db_ent)

if not return_name:
return match_result
else:
if domain == 'train':
match_result = [e['id'] for e in match_result]
else:
match_result = [e['name'] for e in match_result]
return match_result

def querySQL(self, domain, constraints):
if not self.sql_dbs:
for dom in db_domains:
db = 'db/{}-dbase.db'.format(dom)
conn = sqlite3.connect(db)
c = conn.cursor()
self.sql_dbs[dom] = c

sql_query = 'select * from {}'.format(domain)

flag = True
for key, val in constraints.items():
if val == '' or val == 'dontcare' or val == 'not mentioned' or val == "don't care" or val == 'dont care' or val == "do n't care":
pass
else:
if flag:
sql_query += ' where '
val2 = val.replace("'", "''")
# val2 = normalize(val2)
if key == 'leaveAt':
sql_query += r' ' + key + ' > ' + r"'" + val2 + r"'"
elif key == 'arriveBy':
sql_query += r' ' + key + ' < ' + r"'" + val2 + r"'"
else:
sql_query += r' ' + key + '=' + r"'" + val2 + r"'"
flag = False
else:
val2 = val.replace("'", "''")
# val2 = normalize(val2)
if key == 'leaveAt':
sql_query += r' and ' + key + ' > ' + r"'" + val2 + r"'"
elif key == 'arriveBy':
sql_query += r' and ' + key + ' < ' + r"'" + val2 + r"'"
else:
sql_query += r' and ' + key + '=' + r"'" + val2 + r"'"

try: # "select * from attraction where name = 'queens college'"
print(sql_query)
return self.sql_dbs[domain].execute(sql_query).fetchall()
except:
return [] # TODO test it


if __name__ == '__main__':
dbPATHs = {
'attraction': 'db/attraction_db_processed.json',
'hospital': 'db/hospital_db_processed.json',
'hotel': 'db/hotel_db_processed.json',
'police': 'db/police_db_processed.json',
'restaurant': 'db/restaurant_db_processed.json',
'taxi': 'db/taxi_db_processed.json',
'train': 'db/train_db_processed.json',
}
db = MultiWozDB(dbPATHs)
while True:
constraints = {}
inp = input(
'input belief state in fomat: domain-slot1=value1;slot2=value2...\n'
)
domain, cons = inp.split('-')
for sv in cons.split(';'):
s, v = sv.split('=')
constraints[s] = v
# res = db.querySQL(domain, constraints)
res = db.queryJsons(domain, constraints, return_name=True)
report = []
reidx = {
'hotel': 8,
'restaurant': 6,
'attraction': 5,
'train': 1,
}
# for ent in res:
# if reidx.get(domain):
# report.append(ent[reidx[domain]])
# for ent in res:
# if 'name' in ent:
# report.append(ent['name'])
# if 'trainid' in ent:
# report.append(ent['trainid'])
print(constraints)
print(res)
print('count:', len(res), '\nnames:', report)

+ 210
- 0
maas_lib/utils/nlp/space/ontology.py View File

@@ -0,0 +1,210 @@
all_domains = [
'restaurant', 'hotel', 'attraction', 'train', 'taxi', 'police', 'hospital'
]
db_domains = ['restaurant', 'hotel', 'attraction', 'train']

normlize_slot_names = {
'car type': 'car',
'entrance fee': 'price',
'duration': 'time',
'leaveat': 'leave',
'arriveby': 'arrive',
'trainid': 'id'
}

requestable_slots = {
'taxi': ['car', 'phone'],
'police': ['postcode', 'address', 'phone'],
'hospital': ['address', 'phone', 'postcode'],
'hotel': [
'address', 'postcode', 'internet', 'phone', 'parking', 'type',
'pricerange', 'stars', 'area', 'reference'
],
'attraction':
['price', 'type', 'address', 'postcode', 'phone', 'area', 'reference'],
'train': ['time', 'leave', 'price', 'arrive', 'id', 'reference'],
'restaurant': [
'phone', 'postcode', 'address', 'pricerange', 'food', 'area',
'reference'
]
}
all_reqslot = [
'car', 'address', 'postcode', 'phone', 'internet', 'parking', 'type',
'pricerange', 'food', 'stars', 'area', 'reference', 'time', 'leave',
'price', 'arrive', 'id'
]

informable_slots = {
'taxi': ['leave', 'destination', 'departure', 'arrive'],
'police': [],
'hospital': ['department'],
'hotel': [
'type', 'parking', 'pricerange', 'internet', 'stay', 'day', 'people',
'area', 'stars', 'name'
],
'attraction': ['area', 'type', 'name'],
'train': ['destination', 'day', 'arrive', 'departure', 'people', 'leave'],
'restaurant':
['food', 'pricerange', 'area', 'name', 'time', 'day', 'people']
}
all_infslot = [
'type', 'parking', 'pricerange', 'internet', 'stay', 'day', 'people',
'area', 'stars', 'name', 'leave', 'destination', 'departure', 'arrive',
'department', 'food', 'time'
]

all_slots = all_reqslot + [
'stay', 'day', 'people', 'name', 'destination', 'departure', 'department'
]
get_slot = {}
for s in all_slots:
get_slot[s] = 1

# mapping slots in dialogue act to original goal slot names
da_abbr_to_slot_name = {
'addr': 'address',
'fee': 'price',
'post': 'postcode',
'ref': 'reference',
'ticket': 'price',
'depart': 'departure',
'dest': 'destination',
}

dialog_acts = {
'restaurant': [
'inform', 'request', 'nooffer', 'recommend', 'select', 'offerbook',
'offerbooked', 'nobook'
],
'hotel': [
'inform', 'request', 'nooffer', 'recommend', 'select', 'offerbook',
'offerbooked', 'nobook'
],
'attraction': ['inform', 'request', 'nooffer', 'recommend', 'select'],
'train':
['inform', 'request', 'nooffer', 'offerbook', 'offerbooked', 'select'],
'taxi': ['inform', 'request'],
'police': ['inform', 'request'],
'hospital': ['inform', 'request'],
# 'booking': ['book', 'inform', 'nobook', 'request'],
'general': ['bye', 'greet', 'reqmore', 'welcome'],
}
all_acts = []
for acts in dialog_acts.values():
for act in acts:
if act not in all_acts:
all_acts.append(act)

dialog_act_params = {
'inform': all_slots + ['choice', 'open'],
'request': all_infslot + ['choice', 'price'],
'nooffer': all_slots + ['choice'],
'recommend': all_reqslot + ['choice', 'open'],
'select': all_slots + ['choice'],
# 'book': ['time', 'people', 'stay', 'reference', 'day', 'name', 'choice'],
'nobook': ['time', 'people', 'stay', 'reference', 'day', 'name', 'choice'],
'offerbook': all_slots + ['choice'],
'offerbooked': all_slots + ['choice'],
'reqmore': [],
'welcome': [],
'bye': [],
'greet': [],
}

dialog_act_all_slots = all_slots + ['choice', 'open']

# special slot tokens in belief span
# no need of this, just covert slot to [slot] e.g. pricerange -> [pricerange]
slot_name_to_slot_token = {}

# special slot tokens in responses
# not use at the momoent
slot_name_to_value_token = {
# 'entrance fee': '[value_price]',
# 'pricerange': '[value_price]',
# 'arriveby': '[value_time]',
# 'leaveat': '[value_time]',
# 'departure': '[value_place]',
# 'destination': '[value_place]',
# 'stay': 'count',
# 'people': 'count'
}

# eos tokens definition
eos_tokens = {
'user': '<eos_u>',
'user_delex': '<eos_u>',
'resp': '<eos_r>',
'resp_gen': '<eos_r>',
'pv_resp': '<eos_r>',
'bspn': '<eos_b>',
'bspn_gen': '<eos_b>',
'pv_bspn': '<eos_b>',
'bsdx': '<eos_b>',
'bsdx_gen': '<eos_b>',
'pv_bsdx': '<eos_b>',
'qspn': '<eos_q>',
'qspn_gen': '<eos_q>',
'pv_qspn': '<eos_q>',
'aspn': '<eos_a>',
'aspn_gen': '<eos_a>',
'pv_aspn': '<eos_a>',
'dspn': '<eos_d>',
'dspn_gen': '<eos_d>',
'pv_dspn': '<eos_d>'
}

# sos tokens definition
sos_tokens = {
'user': '<sos_u>',
'user_delex': '<sos_u>',
'resp': '<sos_r>',
'resp_gen': '<sos_r>',
'pv_resp': '<sos_r>',
'bspn': '<sos_b>',
'bspn_gen': '<sos_b>',
'pv_bspn': '<sos_b>',
'bsdx': '<sos_b>',
'bsdx_gen': '<sos_b>',
'pv_bsdx': '<sos_b>',
'qspn': '<sos_q>',
'qspn_gen': '<sos_q>',
'pv_qspn': '<sos_q>',
'aspn': '<sos_a>',
'aspn_gen': '<sos_a>',
'pv_aspn': '<sos_a>',
'dspn': '<sos_d>',
'dspn_gen': '<sos_d>',
'pv_dspn': '<sos_d>'
}

# db tokens definition
db_tokens = [
'<sos_db>', '<eos_db>', '[book_nores]', '[book_fail]', '[book_success]',
'[db_nores]', '[db_0]', '[db_1]', '[db_2]', '[db_3]'
]


# understand tokens definition
def get_understand_tokens(prompt_num_for_understand):
understand_tokens = []
for i in range(prompt_num_for_understand):
understand_tokens.append(f'<understand_{i}>')
return understand_tokens


# policy tokens definition
def get_policy_tokens(prompt_num_for_policy):
policy_tokens = []
for i in range(prompt_num_for_policy):
policy_tokens.append(f'<policy_{i}>')
return policy_tokens


# all special tokens definition
def get_special_tokens(other_tokens):
special_tokens = ['<go_r>', '<go_b>', '<go_a>', '<go_d>',
'<eos_u>', '<eos_r>', '<eos_b>', '<eos_a>', '<eos_d>', '<eos_q>',
'<sos_u>', '<sos_r>', '<sos_b>', '<sos_a>', '<sos_d>', '<sos_q>'] \
+ db_tokens + other_tokens
return special_tokens

+ 6
- 0
maas_lib/utils/nlp/space/scores.py View File

@@ -0,0 +1,6 @@
def hierarchical_set_score(frame1, frame2):
# deal with empty frame
if not (frame1 and frame2):
return 0.
pass
return 0.

+ 180
- 0
maas_lib/utils/nlp/space/utils.py View File

@@ -0,0 +1,180 @@
import logging
from collections import OrderedDict

import json
import numpy as np

from . import ontology


def clean_replace(s, r, t, forward=True, backward=False):

def clean_replace_single(s, r, t, forward, backward, sidx=0):
# idx = s[sidx:].find(r)
idx = s.find(r)
if idx == -1:
return s, -1
idx_r = idx + len(r)
if backward:
while idx > 0 and s[idx - 1]:
idx -= 1
elif idx > 0 and s[idx - 1] != ' ':
return s, -1

if forward:
while idx_r < len(s) and (s[idx_r].isalpha()
or s[idx_r].isdigit()):
idx_r += 1
elif idx_r != len(s) and (s[idx_r].isalpha() or s[idx_r].isdigit()):
return s, -1
return s[:idx] + t + s[idx_r:], idx_r

# source, replace, target = s, r, t
# count = 0
sidx = 0
while sidx != -1:
s, sidx = clean_replace_single(s, r, t, forward, backward, sidx)
# count += 1
# print(s, sidx)
# if count == 20:
# print(source, '\n', replace, '\n', target)
# quit()
return s


def py2np(list):
return np.array(list)


def write_dict(fn, dic):
with open(fn, 'w') as f:
json.dump(dic, f, indent=2)


def f1_score(label_list, pred_list):
tp = len([t for t in pred_list if t in label_list])
fp = max(0, len(pred_list) - tp)
fn = max(0, len(label_list) - tp)
precision = tp / (tp + fp + 1e-10)
recall = tp / (tp + fn + 1e-10)
f1 = 2 * precision * recall / (precision + recall + 1e-10)
return f1


class MultiWOZVocab(object):

def __init__(self, vocab_size=0):
"""
vocab for multiwoz dataset
"""
self.vocab_size = vocab_size
self.vocab_size_oov = 0 # get after construction
self._idx2word = {} # word + oov
self._word2idx = {} # word
self._freq_dict = {} # word + oov
for w in [
'[PAD]', '<go_r>', '[UNK]', '<go_b>', '<go_a>', '<eos_u>',
'<eos_r>', '<eos_b>', '<eos_a>', '<go_d>', '<eos_d>'
]:
self._absolute_add_word(w)

def _absolute_add_word(self, w):
idx = len(self._idx2word)
self._idx2word[idx] = w
self._word2idx[w] = idx

def add_word(self, word):
if word not in self._freq_dict:
self._freq_dict[word] = 0
self._freq_dict[word] += 1

def has_word(self, word):
return self._freq_dict.get(word)

def _add_to_vocab(self, word):
if word not in self._word2idx:
idx = len(self._idx2word)
self._idx2word[idx] = word
self._word2idx[word] = idx

def construct(self):
l = sorted(self._freq_dict.keys(), key=lambda x: -self._freq_dict[x])
print('Vocabulary size including oov: %d' %
(len(l) + len(self._idx2word)))
if len(l) + len(self._idx2word) < self.vocab_size:
logging.warning(
'actual label set smaller than that configured: {}/{}'.format(
len(l) + len(self._idx2word), self.vocab_size))
for word in ontology.all_domains + ['general']:
word = '[' + word + ']'
self._add_to_vocab(word)
for word in ontology.all_acts:
word = '[' + word + ']'
self._add_to_vocab(word)
for word in ontology.all_slots:
self._add_to_vocab(word)
for word in l:
if word.startswith('[value_') and word.endswith(']'):
self._add_to_vocab(word)
for word in l:
self._add_to_vocab(word)
self.vocab_size_oov = len(self._idx2word)

def load_vocab(self, vocab_path):
self._freq_dict = json.loads(
open(vocab_path + '.freq.json', 'r').read())
self._word2idx = json.loads(
open(vocab_path + '.word2idx.json', 'r').read())
self._idx2word = {}
for w, idx in self._word2idx.items():
self._idx2word[idx] = w
self.vocab_size_oov = len(self._idx2word)
print('vocab file loaded from "' + vocab_path + '"')
print('Vocabulary size including oov: %d' % (self.vocab_size_oov))

def save_vocab(self, vocab_path):
_freq_dict = OrderedDict(
sorted(
self._freq_dict.items(), key=lambda kv: kv[1], reverse=True))
write_dict(vocab_path + '.word2idx.json', self._word2idx)
write_dict(vocab_path + '.freq.json', _freq_dict)

def encode(self, word, include_oov=True):
if include_oov:
if self._word2idx.get(word, None) is None:
raise ValueError(
'Unknown word: %s. Vocabulary should include oovs here.' %
word)
return self._word2idx[word]
else:
word = '<unk>' if word not in self._word2idx else word
return self._word2idx[word]

def sentence_encode(self, word_list):
return [self.encode(_) for _ in word_list]

def oov_idx_map(self, idx):
return 2 if idx > self.vocab_size else idx

def sentence_oov_map(self, index_list):
return [self.oov_idx_map(_) for _ in index_list]

def decode(self, idx, indicate_oov=False):
if not self._idx2word.get(idx):
raise ValueError(
'Error idx: %d. Vocabulary should include oovs here.' % idx)
if not indicate_oov or idx < self.vocab_size:
return self._idx2word[idx]
else:
return self._idx2word[idx] + '(o)'

def sentence_decode(self, index_list, eos=None, indicate_oov=False):
l = [self.decode(_, indicate_oov) for _ in index_list]
if not eos or eos not in l:
return ' '.join(l)
else:
idx = l.index(eos)
return ' '.join(l[:idx])

def nl_decode(self, l, eos=None):
return [self.sentence_decode(_, eos) + '\n' for _ in l]

+ 2
- 0
requirements/nlp/space.txt View File

@@ -0,0 +1,2 @@
spacy==2.3.5
# python -m spacy download en_core_web_sm

+ 0
- 0
tests/case/__init__.py View File


+ 0
- 0
tests/case/nlp/__init__.py View File


+ 76
- 0
tests/case/nlp/dialog_generation_case.py View File

@@ -0,0 +1,76 @@
test_case = {
'sng0073': {
'goal': {
'taxi': {
'info': {
'leaveat': '17:15',
'destination': 'pizza hut fen ditton',
'departure': "saint john's college"
},
'reqt': ['car', 'phone'],
'fail_info': {}
}
},
'log': [{
'user':
"i would like a taxi from saint john 's college to pizza hut fen ditton .",
'user_delex':
'i would like a taxi from [value_departure] to [value_destination] .',
'resp':
'what time do you want to leave and what time do you want to arrive by ?',
'sys':
'what time do you want to leave and what time do you want to arrive by ?',
'pointer': '0,0,0,0,0,0',
'match': '',
'constraint':
"[taxi] destination pizza hut fen ditton departure saint john 's college",
'cons_delex': '[taxi] destination departure',
'sys_act': '[taxi] [request] leave arrive',
'turn_num': 0,
'turn_domain': '[taxi]'
}, {
'user': 'i want to leave after 17:15 .',
'user_delex': 'i want to leave after [value_leave] .',
'resp':
'booking completed ! your taxi will be [value_car] contact number is [value_phone]',
'sys':
'booking completed ! your taxi will be blue honda contact number is 07218068540',
'pointer': '0,0,0,0,0,0',
'match': '',
'constraint':
"[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15",
'cons_delex': '[taxi] destination departure leave',
'sys_act': '[taxi] [inform] car phone',
'turn_num': 1,
'turn_domain': '[taxi]'
}, {
'user': 'thank you for all the help ! i appreciate it .',
'user_delex': 'thank you for all the help ! i appreciate it .',
'resp':
'you are welcome . is there anything else i can help you with today ?',
'sys':
'you are welcome . is there anything else i can help you with today ?',
'pointer': '0,0,0,0,0,0',
'match': '',
'constraint':
"[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15",
'cons_delex': '[taxi] destination departure leave',
'sys_act': '[general] [reqmore]',
'turn_num': 2,
'turn_domain': '[general]'
}, {
'user': 'no , i am all set . have a nice day . bye .',
'user_delex': 'no , i am all set . have a nice day . bye .',
'resp': 'you too ! thank you',
'sys': 'you too ! thank you',
'pointer': '0,0,0,0,0,0',
'match': '',
'constraint':
"[taxi] destination pizza hut fen ditton departure saint john 's college leave 17:15",
'cons_delex': '[taxi] destination departure leave',
'sys_act': '[general] [bye]',
'turn_num': 3,
'turn_domain': '[general]'
}]
}
}

+ 21
- 20
tests/pipelines/nlp/test_dialog_generation.py View File

@@ -37,30 +37,31 @@ dialog_case = [{
}] }]




def merge(info, result):
return info


class DialogGenerationTest(unittest.TestCase): class DialogGenerationTest(unittest.TestCase):


def test_run(self): def test_run(self):
for item in dialog_case:
q = item['user']
a = item['sys']
print('user:{}'.format(q))
print('sys:{}'.format(a))


# preprocessor = DialogGenerationPreprocessor()
# # data = DialogGenerationData()
# model = DialogGenerationModel(path, preprocessor.tokenizer)
# pipeline = DialogGenerationPipeline(model, preprocessor)
#
# history_dialog = []
# for item in dialog_case:
# user_question = item['user']
# print('user: {}'.format(user_question))
#
# pipeline(user_question)
#
# sys_answer, history_dialog = pipeline()
#
# print('sys : {}'.format(sys_answer))
modeldir = '/Users/yangliu/Desktop/space-dialog-generation'

preprocessor = DialogGenerationPreprocessor()
model = DialogGenerationModel(
model_dir=modeldir, preprocessor.tokenizer)
pipeline = DialogGenerationPipeline(model, preprocessor)

history_dialog = {}
for step in range(0, len(dialog_case)):
user_question = dialog_case[step]['user']
print('user: {}'.format(user_question))

history_dialog_info = merge(history_dialog_info,
result) if step > 0 else {}
result = pipeline(user_question, history=history_dialog_info)

print('sys : {}'.format(result['pred_answer']))




if __name__ == '__main__': if __name__ == '__main__':


+ 0
- 0
tests/preprocessors/nlp/__init__.py View File


+ 25
- 0
tests/preprocessors/nlp/test_dialog_generation.py View File

@@ -0,0 +1,25 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import unittest

from tests.case.nlp.dialog_generation_case import test_case

from maas_lib.preprocessors import DialogGenerationPreprocessor
from maas_lib.utils.constant import Fields, InputFields
from maas_lib.utils.logger import get_logger

logger = get_logger()


class DialogGenerationPreprocessorTest(unittest.TestCase):

def test_tokenize(self):
modeldir = '/Users/yangliu/Desktop/space-dialog-generation'
processor = DialogGenerationPreprocessor(model_dir=modeldir)

for item in test_case['sng0073']['log']:
print(processor(item['user']))


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save