Browse Source

model forward ready

master
ly119399 3 years ago
parent
commit
b698506a2c
7 changed files with 861 additions and 17 deletions
  1. +50
    -5
      maas_lib/models/nlp/space/dialog_generation_model.py
  2. +0
    -0
      maas_lib/trainers/nlp/space/__init__.py
  3. +0
    -0
      maas_lib/trainers/nlp/space/metrics/__init__.py
  4. +73
    -0
      maas_lib/trainers/nlp/space/metrics/metrics_tracker.py
  5. +0
    -0
      maas_lib/trainers/nlp/space/trainers/__init__.py
  6. +725
    -0
      maas_lib/trainers/nlp/space/trainers/gen_trainer.py
  7. +13
    -12
      tests/pipelines/nlp/test_dialog_generation.py

+ 50
- 5
maas_lib/models/nlp/space/dialog_generation_model.py View File

@@ -1,5 +1,6 @@
from typing import Any, Dict, Optional

from maas_lib.trainers.nlp.space.trainers.gen_trainer import MultiWOZTrainer
from maas_lib.utils.constant import Tasks
from ...base import Model, Tensor
from ...builder import MODELS
@@ -32,6 +33,22 @@ class DialogGenerationModel(Model):
reader=self.text_field,
generator=self.generator)

def to_tensor(array):
"""
numpy array -> tensor
"""
import torch
array = torch.tensor(array)
return array.cuda() if self.config.use_gpu else array

self.trainer = MultiWOZTrainer(
model=self.model,
to_tensor=to_tensor,
config=self.config,
reader=self.text_field,
evaluator=None)
self.trainer.load()

def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""return the result by the model

@@ -48,10 +65,38 @@ class DialogGenerationModel(Model):
}
"""
from numpy import array, float32
import torch

return {
'predictions': array([1]), # lable 0-negative 1-positive
'probabilities': array([[0.11491239, 0.8850876]], dtype=float32),
'logits': array([[-0.53860897, 1.5029076]],
dtype=float32) # true value
turn_1 = {
'user': [
13, 1045, 2052, 2066, 1037, 10095, 2013, 3002, 2198, 1005,
1055, 2267, 2000, 10733, 12570, 21713, 4487, 15474, 1012, 7
]
}
old_pv_turn_1 = {}

turn_2 = {
'user':
[13, 1045, 2215, 2000, 2681, 2044, 2459, 1024, 2321, 1012, 7]
}
old_pv_turn_2 = {
'labels': [[
13, 1045, 2052, 2066, 1037, 10095, 2013, 3002, 2198, 1005,
1055, 2267, 2000, 10733, 12570, 21713, 4487, 15474, 1012, 7
]],
'resp': [
14, 1045, 2052, 2022, 3407, 2000, 2393, 2007, 2115, 5227, 1010,
2079, 2017, 2031, 1037, 2051, 2017, 2052, 2066, 2000, 2681,
2030, 7180, 2011, 1029, 8
],
'bspn': [
15, 43, 7688, 10733, 12570, 21713, 4487, 15474, 6712, 3002,
2198, 1005, 1055, 2267, 9
],
'db': [19, 24, 21, 20],
'aspn': [16, 43, 48, 2681, 7180, 10]
}

pv_turn = self.trainer.forward(turn=turn_2, old_pv_turn=old_pv_turn_2)

return pv_turn

+ 0
- 0
maas_lib/trainers/nlp/space/__init__.py View File


+ 0
- 0
maas_lib/trainers/nlp/space/metrics/__init__.py View File


+ 73
- 0
maas_lib/trainers/nlp/space/metrics/metrics_tracker.py View File

@@ -0,0 +1,73 @@
"""
MetricsTracker class
"""

import math
from collections import defaultdict


class MetricsTracker(object):
""" Tracking metrics. """

def __init__(self):
self.metrics_val = defaultdict(float) # 记录最新一个batch返回的指标
self.metrics_avg = defaultdict(float) # 维护一个epoch内已训练batches的平均指标
self.num_samples = 0

def update(self, metrics, num_samples):
for key, val in metrics.items():
if val is not None:
val = float(val) # [val] -> val
self.metrics_val[key] = val
avg_val = (self.metrics_avg.get(key, 0) * self.num_samples +
val * num_samples) / (
self.num_samples + num_samples)
self.metrics_avg[key] = avg_val
self.num_samples += num_samples

def clear(self):
self.metrics_val = defaultdict(float)
self.metrics_avg = defaultdict(float)
self.num_samples = 0

def items(self):
return self.metrics_avg.items()

def get(self, name):
if self.num_samples == 0:
raise ValueError('There is no data in Metrics.')
return self.metrics_avg.get(name)

def state_dict(self):
return {
'metrics_val': self.metrics_val,
'metrics_avg': self.metrics_avg,
'num_samples': self.num_samples,
}

def load_state_dict(self, state_dict):
self.metrics_val = state_dict['metrics_val']
self.metrics_avg = state_dict['metrics_avg']
self.num_samples = state_dict['num_samples']

def value(self):
metric_strs = []
for key, val in self.metrics_val.items():
metric_str = f'{key.upper()}-{val:.3f}'
metric_strs.append(metric_str)
if 'token_nll' in self.metrics_val:
metric_str = f"TOKEN_PPL-{math.exp(self.metrics_val['token_nll']):.3f}"
metric_strs.append(metric_str)
metric_strs = ' '.join(metric_strs)
return metric_strs

def summary(self):
metric_strs = []
for key, val in self.metrics_avg.items():
metric_str = f'{key.upper()}-{val:.3f}'
metric_strs.append(metric_str)
if 'token_nll' in self.metrics_avg:
metric_str = f"TOKEN_PPL-{math.exp(self.metrics_avg['token_nll']):.3f}"
metric_strs.append(metric_str)
metric_strs = ' '.join(metric_strs)
return metric_strs

+ 0
- 0
maas_lib/trainers/nlp/space/trainers/__init__.py View File


+ 725
- 0
maas_lib/trainers/nlp/space/trainers/gen_trainer.py View File

@@ -0,0 +1,725 @@
"""
Trainer class.
"""
import logging
import os
import sys
import time
from collections import OrderedDict

import json
import numpy as np
import torch
from tqdm import tqdm
from transformers.optimization import AdamW, get_linear_schedule_with_warmup

from ..metrics.metrics_tracker import MetricsTracker


def get_logger(log_path, name='default'):
logger = logging.getLogger(name)
logger.propagate = False
logger.setLevel(logging.DEBUG)

formatter = logging.Formatter('%(message)s')

sh = logging.StreamHandler(sys.stdout)
sh.setFormatter(formatter)
logger.addHandler(sh)

fh = logging.FileHandler(log_path, mode='w')
fh.setFormatter(formatter)
logger.addHandler(fh)

return logger


class Trainer(object):

def __init__(self,
model,
to_tensor,
config,
logger=None,
lr_scheduler=None,
optimizer=None,
reader=None,
evaluator=None):
self.to_tensor = to_tensor

self.do_train = config.do_train
self.do_infer = config.do_infer
self.is_decreased_valid_metric = config.Trainer.valid_metric_name[
0] == '-'
self.valid_metric_name = config.Trainer.valid_metric_name[1:]
self.num_epochs = config.Trainer.num_epochs
# self.save_dir = config.Trainer.save_dir
self.log_steps = config.Trainer.log_steps
self.valid_steps = config.Trainer.valid_steps
self.save_checkpoint = config.Trainer.save_checkpoint
self.save_summary = config.Trainer.save_summary
self.lr = config.Model.lr
self.weight_decay = config.Model.weight_decay
self.batch_size = config.Trainer.batch_size
self.gradient_accumulation_steps = config.Model.gradient_accumulation_steps
self.warmup_steps = config.Model.warmup_steps
self.gpu = config.Trainer.gpu

self.lr_scheduler = lr_scheduler
self.optimizer = optimizer

self.model = model
self.func_model = self.model.module if self.gpu > 1 else self.model
self.reader = reader
self.evaluator = evaluator
self.tokenizer = reader.tokenizer

# if not os.path.exists(self.save_dir):
# os.makedirs(self.save_dir)

# self.logger = logger or get_logger(os.path.join(self.save_dir, "trainer.log"), "trainer")
self.logger = logger or get_logger('trainer.log', 'trainer')

self.batch_metrics_tracker = MetricsTracker()
self.token_metrics_tracker = MetricsTracker()

self.best_valid_metric = float(
'inf' if self.is_decreased_valid_metric else '-inf')
self.epoch = 0

def decode_generated_bspn_resp(self, generated):
"""
decode generated
return decoded ('bspn', 'resp')
"""
decoded = {}
eos_r_id = self.reader.eos_r_id
eos_b_id = self.reader.eos_b_id

# eos_r may not exists if gpt2 generated repetitive words.
if eos_r_id in generated:
eos_r_idx = generated.index(eos_r_id)
else:
eos_r_idx = len(generated) - 1
# self.logger.info('eos_r not in generated: ' + self.tokenizer.decode(generated))

# predicted bspn, resp
eos_b_idx = generated.index(eos_b_id)
decoded['bspn'] = generated[:eos_b_idx + 1]
decoded['resp'] = generated[eos_b_idx + 1:eos_r_idx + 1]
return decoded

def decode_generated_act_resp(self, generated):
"""
decode generated
return decoded['resp'] ('bspn', 'aspn')
"""
decoded = {}
eos_a_id = self.reader.eos_a_id
eos_r_id = self.reader.eos_r_id
eos_b_id = self.reader.eos_b_id

# eos_r may not exists if gpt2 generated repetitive words.
if eos_r_id in generated:
eos_r_idx = generated.index(eos_r_id)
else:
eos_r_idx = len(generated) - 1
self.logger.info('eos_r not in generated: ' +
self.tokenizer.decode(generated))

if self.reader.use_true_curr_aspn: # only predict resp
decoded['resp'] = generated[:eos_r_idx + 1]
else: # predicted aspn, resp
eos_a_idx = generated.index(eos_a_id)
decoded['aspn'] = generated[:eos_a_idx + 1]
decoded['resp'] = generated[eos_a_idx + 1:eos_r_idx + 1]
return decoded

def decode_generated_bspn(self, generated):
eos_b_id = self.reader.eos_b_id
if eos_b_id in generated:
eos_b_idx = generated.index(eos_b_id)
else:
eos_b_idx = len(generated) - 1
return generated[:eos_b_idx + 1]

def set_optimizers(self):
"""
Setup the optimizer and the learning rate scheduler.

from transformers.Trainer

parameters from cfg: lr (1e-3); warmup_steps
"""
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'norm.weight']
optimizer_grouped_parameters = [
{
'params': [
p for n, p in self.model.named_parameters()
if not any(nd in n for nd in no_decay)
],
'weight_decay':
self.weight_decay,
},
{
'params': [
p for n, p in self.model.named_parameters()
if any(nd in n for nd in no_decay)
],
'weight_decay':
0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=self.lr)

num_training_steps = self.reader.set_stats['train']['num_training_steps_per_epoch'] * \
self.num_epochs // self.gradient_accumulation_steps
num_warmup_steps = self.warmup_steps if self.warmup_steps >= 0 else int(
num_training_steps * 0.1)
lr_scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps)

self.optimizer = optimizer
self.lr_scheduler = lr_scheduler

def train(self, train_data, dev_data):
# log info
set_stats = self.reader.set_stats['train']
self.logger.info('***** Running training *****')
self.logger.info(
' Num Training steps(one turn in a batch of dialogs) per epoch = %d',
set_stats['num_training_steps_per_epoch'])
self.logger.info(' Num Turns = %d', set_stats['num_turns'])
self.logger.info(' Num Dialogs = %d', set_stats['num_dials'])
self.logger.info(' Num Epochs = %d', self.num_epochs)
self.logger.info(' Batch size = %d', self.batch_size)
self.logger.info(' Gradient Accumulation steps = %d',
self.gradient_accumulation_steps)
self.logger.info(
' Total optimization steps = %d',
set_stats['num_training_steps_per_epoch'] * self.num_epochs //
self.gradient_accumulation_steps)

# begin training
num_epochs = self.num_epochs - self.epoch
for epoch in range(num_epochs):
self.train_epoch(train_data=train_data, dev_data=dev_data)

def train_epoch(self, train_data, dev_data):
"""
Train an epoch.
"""
raise NotImplementedError

def infer(self, data_type):
"""
Inference interface.
"""
raise NotImplementedError

def forward(self, turn, old_pv_turn):
"""
one turn inference
"""
raise NotImplementedError

def save(self, is_best=False):
""" save """
train_state = {
'epoch': self.epoch,
'best_valid_metric': self.best_valid_metric,
'optimizer': self.optimizer.state_dict()
}
if self.lr_scheduler is not None:
train_state['lr_scheduler'] = self.lr_scheduler.state_dict()

# Save checkpoint
if self.save_checkpoint:
model_file = os.path.join(self.save_dir,
f'state_epoch_{self.epoch}.model')
torch.save(self.model.state_dict(), model_file)
self.logger.info(f"Saved model state to '{model_file}'")

train_file = os.path.join(self.save_dir,
f'state_epoch_{self.epoch}.train')
torch.save(train_state, train_file)
self.logger.info(f"Saved train state to '{train_file}'")

# Save current best model
if is_best:
best_model_file = os.path.join(self.save_dir, 'best.model')
torch.save(self.model.state_dict(), best_model_file)
best_train_file = os.path.join(self.save_dir, 'best.train')
torch.save(train_state, best_train_file)
self.logger.info(
f"Saved best model state to '{best_model_file}' with new best valid metric "
f'{self.valid_metric_name.upper()}={self.best_valid_metric:.3f}'
)

def load(self):
""" load """

def _load_model_state():
model_state_dict = torch.load(
f'{self.func_model.init_checkpoint}',
map_location=lambda storage, loc: storage)

if 'module.' in list(model_state_dict.keys())[0]:
new_model_state_dict = OrderedDict()
for k, v in model_state_dict.items():
assert k[:7] == 'module.'
new_model_state_dict[k[7:]] = v
model_state_dict = new_model_state_dict

new_model_state_dict = OrderedDict()
parameters = {
name: param
for name, param in self.func_model.named_parameters()
}
for name, param in model_state_dict.items():
if name in parameters:
if param.shape != parameters[name].shape:
assert hasattr(param, 'numpy')
arr = param.numpy()
z = np.random.normal(
scale=self.func_model.initializer_range,
size=parameters[name].shape).astype('float32')
if name == 'embedder.token_embedding.weight':
z[-param.shape[0]:] = arr
print(
f'part of parameter({name}) random normlize initialize'
)
else:
if z.shape[0] < param.shape[0]:
z = arr[:z.shape[0]]
print(f'part of parameter({name}) are dropped')
else:
z[:param.shape[0]] = arr
print(
f'part of parameter({name}) random normlize initialize'
)
dtype, device = param.dtype, param.device
z = torch.tensor(z, dtype=dtype, device=device)
new_model_state_dict[name] = z
else:
new_model_state_dict[name] = param
else:
print(f'parameter({name}) are dropped')
model_state_dict = new_model_state_dict

for name in parameters:
if name not in model_state_dict:
if parameters[name].requires_grad:
print(f'parameter({name}) random normlize initialize')
z = np.random.normal(
scale=self.func_model.initializer_range,
size=parameters[name].shape).astype('float32')
dtype, device = parameters[name].dtype, parameters[
name].device
model_state_dict[name] = torch.tensor(
z, dtype=dtype, device=device)
else:
model_state_dict[name] = parameters[name]

self.func_model.load_state_dict(model_state_dict)
self.logger.info(
f"Loaded model state from '{self.func_model.init_checkpoint}.model'"
)

def _load_train_state():
train_file = f'{self.func_model.init_checkpoint}.train'
if os.path.exists(train_file):
train_state_dict = torch.load(
train_file, map_location=lambda storage, loc: storage)
self.epoch = train_state_dict['epoch']
self.best_valid_metric = train_state_dict['best_valid_metric']
if self.optimizer is not None and 'optimizer' in train_state_dict:
self.optimizer.load_state_dict(
train_state_dict['optimizer'])
if self.lr_scheduler is not None and 'lr_scheduler' in train_state_dict:
self.lr_scheduler.load_state_dict(
train_state_dict['lr_scheduler'])
self.logger.info(
f"Loaded train state from '{train_file}' with (epoch-{self.epoch} "
f'best_valid_metric={self.best_valid_metric:.3f})')
else:
self.logger.info(f'Loaded no train state')

if self.func_model.init_checkpoint is None:
self.logger.info(f'Loaded no model !!!')
return

if self.do_train:
_load_model_state()
return

if self.do_infer:
_load_model_state()
_load_train_state()


class MultiWOZTrainer(Trainer):

def __init__(self,
model,
to_tensor,
config,
logger=None,
lr_scheduler=None,
optimizer=None,
reader=None,
evaluator=None):
super(MultiWOZTrainer,
self).__init__(model, to_tensor, config, logger, lr_scheduler,
optimizer, reader, evaluator)

def train_epoch(self, train_data, dev_data):
"""
Train an epoch.
"""
times = []
epoch_step = 0
global_step = 0
tr_batch_loss = 0.0
tr_token_loss = 0.0
self.epoch += 1
self.batch_metrics_tracker.clear()
self.token_metrics_tracker.clear()
num_training_steps = self.reader.set_stats['train']['num_training_steps_per_epoch'] // \
self.gradient_accumulation_steps # similar to the original num_batches

self.model.zero_grad()
data_iterator = self.reader.get_data_iterator(all_batches=train_data)

for batch_idx, dial_batch in enumerate(data_iterator):
pv_batch = []
for turn_num, turn_batch in enumerate(dial_batch):
first_turn = (turn_num == 0)
samples, pv_batch = self.reader.convert_batch_turn(
turn_batch, pv_batch, first_turn)
batch, batch_size = self.reader.collate_fn_multi_turn(
samples=samples)
batch = type(batch)(
map(lambda kv: (kv[0], self.to_tensor(kv[1])),
batch.items()))

# Do a training iteration
start_time = time.time()
metrics = self.model(batch, is_training=True)
if self.gpu > 1:
for metric in metrics:
if metric is not None:
assert len(metric) == self.gpu
nll, token_nll, token_num = metrics
metrics = {}

token_num = torch.sum(token_num)
token_nll = torch.sum(nll) * (batch_size /
self.gpu) / token_num
nll = torch.mean(nll)
metrics['token_num'] = token_num
metrics['token_nll'] = token_nll
metrics['nll'] = nll
loss = token_nll if self.func_model.token_loss else nll

metrics['loss'] = loss
else:
loss = metrics['loss']
self.func_model._optimize(
loss, do_update=False, optimizer=self.optimizer)
metrics = {
k: v.cpu().detach().numpy()
if isinstance(v, torch.Tensor) else v
for k, v in metrics.items()
}
token_num = metrics.pop('token_num', None)
# bow_num = metrics.pop("bow_num", None)
elapsed = time.time() - start_time
times.append(elapsed)
epoch_step += 1

tr_batch_loss += metrics['nll']
tr_token_loss += metrics['token_nll']
batch_metrics = {
k: v
for k, v in metrics.items() if 'token' not in k
}
token_metrics = {
k: v
for k, v in metrics.items() if 'token' in k
}
self.batch_metrics_tracker.update(batch_metrics, batch_size)
self.token_metrics_tracker.update(token_metrics, token_num)

if (epoch_step % self.gradient_accumulation_steps == 0) or \
(epoch_step == self.reader.set_stats['train']['num_training_steps_per_epoch']):
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
global_step += 1

if self.log_steps > 0 and global_step % self.log_steps == 0:
batch_metrics_message = self.batch_metrics_tracker.value(
)
token_metrics_message = self.token_metrics_tracker.value(
)
message_prefix = f'[Train][{self.epoch}][{global_step}/{num_training_steps}]'
avg_time = f'AVG_Time-{sum(times[-self.log_steps:]) / self.log_steps:.3f}'
message = ' '.join([
message_prefix, batch_metrics_message,
token_metrics_message, avg_time
])
self.logger.info(message)

self.logger.info('-' * 150)
avg_batch_loss = tr_batch_loss / epoch_step
avg_token_loss = tr_token_loss / epoch_step
batch_metrics_message = self.batch_metrics_tracker.summary()
token_metrics_message = self.token_metrics_tracker.summary()
message_prefix = f'[Valid][{self.epoch}]'
message = ' '.join([
message_prefix, batch_metrics_message, token_metrics_message,
str(avg_batch_loss),
str(avg_token_loss)
])
self.logger.info(message)

cur_valid_metric = self.batch_metrics_tracker.get(
self.valid_metric_name)
if self.is_decreased_valid_metric:
is_best = cur_valid_metric < self.best_valid_metric
else:
is_best = cur_valid_metric > self.best_valid_metric
if is_best:
self.best_valid_metric = cur_valid_metric
self.save(is_best)
self.logger.info('-' * 150)

return

def infer(self, data_type='test'):
"""
Inference interface.
"""
self.logger.info('Generation starts ...')
infer_save_file = os.path.join(self.save_dir,
f'infer_{self.epoch}.result.json')
infer_samples_save_file = os.path.join(
self.save_dir, f'infer_samples_{self.epoch}.result.json')

# Inference
result_collection = {}
begin_time = time.time()

eval_data = self.reader.get_eval_data(data_type)
set_stats = self.reader.set_stats[data_type]
self.logger.info('***** Running Evaluation *****')
self.logger.info(' Num Turns = %d', set_stats['num_turns'])

with torch.no_grad():
pbar = tqdm(eval_data)
for dial_idx, dialog in enumerate(pbar):
pv_turn = {}
for turn_idx, turn in enumerate(dialog):
first_turn = (turn_idx == 0)
inputs, prompt_id = self.reader.convert_turn_eval(
turn, pv_turn, first_turn)
batch, batch_size = self.reader.collate_fn_multi_turn(
samples=[inputs])
batch = type(batch)(
map(lambda kv: (kv[0], self.to_tensor(kv[1])),
batch.items()))
if self.reader.use_true_curr_bspn: # generate act, response
max_len = 60
if not self.reader.use_true_curr_aspn:
max_len = 80
outputs = self.func_model.infer(
inputs=batch,
start_id=prompt_id,
eos_id=self.reader.eos_r_id,
max_gen_len=max_len)
# resp_gen, need to trim previous context
generated = outputs[0].cpu().numpy().tolist()
try:
decoded = self.decode_generated_act_resp(generated)
except ValueError as exception:
self.logger.info(str(exception))
self.logger.info(self.tokenizer.decode(generated))
decoded = {'resp': [], 'bspn': [], 'aspn': []}
else: # predict bspn, access db, then generate act and resp
outputs = self.func_model.infer(
inputs=batch,
start_id=prompt_id,
eos_id=self.reader.eos_b_id,
max_gen_len=60)
generated_bs = outputs[0].cpu().numpy().tolist()
bspn_gen = self.decode_generated_bspn(generated_bs)
# check DB result
if self.reader.use_true_db_pointer: # 控制当前轮的db是否为ground truth
db = turn['db']
else:
db_result = self.reader.bspan_to_DBpointer(
self.tokenizer.decode(bspn_gen),
turn['turn_domain'])
assert len(turn['db']) == 4
book_result = turn['db'][2]
assert isinstance(db_result, str)
db = [self.reader.sos_db_id] + \
self.tokenizer.convert_tokens_to_ids([db_result]) + \
[book_result] + \
[self.reader.eos_db_id]
prompt_id = self.reader.sos_a_id

prev_input = torch.tensor(bspn_gen + db)
if self.func_model.use_gpu:
prev_input = prev_input.cuda()
outputs_db = self.func_model.infer(
inputs=batch,
start_id=prompt_id,
eos_id=self.reader.eos_r_id,
max_gen_len=80,
prev_input=prev_input)
generated_ar = outputs_db[0].cpu().numpy().tolist()
try:
decoded = self.decode_generated_act_resp(
generated_ar)
decoded['bspn'] = bspn_gen
except ValueError as exception:
self.logger.info(str(exception))
self.logger.info(
self.tokenizer.decode(generated_ar))
decoded = {'resp': [], 'bspn': [], 'aspn': []}

turn['resp_gen'] = decoded['resp']
turn['bspn_gen'] = turn[
'bspn'] if self.reader.use_true_curr_bspn else decoded[
'bspn']
turn['aspn_gen'] = turn[
'aspn'] if self.reader.use_true_curr_aspn else decoded[
'aspn']
turn['dspn_gen'] = turn['dspn']

pv_turn['labels'] = inputs[
'labels'] # all true previous context
pv_turn['resp'] = turn[
'resp'] if self.reader.use_true_prev_resp else decoded[
'resp']
if not self.reader.use_true_curr_bspn:
pv_turn['bspn'] = turn[
'bspn'] if self.reader.use_true_prev_bspn else decoded[
'bspn']
pv_turn['db'] = turn[
'db'] if self.reader.use_true_prev_bspn else db
pv_turn['aspn'] = turn[
'aspn'] if self.reader.use_true_prev_aspn else decoded[
'aspn']

tmp_dialog_result = self.reader.inverse_transpose_turn(dialog)
result_collection.update(tmp_dialog_result)

# compute tmp scores
results, _ = self.reader.wrap_result_lm(tmp_dialog_result)
bleu, success, match = self.evaluator.validation_metric(
results)
score = 0.5 * (success + match) + bleu
pbar.set_description(
'match: %2.2f success: %2.2f bleu: %2.2f score: %.2f' %
(match, success, bleu, score))

# compute scores
results, _ = self.reader.wrap_result_lm(result_collection)
bleu, success, match = self.evaluator.validation_metric(results)
score = 0.5 * (success + match) + bleu

# log results
metrics_message = 'match: %2.2f success: %2.2f bleu: %2.2f score: %.2f' %\
(match, success, bleu, score)
message_prefix = f'[Infer][{self.epoch}]'
time_cost = f'TIME-{time.time() - begin_time:.3f}'
message = ' '.join([message_prefix, metrics_message, time_cost])
self.logger.info(message)

# save results
eval_results = {
'bleu': bleu,
'success': success,
'match': match,
'score': score,
'result': message
}
with open(infer_save_file, 'w') as fp:
json.dump(eval_results, fp, indent=2)
self.logger.info(f'Saved inference results to {infer_save_file}')
with open(infer_samples_save_file, 'w') as fp:
for sample in results:
line = json.dumps(sample)
fp.write(line)
fp.write('\n')
self.logger.info(
f'Saved inference samples to {infer_samples_save_file}')

return

def forward(self, turn, old_pv_turn):
with torch.no_grad():
first_turn = True if len(old_pv_turn) == 0 else False
inputs, prompt_id = self.reader.convert_turn_eval(
turn, old_pv_turn, first_turn)
batch, batch_size = self.reader.collate_fn_multi_turn(
samples=[inputs])
batch = type(batch)(
map(lambda kv: (kv[0], self.to_tensor(kv[1])), batch.items()))
pv_turn = {}
print(batch)

outputs = self.func_model.infer(
inputs=batch,
start_id=prompt_id,
eos_id=self.reader.eos_b_id,
max_gen_len=60)
generated_bs = outputs[0].cpu().numpy().tolist()
bspn_gen = self.decode_generated_bspn(generated_bs)
bspn_token = self.tokenizer.convert_ids_to_tokens(bspn_gen)
print(bspn_gen)
print(bspn_token)
turn_domain = []
for item in bspn_token:
if item.startswith('[') and item.endswith(']'):
turn_domain.append(item)
print(turn_domain)
db_result = self.reader.bspan_to_DBpointer(
self.tokenizer.decode(bspn_gen), ['[taxi]'])
print(db_result)
book_result = 21
db = [self.reader.sos_db_id] + \
self.tokenizer.convert_tokens_to_ids([db_result]) + \
[book_result] + \
[self.reader.eos_db_id]
prompt_id = self.reader.sos_a_id

prev_input = torch.tensor(bspn_gen + db)
if self.func_model.use_gpu:
prev_input = prev_input.cuda()

outputs_db = self.func_model.infer(
inputs=batch,
start_id=prompt_id,
eos_id=self.reader.eos_r_id,
max_gen_len=80,
prev_input=prev_input)
generated_ar = outputs_db[0].cpu().numpy().tolist()
decoded = self.decode_generated_act_resp(generated_ar)
decoded['bspn'] = bspn_gen
print(decoded)
print(self.tokenizer.convert_ids_to_tokens(decoded['resp']))

pv_turn['labels'] = None
pv_turn['resp'] = decoded['resp']
pv_turn['bspn'] = decoded['bspn']
pv_turn['db'] = None
pv_turn['aspn'] = None

return pv_turn

+ 13
- 12
tests/pipelines/nlp/test_dialog_generation.py View File

@@ -26,18 +26,19 @@ class DialogGenerationTest(unittest.TestCase):
model_dir=modeldir,
text_field=preprocessor.text_field,
config=preprocessor.config)
# pipeline = DialogGenerationPipeline(model, preprocessor)

history_dialog = {}
for step, item in enumerate(test_case['sng0073']['log']):
user_question = item['user']
print('user: {}'.format(user_question))

# history_dialog_info = merge(history_dialog_info,
# result) if step > 0 else {}
# result = pipeline(user_question, history=history_dialog_info)
#
# print('sys : {}'.format(result['pred_answer']))
print(model.forward(None))
# pipeline = DialogGenerationPipeline(model=model, preprocessor=preprocessor)
#
# history_dialog_info = {}
# for step, item in enumerate(test_case['sng0073']['log']):
# user_question = item['user']
# print('user: {}'.format(user_question))
#
# # history_dialog_info = merge(history_dialog_info,
# # result) if step > 0 else {}
# result = pipeline(user_question, history=history_dialog_info)
# #
# # print('sys : {}'.format(result['pred_answer']))


if __name__ == '__main__':


Loading…
Cancel
Save