|
- # Copyright (c) 2022 Zhipu.AI
-
- import os
- import random
- from os import path as osp
- from typing import Dict
-
- import numpy as np
- import torch
- import torch.nn.functional as F
-
- from modelscope.metainfo import Models
- from modelscope.models.base import Tensor, TorchModel
- from modelscope.models.builder import MODELS
- from modelscope.outputs import OutputKeys
- from modelscope.utils.config import Config
- from modelscope.utils.constant import ModelFile, Tasks
- from . import mpu
- from .arguments import get_args
- from .generation_utils import BeamSearchScorer
- from .train_utils import get_model
- from .utils import load_checkpoint
-
- __all__ = ['MGLMForTextSummarization']
-
-
- def setup_args(args):
- args.block_lm = True
- args.task_mask = True
- args.cloze_eval = True
- args.num_layers = 24
- args.hidden_size = 1536
- args.num_attention_heads = 16
- args.max_position_embeddings = 1024
- args.tokenizer_type = 'ChineseSPTokenizer'
- args.load_pretrained = ''
- args.DDP_impl = 'none'
- args.model_parallel_size = 1
- args.fp16 = True
- args.cache_dir = 'cache'
- args.out_seq_length = 200
- args.seq_length = 512
- args.temperature = 0.9
- args.top_k = 2
- args.top_p = 0.8
- args.frequency_penalty = 0.1
- args.presence_penalty = 0.1
- args.mem_length = args.seq_length + args.mem_length - 1
- return args
-
-
- def setup_model(args):
- """Setup model and optimizer."""
-
- model = get_model(args, model_type='generation')
-
- if args.load_pretrained is not None:
- args.no_load_optim = True
- args.load = args.load_pretrained
- _ = load_checkpoint(model, None, None, args)
-
- return model
-
-
- def set_random_seed(seed):
- """Set random seed for reproducability."""
-
- if seed is not None and seed > 0:
- random.seed(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- mpu.model_parallel_cuda_manual_seed(seed)
-
-
- def get_masks_and_position_ids(data,
- eod_token,
- reset_position_ids,
- reset_attention_mask,
- loss_mask=None,
- attention_mask=None,
- set_loss_mask=False,
- mem_length=None):
- # Extract batch size and sequence length.
- batch_size, seq_length = data.size()
-
- # Attention mask (lower triangular).
- if mem_length:
- if attention_mask is None:
- attention_mask = torch.ones(
- (1, seq_length, seq_length + mem_length), device=data.device)
- attention_mask = torch.tril(
- torch.triu(attention_mask, 1 - seq_length + mem_length),
- mem_length)
- else:
- if reset_attention_mask:
- att_mask_batch = batch_size
- else:
- att_mask_batch = 1
- if attention_mask is None:
- attention_mask = torch.ones(
- (att_mask_batch, seq_length, seq_length), device=data.device)
- attention_mask = torch.tril(attention_mask)
- attention_mask = attention_mask.unsqueeze(1)
-
- # Loss mask.
- if loss_mask is None:
- loss_mask = torch.ones(
- data.size(), dtype=torch.float, device=data.device)
-
- # Position ids.
- position_ids = torch.arange(
- seq_length, dtype=torch.long, device=data.device)
- position_ids = position_ids.unsqueeze(0).expand_as(data)
- if set_loss_mask:
- loss_mask[data == eod_token] = 0.0
- # We need to clone as the ids will be modifed based on batch index.
- if reset_position_ids:
- position_ids = position_ids.clone()
-
- if reset_position_ids or reset_attention_mask:
- # Loop through the batches:
- for b in range(batch_size):
-
- # Find indecies where EOD token is.
- eod_index = position_ids[b, data[b] == eod_token]
- # Detach indecies from positions if going to modify positions.
- if reset_position_ids:
- eod_index = eod_index.clone()
-
- # Loop through EOD indecies:
- prev_index = 0
- for j in range(eod_index.size()[0]):
- i = eod_index[j]
- # Mask attention loss.
- if reset_attention_mask:
- attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
- # Reset positions.
- if reset_position_ids:
- position_ids[b, (i + 1):] -= (i + 1 - prev_index)
- prev_index = i + 1
-
- return attention_mask, loss_mask, position_ids
-
-
- def initialize_distributed(args):
- """Initialize torch.distributed."""
-
- # Manually set the device ids.
- device = args.rank % torch.cuda.device_count()
- if args.local_rank is not None:
- device = args.local_rank
- torch.cuda.set_device(device)
- # Call the init process
- init_method = 'tcp://'
- args.master_ip = os.getenv('MASTER_ADDR', 'localhost')
- args.master_port = os.getenv('MASTER_PORT', '6000')
- init_method += args.master_ip + ':' + args.master_port
- torch.distributed.init_process_group(
- backend=args.distributed_backend,
- world_size=args.world_size,
- rank=args.rank,
- init_method=init_method)
-
- # Set the model-parallel / data-parallel communicators.
- mpu.initialize_model_parallel(args.model_parallel_size)
-
- # Optional DeepSpeed Activation Checkpointing Features
- #
- if hasattr(
- args, 'deepspeed'
- ) and args.deepspeed and args.deepspeed_activation_checkpointing:
- set_deepspeed_activation_checkpointing(args)
-
-
- def get_batch(context_tokens, device, args):
- tokens = context_tokens
- tokens = tokens.view(args.batch_size, -1).contiguous()
- tokens = tokens.to(device)
-
- # Get the masks and postition ids.
- if args.block_lm:
- attention_mask = torch.tensor([tokens.size(1)],
- device=device,
- dtype=torch.long)
- position_ids = torch.arange(
- tokens.size(1), device=device, dtype=torch.long)
- if not args.no_block_position:
- block_position_ids = torch.zeros(
- tokens.size(1), device=device, dtype=torch.long)
- position_ids = torch.stack((position_ids, block_position_ids),
- dim=0)
- position_ids = position_ids.unsqueeze(0)
- else:
- attention_mask, loss_mask, position_ids = get_masks_and_position_ids(
- tokens,
- args.eod_token,
- reset_position_ids=False,
- reset_attention_mask=False,
- set_loss_mask=False,
- mem_length=args.mem_length)
-
- return tokens, attention_mask, position_ids
-
-
- def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
- # This function has been mostly taken from huggingface conversational ai code at
- # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
-
- if top_k > 0:
- # Remove all tokens with a probability less than the last token of the top-k
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1,
- None]
- logits[indices_to_remove] = filter_value
-
- if top_p > 0.0:
- # convert to 1D
- logits = logits.view(logits.size()[1]).contiguous()
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
- cumulative_probs = torch.cumsum(
- F.softmax(sorted_logits, dim=-1), dim=-1)
-
- # Remove tokens with cumulative probability above the threshold
- sorted_indices_to_remove = cumulative_probs > top_p
- # Shift the indices to the right to keep also the first token above the threshold
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
- ..., :-1].clone()
- sorted_indices_to_remove[..., 0] = 0
- indices_to_remove = sorted_indices[sorted_indices_to_remove]
- logits[indices_to_remove] = filter_value
- # going back to 2D
- logits = logits.view(1, -1).contiguous()
-
- return logits
-
-
- def sample_sequence(model,
- tokenizer,
- context_tokens,
- context_length,
- args,
- device,
- mems=None,
- end_tokens=None):
- if not args.block_lm:
- context_tokens, attention_mask, position_ids = get_batch(
- context_tokens, device, args)
- tokens = torch.empty((args.num_beams, 0),
- device=context_tokens.device,
- dtype=torch.long)
- else:
- tokens = context_tokens.new_full((1, 1),
- tokenizer.get_command('sop').Id)
- counter = 0
- if mems is None:
- mems = []
- if end_tokens is None:
- end_tokens = [args.eod_token]
-
- last_beam_num = 1
- output_tokens_list = []
- generated_tokens_list = []
-
- while counter < args.out_seq_length:
- if counter == 0 and not args.block_lm:
- next_token_logits, *mems = model(context_tokens, position_ids,
- attention_mask, *mems)
- else:
- if args.block_lm:
- if args.no_block_position:
- position_ids = context_tokens.new_full(
- (last_beam_num, 1), context_length + counter)
- else:
- position_ids = context_tokens.new_ones(last_beam_num, 2, 1)
- position_ids[:, 0] = context_length
- position_ids[:, 1] = counter + 1
- attention_mask = context_tokens.new_zeros(
- [1], device=context_tokens.device, dtype=torch.long)
- else:
- position_ids = context_tokens.new_ones((last_beam_num, 1)) * (
- context_length + counter - 1)
- attention_mask = context_tokens.new_ones(
- last_beam_num,
- 1,
- 1,
- args.mem_length + 1,
- device=context_tokens.device,
- dtype=torch.float)
- last_token = tokens[:, -1:]
- next_token_logits, *mems = model(last_token, position_ids,
- attention_mask, *mems)
- next_token_logits = next_token_logits[:, -1]
-
- next_token_logits /= args.temperature
- frequency_count = torch.zeros(next_token_logits.shape)
- for tk in output_tokens_list:
- frequency_count[0][tk] += 1
-
- next_token_logits -= (args.frequency_penalty
- * frequency_count).to(device)
- next_token_logits -= (
- args.presence_penalty * # noqa
- (frequency_count > 0)).to(device)
-
- next_token_logits = top_k_logits(
- next_token_logits, top_k=args.top_k, top_p=args.top_p)
- log_probs = F.softmax(next_token_logits, dim=-1)
- prev = torch.multinomial(log_probs, num_samples=1)[0]
- is_end = prev.item() in end_tokens
- if is_end:
- break
- decode_tokens = tokenizer.DecodeIds([prev.item()]) # noqa
- generated_tokens_list.append(prev.item())
- prev = prev.view(1, 1)
- tokens = prev if tokens is None else torch.cat((tokens, prev), dim=1)
- counter += 1
- output_tokens_list = tokens.view(-1).contiguous()
- return torch.cat((context_tokens, tokens), dim=1), mems
-
-
- def read_context(tokenizer, args, context):
- terminate_runs, skip_run = 0, 0 # noqa
- if mpu.get_model_parallel_rank() == 0:
- while True:
- # raw_text = input("\nContext prompt (stop to exit) >>> ")
- raw_text = context
- if not raw_text:
- print('Prompt should not be empty!')
- break
- # if raw_text == "stop":
- # terminate_runs = 1
- # break
- generation_mask = '[gMASK]' if args.task_mask else '[MASK]'
- if args.block_lm and 'MASK]' not in raw_text:
- raw_text += ' ' + generation_mask
- # output.write(raw_text)
- context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization
- if args.block_lm:
- context_tokens = [tokenizer.get_command('ENC').Id
- ] + context_tokens
- if not raw_text.endswith('[gMASK]'):
- context_tokens = context_tokens + [
- tokenizer.get_command('eos').Id
- ]
- context_length = len(context_tokens)
-
- if context_length >= args.seq_length:
- print('\nContext length', context_length,
- '\nPlease give smaller context than the window length!')
- break
- break
- else:
- context_length = 0
-
- terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
- torch.distributed.broadcast(
- terminate_runs_tensor,
- mpu.get_model_parallel_src_rank(),
- group=mpu.get_model_parallel_group())
- terminate_runs = terminate_runs_tensor[0].item()
-
- if terminate_runs == 1:
- return terminate_runs, None, None, None
-
- context_length_tensor = torch.cuda.LongTensor([context_length])
-
- torch.distributed.broadcast(
- context_length_tensor,
- mpu.get_model_parallel_src_rank(),
- group=mpu.get_model_parallel_group())
- context_length = context_length_tensor[0].item()
- if mpu.get_model_parallel_rank() == 0:
- context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
- else:
- context_tokens_tensor = torch.cuda.LongTensor([0] * context_length)
- torch.distributed.broadcast(
- context_tokens_tensor,
- mpu.get_model_parallel_src_rank(),
- group=mpu.get_model_parallel_group())
- if mpu.get_model_parallel_rank() != 0:
- raw_text = tokenizer.DecodeIds(context_tokens_tensor.tolist())
- return terminate_runs, raw_text, context_tokens_tensor, context_length
-
-
- @MODELS.register_module(Tasks.text_summarization, module_name=Models.mglm)
- class MGLMForTextSummarization(TorchModel):
-
- def __init__(self, model_dir: str, *args, **kwargs):
- """initialize the text summarization model from the `model_dir` path.
-
- Args:
- model_dir (str): the model path.
- """
- super().__init__(model_dir, *args, **kwargs)
-
- from .configure_data import prepare_tokenizer
- # Disable CuDNN.
- torch.backends.cudnn.enabled = False
- # Arguments.
- self.args = setup_args(get_args())
- self.args.load_pretrained = model_dir
- # Pytorch distributed.
- try:
- initialize_distributed(self.args)
- except (RuntimeError):
- print('group process initialized twice')
- # Random seeds for reproducability.
- set_random_seed(self.args.seed)
- # setting default batch size to 1
- self.args.batch_size = 1
- self.args.tokenizer_path = model_dir
- self.tokenizer = prepare_tokenizer(self.args)
- self.model = setup_model(self.args)
- self.cfg = Config.from_file(
- osp.join(model_dir, ModelFile.CONFIGURATION))
-
- def forward(self, input: Dict[str, str]) -> Dict[str, str]:
- pass
-
- def generate(self, input: Dict[str, str]) -> Dict[str, str]:
- model = self.model
- tokenizer = self.tokenizer
- args = self.args
- device = torch.cuda.current_device()
- model.eval()
-
- context = input['text'] + self.cfg.model.prompt
- with torch.no_grad():
- terminate_runs, raw_text, context_tokens_tensor, context_length = read_context(
- tokenizer, args, context)
- mems = []
- tokens, attention_mask, position_ids = get_batch(
- context_tokens_tensor, device, args)
- mask_tokens = ['MASK', 'sMASK', 'gMASK'
- ] if args.task_mask else ['MASK']
- mask_tokens = [
- tokenizer.get_command(token).Id for token in mask_tokens
- ]
- end_tokens = [tokenizer.get_command('eop').Id, args.eod_token]
-
- mask_positions = []
- for token in mask_tokens:
- mask_positions += (context_tokens_tensor == token).nonzero(
- as_tuple=True)[0].tolist()
- mask_positions.sort()
- if args.no_block_position:
- for mask_position in mask_positions:
- position_ids[0, mask_position + 1:] += args.out_seq_length
- _, *mems = model(tokens, position_ids, attention_mask, *mems)
- for mask_position in mask_positions:
- if args.no_block_position:
- position = position_ids[0, mask_position].item()
- else:
- position = mask_position
- tokens, mems, = sample_sequence(
- model,
- tokenizer,
- tokens,
- position,
- args,
- device,
- mems=mems,
- end_tokens=end_tokens)
- output_tokens_list = tokens.view(-1).contiguous()
- trim_decode_tokens = tokenizer.DecodeIds(
- output_tokens_list.tolist())
- res = trim_decode_tokens.split('<|startofpiece|>')[-1]
- print(res)
- return {OutputKeys.TEXT: res}
|