@@ -3,6 +3,8 @@ from typing import Any, Dict, Optional | |||
from maas_lib.utils.constant import Tasks | |||
from ...base import Model, Tensor | |||
from ...builder import MODELS | |||
from .model.generator import Generator | |||
from .model.model_base import ModelBase | |||
__all__ = ['DialogGenerationModel'] | |||
@@ -21,7 +23,14 @@ class DialogGenerationModel(Model): | |||
super().__init__(model_dir, *args, **kwargs) | |||
self.model_dir = model_dir | |||
pass | |||
self.text_field = kwargs.pop('text_field') | |||
self.config = kwargs.pop('config') | |||
self.generator = Generator.create(self.config, reader=self.text_field) | |||
self.model = ModelBase.create( | |||
model_dir=model_dir, | |||
config=self.config, | |||
reader=self.text_field, | |||
generator=self.generator) | |||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
"""return the result by the model | |||
@@ -0,0 +1,285 @@ | |||
""" | |||
IntentUnifiedTransformer | |||
""" | |||
import torch | |||
from maas_lib.models.nlp.space.model.unified_transformer import \ | |||
UnifiedTransformer | |||
class GenUnifiedTransformer(UnifiedTransformer): | |||
""" | |||
Implement generation unified transformer. | |||
""" | |||
def __init__(self, model_dir, config, reader, generator): | |||
super(GenUnifiedTransformer, self).__init__(model_dir, config, reader, | |||
generator) | |||
self.understand = config.BPETextField.understand | |||
if self.use_gpu: | |||
self.cuda() | |||
return | |||
def _forward(self, inputs, is_training, with_label): | |||
""" Real forward process of model in different mode(train/test). """ | |||
def cat(x, y, dim=1): | |||
return torch.cat([x, y], dim=dim) | |||
outputs = {} | |||
if self.understand or self.policy: | |||
if self.understand: | |||
prompt_token = inputs['understand_token'] | |||
prompt_mask = inputs['understand_mask'] | |||
if self.policy: | |||
prompt_token = cat(prompt_token, inputs['policy_token']) | |||
prompt_mask = cat(prompt_mask, inputs['policy_mask']) | |||
else: | |||
prompt_token = inputs['policy_token'] | |||
prompt_mask = inputs['policy_mask'] | |||
enc_embed, dec_embed, prompt_embed = self._encoder_prompt_decoder_network( | |||
src_token=inputs['src_token'], | |||
src_mask=inputs['src_mask'], | |||
tgt_token=inputs['tgt_token'][:, :-1], | |||
tgt_mask=inputs['tgt_mask'][:, :-1], | |||
prompt_token=prompt_token, | |||
prompt_mask=prompt_mask, | |||
src_pos=inputs['src_pos'], | |||
src_type=inputs['src_type'], | |||
src_turn=inputs['src_turn'], | |||
tgt_pos=inputs['tgt_pos'][:, :-1], | |||
tgt_type=inputs['tgt_type'][:, :-1], | |||
tgt_turn=inputs['tgt_turn'][:, :-1]) | |||
else: | |||
enc_embed, dec_embed = self._encoder_decoder_network( | |||
src_token=inputs['src_token'], | |||
src_mask=inputs['src_mask'], | |||
tgt_token=inputs['tgt_token'][:, :-1], | |||
tgt_mask=inputs['tgt_mask'][:, :-1], | |||
src_pos=inputs['src_pos'], | |||
src_type=inputs['src_type'], | |||
src_turn=inputs['src_turn'], | |||
tgt_pos=inputs['tgt_pos'][:, :-1], | |||
tgt_type=inputs['tgt_type'][:, :-1], | |||
tgt_turn=inputs['tgt_turn'][:, :-1]) | |||
outputs['dec_probs'] = self._dec_head(dec_embed=dec_embed) | |||
return outputs | |||
def _collect_metrics(self, inputs, outputs, with_label, data_file): | |||
metrics = {} | |||
loss = 0. | |||
label = inputs['tgt_token'][:, 1:] | |||
token_num = torch.sum(torch.sum(inputs['tgt_mask'], dim=1) - 1) | |||
nll = self.nll_loss( | |||
torch.log(outputs['dec_probs'] + 1e-12).permute(0, 2, 1), label) | |||
nll = torch.sum(nll, dim=1) | |||
token_nll = torch.sum(nll) / token_num | |||
nll = torch.mean(nll) | |||
metrics['nll'] = nll | |||
metrics['token_nll'] = token_nll | |||
metrics['token_num'] = token_num | |||
loss = loss + (token_nll if self.token_loss else nll) | |||
metrics['loss'] = loss | |||
if self.gpu > 1: | |||
return nll, token_nll, token_num | |||
else: | |||
return metrics | |||
def _optimize(self, loss, do_update=False, optimizer=None): | |||
""" Optimize loss function and update model. """ | |||
assert optimizer is not None | |||
if self.gradient_accumulation_steps > 1: | |||
loss = loss / self.gradient_accumulation_steps | |||
loss.backward() | |||
if self.grad_clip is not None and self.grad_clip > 0: | |||
torch.nn.utils.clip_grad_norm_( | |||
parameters=self.parameters(), max_norm=self.grad_clip) | |||
if do_update: | |||
optimizer.step() | |||
optimizer.zero_grad() | |||
return | |||
def _init_state(self, | |||
src_token, | |||
src_mask, | |||
src_pos=None, | |||
src_type=None, | |||
src_turn=None): | |||
""" Initialize decode state. """ | |||
state = {} | |||
batch_size = src_token.shape[0] | |||
src_embed = self.embedder(src_token, src_pos, src_type, src_turn) | |||
src_embed = self.embed_layer_norm(src_embed) | |||
mask = self._create_mask(src_mask, append_head=False) | |||
enc_out = src_embed | |||
cache = {} | |||
for l, layer in enumerate(self.layers): | |||
cache[f'layer_{l}'] = {} | |||
enc_out = layer(enc_out, mask, cache[f'layer_{l}']) | |||
state['cache'] = cache | |||
state['mask'] = mask[:, :1] | |||
state['batch_size'] = batch_size | |||
shape = [batch_size, 1, 1] | |||
state['pred_mask'] = torch.ones(shape, dtype=torch.float32) | |||
state['pred_pos'] = torch.zeros(shape, dtype=torch.int64) | |||
state['pred_type'] = torch.zeros(shape, dtype=torch.int64) | |||
state['pred_turn'] = torch.zeros(shape, dtype=torch.int64) | |||
if self.use_gpu: | |||
state['pred_mask'] = state['pred_mask'].cuda() | |||
state['pred_pos'] = state['pred_pos'].cuda() | |||
state['pred_type'] = state['pred_type'].cuda() | |||
state['pred_turn'] = state['pred_turn'].cuda() | |||
return state | |||
def _init_prompt_state(self, | |||
src_token, | |||
src_mask, | |||
prompt_token, | |||
prompt_mask, | |||
src_pos=None, | |||
src_type=None, | |||
src_turn=None, | |||
prompt_pos=None, | |||
prompt_type=None, | |||
prompt_turn=None): | |||
""" Initialize decode state. """ | |||
state = {} | |||
batch_size = src_token.shape[0] | |||
src_embed = self.embedder(src_token, src_pos, src_type, src_turn) | |||
prompt_embed = self.embedder(prompt_token, prompt_pos, prompt_type, | |||
prompt_turn) | |||
embed = torch.cat([src_embed, prompt_embed], dim=1) | |||
embed = self.embed_layer_norm(embed) | |||
enc_out = embed | |||
enc_mask = self._create_mask(src_mask, auto_regressive=False) | |||
dec_mask = self._create_mask(prompt_mask, auto_regressive=True) | |||
mask = self._join_mask(enc_mask, dec_mask) | |||
cache = {} | |||
for l, layer in enumerate(self.layers): | |||
cache[f'layer_{l}'] = {} | |||
enc_out = layer(enc_out, mask, cache[f'layer_{l}']) | |||
state['cache'] = cache | |||
state['mask'] = mask[:, -1:] # state["mask"] = mask[:, :1] | |||
state['batch_size'] = batch_size | |||
shape = [batch_size, 1, 1] | |||
state['pred_mask'] = torch.ones(shape, dtype=torch.float32) | |||
state['pred_pos'] = torch.zeros(shape, dtype=torch.int64) | |||
state['pred_type'] = torch.zeros(shape, dtype=torch.int64) | |||
state['pred_turn'] = torch.zeros(shape, dtype=torch.int64) | |||
if self.use_gpu: | |||
state['pred_mask'] = state['pred_mask'].cuda() | |||
state['pred_pos'] = state['pred_pos'].cuda() | |||
state['pred_type'] = state['pred_type'].cuda() | |||
state['pred_turn'] = state['pred_turn'].cuda() | |||
return state | |||
def _decode(self, state): | |||
""" Decoding one time stamp. """ | |||
# shape: [batch_size, 1, seq_len] | |||
mask = state['mask'] | |||
# shape: [batch_size, 1, 1] | |||
pred_token = state['pred_token'] | |||
pred_mask = state['pred_mask'] | |||
pred_pos = state['pred_pos'] | |||
pred_type = state['pred_type'] | |||
pred_turn = state['pred_turn'] | |||
# list of shape(len: num_layers): [batch_size, seq_len, hidden_dim] | |||
cache = state['cache'] | |||
pred_embed = self.embedder(pred_token, pred_pos, pred_type, | |||
pred_turn).squeeze(-2) | |||
pred_embed = self.embed_layer_norm(pred_embed) | |||
# shape: [batch_size, 1, seq_len + 1] | |||
mask = torch.cat([mask, 1 - pred_mask], dim=2) | |||
# shape: [batch_size, 1, hidden_dim] | |||
for l, layer in enumerate(self.layers): | |||
pred_embed = layer(pred_embed, mask, cache[f'layer_{l}']) | |||
# shape: [batch_size, vocab_size] | |||
pred_probs = self._dec_head(dec_embed=pred_embed[:, 0]) | |||
pred_logits = torch.log(pred_probs) | |||
state['mask'] = mask | |||
return pred_logits, state | |||
def _infer(self, | |||
inputs, | |||
start_id=None, | |||
eos_id=None, | |||
max_gen_len=None, | |||
prev_input=None): | |||
""" Real inference process of model. """ | |||
def cat(x, y, dim=1): | |||
return torch.cat([x, y], dim=dim) | |||
# Initial decode state. | |||
if self.understand or self.policy: | |||
if self.understand: | |||
prompt_token = inputs['understand_token'] | |||
prompt_mask = inputs['understand_mask'] | |||
if self.policy: | |||
prompt_token = cat(prompt_token, inputs['policy_token']) | |||
prompt_mask = cat(prompt_mask, inputs['policy_mask']) | |||
else: | |||
prompt_token = inputs['policy_token'] | |||
prompt_mask = inputs['policy_mask'] | |||
state = self._init_prompt_state( | |||
src_token=inputs['src_token'], | |||
src_mask=inputs['src_mask'], | |||
prompt_token=prompt_token, | |||
prompt_mask=prompt_mask, | |||
src_pos=inputs['src_pos'], | |||
src_type=inputs['src_type'], | |||
src_turn=inputs['src_turn']) | |||
else: | |||
state = self._init_state( | |||
src_token=inputs['src_token'], | |||
src_mask=inputs['src_mask'], | |||
src_pos=inputs['src_pos'], | |||
src_type=inputs['src_type'], | |||
src_turn=inputs['src_turn']) | |||
# Generation process. | |||
gen_results = self.generator( | |||
step_fn=self._decode, | |||
state=state, | |||
start_id=start_id, | |||
eos_id=eos_id, | |||
max_gen_len=max_gen_len, | |||
prev_input=prev_input) | |||
outputs = gen_results['preds'] | |||
return outputs | |||
GenUnifiedTransformer.register('GenUnifiedTransformer') |
@@ -0,0 +1,296 @@ | |||
""" | |||
Generator class. | |||
""" | |||
import math | |||
import numpy as np | |||
import torch | |||
from .gen_unified_transformer import GenUnifiedTransformer | |||
from .unified_transformer import UnifiedTransformer | |||
def repeat(var, times): | |||
if isinstance(var, list): | |||
return [repeat(x, times) for x in var] | |||
elif isinstance(var, dict): | |||
return {k: repeat(v, times) for k, v in var.items()} | |||
elif isinstance(var, torch.Tensor): | |||
var = var.unsqueeze(1) | |||
expand_times = [1] * len(var.shape) | |||
expand_times[1] = times | |||
dtype = var.dtype | |||
var = var.float() | |||
var = var.repeat(*expand_times) | |||
shape = [var.shape[0] * var.shape[1]] + list(var.shape[2:]) | |||
var = var.reshape(*shape) | |||
var = torch.tensor(var, dtype=dtype) | |||
return var | |||
else: | |||
return var | |||
def gather(var, idx): | |||
if isinstance(var, list): | |||
return [gather(x, idx) for x in var] | |||
elif isinstance(var, dict): | |||
return {k: gather(v, idx) for k, v in var.items()} | |||
elif isinstance(var, torch.Tensor): | |||
out = var.index_select(dim=0, index=idx) | |||
return out | |||
else: | |||
return var | |||
class Generator(object): | |||
""" Genrator class. """ | |||
_registry = dict() | |||
@classmethod | |||
def register(cls, name): | |||
Generator._registry[name] = cls | |||
return | |||
@staticmethod | |||
def by_name(name): | |||
return Generator._registry[name] | |||
@staticmethod | |||
def create(config, *args, **kwargs): | |||
""" Create generator. """ | |||
generator_cls = Generator.by_name(config.Generator.generator) | |||
return generator_cls(config, *args, **kwargs) | |||
def __init__(self, config, reader): | |||
self.vocab_size = reader.vocab_size | |||
self.bos_id = reader.bos_id | |||
self.eos_id = reader.eos_id | |||
self.unk_id = reader.unk_id | |||
self.pad_id = reader.pad_id | |||
self.min_gen_len = config.Generator.min_gen_len | |||
self.max_gen_len = config.Generator.max_gen_len | |||
self.use_gpu = config.use_gpu | |||
assert 1 <= self.min_gen_len <= self.max_gen_len | |||
return | |||
def __call__(self, step_fn, state): | |||
""" | |||
Running generation. | |||
@param : step_fn : decoding one step | |||
@type : function | |||
@param : state : initial state | |||
@type : dict | |||
""" | |||
raise NotImplementedError | |||
class BeamSearch(Generator): | |||
""" BeamSearch generator. """ | |||
def __init__(self, config, reader): | |||
super().__init__(config, reader) | |||
self.beam_size = config.Generator.beam_size | |||
self.length_average = config.Generator.length_average | |||
self.length_penalty = config.Generator.length_penalty | |||
self.ignore_unk = config.Generator.ignore_unk | |||
return | |||
def __call__(self, | |||
step_fn, | |||
state, | |||
start_id=None, | |||
eos_id=None, | |||
max_gen_len=None, | |||
prev_input=None): | |||
""" | |||
Running beam search. | |||
@param : step_fn : decoding one step | |||
@type : function | |||
@param : state : initial state | |||
@type : dict | |||
""" | |||
if prev_input is not None: | |||
if isinstance(prev_input, list): | |||
length = max(list(map(lambda x: len(x), prev_input))) | |||
prev_input_numpy = np.full((len(prev_input), length), | |||
self.pad_id) | |||
for i, x in enumerate(prev_input): | |||
prev_input_numpy[i, :len(x)] = x | |||
prev_input_tensor = torch.from_numpy(prev_input_numpy) | |||
if self.use_gpu: | |||
prev_input_tensor = prev_input_tensor.cuda() | |||
for i in range(length): | |||
state['pred_token'] = prev_input_tensor[:, i].unsqueeze( | |||
-1).unsqueeze(-1) | |||
if i != 0: | |||
state['pred_mask'] = torch.not_equal( | |||
state['pred_token'], self.pad_id).float() | |||
state['pred_pos'] = state['pred_pos'] + state[ | |||
'pred_mask'].int() | |||
_, state = step_fn(state) | |||
else: | |||
assert isinstance(prev_input, torch.Tensor) | |||
for i, input in enumerate(prev_input): | |||
state['pred_token'] = input.expand(1, 1, 1) | |||
if i != 0: | |||
state['pred_mask'] = torch.not_equal( | |||
state['pred_token'], self.pad_id).float() | |||
state['pred_pos'] = state['pred_pos'] + 1 | |||
_, state = step_fn(state) | |||
batch_size = state['batch_size'] | |||
beam_size = self.beam_size | |||
# shape: [batch_size, 1] | |||
pos_index = torch.arange( | |||
0, batch_size, 1, dtype=torch.int64) * beam_size | |||
pos_index = pos_index.unsqueeze(1) | |||
# shape: [batch_size, beam_size, 1] | |||
if start_id is None: | |||
start_id = self.bos_id | |||
if eos_id is None: | |||
eos_id = self.eos_id | |||
predictions = torch.ones([batch_size, beam_size, 1], | |||
dtype=torch.int64) * start_id | |||
if self.use_gpu: | |||
pos_index = pos_index.cuda() | |||
predictions = predictions.cuda() | |||
# initial input (start_id) | |||
state['pred_token'] = predictions[:, :1] | |||
if prev_input is not None: | |||
state['pred_mask'] = torch.not_equal(state['pred_token'], | |||
self.pad_id).float() | |||
state['pred_pos'] = state['pred_pos'] + 1 | |||
# shape: [batch_size, vocab_size] | |||
scores, state = step_fn(state) | |||
unk_penalty = np.zeros(self.vocab_size, dtype='float32') | |||
unk_penalty[self.unk_id] = -1e10 | |||
unk_penalty = torch.from_numpy(unk_penalty) | |||
eos_penalty = np.zeros(self.vocab_size, dtype='float32') | |||
eos_penalty[eos_id] = -1e10 | |||
eos_penalty = torch.from_numpy(eos_penalty) | |||
scores_after_end = np.full(self.vocab_size, -1e10, dtype='float32') | |||
scores_after_end[ | |||
self.pad_id] = 0 # 希望<eos>之后只生成<pad>,故使词表中log(p(<pad>))最高(0) | |||
scores_after_end = torch.from_numpy(scores_after_end) | |||
if self.use_gpu: | |||
unk_penalty = unk_penalty.cuda() | |||
eos_penalty = eos_penalty.cuda() | |||
scores_after_end = scores_after_end.cuda() | |||
if self.ignore_unk: | |||
scores = scores + unk_penalty | |||
scores = scores + eos_penalty | |||
# shape: [batch_size, beam_size] | |||
sequence_scores, preds = torch.topk(scores, self.beam_size) | |||
predictions = torch.cat([predictions, preds.unsqueeze(2)], dim=2) | |||
state = repeat(state, beam_size) | |||
parent_idx_list = [] | |||
pred_list = [] | |||
if max_gen_len is None: | |||
max_gen_len = self.max_gen_len | |||
for step in range(2, max_gen_len + 1): | |||
pre_ids = predictions[:, :, -1:] | |||
state['pred_token'] = pre_ids.reshape(batch_size * beam_size, 1, 1) | |||
state['pred_mask'] = torch.not_equal(state['pred_token'], | |||
self.pad_id).float() | |||
state['pred_pos'] = state['pred_pos'] + 1 | |||
scores, state = step_fn(state) | |||
# Generate next | |||
# scores shape: [batch_size * beam_size, vocab_size] | |||
if self.ignore_unk: | |||
scores = scores + unk_penalty | |||
if step <= self.min_gen_len: | |||
scores = scores + eos_penalty | |||
# scores shape: [batch_size, beam_size, vocab_size] | |||
scores = scores.reshape(batch_size, beam_size, self.vocab_size) | |||
# previous token is [PAD] or [EOS] | |||
pre_eos_mask = (1 - torch.not_equal(pre_ids, eos_id).float()) + \ | |||
(1 - torch.not_equal(pre_ids, self.pad_id).float()) | |||
scores = scores * (1 - pre_eos_mask) + \ | |||
pre_eos_mask.repeat(1, 1, self.vocab_size) * scores_after_end | |||
if self.length_average: | |||
scaled_value = pre_eos_mask + (1 - pre_eos_mask) * (1 - | |||
1 / step) | |||
sequence_scores = sequence_scores.unsqueeze(2) * scaled_value | |||
scaled_value = pre_eos_mask + (1 - pre_eos_mask) * (1 / step) | |||
scores = scores * scaled_value | |||
elif self.length_penalty >= 0.0: | |||
scaled_value = pre_eos_mask + (1 - pre_eos_mask) * \ | |||
(math.pow((4 + step) / (5 + step), self.length_penalty)) | |||
sequence_scores = scaled_value * sequence_scores | |||
scaled_value = pre_eos_mask + (1 - pre_eos_mask) * \ | |||
(math.pow(1 / (5 + step), self.length_penalty)) | |||
scores = scores * scaled_value | |||
scores = scores + sequence_scores.unsqueeze(-1) | |||
scores = scores.reshape(batch_size, beam_size * self.vocab_size) | |||
topk_scores, topk_indices = torch.topk(scores, beam_size) | |||
# topk_indices: [batch_size, beam_size * self.vocab_size] (已reshape) | |||
# 判断当前时间步产生词的前一个词在哪个beam中,对vocab_size取商 | |||
parent_idx = topk_indices.floor_divide(self.vocab_size) | |||
# 对vocab_size取余 | |||
preds = topk_indices % self.vocab_size | |||
# Gather state / sequence_scores | |||
parent_idx = parent_idx + pos_index | |||
parent_idx = parent_idx.reshape(batch_size * beam_size) | |||
state = gather(state, parent_idx) | |||
sequence_scores = topk_scores | |||
predictions = predictions.reshape(batch_size * beam_size, step) | |||
predictions = gather(predictions, parent_idx) | |||
predictions = predictions.reshape(batch_size, beam_size, step) | |||
predictions = torch.cat([predictions, preds.unsqueeze(2)], dim=2) | |||
# 希望生成的整个句子已完结,所以要求最后一个token为<eos>或者<pad>(跟在<eos>之后),否则惩罚 | |||
pre_ids = predictions[:, :, -1] | |||
pre_eos_mask = (1 - torch.not_equal(pre_ids, eos_id).float()) + \ | |||
(1 - torch.not_equal(pre_ids, self.pad_id).float()) | |||
sequence_scores = sequence_scores * pre_eos_mask + ( | |||
1 - pre_eos_mask) * (-1e10) | |||
# 先获得ascending排序的index,便于之后对predictions和sequence_scores排序(针对beam size轴) | |||
indices = torch.argsort(sequence_scores, dim=1) | |||
indices = indices + pos_index | |||
indices = indices.reshape(-1) | |||
sequence_scores = sequence_scores.reshape(batch_size * beam_size) | |||
predictions = predictions.reshape(batch_size * beam_size, -1) | |||
sequence_scores = gather(sequence_scores, indices) | |||
predictions = gather(predictions, indices) | |||
sequence_scores = sequence_scores.reshape(batch_size, beam_size) | |||
predictions = predictions.reshape(batch_size, beam_size, -1) | |||
results = { | |||
'preds': predictions[:, -1], | |||
'scores': sequence_scores[:, -1] | |||
} | |||
return results | |||
BeamSearch.register('BeamSearch') |
@@ -0,0 +1,99 @@ | |||
""" | |||
Model base | |||
""" | |||
import os | |||
import torch.nn as nn | |||
class ModelBase(nn.Module): | |||
""" | |||
Basic model wrapper for static graph and dygrpah. | |||
""" | |||
_registry = dict() | |||
@classmethod | |||
def register(cls, name): | |||
ModelBase._registry[name] = cls | |||
return | |||
@staticmethod | |||
def by_name(name): | |||
return ModelBase._registry[name] | |||
@staticmethod | |||
def create(model_dir, config, *args, **kwargs): | |||
model_cls = ModelBase.by_name(config.Model.model) | |||
return model_cls(model_dir, config, *args, **kwargs) | |||
def __init__(self, model_dir, config): | |||
super(ModelBase, self).__init__() | |||
self.init_checkpoint = os.path.join(model_dir, 'pytorch_model.bin') | |||
self.abandon_label = config.Dataset.abandon_label | |||
self.use_gpu = config.use_gpu | |||
self.gpu = config.Trainer.gpu | |||
return | |||
def _create_parameters(self): | |||
""" Create model's paramters. """ | |||
raise NotImplementedError | |||
def _forward(self, inputs, is_training, with_label): | |||
""" NO LABEL: Real forward process of model in different mode(train/test). """ | |||
raise NotImplementedError | |||
def _collect_metrics(self, inputs, outputs, with_label, data_file): | |||
""" NO LABEL: Calculate loss function by using inputs and outputs. """ | |||
raise NotImplementedError | |||
def _optimize(self, loss, optimizer, lr_scheduler): | |||
""" Optimize loss function and update model. """ | |||
raise NotImplementedError | |||
def _infer(self, inputs, start_id, eos_id, max_gen_len, prev_input): | |||
""" Real inference process of model. """ | |||
raise NotImplementedError | |||
def forward(self, | |||
inputs, | |||
is_training=False, | |||
with_label=False, | |||
data_file=None): | |||
""" | |||
Forward process, include real forward, collect metrices and optimize(optional) | |||
@params : inputs : input data | |||
@type : dict of numpy.ndarray/int/float/... | |||
""" | |||
if is_training: | |||
self.train() | |||
else: | |||
self.eval() | |||
with_label = False if self.abandon_label else with_label | |||
outputs = self._forward(inputs, is_training, with_label=with_label) | |||
metrics = self._collect_metrics( | |||
inputs, outputs, with_label=with_label, data_file=data_file) | |||
return metrics | |||
def infer(self, | |||
inputs, | |||
start_id=None, | |||
eos_id=None, | |||
max_gen_len=None, | |||
prev_input=None): | |||
""" | |||
Inference process. | |||
@params : inputs : input data | |||
@type : dict of numpy.ndarray/int/float/... | |||
""" | |||
self.eval() | |||
results = self._infer( | |||
inputs, | |||
start_id=start_id, | |||
eos_id=eos_id, | |||
max_gen_len=max_gen_len, | |||
prev_input=prev_input) | |||
return results |
@@ -0,0 +1,322 @@ | |||
""" | |||
UnifiedTransformer | |||
""" | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from maas_lib.models.nlp.space.model.model_base import ModelBase | |||
from maas_lib.models.nlp.space.modules.embedder import Embedder | |||
from maas_lib.models.nlp.space.modules.transformer_block import \ | |||
TransformerBlock | |||
class UnifiedTransformer(ModelBase): | |||
""" | |||
Implement unified transformer. | |||
""" | |||
def __init__(self, model_dir, config, reader, generator, dtype='float32'): | |||
super(UnifiedTransformer, self).__init__(model_dir, config) | |||
self.reader = reader | |||
self.generator = generator | |||
self.policy = config.BPETextField.policy | |||
self.generation = config.BPETextField.generation | |||
self.num_token_embeddings = config.Model.num_token_embeddings | |||
self.num_pos_embeddings = config.Model.num_pos_embeddings | |||
self.num_type_embeddings = config.Model.num_type_embeddings | |||
self.num_turn_embeddings = config.Model.num_turn_embeddings | |||
self.temperature = config.Model.temperature | |||
self.hidden_dim = config.Model.hidden_dim | |||
self.num_heads = config.Model.num_heads | |||
self.num_layers = config.Model.num_layers | |||
self.padding_idx = config.Model.padding_idx | |||
self.dropout = config.Model.dropout | |||
self.embed_dropout = config.Model.embed_dropout | |||
self.attn_dropout = config.Model.attn_dropout | |||
self.ff_dropout = config.Model.ff_dropout | |||
self.mlm_ratio = config.Model.mlm_ratio | |||
self.mmd_ratio = config.Model.mmd_ratio | |||
self.pos_trainable = config.Model.pos_trainable | |||
self.label_smooth = config.Model.label_smooth | |||
self.initializer_range = config.Model.initializer_range | |||
self.gradient_accumulation_steps = config.Model.gradient_accumulation_steps | |||
self.token_loss = config.Trainer.token_loss | |||
self.learning_method = config.Dataset.learning_method | |||
self.with_contrastive = config.Dataset.with_contrastive | |||
self.with_query_bow = config.BPETextField.with_query_bow | |||
self.with_resp_bow = config.BPETextField.with_resp_bow | |||
self.with_pool = config.Model.with_pool | |||
self.with_mlm = config.Dataset.with_mlm | |||
self._dtype = dtype | |||
self.embedder = Embedder( | |||
self.hidden_dim, | |||
self.num_token_embeddings, | |||
self.num_pos_embeddings, | |||
self.num_type_embeddings, | |||
self.num_turn_embeddings, | |||
padding_idx=self.padding_idx, | |||
dropout=self.embed_dropout, | |||
pos_trainable=self.pos_trainable) | |||
self.embed_layer_norm = nn.LayerNorm( | |||
normalized_shape=self.hidden_dim, | |||
eps=1e-12, | |||
elementwise_affine=True) | |||
self.layers = nn.ModuleList([ | |||
TransformerBlock(self.hidden_dim, self.num_heads, self.dropout, | |||
self.attn_dropout, self.ff_dropout) | |||
for _ in range(config.Model.num_layers) | |||
]) | |||
if self.with_mlm: | |||
self.mlm_transform = nn.Sequential( | |||
nn.Linear(self.hidden_dim, self.hidden_dim), nn.GELU(), | |||
nn.LayerNorm( | |||
normalized_shape=self.hidden_dim, | |||
eps=1e-12, | |||
elementwise_affine=True)) | |||
self.mlm_bias = nn.Parameter( | |||
torch.zeros(self.num_token_embeddings)) | |||
self.pooler = nn.Sequential( | |||
nn.Linear(self.hidden_dim, self.hidden_dim), nn.Tanh()) | |||
if self.with_query_bow or self.with_resp_bow: | |||
self.bow_predictor = nn.Linear( | |||
self.hidden_dim, self.num_token_embeddings, bias=False) | |||
self.sigmoid = nn.Sigmoid() | |||
self.softmax = nn.Softmax(dim=-1) | |||
self.bce_loss = nn.BCELoss(reduction='none') | |||
self.nll_loss = nn.NLLLoss( | |||
ignore_index=self.padding_idx, reduction='none') | |||
self._create_parameters() | |||
self.max_grad_norm = config.Model.max_grad_norm | |||
if self.max_grad_norm is not None: | |||
self.grad_clip = self.max_grad_norm | |||
else: | |||
self.grad_clip = None | |||
self.weight_decay = config.Model.weight_decay | |||
if self.use_gpu: | |||
self.cuda() | |||
return | |||
def _create_parameters(self): | |||
""" Create model's paramters. """ | |||
sequence_mask = np.tri( | |||
self.num_pos_embeddings, | |||
self.num_pos_embeddings, | |||
dtype=self._dtype) | |||
self.sequence_mask = torch.tensor(sequence_mask) | |||
return | |||
def _create_mask(self, | |||
input_mask, | |||
append_head=False, | |||
auto_regressive=False): | |||
""" | |||
Create attention mask. | |||
创建从序列形式到矩阵形式的mask:[batch_size, max_seq_len, 1] -> [batch_size, max_seq_len, max_seq_len] | |||
mask除了要考虑attention mask(自回归),还需要考虑pad的mask(自回归和双向) | |||
注: | |||
1. 一个句子中的非<pad>词看整个句子,该句中只有<pad>词才被mask | |||
2. 一个句子中的<pad>词看整个句子,该句的所有词都应该被mask | |||
@param : input_mask | |||
@type : Variable(shape: [batch_size, max_seq_len]) | |||
@param : auto_regressive | |||
@type : bool | |||
""" | |||
seq_len = input_mask.shape[1] | |||
input_mask = input_mask.float() | |||
mask1 = input_mask.unsqueeze(-1).repeat(1, 1, seq_len) | |||
mask2 = mask1.permute(0, 2, 1) | |||
mask = mask1 * mask2 | |||
if append_head: | |||
# 拼接上句首位置([M]/z)的mask | |||
mask = torch.cat([mask[:, :1, :], mask], dim=1) | |||
mask = torch.cat([mask[:, :, :1], mask], dim=2) | |||
seq_len += 1 | |||
if auto_regressive: | |||
# 将tgt端的<pad> mask和自回归attention mask融合 | |||
seq_mask = self.sequence_mask[:seq_len, :seq_len] | |||
seq_mask = seq_mask.to(mask.device) | |||
mask = mask * seq_mask | |||
mask = 1 - mask | |||
return mask | |||
def _join_mask(self, mask1, mask2): | |||
""" | |||
Merge source attention mask and target attention mask. | |||
合并后的整个mask矩阵可以分为四个部分:左上lu/右上ru/左下lb/右下rb | |||
@param : mask1 : source attention mask | |||
@type : Variable(shape: [batch_size, max_src_len, max_src_len]) | |||
@param : mask1 : target attention mask | |||
@type : Variable(shape: [batch_size, max_tgt_len, max_tgt_len]) | |||
""" | |||
batch_size = mask1.shape[0] | |||
seq_len1 = mask1.shape[1] | |||
seq_len2 = mask2.shape[1] | |||
seq_len = seq_len1 + seq_len2 | |||
mask_lu = mask1 | |||
mask_ru = torch.ones(batch_size, seq_len1, seq_len2) | |||
if self.use_gpu: | |||
mask_ru = mask_ru.cuda() | |||
mask3 = mask2[:, :, :1].repeat(1, 1, seq_len1) | |||
mask4 = mask1[:, :1].repeat(1, seq_len2, 1) | |||
mask_lb = mask3 + mask4 - mask3 * mask4 | |||
mask_rb = mask2 | |||
mask_u = torch.cat([mask_lu, mask_ru], dim=2) | |||
mask_b = torch.cat([mask_lb, mask_rb], dim=2) | |||
mask = torch.cat([mask_u, mask_b], dim=1) | |||
return mask | |||
def _mlm_head(self, mlm_embed): | |||
mlm_embed = self.mlm_transform(mlm_embed) | |||
mlm_logits = torch.matmul( | |||
mlm_embed, self.embedder.token_embedding.weight.T) + self.mlm_bias | |||
mlm_probs = self.softmax(mlm_logits) | |||
return mlm_probs | |||
def _dec_head(self, dec_embed): | |||
dec_logits = torch.matmul(dec_embed, | |||
self.embedder.token_embedding.weight.T) | |||
dec_probs = self.softmax(dec_logits) | |||
return dec_probs | |||
def _refactor_feature(self, features): | |||
features = self.pooler(features) if self.with_pool else features | |||
batch_size = features.size(0) // 2 | |||
features = torch.cat([ | |||
features[:batch_size].unsqueeze(1), | |||
features[batch_size:].unsqueeze(1) | |||
], | |||
dim=1) | |||
features = F.normalize(features, dim=-1, p=2) | |||
return features | |||
def _encoder_network(self, | |||
input_token, | |||
input_mask, | |||
input_pos=None, | |||
input_type=None, | |||
input_turn=None): | |||
embed = self.embedder(input_token, input_pos, input_type, input_turn) | |||
embed = self.embed_layer_norm(embed) | |||
mask = self._create_mask(input_mask, auto_regressive=False) | |||
for layer in self.layers: | |||
embed = layer(embed, mask, None) | |||
return embed | |||
def _encoder_decoder_network(self, | |||
src_token, | |||
src_mask, | |||
tgt_token, | |||
tgt_mask, | |||
src_pos=None, | |||
src_type=None, | |||
src_turn=None, | |||
tgt_pos=None, | |||
tgt_type=None, | |||
tgt_turn=None): | |||
src_embed = self.embedder(src_token, src_pos, src_type, src_turn) | |||
tgt_embed = self.embedder(tgt_token, tgt_pos, tgt_type, tgt_turn) | |||
embed = torch.cat([src_embed, tgt_embed], dim=1) | |||
embed = self.embed_layer_norm(embed) | |||
enc_mask = self._create_mask(src_mask, auto_regressive=False) | |||
dec_mask = self._create_mask(tgt_mask, auto_regressive=True) | |||
mask = self._join_mask(enc_mask, dec_mask) | |||
for layer in self.layers: | |||
embed = layer(embed, mask, None) | |||
tgt_len = tgt_token.shape[1] | |||
enc_embed = embed[:, :-tgt_len] | |||
dec_embed = embed[:, -tgt_len:] | |||
return enc_embed, dec_embed | |||
def _encoder_prompt_decoder_network(self, | |||
src_token, | |||
src_mask, | |||
tgt_token, | |||
tgt_mask, | |||
prompt_token, | |||
prompt_mask, | |||
src_pos=None, | |||
src_type=None, | |||
src_turn=None, | |||
tgt_pos=None, | |||
tgt_type=None, | |||
tgt_turn=None, | |||
prompt_pos=None, | |||
prompt_type=None, | |||
prompt_turn=None): | |||
src_embed = self.embedder(src_token, src_pos, src_type, src_turn) | |||
tgt_embed = self.embedder(tgt_token, tgt_pos, tgt_type, tgt_turn) | |||
prompt_embed = self.embedder(prompt_token, prompt_pos, prompt_type, | |||
prompt_turn) | |||
embed = torch.cat([src_embed, prompt_embed, tgt_embed], dim=1) | |||
embed = self.embed_layer_norm(embed) | |||
enc_mask = self._create_mask(src_mask, auto_regressive=False) | |||
dec_mask = self._create_mask( | |||
torch.cat([prompt_mask, tgt_mask], dim=1), auto_regressive=True) | |||
mask = self._join_mask(enc_mask, dec_mask) | |||
for layer in self.layers: | |||
embed = layer(embed, mask, None) | |||
src_len = src_token.shape[1] | |||
tgt_len = tgt_token.shape[1] | |||
enc_embed = embed[:, :src_len] | |||
dec_embed = embed[:, -tgt_len:] | |||
prompt_embed = embed[:, src_len:-tgt_len] | |||
return enc_embed, dec_embed, prompt_embed | |||
def _optimize(self, loss, optimizer=None, lr_scheduler=None): | |||
""" Optimize loss function and update model. """ | |||
assert optimizer is not None | |||
optimizer.zero_grad() | |||
loss.backward() | |||
if self.grad_clip is not None and self.grad_clip > 0: | |||
torch.nn.utils.clip_grad_norm_( | |||
parameters=self.parameters(), max_norm=self.grad_clip) | |||
optimizer.step() | |||
if lr_scheduler is not None: | |||
lr_scheduler.step() | |||
return | |||
def _infer(self, | |||
inputs, | |||
start_id=None, | |||
eos_id=None, | |||
max_gen_len=None, | |||
prev_input=None): | |||
""" Real inference process of model. """ | |||
results = {} | |||
return results | |||
UnifiedTransformer.register('UnifiedTransformer') |
@@ -0,0 +1,67 @@ | |||
""" | |||
Embedder class. | |||
""" | |||
import torch | |||
import torch.nn as nn | |||
class Embedder(nn.Module): | |||
""" | |||
Composite embedding layer. | |||
""" | |||
def __init__(self, | |||
hidden_dim, | |||
num_token_embeddings, | |||
num_pos_embeddings, | |||
num_type_embeddings, | |||
num_turn_embeddings, | |||
padding_idx=None, | |||
dropout=0.1, | |||
pos_trainable=False): | |||
super(Embedder, self).__init__() | |||
self.token_embedding = nn.Embedding(num_token_embeddings, hidden_dim) | |||
self.pos_embedding = nn.Embedding(num_pos_embeddings, hidden_dim) | |||
self.pos_embedding.weight.requires_grad = pos_trainable | |||
self.type_embedding = nn.Embedding(num_type_embeddings, hidden_dim) | |||
self.turn_embedding = nn.Embedding(num_turn_embeddings, hidden_dim) | |||
self.dropout_layer = nn.Dropout(p=dropout) | |||
# follow the default xavier_uniform initializer in paddle version | |||
# otherwise, there are bugs for dec_probs computation in weight typing setting | |||
# default norm initializer in nn.Embedding in pytorch, which samples larger values | |||
nn.init.xavier_uniform_(self.token_embedding.weight) | |||
nn.init.xavier_uniform_(self.pos_embedding.weight) | |||
nn.init.xavier_uniform_(self.type_embedding.weight) | |||
nn.init.xavier_uniform_(self.turn_embedding.weight) | |||
return | |||
def forward(self, token_inp, pos_inp=None, type_inp=None, turn_inp=None): | |||
embed = self.token_embedding(token_inp) | |||
if pos_inp is not None: | |||
embed += self.pos_embedding(pos_inp) | |||
if type_inp is not None: | |||
embed += self.type_embedding(type_inp) | |||
if turn_inp is not None: | |||
embed += self.turn_embedding(turn_inp) | |||
embed = self.dropout_layer(embed) | |||
return embed | |||
def main(): | |||
import numpy as np | |||
model = Embedder(10, 20, 20, 20, 20) | |||
token_inp = torch.tensor( | |||
np.random.randint(0, 19, [10, 10]).astype('int64')) | |||
pos_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64')) | |||
type_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64')) | |||
turn_inp = torch.tensor(np.random.randint(0, 19, [10, 10]).astype('int64')) | |||
out = model(token_inp, pos_inp, type_inp, turn_inp) | |||
print(out) | |||
if __name__ == '__main__': | |||
main() |
@@ -0,0 +1,43 @@ | |||
""" | |||
FeedForward class. | |||
""" | |||
import torch | |||
import torch.nn as nn | |||
class FeedForward(nn.Module): | |||
""" | |||
Positional feed forward layer. | |||
""" | |||
def __init__(self, hidden_dim, inner_dim, dropout): | |||
super(FeedForward, self).__init__() | |||
self.hidden_dim = hidden_dim | |||
self.inner_dim = inner_dim | |||
self.linear_hidden = nn.Sequential( | |||
nn.Linear(hidden_dim, inner_dim), nn.GELU()) | |||
self.linear_out = nn.Linear(inner_dim, hidden_dim) | |||
self.dropout_layer = nn.Dropout(p=dropout) | |||
return | |||
def forward(self, x): | |||
out = self.linear_hidden(x) | |||
out = self.dropout_layer(out) | |||
out = self.linear_out(out) | |||
return out | |||
def main(): | |||
import numpy as np | |||
model = FeedForward(10, 20, 0.5) | |||
inp = np.random.rand(2, 3, 10).astype('float32') | |||
inp = torch.tensor(inp) | |||
out = model(inp) | |||
print(out) | |||
if __name__ == '__main__': | |||
main() |
@@ -0,0 +1,64 @@ | |||
""" | |||
Helpful functions. | |||
""" | |||
import numpy as np | |||
import torch | |||
import torch.nn.functional as F | |||
def unsqueeze(input, dims): | |||
""" Implement multi-dimension unsqueeze function. """ | |||
if isinstance(dims, (list, tuple)): | |||
dims = [ | |||
dim if dim >= 0 else dim + len(input.shape) + 1 for dim in dims | |||
] | |||
dims = sorted(dims, reverse=True) | |||
shape = list(input.shape) | |||
for dim in dims: | |||
shape.insert(dim, 1) | |||
return torch.reshape(input, shape) | |||
elif isinstance(dims, int): | |||
return input.unsqueeze(dims) | |||
else: | |||
raise ValueError('Warning: type(dims) must in (list, tuple, int)!') | |||
def gumbel_softmax(input, tau=1, eps=1e-10): | |||
""" Basic implement of gumbel_softmax. """ | |||
U = torch.tensor(np.random.rand(*input.shape)) | |||
gumbel = 0.0 - torch.log(eps - torch.log(U + eps)) | |||
y = input + gumbel | |||
return F.softmax(y / tau) | |||
def equal(x, y, dtype=None): | |||
""" Implement equal in dygraph mode. (paddle) """ | |||
if dtype is None: | |||
dtype = 'float32' | |||
if isinstance(x, torch.Tensor): | |||
x = x.numpy() | |||
if isinstance(y, torch.Tensor): | |||
y = y.numpy() | |||
out = np.equal(x, y).astype(dtype) | |||
return torch.tensor(out) | |||
def not_equal(x, y, dtype=None): | |||
""" Implement not_equal in dygraph mode. (paddle) """ | |||
return 1 - equal(x, y, dtype) | |||
if __name__ == '__main__': | |||
a = torch.tensor([[1, 1], [3, 4]]) | |||
b = torch.tensor([[1, 1], [3, 4]]) | |||
c = torch.equal(a, a) | |||
c1 = equal(a, 3) | |||
d = 1 - torch.not_equal(a, 3).float() | |||
print(c) | |||
print(c1) | |||
print(d) | |||
e = F.gumbel_softmax(a) | |||
f = a.unsqueeze(a) | |||
g = unsqueeze(a, dims=[0, 0, 1]) | |||
print(g, g.shape) |
@@ -0,0 +1,109 @@ | |||
""" | |||
MultiheadAttention class. | |||
""" | |||
import torch | |||
import torch.nn as nn | |||
class MultiheadAttention(nn.Module): | |||
""" | |||
Multi head attention layer. | |||
""" | |||
def __init__(self, hidden_dim, num_heads, dropout): | |||
assert hidden_dim % num_heads == 0 | |||
super(MultiheadAttention, self).__init__() | |||
self.hidden_dim = hidden_dim | |||
self.num_heads = num_heads | |||
self.head_dim = hidden_dim // num_heads | |||
self.scale = self.head_dim**-0.5 | |||
self.linear_qkv = nn.Linear(hidden_dim, hidden_dim * 3) | |||
self.linear_out = nn.Linear(hidden_dim, hidden_dim) | |||
self.dropout_layer = nn.Dropout(p=dropout) | |||
self.softmax = nn.Softmax(dim=-1) | |||
return | |||
def _split_heads(self, x, is_key=False): | |||
x = x.reshape(x.size(0), x.size(1), self.num_heads, self.head_dim) | |||
x = x.permute(0, 2, 3, 1) if is_key else x.permute(0, 2, 1, 3) | |||
return x | |||
def _merge_heads(self, x): | |||
x = x.permute(0, 2, 1, 3) | |||
x = x.reshape(x.size(0), x.size(1), self.hidden_dim) | |||
return x | |||
def _attn(self, query, key, value, mask): | |||
# shape: [batch_size, num_head, seq_len, seq_len] | |||
scores = torch.matmul(query, key) | |||
scores = scores * self.scale | |||
if mask is not None: | |||
mask = mask.unsqueeze(1) | |||
mask = mask.repeat(1, self.num_heads, 1, 1) | |||
scores.masked_fill_( | |||
mask.bool(), | |||
float('-inf')) # scores = (1 - mask) * scores + mask * (-1e10) | |||
attn = self.softmax(scores) | |||
attn = self.dropout_layer(attn) | |||
if mask is not None: | |||
''' | |||
mask: [batch size, num_heads, seq_len, seq_len] | |||
mask后两维(seq_len, seq_len)矩阵来看,其中有的行可能都是true(1),对应句子中<pad>位看的行 | |||
导致softmax后该行的每个位置的attn prob都为1/n而非0,所以此处需重置为0 | |||
>>> F.softmax([-1e10, -100, -100]) | |||
>>> [0.00, 0.50, 0.50] | |||
>>> F.softmax([-1e10, -1e10, -1e10]) | |||
>>> [0.33, 0.33, 0.33] | |||
==> [0.00, 0.00, 0.00] | |||
''' | |||
attn.masked_fill_(mask.bool(), 0.) # attn = (1 - mask) * attn | |||
out = torch.matmul(attn, value) | |||
return out | |||
def forward(self, inp, mask=None, cache=None): | |||
""" Forward process of self attention. """ | |||
# shape: [batch_size, seq_len, 3 * hidden_dim] | |||
qkv = self.linear_qkv(inp) | |||
query, key, value = torch.split(qkv, self.hidden_dim, dim=2) | |||
# shape: [batch_size, num_head, seq_len, head_dim] | |||
query = self._split_heads(query) | |||
# shape: [batch_size, num_head, head_dim, seq_len] | |||
key = self._split_heads(key, is_key=True) | |||
# shape: [batch_size, num_head, seq_len, head_dim] | |||
value = self._split_heads(value) | |||
if cache is not None: | |||
if 'key' in cache and 'value' in cache: | |||
key = torch.cat([cache['key'], key], dim=3) | |||
value = torch.cat([cache['value'], value], dim=2) | |||
cache['key'] = key | |||
cache['value'] = value | |||
out = self._attn(query, key, value, mask) | |||
out = self._merge_heads(out) | |||
out = self.linear_out(out) | |||
return out | |||
def main(): | |||
import numpy as np | |||
model = MultiheadAttention(10, 2, 0.5) | |||
inp = np.random.rand(2, 3, 10).astype('float32') | |||
inp = torch.tensor(inp) | |||
mask = (np.random.rand(2, 3, 3) > 0.5).astype('float32') | |||
mask = torch.tensor(mask) | |||
out = model(inp, mask=mask, cache=None) | |||
print(out) | |||
if __name__ == '__main__': | |||
main() |
@@ -0,0 +1,73 @@ | |||
""" | |||
TransformerBlock class. | |||
""" | |||
import torch | |||
import torch.nn as nn | |||
from maas_lib.models.nlp.space.modules.feedforward import FeedForward | |||
from maas_lib.models.nlp.space.modules.multihead_attention import \ | |||
MultiheadAttention | |||
class TransformerBlock(nn.Module): | |||
""" | |||
Transformer block module. | |||
""" | |||
def __init__(self, hidden_dim, num_heads, dropout, attn_dropout, | |||
ff_dropout): | |||
super(TransformerBlock, self).__init__() | |||
self.attn = MultiheadAttention( | |||
hidden_dim=hidden_dim, num_heads=num_heads, dropout=attn_dropout) | |||
self.attn_norm = nn.LayerNorm( | |||
normalized_shape=hidden_dim, eps=1e-12, elementwise_affine=True) | |||
self.ff = FeedForward( | |||
hidden_dim=hidden_dim, | |||
inner_dim=4 * hidden_dim, | |||
dropout=ff_dropout) | |||
self.ff_norm = nn.LayerNorm( | |||
normalized_shape=hidden_dim, eps=1e-12, elementwise_affine=True) | |||
self.dropout_layer = nn.Dropout(p=dropout) | |||
return | |||
def forward(self, inp, mask=None, cache=None): | |||
""" | |||
Forward process on one transformer layer. | |||
@param : x | |||
@type : Variable(shape: [batch_size, seq_len, hidden_size]) | |||
@param : memory | |||
@type : Variable(shape: [batch_size, seq_len, hidden_size]) | |||
@param : mask | |||
@param : cache | |||
""" | |||
attn_out = self.attn(inp, mask, cache) | |||
attn_out = self.dropout_layer(attn_out) | |||
attn_out = self.attn_norm(attn_out + inp) | |||
ff_out = self.ff(attn_out) | |||
ff_out = self.dropout_layer(ff_out) | |||
ff_out = self.ff_norm(ff_out + attn_out) | |||
return ff_out | |||
def main(): | |||
import numpy as np | |||
model = TransformerBlock(10, 2, 0.5, 0.5, 0.5) | |||
inp = np.random.rand(2, 3, 10).astype('float32') | |||
inp = torch.tensor(inp) | |||
mask = (np.random.rand(2, 3, 3) > 0.5).astype('float32') | |||
mask = torch.tensor(mask) | |||
out = model(inp, mask=mask, cache=None) | |||
print(out) | |||
if __name__ == '__main__': | |||
main() |
@@ -4,5 +4,5 @@ from .base import Preprocessor | |||
from .builder import PREPROCESSORS, build_preprocessor | |||
from .common import Compose | |||
from .image import LoadImage, load_image | |||
from .nlp import * # noqa F403 | |||
from .space.dialog_generation_preprcessor import * # noqa F403 | |||
from .nlp.nlp import * # noqa F403 | |||
from .nlp.space.dialog_generation_preprcessor import * # noqa F403 |
@@ -7,8 +7,8 @@ from transformers import AutoTokenizer | |||
from maas_lib.utils.constant import Fields, InputFields | |||
from maas_lib.utils.type_assert import type_assert | |||
from .base import Preprocessor | |||
from .builder import PREPROCESSORS | |||
from ..base import Preprocessor | |||
from ..builder import PREPROCESSORS | |||
__all__ = [ | |||
'Tokenize', |
@@ -5,10 +5,11 @@ import uuid | |||
from typing import Any, Dict, Union | |||
from maas_lib.data.nlp.space.fields.gen_field import MultiWOZBPETextField | |||
from maas_lib.utils.config import Config | |||
from maas_lib.utils.constant import Fields, InputFields | |||
from maas_lib.utils.type_assert import type_assert | |||
from ..base import Preprocessor | |||
from ..builder import PREPROCESSORS | |||
from ...base import Preprocessor | |||
from ...builder import PREPROCESSORS | |||
__all__ = ['DialogGenerationPreprocessor'] | |||
@@ -25,10 +26,10 @@ class DialogGenerationPreprocessor(Preprocessor): | |||
super().__init__(*args, **kwargs) | |||
self.model_dir: str = model_dir | |||
self.text_field = MultiWOZBPETextField(model_dir=self.model_dir) | |||
pass | |||
self.config = Config.from_file( | |||
os.path.join(self.model_dir, 'configuration.json')) | |||
self.text_field = MultiWOZBPETextField( | |||
self.model_dir, config=self.config) | |||
@type_assert(object, str) | |||
def __call__(self, data: str) -> Dict[str, Any]: |
@@ -4,37 +4,11 @@ import os.path as osp | |||
import tempfile | |||
import unittest | |||
from maas_lib.fileio import File | |||
from tests.case.nlp.dialog_generation_case import test_case | |||
from maas_lib.models.nlp import DialogGenerationModel | |||
from maas_lib.pipelines import DialogGenerationPipeline, pipeline | |||
from maas_lib.preprocessors import DialogGenerationPreprocessor | |||
from maas_lib.utils.constant import Tasks | |||
dialog_case = [{ | |||
'user': | |||
'am looking for a place to to stay that has cheap price range it should be in a type of hotel', | |||
'sys': | |||
'okay , do you have a specific area you want to stay in ?' | |||
}, { | |||
'user': | |||
'no , i just need to make sure it is cheap . oh , and i need parking', | |||
'sys': | |||
'i found 1 cheap hotel for you that include -s parking . do you like me to book it ?' | |||
}, { | |||
'user': | |||
'yes , please . 6 people 3 nights starting on tuesday .', | |||
'sys': | |||
"i am sorry but i was n't able to book that for you for tuesday . is there another day you would like " | |||
'to stay or perhaps a shorter stay ? ' | |||
}, { | |||
'user': | |||
'how about only 2 nights .', | |||
'sys': | |||
'booking was successful . reference number is : 7gawk763 . anything else i can do for you ?', | |||
}, { | |||
'user': 'no , that will be all . goodbye .', | |||
'sys': 'thank you for using our services .' | |||
}] | |||
def merge(info, result): | |||
@@ -47,21 +21,23 @@ class DialogGenerationTest(unittest.TestCase): | |||
modeldir = '/Users/yangliu/Desktop/space-dialog-generation' | |||
preprocessor = DialogGenerationPreprocessor() | |||
preprocessor = DialogGenerationPreprocessor(model_dir=modeldir) | |||
model = DialogGenerationModel( | |||
model_dir=modeldir, preprocessor.tokenizer) | |||
pipeline = DialogGenerationPipeline(model, preprocessor) | |||
model_dir=modeldir, | |||
text_field=preprocessor.text_field, | |||
config=preprocessor.config) | |||
# pipeline = DialogGenerationPipeline(model, preprocessor) | |||
history_dialog = {} | |||
for step in range(0, len(dialog_case)): | |||
user_question = dialog_case[step]['user'] | |||
for step, item in enumerate(test_case['sng0073']['log']): | |||
user_question = item['user'] | |||
print('user: {}'.format(user_question)) | |||
history_dialog_info = merge(history_dialog_info, | |||
result) if step > 0 else {} | |||
result = pipeline(user_question, history=history_dialog_info) | |||
print('sys : {}'.format(result['pred_answer'])) | |||
# history_dialog_info = merge(history_dialog_info, | |||
# result) if step > 0 else {} | |||
# result = pipeline(user_question, history=history_dialog_info) | |||
# | |||
# print('sys : {}'.format(result['pred_answer'])) | |||
if __name__ == '__main__': | |||