From b698506a2c7a8ca8b8cc7b2bf7cc3bb23b6bb8de Mon Sep 17 00:00:00 2001 From: ly119399 Date: Wed, 8 Jun 2022 23:12:59 +0800 Subject: [PATCH] model forward ready --- .../nlp/space/dialog_generation_model.py | 55 +- maas_lib/trainers/nlp/space/__init__.py | 0 .../trainers/nlp/space/metrics/__init__.py | 0 .../nlp/space/metrics/metrics_tracker.py | 73 ++ .../trainers/nlp/space/trainers/__init__.py | 0 .../nlp/space/trainers/gen_trainer.py | 725 ++++++++++++++++++ tests/pipelines/nlp/test_dialog_generation.py | 25 +- 7 files changed, 861 insertions(+), 17 deletions(-) create mode 100644 maas_lib/trainers/nlp/space/__init__.py create mode 100644 maas_lib/trainers/nlp/space/metrics/__init__.py create mode 100644 maas_lib/trainers/nlp/space/metrics/metrics_tracker.py create mode 100644 maas_lib/trainers/nlp/space/trainers/__init__.py create mode 100644 maas_lib/trainers/nlp/space/trainers/gen_trainer.py diff --git a/maas_lib/models/nlp/space/dialog_generation_model.py b/maas_lib/models/nlp/space/dialog_generation_model.py index a5d286a4..440c1163 100644 --- a/maas_lib/models/nlp/space/dialog_generation_model.py +++ b/maas_lib/models/nlp/space/dialog_generation_model.py @@ -1,5 +1,6 @@ from typing import Any, Dict, Optional +from maas_lib.trainers.nlp.space.trainers.gen_trainer import MultiWOZTrainer from maas_lib.utils.constant import Tasks from ...base import Model, Tensor from ...builder import MODELS @@ -32,6 +33,22 @@ class DialogGenerationModel(Model): reader=self.text_field, generator=self.generator) + def to_tensor(array): + """ + numpy array -> tensor + """ + import torch + array = torch.tensor(array) + return array.cuda() if self.config.use_gpu else array + + self.trainer = MultiWOZTrainer( + model=self.model, + to_tensor=to_tensor, + config=self.config, + reader=self.text_field, + evaluator=None) + self.trainer.load() + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: """return the result by the model @@ -48,10 +65,38 @@ class DialogGenerationModel(Model): } """ from numpy import array, float32 + import torch - return { - 'predictions': array([1]), # lable 0-negative 1-positive - 'probabilities': array([[0.11491239, 0.8850876]], dtype=float32), - 'logits': array([[-0.53860897, 1.5029076]], - dtype=float32) # true value + turn_1 = { + 'user': [ + 13, 1045, 2052, 2066, 1037, 10095, 2013, 3002, 2198, 1005, + 1055, 2267, 2000, 10733, 12570, 21713, 4487, 15474, 1012, 7 + ] } + old_pv_turn_1 = {} + + turn_2 = { + 'user': + [13, 1045, 2215, 2000, 2681, 2044, 2459, 1024, 2321, 1012, 7] + } + old_pv_turn_2 = { + 'labels': [[ + 13, 1045, 2052, 2066, 1037, 10095, 2013, 3002, 2198, 1005, + 1055, 2267, 2000, 10733, 12570, 21713, 4487, 15474, 1012, 7 + ]], + 'resp': [ + 14, 1045, 2052, 2022, 3407, 2000, 2393, 2007, 2115, 5227, 1010, + 2079, 2017, 2031, 1037, 2051, 2017, 2052, 2066, 2000, 2681, + 2030, 7180, 2011, 1029, 8 + ], + 'bspn': [ + 15, 43, 7688, 10733, 12570, 21713, 4487, 15474, 6712, 3002, + 2198, 1005, 1055, 2267, 9 + ], + 'db': [19, 24, 21, 20], + 'aspn': [16, 43, 48, 2681, 7180, 10] + } + + pv_turn = self.trainer.forward(turn=turn_2, old_pv_turn=old_pv_turn_2) + + return pv_turn diff --git a/maas_lib/trainers/nlp/space/__init__.py b/maas_lib/trainers/nlp/space/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/maas_lib/trainers/nlp/space/metrics/__init__.py b/maas_lib/trainers/nlp/space/metrics/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/maas_lib/trainers/nlp/space/metrics/metrics_tracker.py b/maas_lib/trainers/nlp/space/metrics/metrics_tracker.py new file mode 100644 index 00000000..37441522 --- /dev/null +++ b/maas_lib/trainers/nlp/space/metrics/metrics_tracker.py @@ -0,0 +1,73 @@ +""" +MetricsTracker class +""" + +import math +from collections import defaultdict + + +class MetricsTracker(object): + """ Tracking metrics. """ + + def __init__(self): + self.metrics_val = defaultdict(float) # 记录最新一个batch返回的指标 + self.metrics_avg = defaultdict(float) # 维护一个epoch内已训练batches的平均指标 + self.num_samples = 0 + + def update(self, metrics, num_samples): + for key, val in metrics.items(): + if val is not None: + val = float(val) # [val] -> val + self.metrics_val[key] = val + avg_val = (self.metrics_avg.get(key, 0) * self.num_samples + + val * num_samples) / ( + self.num_samples + num_samples) + self.metrics_avg[key] = avg_val + self.num_samples += num_samples + + def clear(self): + self.metrics_val = defaultdict(float) + self.metrics_avg = defaultdict(float) + self.num_samples = 0 + + def items(self): + return self.metrics_avg.items() + + def get(self, name): + if self.num_samples == 0: + raise ValueError('There is no data in Metrics.') + return self.metrics_avg.get(name) + + def state_dict(self): + return { + 'metrics_val': self.metrics_val, + 'metrics_avg': self.metrics_avg, + 'num_samples': self.num_samples, + } + + def load_state_dict(self, state_dict): + self.metrics_val = state_dict['metrics_val'] + self.metrics_avg = state_dict['metrics_avg'] + self.num_samples = state_dict['num_samples'] + + def value(self): + metric_strs = [] + for key, val in self.metrics_val.items(): + metric_str = f'{key.upper()}-{val:.3f}' + metric_strs.append(metric_str) + if 'token_nll' in self.metrics_val: + metric_str = f"TOKEN_PPL-{math.exp(self.metrics_val['token_nll']):.3f}" + metric_strs.append(metric_str) + metric_strs = ' '.join(metric_strs) + return metric_strs + + def summary(self): + metric_strs = [] + for key, val in self.metrics_avg.items(): + metric_str = f'{key.upper()}-{val:.3f}' + metric_strs.append(metric_str) + if 'token_nll' in self.metrics_avg: + metric_str = f"TOKEN_PPL-{math.exp(self.metrics_avg['token_nll']):.3f}" + metric_strs.append(metric_str) + metric_strs = ' '.join(metric_strs) + return metric_strs diff --git a/maas_lib/trainers/nlp/space/trainers/__init__.py b/maas_lib/trainers/nlp/space/trainers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/maas_lib/trainers/nlp/space/trainers/gen_trainer.py b/maas_lib/trainers/nlp/space/trainers/gen_trainer.py new file mode 100644 index 00000000..b7197ac2 --- /dev/null +++ b/maas_lib/trainers/nlp/space/trainers/gen_trainer.py @@ -0,0 +1,725 @@ +""" +Trainer class. +""" +import logging +import os +import sys +import time +from collections import OrderedDict + +import json +import numpy as np +import torch +from tqdm import tqdm +from transformers.optimization import AdamW, get_linear_schedule_with_warmup + +from ..metrics.metrics_tracker import MetricsTracker + + +def get_logger(log_path, name='default'): + logger = logging.getLogger(name) + logger.propagate = False + logger.setLevel(logging.DEBUG) + + formatter = logging.Formatter('%(message)s') + + sh = logging.StreamHandler(sys.stdout) + sh.setFormatter(formatter) + logger.addHandler(sh) + + fh = logging.FileHandler(log_path, mode='w') + fh.setFormatter(formatter) + logger.addHandler(fh) + + return logger + + +class Trainer(object): + + def __init__(self, + model, + to_tensor, + config, + logger=None, + lr_scheduler=None, + optimizer=None, + reader=None, + evaluator=None): + self.to_tensor = to_tensor + + self.do_train = config.do_train + self.do_infer = config.do_infer + self.is_decreased_valid_metric = config.Trainer.valid_metric_name[ + 0] == '-' + self.valid_metric_name = config.Trainer.valid_metric_name[1:] + self.num_epochs = config.Trainer.num_epochs + # self.save_dir = config.Trainer.save_dir + self.log_steps = config.Trainer.log_steps + self.valid_steps = config.Trainer.valid_steps + self.save_checkpoint = config.Trainer.save_checkpoint + self.save_summary = config.Trainer.save_summary + self.lr = config.Model.lr + self.weight_decay = config.Model.weight_decay + self.batch_size = config.Trainer.batch_size + self.gradient_accumulation_steps = config.Model.gradient_accumulation_steps + self.warmup_steps = config.Model.warmup_steps + self.gpu = config.Trainer.gpu + + self.lr_scheduler = lr_scheduler + self.optimizer = optimizer + + self.model = model + self.func_model = self.model.module if self.gpu > 1 else self.model + self.reader = reader + self.evaluator = evaluator + self.tokenizer = reader.tokenizer + + # if not os.path.exists(self.save_dir): + # os.makedirs(self.save_dir) + + # self.logger = logger or get_logger(os.path.join(self.save_dir, "trainer.log"), "trainer") + self.logger = logger or get_logger('trainer.log', 'trainer') + + self.batch_metrics_tracker = MetricsTracker() + self.token_metrics_tracker = MetricsTracker() + + self.best_valid_metric = float( + 'inf' if self.is_decreased_valid_metric else '-inf') + self.epoch = 0 + + def decode_generated_bspn_resp(self, generated): + """ + decode generated + return decoded ('bspn', 'resp') + """ + decoded = {} + eos_r_id = self.reader.eos_r_id + eos_b_id = self.reader.eos_b_id + + # eos_r may not exists if gpt2 generated repetitive words. + if eos_r_id in generated: + eos_r_idx = generated.index(eos_r_id) + else: + eos_r_idx = len(generated) - 1 + # self.logger.info('eos_r not in generated: ' + self.tokenizer.decode(generated)) + + # predicted bspn, resp + eos_b_idx = generated.index(eos_b_id) + decoded['bspn'] = generated[:eos_b_idx + 1] + decoded['resp'] = generated[eos_b_idx + 1:eos_r_idx + 1] + return decoded + + def decode_generated_act_resp(self, generated): + """ + decode generated + return decoded['resp'] ('bspn', 'aspn') + """ + decoded = {} + eos_a_id = self.reader.eos_a_id + eos_r_id = self.reader.eos_r_id + eos_b_id = self.reader.eos_b_id + + # eos_r may not exists if gpt2 generated repetitive words. + if eos_r_id in generated: + eos_r_idx = generated.index(eos_r_id) + else: + eos_r_idx = len(generated) - 1 + self.logger.info('eos_r not in generated: ' + + self.tokenizer.decode(generated)) + + if self.reader.use_true_curr_aspn: # only predict resp + decoded['resp'] = generated[:eos_r_idx + 1] + else: # predicted aspn, resp + eos_a_idx = generated.index(eos_a_id) + decoded['aspn'] = generated[:eos_a_idx + 1] + decoded['resp'] = generated[eos_a_idx + 1:eos_r_idx + 1] + return decoded + + def decode_generated_bspn(self, generated): + eos_b_id = self.reader.eos_b_id + if eos_b_id in generated: + eos_b_idx = generated.index(eos_b_id) + else: + eos_b_idx = len(generated) - 1 + return generated[:eos_b_idx + 1] + + def set_optimizers(self): + """ + Setup the optimizer and the learning rate scheduler. + + from transformers.Trainer + + parameters from cfg: lr (1e-3); warmup_steps + """ + # Prepare optimizer and schedule (linear warmup and decay) + no_decay = ['bias', 'norm.weight'] + optimizer_grouped_parameters = [ + { + 'params': [ + p for n, p in self.model.named_parameters() + if not any(nd in n for nd in no_decay) + ], + 'weight_decay': + self.weight_decay, + }, + { + 'params': [ + p for n, p in self.model.named_parameters() + if any(nd in n for nd in no_decay) + ], + 'weight_decay': + 0.0, + }, + ] + optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr) + + num_training_steps = self.reader.set_stats['train']['num_training_steps_per_epoch'] * \ + self.num_epochs // self.gradient_accumulation_steps + num_warmup_steps = self.warmup_steps if self.warmup_steps >= 0 else int( + num_training_steps * 0.1) + lr_scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps) + + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + def train(self, train_data, dev_data): + # log info + set_stats = self.reader.set_stats['train'] + self.logger.info('***** Running training *****') + self.logger.info( + ' Num Training steps(one turn in a batch of dialogs) per epoch = %d', + set_stats['num_training_steps_per_epoch']) + self.logger.info(' Num Turns = %d', set_stats['num_turns']) + self.logger.info(' Num Dialogs = %d', set_stats['num_dials']) + self.logger.info(' Num Epochs = %d', self.num_epochs) + self.logger.info(' Batch size = %d', self.batch_size) + self.logger.info(' Gradient Accumulation steps = %d', + self.gradient_accumulation_steps) + self.logger.info( + ' Total optimization steps = %d', + set_stats['num_training_steps_per_epoch'] * self.num_epochs // + self.gradient_accumulation_steps) + + # begin training + num_epochs = self.num_epochs - self.epoch + for epoch in range(num_epochs): + self.train_epoch(train_data=train_data, dev_data=dev_data) + + def train_epoch(self, train_data, dev_data): + """ + Train an epoch. + """ + raise NotImplementedError + + def infer(self, data_type): + """ + Inference interface. + """ + raise NotImplementedError + + def forward(self, turn, old_pv_turn): + """ + one turn inference + """ + raise NotImplementedError + + def save(self, is_best=False): + """ save """ + train_state = { + 'epoch': self.epoch, + 'best_valid_metric': self.best_valid_metric, + 'optimizer': self.optimizer.state_dict() + } + if self.lr_scheduler is not None: + train_state['lr_scheduler'] = self.lr_scheduler.state_dict() + + # Save checkpoint + if self.save_checkpoint: + model_file = os.path.join(self.save_dir, + f'state_epoch_{self.epoch}.model') + torch.save(self.model.state_dict(), model_file) + self.logger.info(f"Saved model state to '{model_file}'") + + train_file = os.path.join(self.save_dir, + f'state_epoch_{self.epoch}.train') + torch.save(train_state, train_file) + self.logger.info(f"Saved train state to '{train_file}'") + + # Save current best model + if is_best: + best_model_file = os.path.join(self.save_dir, 'best.model') + torch.save(self.model.state_dict(), best_model_file) + best_train_file = os.path.join(self.save_dir, 'best.train') + torch.save(train_state, best_train_file) + self.logger.info( + f"Saved best model state to '{best_model_file}' with new best valid metric " + f'{self.valid_metric_name.upper()}={self.best_valid_metric:.3f}' + ) + + def load(self): + """ load """ + + def _load_model_state(): + model_state_dict = torch.load( + f'{self.func_model.init_checkpoint}', + map_location=lambda storage, loc: storage) + + if 'module.' in list(model_state_dict.keys())[0]: + new_model_state_dict = OrderedDict() + for k, v in model_state_dict.items(): + assert k[:7] == 'module.' + new_model_state_dict[k[7:]] = v + model_state_dict = new_model_state_dict + + new_model_state_dict = OrderedDict() + parameters = { + name: param + for name, param in self.func_model.named_parameters() + } + for name, param in model_state_dict.items(): + if name in parameters: + if param.shape != parameters[name].shape: + assert hasattr(param, 'numpy') + arr = param.numpy() + z = np.random.normal( + scale=self.func_model.initializer_range, + size=parameters[name].shape).astype('float32') + if name == 'embedder.token_embedding.weight': + z[-param.shape[0]:] = arr + print( + f'part of parameter({name}) random normlize initialize' + ) + else: + if z.shape[0] < param.shape[0]: + z = arr[:z.shape[0]] + print(f'part of parameter({name}) are dropped') + else: + z[:param.shape[0]] = arr + print( + f'part of parameter({name}) random normlize initialize' + ) + dtype, device = param.dtype, param.device + z = torch.tensor(z, dtype=dtype, device=device) + new_model_state_dict[name] = z + else: + new_model_state_dict[name] = param + else: + print(f'parameter({name}) are dropped') + model_state_dict = new_model_state_dict + + for name in parameters: + if name not in model_state_dict: + if parameters[name].requires_grad: + print(f'parameter({name}) random normlize initialize') + z = np.random.normal( + scale=self.func_model.initializer_range, + size=parameters[name].shape).astype('float32') + dtype, device = parameters[name].dtype, parameters[ + name].device + model_state_dict[name] = torch.tensor( + z, dtype=dtype, device=device) + else: + model_state_dict[name] = parameters[name] + + self.func_model.load_state_dict(model_state_dict) + self.logger.info( + f"Loaded model state from '{self.func_model.init_checkpoint}.model'" + ) + + def _load_train_state(): + train_file = f'{self.func_model.init_checkpoint}.train' + if os.path.exists(train_file): + train_state_dict = torch.load( + train_file, map_location=lambda storage, loc: storage) + self.epoch = train_state_dict['epoch'] + self.best_valid_metric = train_state_dict['best_valid_metric'] + if self.optimizer is not None and 'optimizer' in train_state_dict: + self.optimizer.load_state_dict( + train_state_dict['optimizer']) + if self.lr_scheduler is not None and 'lr_scheduler' in train_state_dict: + self.lr_scheduler.load_state_dict( + train_state_dict['lr_scheduler']) + self.logger.info( + f"Loaded train state from '{train_file}' with (epoch-{self.epoch} " + f'best_valid_metric={self.best_valid_metric:.3f})') + else: + self.logger.info(f'Loaded no train state') + + if self.func_model.init_checkpoint is None: + self.logger.info(f'Loaded no model !!!') + return + + if self.do_train: + _load_model_state() + return + + if self.do_infer: + _load_model_state() + _load_train_state() + + +class MultiWOZTrainer(Trainer): + + def __init__(self, + model, + to_tensor, + config, + logger=None, + lr_scheduler=None, + optimizer=None, + reader=None, + evaluator=None): + super(MultiWOZTrainer, + self).__init__(model, to_tensor, config, logger, lr_scheduler, + optimizer, reader, evaluator) + + def train_epoch(self, train_data, dev_data): + """ + Train an epoch. + """ + times = [] + epoch_step = 0 + global_step = 0 + tr_batch_loss = 0.0 + tr_token_loss = 0.0 + self.epoch += 1 + self.batch_metrics_tracker.clear() + self.token_metrics_tracker.clear() + num_training_steps = self.reader.set_stats['train']['num_training_steps_per_epoch'] // \ + self.gradient_accumulation_steps # similar to the original num_batches + + self.model.zero_grad() + data_iterator = self.reader.get_data_iterator(all_batches=train_data) + + for batch_idx, dial_batch in enumerate(data_iterator): + pv_batch = [] + for turn_num, turn_batch in enumerate(dial_batch): + first_turn = (turn_num == 0) + samples, pv_batch = self.reader.convert_batch_turn( + turn_batch, pv_batch, first_turn) + batch, batch_size = self.reader.collate_fn_multi_turn( + samples=samples) + batch = type(batch)( + map(lambda kv: (kv[0], self.to_tensor(kv[1])), + batch.items())) + + # Do a training iteration + start_time = time.time() + metrics = self.model(batch, is_training=True) + if self.gpu > 1: + for metric in metrics: + if metric is not None: + assert len(metric) == self.gpu + nll, token_nll, token_num = metrics + metrics = {} + + token_num = torch.sum(token_num) + token_nll = torch.sum(nll) * (batch_size / + self.gpu) / token_num + nll = torch.mean(nll) + metrics['token_num'] = token_num + metrics['token_nll'] = token_nll + metrics['nll'] = nll + loss = token_nll if self.func_model.token_loss else nll + + metrics['loss'] = loss + else: + loss = metrics['loss'] + self.func_model._optimize( + loss, do_update=False, optimizer=self.optimizer) + metrics = { + k: v.cpu().detach().numpy() + if isinstance(v, torch.Tensor) else v + for k, v in metrics.items() + } + token_num = metrics.pop('token_num', None) + # bow_num = metrics.pop("bow_num", None) + elapsed = time.time() - start_time + times.append(elapsed) + epoch_step += 1 + + tr_batch_loss += metrics['nll'] + tr_token_loss += metrics['token_nll'] + batch_metrics = { + k: v + for k, v in metrics.items() if 'token' not in k + } + token_metrics = { + k: v + for k, v in metrics.items() if 'token' in k + } + self.batch_metrics_tracker.update(batch_metrics, batch_size) + self.token_metrics_tracker.update(token_metrics, token_num) + + if (epoch_step % self.gradient_accumulation_steps == 0) or \ + (epoch_step == self.reader.set_stats['train']['num_training_steps_per_epoch']): + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + global_step += 1 + + if self.log_steps > 0 and global_step % self.log_steps == 0: + batch_metrics_message = self.batch_metrics_tracker.value( + ) + token_metrics_message = self.token_metrics_tracker.value( + ) + message_prefix = f'[Train][{self.epoch}][{global_step}/{num_training_steps}]' + avg_time = f'AVG_Time-{sum(times[-self.log_steps:]) / self.log_steps:.3f}' + message = ' '.join([ + message_prefix, batch_metrics_message, + token_metrics_message, avg_time + ]) + self.logger.info(message) + + self.logger.info('-' * 150) + avg_batch_loss = tr_batch_loss / epoch_step + avg_token_loss = tr_token_loss / epoch_step + batch_metrics_message = self.batch_metrics_tracker.summary() + token_metrics_message = self.token_metrics_tracker.summary() + message_prefix = f'[Valid][{self.epoch}]' + message = ' '.join([ + message_prefix, batch_metrics_message, token_metrics_message, + str(avg_batch_loss), + str(avg_token_loss) + ]) + self.logger.info(message) + + cur_valid_metric = self.batch_metrics_tracker.get( + self.valid_metric_name) + if self.is_decreased_valid_metric: + is_best = cur_valid_metric < self.best_valid_metric + else: + is_best = cur_valid_metric > self.best_valid_metric + if is_best: + self.best_valid_metric = cur_valid_metric + self.save(is_best) + self.logger.info('-' * 150) + + return + + def infer(self, data_type='test'): + """ + Inference interface. + """ + self.logger.info('Generation starts ...') + infer_save_file = os.path.join(self.save_dir, + f'infer_{self.epoch}.result.json') + infer_samples_save_file = os.path.join( + self.save_dir, f'infer_samples_{self.epoch}.result.json') + + # Inference + result_collection = {} + begin_time = time.time() + + eval_data = self.reader.get_eval_data(data_type) + set_stats = self.reader.set_stats[data_type] + self.logger.info('***** Running Evaluation *****') + self.logger.info(' Num Turns = %d', set_stats['num_turns']) + + with torch.no_grad(): + pbar = tqdm(eval_data) + for dial_idx, dialog in enumerate(pbar): + pv_turn = {} + for turn_idx, turn in enumerate(dialog): + first_turn = (turn_idx == 0) + inputs, prompt_id = self.reader.convert_turn_eval( + turn, pv_turn, first_turn) + batch, batch_size = self.reader.collate_fn_multi_turn( + samples=[inputs]) + batch = type(batch)( + map(lambda kv: (kv[0], self.to_tensor(kv[1])), + batch.items())) + if self.reader.use_true_curr_bspn: # generate act, response + max_len = 60 + if not self.reader.use_true_curr_aspn: + max_len = 80 + outputs = self.func_model.infer( + inputs=batch, + start_id=prompt_id, + eos_id=self.reader.eos_r_id, + max_gen_len=max_len) + # resp_gen, need to trim previous context + generated = outputs[0].cpu().numpy().tolist() + try: + decoded = self.decode_generated_act_resp(generated) + except ValueError as exception: + self.logger.info(str(exception)) + self.logger.info(self.tokenizer.decode(generated)) + decoded = {'resp': [], 'bspn': [], 'aspn': []} + else: # predict bspn, access db, then generate act and resp + outputs = self.func_model.infer( + inputs=batch, + start_id=prompt_id, + eos_id=self.reader.eos_b_id, + max_gen_len=60) + generated_bs = outputs[0].cpu().numpy().tolist() + bspn_gen = self.decode_generated_bspn(generated_bs) + # check DB result + if self.reader.use_true_db_pointer: # 控制当前轮的db是否为ground truth + db = turn['db'] + else: + db_result = self.reader.bspan_to_DBpointer( + self.tokenizer.decode(bspn_gen), + turn['turn_domain']) + assert len(turn['db']) == 4 + book_result = turn['db'][2] + assert isinstance(db_result, str) + db = [self.reader.sos_db_id] + \ + self.tokenizer.convert_tokens_to_ids([db_result]) + \ + [book_result] + \ + [self.reader.eos_db_id] + prompt_id = self.reader.sos_a_id + + prev_input = torch.tensor(bspn_gen + db) + if self.func_model.use_gpu: + prev_input = prev_input.cuda() + outputs_db = self.func_model.infer( + inputs=batch, + start_id=prompt_id, + eos_id=self.reader.eos_r_id, + max_gen_len=80, + prev_input=prev_input) + generated_ar = outputs_db[0].cpu().numpy().tolist() + try: + decoded = self.decode_generated_act_resp( + generated_ar) + decoded['bspn'] = bspn_gen + except ValueError as exception: + self.logger.info(str(exception)) + self.logger.info( + self.tokenizer.decode(generated_ar)) + decoded = {'resp': [], 'bspn': [], 'aspn': []} + + turn['resp_gen'] = decoded['resp'] + turn['bspn_gen'] = turn[ + 'bspn'] if self.reader.use_true_curr_bspn else decoded[ + 'bspn'] + turn['aspn_gen'] = turn[ + 'aspn'] if self.reader.use_true_curr_aspn else decoded[ + 'aspn'] + turn['dspn_gen'] = turn['dspn'] + + pv_turn['labels'] = inputs[ + 'labels'] # all true previous context + pv_turn['resp'] = turn[ + 'resp'] if self.reader.use_true_prev_resp else decoded[ + 'resp'] + if not self.reader.use_true_curr_bspn: + pv_turn['bspn'] = turn[ + 'bspn'] if self.reader.use_true_prev_bspn else decoded[ + 'bspn'] + pv_turn['db'] = turn[ + 'db'] if self.reader.use_true_prev_bspn else db + pv_turn['aspn'] = turn[ + 'aspn'] if self.reader.use_true_prev_aspn else decoded[ + 'aspn'] + + tmp_dialog_result = self.reader.inverse_transpose_turn(dialog) + result_collection.update(tmp_dialog_result) + + # compute tmp scores + results, _ = self.reader.wrap_result_lm(tmp_dialog_result) + bleu, success, match = self.evaluator.validation_metric( + results) + score = 0.5 * (success + match) + bleu + pbar.set_description( + 'match: %2.2f success: %2.2f bleu: %2.2f score: %.2f' % + (match, success, bleu, score)) + + # compute scores + results, _ = self.reader.wrap_result_lm(result_collection) + bleu, success, match = self.evaluator.validation_metric(results) + score = 0.5 * (success + match) + bleu + + # log results + metrics_message = 'match: %2.2f success: %2.2f bleu: %2.2f score: %.2f' %\ + (match, success, bleu, score) + message_prefix = f'[Infer][{self.epoch}]' + time_cost = f'TIME-{time.time() - begin_time:.3f}' + message = ' '.join([message_prefix, metrics_message, time_cost]) + self.logger.info(message) + + # save results + eval_results = { + 'bleu': bleu, + 'success': success, + 'match': match, + 'score': score, + 'result': message + } + with open(infer_save_file, 'w') as fp: + json.dump(eval_results, fp, indent=2) + self.logger.info(f'Saved inference results to {infer_save_file}') + with open(infer_samples_save_file, 'w') as fp: + for sample in results: + line = json.dumps(sample) + fp.write(line) + fp.write('\n') + self.logger.info( + f'Saved inference samples to {infer_samples_save_file}') + + return + + def forward(self, turn, old_pv_turn): + with torch.no_grad(): + first_turn = True if len(old_pv_turn) == 0 else False + inputs, prompt_id = self.reader.convert_turn_eval( + turn, old_pv_turn, first_turn) + batch, batch_size = self.reader.collate_fn_multi_turn( + samples=[inputs]) + batch = type(batch)( + map(lambda kv: (kv[0], self.to_tensor(kv[1])), batch.items())) + pv_turn = {} + print(batch) + + outputs = self.func_model.infer( + inputs=batch, + start_id=prompt_id, + eos_id=self.reader.eos_b_id, + max_gen_len=60) + generated_bs = outputs[0].cpu().numpy().tolist() + bspn_gen = self.decode_generated_bspn(generated_bs) + bspn_token = self.tokenizer.convert_ids_to_tokens(bspn_gen) + print(bspn_gen) + print(bspn_token) + turn_domain = [] + for item in bspn_token: + if item.startswith('[') and item.endswith(']'): + turn_domain.append(item) + print(turn_domain) + db_result = self.reader.bspan_to_DBpointer( + self.tokenizer.decode(bspn_gen), ['[taxi]']) + print(db_result) + book_result = 21 + db = [self.reader.sos_db_id] + \ + self.tokenizer.convert_tokens_to_ids([db_result]) + \ + [book_result] + \ + [self.reader.eos_db_id] + prompt_id = self.reader.sos_a_id + + prev_input = torch.tensor(bspn_gen + db) + if self.func_model.use_gpu: + prev_input = prev_input.cuda() + + outputs_db = self.func_model.infer( + inputs=batch, + start_id=prompt_id, + eos_id=self.reader.eos_r_id, + max_gen_len=80, + prev_input=prev_input) + generated_ar = outputs_db[0].cpu().numpy().tolist() + decoded = self.decode_generated_act_resp(generated_ar) + decoded['bspn'] = bspn_gen + print(decoded) + print(self.tokenizer.convert_ids_to_tokens(decoded['resp'])) + + pv_turn['labels'] = None + pv_turn['resp'] = decoded['resp'] + pv_turn['bspn'] = decoded['bspn'] + pv_turn['db'] = None + pv_turn['aspn'] = None + + return pv_turn diff --git a/tests/pipelines/nlp/test_dialog_generation.py b/tests/pipelines/nlp/test_dialog_generation.py index 7b42059a..1baee3df 100644 --- a/tests/pipelines/nlp/test_dialog_generation.py +++ b/tests/pipelines/nlp/test_dialog_generation.py @@ -26,18 +26,19 @@ class DialogGenerationTest(unittest.TestCase): model_dir=modeldir, text_field=preprocessor.text_field, config=preprocessor.config) - # pipeline = DialogGenerationPipeline(model, preprocessor) - - history_dialog = {} - for step, item in enumerate(test_case['sng0073']['log']): - user_question = item['user'] - print('user: {}'.format(user_question)) - - # history_dialog_info = merge(history_dialog_info, - # result) if step > 0 else {} - # result = pipeline(user_question, history=history_dialog_info) - # - # print('sys : {}'.format(result['pred_answer'])) + print(model.forward(None)) + # pipeline = DialogGenerationPipeline(model=model, preprocessor=preprocessor) + # + # history_dialog_info = {} + # for step, item in enumerate(test_case['sng0073']['log']): + # user_question = item['user'] + # print('user: {}'.format(user_question)) + # + # # history_dialog_info = merge(history_dialog_info, + # # result) if step > 0 else {} + # result = pipeline(user_question, history=history_dialog_info) + # # + # # print('sys : {}'.format(result['pred_answer'])) if __name__ == '__main__':