From b41b10f8970a748dee26baf8e5e1d13e04568e54 Mon Sep 17 00:00:00 2001 From: ly119399 Date: Fri, 9 Sep 2022 14:27:08 +0800 Subject: [PATCH] [to #42322933] space finetune on generation task Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10061562 --- modelscope/metainfo.py | 1 + modelscope/models/nlp/space/model/__init__.py | 2 +- .../models/nlp/space/model/generator.py | 10 +- .../nlp/space/space_for_dialog_modeling.py | 2 +- .../space/dialog_modeling_preprocessor.py | 2 +- .../preprocessors/space/fields/gen_field.py | 236 ++++- .../nlp/space/dialog_modeling_trainer.py | 130 +++ modelscope/trainers/nlp/space/eval.py | 952 ++++++++++++++++++ .../trainers/nlp/space/trainer/gen_trainer.py | 72 +- modelscope/utils/nlp/space/clean_dataset.py | 333 ++++++ modelscope/utils/nlp/space/utils.py | 12 +- .../trainers/test_dialog_modeling_trainer.py | 68 ++ 12 files changed, 1744 insertions(+), 76 deletions(-) create mode 100644 modelscope/trainers/nlp/space/dialog_modeling_trainer.py create mode 100644 modelscope/trainers/nlp/space/eval.py create mode 100644 modelscope/utils/nlp/space/clean_dataset.py create mode 100644 tests/trainers/test_dialog_modeling_trainer.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index e051bb76..63b4f1c2 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -241,6 +241,7 @@ class Trainers(object): # nlp trainers bert_sentiment_analysis = 'bert-sentiment-analysis' + dialog_modeling_trainer = 'dialog-modeling-trainer' dialog_intent_trainer = 'dialog-intent-trainer' nlp_base_trainer = 'nlp-base-trainer' nlp_veco_trainer = 'nlp-veco-trainer' diff --git a/modelscope/models/nlp/space/model/__init__.py b/modelscope/models/nlp/space/model/__init__.py index 24641f06..bb1d18e4 100644 --- a/modelscope/models/nlp/space/model/__init__.py +++ b/modelscope/models/nlp/space/model/__init__.py @@ -1,6 +1,6 @@ from .configuration_space import SpaceConfig from .gen_unified_transformer import GenUnifiedTransformer -from .generator import Generator as SpaceGenerator +from .generator import SpaceGenerator from .intent_unified_transformer import IntentUnifiedTransformer from .model_base import SpaceModelBase from .modeling_space import (SpaceForDST, SpaceForMaskedLM, diff --git a/modelscope/models/nlp/space/model/generator.py b/modelscope/models/nlp/space/model/generator.py index c1521e3d..0e7833e6 100644 --- a/modelscope/models/nlp/space/model/generator.py +++ b/modelscope/models/nlp/space/model/generator.py @@ -38,24 +38,24 @@ def gather(var, idx): return var -class Generator(object): +class SpaceGenerator(object): """ Genrator class. """ _registry = dict() @classmethod def register(cls, name): - Generator._registry[name] = cls + SpaceGenerator._registry[name] = cls return @staticmethod def by_name(name): - return Generator._registry[name] + return SpaceGenerator._registry[name] @staticmethod def create(config, *args, **kwargs): """ Create generator. """ - generator_cls = Generator.by_name(config.Generator.generator) + generator_cls = SpaceGenerator.by_name(config.Generator.generator) return generator_cls(config, *args, **kwargs) def __init__(self, config, reader): @@ -83,7 +83,7 @@ class Generator(object): raise NotImplementedError -class BeamSearch(Generator): +class BeamSearch(SpaceGenerator): """ BeamSearch generator. """ def __init__(self, config, reader): diff --git a/modelscope/models/nlp/space/space_for_dialog_modeling.py b/modelscope/models/nlp/space/space_for_dialog_modeling.py index 4c65c7d1..efa9b851 100644 --- a/modelscope/models/nlp/space/space_for_dialog_modeling.py +++ b/modelscope/models/nlp/space/space_for_dialog_modeling.py @@ -41,7 +41,7 @@ class SpaceForDialogModeling(TorchModel): self.text_field = kwargs.pop( 'text_field', - MultiWOZBPETextField(self.model_dir, config=self.config)) + MultiWOZBPETextField(config=self.config, model_dir=self.model_dir)) self.generator = SpaceGenerator.create( self.config, reader=self.text_field) self.model = SpaceModelBase.create( diff --git a/modelscope/preprocessors/space/dialog_modeling_preprocessor.py b/modelscope/preprocessors/space/dialog_modeling_preprocessor.py index a2157c2b..c461ade1 100644 --- a/modelscope/preprocessors/space/dialog_modeling_preprocessor.py +++ b/modelscope/preprocessors/space/dialog_modeling_preprocessor.py @@ -35,7 +35,7 @@ class DialogModelingPreprocessor(Preprocessor): self.config.use_gpu = self.config.use_gpu and torch.cuda.is_available() self.text_field = MultiWOZBPETextField( - self.model_dir, config=self.config) + config=self.config, model_dir=self.model_dir) @type_assert(object, Dict) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: diff --git a/modelscope/preprocessors/space/fields/gen_field.py b/modelscope/preprocessors/space/fields/gen_field.py index 5bff360f..32346bd5 100644 --- a/modelscope/preprocessors/space/fields/gen_field.py +++ b/modelscope/preprocessors/space/fields/gen_field.py @@ -2,9 +2,11 @@ import os import random +from asyncio import constants from collections import OrderedDict from itertools import chain +import json import numpy as np from modelscope.preprocessors.space.tokenizer import Tokenizer @@ -117,7 +119,8 @@ class BPETextField(object): return self.tokenizer.convert_tokens_to_ids([self.eos_d_token])[0] def __init__(self, config): - self.gpu = 0 + self.train, self.dev, self.test = [], [], [] + self.gpu = config.Trainer.gpu self.tokenizer = None self.vocab = None self.db = None @@ -249,13 +252,9 @@ class BPETextField(object): for dial in data: batch.append(dial) if len(batch) == self.batch_size: - # print('batch size: %d, batch num +1'%(len(batch))) all_batches.append(batch) batch = [] - # if remainder > 1/2 batch_size, just put them in the previous batch, otherwise form a new batch - # print('last batch size: %d, batch num +1'%(len(batch))) - # if (len(batch) % len(cfg.cuda_device)) != 0: - # batch = batch[:-(len(batch) % len(cfg.cuda_device))] + # TODO deal with deleted data if self.gpu <= 1: if len(batch) > 0.5 * self.batch_size: @@ -308,7 +307,7 @@ class BPETextField(object): class MultiWOZBPETextField(BPETextField): - def __init__(self, model_dir, config): + def __init__(self, config, **kwargs): super(MultiWOZBPETextField, self).__init__(config) import spacy @@ -327,8 +326,12 @@ class MultiWOZBPETextField(BPETextField): ) self.nlp = spacy.load('en_core_web_sm') + if config.do_train: + db_dir = kwargs['data_dir'] + else: + db_dir = kwargs['model_dir'] self.db = MultiWozDB( - model_dir, { + db_dir, { 'attraction': 'db/attraction_db_processed.json', 'hospital': 'db/hospital_db_processed.json', 'hotel': 'db/hotel_db_processed.json', @@ -337,14 +340,14 @@ class MultiWOZBPETextField(BPETextField): 'taxi': 'db/taxi_db_processed.json', 'train': 'db/train_db_processed.json', }) - self._build_vocab(model_dir) + self._build_vocab(db_dir) special_tokens = [ self.pad_token, self.bos_token, self.eos_token, self.unk_token ] special_tokens.extend(self.add_sepcial_tokens()) self.tokenizer = Tokenizer( - vocab_path=os.path.join(model_dir, ModelFile.VOCAB_FILE), + vocab_path=os.path.join(kwargs['model_dir'], ModelFile.VOCAB_FILE), special_tokens=special_tokens, tokenizer_type=config.BPETextField.tokenizer_type) self.understand_ids = self.tokenizer.convert_tokens_to_ids( @@ -352,6 +355,26 @@ class MultiWOZBPETextField(BPETextField): self.policy_ids = self.tokenizer.convert_tokens_to_ids( self.policy_tokens) + if config.do_train: + test_list = [ + line.strip().lower() for line in open( + os.path.join(kwargs['data_dir'], 'testListFile.json'), + 'r').readlines() + ] + dev_list = [ + line.strip().lower() for line in open( + os.path.join(kwargs['data_dir'], 'valListFile.json'), + 'r').readlines() + ] + + self.dev_files, self.test_files = {}, {} + for fn in test_list: + self.test_files[fn.replace('.json', '')] = 1 + for fn in dev_list: + self.dev_files[fn.replace('.json', '')] = 1 + + self._load_data(kwargs['data_dir']) + return def get_ids(self, data: str): @@ -414,7 +437,6 @@ class MultiWOZBPETextField(BPETextField): name_to_set = {'train': self.train, 'test': self.test, 'dev': self.dev} dial = name_to_set[set_name] turn_bucket = self._bucket_by_turn(dial) - # self._shuffle_turn_bucket(turn_bucket) all_batches = [] if set_name not in self.set_stats: @@ -433,19 +455,13 @@ class MultiWOZBPETextField(BPETextField): except Exception: log_str += 'turn num:%d, dial num: %d, batch num: %d last batch len: %d\n' % ( k, len(turn_bucket[k]), len(batches), 0.0) - # print("turn num:%d, dial num:v%d, batch num: %d, "%(k, len(turn_bucket[k]), len(batches))) + num_training_steps += k * len(batches) num_turns += k * len(turn_bucket[k]) num_dials += len(turn_bucket[k]) all_batches += batches log_str += 'total batch num: %d\n' % len(all_batches) - # print('total batch num: %d'%len(all_batches)) - # print('dialog count: %d'%dia_count) - # return all_batches - # log stats - # logging.info(log_str) - # cfg.num_training_steps = num_training_steps * cfg.epoch_num self.set_stats[set_name][ 'num_training_steps_per_epoch'] = num_training_steps # turn-level steps self.set_stats[set_name]['num_turns'] = num_turns @@ -484,6 +500,71 @@ class MultiWOZBPETextField(BPETextField): self.vocab.load_vocab(vp) return self.vocab.vocab_size + def _load_data(self, data_dir, save_temp=True): + """ + load processed data and encode, or load already encoded data + """ + + def load_data_from_resource(data_resource): + data = json.loads( + open( + os.path.join(data_dir, data_resource), + 'r', + encoding='utf-8').read().lower()) + train, dev, test = [], [], [] + for fn, dial in data.items(): + if '.json' in fn: + fn = fn.replace('.json', '') + if self.dev_files.get(fn): + dev.append(self._get_encoded_data(fn, dial)) + elif self.test_files.get(fn): + test.append(self._get_encoded_data(fn, dial)) + else: + train.append(self._get_encoded_data(fn, dial)) + return train, dev, test + + data_processed = 'new_db_se_blank_encoded_domain.data.json' + data_resource = 'data_for_damd.json' + if save_temp: # save encoded data + # encoded: no sos, se_encoded: sos and eos + encoded_file = os.path.join(data_dir, data_processed) + + if os.path.exists(encoded_file): + logger.info( + 'Reading encoded data from {}'.format(encoded_file)) + self.data = json.loads( + open( + os.path.join(data_dir, data_resource), + 'r', + encoding='utf-8').read().lower()) + encoded_data = json.loads( + open(encoded_file, 'r', encoding='utf-8').read()) + self.train = encoded_data['train'] + self.dev = encoded_data['dev'] + self.test = encoded_data['test'] + else: + logger.info( + 'Encoding data now and save the encoded data in {}'.format( + encoded_file)) + # not exists, encode data and save + self.train, self.dev, self.test = load_data_from_resource( + data_resource) + # save encoded data + encoded_data = { + 'train': self.train, + 'dev': self.dev, + 'test': self.test + } + json.dump(encoded_data, open(encoded_file, 'w'), indent=2) + else: # directly read processed data and encode + self.train, self.dev, self.test = load_data_from_resource( + data_resource) + + random.seed(10) + random.shuffle(self.train) + logger.info('train size:{}, dev size:{}, test size:{}'.format( + len(self.train), len(self.dev), len(self.test))) + def _get_convert_str(self, sent): assert isinstance(sent, str) return ' '.join([ @@ -491,14 +572,65 @@ class MultiWOZBPETextField(BPETextField): for tok in sent.split() ]) + def _get_encoded_data(self, fn, dial): + encoded_dial = [] + for idx, t in enumerate(dial['log']): # tokenize to list of ids + enc = {} + enc['dial_id'] = fn + + enc_info_list = [ + ('user', self.sos_u_id, 'user', self.eos_u_id), + ('usdx', self.sos_u_id, 'user', self.eos_u_id), + ('resp', self.sos_r_id, 'resp', self.eos_r_id), + ('bspn', self.sos_b_id, 'constraint', self.eos_b_id), + ('bsdx', self.sos_b_id, 'cons_delex', self.eos_b_id), + ('aspn', self.sos_a_id, 'sys_act', self.eos_a_id) + ] + for enc_key, start_token, item_key, end_token in enc_info_list: + enc[enc_key] = [ + start_token + ] + self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize( + self._get_convert_str(t[item_key]))) + [end_token] + + enc['turn_num'] = t['turn_num'] + + if idx > 0 and t['turn_domain'] == '[general]': + enc['dspn'] = encoded_dial[idx - 1]['dspn'] + enc['pointer'] = encoded_dial[idx - 1]['pointer'][:4] + [ + int(i) for i in t['pointer'].split(',') + ][-2:] + enc['turn_domain'] = encoded_dial[idx - 1]['turn_domain'] + enc['db'] = encoded_dial[idx - 1]['db'] + else: + if t['turn_domain'] == '[general]': + assert not t['constraint'], f'{fn}-{idx}' + enc['dspn'] = [ + self.sos_d_id + ] + self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize( + self._get_convert_str( + t['turn_domain']))) + [self.eos_d_id] + enc['pointer'] = [int(i) for i in t['pointer'].split(',')] + enc['turn_domain'] = t['turn_domain'].split() + db_pointer = self.bspan_to_DBpointer(t['constraint'], + t['turn_domain'].split()) + enc['db'] = [ + self.sos_db_id + ] + self.tokenizer.convert_tokens_to_ids( + self.tokenizer.tokenize( + self._get_convert_str(db_pointer))) + [self.eos_db_id] + + encoded_dial.append(enc) + return encoded_dial + def bspan_to_DBpointer(self, bspan, turn_domain): constraint_dict = self.bspan_to_constraint_dict(bspan) - # print(constraint_dict) matnums = self.db.get_match_num(constraint_dict) match_dom = turn_domain[0] if len(turn_domain) == 1 else turn_domain[1] match_dom = match_dom[1:-1] if match_dom.startswith('[') else match_dom match = matnums[match_dom] - # vector = self.db.addDBPointer(match_dom, match) + vector = self.db.addDBIndicator(match_dom, match) return vector @@ -691,3 +823,67 @@ class MultiWOZBPETextField(BPETextField): inputs['labels'] = [context] # use previous turn return inputs, prompt_id + + def restore(self, resp, domain, constraint_dict, mat_ents): + restored = resp + + restored = restored.replace('[value_reference]', '53022') + restored = restored.replace('[value_car]', 'BMW') + + for d in domain: + constraint = constraint_dict.get(d, None) + if constraint: + replace_res_list = [('stay', '[value_stay]'), + ('day', '[value_day]'), + ('people', '[value_people]'), + ('time', '[value_time]'), + ('type', '[value_type]')] + for key, value_key in replace_res_list: + if key in constraint: + restored = restored.replace(value_key, constraint[key]) + + if d in mat_ents and len(mat_ents[d]) == 0: + for s in constraint: + if s == 'pricerange' and d in [ + 'hotel', 'restaurant' + ] and 'price]' in restored: + restored = restored.replace( + '[value_price]', constraint['pricerange']) + if s + ']' in restored: + restored = restored.replace( + '[value_%s]' % s, constraint[s]) + + if '[value_choice' in restored and mat_ents.get(d): + restored = restored.replace('[value_choice]', + str(len(mat_ents[d]))) + if '[value_choice' in restored: + restored = restored.replace('[value_choice]', '3') + + try: + ent = mat_ents.get(domain[-1], []) + if ent: + ent = ent[0] + + for t in restored.split(): + if '[value' in t: + slot = t[7:-1] + if ent.get(slot): + if domain[-1] == 'hotel' and slot == 'price': + slot = 'pricerange' + restored = restored.replace(t, ent[slot]) + elif slot == 'price': + if ent.get('pricerange'): + restored = restored.replace( + t, ent['pricerange']) + else: + logger.info(restored, domain) + except Exception: + logger.error(resp) + logger.error(restored) + quit() + + restored = restored.replace('[value_phone]', '62781111') + restored = restored.replace('[value_postcode]', 'CG9566') + restored = restored.replace('[value_address]', 'Parkside, Cambridge') + + return restored diff --git a/modelscope/trainers/nlp/space/dialog_modeling_trainer.py b/modelscope/trainers/nlp/space/dialog_modeling_trainer.py new file mode 100644 index 00000000..6bdd8a3a --- /dev/null +++ b/modelscope/trainers/nlp/space/dialog_modeling_trainer.py @@ -0,0 +1,130 @@ +import os +import time +from typing import Callable, Dict, Optional, Tuple, Union + +import numpy as np + +from modelscope.metainfo import Trainers +from modelscope.models.nlp.space.model.generator import SpaceGenerator +from modelscope.models.nlp.space.model.model_base import SpaceModelBase +from modelscope.preprocessors.space.fields.gen_field import \ + MultiWOZBPETextField +from modelscope.trainers.base import BaseTrainer +from modelscope.trainers.builder import TRAINERS +from modelscope.trainers.nlp.space.eval import MultiWOZEvaluator +from modelscope.trainers.nlp.space.trainer.gen_trainer import MultiWOZTrainer +from modelscope.utils.config import Config, ModelFile +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +def setup_seed(seed: int): + import random + import torch + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + + +@TRAINERS.register_module(module_name=Trainers.dialog_modeling_trainer) +class DialogModelingTrainer(BaseTrainer): + + def __init__(self, + cfg_file: Optional[str] = None, + cfg_modify_fn: Optional[Callable] = None, + *args, + **kwargs): + + super().__init__(os.path.join(kwargs['model_dir'], kwargs['cfg_name'])) + + self.cfg_modify_fn = cfg_modify_fn + self.cfg = self.rebuild_config(self.cfg) + + setup_seed(self.cfg.Trainer.seed) + + # set reader and evaluator + self.bpe = MultiWOZBPETextField(self.cfg, **kwargs) + + self.cfg.Model.num_token_embeddings = self.bpe.vocab_size + self.cfg.Model.num_turn_embeddings = self.bpe.max_ctx_turn + 1 + + if 'work_dir' in kwargs: + self.cfg.Trainer.save_dir = kwargs['work_dir'] + else: + self.cfg.Trainer.save_dir = './default_save_dir' + + # set data and data status + self.train_data = self.bpe.get_batches('train') + self.dev_data = self.bpe.get_batches('dev') + + self.evaluator = MultiWOZEvaluator(reader=self.bpe, **kwargs) + # set generator + self.generator = SpaceGenerator.create(self.cfg, reader=self.bpe) + self._load_model(**kwargs) + + def _load_model(self, **kwargs): + + def to_tensor(array): + """ + numpy array -> tensor + """ + import torch + array = torch.tensor(array) + return array.cuda( + ) if self.cfg.use_gpu and torch.cuda.is_available() else array + + # construct model + if 'model' in kwargs: + self.model = kwargs['model'] + else: + self.model = SpaceModelBase.create( + kwargs['model_dir'], + self.cfg, + reader=self.bpe, + generator=self.generator) + + import torch + # multi-gpu + if self.cfg.Trainer.gpu > 1 and torch.cuda.device_count() > 1: + self.model = torch.nn.DataParallel(self.model) + + # construct trainer + self.trainer = MultiWOZTrainer( + self.model, + to_tensor, + self.cfg, + reader=self.bpe, + evaluator=self.evaluator) + self.trainer.set_optimizers() + # load model, optimizer and lr_scheduler + self.trainer.load() + + def rebuild_config(self, cfg: Config): + if self.cfg_modify_fn is not None: + return self.cfg_modify_fn(cfg) + return cfg + + def train(self, *args, **kwargs): + logger.info('Train') + + self.trainer.train(train_data=self.train_data, dev_data=self.dev_data) + + def evaluate(self, + checkpoint_path: Optional[str] = None, + *args, + **kwargs) -> Dict[str, float]: + logger.info('Evaluate') + self.cfg.do_infer = True + + # get best checkpoint path + pos = checkpoint_path.rfind('/') + checkpoint_name = checkpoint_path[pos + 1:] + checkpoint_dir = checkpoint_path[:pos] + + assert checkpoint_name == ModelFile.TORCH_MODEL_BIN_FILE + kwargs['model_dir'] = checkpoint_dir + self._load_model(**kwargs) + self.trainer.infer(data_type='test') diff --git a/modelscope/trainers/nlp/space/eval.py b/modelscope/trainers/nlp/space/eval.py new file mode 100644 index 00000000..f315ff07 --- /dev/null +++ b/modelscope/trainers/nlp/space/eval.py @@ -0,0 +1,952 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright from https://github.com/thu-spmi/LABES +# Copyright from https://github.com/TonyNemo/UBAR-MultiWOZ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from collections import Counter + +import json +import numpy as np +from nltk.util import ngrams +from sklearn.metrics import f1_score + +from modelscope.utils.nlp.space import ontology, utils +from modelscope.utils.nlp.space.clean_dataset import clean_slot_values + + +def similar(a, b): + return a == b or a in b or b in a or a.split()[0] == b.split( + )[0] or a.split()[-1] == b.split()[-1] + + +def setsub(a, b): + junks_a = [] + useless_constraint = [ + 'temperature', 'week', 'est ', 'quick', 'reminder', 'near' + ] + for i in a: + flg = False + for j in b: + if similar(i, j): + flg = True + if not flg: + junks_a.append(i) + for junk in junks_a: + flg = False + for item in useless_constraint: + if item in junk: + flg = True + if not flg: + return False + return True + + +def setsim(a, b): + a, b = set(a), set(b) + return setsub(a, b) and setsub(b, a) + + +def DA_evaluate(preds, labels): + preds = np.array(preds) + labels = np.array(labels) + results = {} + + for avg_name in ['micro']: + my_f1_score = f1_score(y_true=labels, y_pred=preds, average=avg_name) + results['f1_{}'.format(avg_name)] = my_f1_score + + return results + + +class BLEUScorer(object): + # BLEU score calculator via GentScorer interface + # it calculates the BLEU-4 by taking the entire corpus in + # Calulate based multiple candidates against multiple references + def __init__(self): + pass + + def score(self, parallel_corpus): + + # containers + count = [0, 0, 0, 0] + clip_count = [0, 0, 0, 0] + r = 0 + c = 0 + weights = [0.25, 0.25, 0.25, 0.25] + + # accumulate ngram statistics + for hyps, refs in parallel_corpus: + hyps = [hyp.split() for hyp in hyps] + refs = [ref.split() for ref in refs] + for hyp in hyps: + + for i in range(4): + # accumulate ngram counts + hypcnts = Counter(ngrams(hyp, i + 1)) + cnt = sum(hypcnts.values()) + count[i] += cnt + + # compute clipped counts + max_counts = {} + for ref in refs: + refcnts = Counter(ngrams(ref, i + 1)) + for ng in hypcnts: + max_counts[ng] = max( + max_counts.get(ng, 0), refcnts[ng]) + clipcnt = \ + dict((ng, min(count, max_counts[ng])) for ng, count in hypcnts.items()) + clip_count[i] += sum(clipcnt.values()) + + # accumulate r & c + bestmatch = [1000, 1000] + for ref in refs: + if bestmatch[0] == 0: + break + diff = abs(len(ref) - len(hyp)) + if diff < bestmatch[0]: + bestmatch[0] = diff + bestmatch[1] = len(ref) + r += bestmatch[1] + c += len(hyp) + + # computing bleu score + p0 = 1e-7 + bp = \ + 1 if c > r else math.exp(1 - float(r) / float(c)) + p_ns = \ + [float(clip_count[i]) / float(count[i] + p0) + p0 for i in range(4)] + s = \ + math.fsum(w * math.log(p_n) for w, p_n in zip(weights, p_ns) if p_n) + bleu = bp * math.exp(s) + return bleu * 100 + + +"""" +For the data preparation and evaluation on MultiWOZ2.0/2.1, +we refer to the code of UBAR (https://github.com/TonyNemo/UBAR-MultiWOZ) +""" + + +class MultiWOZEvaluator(object): + + def __init__(self, reader, **kwargs): + self.reader = reader + self.domains = ontology.all_domains + self.all_data = self.reader.data + self.test_data = self.reader.test + + self.bleu_scorer = BLEUScorer() + + self.all_info_slot = [] + for d, s_list in ontology.informable_slots.items(): + for s in s_list: + self.all_info_slot.append(d + '-' + s) + + # only evaluate these slots for dialog success + self.requestables = ['phone', 'address', 'postcode', 'reference', 'id'] + self.db_dir = kwargs['data_dir'] + + def pack_dial(self, data): + dials = {} + for turn in data: + dial_id = turn['dial_id'] + if dial_id not in dials: + dials[dial_id] = [] + dials[dial_id].append(turn) + return dials + + def validation_metric(self, data, fout=None): + bleu = self.bleu_metric(data) + # accu_single_dom, accu_multi_dom, multi_dom_num = self.domain_eval(data) + success, match, req_offer_counts, dial_num = \ + self.context_to_response_eval(data, same_eval_as_cambridge=True, fout=fout) + return bleu, success, match + + def bleu_metric(self, data, eval_dial_list=None): + gen, truth = [], [] + for row in data: + if eval_dial_list and row[ + 'dial_id'] + '.json' not in eval_dial_list: + continue + gen.append(row['resp_gen']) + truth.append(row['resp']) + wrap_generated = [[_] for _ in gen] + wrap_truth = [[_] for _ in truth] + if gen and truth: + try: + sc = self.bleu_scorer.score(zip(wrap_generated, wrap_truth)) + except Exception: + sc = 0.0 + else: + sc = 0.0 + return sc + + def context_to_response_eval(self, + data, + eval_dial_list=None, + same_eval_as_cambridge=False, + fout=None): + dials = self.pack_dial(data) + counts = {} + for req in self.requestables: + counts[req + '_total'] = 0 + counts[req + '_offer'] = 0 + + dial_num, successes, matches = 0, 0, 0 + + for dial_id in dials: + if eval_dial_list and dial_id + '.json' not in eval_dial_list: + continue + dial = dials[dial_id] + reqs = {} + goal = {} + if '.json' not in dial_id and '.json' in list( + self.all_data.keys())[0]: + dial_id = dial_id + '.json' + for domain in ontology.all_domains: + if self.all_data[dial_id]['goal'].get(domain): + true_goal = self.all_data[dial_id]['goal'] + goal = self._parseGoal(goal, true_goal, domain) + + for domain in goal.keys(): + reqs[domain] = goal[domain]['requestable'] + + success, match, stats, counts = \ + self._evaluateGeneratedDialogue(dial, goal, reqs, counts, + same_eval_as_cambridge=same_eval_as_cambridge, fout=fout) + + successes += success + matches += match + dial_num += 1 + + succ_rate = successes / (float(dial_num) + 1e-10) * 100 + match_rate = matches / (float(dial_num) + 1e-10) * 100 + return succ_rate, match_rate, counts, dial_num + + def _evaluateGeneratedDialogue(self, + dialog, + goal, + real_requestables, + counts, + soft_acc=False, + same_eval_as_cambridge=False, + fout=None): + """Evaluates the dialogue created by the model. + First we load the user goal of the dialogue, then for each turn + generated by the system we look for key-words. + For the Inform rate we look whether the entity was proposed. + For the Success rate we look for requestables slots""" + # for computing corpus success + requestables = self.requestables + + # CHECK IF MATCH HAPPENED + provided_requestables = {} + venue_offered = {} + domains_in_goal = [] + log = [] + bspans = {} + + for domain in goal.keys(): + venue_offered[domain] = [] + provided_requestables[domain] = [] + domains_in_goal.append(domain) + + for t, turn in enumerate(dialog): + if t == 0: + continue + if fout is not None: + log.append({ + 'turn_num': turn['turn_num'], + 'turn_domain': turn['dspn'], + 'user': turn['user'], + 'aspn': turn['aspn'], + 'aspn_gen': turn['aspn_gen'], + 'resp': turn['resp'], + 'resp_gen': turn['resp_gen'], + 'pointer': turn['pointer'], + }) + + sent_t = turn['resp_gen'] + + for domain in goal.keys(): + # for computing success + if same_eval_as_cambridge: + # [restaurant_name], [hotel_name] instead of [value_name] + if self.reader.use_true_domain_for_ctr_eval: + dom_pred = [d[1:-1] for d in turn['dspn'].split()] + else: + dom_pred = [d[1:-1] for d in turn['dspn_gen'].split()] + + if domain not in dom_pred: # fail + continue + if '[value_name]' in sent_t or '[value_id]' in sent_t: + if domain in [ + 'restaurant', 'hotel', 'attraction', 'train' + ]: + # HERE YOU CAN PUT YOUR BELIEF STATE ESTIMATION + if not self.reader.use_true_curr_bspn and not self.reader.use_true_bspn_for_ctr_eval: + bspn = turn['bspn_gen'] + else: + bspn = turn['bspn'] + + constraint_dict = self.reader.bspan_to_constraint_dict( + bspn) + if constraint_dict.get(domain): + venues = self.reader.db.queryJsons( + domain, + constraint_dict[domain], + return_name=True) + else: + venues = [] + + if len(venue_offered[domain]) == 0 and venues: + + venue_offered[domain] = venues + bspans[domain] = constraint_dict[domain] + else: + flag = False + for ven in venues: + if ven not in venue_offered[domain]: + flag = True + break + if flag and venues: # sometimes there are no results so sample won't work + venue_offered[domain] = venues + bspans[domain] = constraint_dict[domain] + else: # not limited so we can provide one + venue_offered[domain] = '[value_name]' + + # ATTENTION: assumption here - we didn't provide phone or address twice! etc + for requestable in requestables: + if requestable == 'reference': + if '[value_reference]' in sent_t: + if domain in ['restaurant', 'hotel', 'train']: + if 'booked' in turn['pointer'] or 'ok' in turn[ + 'pointer'] or '[value_reference]' in turn[ + 'resp']: + # if pointer was allowing for that? + provided_requestables[domain].append( + 'reference') + else: + provided_requestables[domain].append( + 'reference') + else: + if '[value_' + requestable + ']' in sent_t: + provided_requestables[domain].append(requestable) + + # if name was given in the task + for domain in goal.keys(): + # if name was provided for the user, the match is being done automatically + if 'name' in goal[domain]['informable']: + venue_offered[domain] = '[value_name]' + + # special domains - entity does not need to be provided + if domain in ['taxi', 'police', 'hospital']: + venue_offered[domain] = '[value_name]' + + if domain == 'train': + if not venue_offered[domain] and 'id' not in goal[domain][ + 'requestable']: + venue_offered[domain] = '[value_name]' + """ + Given all inform and requestable slots + we go through each domain from the user goal + and check whether right entity was provided and + all requestable slots were given to the user. + The dialogue is successful if that's the case for all domains. + """ + # HARD EVAL + stats = { + 'restaurant': [0, 0, 0], + 'hotel': [0, 0, 0], + 'attraction': [0, 0, 0], + 'train': [0, 0, 0], + 'taxi': [0, 0, 0], + 'hospital': [0, 0, 0], + 'police': [0, 0, 0] + } + + match = 0 + success = 0 + # MATCH + for domain in goal.keys(): + match_stat = 0 + if domain in ['restaurant', 'hotel', 'attraction', 'train']: + goal_venues = self.reader.db.queryJsons( + domain, goal[domain]['informable'], return_name=True) + if type(venue_offered[domain] + ) is str and '_name' in venue_offered[domain]: + match += 1 + match_stat = 1 + elif len(venue_offered[domain]) > 0 and len( + set(venue_offered[domain]) & set(goal_venues)) > 0: + match += 1 + match_stat = 1 + else: + if '_name]' in venue_offered[domain]: + match += 1 + match_stat = 1 + + stats[domain][0] = match_stat + stats[domain][2] = 1 + + if soft_acc: + match = float(match) / len(goal.keys()) + else: + if match == len(goal.keys()): + match = 1.0 + else: + match = 0.0 + + for domain in domains_in_goal: + for request in real_requestables[domain]: + counts[request + '_total'] += 1 + if request in provided_requestables[domain]: + counts[request + '_offer'] += 1 + + # SUCCESS + if fout is not None: + for domain in domains_in_goal: + success_stat = 0 + domain_success = 0 + if len(real_requestables[domain]) == 0: + success += 1 + success_stat = 1 + stats[domain][1] = success_stat + continue + # if values in sentences are super set of requestables + for request in real_requestables[domain]: + if request in provided_requestables[domain]: + domain_success += 1 + + if domain_success == len(real_requestables[domain]): + success += 1 + success_stat = 1 + + stats[domain][1] = success_stat + + # final eval + if soft_acc: + success = float(success) / len(real_requestables) + else: + if success >= len(real_requestables): + success = 1 + else: + success = 0 + else: + if match == 1.0: + for domain in domains_in_goal: + success_stat = 0 + domain_success = 0 + if len(real_requestables[domain]) == 0: + success += 1 + success_stat = 1 + stats[domain][1] = success_stat + continue + # if values in sentences are super set of requestables + for request in real_requestables[domain]: + if request in provided_requestables[domain]: + domain_success += 1 + + if domain_success == len(real_requestables[domain]): + success += 1 + success_stat = 1 + + stats[domain][1] = success_stat + + # final eval + if soft_acc: + success = float(success) / len(real_requestables) + else: + if success >= len(real_requestables): + success = 1 + else: + success = 0 + + if fout is not None and success == 0: + sample = { + dialog[0]['dial_id']: { + 'log': log, + 'real_requestables': real_requestables, + 'provided_requestables': provided_requestables + } + } + line = json.dumps(sample) + fout.write(line) + fout.write('\n') + + return success, match, stats, counts + + def _parseGoal(self, goal, true_goal, domain): + """Parses user goal into dictionary format.""" + goal[domain] = {} + goal[domain] = {'informable': {}, 'requestable': [], 'booking': []} + if 'info' in true_goal[domain]: + if domain == 'train': + # we consider dialogues only where train had to be booked! + if 'book' in true_goal[domain]: + goal[domain]['requestable'].append('reference') + if 'reqt' in true_goal[domain]: + if 'id' in true_goal[domain]['reqt']: + goal[domain]['requestable'].append('id') + else: + if 'reqt' in true_goal[domain]: + for s in true_goal[domain]['reqt']: # addtional requests: + if s in [ + 'phone', 'address', 'postcode', 'reference', + 'id' + ]: + # ones that can be easily delexicalized + goal[domain]['requestable'].append(s) + if 'book' in true_goal[domain]: + goal[domain]['requestable'].append('reference') + + for s, v in true_goal[domain]['info'].items(): + s_, v_ = clean_slot_values(self.db_dir, domain, s, v) + if len(v_.split()) > 1: + v_ = ' '.join( + [token.text for token in self.reader.nlp(v_)]).strip() + goal[domain]['informable'][s_] = v_ + + if 'book' in true_goal[domain]: + goal[domain]['booking'] = true_goal[domain]['book'] + return goal + + +class GenericEvaluator: + + def __init__(self, reader): + self.reader = reader + self.metric_dict = {} + + def pack_dial(self, data): + dials = {} + for turn in data: + dial_id = turn['dial_id'] + if dial_id not in dials: + dials[dial_id] = [] + dials[dial_id].append(turn) + return dials + + def run_metrics(self, results): + raise ValueError('Please specify the evaluator first') + + def bleu_metric(self, data, type='bleu'): + gen, truth = [], [] + for row in data: + gen.append(self.clean(row['resp_gen'])) + # gen.append(self.clean(row['resp'])) + truth.append(self.clean(row['resp'])) + wrap_generated = [[_] for _ in gen] + wrap_truth = [[_] for _ in truth] + sc = BLEUScorer().score(zip(wrap_generated, wrap_truth)) + return sc + + def _normalize_constraint(self, + constraint, + ignore_dontcare=False, + intersection=True): + """ + Normalize belief span, e.g. delete repeated words + :param constraint - {'food': 'asian oritental', 'pricerange': 'cheap'} + :param intersection: if true, only keeps the words that appear in th ontology + we set intersection=True as in previous works + :returns: normalized constraint dict + e.g. - {'food': 'asian oritental', 'pricerange': 'cheap', 'area': ''} + """ + normalized = {} + for s in self.informable_slots: + normalized[s] = '' + for s, v in constraint.items(): + if ignore_dontcare and v == 'dontcare': + continue + if intersection and v != 'dontcare' and v not in self.entities_flat: + continue + + normalized[s] = v + + return normalized + + def _normalize_act(self, aspn, intersection=False): + aspn_list = aspn.split('|') + normalized = {} + for i, v in enumerate(aspn_list): + seq = v.strip() + word_set = set() + for w in seq.split(): + if intersection: + if self.reader.act_order[i] == 'av': + if '[value' in w: + word_set.add(w) + else: + if w in self.requestable_slots: + word_set.add(w) + else: + word_set.add(w) + normalized[self.reader.act_order[i]] = word_set + return normalized + + def tracker_metric(self, data, normalize=True): + # turn level metric + tp, fp, fn, db_correct = 0, 0, 0, 0 + goal_accr, slot_accr, total = 0, {}, 1e-8 + for s in self.informable_slots: + slot_accr[s] = 0 + + for row in data: + if normalize: + gen = self._normalize_constraint(row['bspn_gen']) + truth = self._normalize_constraint(row['bspn']) + else: + gen = self._normalize_constraint( + row['bspn_gen'], intersection=False) + truth = self._normalize_constraint( + row['bspn'], intersection=False) + valid = 'thank' not in row['user'] and 'bye' not in row['user'] + if valid: + for slot, value in gen.items(): + if value in truth[slot]: + tp += 1 + else: + fp += 1 + for slot, value in truth.items(): + if value not in gen[slot]: + fn += 1 + + if truth and valid: + total += 1 + for s in self.informable_slots: + if gen[s] == truth[s]: + slot_accr[s] += 1 + if gen == truth: + goal_accr += 1 + if row.get('db_gen') and row.get('db_match'): + if row['db_gen'] == row['db_match']: + db_correct += 1 + precision, recall = tp / (tp + fp + 1e-8), tp / (tp + fn + 1e-8) + f1 = 2 * precision * recall / (precision + recall + 1e-8) + goal_accr /= total + db_correct /= total + for s in slot_accr: + slot_accr[s] /= total + return precision, recall, f1, goal_accr, slot_accr, db_correct + + def request_metric(self, data): + # dialog level metric + dials = self.pack_dial(data) + tp, fp, fn = 0, 0, 0 + for dial_id in dials: + truth_req, gen_req = set(), set() + dial = dials[dial_id] + for turn_num, turn in enumerate(dial): + resp_gen_token = self.clean(turn['resp_gen']).split() + resp_token = self.clean(turn['resp']).split() + for w in resp_gen_token: + if '[value_' in w and w.endswith( + ']') and w != '[value_name]': + gen_req.add(w[1:-1].split('_')[1]) + for w in resp_token: + if '[value_' in w and w.endswith( + ']') and w != '[value_name]': + truth_req.add(w[1:-1].split('_')[1]) + for req in gen_req: + if req in truth_req: + tp += 1 + else: + fp += 1 + for req in truth_req: + if req not in gen_req: + fn += 1 + precision, recall = tp / (tp + fp + 1e-8), tp / (tp + fn + 1e-8) + f1 = 2 * precision * recall / (precision + recall + 1e-8) + return f1, precision, recall + + def act_metric(self, data): + # turn level metric + tp, fp, fn = { + 'all_s': 0, + 'all_v': 0 + }, { + 'all_s': 0, + 'all_v': 0 + }, { + 'all_s': 0, + 'all_v': 0 + } + for s in self.requestable_slots: + tp[s], fp[s], fn[s] = 0, 0, 0 + tp['[value_%s]' % s], fp['[value_%s]' % s], fn['[value_%s]' + % s] = 0, 0, 0 + + for row in data: + gen = self._normalize_act(row['aspn_gen']) + truth = self._normalize_act(row['aspn']) + valid = 'thank' not in row['user'] and 'bye' not in row['user'] + if valid: + # how well the act decoder captures user's requests + for value in gen['av']: + if value in truth['av']: + tp['all_v'] += 1 + if tp.get(value): + tp[value] += 1 + else: + fp['all_v'] += 1 + if fp.get(value): + fp[value] += 1 + for value in truth['av']: + if value not in gen['av']: + fn['all_v'] += 1 + if fn.get(value): + fn[value] += 1 + + # how accurately the act decoder predicts system's question + if 'as' not in gen: + continue + for slot in gen['as']: + if slot in truth['as']: + tp['all_s'] += 1 + if tp.get(slot): + tp[slot] += 1 + else: + fp['all_s'] += 1 + if fp.get(slot): + fp[slot] += 1 + for slot in truth['as']: + if slot not in gen['as']: + fn['all_s'] += 1 + if fn.get(slot): + fn[slot] += 1 + + result = {} + for k, v in tp.items(): + precision, recall = tp[k] / (tp[k] + fp[k] + 1e-8), tp[k] / ( + tp[k] + fn[k] + 1e-8) + f1 = 2 * precision * recall / (precision + recall + 1e-8) + result[k] = [f1, precision, recall] + return result + + +""" +For the data preparation and evaluation on In-Car Assistant/CamRest, +we refer to the code of LABES (https://github.com/thu-spmi/LABES) +""" + + +class CamRestEvaluator(GenericEvaluator): + + def __init__(self, reader): + super().__init__(reader) + self.entities_flat, self.entitiy_to_slot_dict = self.get_entities( + self.reader.ontology_path) + self.informable_slots = self.reader.otlg.informable_slots + self.requestable_slots = self.reader.otlg.requestable_slots + + def run_metrics(self, results): + metrics = {} + bleu = self.bleu_metric(results) + p, r, f1, goal_acc, slot_acc, db_acc = self.tracker_metric(results) + match = self.match_metric(results) + req_f1, req_p, req_r = self.request_metric(results) + + metrics['bleu'] = bleu + metrics['match'] = match + metrics['req_f1'] = req_f1 + metrics['joint_goal'] = goal_acc + metrics['slot_accu'] = slot_acc + metrics['slot-p/r/f1'] = (p, r, f1) + metrics['db_acc'] = db_acc + + return metrics + + def get_entities(self, entity_path): + entities_flat = [] + entitiy_to_slot_dict = {} + raw_entities = json.loads(open(entity_path).read().lower()) + for s in raw_entities['informable']: + entities_flat.extend(raw_entities['informable'][s]) + for v in raw_entities['informable'][s]: + entitiy_to_slot_dict[v] = s + return entities_flat, entitiy_to_slot_dict + + def constraint_same(self, truth_cons, gen_cons): + if not truth_cons and not gen_cons: + return True + if not truth_cons or not gen_cons: + return False + return setsim(gen_cons, truth_cons) + + def match_metric(self, data): + dials = self.pack_dial(data) + match, total = 0, 1e-8 + for dial_id in dials: + dial = dials[dial_id] + truth_cons, gen_cons = {'1': '', '2': '', '3': ''}, None + for turn_num, turn in enumerate(dial): + # find the last turn which the system provide an entity + if '[value' in turn['resp_gen']: + gen_cons = self._normalize_constraint( + turn['bspn_gen'], ignore_dontcare=True) + if '[value' in turn['resp']: + truth_cons = self._normalize_constraint( + turn['bspn'], ignore_dontcare=True) + if not gen_cons: + # if no entity is provided, choose the state of the last dialog turn + gen_cons = self._normalize_constraint( + dial[-1]['bspn_gen'], ignore_dontcare=True) + if list(truth_cons.values()) != ['', '', '']: + if gen_cons == truth_cons: + match += 1 + total += 1 + + return match / total + + def clean(self, resp): + # we use the same clean process as in Sequicity, SEDST, FSDM + # to ensure comparable results + resp = resp.replace(f'{self.reader.sos_r_token} ', '') + resp = resp.replace(f' {self.reader.eos_r_token}', '') + resp = f'{self.reader.sos_r_token} {resp} {self.reader.eos_r_token}' + for value, slot in self.entitiy_to_slot_dict.items(): + + resp = utils.clean_replace(resp, value, '[value_%s]' % slot) + return resp + + +class KvretEvaluator(GenericEvaluator): + + def __init__(self, reader): + super().__init__(reader) + self.entities_flat, self.entitiy_to_slot_dict = self.get_entities( + self.reader.ontology_path) + self.informable_slots = self.reader.otlg.informable_slots + self.requestable_slots = self.reader.otlg.requestable_slots + + def run_metrics(self, results): + metrics = {} + bleu = self.bleu_metric(results) + p, r, f1, goal_acc, slot_acc, db_acc = self.tracker_metric( + results, normalize=True) + match = self.match_metric(results) + req_f1, req_p, req_r = self.request_metric(results) + + metrics['bleu'] = bleu + metrics['match'] = match + metrics['req_f1'] = req_f1 + metrics['joint_goal'] = goal_acc + metrics['slot_accu'] = slot_acc + metrics['slot-p/r/f1'] = (p, r, f1) + metrics['db_acc'] = db_acc + + return metrics + + def _normalize_constraint(self, + constraint, + ignore_dontcare=False, + intersection=True): + """ + Normalize belief span, e.g. delete repeated words + :param constraint - {'food': 'asian oritental', 'pricerange': 'cheap'} + :param intersection: if true, only keeps the words that appear in th ontology + we set intersection=True as in previous works + :returns: normalized constraint dict + e.g. - {'food': 'asian oritental', 'pricerange': 'cheap', 'area': ''} + """ + junk = [ + 'good', 'great', 'quickest', 'shortest', 'route', 'week', + 'fastest', 'nearest', 'next', 'closest', 'way', 'mile', 'activity', + 'restaurant', 'appointment' + ] + normalized = {} + for s in self.informable_slots: + normalized[s] = '' + for s, v in constraint.items(): + for j in junk: + v = ' '.join(v.replace(j, '').split()) + if intersection and v not in self.entities_flat: + continue + + if s in self.informable_slots: + normalized[s] = v + else: + # TODO only use slot (not domain) in s for matching !!! + pass + + return normalized + + def get_entities(self, entity_path): + entities_flat = [] + entitiy_to_slot_dict = {} + + entitiy_to_slot_dict = self.reader.entity_dict + for s in entitiy_to_slot_dict: + if s not in entities_flat: + entities_flat.append(s) + return entities_flat, entitiy_to_slot_dict + + def constraint_same(self, truth_cons, gen_cons): + if not truth_cons and not gen_cons: + return True + if not truth_cons or not gen_cons: + return False + return setsim(gen_cons, truth_cons) + + def match_metric(self, data): + dials = self.pack_dial(data) + match, total = 0, 1e-8 + for dial_id in dials: + dial = dials[dial_id] + truth_cons, gen_cons = { + '1': '', + '2': '', + '3': '', + '4': '', + '5': '', + '6': '', + '7': '', + '8': '', + '9': '', + '10': '', + '11': '' + }, None + for turn_num, turn in enumerate(dial): + # find the last turn which the system provide an entity + if '[value' in turn['resp_gen']: + gen_cons = self._normalize_constraint( + turn['bspn_gen'], ignore_dontcare=True) + if '[value' in turn['resp']: + truth_cons = self._normalize_constraint( + turn['bspn'], ignore_dontcare=True) + + if not gen_cons: + # if no entity is provided, choose the state of the last dialog turn + gen_cons = self._normalize_constraint( + dial[-1]['bspn_gen'], ignore_dontcare=True) + + if list(truth_cons.values()) != [''] * 11: + gen_cons = [x for x in gen_cons.values() if x] + truth_cons = [x for x in truth_cons.values() if x] + if self.constraint_same(gen_cons, truth_cons): + match += 1 + total += 1 + + return match / total + + def clean(self, resp): + # we use the same clean process as in Sequicity, SEDST, FSDM + # to ensure comparable results + resp = resp.replace(f'{self.reader.sos_r_token} ', '') + resp = resp.replace(f' {self.reader.eos_r_token}', '') + resp = f'{self.reader.sos_r_token} {resp} {self.reader.eos_r_token}' + for value, slot in self.entitiy_to_slot_dict.items(): + resp = utils.clean_replace(resp, value, '[value_%s]' % slot) + return resp diff --git a/modelscope/trainers/nlp/space/trainer/gen_trainer.py b/modelscope/trainers/nlp/space/trainer/gen_trainer.py index aa28d798..34cd2f9b 100644 --- a/modelscope/trainers/nlp/space/trainer/gen_trainer.py +++ b/modelscope/trainers/nlp/space/trainer/gen_trainer.py @@ -15,27 +15,11 @@ from transformers.optimization import AdamW, get_linear_schedule_with_warmup from modelscope.trainers.nlp.space.metrics.metrics_tracker import \ MetricsTracker +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger from modelscope.utils.nlp.space import ontology -def get_logger(log_path, name='default'): - logger = logging.getLogger(name) - logger.propagate = False - logger.setLevel(logging.DEBUG) - - formatter = logging.Formatter('%(message)s') - - sh = logging.StreamHandler(sys.stdout) - sh.setFormatter(formatter) - logger.addHandler(sh) - - fh = logging.FileHandler(log_path, mode='w') - fh.setFormatter(formatter) - logger.addHandler(fh) - - return logger - - class Trainer(object): def __init__(self, @@ -51,15 +35,16 @@ class Trainer(object): self.do_train = config.do_train self.do_infer = config.do_infer - self.is_decreased_valid_metric = config.Trainer.valid_metric_name[ - 0] == '-' - self.valid_metric_name = config.Trainer.valid_metric_name[1:] - self.num_epochs = config.Trainer.num_epochs - # self.save_dir = config.Trainer.save_dir - self.log_steps = config.Trainer.log_steps - self.valid_steps = config.Trainer.valid_steps - self.save_checkpoint = config.Trainer.save_checkpoint - self.save_summary = config.Trainer.save_summary + if self.do_train: + self.is_decreased_valid_metric = config.Trainer.valid_metric_name[ + 0] == '-' + self.valid_metric_name = config.Trainer.valid_metric_name[1:] + self.num_epochs = config.Trainer.num_epochs + self.save_dir = config.Trainer.save_dir + self.log_steps = config.Trainer.log_steps + self.valid_steps = config.Trainer.valid_steps + self.save_checkpoint = config.Trainer.save_checkpoint + self.save_summary = config.Trainer.save_summary self.lr = config.Model.lr self.weight_decay = config.Model.weight_decay self.batch_size = config.Trainer.batch_size @@ -71,22 +56,21 @@ class Trainer(object): self.optimizer = optimizer self.model = model - self.func_model = self.model.module if self.gpu > 1 else self.model + self.func_model = self.model.module if self.gpu > 1 and config.use_gpu else self.model self.reader = reader self.evaluator = evaluator self.tokenizer = reader.tokenizer - # if not os.path.exists(self.save_dir): - # os.makedirs(self.save_dir) - - # self.logger = logger or get_logger(os.path.join(self.save_dir, "trainer.log"), "trainer") - self.logger = logger or get_logger('trainer.log', 'trainer') + self.logger = get_logger() self.batch_metrics_tracker = MetricsTracker() self.token_metrics_tracker = MetricsTracker() - self.best_valid_metric = float( - 'inf' if self.is_decreased_valid_metric else '-inf') + if self.do_train: + if not os.path.exists(self.save_dir): + os.makedirs(self.save_dir) + self.best_valid_metric = float( + 'inf' if self.is_decreased_valid_metric else '-inf') self.epoch = 0 def decode_generated_bspn_resp(self, generated): @@ -248,9 +232,12 @@ class Trainer(object): # Save current best model if is_best: - best_model_file = os.path.join(self.save_dir, 'best.model') + best_model_file = os.path.join(self.save_dir, + ModelFile.TORCH_MODEL_BIN_FILE) torch.save(self.model.state_dict(), best_model_file) - best_train_file = os.path.join(self.save_dir, 'best.train') + best_train_file = os.path.join( + self.save_dir, + '{}.train'.format(ModelFile.TORCH_MODEL_BIN_FILE)) torch.save(train_state, best_train_file) self.logger.info( f"Saved best model state to '{best_model_file}' with new best valid metric " @@ -324,8 +311,7 @@ class Trainer(object): self.func_model.load_state_dict(model_state_dict) self.logger.info( - f"Loaded model state from '{self.func_model.init_checkpoint}.model'" - ) + f"Loaded model state from '{self.func_model.init_checkpoint}'") def _load_train_state(): train_file = f'{self.func_model.init_checkpoint}.train' @@ -558,19 +544,17 @@ class MultiWOZTrainer(Trainer): generated_bs = outputs[0].cpu().numpy().tolist() bspn_gen = self.decode_generated_bspn(generated_bs) # check DB result - if self.reader.use_true_db_pointer: # To control whether current db is ground truth + if self.reader.use_true_db_pointer: db = turn['db'] else: db_result = self.reader.bspan_to_DBpointer( self.tokenizer.decode(bspn_gen), turn['turn_domain']) - assert len(turn['db']) == 4 - book_result = turn['db'][2] + assert len(turn['db']) == 3 assert isinstance(db_result, str) db = \ [self.reader.sos_db_id] + \ self.tokenizer.convert_tokens_to_ids([db_result]) + \ - [book_result] + \ [self.reader.eos_db_id] prompt_id = self.reader.sos_a_id @@ -636,7 +620,7 @@ class MultiWOZTrainer(Trainer): score = 0.5 * (success + match) + bleu # log results - metrics_message = 'match: %2.2f success: %2.2f bleu: %2.2f score: %.2f' %\ + metrics_message = 'match: %2.2f success: %2.2f bleu: %2.2f score: %.2f' % \ (match, success, bleu, score) message_prefix = f'[Infer][{self.epoch}]' time_cost = f'TIME-{time.time() - begin_time:.3f}' diff --git a/modelscope/utils/nlp/space/clean_dataset.py b/modelscope/utils/nlp/space/clean_dataset.py new file mode 100644 index 00000000..4578ccc4 --- /dev/null +++ b/modelscope/utils/nlp/space/clean_dataset.py @@ -0,0 +1,333 @@ +import os +import re + +from . import ontology + + +def clean_text_split_dot(text): + text = re.sub(r'([a-zT]+)\.([a-z])', r'\1 . \2', + text) # 'abc.xyz' -> 'abc . xyz' + text = re.sub(r'(\w+)\.\.? ', r'\1 . ', text) # if 'abc. ' -> 'abc . ' + return text + + +def clean_text(data_dir, text): + text = text.strip() + text = text.lower() + text = text.replace(u'’', "'") + text = text.replace(u'‘', "'") + text = text.replace(';', ',') + text = text.replace('"', ' ') + text = text.replace('/', ' and ') + text = text.replace("don't", "do n't") + text = clean_time(text) + baddata = { + r'c\.b (\d), (\d) ([a-z])\.([a-z])': r'cb\1\2\3\4', + 'c.b. 1 7 d.y': 'cb17dy', + 'c.b.1 7 d.y': 'cb17dy', + 'c.b 25, 9 a.q': 'cb259aq', + 'isc.b 25, 9 a.q': 'is cb259aq', + 'c.b2, 1 u.f': 'cb21uf', + 'c.b 1,2 q.a': 'cb12qa', + '0-122-336-5664': '01223365664', + 'postcodecb21rs': 'postcode cb21rs', + r'i\.d': 'id', + ' i d ': 'id', + 'Telephone:01223358966': 'Telephone: 01223358966', + 'depature': 'departure', + 'depearting': 'departing', + '-type': ' type', + r'b[\s]?&[\s]?b': 'bed and breakfast', + 'b and b': 'bed and breakfast', + r'guesthouse[s]?': 'guest house', + r'swimmingpool[s]?': 'swimming pool', + "wo n\'t": 'will not', + " \'d ": ' would ', + " \'m ": ' am ', + " \'re' ": ' are ', + " \'ll' ": ' will ', + " \'ve ": ' have ', + r'^\'': '', + r'\'$': '', + } + for tmpl, good in baddata.items(): + text = re.sub(tmpl, good, text) + + text = re.sub(r'([a-zT]+)\.([a-z])', r'\1 . \2', + text) # 'abc.xyz' -> 'abc . xyz' + text = re.sub(r'(\w+)\.\.? ', r'\1 . ', text) # if 'abc. ' -> 'abc . ' + + with open(os.path.join(data_dir, 'mapping.pair'), 'r') as fin: + for line in fin.readlines(): + fromx, tox = line.replace('\n', '').split('\t') + text = ' ' + text + ' ' + text = text.replace(' ' + fromx + ' ', ' ' + tox + ' ')[1:-1] + + return text + + +def clean_time(utter): + utter = re.sub(r'(\d+) ([ap]\.?m)', lambda x: x.group(1) + x.group(2), + utter) # 9 am -> 9am + utter = re.sub(r'((?= 1, 'skip test in current test level') + def test_trainer_with_model_and_args(self): + # download data set + data_multiwoz = MsDataset.load( + 'MultiWoz2.0', download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS) + data_dir = os.path.join( + data_multiwoz._hf_ds.config_kwargs['split_config']['train'], + 'data') + + # download model + model_dir = snapshot_download(self.model_id) + + # dialog finetune config + def cfg_modify_fn(cfg): + config = { + 'seed': 10, + 'gpu': 4, + 'use_data_distributed': False, + 'valid_metric_name': '-loss', + 'num_epochs': 60, + 'save_dir': self.output_dir, + 'token_loss': True, + 'batch_size': 32, + 'log_steps': 10, + 'valid_steps': 0, + 'save_checkpoint': True, + 'save_summary': False, + 'shuffle': True, + 'sort_pool_size': 0 + } + + cfg.Trainer = config + cfg.use_gpu = torch.cuda.is_available() and config['gpu'] >= 1 + return cfg + + # trainer config + kwargs = dict( + model_dir=model_dir, + cfg_name='gen_train_config.json', + data_dir=data_dir, + cfg_modify_fn=cfg_modify_fn) + + trainer = build_trainer( + name=Trainers.dialog_modeling_trainer, default_args=kwargs) + trainer.train() + checkpoint_path = os.path.join(self.output_dir, + ModelFile.TORCH_MODEL_BIN_FILE) + assert os.path.exists(checkpoint_path) + trainer.evaluate(checkpoint_path=checkpoint_path)