Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10061562master
@@ -241,6 +241,7 @@ class Trainers(object): | |||
# nlp trainers | |||
bert_sentiment_analysis = 'bert-sentiment-analysis' | |||
dialog_modeling_trainer = 'dialog-modeling-trainer' | |||
dialog_intent_trainer = 'dialog-intent-trainer' | |||
nlp_base_trainer = 'nlp-base-trainer' | |||
nlp_veco_trainer = 'nlp-veco-trainer' | |||
@@ -1,6 +1,6 @@ | |||
from .configuration_space import SpaceConfig | |||
from .gen_unified_transformer import GenUnifiedTransformer | |||
from .generator import Generator as SpaceGenerator | |||
from .generator import SpaceGenerator | |||
from .intent_unified_transformer import IntentUnifiedTransformer | |||
from .model_base import SpaceModelBase | |||
from .modeling_space import (SpaceForDST, SpaceForMaskedLM, | |||
@@ -38,24 +38,24 @@ def gather(var, idx): | |||
return var | |||
class Generator(object): | |||
class SpaceGenerator(object): | |||
""" Genrator class. """ | |||
_registry = dict() | |||
@classmethod | |||
def register(cls, name): | |||
Generator._registry[name] = cls | |||
SpaceGenerator._registry[name] = cls | |||
return | |||
@staticmethod | |||
def by_name(name): | |||
return Generator._registry[name] | |||
return SpaceGenerator._registry[name] | |||
@staticmethod | |||
def create(config, *args, **kwargs): | |||
""" Create generator. """ | |||
generator_cls = Generator.by_name(config.Generator.generator) | |||
generator_cls = SpaceGenerator.by_name(config.Generator.generator) | |||
return generator_cls(config, *args, **kwargs) | |||
def __init__(self, config, reader): | |||
@@ -83,7 +83,7 @@ class Generator(object): | |||
raise NotImplementedError | |||
class BeamSearch(Generator): | |||
class BeamSearch(SpaceGenerator): | |||
""" BeamSearch generator. """ | |||
def __init__(self, config, reader): | |||
@@ -41,7 +41,7 @@ class SpaceForDialogModeling(TorchModel): | |||
self.text_field = kwargs.pop( | |||
'text_field', | |||
MultiWOZBPETextField(self.model_dir, config=self.config)) | |||
MultiWOZBPETextField(config=self.config, model_dir=self.model_dir)) | |||
self.generator = SpaceGenerator.create( | |||
self.config, reader=self.text_field) | |||
self.model = SpaceModelBase.create( | |||
@@ -35,7 +35,7 @@ class DialogModelingPreprocessor(Preprocessor): | |||
self.config.use_gpu = self.config.use_gpu and torch.cuda.is_available() | |||
self.text_field = MultiWOZBPETextField( | |||
self.model_dir, config=self.config) | |||
config=self.config, model_dir=self.model_dir) | |||
@type_assert(object, Dict) | |||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |||
@@ -2,9 +2,11 @@ | |||
import os | |||
import random | |||
from asyncio import constants | |||
from collections import OrderedDict | |||
from itertools import chain | |||
import json | |||
import numpy as np | |||
from modelscope.preprocessors.space.tokenizer import Tokenizer | |||
@@ -117,7 +119,8 @@ class BPETextField(object): | |||
return self.tokenizer.convert_tokens_to_ids([self.eos_d_token])[0] | |||
def __init__(self, config): | |||
self.gpu = 0 | |||
self.train, self.dev, self.test = [], [], [] | |||
self.gpu = config.Trainer.gpu | |||
self.tokenizer = None | |||
self.vocab = None | |||
self.db = None | |||
@@ -249,13 +252,9 @@ class BPETextField(object): | |||
for dial in data: | |||
batch.append(dial) | |||
if len(batch) == self.batch_size: | |||
# print('batch size: %d, batch num +1'%(len(batch))) | |||
all_batches.append(batch) | |||
batch = [] | |||
# if remainder > 1/2 batch_size, just put them in the previous batch, otherwise form a new batch | |||
# print('last batch size: %d, batch num +1'%(len(batch))) | |||
# if (len(batch) % len(cfg.cuda_device)) != 0: | |||
# batch = batch[:-(len(batch) % len(cfg.cuda_device))] | |||
# TODO deal with deleted data | |||
if self.gpu <= 1: | |||
if len(batch) > 0.5 * self.batch_size: | |||
@@ -308,7 +307,7 @@ class BPETextField(object): | |||
class MultiWOZBPETextField(BPETextField): | |||
def __init__(self, model_dir, config): | |||
def __init__(self, config, **kwargs): | |||
super(MultiWOZBPETextField, self).__init__(config) | |||
import spacy | |||
@@ -327,8 +326,12 @@ class MultiWOZBPETextField(BPETextField): | |||
) | |||
self.nlp = spacy.load('en_core_web_sm') | |||
if config.do_train: | |||
db_dir = kwargs['data_dir'] | |||
else: | |||
db_dir = kwargs['model_dir'] | |||
self.db = MultiWozDB( | |||
model_dir, { | |||
db_dir, { | |||
'attraction': 'db/attraction_db_processed.json', | |||
'hospital': 'db/hospital_db_processed.json', | |||
'hotel': 'db/hotel_db_processed.json', | |||
@@ -337,14 +340,14 @@ class MultiWOZBPETextField(BPETextField): | |||
'taxi': 'db/taxi_db_processed.json', | |||
'train': 'db/train_db_processed.json', | |||
}) | |||
self._build_vocab(model_dir) | |||
self._build_vocab(db_dir) | |||
special_tokens = [ | |||
self.pad_token, self.bos_token, self.eos_token, self.unk_token | |||
] | |||
special_tokens.extend(self.add_sepcial_tokens()) | |||
self.tokenizer = Tokenizer( | |||
vocab_path=os.path.join(model_dir, ModelFile.VOCAB_FILE), | |||
vocab_path=os.path.join(kwargs['model_dir'], ModelFile.VOCAB_FILE), | |||
special_tokens=special_tokens, | |||
tokenizer_type=config.BPETextField.tokenizer_type) | |||
self.understand_ids = self.tokenizer.convert_tokens_to_ids( | |||
@@ -352,6 +355,26 @@ class MultiWOZBPETextField(BPETextField): | |||
self.policy_ids = self.tokenizer.convert_tokens_to_ids( | |||
self.policy_tokens) | |||
if config.do_train: | |||
test_list = [ | |||
line.strip().lower() for line in open( | |||
os.path.join(kwargs['data_dir'], 'testListFile.json'), | |||
'r').readlines() | |||
] | |||
dev_list = [ | |||
line.strip().lower() for line in open( | |||
os.path.join(kwargs['data_dir'], 'valListFile.json'), | |||
'r').readlines() | |||
] | |||
self.dev_files, self.test_files = {}, {} | |||
for fn in test_list: | |||
self.test_files[fn.replace('.json', '')] = 1 | |||
for fn in dev_list: | |||
self.dev_files[fn.replace('.json', '')] = 1 | |||
self._load_data(kwargs['data_dir']) | |||
return | |||
def get_ids(self, data: str): | |||
@@ -414,7 +437,6 @@ class MultiWOZBPETextField(BPETextField): | |||
name_to_set = {'train': self.train, 'test': self.test, 'dev': self.dev} | |||
dial = name_to_set[set_name] | |||
turn_bucket = self._bucket_by_turn(dial) | |||
# self._shuffle_turn_bucket(turn_bucket) | |||
all_batches = [] | |||
if set_name not in self.set_stats: | |||
@@ -433,19 +455,13 @@ class MultiWOZBPETextField(BPETextField): | |||
except Exception: | |||
log_str += 'turn num:%d, dial num: %d, batch num: %d last batch len: %d\n' % ( | |||
k, len(turn_bucket[k]), len(batches), 0.0) | |||
# print("turn num:%d, dial num:v%d, batch num: %d, "%(k, len(turn_bucket[k]), len(batches))) | |||
num_training_steps += k * len(batches) | |||
num_turns += k * len(turn_bucket[k]) | |||
num_dials += len(turn_bucket[k]) | |||
all_batches += batches | |||
log_str += 'total batch num: %d\n' % len(all_batches) | |||
# print('total batch num: %d'%len(all_batches)) | |||
# print('dialog count: %d'%dia_count) | |||
# return all_batches | |||
# log stats | |||
# logging.info(log_str) | |||
# cfg.num_training_steps = num_training_steps * cfg.epoch_num | |||
self.set_stats[set_name][ | |||
'num_training_steps_per_epoch'] = num_training_steps # turn-level steps | |||
self.set_stats[set_name]['num_turns'] = num_turns | |||
@@ -484,6 +500,71 @@ class MultiWOZBPETextField(BPETextField): | |||
self.vocab.load_vocab(vp) | |||
return self.vocab.vocab_size | |||
def _load_data(self, data_dir, save_temp=True): | |||
""" | |||
load processed data and encode, or load already encoded data | |||
""" | |||
def load_data_from_resource(data_resource): | |||
data = json.loads( | |||
open( | |||
os.path.join(data_dir, data_resource), | |||
'r', | |||
encoding='utf-8').read().lower()) | |||
train, dev, test = [], [], [] | |||
for fn, dial in data.items(): | |||
if '.json' in fn: | |||
fn = fn.replace('.json', '') | |||
if self.dev_files.get(fn): | |||
dev.append(self._get_encoded_data(fn, dial)) | |||
elif self.test_files.get(fn): | |||
test.append(self._get_encoded_data(fn, dial)) | |||
else: | |||
train.append(self._get_encoded_data(fn, dial)) | |||
return train, dev, test | |||
data_processed = 'new_db_se_blank_encoded_domain.data.json' | |||
data_resource = 'data_for_damd.json' | |||
if save_temp: # save encoded data | |||
# encoded: no sos, se_encoded: sos and eos | |||
encoded_file = os.path.join(data_dir, data_processed) | |||
if os.path.exists(encoded_file): | |||
logger.info( | |||
'Reading encoded data from {}'.format(encoded_file)) | |||
self.data = json.loads( | |||
open( | |||
os.path.join(data_dir, data_resource), | |||
'r', | |||
encoding='utf-8').read().lower()) | |||
encoded_data = json.loads( | |||
open(encoded_file, 'r', encoding='utf-8').read()) | |||
self.train = encoded_data['train'] | |||
self.dev = encoded_data['dev'] | |||
self.test = encoded_data['test'] | |||
else: | |||
logger.info( | |||
'Encoding data now and save the encoded data in {}'.format( | |||
encoded_file)) | |||
# not exists, encode data and save | |||
self.train, self.dev, self.test = load_data_from_resource( | |||
data_resource) | |||
# save encoded data | |||
encoded_data = { | |||
'train': self.train, | |||
'dev': self.dev, | |||
'test': self.test | |||
} | |||
json.dump(encoded_data, open(encoded_file, 'w'), indent=2) | |||
else: # directly read processed data and encode | |||
self.train, self.dev, self.test = load_data_from_resource( | |||
data_resource) | |||
random.seed(10) | |||
random.shuffle(self.train) | |||
logger.info('train size:{}, dev size:{}, test size:{}'.format( | |||
len(self.train), len(self.dev), len(self.test))) | |||
def _get_convert_str(self, sent): | |||
assert isinstance(sent, str) | |||
return ' '.join([ | |||
@@ -491,14 +572,65 @@ class MultiWOZBPETextField(BPETextField): | |||
for tok in sent.split() | |||
]) | |||
def _get_encoded_data(self, fn, dial): | |||
encoded_dial = [] | |||
for idx, t in enumerate(dial['log']): # tokenize to list of ids | |||
enc = {} | |||
enc['dial_id'] = fn | |||
enc_info_list = [ | |||
('user', self.sos_u_id, 'user', self.eos_u_id), | |||
('usdx', self.sos_u_id, 'user', self.eos_u_id), | |||
('resp', self.sos_r_id, 'resp', self.eos_r_id), | |||
('bspn', self.sos_b_id, 'constraint', self.eos_b_id), | |||
('bsdx', self.sos_b_id, 'cons_delex', self.eos_b_id), | |||
('aspn', self.sos_a_id, 'sys_act', self.eos_a_id) | |||
] | |||
for enc_key, start_token, item_key, end_token in enc_info_list: | |||
enc[enc_key] = [ | |||
start_token | |||
] + self.tokenizer.convert_tokens_to_ids( | |||
self.tokenizer.tokenize( | |||
self._get_convert_str(t[item_key]))) + [end_token] | |||
enc['turn_num'] = t['turn_num'] | |||
if idx > 0 and t['turn_domain'] == '[general]': | |||
enc['dspn'] = encoded_dial[idx - 1]['dspn'] | |||
enc['pointer'] = encoded_dial[idx - 1]['pointer'][:4] + [ | |||
int(i) for i in t['pointer'].split(',') | |||
][-2:] | |||
enc['turn_domain'] = encoded_dial[idx - 1]['turn_domain'] | |||
enc['db'] = encoded_dial[idx - 1]['db'] | |||
else: | |||
if t['turn_domain'] == '[general]': | |||
assert not t['constraint'], f'{fn}-{idx}' | |||
enc['dspn'] = [ | |||
self.sos_d_id | |||
] + self.tokenizer.convert_tokens_to_ids( | |||
self.tokenizer.tokenize( | |||
self._get_convert_str( | |||
t['turn_domain']))) + [self.eos_d_id] | |||
enc['pointer'] = [int(i) for i in t['pointer'].split(',')] | |||
enc['turn_domain'] = t['turn_domain'].split() | |||
db_pointer = self.bspan_to_DBpointer(t['constraint'], | |||
t['turn_domain'].split()) | |||
enc['db'] = [ | |||
self.sos_db_id | |||
] + self.tokenizer.convert_tokens_to_ids( | |||
self.tokenizer.tokenize( | |||
self._get_convert_str(db_pointer))) + [self.eos_db_id] | |||
encoded_dial.append(enc) | |||
return encoded_dial | |||
def bspan_to_DBpointer(self, bspan, turn_domain): | |||
constraint_dict = self.bspan_to_constraint_dict(bspan) | |||
# print(constraint_dict) | |||
matnums = self.db.get_match_num(constraint_dict) | |||
match_dom = turn_domain[0] if len(turn_domain) == 1 else turn_domain[1] | |||
match_dom = match_dom[1:-1] if match_dom.startswith('[') else match_dom | |||
match = matnums[match_dom] | |||
# vector = self.db.addDBPointer(match_dom, match) | |||
vector = self.db.addDBIndicator(match_dom, match) | |||
return vector | |||
@@ -691,3 +823,67 @@ class MultiWOZBPETextField(BPETextField): | |||
inputs['labels'] = [context] # use previous turn | |||
return inputs, prompt_id | |||
def restore(self, resp, domain, constraint_dict, mat_ents): | |||
restored = resp | |||
restored = restored.replace('[value_reference]', '53022') | |||
restored = restored.replace('[value_car]', 'BMW') | |||
for d in domain: | |||
constraint = constraint_dict.get(d, None) | |||
if constraint: | |||
replace_res_list = [('stay', '[value_stay]'), | |||
('day', '[value_day]'), | |||
('people', '[value_people]'), | |||
('time', '[value_time]'), | |||
('type', '[value_type]')] | |||
for key, value_key in replace_res_list: | |||
if key in constraint: | |||
restored = restored.replace(value_key, constraint[key]) | |||
if d in mat_ents and len(mat_ents[d]) == 0: | |||
for s in constraint: | |||
if s == 'pricerange' and d in [ | |||
'hotel', 'restaurant' | |||
] and 'price]' in restored: | |||
restored = restored.replace( | |||
'[value_price]', constraint['pricerange']) | |||
if s + ']' in restored: | |||
restored = restored.replace( | |||
'[value_%s]' % s, constraint[s]) | |||
if '[value_choice' in restored and mat_ents.get(d): | |||
restored = restored.replace('[value_choice]', | |||
str(len(mat_ents[d]))) | |||
if '[value_choice' in restored: | |||
restored = restored.replace('[value_choice]', '3') | |||
try: | |||
ent = mat_ents.get(domain[-1], []) | |||
if ent: | |||
ent = ent[0] | |||
for t in restored.split(): | |||
if '[value' in t: | |||
slot = t[7:-1] | |||
if ent.get(slot): | |||
if domain[-1] == 'hotel' and slot == 'price': | |||
slot = 'pricerange' | |||
restored = restored.replace(t, ent[slot]) | |||
elif slot == 'price': | |||
if ent.get('pricerange'): | |||
restored = restored.replace( | |||
t, ent['pricerange']) | |||
else: | |||
logger.info(restored, domain) | |||
except Exception: | |||
logger.error(resp) | |||
logger.error(restored) | |||
quit() | |||
restored = restored.replace('[value_phone]', '62781111') | |||
restored = restored.replace('[value_postcode]', 'CG9566') | |||
restored = restored.replace('[value_address]', 'Parkside, Cambridge') | |||
return restored |
@@ -0,0 +1,130 @@ | |||
import os | |||
import time | |||
from typing import Callable, Dict, Optional, Tuple, Union | |||
import numpy as np | |||
from modelscope.metainfo import Trainers | |||
from modelscope.models.nlp.space.model.generator import SpaceGenerator | |||
from modelscope.models.nlp.space.model.model_base import SpaceModelBase | |||
from modelscope.preprocessors.space.fields.gen_field import \ | |||
MultiWOZBPETextField | |||
from modelscope.trainers.base import BaseTrainer | |||
from modelscope.trainers.builder import TRAINERS | |||
from modelscope.trainers.nlp.space.eval import MultiWOZEvaluator | |||
from modelscope.trainers.nlp.space.trainer.gen_trainer import MultiWOZTrainer | |||
from modelscope.utils.config import Config, ModelFile | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
def setup_seed(seed: int): | |||
import random | |||
import torch | |||
torch.manual_seed(seed) | |||
torch.cuda.manual_seed_all(seed) | |||
np.random.seed(seed) | |||
random.seed(seed) | |||
torch.backends.cudnn.deterministic = True | |||
@TRAINERS.register_module(module_name=Trainers.dialog_modeling_trainer) | |||
class DialogModelingTrainer(BaseTrainer): | |||
def __init__(self, | |||
cfg_file: Optional[str] = None, | |||
cfg_modify_fn: Optional[Callable] = None, | |||
*args, | |||
**kwargs): | |||
super().__init__(os.path.join(kwargs['model_dir'], kwargs['cfg_name'])) | |||
self.cfg_modify_fn = cfg_modify_fn | |||
self.cfg = self.rebuild_config(self.cfg) | |||
setup_seed(self.cfg.Trainer.seed) | |||
# set reader and evaluator | |||
self.bpe = MultiWOZBPETextField(self.cfg, **kwargs) | |||
self.cfg.Model.num_token_embeddings = self.bpe.vocab_size | |||
self.cfg.Model.num_turn_embeddings = self.bpe.max_ctx_turn + 1 | |||
if 'work_dir' in kwargs: | |||
self.cfg.Trainer.save_dir = kwargs['work_dir'] | |||
else: | |||
self.cfg.Trainer.save_dir = './default_save_dir' | |||
# set data and data status | |||
self.train_data = self.bpe.get_batches('train') | |||
self.dev_data = self.bpe.get_batches('dev') | |||
self.evaluator = MultiWOZEvaluator(reader=self.bpe, **kwargs) | |||
# set generator | |||
self.generator = SpaceGenerator.create(self.cfg, reader=self.bpe) | |||
self._load_model(**kwargs) | |||
def _load_model(self, **kwargs): | |||
def to_tensor(array): | |||
""" | |||
numpy array -> tensor | |||
""" | |||
import torch | |||
array = torch.tensor(array) | |||
return array.cuda( | |||
) if self.cfg.use_gpu and torch.cuda.is_available() else array | |||
# construct model | |||
if 'model' in kwargs: | |||
self.model = kwargs['model'] | |||
else: | |||
self.model = SpaceModelBase.create( | |||
kwargs['model_dir'], | |||
self.cfg, | |||
reader=self.bpe, | |||
generator=self.generator) | |||
import torch | |||
# multi-gpu | |||
if self.cfg.Trainer.gpu > 1 and torch.cuda.device_count() > 1: | |||
self.model = torch.nn.DataParallel(self.model) | |||
# construct trainer | |||
self.trainer = MultiWOZTrainer( | |||
self.model, | |||
to_tensor, | |||
self.cfg, | |||
reader=self.bpe, | |||
evaluator=self.evaluator) | |||
self.trainer.set_optimizers() | |||
# load model, optimizer and lr_scheduler | |||
self.trainer.load() | |||
def rebuild_config(self, cfg: Config): | |||
if self.cfg_modify_fn is not None: | |||
return self.cfg_modify_fn(cfg) | |||
return cfg | |||
def train(self, *args, **kwargs): | |||
logger.info('Train') | |||
self.trainer.train(train_data=self.train_data, dev_data=self.dev_data) | |||
def evaluate(self, | |||
checkpoint_path: Optional[str] = None, | |||
*args, | |||
**kwargs) -> Dict[str, float]: | |||
logger.info('Evaluate') | |||
self.cfg.do_infer = True | |||
# get best checkpoint path | |||
pos = checkpoint_path.rfind('/') | |||
checkpoint_name = checkpoint_path[pos + 1:] | |||
checkpoint_dir = checkpoint_path[:pos] | |||
assert checkpoint_name == ModelFile.TORCH_MODEL_BIN_FILE | |||
kwargs['model_dir'] = checkpoint_dir | |||
self._load_model(**kwargs) | |||
self.trainer.infer(data_type='test') |
@@ -0,0 +1,952 @@ | |||
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||
# Copyright from https://github.com/thu-spmi/LABES | |||
# Copyright from https://github.com/TonyNemo/UBAR-MultiWOZ | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import math | |||
from collections import Counter | |||
import json | |||
import numpy as np | |||
from nltk.util import ngrams | |||
from sklearn.metrics import f1_score | |||
from modelscope.utils.nlp.space import ontology, utils | |||
from modelscope.utils.nlp.space.clean_dataset import clean_slot_values | |||
def similar(a, b): | |||
return a == b or a in b or b in a or a.split()[0] == b.split( | |||
)[0] or a.split()[-1] == b.split()[-1] | |||
def setsub(a, b): | |||
junks_a = [] | |||
useless_constraint = [ | |||
'temperature', 'week', 'est ', 'quick', 'reminder', 'near' | |||
] | |||
for i in a: | |||
flg = False | |||
for j in b: | |||
if similar(i, j): | |||
flg = True | |||
if not flg: | |||
junks_a.append(i) | |||
for junk in junks_a: | |||
flg = False | |||
for item in useless_constraint: | |||
if item in junk: | |||
flg = True | |||
if not flg: | |||
return False | |||
return True | |||
def setsim(a, b): | |||
a, b = set(a), set(b) | |||
return setsub(a, b) and setsub(b, a) | |||
def DA_evaluate(preds, labels): | |||
preds = np.array(preds) | |||
labels = np.array(labels) | |||
results = {} | |||
for avg_name in ['micro']: | |||
my_f1_score = f1_score(y_true=labels, y_pred=preds, average=avg_name) | |||
results['f1_{}'.format(avg_name)] = my_f1_score | |||
return results | |||
class BLEUScorer(object): | |||
# BLEU score calculator via GentScorer interface | |||
# it calculates the BLEU-4 by taking the entire corpus in | |||
# Calulate based multiple candidates against multiple references | |||
def __init__(self): | |||
pass | |||
def score(self, parallel_corpus): | |||
# containers | |||
count = [0, 0, 0, 0] | |||
clip_count = [0, 0, 0, 0] | |||
r = 0 | |||
c = 0 | |||
weights = [0.25, 0.25, 0.25, 0.25] | |||
# accumulate ngram statistics | |||
for hyps, refs in parallel_corpus: | |||
hyps = [hyp.split() for hyp in hyps] | |||
refs = [ref.split() for ref in refs] | |||
for hyp in hyps: | |||
for i in range(4): | |||
# accumulate ngram counts | |||
hypcnts = Counter(ngrams(hyp, i + 1)) | |||
cnt = sum(hypcnts.values()) | |||
count[i] += cnt | |||
# compute clipped counts | |||
max_counts = {} | |||
for ref in refs: | |||
refcnts = Counter(ngrams(ref, i + 1)) | |||
for ng in hypcnts: | |||
max_counts[ng] = max( | |||
max_counts.get(ng, 0), refcnts[ng]) | |||
clipcnt = \ | |||
dict((ng, min(count, max_counts[ng])) for ng, count in hypcnts.items()) | |||
clip_count[i] += sum(clipcnt.values()) | |||
# accumulate r & c | |||
bestmatch = [1000, 1000] | |||
for ref in refs: | |||
if bestmatch[0] == 0: | |||
break | |||
diff = abs(len(ref) - len(hyp)) | |||
if diff < bestmatch[0]: | |||
bestmatch[0] = diff | |||
bestmatch[1] = len(ref) | |||
r += bestmatch[1] | |||
c += len(hyp) | |||
# computing bleu score | |||
p0 = 1e-7 | |||
bp = \ | |||
1 if c > r else math.exp(1 - float(r) / float(c)) | |||
p_ns = \ | |||
[float(clip_count[i]) / float(count[i] + p0) + p0 for i in range(4)] | |||
s = \ | |||
math.fsum(w * math.log(p_n) for w, p_n in zip(weights, p_ns) if p_n) | |||
bleu = bp * math.exp(s) | |||
return bleu * 100 | |||
"""" | |||
For the data preparation and evaluation on MultiWOZ2.0/2.1, | |||
we refer to the code of UBAR (https://github.com/TonyNemo/UBAR-MultiWOZ) | |||
""" | |||
class MultiWOZEvaluator(object): | |||
def __init__(self, reader, **kwargs): | |||
self.reader = reader | |||
self.domains = ontology.all_domains | |||
self.all_data = self.reader.data | |||
self.test_data = self.reader.test | |||
self.bleu_scorer = BLEUScorer() | |||
self.all_info_slot = [] | |||
for d, s_list in ontology.informable_slots.items(): | |||
for s in s_list: | |||
self.all_info_slot.append(d + '-' + s) | |||
# only evaluate these slots for dialog success | |||
self.requestables = ['phone', 'address', 'postcode', 'reference', 'id'] | |||
self.db_dir = kwargs['data_dir'] | |||
def pack_dial(self, data): | |||
dials = {} | |||
for turn in data: | |||
dial_id = turn['dial_id'] | |||
if dial_id not in dials: | |||
dials[dial_id] = [] | |||
dials[dial_id].append(turn) | |||
return dials | |||
def validation_metric(self, data, fout=None): | |||
bleu = self.bleu_metric(data) | |||
# accu_single_dom, accu_multi_dom, multi_dom_num = self.domain_eval(data) | |||
success, match, req_offer_counts, dial_num = \ | |||
self.context_to_response_eval(data, same_eval_as_cambridge=True, fout=fout) | |||
return bleu, success, match | |||
def bleu_metric(self, data, eval_dial_list=None): | |||
gen, truth = [], [] | |||
for row in data: | |||
if eval_dial_list and row[ | |||
'dial_id'] + '.json' not in eval_dial_list: | |||
continue | |||
gen.append(row['resp_gen']) | |||
truth.append(row['resp']) | |||
wrap_generated = [[_] for _ in gen] | |||
wrap_truth = [[_] for _ in truth] | |||
if gen and truth: | |||
try: | |||
sc = self.bleu_scorer.score(zip(wrap_generated, wrap_truth)) | |||
except Exception: | |||
sc = 0.0 | |||
else: | |||
sc = 0.0 | |||
return sc | |||
def context_to_response_eval(self, | |||
data, | |||
eval_dial_list=None, | |||
same_eval_as_cambridge=False, | |||
fout=None): | |||
dials = self.pack_dial(data) | |||
counts = {} | |||
for req in self.requestables: | |||
counts[req + '_total'] = 0 | |||
counts[req + '_offer'] = 0 | |||
dial_num, successes, matches = 0, 0, 0 | |||
for dial_id in dials: | |||
if eval_dial_list and dial_id + '.json' not in eval_dial_list: | |||
continue | |||
dial = dials[dial_id] | |||
reqs = {} | |||
goal = {} | |||
if '.json' not in dial_id and '.json' in list( | |||
self.all_data.keys())[0]: | |||
dial_id = dial_id + '.json' | |||
for domain in ontology.all_domains: | |||
if self.all_data[dial_id]['goal'].get(domain): | |||
true_goal = self.all_data[dial_id]['goal'] | |||
goal = self._parseGoal(goal, true_goal, domain) | |||
for domain in goal.keys(): | |||
reqs[domain] = goal[domain]['requestable'] | |||
success, match, stats, counts = \ | |||
self._evaluateGeneratedDialogue(dial, goal, reqs, counts, | |||
same_eval_as_cambridge=same_eval_as_cambridge, fout=fout) | |||
successes += success | |||
matches += match | |||
dial_num += 1 | |||
succ_rate = successes / (float(dial_num) + 1e-10) * 100 | |||
match_rate = matches / (float(dial_num) + 1e-10) * 100 | |||
return succ_rate, match_rate, counts, dial_num | |||
def _evaluateGeneratedDialogue(self, | |||
dialog, | |||
goal, | |||
real_requestables, | |||
counts, | |||
soft_acc=False, | |||
same_eval_as_cambridge=False, | |||
fout=None): | |||
"""Evaluates the dialogue created by the model. | |||
First we load the user goal of the dialogue, then for each turn | |||
generated by the system we look for key-words. | |||
For the Inform rate we look whether the entity was proposed. | |||
For the Success rate we look for requestables slots""" | |||
# for computing corpus success | |||
requestables = self.requestables | |||
# CHECK IF MATCH HAPPENED | |||
provided_requestables = {} | |||
venue_offered = {} | |||
domains_in_goal = [] | |||
log = [] | |||
bspans = {} | |||
for domain in goal.keys(): | |||
venue_offered[domain] = [] | |||
provided_requestables[domain] = [] | |||
domains_in_goal.append(domain) | |||
for t, turn in enumerate(dialog): | |||
if t == 0: | |||
continue | |||
if fout is not None: | |||
log.append({ | |||
'turn_num': turn['turn_num'], | |||
'turn_domain': turn['dspn'], | |||
'user': turn['user'], | |||
'aspn': turn['aspn'], | |||
'aspn_gen': turn['aspn_gen'], | |||
'resp': turn['resp'], | |||
'resp_gen': turn['resp_gen'], | |||
'pointer': turn['pointer'], | |||
}) | |||
sent_t = turn['resp_gen'] | |||
for domain in goal.keys(): | |||
# for computing success | |||
if same_eval_as_cambridge: | |||
# [restaurant_name], [hotel_name] instead of [value_name] | |||
if self.reader.use_true_domain_for_ctr_eval: | |||
dom_pred = [d[1:-1] for d in turn['dspn'].split()] | |||
else: | |||
dom_pred = [d[1:-1] for d in turn['dspn_gen'].split()] | |||
if domain not in dom_pred: # fail | |||
continue | |||
if '[value_name]' in sent_t or '[value_id]' in sent_t: | |||
if domain in [ | |||
'restaurant', 'hotel', 'attraction', 'train' | |||
]: | |||
# HERE YOU CAN PUT YOUR BELIEF STATE ESTIMATION | |||
if not self.reader.use_true_curr_bspn and not self.reader.use_true_bspn_for_ctr_eval: | |||
bspn = turn['bspn_gen'] | |||
else: | |||
bspn = turn['bspn'] | |||
constraint_dict = self.reader.bspan_to_constraint_dict( | |||
bspn) | |||
if constraint_dict.get(domain): | |||
venues = self.reader.db.queryJsons( | |||
domain, | |||
constraint_dict[domain], | |||
return_name=True) | |||
else: | |||
venues = [] | |||
if len(venue_offered[domain]) == 0 and venues: | |||
venue_offered[domain] = venues | |||
bspans[domain] = constraint_dict[domain] | |||
else: | |||
flag = False | |||
for ven in venues: | |||
if ven not in venue_offered[domain]: | |||
flag = True | |||
break | |||
if flag and venues: # sometimes there are no results so sample won't work | |||
venue_offered[domain] = venues | |||
bspans[domain] = constraint_dict[domain] | |||
else: # not limited so we can provide one | |||
venue_offered[domain] = '[value_name]' | |||
# ATTENTION: assumption here - we didn't provide phone or address twice! etc | |||
for requestable in requestables: | |||
if requestable == 'reference': | |||
if '[value_reference]' in sent_t: | |||
if domain in ['restaurant', 'hotel', 'train']: | |||
if 'booked' in turn['pointer'] or 'ok' in turn[ | |||
'pointer'] or '[value_reference]' in turn[ | |||
'resp']: | |||
# if pointer was allowing for that? | |||
provided_requestables[domain].append( | |||
'reference') | |||
else: | |||
provided_requestables[domain].append( | |||
'reference') | |||
else: | |||
if '[value_' + requestable + ']' in sent_t: | |||
provided_requestables[domain].append(requestable) | |||
# if name was given in the task | |||
for domain in goal.keys(): | |||
# if name was provided for the user, the match is being done automatically | |||
if 'name' in goal[domain]['informable']: | |||
venue_offered[domain] = '[value_name]' | |||
# special domains - entity does not need to be provided | |||
if domain in ['taxi', 'police', 'hospital']: | |||
venue_offered[domain] = '[value_name]' | |||
if domain == 'train': | |||
if not venue_offered[domain] and 'id' not in goal[domain][ | |||
'requestable']: | |||
venue_offered[domain] = '[value_name]' | |||
""" | |||
Given all inform and requestable slots | |||
we go through each domain from the user goal | |||
and check whether right entity was provided and | |||
all requestable slots were given to the user. | |||
The dialogue is successful if that's the case for all domains. | |||
""" | |||
# HARD EVAL | |||
stats = { | |||
'restaurant': [0, 0, 0], | |||
'hotel': [0, 0, 0], | |||
'attraction': [0, 0, 0], | |||
'train': [0, 0, 0], | |||
'taxi': [0, 0, 0], | |||
'hospital': [0, 0, 0], | |||
'police': [0, 0, 0] | |||
} | |||
match = 0 | |||
success = 0 | |||
# MATCH | |||
for domain in goal.keys(): | |||
match_stat = 0 | |||
if domain in ['restaurant', 'hotel', 'attraction', 'train']: | |||
goal_venues = self.reader.db.queryJsons( | |||
domain, goal[domain]['informable'], return_name=True) | |||
if type(venue_offered[domain] | |||
) is str and '_name' in venue_offered[domain]: | |||
match += 1 | |||
match_stat = 1 | |||
elif len(venue_offered[domain]) > 0 and len( | |||
set(venue_offered[domain]) & set(goal_venues)) > 0: | |||
match += 1 | |||
match_stat = 1 | |||
else: | |||
if '_name]' in venue_offered[domain]: | |||
match += 1 | |||
match_stat = 1 | |||
stats[domain][0] = match_stat | |||
stats[domain][2] = 1 | |||
if soft_acc: | |||
match = float(match) / len(goal.keys()) | |||
else: | |||
if match == len(goal.keys()): | |||
match = 1.0 | |||
else: | |||
match = 0.0 | |||
for domain in domains_in_goal: | |||
for request in real_requestables[domain]: | |||
counts[request + '_total'] += 1 | |||
if request in provided_requestables[domain]: | |||
counts[request + '_offer'] += 1 | |||
# SUCCESS | |||
if fout is not None: | |||
for domain in domains_in_goal: | |||
success_stat = 0 | |||
domain_success = 0 | |||
if len(real_requestables[domain]) == 0: | |||
success += 1 | |||
success_stat = 1 | |||
stats[domain][1] = success_stat | |||
continue | |||
# if values in sentences are super set of requestables | |||
for request in real_requestables[domain]: | |||
if request in provided_requestables[domain]: | |||
domain_success += 1 | |||
if domain_success == len(real_requestables[domain]): | |||
success += 1 | |||
success_stat = 1 | |||
stats[domain][1] = success_stat | |||
# final eval | |||
if soft_acc: | |||
success = float(success) / len(real_requestables) | |||
else: | |||
if success >= len(real_requestables): | |||
success = 1 | |||
else: | |||
success = 0 | |||
else: | |||
if match == 1.0: | |||
for domain in domains_in_goal: | |||
success_stat = 0 | |||
domain_success = 0 | |||
if len(real_requestables[domain]) == 0: | |||
success += 1 | |||
success_stat = 1 | |||
stats[domain][1] = success_stat | |||
continue | |||
# if values in sentences are super set of requestables | |||
for request in real_requestables[domain]: | |||
if request in provided_requestables[domain]: | |||
domain_success += 1 | |||
if domain_success == len(real_requestables[domain]): | |||
success += 1 | |||
success_stat = 1 | |||
stats[domain][1] = success_stat | |||
# final eval | |||
if soft_acc: | |||
success = float(success) / len(real_requestables) | |||
else: | |||
if success >= len(real_requestables): | |||
success = 1 | |||
else: | |||
success = 0 | |||
if fout is not None and success == 0: | |||
sample = { | |||
dialog[0]['dial_id']: { | |||
'log': log, | |||
'real_requestables': real_requestables, | |||
'provided_requestables': provided_requestables | |||
} | |||
} | |||
line = json.dumps(sample) | |||
fout.write(line) | |||
fout.write('\n') | |||
return success, match, stats, counts | |||
def _parseGoal(self, goal, true_goal, domain): | |||
"""Parses user goal into dictionary format.""" | |||
goal[domain] = {} | |||
goal[domain] = {'informable': {}, 'requestable': [], 'booking': []} | |||
if 'info' in true_goal[domain]: | |||
if domain == 'train': | |||
# we consider dialogues only where train had to be booked! | |||
if 'book' in true_goal[domain]: | |||
goal[domain]['requestable'].append('reference') | |||
if 'reqt' in true_goal[domain]: | |||
if 'id' in true_goal[domain]['reqt']: | |||
goal[domain]['requestable'].append('id') | |||
else: | |||
if 'reqt' in true_goal[domain]: | |||
for s in true_goal[domain]['reqt']: # addtional requests: | |||
if s in [ | |||
'phone', 'address', 'postcode', 'reference', | |||
'id' | |||
]: | |||
# ones that can be easily delexicalized | |||
goal[domain]['requestable'].append(s) | |||
if 'book' in true_goal[domain]: | |||
goal[domain]['requestable'].append('reference') | |||
for s, v in true_goal[domain]['info'].items(): | |||
s_, v_ = clean_slot_values(self.db_dir, domain, s, v) | |||
if len(v_.split()) > 1: | |||
v_ = ' '.join( | |||
[token.text for token in self.reader.nlp(v_)]).strip() | |||
goal[domain]['informable'][s_] = v_ | |||
if 'book' in true_goal[domain]: | |||
goal[domain]['booking'] = true_goal[domain]['book'] | |||
return goal | |||
class GenericEvaluator: | |||
def __init__(self, reader): | |||
self.reader = reader | |||
self.metric_dict = {} | |||
def pack_dial(self, data): | |||
dials = {} | |||
for turn in data: | |||
dial_id = turn['dial_id'] | |||
if dial_id not in dials: | |||
dials[dial_id] = [] | |||
dials[dial_id].append(turn) | |||
return dials | |||
def run_metrics(self, results): | |||
raise ValueError('Please specify the evaluator first') | |||
def bleu_metric(self, data, type='bleu'): | |||
gen, truth = [], [] | |||
for row in data: | |||
gen.append(self.clean(row['resp_gen'])) | |||
# gen.append(self.clean(row['resp'])) | |||
truth.append(self.clean(row['resp'])) | |||
wrap_generated = [[_] for _ in gen] | |||
wrap_truth = [[_] for _ in truth] | |||
sc = BLEUScorer().score(zip(wrap_generated, wrap_truth)) | |||
return sc | |||
def _normalize_constraint(self, | |||
constraint, | |||
ignore_dontcare=False, | |||
intersection=True): | |||
""" | |||
Normalize belief span, e.g. delete repeated words | |||
:param constraint - {'food': 'asian oritental', 'pricerange': 'cheap'} | |||
:param intersection: if true, only keeps the words that appear in th ontology | |||
we set intersection=True as in previous works | |||
:returns: normalized constraint dict | |||
e.g. - {'food': 'asian oritental', 'pricerange': 'cheap', 'area': ''} | |||
""" | |||
normalized = {} | |||
for s in self.informable_slots: | |||
normalized[s] = '' | |||
for s, v in constraint.items(): | |||
if ignore_dontcare and v == 'dontcare': | |||
continue | |||
if intersection and v != 'dontcare' and v not in self.entities_flat: | |||
continue | |||
normalized[s] = v | |||
return normalized | |||
def _normalize_act(self, aspn, intersection=False): | |||
aspn_list = aspn.split('|') | |||
normalized = {} | |||
for i, v in enumerate(aspn_list): | |||
seq = v.strip() | |||
word_set = set() | |||
for w in seq.split(): | |||
if intersection: | |||
if self.reader.act_order[i] == 'av': | |||
if '[value' in w: | |||
word_set.add(w) | |||
else: | |||
if w in self.requestable_slots: | |||
word_set.add(w) | |||
else: | |||
word_set.add(w) | |||
normalized[self.reader.act_order[i]] = word_set | |||
return normalized | |||
def tracker_metric(self, data, normalize=True): | |||
# turn level metric | |||
tp, fp, fn, db_correct = 0, 0, 0, 0 | |||
goal_accr, slot_accr, total = 0, {}, 1e-8 | |||
for s in self.informable_slots: | |||
slot_accr[s] = 0 | |||
for row in data: | |||
if normalize: | |||
gen = self._normalize_constraint(row['bspn_gen']) | |||
truth = self._normalize_constraint(row['bspn']) | |||
else: | |||
gen = self._normalize_constraint( | |||
row['bspn_gen'], intersection=False) | |||
truth = self._normalize_constraint( | |||
row['bspn'], intersection=False) | |||
valid = 'thank' not in row['user'] and 'bye' not in row['user'] | |||
if valid: | |||
for slot, value in gen.items(): | |||
if value in truth[slot]: | |||
tp += 1 | |||
else: | |||
fp += 1 | |||
for slot, value in truth.items(): | |||
if value not in gen[slot]: | |||
fn += 1 | |||
if truth and valid: | |||
total += 1 | |||
for s in self.informable_slots: | |||
if gen[s] == truth[s]: | |||
slot_accr[s] += 1 | |||
if gen == truth: | |||
goal_accr += 1 | |||
if row.get('db_gen') and row.get('db_match'): | |||
if row['db_gen'] == row['db_match']: | |||
db_correct += 1 | |||
precision, recall = tp / (tp + fp + 1e-8), tp / (tp + fn + 1e-8) | |||
f1 = 2 * precision * recall / (precision + recall + 1e-8) | |||
goal_accr /= total | |||
db_correct /= total | |||
for s in slot_accr: | |||
slot_accr[s] /= total | |||
return precision, recall, f1, goal_accr, slot_accr, db_correct | |||
def request_metric(self, data): | |||
# dialog level metric | |||
dials = self.pack_dial(data) | |||
tp, fp, fn = 0, 0, 0 | |||
for dial_id in dials: | |||
truth_req, gen_req = set(), set() | |||
dial = dials[dial_id] | |||
for turn_num, turn in enumerate(dial): | |||
resp_gen_token = self.clean(turn['resp_gen']).split() | |||
resp_token = self.clean(turn['resp']).split() | |||
for w in resp_gen_token: | |||
if '[value_' in w and w.endswith( | |||
']') and w != '[value_name]': | |||
gen_req.add(w[1:-1].split('_')[1]) | |||
for w in resp_token: | |||
if '[value_' in w and w.endswith( | |||
']') and w != '[value_name]': | |||
truth_req.add(w[1:-1].split('_')[1]) | |||
for req in gen_req: | |||
if req in truth_req: | |||
tp += 1 | |||
else: | |||
fp += 1 | |||
for req in truth_req: | |||
if req not in gen_req: | |||
fn += 1 | |||
precision, recall = tp / (tp + fp + 1e-8), tp / (tp + fn + 1e-8) | |||
f1 = 2 * precision * recall / (precision + recall + 1e-8) | |||
return f1, precision, recall | |||
def act_metric(self, data): | |||
# turn level metric | |||
tp, fp, fn = { | |||
'all_s': 0, | |||
'all_v': 0 | |||
}, { | |||
'all_s': 0, | |||
'all_v': 0 | |||
}, { | |||
'all_s': 0, | |||
'all_v': 0 | |||
} | |||
for s in self.requestable_slots: | |||
tp[s], fp[s], fn[s] = 0, 0, 0 | |||
tp['[value_%s]' % s], fp['[value_%s]' % s], fn['[value_%s]' | |||
% s] = 0, 0, 0 | |||
for row in data: | |||
gen = self._normalize_act(row['aspn_gen']) | |||
truth = self._normalize_act(row['aspn']) | |||
valid = 'thank' not in row['user'] and 'bye' not in row['user'] | |||
if valid: | |||
# how well the act decoder captures user's requests | |||
for value in gen['av']: | |||
if value in truth['av']: | |||
tp['all_v'] += 1 | |||
if tp.get(value): | |||
tp[value] += 1 | |||
else: | |||
fp['all_v'] += 1 | |||
if fp.get(value): | |||
fp[value] += 1 | |||
for value in truth['av']: | |||
if value not in gen['av']: | |||
fn['all_v'] += 1 | |||
if fn.get(value): | |||
fn[value] += 1 | |||
# how accurately the act decoder predicts system's question | |||
if 'as' not in gen: | |||
continue | |||
for slot in gen['as']: | |||
if slot in truth['as']: | |||
tp['all_s'] += 1 | |||
if tp.get(slot): | |||
tp[slot] += 1 | |||
else: | |||
fp['all_s'] += 1 | |||
if fp.get(slot): | |||
fp[slot] += 1 | |||
for slot in truth['as']: | |||
if slot not in gen['as']: | |||
fn['all_s'] += 1 | |||
if fn.get(slot): | |||
fn[slot] += 1 | |||
result = {} | |||
for k, v in tp.items(): | |||
precision, recall = tp[k] / (tp[k] + fp[k] + 1e-8), tp[k] / ( | |||
tp[k] + fn[k] + 1e-8) | |||
f1 = 2 * precision * recall / (precision + recall + 1e-8) | |||
result[k] = [f1, precision, recall] | |||
return result | |||
""" | |||
For the data preparation and evaluation on In-Car Assistant/CamRest, | |||
we refer to the code of LABES (https://github.com/thu-spmi/LABES) | |||
""" | |||
class CamRestEvaluator(GenericEvaluator): | |||
def __init__(self, reader): | |||
super().__init__(reader) | |||
self.entities_flat, self.entitiy_to_slot_dict = self.get_entities( | |||
self.reader.ontology_path) | |||
self.informable_slots = self.reader.otlg.informable_slots | |||
self.requestable_slots = self.reader.otlg.requestable_slots | |||
def run_metrics(self, results): | |||
metrics = {} | |||
bleu = self.bleu_metric(results) | |||
p, r, f1, goal_acc, slot_acc, db_acc = self.tracker_metric(results) | |||
match = self.match_metric(results) | |||
req_f1, req_p, req_r = self.request_metric(results) | |||
metrics['bleu'] = bleu | |||
metrics['match'] = match | |||
metrics['req_f1'] = req_f1 | |||
metrics['joint_goal'] = goal_acc | |||
metrics['slot_accu'] = slot_acc | |||
metrics['slot-p/r/f1'] = (p, r, f1) | |||
metrics['db_acc'] = db_acc | |||
return metrics | |||
def get_entities(self, entity_path): | |||
entities_flat = [] | |||
entitiy_to_slot_dict = {} | |||
raw_entities = json.loads(open(entity_path).read().lower()) | |||
for s in raw_entities['informable']: | |||
entities_flat.extend(raw_entities['informable'][s]) | |||
for v in raw_entities['informable'][s]: | |||
entitiy_to_slot_dict[v] = s | |||
return entities_flat, entitiy_to_slot_dict | |||
def constraint_same(self, truth_cons, gen_cons): | |||
if not truth_cons and not gen_cons: | |||
return True | |||
if not truth_cons or not gen_cons: | |||
return False | |||
return setsim(gen_cons, truth_cons) | |||
def match_metric(self, data): | |||
dials = self.pack_dial(data) | |||
match, total = 0, 1e-8 | |||
for dial_id in dials: | |||
dial = dials[dial_id] | |||
truth_cons, gen_cons = {'1': '', '2': '', '3': ''}, None | |||
for turn_num, turn in enumerate(dial): | |||
# find the last turn which the system provide an entity | |||
if '[value' in turn['resp_gen']: | |||
gen_cons = self._normalize_constraint( | |||
turn['bspn_gen'], ignore_dontcare=True) | |||
if '[value' in turn['resp']: | |||
truth_cons = self._normalize_constraint( | |||
turn['bspn'], ignore_dontcare=True) | |||
if not gen_cons: | |||
# if no entity is provided, choose the state of the last dialog turn | |||
gen_cons = self._normalize_constraint( | |||
dial[-1]['bspn_gen'], ignore_dontcare=True) | |||
if list(truth_cons.values()) != ['', '', '']: | |||
if gen_cons == truth_cons: | |||
match += 1 | |||
total += 1 | |||
return match / total | |||
def clean(self, resp): | |||
# we use the same clean process as in Sequicity, SEDST, FSDM | |||
# to ensure comparable results | |||
resp = resp.replace(f'{self.reader.sos_r_token} ', '') | |||
resp = resp.replace(f' {self.reader.eos_r_token}', '') | |||
resp = f'{self.reader.sos_r_token} {resp} {self.reader.eos_r_token}' | |||
for value, slot in self.entitiy_to_slot_dict.items(): | |||
resp = utils.clean_replace(resp, value, '[value_%s]' % slot) | |||
return resp | |||
class KvretEvaluator(GenericEvaluator): | |||
def __init__(self, reader): | |||
super().__init__(reader) | |||
self.entities_flat, self.entitiy_to_slot_dict = self.get_entities( | |||
self.reader.ontology_path) | |||
self.informable_slots = self.reader.otlg.informable_slots | |||
self.requestable_slots = self.reader.otlg.requestable_slots | |||
def run_metrics(self, results): | |||
metrics = {} | |||
bleu = self.bleu_metric(results) | |||
p, r, f1, goal_acc, slot_acc, db_acc = self.tracker_metric( | |||
results, normalize=True) | |||
match = self.match_metric(results) | |||
req_f1, req_p, req_r = self.request_metric(results) | |||
metrics['bleu'] = bleu | |||
metrics['match'] = match | |||
metrics['req_f1'] = req_f1 | |||
metrics['joint_goal'] = goal_acc | |||
metrics['slot_accu'] = slot_acc | |||
metrics['slot-p/r/f1'] = (p, r, f1) | |||
metrics['db_acc'] = db_acc | |||
return metrics | |||
def _normalize_constraint(self, | |||
constraint, | |||
ignore_dontcare=False, | |||
intersection=True): | |||
""" | |||
Normalize belief span, e.g. delete repeated words | |||
:param constraint - {'food': 'asian oritental', 'pricerange': 'cheap'} | |||
:param intersection: if true, only keeps the words that appear in th ontology | |||
we set intersection=True as in previous works | |||
:returns: normalized constraint dict | |||
e.g. - {'food': 'asian oritental', 'pricerange': 'cheap', 'area': ''} | |||
""" | |||
junk = [ | |||
'good', 'great', 'quickest', 'shortest', 'route', 'week', | |||
'fastest', 'nearest', 'next', 'closest', 'way', 'mile', 'activity', | |||
'restaurant', 'appointment' | |||
] | |||
normalized = {} | |||
for s in self.informable_slots: | |||
normalized[s] = '' | |||
for s, v in constraint.items(): | |||
for j in junk: | |||
v = ' '.join(v.replace(j, '').split()) | |||
if intersection and v not in self.entities_flat: | |||
continue | |||
if s in self.informable_slots: | |||
normalized[s] = v | |||
else: | |||
# TODO only use slot (not domain) in s for matching !!! | |||
pass | |||
return normalized | |||
def get_entities(self, entity_path): | |||
entities_flat = [] | |||
entitiy_to_slot_dict = {} | |||
entitiy_to_slot_dict = self.reader.entity_dict | |||
for s in entitiy_to_slot_dict: | |||
if s not in entities_flat: | |||
entities_flat.append(s) | |||
return entities_flat, entitiy_to_slot_dict | |||
def constraint_same(self, truth_cons, gen_cons): | |||
if not truth_cons and not gen_cons: | |||
return True | |||
if not truth_cons or not gen_cons: | |||
return False | |||
return setsim(gen_cons, truth_cons) | |||
def match_metric(self, data): | |||
dials = self.pack_dial(data) | |||
match, total = 0, 1e-8 | |||
for dial_id in dials: | |||
dial = dials[dial_id] | |||
truth_cons, gen_cons = { | |||
'1': '', | |||
'2': '', | |||
'3': '', | |||
'4': '', | |||
'5': '', | |||
'6': '', | |||
'7': '', | |||
'8': '', | |||
'9': '', | |||
'10': '', | |||
'11': '' | |||
}, None | |||
for turn_num, turn in enumerate(dial): | |||
# find the last turn which the system provide an entity | |||
if '[value' in turn['resp_gen']: | |||
gen_cons = self._normalize_constraint( | |||
turn['bspn_gen'], ignore_dontcare=True) | |||
if '[value' in turn['resp']: | |||
truth_cons = self._normalize_constraint( | |||
turn['bspn'], ignore_dontcare=True) | |||
if not gen_cons: | |||
# if no entity is provided, choose the state of the last dialog turn | |||
gen_cons = self._normalize_constraint( | |||
dial[-1]['bspn_gen'], ignore_dontcare=True) | |||
if list(truth_cons.values()) != [''] * 11: | |||
gen_cons = [x for x in gen_cons.values() if x] | |||
truth_cons = [x for x in truth_cons.values() if x] | |||
if self.constraint_same(gen_cons, truth_cons): | |||
match += 1 | |||
total += 1 | |||
return match / total | |||
def clean(self, resp): | |||
# we use the same clean process as in Sequicity, SEDST, FSDM | |||
# to ensure comparable results | |||
resp = resp.replace(f'{self.reader.sos_r_token} ', '') | |||
resp = resp.replace(f' {self.reader.eos_r_token}', '') | |||
resp = f'{self.reader.sos_r_token} {resp} {self.reader.eos_r_token}' | |||
for value, slot in self.entitiy_to_slot_dict.items(): | |||
resp = utils.clean_replace(resp, value, '[value_%s]' % slot) | |||
return resp |
@@ -15,27 +15,11 @@ from transformers.optimization import AdamW, get_linear_schedule_with_warmup | |||
from modelscope.trainers.nlp.space.metrics.metrics_tracker import \ | |||
MetricsTracker | |||
from modelscope.utils.constant import ModelFile | |||
from modelscope.utils.logger import get_logger | |||
from modelscope.utils.nlp.space import ontology | |||
def get_logger(log_path, name='default'): | |||
logger = logging.getLogger(name) | |||
logger.propagate = False | |||
logger.setLevel(logging.DEBUG) | |||
formatter = logging.Formatter('%(message)s') | |||
sh = logging.StreamHandler(sys.stdout) | |||
sh.setFormatter(formatter) | |||
logger.addHandler(sh) | |||
fh = logging.FileHandler(log_path, mode='w') | |||
fh.setFormatter(formatter) | |||
logger.addHandler(fh) | |||
return logger | |||
class Trainer(object): | |||
def __init__(self, | |||
@@ -51,15 +35,16 @@ class Trainer(object): | |||
self.do_train = config.do_train | |||
self.do_infer = config.do_infer | |||
self.is_decreased_valid_metric = config.Trainer.valid_metric_name[ | |||
0] == '-' | |||
self.valid_metric_name = config.Trainer.valid_metric_name[1:] | |||
self.num_epochs = config.Trainer.num_epochs | |||
# self.save_dir = config.Trainer.save_dir | |||
self.log_steps = config.Trainer.log_steps | |||
self.valid_steps = config.Trainer.valid_steps | |||
self.save_checkpoint = config.Trainer.save_checkpoint | |||
self.save_summary = config.Trainer.save_summary | |||
if self.do_train: | |||
self.is_decreased_valid_metric = config.Trainer.valid_metric_name[ | |||
0] == '-' | |||
self.valid_metric_name = config.Trainer.valid_metric_name[1:] | |||
self.num_epochs = config.Trainer.num_epochs | |||
self.save_dir = config.Trainer.save_dir | |||
self.log_steps = config.Trainer.log_steps | |||
self.valid_steps = config.Trainer.valid_steps | |||
self.save_checkpoint = config.Trainer.save_checkpoint | |||
self.save_summary = config.Trainer.save_summary | |||
self.lr = config.Model.lr | |||
self.weight_decay = config.Model.weight_decay | |||
self.batch_size = config.Trainer.batch_size | |||
@@ -71,22 +56,21 @@ class Trainer(object): | |||
self.optimizer = optimizer | |||
self.model = model | |||
self.func_model = self.model.module if self.gpu > 1 else self.model | |||
self.func_model = self.model.module if self.gpu > 1 and config.use_gpu else self.model | |||
self.reader = reader | |||
self.evaluator = evaluator | |||
self.tokenizer = reader.tokenizer | |||
# if not os.path.exists(self.save_dir): | |||
# os.makedirs(self.save_dir) | |||
# self.logger = logger or get_logger(os.path.join(self.save_dir, "trainer.log"), "trainer") | |||
self.logger = logger or get_logger('trainer.log', 'trainer') | |||
self.logger = get_logger() | |||
self.batch_metrics_tracker = MetricsTracker() | |||
self.token_metrics_tracker = MetricsTracker() | |||
self.best_valid_metric = float( | |||
'inf' if self.is_decreased_valid_metric else '-inf') | |||
if self.do_train: | |||
if not os.path.exists(self.save_dir): | |||
os.makedirs(self.save_dir) | |||
self.best_valid_metric = float( | |||
'inf' if self.is_decreased_valid_metric else '-inf') | |||
self.epoch = 0 | |||
def decode_generated_bspn_resp(self, generated): | |||
@@ -248,9 +232,12 @@ class Trainer(object): | |||
# Save current best model | |||
if is_best: | |||
best_model_file = os.path.join(self.save_dir, 'best.model') | |||
best_model_file = os.path.join(self.save_dir, | |||
ModelFile.TORCH_MODEL_BIN_FILE) | |||
torch.save(self.model.state_dict(), best_model_file) | |||
best_train_file = os.path.join(self.save_dir, 'best.train') | |||
best_train_file = os.path.join( | |||
self.save_dir, | |||
'{}.train'.format(ModelFile.TORCH_MODEL_BIN_FILE)) | |||
torch.save(train_state, best_train_file) | |||
self.logger.info( | |||
f"Saved best model state to '{best_model_file}' with new best valid metric " | |||
@@ -324,8 +311,7 @@ class Trainer(object): | |||
self.func_model.load_state_dict(model_state_dict) | |||
self.logger.info( | |||
f"Loaded model state from '{self.func_model.init_checkpoint}.model'" | |||
) | |||
f"Loaded model state from '{self.func_model.init_checkpoint}'") | |||
def _load_train_state(): | |||
train_file = f'{self.func_model.init_checkpoint}.train' | |||
@@ -558,19 +544,17 @@ class MultiWOZTrainer(Trainer): | |||
generated_bs = outputs[0].cpu().numpy().tolist() | |||
bspn_gen = self.decode_generated_bspn(generated_bs) | |||
# check DB result | |||
if self.reader.use_true_db_pointer: # To control whether current db is ground truth | |||
if self.reader.use_true_db_pointer: | |||
db = turn['db'] | |||
else: | |||
db_result = self.reader.bspan_to_DBpointer( | |||
self.tokenizer.decode(bspn_gen), | |||
turn['turn_domain']) | |||
assert len(turn['db']) == 4 | |||
book_result = turn['db'][2] | |||
assert len(turn['db']) == 3 | |||
assert isinstance(db_result, str) | |||
db = \ | |||
[self.reader.sos_db_id] + \ | |||
self.tokenizer.convert_tokens_to_ids([db_result]) + \ | |||
[book_result] + \ | |||
[self.reader.eos_db_id] | |||
prompt_id = self.reader.sos_a_id | |||
@@ -636,7 +620,7 @@ class MultiWOZTrainer(Trainer): | |||
score = 0.5 * (success + match) + bleu | |||
# log results | |||
metrics_message = 'match: %2.2f success: %2.2f bleu: %2.2f score: %.2f' %\ | |||
metrics_message = 'match: %2.2f success: %2.2f bleu: %2.2f score: %.2f' % \ | |||
(match, success, bleu, score) | |||
message_prefix = f'[Infer][{self.epoch}]' | |||
time_cost = f'TIME-{time.time() - begin_time:.3f}' | |||
@@ -0,0 +1,333 @@ | |||
import os | |||
import re | |||
from . import ontology | |||
def clean_text_split_dot(text): | |||
text = re.sub(r'([a-zT]+)\.([a-z])', r'\1 . \2', | |||
text) # 'abc.xyz' -> 'abc . xyz' | |||
text = re.sub(r'(\w+)\.\.? ', r'\1 . ', text) # if 'abc. ' -> 'abc . ' | |||
return text | |||
def clean_text(data_dir, text): | |||
text = text.strip() | |||
text = text.lower() | |||
text = text.replace(u'’', "'") | |||
text = text.replace(u'‘', "'") | |||
text = text.replace(';', ',') | |||
text = text.replace('"', ' ') | |||
text = text.replace('/', ' and ') | |||
text = text.replace("don't", "do n't") | |||
text = clean_time(text) | |||
baddata = { | |||
r'c\.b (\d), (\d) ([a-z])\.([a-z])': r'cb\1\2\3\4', | |||
'c.b. 1 7 d.y': 'cb17dy', | |||
'c.b.1 7 d.y': 'cb17dy', | |||
'c.b 25, 9 a.q': 'cb259aq', | |||
'isc.b 25, 9 a.q': 'is cb259aq', | |||
'c.b2, 1 u.f': 'cb21uf', | |||
'c.b 1,2 q.a': 'cb12qa', | |||
'0-122-336-5664': '01223365664', | |||
'postcodecb21rs': 'postcode cb21rs', | |||
r'i\.d': 'id', | |||
' i d ': 'id', | |||
'Telephone:01223358966': 'Telephone: 01223358966', | |||
'depature': 'departure', | |||
'depearting': 'departing', | |||
'-type': ' type', | |||
r'b[\s]?&[\s]?b': 'bed and breakfast', | |||
'b and b': 'bed and breakfast', | |||
r'guesthouse[s]?': 'guest house', | |||
r'swimmingpool[s]?': 'swimming pool', | |||
"wo n\'t": 'will not', | |||
" \'d ": ' would ', | |||
" \'m ": ' am ', | |||
" \'re' ": ' are ', | |||
" \'ll' ": ' will ', | |||
" \'ve ": ' have ', | |||
r'^\'': '', | |||
r'\'$': '', | |||
} | |||
for tmpl, good in baddata.items(): | |||
text = re.sub(tmpl, good, text) | |||
text = re.sub(r'([a-zT]+)\.([a-z])', r'\1 . \2', | |||
text) # 'abc.xyz' -> 'abc . xyz' | |||
text = re.sub(r'(\w+)\.\.? ', r'\1 . ', text) # if 'abc. ' -> 'abc . ' | |||
with open(os.path.join(data_dir, 'mapping.pair'), 'r') as fin: | |||
for line in fin.readlines(): | |||
fromx, tox = line.replace('\n', '').split('\t') | |||
text = ' ' + text + ' ' | |||
text = text.replace(' ' + fromx + ' ', ' ' + tox + ' ')[1:-1] | |||
return text | |||
def clean_time(utter): | |||
utter = re.sub(r'(\d+) ([ap]\.?m)', lambda x: x.group(1) + x.group(2), | |||
utter) # 9 am -> 9am | |||
utter = re.sub(r'((?<!\d)\d:\d+)(am)?', r'0\1', utter) | |||
utter = re.sub(r'((?<!\d)\d)am', r'0\1:00', utter) | |||
utter = re.sub(r'((?<!\d)\d)pm', | |||
lambda x: str(int(x.group(1)) + 12) + ':00', utter) | |||
utter = re.sub(r'(\d+)(:\d+)pm', | |||
lambda x: str(int(x.group(1)) + 12) + x.group(2), utter) | |||
utter = re.sub(r'(\d+)a\.?m', r'\1', utter) | |||
return utter | |||
def clean_slot_values(data_dir, domain, slot, value): | |||
value = clean_text(data_dir, value) | |||
if not value: | |||
value = '' | |||
elif value == 'not mentioned': | |||
value = '' | |||
# value = 'not mentioned' # if in DST setting | |||
elif domain == 'attraction': | |||
if slot == 'name': | |||
if value == 't': | |||
value = '' | |||
if value == 'trinity': | |||
value = 'trinity college' | |||
elif slot == 'area': | |||
if value in ['town centre', 'cent', 'center', 'ce']: | |||
value = 'centre' | |||
elif value in [ | |||
'ely', 'in town', 'museum', 'norwich', 'same area as hotel' | |||
]: | |||
value = '' | |||
elif value in ['we']: | |||
value = 'west' | |||
elif slot == 'type': | |||
if value in ['m', 'mus', 'musuem']: | |||
value = 'museum' | |||
elif value in ['art', 'architectural']: | |||
value = 'architecture' | |||
elif value in ['churches']: | |||
value = 'church' | |||
elif value in ['coll']: | |||
value = 'college' | |||
elif value in ['concert', 'concerthall']: | |||
value = 'concert hall' | |||
elif value in ['night club']: | |||
value = 'nightclub' | |||
elif value in [ | |||
'mutiple sports', 'mutliple sports', 'sports', 'galleria' | |||
]: | |||
value = 'multiple sports' | |||
elif value in ['ol', 'science', 'gastropub', 'la raza']: | |||
value = '' | |||
elif value in ['swimmingpool', 'pool']: | |||
value = 'swimming pool' | |||
elif value in ['fun']: | |||
value = 'entertainment' | |||
elif domain == 'hotel': | |||
if slot == 'area': | |||
if value in [ | |||
'cen', 'centre of town', 'near city center', 'center' | |||
]: | |||
value = 'centre' | |||
elif value in ['east area', 'east side']: | |||
value = 'east' | |||
elif value in ['in the north', 'north part of town']: | |||
value = 'north' | |||
elif value in ['we']: | |||
value = 'west' | |||
elif slot == 'day': | |||
if value == 'monda': | |||
value = 'monday' | |||
elif value == 't': | |||
value = 'tuesday' | |||
elif slot == 'name': | |||
if value == 'uni': | |||
value = 'university arms hotel' | |||
elif value == 'university arms': | |||
value = 'university arms hotel' | |||
elif value == 'acron': | |||
value = 'acorn guest house' | |||
elif value == 'ashley': | |||
value = 'ashley hotel' | |||
elif value == 'arbury lodge guesthouse': | |||
value = 'arbury lodge guest house' | |||
elif value == 'la': | |||
value = 'la margherit' | |||
elif value == 'no': | |||
value = '' | |||
elif slot == 'internet': | |||
if value == 'does not': | |||
value = 'no' | |||
elif value in ['y', 'free', 'free internet']: | |||
value = 'yes' | |||
elif value in ['4']: | |||
value = '' | |||
elif slot == 'parking': | |||
if value == 'n': | |||
value = 'no' | |||
elif value in ['free parking']: | |||
value = 'yes' | |||
elif value in ['y']: | |||
value = 'yes' | |||
elif slot in ['pricerange', 'price range']: | |||
slot = 'pricerange' | |||
if value == 'moderately': | |||
value = 'moderate' | |||
elif value in ['any']: | |||
value = "do n't care" | |||
elif value in ['any']: | |||
value = "do n't care" | |||
elif value in ['inexpensive']: | |||
value = 'cheap' | |||
elif value in ['2', '4']: | |||
value = '' | |||
elif slot == 'stars': | |||
if value == 'two': | |||
value = '2' | |||
elif value == 'three': | |||
value = '3' | |||
elif value in [ | |||
'4-star', '4 stars', '4 star', 'four star', 'four stars' | |||
]: | |||
value = '4' | |||
elif slot == 'type': | |||
if value == '0 star rarting': | |||
value = '' | |||
elif value == 'guesthouse': | |||
value = 'guest house' | |||
elif value not in ['hotel', 'guest house', "do n't care"]: | |||
value = '' | |||
elif domain == 'restaurant': | |||
if slot == 'area': | |||
if value in [ | |||
'center', 'scentre', 'center of town', 'city center', | |||
'cb30aq', 'town center', 'centre of cambridge', | |||
'city centre' | |||
]: | |||
value = 'centre' | |||
elif value == 'west part of town': | |||
value = 'west' | |||
elif value == 'n': | |||
value = 'north' | |||
elif value in ['the south']: | |||
value = 'south' | |||
elif value not in [ | |||
'centre', 'south', "do n't care", 'west', 'east', 'north' | |||
]: | |||
value = '' | |||
elif slot == 'day': | |||
if value == 'monda': | |||
value = 'monday' | |||
elif value == 't': | |||
value = 'tuesday' | |||
elif slot in ['pricerange', 'price range']: | |||
slot = 'pricerange' | |||
if value in ['moderately', 'mode', 'mo']: | |||
value = 'moderate' | |||
elif value in ['not']: | |||
value = '' | |||
elif value in ['inexpensive', 'ch']: | |||
value = 'cheap' | |||
elif slot == 'food': | |||
if value == 'barbecue': | |||
value = 'barbeque' | |||
elif slot == 'pricerange': | |||
if value == 'moderately': | |||
value = 'moderate' | |||
elif slot == 'time': | |||
if value == '9:00': | |||
value = '09:00' | |||
elif value == '9:45': | |||
value = '09:45' | |||
elif value == '1330': | |||
value = '13:30' | |||
elif value == '1430': | |||
value = '14:30' | |||
elif value == '9:15': | |||
value = '09:15' | |||
elif value == '9:30': | |||
value = '09:30' | |||
elif value == '1830': | |||
value = '18:30' | |||
elif value == '9': | |||
value = '09:00' | |||
elif value == '2:00': | |||
value = '14:00' | |||
elif value == '1:00': | |||
value = '13:00' | |||
elif value == '3:00': | |||
value = '15:00' | |||
elif domain == 'taxi': | |||
if slot in ['arriveBy', 'arrive by']: | |||
slot = 'arriveby' | |||
if value == '1530': | |||
value = '15:30' | |||
elif value == '15 minutes': | |||
value = '' | |||
elif slot in ['leaveAt', 'leave at']: | |||
slot = 'leaveat' | |||
if value == '1:00': | |||
value = '01:00' | |||
elif value == '21:4': | |||
value = '21:04' | |||
elif value == '4:15': | |||
value = '04:15' | |||
elif value == '5:45': | |||
value = '05:45' | |||
elif value == '0700': | |||
value = '07:00' | |||
elif value == '4:45': | |||
value = '04:45' | |||
elif value == '8:30': | |||
value = '08:30' | |||
elif value == '9:30': | |||
value = '09:30' | |||
value = value.replace('.', ':') | |||
elif domain == 'train': | |||
if slot in ['arriveBy', 'arrive by']: | |||
slot = 'arriveby' | |||
if value == '1': | |||
value = '01:00' | |||
elif value in ['does not care', 'doesnt care', "doesn't care"]: | |||
value = "do n't care" | |||
elif value == '8:30': | |||
value = '08:30' | |||
elif value == 'not 15:45': | |||
value = '' | |||
value = value.replace('.', ':') | |||
elif slot == 'day': | |||
if value == 'doesnt care' or value == "doesn't care": | |||
value = "do n't care" | |||
elif slot in ['leaveAt', 'leave at']: | |||
slot = 'leaveat' | |||
if value == '2:30': | |||
value = '02:30' | |||
elif value == '7:54': | |||
value = '07:54' | |||
elif value == 'after 5:45 pm': | |||
value = '17:45' | |||
elif value in [ | |||
'early evening', 'friday', 'sunday', 'tuesday', 'afternoon' | |||
]: | |||
value = '' | |||
elif value == '12': | |||
value = '12:00' | |||
elif value == '1030': | |||
value = '10:30' | |||
elif value == '1700': | |||
value = '17:00' | |||
elif value in [ | |||
'does not care', 'doesnt care', 'do nt care', | |||
"doesn't care" | |||
]: | |||
value = "do n't care" | |||
value = value.replace('.', ':') | |||
if value in ['dont care', "don't care", 'do nt care', "doesn't care"]: | |||
value = "do n't care" | |||
if ontology.normlize_slot_names.get(slot): | |||
slot = ontology.normlize_slot_names[slot] | |||
return slot, value |
@@ -4,8 +4,11 @@ from collections import OrderedDict | |||
import json | |||
import numpy as np | |||
from modelscope.utils.logger import get_logger | |||
from . import ontology | |||
logger = get_logger() | |||
def max_lens(X): | |||
lens = [len(X)] | |||
@@ -117,8 +120,8 @@ class MultiWOZVocab(object): | |||
def construct(self): | |||
freq_dict_sorted = sorted( | |||
self._freq_dict.keys(), key=lambda x: -self._freq_dict[x]) | |||
print('Vocabulary size including oov: %d' % | |||
(len(freq_dict_sorted) + len(self._idx2word))) | |||
logger.info('Vocabulary size including oov: %d' % | |||
(len(freq_dict_sorted) + len(self._idx2word))) | |||
if len(freq_dict_sorted) + len(self._idx2word) < self.vocab_size: | |||
logging.warning( | |||
'actual label set smaller than that configured: {}/{}'.format( | |||
@@ -148,8 +151,9 @@ class MultiWOZVocab(object): | |||
for w, idx in self._word2idx.items(): | |||
self._idx2word[idx] = w | |||
self.vocab_size_oov = len(self._idx2word) | |||
print('vocab file loaded from "' + vocab_path + '"') | |||
print('Vocabulary size including oov: %d' % (self.vocab_size_oov)) | |||
logger.info('vocab file loaded from "' + vocab_path + '"') | |||
logger.info('Vocabulary size including oov: %d' % | |||
(self.vocab_size_oov)) | |||
def save_vocab(self, vocab_path): | |||
_freq_dict = OrderedDict( | |||
@@ -0,0 +1,68 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import unittest | |||
import torch | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.metainfo import Preprocessors, Trainers | |||
from modelscope.msdatasets import MsDataset | |||
from modelscope.trainers import build_trainer | |||
from modelscope.utils.constant import DownloadMode, ModelFile | |||
from modelscope.utils.test_utils import test_level | |||
class TestDialogModelingTrainer(unittest.TestCase): | |||
model_id = 'damo/nlp_space_pretrained-dialog-model' | |||
output_dir = './dialog_fintune_result' | |||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
def test_trainer_with_model_and_args(self): | |||
# download data set | |||
data_multiwoz = MsDataset.load( | |||
'MultiWoz2.0', download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS) | |||
data_dir = os.path.join( | |||
data_multiwoz._hf_ds.config_kwargs['split_config']['train'], | |||
'data') | |||
# download model | |||
model_dir = snapshot_download(self.model_id) | |||
# dialog finetune config | |||
def cfg_modify_fn(cfg): | |||
config = { | |||
'seed': 10, | |||
'gpu': 4, | |||
'use_data_distributed': False, | |||
'valid_metric_name': '-loss', | |||
'num_epochs': 60, | |||
'save_dir': self.output_dir, | |||
'token_loss': True, | |||
'batch_size': 32, | |||
'log_steps': 10, | |||
'valid_steps': 0, | |||
'save_checkpoint': True, | |||
'save_summary': False, | |||
'shuffle': True, | |||
'sort_pool_size': 0 | |||
} | |||
cfg.Trainer = config | |||
cfg.use_gpu = torch.cuda.is_available() and config['gpu'] >= 1 | |||
return cfg | |||
# trainer config | |||
kwargs = dict( | |||
model_dir=model_dir, | |||
cfg_name='gen_train_config.json', | |||
data_dir=data_dir, | |||
cfg_modify_fn=cfg_modify_fn) | |||
trainer = build_trainer( | |||
name=Trainers.dialog_modeling_trainer, default_args=kwargs) | |||
trainer.train() | |||
checkpoint_path = os.path.join(self.output_dir, | |||
ModelFile.TORCH_MODEL_BIN_FILE) | |||
assert os.path.exists(checkpoint_path) | |||
trainer.evaluate(checkpoint_path=checkpoint_path) |