ly119399 yingda.chen 3 years ago
parent
commit
b41b10f897
12 changed files with 1744 additions and 76 deletions
  1. +1
    -0
      modelscope/metainfo.py
  2. +1
    -1
      modelscope/models/nlp/space/model/__init__.py
  3. +5
    -5
      modelscope/models/nlp/space/model/generator.py
  4. +1
    -1
      modelscope/models/nlp/space/space_for_dialog_modeling.py
  5. +1
    -1
      modelscope/preprocessors/space/dialog_modeling_preprocessor.py
  6. +216
    -20
      modelscope/preprocessors/space/fields/gen_field.py
  7. +130
    -0
      modelscope/trainers/nlp/space/dialog_modeling_trainer.py
  8. +952
    -0
      modelscope/trainers/nlp/space/eval.py
  9. +28
    -44
      modelscope/trainers/nlp/space/trainer/gen_trainer.py
  10. +333
    -0
      modelscope/utils/nlp/space/clean_dataset.py
  11. +8
    -4
      modelscope/utils/nlp/space/utils.py
  12. +68
    -0
      tests/trainers/test_dialog_modeling_trainer.py

+ 1
- 0
modelscope/metainfo.py View File

@@ -241,6 +241,7 @@ class Trainers(object):

# nlp trainers
bert_sentiment_analysis = 'bert-sentiment-analysis'
dialog_modeling_trainer = 'dialog-modeling-trainer'
dialog_intent_trainer = 'dialog-intent-trainer'
nlp_base_trainer = 'nlp-base-trainer'
nlp_veco_trainer = 'nlp-veco-trainer'


+ 1
- 1
modelscope/models/nlp/space/model/__init__.py View File

@@ -1,6 +1,6 @@
from .configuration_space import SpaceConfig
from .gen_unified_transformer import GenUnifiedTransformer
from .generator import Generator as SpaceGenerator
from .generator import SpaceGenerator
from .intent_unified_transformer import IntentUnifiedTransformer
from .model_base import SpaceModelBase
from .modeling_space import (SpaceForDST, SpaceForMaskedLM,


+ 5
- 5
modelscope/models/nlp/space/model/generator.py View File

@@ -38,24 +38,24 @@ def gather(var, idx):
return var


class Generator(object):
class SpaceGenerator(object):
""" Genrator class. """

_registry = dict()

@classmethod
def register(cls, name):
Generator._registry[name] = cls
SpaceGenerator._registry[name] = cls
return

@staticmethod
def by_name(name):
return Generator._registry[name]
return SpaceGenerator._registry[name]

@staticmethod
def create(config, *args, **kwargs):
""" Create generator. """
generator_cls = Generator.by_name(config.Generator.generator)
generator_cls = SpaceGenerator.by_name(config.Generator.generator)
return generator_cls(config, *args, **kwargs)

def __init__(self, config, reader):
@@ -83,7 +83,7 @@ class Generator(object):
raise NotImplementedError


class BeamSearch(Generator):
class BeamSearch(SpaceGenerator):
""" BeamSearch generator. """

def __init__(self, config, reader):


+ 1
- 1
modelscope/models/nlp/space/space_for_dialog_modeling.py View File

@@ -41,7 +41,7 @@ class SpaceForDialogModeling(TorchModel):

self.text_field = kwargs.pop(
'text_field',
MultiWOZBPETextField(self.model_dir, config=self.config))
MultiWOZBPETextField(config=self.config, model_dir=self.model_dir))
self.generator = SpaceGenerator.create(
self.config, reader=self.text_field)
self.model = SpaceModelBase.create(


+ 1
- 1
modelscope/preprocessors/space/dialog_modeling_preprocessor.py View File

@@ -35,7 +35,7 @@ class DialogModelingPreprocessor(Preprocessor):
self.config.use_gpu = self.config.use_gpu and torch.cuda.is_available()

self.text_field = MultiWOZBPETextField(
self.model_dir, config=self.config)
config=self.config, model_dir=self.model_dir)

@type_assert(object, Dict)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:


+ 216
- 20
modelscope/preprocessors/space/fields/gen_field.py View File

@@ -2,9 +2,11 @@

import os
import random
from asyncio import constants
from collections import OrderedDict
from itertools import chain

import json
import numpy as np

from modelscope.preprocessors.space.tokenizer import Tokenizer
@@ -117,7 +119,8 @@ class BPETextField(object):
return self.tokenizer.convert_tokens_to_ids([self.eos_d_token])[0]

def __init__(self, config):
self.gpu = 0
self.train, self.dev, self.test = [], [], []
self.gpu = config.Trainer.gpu
self.tokenizer = None
self.vocab = None
self.db = None
@@ -249,13 +252,9 @@ class BPETextField(object):
for dial in data:
batch.append(dial)
if len(batch) == self.batch_size:
# print('batch size: %d, batch num +1'%(len(batch)))
all_batches.append(batch)
batch = []
# if remainder > 1/2 batch_size, just put them in the previous batch, otherwise form a new batch
# print('last batch size: %d, batch num +1'%(len(batch)))
# if (len(batch) % len(cfg.cuda_device)) != 0:
# batch = batch[:-(len(batch) % len(cfg.cuda_device))]

# TODO deal with deleted data
if self.gpu <= 1:
if len(batch) > 0.5 * self.batch_size:
@@ -308,7 +307,7 @@ class BPETextField(object):

class MultiWOZBPETextField(BPETextField):

def __init__(self, model_dir, config):
def __init__(self, config, **kwargs):
super(MultiWOZBPETextField, self).__init__(config)

import spacy
@@ -327,8 +326,12 @@ class MultiWOZBPETextField(BPETextField):
)
self.nlp = spacy.load('en_core_web_sm')

if config.do_train:
db_dir = kwargs['data_dir']
else:
db_dir = kwargs['model_dir']
self.db = MultiWozDB(
model_dir, {
db_dir, {
'attraction': 'db/attraction_db_processed.json',
'hospital': 'db/hospital_db_processed.json',
'hotel': 'db/hotel_db_processed.json',
@@ -337,14 +340,14 @@ class MultiWOZBPETextField(BPETextField):
'taxi': 'db/taxi_db_processed.json',
'train': 'db/train_db_processed.json',
})
self._build_vocab(model_dir)
self._build_vocab(db_dir)

special_tokens = [
self.pad_token, self.bos_token, self.eos_token, self.unk_token
]
special_tokens.extend(self.add_sepcial_tokens())
self.tokenizer = Tokenizer(
vocab_path=os.path.join(model_dir, ModelFile.VOCAB_FILE),
vocab_path=os.path.join(kwargs['model_dir'], ModelFile.VOCAB_FILE),
special_tokens=special_tokens,
tokenizer_type=config.BPETextField.tokenizer_type)
self.understand_ids = self.tokenizer.convert_tokens_to_ids(
@@ -352,6 +355,26 @@ class MultiWOZBPETextField(BPETextField):
self.policy_ids = self.tokenizer.convert_tokens_to_ids(
self.policy_tokens)

if config.do_train:
test_list = [
line.strip().lower() for line in open(
os.path.join(kwargs['data_dir'], 'testListFile.json'),
'r').readlines()
]
dev_list = [
line.strip().lower() for line in open(
os.path.join(kwargs['data_dir'], 'valListFile.json'),
'r').readlines()
]

self.dev_files, self.test_files = {}, {}
for fn in test_list:
self.test_files[fn.replace('.json', '')] = 1
for fn in dev_list:
self.dev_files[fn.replace('.json', '')] = 1

self._load_data(kwargs['data_dir'])

return

def get_ids(self, data: str):
@@ -414,7 +437,6 @@ class MultiWOZBPETextField(BPETextField):
name_to_set = {'train': self.train, 'test': self.test, 'dev': self.dev}
dial = name_to_set[set_name]
turn_bucket = self._bucket_by_turn(dial)
# self._shuffle_turn_bucket(turn_bucket)
all_batches = []

if set_name not in self.set_stats:
@@ -433,19 +455,13 @@ class MultiWOZBPETextField(BPETextField):
except Exception:
log_str += 'turn num:%d, dial num: %d, batch num: %d last batch len: %d\n' % (
k, len(turn_bucket[k]), len(batches), 0.0)
# print("turn num:%d, dial num:v%d, batch num: %d, "%(k, len(turn_bucket[k]), len(batches)))
num_training_steps += k * len(batches)
num_turns += k * len(turn_bucket[k])
num_dials += len(turn_bucket[k])
all_batches += batches
log_str += 'total batch num: %d\n' % len(all_batches)
# print('total batch num: %d'%len(all_batches))
# print('dialog count: %d'%dia_count)
# return all_batches

# log stats
# logging.info(log_str)
# cfg.num_training_steps = num_training_steps * cfg.epoch_num
self.set_stats[set_name][
'num_training_steps_per_epoch'] = num_training_steps # turn-level steps
self.set_stats[set_name]['num_turns'] = num_turns
@@ -484,6 +500,71 @@ class MultiWOZBPETextField(BPETextField):
self.vocab.load_vocab(vp)
return self.vocab.vocab_size

def _load_data(self, data_dir, save_temp=True):
"""
load processed data and encode, or load already encoded data
"""

def load_data_from_resource(data_resource):
data = json.loads(
open(
os.path.join(data_dir, data_resource),
'r',
encoding='utf-8').read().lower())
train, dev, test = [], [], []
for fn, dial in data.items():
if '.json' in fn:
fn = fn.replace('.json', '')
if self.dev_files.get(fn):
dev.append(self._get_encoded_data(fn, dial))
elif self.test_files.get(fn):
test.append(self._get_encoded_data(fn, dial))
else:
train.append(self._get_encoded_data(fn, dial))
return train, dev, test

data_processed = 'new_db_se_blank_encoded_domain.data.json'
data_resource = 'data_for_damd.json'
if save_temp: # save encoded data
# encoded: no sos, se_encoded: sos and eos
encoded_file = os.path.join(data_dir, data_processed)

if os.path.exists(encoded_file):
logger.info(
'Reading encoded data from {}'.format(encoded_file))
self.data = json.loads(
open(
os.path.join(data_dir, data_resource),
'r',
encoding='utf-8').read().lower())
encoded_data = json.loads(
open(encoded_file, 'r', encoding='utf-8').read())
self.train = encoded_data['train']
self.dev = encoded_data['dev']
self.test = encoded_data['test']
else:
logger.info(
'Encoding data now and save the encoded data in {}'.format(
encoded_file))
# not exists, encode data and save
self.train, self.dev, self.test = load_data_from_resource(
data_resource)
# save encoded data
encoded_data = {
'train': self.train,
'dev': self.dev,
'test': self.test
}
json.dump(encoded_data, open(encoded_file, 'w'), indent=2)
else: # directly read processed data and encode
self.train, self.dev, self.test = load_data_from_resource(
data_resource)

random.seed(10)
random.shuffle(self.train)
logger.info('train size:{}, dev size:{}, test size:{}'.format(
len(self.train), len(self.dev), len(self.test)))

def _get_convert_str(self, sent):
assert isinstance(sent, str)
return ' '.join([
@@ -491,14 +572,65 @@ class MultiWOZBPETextField(BPETextField):
for tok in sent.split()
])

def _get_encoded_data(self, fn, dial):
encoded_dial = []
for idx, t in enumerate(dial['log']): # tokenize to list of ids
enc = {}
enc['dial_id'] = fn

enc_info_list = [
('user', self.sos_u_id, 'user', self.eos_u_id),
('usdx', self.sos_u_id, 'user', self.eos_u_id),
('resp', self.sos_r_id, 'resp', self.eos_r_id),
('bspn', self.sos_b_id, 'constraint', self.eos_b_id),
('bsdx', self.sos_b_id, 'cons_delex', self.eos_b_id),
('aspn', self.sos_a_id, 'sys_act', self.eos_a_id)
]
for enc_key, start_token, item_key, end_token in enc_info_list:
enc[enc_key] = [
start_token
] + self.tokenizer.convert_tokens_to_ids(
self.tokenizer.tokenize(
self._get_convert_str(t[item_key]))) + [end_token]

enc['turn_num'] = t['turn_num']

if idx > 0 and t['turn_domain'] == '[general]':
enc['dspn'] = encoded_dial[idx - 1]['dspn']
enc['pointer'] = encoded_dial[idx - 1]['pointer'][:4] + [
int(i) for i in t['pointer'].split(',')
][-2:]
enc['turn_domain'] = encoded_dial[idx - 1]['turn_domain']
enc['db'] = encoded_dial[idx - 1]['db']
else:
if t['turn_domain'] == '[general]':
assert not t['constraint'], f'{fn}-{idx}'
enc['dspn'] = [
self.sos_d_id
] + self.tokenizer.convert_tokens_to_ids(
self.tokenizer.tokenize(
self._get_convert_str(
t['turn_domain']))) + [self.eos_d_id]
enc['pointer'] = [int(i) for i in t['pointer'].split(',')]
enc['turn_domain'] = t['turn_domain'].split()
db_pointer = self.bspan_to_DBpointer(t['constraint'],
t['turn_domain'].split())
enc['db'] = [
self.sos_db_id
] + self.tokenizer.convert_tokens_to_ids(
self.tokenizer.tokenize(
self._get_convert_str(db_pointer))) + [self.eos_db_id]

encoded_dial.append(enc)
return encoded_dial

def bspan_to_DBpointer(self, bspan, turn_domain):
constraint_dict = self.bspan_to_constraint_dict(bspan)
# print(constraint_dict)
matnums = self.db.get_match_num(constraint_dict)
match_dom = turn_domain[0] if len(turn_domain) == 1 else turn_domain[1]
match_dom = match_dom[1:-1] if match_dom.startswith('[') else match_dom
match = matnums[match_dom]
# vector = self.db.addDBPointer(match_dom, match)
vector = self.db.addDBIndicator(match_dom, match)
return vector

@@ -691,3 +823,67 @@ class MultiWOZBPETextField(BPETextField):
inputs['labels'] = [context] # use previous turn

return inputs, prompt_id

def restore(self, resp, domain, constraint_dict, mat_ents):
restored = resp

restored = restored.replace('[value_reference]', '53022')
restored = restored.replace('[value_car]', 'BMW')

for d in domain:
constraint = constraint_dict.get(d, None)
if constraint:
replace_res_list = [('stay', '[value_stay]'),
('day', '[value_day]'),
('people', '[value_people]'),
('time', '[value_time]'),
('type', '[value_type]')]
for key, value_key in replace_res_list:
if key in constraint:
restored = restored.replace(value_key, constraint[key])

if d in mat_ents and len(mat_ents[d]) == 0:
for s in constraint:
if s == 'pricerange' and d in [
'hotel', 'restaurant'
] and 'price]' in restored:
restored = restored.replace(
'[value_price]', constraint['pricerange'])
if s + ']' in restored:
restored = restored.replace(
'[value_%s]' % s, constraint[s])

if '[value_choice' in restored and mat_ents.get(d):
restored = restored.replace('[value_choice]',
str(len(mat_ents[d])))
if '[value_choice' in restored:
restored = restored.replace('[value_choice]', '3')

try:
ent = mat_ents.get(domain[-1], [])
if ent:
ent = ent[0]

for t in restored.split():
if '[value' in t:
slot = t[7:-1]
if ent.get(slot):
if domain[-1] == 'hotel' and slot == 'price':
slot = 'pricerange'
restored = restored.replace(t, ent[slot])
elif slot == 'price':
if ent.get('pricerange'):
restored = restored.replace(
t, ent['pricerange'])
else:
logger.info(restored, domain)
except Exception:
logger.error(resp)
logger.error(restored)
quit()

restored = restored.replace('[value_phone]', '62781111')
restored = restored.replace('[value_postcode]', 'CG9566')
restored = restored.replace('[value_address]', 'Parkside, Cambridge')

return restored

+ 130
- 0
modelscope/trainers/nlp/space/dialog_modeling_trainer.py View File

@@ -0,0 +1,130 @@
import os
import time
from typing import Callable, Dict, Optional, Tuple, Union

import numpy as np

from modelscope.metainfo import Trainers
from modelscope.models.nlp.space.model.generator import SpaceGenerator
from modelscope.models.nlp.space.model.model_base import SpaceModelBase
from modelscope.preprocessors.space.fields.gen_field import \
MultiWOZBPETextField
from modelscope.trainers.base import BaseTrainer
from modelscope.trainers.builder import TRAINERS
from modelscope.trainers.nlp.space.eval import MultiWOZEvaluator
from modelscope.trainers.nlp.space.trainer.gen_trainer import MultiWOZTrainer
from modelscope.utils.config import Config, ModelFile
from modelscope.utils.logger import get_logger

logger = get_logger()


def setup_seed(seed: int):
import random
import torch
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True


@TRAINERS.register_module(module_name=Trainers.dialog_modeling_trainer)
class DialogModelingTrainer(BaseTrainer):

def __init__(self,
cfg_file: Optional[str] = None,
cfg_modify_fn: Optional[Callable] = None,
*args,
**kwargs):

super().__init__(os.path.join(kwargs['model_dir'], kwargs['cfg_name']))

self.cfg_modify_fn = cfg_modify_fn
self.cfg = self.rebuild_config(self.cfg)

setup_seed(self.cfg.Trainer.seed)

# set reader and evaluator
self.bpe = MultiWOZBPETextField(self.cfg, **kwargs)

self.cfg.Model.num_token_embeddings = self.bpe.vocab_size
self.cfg.Model.num_turn_embeddings = self.bpe.max_ctx_turn + 1

if 'work_dir' in kwargs:
self.cfg.Trainer.save_dir = kwargs['work_dir']
else:
self.cfg.Trainer.save_dir = './default_save_dir'

# set data and data status
self.train_data = self.bpe.get_batches('train')
self.dev_data = self.bpe.get_batches('dev')

self.evaluator = MultiWOZEvaluator(reader=self.bpe, **kwargs)
# set generator
self.generator = SpaceGenerator.create(self.cfg, reader=self.bpe)
self._load_model(**kwargs)

def _load_model(self, **kwargs):

def to_tensor(array):
"""
numpy array -> tensor
"""
import torch
array = torch.tensor(array)
return array.cuda(
) if self.cfg.use_gpu and torch.cuda.is_available() else array

# construct model
if 'model' in kwargs:
self.model = kwargs['model']
else:
self.model = SpaceModelBase.create(
kwargs['model_dir'],
self.cfg,
reader=self.bpe,
generator=self.generator)

import torch
# multi-gpu
if self.cfg.Trainer.gpu > 1 and torch.cuda.device_count() > 1:
self.model = torch.nn.DataParallel(self.model)

# construct trainer
self.trainer = MultiWOZTrainer(
self.model,
to_tensor,
self.cfg,
reader=self.bpe,
evaluator=self.evaluator)
self.trainer.set_optimizers()
# load model, optimizer and lr_scheduler
self.trainer.load()

def rebuild_config(self, cfg: Config):
if self.cfg_modify_fn is not None:
return self.cfg_modify_fn(cfg)
return cfg

def train(self, *args, **kwargs):
logger.info('Train')

self.trainer.train(train_data=self.train_data, dev_data=self.dev_data)

def evaluate(self,
checkpoint_path: Optional[str] = None,
*args,
**kwargs) -> Dict[str, float]:
logger.info('Evaluate')
self.cfg.do_infer = True

# get best checkpoint path
pos = checkpoint_path.rfind('/')
checkpoint_name = checkpoint_path[pos + 1:]
checkpoint_dir = checkpoint_path[:pos]

assert checkpoint_name == ModelFile.TORCH_MODEL_BIN_FILE
kwargs['model_dir'] = checkpoint_dir
self._load_model(**kwargs)
self.trainer.infer(data_type='test')

+ 952
- 0
modelscope/trainers/nlp/space/eval.py View File

@@ -0,0 +1,952 @@
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
# Copyright from https://github.com/thu-spmi/LABES
# Copyright from https://github.com/TonyNemo/UBAR-MultiWOZ
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections import Counter

import json
import numpy as np
from nltk.util import ngrams
from sklearn.metrics import f1_score

from modelscope.utils.nlp.space import ontology, utils
from modelscope.utils.nlp.space.clean_dataset import clean_slot_values


def similar(a, b):
return a == b or a in b or b in a or a.split()[0] == b.split(
)[0] or a.split()[-1] == b.split()[-1]


def setsub(a, b):
junks_a = []
useless_constraint = [
'temperature', 'week', 'est ', 'quick', 'reminder', 'near'
]
for i in a:
flg = False
for j in b:
if similar(i, j):
flg = True
if not flg:
junks_a.append(i)
for junk in junks_a:
flg = False
for item in useless_constraint:
if item in junk:
flg = True
if not flg:
return False
return True


def setsim(a, b):
a, b = set(a), set(b)
return setsub(a, b) and setsub(b, a)


def DA_evaluate(preds, labels):
preds = np.array(preds)
labels = np.array(labels)
results = {}

for avg_name in ['micro']:
my_f1_score = f1_score(y_true=labels, y_pred=preds, average=avg_name)
results['f1_{}'.format(avg_name)] = my_f1_score

return results


class BLEUScorer(object):
# BLEU score calculator via GentScorer interface
# it calculates the BLEU-4 by taking the entire corpus in
# Calulate based multiple candidates against multiple references
def __init__(self):
pass

def score(self, parallel_corpus):

# containers
count = [0, 0, 0, 0]
clip_count = [0, 0, 0, 0]
r = 0
c = 0
weights = [0.25, 0.25, 0.25, 0.25]

# accumulate ngram statistics
for hyps, refs in parallel_corpus:
hyps = [hyp.split() for hyp in hyps]
refs = [ref.split() for ref in refs]
for hyp in hyps:

for i in range(4):
# accumulate ngram counts
hypcnts = Counter(ngrams(hyp, i + 1))
cnt = sum(hypcnts.values())
count[i] += cnt

# compute clipped counts
max_counts = {}
for ref in refs:
refcnts = Counter(ngrams(ref, i + 1))
for ng in hypcnts:
max_counts[ng] = max(
max_counts.get(ng, 0), refcnts[ng])
clipcnt = \
dict((ng, min(count, max_counts[ng])) for ng, count in hypcnts.items())
clip_count[i] += sum(clipcnt.values())

# accumulate r & c
bestmatch = [1000, 1000]
for ref in refs:
if bestmatch[0] == 0:
break
diff = abs(len(ref) - len(hyp))
if diff < bestmatch[0]:
bestmatch[0] = diff
bestmatch[1] = len(ref)
r += bestmatch[1]
c += len(hyp)

# computing bleu score
p0 = 1e-7
bp = \
1 if c > r else math.exp(1 - float(r) / float(c))
p_ns = \
[float(clip_count[i]) / float(count[i] + p0) + p0 for i in range(4)]
s = \
math.fsum(w * math.log(p_n) for w, p_n in zip(weights, p_ns) if p_n)
bleu = bp * math.exp(s)
return bleu * 100


""""
For the data preparation and evaluation on MultiWOZ2.0/2.1,
we refer to the code of UBAR (https://github.com/TonyNemo/UBAR-MultiWOZ)
"""


class MultiWOZEvaluator(object):

def __init__(self, reader, **kwargs):
self.reader = reader
self.domains = ontology.all_domains
self.all_data = self.reader.data
self.test_data = self.reader.test

self.bleu_scorer = BLEUScorer()

self.all_info_slot = []
for d, s_list in ontology.informable_slots.items():
for s in s_list:
self.all_info_slot.append(d + '-' + s)

# only evaluate these slots for dialog success
self.requestables = ['phone', 'address', 'postcode', 'reference', 'id']
self.db_dir = kwargs['data_dir']

def pack_dial(self, data):
dials = {}
for turn in data:
dial_id = turn['dial_id']
if dial_id not in dials:
dials[dial_id] = []
dials[dial_id].append(turn)
return dials

def validation_metric(self, data, fout=None):
bleu = self.bleu_metric(data)
# accu_single_dom, accu_multi_dom, multi_dom_num = self.domain_eval(data)
success, match, req_offer_counts, dial_num = \
self.context_to_response_eval(data, same_eval_as_cambridge=True, fout=fout)
return bleu, success, match

def bleu_metric(self, data, eval_dial_list=None):
gen, truth = [], []
for row in data:
if eval_dial_list and row[
'dial_id'] + '.json' not in eval_dial_list:
continue
gen.append(row['resp_gen'])
truth.append(row['resp'])
wrap_generated = [[_] for _ in gen]
wrap_truth = [[_] for _ in truth]
if gen and truth:
try:
sc = self.bleu_scorer.score(zip(wrap_generated, wrap_truth))
except Exception:
sc = 0.0
else:
sc = 0.0
return sc

def context_to_response_eval(self,
data,
eval_dial_list=None,
same_eval_as_cambridge=False,
fout=None):
dials = self.pack_dial(data)
counts = {}
for req in self.requestables:
counts[req + '_total'] = 0
counts[req + '_offer'] = 0

dial_num, successes, matches = 0, 0, 0

for dial_id in dials:
if eval_dial_list and dial_id + '.json' not in eval_dial_list:
continue
dial = dials[dial_id]
reqs = {}
goal = {}
if '.json' not in dial_id and '.json' in list(
self.all_data.keys())[0]:
dial_id = dial_id + '.json'
for domain in ontology.all_domains:
if self.all_data[dial_id]['goal'].get(domain):
true_goal = self.all_data[dial_id]['goal']
goal = self._parseGoal(goal, true_goal, domain)

for domain in goal.keys():
reqs[domain] = goal[domain]['requestable']

success, match, stats, counts = \
self._evaluateGeneratedDialogue(dial, goal, reqs, counts,
same_eval_as_cambridge=same_eval_as_cambridge, fout=fout)

successes += success
matches += match
dial_num += 1

succ_rate = successes / (float(dial_num) + 1e-10) * 100
match_rate = matches / (float(dial_num) + 1e-10) * 100
return succ_rate, match_rate, counts, dial_num

def _evaluateGeneratedDialogue(self,
dialog,
goal,
real_requestables,
counts,
soft_acc=False,
same_eval_as_cambridge=False,
fout=None):
"""Evaluates the dialogue created by the model.
First we load the user goal of the dialogue, then for each turn
generated by the system we look for key-words.
For the Inform rate we look whether the entity was proposed.
For the Success rate we look for requestables slots"""
# for computing corpus success
requestables = self.requestables

# CHECK IF MATCH HAPPENED
provided_requestables = {}
venue_offered = {}
domains_in_goal = []
log = []
bspans = {}

for domain in goal.keys():
venue_offered[domain] = []
provided_requestables[domain] = []
domains_in_goal.append(domain)

for t, turn in enumerate(dialog):
if t == 0:
continue
if fout is not None:
log.append({
'turn_num': turn['turn_num'],
'turn_domain': turn['dspn'],
'user': turn['user'],
'aspn': turn['aspn'],
'aspn_gen': turn['aspn_gen'],
'resp': turn['resp'],
'resp_gen': turn['resp_gen'],
'pointer': turn['pointer'],
})

sent_t = turn['resp_gen']

for domain in goal.keys():
# for computing success
if same_eval_as_cambridge:
# [restaurant_name], [hotel_name] instead of [value_name]
if self.reader.use_true_domain_for_ctr_eval:
dom_pred = [d[1:-1] for d in turn['dspn'].split()]
else:
dom_pred = [d[1:-1] for d in turn['dspn_gen'].split()]

if domain not in dom_pred: # fail
continue
if '[value_name]' in sent_t or '[value_id]' in sent_t:
if domain in [
'restaurant', 'hotel', 'attraction', 'train'
]:
# HERE YOU CAN PUT YOUR BELIEF STATE ESTIMATION
if not self.reader.use_true_curr_bspn and not self.reader.use_true_bspn_for_ctr_eval:
bspn = turn['bspn_gen']
else:
bspn = turn['bspn']

constraint_dict = self.reader.bspan_to_constraint_dict(
bspn)
if constraint_dict.get(domain):
venues = self.reader.db.queryJsons(
domain,
constraint_dict[domain],
return_name=True)
else:
venues = []

if len(venue_offered[domain]) == 0 and venues:

venue_offered[domain] = venues
bspans[domain] = constraint_dict[domain]
else:
flag = False
for ven in venues:
if ven not in venue_offered[domain]:
flag = True
break
if flag and venues: # sometimes there are no results so sample won't work
venue_offered[domain] = venues
bspans[domain] = constraint_dict[domain]
else: # not limited so we can provide one
venue_offered[domain] = '[value_name]'

# ATTENTION: assumption here - we didn't provide phone or address twice! etc
for requestable in requestables:
if requestable == 'reference':
if '[value_reference]' in sent_t:
if domain in ['restaurant', 'hotel', 'train']:
if 'booked' in turn['pointer'] or 'ok' in turn[
'pointer'] or '[value_reference]' in turn[
'resp']:
# if pointer was allowing for that?
provided_requestables[domain].append(
'reference')
else:
provided_requestables[domain].append(
'reference')
else:
if '[value_' + requestable + ']' in sent_t:
provided_requestables[domain].append(requestable)

# if name was given in the task
for domain in goal.keys():
# if name was provided for the user, the match is being done automatically
if 'name' in goal[domain]['informable']:
venue_offered[domain] = '[value_name]'

# special domains - entity does not need to be provided
if domain in ['taxi', 'police', 'hospital']:
venue_offered[domain] = '[value_name]'

if domain == 'train':
if not venue_offered[domain] and 'id' not in goal[domain][
'requestable']:
venue_offered[domain] = '[value_name]'
"""
Given all inform and requestable slots
we go through each domain from the user goal
and check whether right entity was provided and
all requestable slots were given to the user.
The dialogue is successful if that's the case for all domains.
"""
# HARD EVAL
stats = {
'restaurant': [0, 0, 0],
'hotel': [0, 0, 0],
'attraction': [0, 0, 0],
'train': [0, 0, 0],
'taxi': [0, 0, 0],
'hospital': [0, 0, 0],
'police': [0, 0, 0]
}

match = 0
success = 0
# MATCH
for domain in goal.keys():
match_stat = 0
if domain in ['restaurant', 'hotel', 'attraction', 'train']:
goal_venues = self.reader.db.queryJsons(
domain, goal[domain]['informable'], return_name=True)
if type(venue_offered[domain]
) is str and '_name' in venue_offered[domain]:
match += 1
match_stat = 1
elif len(venue_offered[domain]) > 0 and len(
set(venue_offered[domain]) & set(goal_venues)) > 0:
match += 1
match_stat = 1
else:
if '_name]' in venue_offered[domain]:
match += 1
match_stat = 1

stats[domain][0] = match_stat
stats[domain][2] = 1

if soft_acc:
match = float(match) / len(goal.keys())
else:
if match == len(goal.keys()):
match = 1.0
else:
match = 0.0

for domain in domains_in_goal:
for request in real_requestables[domain]:
counts[request + '_total'] += 1
if request in provided_requestables[domain]:
counts[request + '_offer'] += 1

# SUCCESS
if fout is not None:
for domain in domains_in_goal:
success_stat = 0
domain_success = 0
if len(real_requestables[domain]) == 0:
success += 1
success_stat = 1
stats[domain][1] = success_stat
continue
# if values in sentences are super set of requestables
for request in real_requestables[domain]:
if request in provided_requestables[domain]:
domain_success += 1

if domain_success == len(real_requestables[domain]):
success += 1
success_stat = 1

stats[domain][1] = success_stat

# final eval
if soft_acc:
success = float(success) / len(real_requestables)
else:
if success >= len(real_requestables):
success = 1
else:
success = 0
else:
if match == 1.0:
for domain in domains_in_goal:
success_stat = 0
domain_success = 0
if len(real_requestables[domain]) == 0:
success += 1
success_stat = 1
stats[domain][1] = success_stat
continue
# if values in sentences are super set of requestables
for request in real_requestables[domain]:
if request in provided_requestables[domain]:
domain_success += 1

if domain_success == len(real_requestables[domain]):
success += 1
success_stat = 1

stats[domain][1] = success_stat

# final eval
if soft_acc:
success = float(success) / len(real_requestables)
else:
if success >= len(real_requestables):
success = 1
else:
success = 0

if fout is not None and success == 0:
sample = {
dialog[0]['dial_id']: {
'log': log,
'real_requestables': real_requestables,
'provided_requestables': provided_requestables
}
}
line = json.dumps(sample)
fout.write(line)
fout.write('\n')

return success, match, stats, counts

def _parseGoal(self, goal, true_goal, domain):
"""Parses user goal into dictionary format."""
goal[domain] = {}
goal[domain] = {'informable': {}, 'requestable': [], 'booking': []}
if 'info' in true_goal[domain]:
if domain == 'train':
# we consider dialogues only where train had to be booked!
if 'book' in true_goal[domain]:
goal[domain]['requestable'].append('reference')
if 'reqt' in true_goal[domain]:
if 'id' in true_goal[domain]['reqt']:
goal[domain]['requestable'].append('id')
else:
if 'reqt' in true_goal[domain]:
for s in true_goal[domain]['reqt']: # addtional requests:
if s in [
'phone', 'address', 'postcode', 'reference',
'id'
]:
# ones that can be easily delexicalized
goal[domain]['requestable'].append(s)
if 'book' in true_goal[domain]:
goal[domain]['requestable'].append('reference')

for s, v in true_goal[domain]['info'].items():
s_, v_ = clean_slot_values(self.db_dir, domain, s, v)
if len(v_.split()) > 1:
v_ = ' '.join(
[token.text for token in self.reader.nlp(v_)]).strip()
goal[domain]['informable'][s_] = v_

if 'book' in true_goal[domain]:
goal[domain]['booking'] = true_goal[domain]['book']
return goal


class GenericEvaluator:

def __init__(self, reader):
self.reader = reader
self.metric_dict = {}

def pack_dial(self, data):
dials = {}
for turn in data:
dial_id = turn['dial_id']
if dial_id not in dials:
dials[dial_id] = []
dials[dial_id].append(turn)
return dials

def run_metrics(self, results):
raise ValueError('Please specify the evaluator first')

def bleu_metric(self, data, type='bleu'):
gen, truth = [], []
for row in data:
gen.append(self.clean(row['resp_gen']))
# gen.append(self.clean(row['resp']))
truth.append(self.clean(row['resp']))
wrap_generated = [[_] for _ in gen]
wrap_truth = [[_] for _ in truth]
sc = BLEUScorer().score(zip(wrap_generated, wrap_truth))
return sc

def _normalize_constraint(self,
constraint,
ignore_dontcare=False,
intersection=True):
"""
Normalize belief span, e.g. delete repeated words
:param constraint - {'food': 'asian oritental', 'pricerange': 'cheap'}
:param intersection: if true, only keeps the words that appear in th ontology
we set intersection=True as in previous works
:returns: normalized constraint dict
e.g. - {'food': 'asian oritental', 'pricerange': 'cheap', 'area': ''}
"""
normalized = {}
for s in self.informable_slots:
normalized[s] = ''
for s, v in constraint.items():
if ignore_dontcare and v == 'dontcare':
continue
if intersection and v != 'dontcare' and v not in self.entities_flat:
continue

normalized[s] = v

return normalized

def _normalize_act(self, aspn, intersection=False):
aspn_list = aspn.split('|')
normalized = {}
for i, v in enumerate(aspn_list):
seq = v.strip()
word_set = set()
for w in seq.split():
if intersection:
if self.reader.act_order[i] == 'av':
if '[value' in w:
word_set.add(w)
else:
if w in self.requestable_slots:
word_set.add(w)
else:
word_set.add(w)
normalized[self.reader.act_order[i]] = word_set
return normalized

def tracker_metric(self, data, normalize=True):
# turn level metric
tp, fp, fn, db_correct = 0, 0, 0, 0
goal_accr, slot_accr, total = 0, {}, 1e-8
for s in self.informable_slots:
slot_accr[s] = 0

for row in data:
if normalize:
gen = self._normalize_constraint(row['bspn_gen'])
truth = self._normalize_constraint(row['bspn'])
else:
gen = self._normalize_constraint(
row['bspn_gen'], intersection=False)
truth = self._normalize_constraint(
row['bspn'], intersection=False)
valid = 'thank' not in row['user'] and 'bye' not in row['user']
if valid:
for slot, value in gen.items():
if value in truth[slot]:
tp += 1
else:
fp += 1
for slot, value in truth.items():
if value not in gen[slot]:
fn += 1

if truth and valid:
total += 1
for s in self.informable_slots:
if gen[s] == truth[s]:
slot_accr[s] += 1
if gen == truth:
goal_accr += 1
if row.get('db_gen') and row.get('db_match'):
if row['db_gen'] == row['db_match']:
db_correct += 1
precision, recall = tp / (tp + fp + 1e-8), tp / (tp + fn + 1e-8)
f1 = 2 * precision * recall / (precision + recall + 1e-8)
goal_accr /= total
db_correct /= total
for s in slot_accr:
slot_accr[s] /= total
return precision, recall, f1, goal_accr, slot_accr, db_correct

def request_metric(self, data):
# dialog level metric
dials = self.pack_dial(data)
tp, fp, fn = 0, 0, 0
for dial_id in dials:
truth_req, gen_req = set(), set()
dial = dials[dial_id]
for turn_num, turn in enumerate(dial):
resp_gen_token = self.clean(turn['resp_gen']).split()
resp_token = self.clean(turn['resp']).split()
for w in resp_gen_token:
if '[value_' in w and w.endswith(
']') and w != '[value_name]':
gen_req.add(w[1:-1].split('_')[1])
for w in resp_token:
if '[value_' in w and w.endswith(
']') and w != '[value_name]':
truth_req.add(w[1:-1].split('_')[1])
for req in gen_req:
if req in truth_req:
tp += 1
else:
fp += 1
for req in truth_req:
if req not in gen_req:
fn += 1
precision, recall = tp / (tp + fp + 1e-8), tp / (tp + fn + 1e-8)
f1 = 2 * precision * recall / (precision + recall + 1e-8)
return f1, precision, recall

def act_metric(self, data):
# turn level metric
tp, fp, fn = {
'all_s': 0,
'all_v': 0
}, {
'all_s': 0,
'all_v': 0
}, {
'all_s': 0,
'all_v': 0
}
for s in self.requestable_slots:
tp[s], fp[s], fn[s] = 0, 0, 0
tp['[value_%s]' % s], fp['[value_%s]' % s], fn['[value_%s]'
% s] = 0, 0, 0

for row in data:
gen = self._normalize_act(row['aspn_gen'])
truth = self._normalize_act(row['aspn'])
valid = 'thank' not in row['user'] and 'bye' not in row['user']
if valid:
# how well the act decoder captures user's requests
for value in gen['av']:
if value in truth['av']:
tp['all_v'] += 1
if tp.get(value):
tp[value] += 1
else:
fp['all_v'] += 1
if fp.get(value):
fp[value] += 1
for value in truth['av']:
if value not in gen['av']:
fn['all_v'] += 1
if fn.get(value):
fn[value] += 1

# how accurately the act decoder predicts system's question
if 'as' not in gen:
continue
for slot in gen['as']:
if slot in truth['as']:
tp['all_s'] += 1
if tp.get(slot):
tp[slot] += 1
else:
fp['all_s'] += 1
if fp.get(slot):
fp[slot] += 1
for slot in truth['as']:
if slot not in gen['as']:
fn['all_s'] += 1
if fn.get(slot):
fn[slot] += 1

result = {}
for k, v in tp.items():
precision, recall = tp[k] / (tp[k] + fp[k] + 1e-8), tp[k] / (
tp[k] + fn[k] + 1e-8)
f1 = 2 * precision * recall / (precision + recall + 1e-8)
result[k] = [f1, precision, recall]
return result


"""
For the data preparation and evaluation on In-Car Assistant/CamRest,
we refer to the code of LABES (https://github.com/thu-spmi/LABES)
"""


class CamRestEvaluator(GenericEvaluator):

def __init__(self, reader):
super().__init__(reader)
self.entities_flat, self.entitiy_to_slot_dict = self.get_entities(
self.reader.ontology_path)
self.informable_slots = self.reader.otlg.informable_slots
self.requestable_slots = self.reader.otlg.requestable_slots

def run_metrics(self, results):
metrics = {}
bleu = self.bleu_metric(results)
p, r, f1, goal_acc, slot_acc, db_acc = self.tracker_metric(results)
match = self.match_metric(results)
req_f1, req_p, req_r = self.request_metric(results)

metrics['bleu'] = bleu
metrics['match'] = match
metrics['req_f1'] = req_f1
metrics['joint_goal'] = goal_acc
metrics['slot_accu'] = slot_acc
metrics['slot-p/r/f1'] = (p, r, f1)
metrics['db_acc'] = db_acc

return metrics

def get_entities(self, entity_path):
entities_flat = []
entitiy_to_slot_dict = {}
raw_entities = json.loads(open(entity_path).read().lower())
for s in raw_entities['informable']:
entities_flat.extend(raw_entities['informable'][s])
for v in raw_entities['informable'][s]:
entitiy_to_slot_dict[v] = s
return entities_flat, entitiy_to_slot_dict

def constraint_same(self, truth_cons, gen_cons):
if not truth_cons and not gen_cons:
return True
if not truth_cons or not gen_cons:
return False
return setsim(gen_cons, truth_cons)

def match_metric(self, data):
dials = self.pack_dial(data)
match, total = 0, 1e-8
for dial_id in dials:
dial = dials[dial_id]
truth_cons, gen_cons = {'1': '', '2': '', '3': ''}, None
for turn_num, turn in enumerate(dial):
# find the last turn which the system provide an entity
if '[value' in turn['resp_gen']:
gen_cons = self._normalize_constraint(
turn['bspn_gen'], ignore_dontcare=True)
if '[value' in turn['resp']:
truth_cons = self._normalize_constraint(
turn['bspn'], ignore_dontcare=True)
if not gen_cons:
# if no entity is provided, choose the state of the last dialog turn
gen_cons = self._normalize_constraint(
dial[-1]['bspn_gen'], ignore_dontcare=True)
if list(truth_cons.values()) != ['', '', '']:
if gen_cons == truth_cons:
match += 1
total += 1

return match / total

def clean(self, resp):
# we use the same clean process as in Sequicity, SEDST, FSDM
# to ensure comparable results
resp = resp.replace(f'{self.reader.sos_r_token} ', '')
resp = resp.replace(f' {self.reader.eos_r_token}', '')
resp = f'{self.reader.sos_r_token} {resp} {self.reader.eos_r_token}'
for value, slot in self.entitiy_to_slot_dict.items():

resp = utils.clean_replace(resp, value, '[value_%s]' % slot)
return resp


class KvretEvaluator(GenericEvaluator):

def __init__(self, reader):
super().__init__(reader)
self.entities_flat, self.entitiy_to_slot_dict = self.get_entities(
self.reader.ontology_path)
self.informable_slots = self.reader.otlg.informable_slots
self.requestable_slots = self.reader.otlg.requestable_slots

def run_metrics(self, results):
metrics = {}
bleu = self.bleu_metric(results)
p, r, f1, goal_acc, slot_acc, db_acc = self.tracker_metric(
results, normalize=True)
match = self.match_metric(results)
req_f1, req_p, req_r = self.request_metric(results)

metrics['bleu'] = bleu
metrics['match'] = match
metrics['req_f1'] = req_f1
metrics['joint_goal'] = goal_acc
metrics['slot_accu'] = slot_acc
metrics['slot-p/r/f1'] = (p, r, f1)
metrics['db_acc'] = db_acc

return metrics

def _normalize_constraint(self,
constraint,
ignore_dontcare=False,
intersection=True):
"""
Normalize belief span, e.g. delete repeated words
:param constraint - {'food': 'asian oritental', 'pricerange': 'cheap'}
:param intersection: if true, only keeps the words that appear in th ontology
we set intersection=True as in previous works
:returns: normalized constraint dict
e.g. - {'food': 'asian oritental', 'pricerange': 'cheap', 'area': ''}
"""
junk = [
'good', 'great', 'quickest', 'shortest', 'route', 'week',
'fastest', 'nearest', 'next', 'closest', 'way', 'mile', 'activity',
'restaurant', 'appointment'
]
normalized = {}
for s in self.informable_slots:
normalized[s] = ''
for s, v in constraint.items():
for j in junk:
v = ' '.join(v.replace(j, '').split())
if intersection and v not in self.entities_flat:
continue

if s in self.informable_slots:
normalized[s] = v
else:
# TODO only use slot (not domain) in s for matching !!!
pass

return normalized

def get_entities(self, entity_path):
entities_flat = []
entitiy_to_slot_dict = {}

entitiy_to_slot_dict = self.reader.entity_dict
for s in entitiy_to_slot_dict:
if s not in entities_flat:
entities_flat.append(s)
return entities_flat, entitiy_to_slot_dict

def constraint_same(self, truth_cons, gen_cons):
if not truth_cons and not gen_cons:
return True
if not truth_cons or not gen_cons:
return False
return setsim(gen_cons, truth_cons)

def match_metric(self, data):
dials = self.pack_dial(data)
match, total = 0, 1e-8
for dial_id in dials:
dial = dials[dial_id]
truth_cons, gen_cons = {
'1': '',
'2': '',
'3': '',
'4': '',
'5': '',
'6': '',
'7': '',
'8': '',
'9': '',
'10': '',
'11': ''
}, None
for turn_num, turn in enumerate(dial):
# find the last turn which the system provide an entity
if '[value' in turn['resp_gen']:
gen_cons = self._normalize_constraint(
turn['bspn_gen'], ignore_dontcare=True)
if '[value' in turn['resp']:
truth_cons = self._normalize_constraint(
turn['bspn'], ignore_dontcare=True)

if not gen_cons:
# if no entity is provided, choose the state of the last dialog turn
gen_cons = self._normalize_constraint(
dial[-1]['bspn_gen'], ignore_dontcare=True)

if list(truth_cons.values()) != [''] * 11:
gen_cons = [x for x in gen_cons.values() if x]
truth_cons = [x for x in truth_cons.values() if x]
if self.constraint_same(gen_cons, truth_cons):
match += 1
total += 1

return match / total

def clean(self, resp):
# we use the same clean process as in Sequicity, SEDST, FSDM
# to ensure comparable results
resp = resp.replace(f'{self.reader.sos_r_token} ', '')
resp = resp.replace(f' {self.reader.eos_r_token}', '')
resp = f'{self.reader.sos_r_token} {resp} {self.reader.eos_r_token}'
for value, slot in self.entitiy_to_slot_dict.items():
resp = utils.clean_replace(resp, value, '[value_%s]' % slot)
return resp

+ 28
- 44
modelscope/trainers/nlp/space/trainer/gen_trainer.py View File

@@ -15,27 +15,11 @@ from transformers.optimization import AdamW, get_linear_schedule_with_warmup

from modelscope.trainers.nlp.space.metrics.metrics_tracker import \
MetricsTracker
from modelscope.utils.constant import ModelFile
from modelscope.utils.logger import get_logger
from modelscope.utils.nlp.space import ontology


def get_logger(log_path, name='default'):
logger = logging.getLogger(name)
logger.propagate = False
logger.setLevel(logging.DEBUG)

formatter = logging.Formatter('%(message)s')

sh = logging.StreamHandler(sys.stdout)
sh.setFormatter(formatter)
logger.addHandler(sh)

fh = logging.FileHandler(log_path, mode='w')
fh.setFormatter(formatter)
logger.addHandler(fh)

return logger


class Trainer(object):

def __init__(self,
@@ -51,15 +35,16 @@ class Trainer(object):

self.do_train = config.do_train
self.do_infer = config.do_infer
self.is_decreased_valid_metric = config.Trainer.valid_metric_name[
0] == '-'
self.valid_metric_name = config.Trainer.valid_metric_name[1:]
self.num_epochs = config.Trainer.num_epochs
# self.save_dir = config.Trainer.save_dir
self.log_steps = config.Trainer.log_steps
self.valid_steps = config.Trainer.valid_steps
self.save_checkpoint = config.Trainer.save_checkpoint
self.save_summary = config.Trainer.save_summary
if self.do_train:
self.is_decreased_valid_metric = config.Trainer.valid_metric_name[
0] == '-'
self.valid_metric_name = config.Trainer.valid_metric_name[1:]
self.num_epochs = config.Trainer.num_epochs
self.save_dir = config.Trainer.save_dir
self.log_steps = config.Trainer.log_steps
self.valid_steps = config.Trainer.valid_steps
self.save_checkpoint = config.Trainer.save_checkpoint
self.save_summary = config.Trainer.save_summary
self.lr = config.Model.lr
self.weight_decay = config.Model.weight_decay
self.batch_size = config.Trainer.batch_size
@@ -71,22 +56,21 @@ class Trainer(object):
self.optimizer = optimizer

self.model = model
self.func_model = self.model.module if self.gpu > 1 else self.model
self.func_model = self.model.module if self.gpu > 1 and config.use_gpu else self.model
self.reader = reader
self.evaluator = evaluator
self.tokenizer = reader.tokenizer

# if not os.path.exists(self.save_dir):
# os.makedirs(self.save_dir)

# self.logger = logger or get_logger(os.path.join(self.save_dir, "trainer.log"), "trainer")
self.logger = logger or get_logger('trainer.log', 'trainer')
self.logger = get_logger()

self.batch_metrics_tracker = MetricsTracker()
self.token_metrics_tracker = MetricsTracker()

self.best_valid_metric = float(
'inf' if self.is_decreased_valid_metric else '-inf')
if self.do_train:
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
self.best_valid_metric = float(
'inf' if self.is_decreased_valid_metric else '-inf')
self.epoch = 0

def decode_generated_bspn_resp(self, generated):
@@ -248,9 +232,12 @@ class Trainer(object):

# Save current best model
if is_best:
best_model_file = os.path.join(self.save_dir, 'best.model')
best_model_file = os.path.join(self.save_dir,
ModelFile.TORCH_MODEL_BIN_FILE)
torch.save(self.model.state_dict(), best_model_file)
best_train_file = os.path.join(self.save_dir, 'best.train')
best_train_file = os.path.join(
self.save_dir,
'{}.train'.format(ModelFile.TORCH_MODEL_BIN_FILE))
torch.save(train_state, best_train_file)
self.logger.info(
f"Saved best model state to '{best_model_file}' with new best valid metric "
@@ -324,8 +311,7 @@ class Trainer(object):

self.func_model.load_state_dict(model_state_dict)
self.logger.info(
f"Loaded model state from '{self.func_model.init_checkpoint}.model'"
)
f"Loaded model state from '{self.func_model.init_checkpoint}'")

def _load_train_state():
train_file = f'{self.func_model.init_checkpoint}.train'
@@ -558,19 +544,17 @@ class MultiWOZTrainer(Trainer):
generated_bs = outputs[0].cpu().numpy().tolist()
bspn_gen = self.decode_generated_bspn(generated_bs)
# check DB result
if self.reader.use_true_db_pointer: # To control whether current db is ground truth
if self.reader.use_true_db_pointer:
db = turn['db']
else:
db_result = self.reader.bspan_to_DBpointer(
self.tokenizer.decode(bspn_gen),
turn['turn_domain'])
assert len(turn['db']) == 4
book_result = turn['db'][2]
assert len(turn['db']) == 3
assert isinstance(db_result, str)
db = \
[self.reader.sos_db_id] + \
self.tokenizer.convert_tokens_to_ids([db_result]) + \
[book_result] + \
[self.reader.eos_db_id]
prompt_id = self.reader.sos_a_id

@@ -636,7 +620,7 @@ class MultiWOZTrainer(Trainer):
score = 0.5 * (success + match) + bleu

# log results
metrics_message = 'match: %2.2f success: %2.2f bleu: %2.2f score: %.2f' %\
metrics_message = 'match: %2.2f success: %2.2f bleu: %2.2f score: %.2f' % \
(match, success, bleu, score)
message_prefix = f'[Infer][{self.epoch}]'
time_cost = f'TIME-{time.time() - begin_time:.3f}'


+ 333
- 0
modelscope/utils/nlp/space/clean_dataset.py View File

@@ -0,0 +1,333 @@
import os
import re

from . import ontology


def clean_text_split_dot(text):
text = re.sub(r'([a-zT]+)\.([a-z])', r'\1 . \2',
text) # 'abc.xyz' -> 'abc . xyz'
text = re.sub(r'(\w+)\.\.? ', r'\1 . ', text) # if 'abc. ' -> 'abc . '
return text


def clean_text(data_dir, text):
text = text.strip()
text = text.lower()
text = text.replace(u'’', "'")
text = text.replace(u'‘', "'")
text = text.replace(';', ',')
text = text.replace('"', ' ')
text = text.replace('/', ' and ')
text = text.replace("don't", "do n't")
text = clean_time(text)
baddata = {
r'c\.b (\d), (\d) ([a-z])\.([a-z])': r'cb\1\2\3\4',
'c.b. 1 7 d.y': 'cb17dy',
'c.b.1 7 d.y': 'cb17dy',
'c.b 25, 9 a.q': 'cb259aq',
'isc.b 25, 9 a.q': 'is cb259aq',
'c.b2, 1 u.f': 'cb21uf',
'c.b 1,2 q.a': 'cb12qa',
'0-122-336-5664': '01223365664',
'postcodecb21rs': 'postcode cb21rs',
r'i\.d': 'id',
' i d ': 'id',
'Telephone:01223358966': 'Telephone: 01223358966',
'depature': 'departure',
'depearting': 'departing',
'-type': ' type',
r'b[\s]?&[\s]?b': 'bed and breakfast',
'b and b': 'bed and breakfast',
r'guesthouse[s]?': 'guest house',
r'swimmingpool[s]?': 'swimming pool',
"wo n\'t": 'will not',
" \'d ": ' would ',
" \'m ": ' am ',
" \'re' ": ' are ',
" \'ll' ": ' will ',
" \'ve ": ' have ',
r'^\'': '',
r'\'$': '',
}
for tmpl, good in baddata.items():
text = re.sub(tmpl, good, text)

text = re.sub(r'([a-zT]+)\.([a-z])', r'\1 . \2',
text) # 'abc.xyz' -> 'abc . xyz'
text = re.sub(r'(\w+)\.\.? ', r'\1 . ', text) # if 'abc. ' -> 'abc . '

with open(os.path.join(data_dir, 'mapping.pair'), 'r') as fin:
for line in fin.readlines():
fromx, tox = line.replace('\n', '').split('\t')
text = ' ' + text + ' '
text = text.replace(' ' + fromx + ' ', ' ' + tox + ' ')[1:-1]

return text


def clean_time(utter):
utter = re.sub(r'(\d+) ([ap]\.?m)', lambda x: x.group(1) + x.group(2),
utter) # 9 am -> 9am
utter = re.sub(r'((?<!\d)\d:\d+)(am)?', r'0\1', utter)
utter = re.sub(r'((?<!\d)\d)am', r'0\1:00', utter)
utter = re.sub(r'((?<!\d)\d)pm',
lambda x: str(int(x.group(1)) + 12) + ':00', utter)
utter = re.sub(r'(\d+)(:\d+)pm',
lambda x: str(int(x.group(1)) + 12) + x.group(2), utter)
utter = re.sub(r'(\d+)a\.?m', r'\1', utter)
return utter


def clean_slot_values(data_dir, domain, slot, value):
value = clean_text(data_dir, value)
if not value:
value = ''
elif value == 'not mentioned':
value = ''
# value = 'not mentioned' # if in DST setting
elif domain == 'attraction':
if slot == 'name':
if value == 't':
value = ''
if value == 'trinity':
value = 'trinity college'
elif slot == 'area':
if value in ['town centre', 'cent', 'center', 'ce']:
value = 'centre'
elif value in [
'ely', 'in town', 'museum', 'norwich', 'same area as hotel'
]:
value = ''
elif value in ['we']:
value = 'west'
elif slot == 'type':
if value in ['m', 'mus', 'musuem']:
value = 'museum'
elif value in ['art', 'architectural']:
value = 'architecture'
elif value in ['churches']:
value = 'church'
elif value in ['coll']:
value = 'college'
elif value in ['concert', 'concerthall']:
value = 'concert hall'
elif value in ['night club']:
value = 'nightclub'
elif value in [
'mutiple sports', 'mutliple sports', 'sports', 'galleria'
]:
value = 'multiple sports'
elif value in ['ol', 'science', 'gastropub', 'la raza']:
value = ''
elif value in ['swimmingpool', 'pool']:
value = 'swimming pool'
elif value in ['fun']:
value = 'entertainment'

elif domain == 'hotel':
if slot == 'area':
if value in [
'cen', 'centre of town', 'near city center', 'center'
]:
value = 'centre'
elif value in ['east area', 'east side']:
value = 'east'
elif value in ['in the north', 'north part of town']:
value = 'north'
elif value in ['we']:
value = 'west'
elif slot == 'day':
if value == 'monda':
value = 'monday'
elif value == 't':
value = 'tuesday'
elif slot == 'name':
if value == 'uni':
value = 'university arms hotel'
elif value == 'university arms':
value = 'university arms hotel'
elif value == 'acron':
value = 'acorn guest house'
elif value == 'ashley':
value = 'ashley hotel'
elif value == 'arbury lodge guesthouse':
value = 'arbury lodge guest house'
elif value == 'la':
value = 'la margherit'
elif value == 'no':
value = ''
elif slot == 'internet':
if value == 'does not':
value = 'no'
elif value in ['y', 'free', 'free internet']:
value = 'yes'
elif value in ['4']:
value = ''
elif slot == 'parking':
if value == 'n':
value = 'no'
elif value in ['free parking']:
value = 'yes'
elif value in ['y']:
value = 'yes'
elif slot in ['pricerange', 'price range']:
slot = 'pricerange'
if value == 'moderately':
value = 'moderate'
elif value in ['any']:
value = "do n't care"
elif value in ['any']:
value = "do n't care"
elif value in ['inexpensive']:
value = 'cheap'
elif value in ['2', '4']:
value = ''
elif slot == 'stars':
if value == 'two':
value = '2'
elif value == 'three':
value = '3'
elif value in [
'4-star', '4 stars', '4 star', 'four star', 'four stars'
]:
value = '4'
elif slot == 'type':
if value == '0 star rarting':
value = ''
elif value == 'guesthouse':
value = 'guest house'
elif value not in ['hotel', 'guest house', "do n't care"]:
value = ''
elif domain == 'restaurant':
if slot == 'area':
if value in [
'center', 'scentre', 'center of town', 'city center',
'cb30aq', 'town center', 'centre of cambridge',
'city centre'
]:
value = 'centre'
elif value == 'west part of town':
value = 'west'
elif value == 'n':
value = 'north'
elif value in ['the south']:
value = 'south'
elif value not in [
'centre', 'south', "do n't care", 'west', 'east', 'north'
]:
value = ''
elif slot == 'day':
if value == 'monda':
value = 'monday'
elif value == 't':
value = 'tuesday'
elif slot in ['pricerange', 'price range']:
slot = 'pricerange'
if value in ['moderately', 'mode', 'mo']:
value = 'moderate'
elif value in ['not']:
value = ''
elif value in ['inexpensive', 'ch']:
value = 'cheap'
elif slot == 'food':
if value == 'barbecue':
value = 'barbeque'
elif slot == 'pricerange':
if value == 'moderately':
value = 'moderate'
elif slot == 'time':
if value == '9:00':
value = '09:00'
elif value == '9:45':
value = '09:45'
elif value == '1330':
value = '13:30'
elif value == '1430':
value = '14:30'
elif value == '9:15':
value = '09:15'
elif value == '9:30':
value = '09:30'
elif value == '1830':
value = '18:30'
elif value == '9':
value = '09:00'
elif value == '2:00':
value = '14:00'
elif value == '1:00':
value = '13:00'
elif value == '3:00':
value = '15:00'
elif domain == 'taxi':
if slot in ['arriveBy', 'arrive by']:
slot = 'arriveby'
if value == '1530':
value = '15:30'
elif value == '15 minutes':
value = ''
elif slot in ['leaveAt', 'leave at']:
slot = 'leaveat'
if value == '1:00':
value = '01:00'
elif value == '21:4':
value = '21:04'
elif value == '4:15':
value = '04:15'
elif value == '5:45':
value = '05:45'
elif value == '0700':
value = '07:00'
elif value == '4:45':
value = '04:45'
elif value == '8:30':
value = '08:30'
elif value == '9:30':
value = '09:30'
value = value.replace('.', ':')

elif domain == 'train':
if slot in ['arriveBy', 'arrive by']:
slot = 'arriveby'
if value == '1':
value = '01:00'
elif value in ['does not care', 'doesnt care', "doesn't care"]:
value = "do n't care"
elif value == '8:30':
value = '08:30'
elif value == 'not 15:45':
value = ''
value = value.replace('.', ':')
elif slot == 'day':
if value == 'doesnt care' or value == "doesn't care":
value = "do n't care"
elif slot in ['leaveAt', 'leave at']:
slot = 'leaveat'
if value == '2:30':
value = '02:30'
elif value == '7:54':
value = '07:54'
elif value == 'after 5:45 pm':
value = '17:45'
elif value in [
'early evening', 'friday', 'sunday', 'tuesday', 'afternoon'
]:
value = ''
elif value == '12':
value = '12:00'
elif value == '1030':
value = '10:30'
elif value == '1700':
value = '17:00'
elif value in [
'does not care', 'doesnt care', 'do nt care',
"doesn't care"
]:
value = "do n't care"

value = value.replace('.', ':')
if value in ['dont care', "don't care", 'do nt care', "doesn't care"]:
value = "do n't care"
if ontology.normlize_slot_names.get(slot):
slot = ontology.normlize_slot_names[slot]
return slot, value

+ 8
- 4
modelscope/utils/nlp/space/utils.py View File

@@ -4,8 +4,11 @@ from collections import OrderedDict
import json
import numpy as np

from modelscope.utils.logger import get_logger
from . import ontology

logger = get_logger()


def max_lens(X):
lens = [len(X)]
@@ -117,8 +120,8 @@ class MultiWOZVocab(object):
def construct(self):
freq_dict_sorted = sorted(
self._freq_dict.keys(), key=lambda x: -self._freq_dict[x])
print('Vocabulary size including oov: %d' %
(len(freq_dict_sorted) + len(self._idx2word)))
logger.info('Vocabulary size including oov: %d' %
(len(freq_dict_sorted) + len(self._idx2word)))
if len(freq_dict_sorted) + len(self._idx2word) < self.vocab_size:
logging.warning(
'actual label set smaller than that configured: {}/{}'.format(
@@ -148,8 +151,9 @@ class MultiWOZVocab(object):
for w, idx in self._word2idx.items():
self._idx2word[idx] = w
self.vocab_size_oov = len(self._idx2word)
print('vocab file loaded from "' + vocab_path + '"')
print('Vocabulary size including oov: %d' % (self.vocab_size_oov))
logger.info('vocab file loaded from "' + vocab_path + '"')
logger.info('Vocabulary size including oov: %d' %
(self.vocab_size_oov))

def save_vocab(self, vocab_path):
_freq_dict = OrderedDict(


+ 68
- 0
tests/trainers/test_dialog_modeling_trainer.py View File

@@ -0,0 +1,68 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import unittest

import torch

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Preprocessors, Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.utils.constant import DownloadMode, ModelFile
from modelscope.utils.test_utils import test_level


class TestDialogModelingTrainer(unittest.TestCase):

model_id = 'damo/nlp_space_pretrained-dialog-model'
output_dir = './dialog_fintune_result'

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_trainer_with_model_and_args(self):
# download data set
data_multiwoz = MsDataset.load(
'MultiWoz2.0', download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)
data_dir = os.path.join(
data_multiwoz._hf_ds.config_kwargs['split_config']['train'],
'data')

# download model
model_dir = snapshot_download(self.model_id)

# dialog finetune config
def cfg_modify_fn(cfg):
config = {
'seed': 10,
'gpu': 4,
'use_data_distributed': False,
'valid_metric_name': '-loss',
'num_epochs': 60,
'save_dir': self.output_dir,
'token_loss': True,
'batch_size': 32,
'log_steps': 10,
'valid_steps': 0,
'save_checkpoint': True,
'save_summary': False,
'shuffle': True,
'sort_pool_size': 0
}

cfg.Trainer = config
cfg.use_gpu = torch.cuda.is_available() and config['gpu'] >= 1
return cfg

# trainer config
kwargs = dict(
model_dir=model_dir,
cfg_name='gen_train_config.json',
data_dir=data_dir,
cfg_modify_fn=cfg_modify_fn)

trainer = build_trainer(
name=Trainers.dialog_modeling_trainer, default_args=kwargs)
trainer.train()
checkpoint_path = os.path.join(self.output_dir,
ModelFile.TORCH_MODEL_BIN_FILE)
assert os.path.exists(checkpoint_path)
trainer.evaluate(checkpoint_path=checkpoint_path)

Loading…
Cancel
Save