From 177d70829be59432ab7fbcd4740d7904a9c5c819 Mon Sep 17 00:00:00 2001 From: "jerry.lp" Date: Tue, 29 Nov 2022 20:54:32 +0800 Subject: [PATCH] add gpt-moe model for modelscope pipeline inference Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10836131 --- modelscope/metainfo.py | 2 + modelscope/models/nlp/gpt_moe/__init__.py | 27 + modelscope/models/nlp/gpt_moe/backbone.py | 355 +++++ .../models/nlp/gpt_moe/checkpointing.py | 145 ++ .../models/nlp/gpt_moe/configuration.py | 128 ++ .../models/nlp/gpt_moe/distributed_gpt_moe.py | 1236 +++++++++++++++++ modelscope/models/nlp/gpt_moe/moe/__init__.py | 0 modelscope/models/nlp/gpt_moe/moe/experts.py | 36 + modelscope/models/nlp/gpt_moe/moe/layer.py | 98 ++ modelscope/models/nlp/gpt_moe/moe/mappings.py | 87 ++ .../models/nlp/gpt_moe/moe/sharded_moe.py | 647 +++++++++ modelscope/models/nlp/gpt_moe/moe/utils.py | 125 ++ .../models/nlp/gpt_moe/text_generation.py | 62 + modelscope/models/nlp/gpt_moe/tokenizer.py | 67 + .../nlp/distributed_gpt_moe_pipeline.py | 54 + .../pipelines/test_gpt_moe_text_generation.py | 24 + 16 files changed, 3093 insertions(+) create mode 100644 modelscope/models/nlp/gpt_moe/__init__.py create mode 100644 modelscope/models/nlp/gpt_moe/backbone.py create mode 100644 modelscope/models/nlp/gpt_moe/checkpointing.py create mode 100644 modelscope/models/nlp/gpt_moe/configuration.py create mode 100644 modelscope/models/nlp/gpt_moe/distributed_gpt_moe.py create mode 100644 modelscope/models/nlp/gpt_moe/moe/__init__.py create mode 100644 modelscope/models/nlp/gpt_moe/moe/experts.py create mode 100644 modelscope/models/nlp/gpt_moe/moe/layer.py create mode 100644 modelscope/models/nlp/gpt_moe/moe/mappings.py create mode 100644 modelscope/models/nlp/gpt_moe/moe/sharded_moe.py create mode 100644 modelscope/models/nlp/gpt_moe/moe/utils.py create mode 100644 modelscope/models/nlp/gpt_moe/text_generation.py create mode 100644 modelscope/models/nlp/gpt_moe/tokenizer.py create mode 100644 modelscope/pipelines/nlp/distributed_gpt_moe_pipeline.py create mode 100644 tests/pipelines/test_gpt_moe_text_generation.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 3d566da8..e70e82fe 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -80,6 +80,7 @@ class Models(object): gcnncrf = 'gcnn-crf' bart = 'bart' gpt3 = 'gpt3' + gpt_moe = 'gpt-moe' gpt_neo = 'gpt-neo' plug = 'plug' bert_for_ds = 'bert-for-document-segmentation' @@ -255,6 +256,7 @@ class Pipelines(object): text_error_correction = 'text-error-correction' plug_generation = 'plug-generation' gpt3_generation = 'gpt3-generation' + gpt_moe_generation = 'gpt-moe-generation' faq_question_answering = 'faq-question-answering' conversational_text_to_sql = 'conversational-text-to-sql' table_question_answering_pipeline = 'table-question-answering-pipeline' diff --git a/modelscope/models/nlp/gpt_moe/__init__.py b/modelscope/models/nlp/gpt_moe/__init__.py new file mode 100644 index 00000000..3010e64f --- /dev/null +++ b/modelscope/models/nlp/gpt_moe/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .configuration import GPTMoEConfig + from .backbone import GPTMoEModel + from .text_generation import GPTMoEForTextGeneration + from .tokenizer import JiebaBPETokenizer +else: + _import_structure = { + 'configuration': ['GPTMoEConfig'], + 'backbone': ['GPTMoEModel'], + 'text_generation': ['GPTMoEForTextGeneration'], + 'tokenizer': ['JiebaBPETokenizer'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/nlp/gpt_moe/backbone.py b/modelscope/models/nlp/gpt_moe/backbone.py new file mode 100644 index 00000000..cea37432 --- /dev/null +++ b/modelscope/models/nlp/gpt_moe/backbone.py @@ -0,0 +1,355 @@ +# Copyright 2021-2022 The Alibaba PAI Team Authors. +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +from typing import Optional, Union + +import addict +import torch +from torch import nn +from torch.nn import functional as F +from transformers.modeling_utils import PreTrainedModel + +from modelscope.utils.constant import ModelFile +from .configuration import GPTMoEConfig + + +class GPTMoESelfAttention(nn.Module): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, config): + super().__init__() + + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + # Per attention head + self.hidden_size_per_attention_head = \ + self.hidden_size // self.num_attention_heads + + self.query_key_value = nn.Linear(self.hidden_size, + 3 * self.hidden_size) + self.softmax = nn.Softmax(dim=-1) + self.attention_dropout = nn.Dropout( + config.attention_probs_dropout_prob) + + # Output. + self.dense = nn.Linear(self.hidden_size, self.hidden_size) + self.output_dropout = nn.Dropout(config.hidden_dropout_prob) + + def _transpose_for_scores(self, tensor): + """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with + size [b, np, s, hn]. + """ + new_tensor_shape = tensor.size()[:-1] + ( + self.num_attention_heads, self.hidden_size_per_attention_head) + tensor = tensor.view(*new_tensor_shape) + return tensor.permute(0, 2, 1, 3) + + def _split_tensor_along_last_dim(self, + tensor, + num_partitions, + contiguous_split_chunks=False): + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + def forward(self, hidden_states, ltor_mask, is_infer=False): + # hidden_states: [b, s, h] + # ltor_mask: [1, 1, s, s] + + # Attention heads. [b, s, hp] + tgt_len = hidden_states.size(1) + ltor_mask = torch.reshape(ltor_mask, [1, 1, tgt_len, tgt_len]) + mixed_x_layer = self.query_key_value(hidden_states) + (mixed_query_layer, mixed_key_layer, mixed_value_layer) = \ + self._split_tensor_along_last_dim(mixed_x_layer, 3) + + # Reshape and transpose [b, np, s, hn] + query_layer = self._transpose_for_scores(mixed_query_layer) + key_layer = self._transpose_for_scores(mixed_key_layer) + value_layer = self._transpose_for_scores(mixed_value_layer) + + previous_type = value_layer.type() + + # Raw attention scores. [b, np, s, s] + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt( + self.hidden_size_per_attention_head) + # Apply the left to right attention mask. + if is_infer: + src_len = key_layer.size(2) + ltor_mask = torch.tril( + torch.ones((1, tgt_len, src_len), + device=hidden_states.device)).view( + 1, 1, tgt_len, src_len).type(previous_type) + converted_mask = 10000.0 * (1.0 - ltor_mask) + attention_scores = (torch.mul(attention_scores, ltor_mask) + - converted_mask).type(previous_type) + + # Attention probabilities. [b, np, s, s] + attention_probs = self.softmax(attention_scores) + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.attention_dropout(attention_probs) + + # Context layer. + # [b, np, s, hn] + context_layer = torch.matmul(attention_probs, value_layer) + # [b, s, np, hn] + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.hidden_size, ) + # [b, s, hp] + context_layer = context_layer.view(*new_context_layer_shape) + + # Output. [b, s, h] + output = self.dense(context_layer) + output = self.output_dropout(output) + + return output + + +class GPTMoEMLP(nn.Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__(self, config): + super().__init__() + + hidden_size = config.hidden_size + # Project to 4h. + self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size) + self.activation_func = F.gelu + # Project back to h. + self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size) + + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states): + + # [s, b, 4hp] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + intermediate_parallel = self.activation_func(intermediate_parallel) + # [s, b, h] + output = self.dense_4h_to_h(intermediate_parallel) + output = self.dropout(output) + return output + + +class GPTMoETransformerLayer(nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__(self, config): + super().__init__() + + # Layernorm on the input data. + self.input_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + + # Self attention. + self.attention = GPTMoESelfAttention(config) + + # Layernorm on the attention output + self.post_attention_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + + # MLP + self.mlp = GPTMoEMLP(config) + + def forward(self, hidden_states, ltor_mask): + # hidden_states: [b, s, h] + # ltor_mask: [1, 1, s, s] + + # Layer norm at the begining of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output = self.attention(layernorm_output, ltor_mask) + # Residual connection. + layernorm_input = hidden_states + attention_output + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + # MLP. + mlp_output = self.mlp(layernorm_output) + # Second residual connection. + output = layernorm_input + mlp_output + + return output + + +class GPTMoETransformer(nn.Module): + """Transformer class.""" + + def __init__(self, config): + super().__init__() + + self.input_tensor = None + + # Number of layers. + self.num_layers = config.num_hidden_layers + + self.layers = torch.nn.ModuleList( + [GPTMoETransformerLayer(config) for _ in range(self.num_layers)]) + + # Final layer norm before output. + self.final_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layernorm_epsilon) + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def forward(self, hidden_states, attention_mask): + # hidden_states: [s, b, h] + + for index in range(self.num_layers): + layer = self._get_layer(index) + hidden_states = layer(hidden_states, attention_mask) + + # Final layer norm. + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states + + +class GPTMoETransformerLanguageModel(nn.Module): + """Transformer language model. + + Arguments: + transformer_hparams: transformer hyperparameters + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + embedding_dropout_prob: dropout probability for embeddings + num_tokentypes: size of the token-type embeddings. 0 value + will ignore this embedding + """ + + def __init__(self, config): + super().__init__() + + # Embeddings. + self.word_embeddings = nn.Embedding(config.vocab_size, + config.hidden_size) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob) + + # Transformer. + self.transformer = GPTMoETransformer(config) + + def forward(self, input_ids, attention_mask, position_ids): + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + + embeddings = words_embeddings + position_embeddings + transformer_input = self.embedding_dropout(embeddings) + transformer_output = self.transformer(transformer_input, + attention_mask) + + logits = F.linear(transformer_output, self.word_embeddings.weight) + return logits + + +class GPTMoEModel(PreTrainedModel): + + config_class = GPTMoEConfig + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_( + mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def __init__(self, config): + super().__init__(config) + self.language_model = GPTMoETransformerLanguageModel(config) + + def forward(self, + input_ids, + attention_mask=None, + position_ids=None, + labels=None, + **kwargs): + seq_length = input_ids.size(1) + attention_mask = torch.tril( + torch.ones((1, 1, seq_length, seq_length), + dtype=torch.long, + device=input_ids.device)) + if position_ids is None: + position_ids = torch.arange( + seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + logits = self.language_model(input_ids, attention_mask, position_ids) + loss = None + if labels is not None: + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + logits.view(-1, self.config.vocab_size), labels.view(-1)) + return addict.Dict(loss=loss, logits=logits) + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: Optional[Union[str, + os.PathLike]]): + config = cls.config_class.from_pretrained( + pretrained_model_name_or_path) + model = cls(config) + state_dict_file = os.path.join(pretrained_model_name_or_path, + ModelFile.TORCH_MODEL_BIN_FILE) + state_dict = torch.load(state_dict_file) + if 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + state_dict = { + k.replace('model.language_model', 'language_model'): v + for k, v in state_dict.items() + } + model.load_state_dict(state_dict) + return model + + def prepare_inputs_for_generation(self, input_ids, *args, **kwargs): + return {'input_ids': input_ids} diff --git a/modelscope/models/nlp/gpt_moe/checkpointing.py b/modelscope/models/nlp/gpt_moe/checkpointing.py new file mode 100644 index 00000000..68b66e97 --- /dev/null +++ b/modelscope/models/nlp/gpt_moe/checkpointing.py @@ -0,0 +1,145 @@ +# Copyright 2021-2022 The Alibaba PAI Team Authors. +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch +from megatron import mpu +from megatron.model import Float16Module +from torch.nn.parallel import DistributedDataParallel as torchDDP + +from .configuration import logger +from .moe.layer import MoE + + +def unwrap_model(model, module_instances=(torchDDP)): + return_list = True + if not isinstance(model, list): + model = [model] + return_list = False + unwrapped_model = [] + for model_module in model: + while isinstance(model_module, module_instances): + model_module = model_module.module + unwrapped_model.append(model_module) + if not return_list: + return unwrapped_model[0] + return unwrapped_model + + +def get_checkpoint_names(checkpoints_path, + path_load_tag, + num_experts, + tensor_rank=None, + expp_rank=None): + """Determine the directory name for this rank's checkpoint.""" + if tensor_rank is None: + tensor_rank = mpu.get_model_parallel_rank() + + common_path = os.path.join(checkpoints_path, path_load_tag, + f'mp_rank_{tensor_rank:02d}') + + if num_experts[0] > 0: + model_name = common_path + '_model_states.pt' + optim_name = os.path.join( + checkpoints_path, path_load_tag, + f'expp_rank_{expp_rank}_mp_rank_{tensor_rank:02d}_optim_states.pt') + else: + model_name = optim_name = os.path.join(common_path, + 'model_optim_rng.pt') + + return model_name, optim_name + + +def _get_expert_ckpt_name(checkpoints_path, layer_id, expert_id): + mp_rank = mpu.get_model_parallel_rank() + ckpt_name = os.path.join( + os.path.join(checkpoints_path, 'model'), + f'layer_{layer_id}_expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt' + ) + return ckpt_name + + +def _load_base_checkpoint(load_dir, path_load_tag=None, num_experts=None): + """ Load the base state_dict from the given directory + + If rank0 is true, just loads rank 0 checkpoint, ignoring arguments. + """ + largest_group_name = mpu.get_max_expert_size_name() + expp_rank = mpu.get_expert_parallel_rank(largest_group_name) + checkpoint_names = get_checkpoint_names( + load_dir, + path_load_tag=path_load_tag, + num_experts=num_experts, + expp_rank=expp_rank) + model_checkpoint_name, optim_checkpoint_name = checkpoint_names + + logger.info(f'Loading model checkpoint from {model_checkpoint_name}') + model_state_dict = torch.load(model_checkpoint_name, map_location='cpu') + + return model_state_dict + + +def load_checkpoint(model, + load_dir, + num_experts=None, + strict=True, + path_load_tag='model', + load_ds_ckpts=True): + model = unwrap_model(model, (torchDDP, Float16Module)) + + model_state_dict = _load_base_checkpoint( + load_dir, path_load_tag=path_load_tag, num_experts=num_experts) + + assert model_state_dict is not None + + if load_ds_ckpts: + load_moe_checkpoint(model, model_state_dict['module'], load_dir) + else: + load_moe_checkpoint(model, model_state_dict['model'], load_dir) + + if load_ds_ckpts: + model.load_state_dict(model_state_dict['module'], strict=strict) + else: + model.load_state_dict(model_state_dict['model'], strict=strict) + + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + +def load_moe_checkpoint(model, state_dict, load_dir): + moe_layer_id = 0 + for n_module, module in model.named_modules(): + if isinstance(module, MoE): # and torch.distributed.get_rank() == 0: + group_name = module.expert_group_name + num_local_experts = module.num_local_experts + expp_rank = mpu.get_expert_parallel_rank(group_name) + # loop all local_experts + for local_expert_id in range(num_local_experts): + global_expert_id = expp_rank * num_local_experts + local_expert_id + moe_load_path = _get_expert_ckpt_name(load_dir, moe_layer_id, + global_expert_id) + logger.info(f'Loading expert states from {moe_load_path}') + expert_state_dict = torch.load( + moe_load_path, map_location=torch.device('cpu')) + # Updating global -> local expert ids + moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.' + for key in list(expert_state_dict.keys()): + local_key = key.replace( + f'{moe_str_prefix}{global_expert_id}', + f'{moe_str_prefix}{local_expert_id}') + expert_state_dict[local_key] = expert_state_dict.pop(key) + state_dict.update(expert_state_dict) + moe_layer_id += 1 diff --git a/modelscope/models/nlp/gpt_moe/configuration.py b/modelscope/models/nlp/gpt_moe/configuration.py new file mode 100644 index 00000000..dfab93c6 --- /dev/null +++ b/modelscope/models/nlp/gpt_moe/configuration.py @@ -0,0 +1,128 @@ +# Copyright 2021-2022 The Alibaba PAI Team Authors. +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class GPTMoEConfig(PretrainedConfig): + + model_type = 'gpt-moe' + + def __init__( + self, + vocab_size=25600, + hidden_size=768, + ffn_hidden_size=None, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act='gelu', + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=2048, + type_vocab_size=2, + layernorm_epsilon=1e-12, + bias_gelu_fusion=True, + fp32_residual_connection=False, + sequence_parallel=False, + fp16=False, + bf16=False, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=False, + kv_channels=None, + masked_softmax_fusion=True, + attention_dropout=0.1, + bias_dropout_fusion=True, + apply_residual_connection_post_layernorm=False, + hidden_dropout=0.1, + init_method_std=0.02, + # generate + eod_id=7, + tokens_to_generate=100, + top_k=0, + top_p=0.9, + num_experts=[0], + use_tutel=False, + top_k_linear_strategy='standard', + use_expert_residual_network=False, + load_ds_ckpts=False, + model_dir=None, + **kwargs): + super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size = 4 * hidden_size \ + if ffn_hidden_size is None else ffn_hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.layernorm_epsilon = layernorm_epsilon + self.bias_gelu_fusion = bias_gelu_fusion + self.fp32_residual_connection = fp32_residual_connection + self.sequence_parallel = sequence_parallel + self.fp16 = fp16 + self.bf16 = bf16 + assert not (fp16 and bf16) + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + if kv_channels is None: + assert hidden_size % num_attention_heads == 0 + self.kv_channels = hidden_size // num_attention_heads + self.masked_softmax_fusion = masked_softmax_fusion + self.attention_dropout = attention_dropout + self.bias_dropout_fusion = bias_dropout_fusion + self.apply_residual_connection_post_layernorm = \ + apply_residual_connection_post_layernorm + self.hidden_dropout = hidden_dropout + self.init_method_std = init_method_std + self.eod_id = eod_id + self.tokens_to_generate = tokens_to_generate + self.top_k = top_k + self.top_p = top_p + self.num_experts = num_experts + self.use_tutel = use_tutel + self.top_k_linear_strategy = top_k_linear_strategy + self.use_expert_residual_network = use_expert_residual_network + self.load_ds_ckpts = load_ds_ckpts + self.model_dir = model_dir + + if self.num_experts[0] > torch.cuda.device_count(): + self.moe_expert_parallel_size = torch.cuda.device_count() + else: + self.moe_expert_parallel_size = self.num_experts[0] + + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + self.no_persist_layer_norm = \ + TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11) + + @property + def params_dtype(self): + if self.fp16: + return torch.half + elif self.bf16: + return torch.bfloat16 + else: + return torch.float diff --git a/modelscope/models/nlp/gpt_moe/distributed_gpt_moe.py b/modelscope/models/nlp/gpt_moe/distributed_gpt_moe.py new file mode 100644 index 00000000..9adf332c --- /dev/null +++ b/modelscope/models/nlp/gpt_moe/distributed_gpt_moe.py @@ -0,0 +1,1236 @@ +# Copyright 2021-2022 The Alibaba PAI Team Authors. +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +from megatron import mpu +from megatron.global_vars import get_global_memory_buffer, set_global_variables +from megatron.model import (AttnMaskType, Float16Module, LayerNorm, + bias_gelu_impl) +from megatron.model.fused_softmax import FusedScaleMaskSoftmax +from torch import nn +from torch.nn import functional as F +from transformers.modeling_utils import PreTrainedModel + +from modelscope.models import TorchModel +from modelscope.models.nlp.gpt_moe import GPTMoEConfig +from modelscope.utils.nlp.distributed import initialize_distributed +from modelscope.utils.torch_utils import set_random_seed_mpu +from .checkpointing import load_checkpoint +from .moe.layer import MoE + + +class GPTMoEParallelMLP(nn.Module): + + def __init__(self, + config, + init_method, + output_layer_init_method, + moe=False, + enable_expert_tensor_parallelism=False): + super().__init__() + + # Project to 4h. + self.dense_h_to_4h = mpu.ColumnParallelLinearV3( + config, + config.hidden_size, + config.ffn_hidden_size, + gather_output=False, + init_method=init_method, + skip_bias_add=True, + moe=moe, + enable_expert_tensor_parallelism=enable_expert_tensor_parallelism) + + self.bias_gelu_fusion = config.bias_gelu_fusion + self.activation_func = F.gelu + # Project back to h. + self.dense_4h_to_h = mpu.RowParallelLinearV3( + config, + config.ffn_hidden_size, + config.hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + skip_bias_add=True, + moe=moe, + enable_expert_tensor_parallelism=enable_expert_tensor_parallelism) + + def forward(self, hidden_states): + + # [s, b, 4hp] + intermediate_parallel, bias_parallel = self.dense_h_to_4h( + hidden_states) + + if self.bias_gelu_fusion: + intermediate_parallel = \ + bias_gelu_impl(intermediate_parallel, bias_parallel) + else: + intermediate_parallel = \ + self.activation_func(intermediate_parallel + bias_parallel) + + # [s, b, h] + output, output_bias = self.dense_4h_to_h(intermediate_parallel) + return output, output_bias + + +class GPTMoEEmbedding(nn.Module): + """Language model embeddings. + + Arguments: + hidden_size: hidden size + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + embedding_dropout_prob: dropout probability for embeddings + init_method: weight initialization method + num_tokentypes: size of the token-type embeddings. 0 value + will ignore this embedding + """ + + def __init__(self, config, init_method): + super().__init__() + + self.hidden_size = config.hidden_size + self.init_method = init_method + + # Word embeddings (parallel). + self.word_embeddings = mpu.VocabParallelEmbedding( + config.vocab_size, self.hidden_size, init_method=self.init_method) + self._word_embeddings_key = 'word_embeddings' + + # Position embedding (serial). + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + self.hidden_size) + self._position_embeddings_key = 'position_embeddings' + # Initialize the position embeddings. + self.init_method(self.position_embeddings.weight) + + self.fp32_residual_connection = config.fp32_residual_connection + self.sequence_parallel = config.sequence_parallel + # Embeddings dropout + self.embedding_dropout = nn.Dropout(config.hidden_dropout) + + def zero_parameters(self): + """Zero out all parameters in embedding.""" + self.word_embeddings.weight.data.fill_(0) + self.word_embeddings.weight.shared = True + self.position_embeddings.weight.data.fill_(0) + self.position_embeddings.weight.shared = True + + def forward(self, input_ids, position_ids): + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + embeddings = words_embeddings + position_embeddings + + # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. + embeddings = embeddings.transpose(0, 1).contiguous() + + # If the input flag for fp32 residual connection is set, convert for float. + if self.fp32_residual_connection: + embeddings = embeddings.float() + + # Dropout. + if self.sequence_parallel: + embeddings = mpu.scatter_to_sequence_parallel_region(embeddings) + with mpu.get_cuda_rng_tracker().fork(): + embeddings = self.embedding_dropout(embeddings) + else: + embeddings = self.embedding_dropout(embeddings) + return embeddings + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + # Word embedding. + if self._word_embeddings_key in state_dict: + state_dict_ = state_dict[self._word_embeddings_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'word_embeddings' in key: + state_dict_[key.split('word_embeddings.')[1]] \ + = state_dict[key] + self.word_embeddings.load_state_dict(state_dict_, strict=strict) + + # Position embedding. + if self._position_embeddings_key in state_dict: + state_dict_ = state_dict[self._position_embeddings_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'position_embeddings' in key: + state_dict_[key.split('position_embeddings.')[1]] \ + = state_dict[key] + self.position_embeddings.load_state_dict(state_dict_, strict=strict) + + +class NoopTransformerLayer(nn.Module): + + def __init__(self, layer_number): + super().__init__() + self.layer_number = layer_number + + def forward(self, + hidden_states, + attention_mask, + encoder_output=None, + enc_dec_attn_mask=None, + inference_params=None): + return hidden_states.clone() + + +def attention_mask_func(attention_scores, attention_mask): + attention_scores.masked_fill_(attention_mask, -10000.0) + return attention_scores + + +class GPTMoECoreAttention(nn.Module): + + def __init__(self, + config, + layer_number, + attn_mask_type=AttnMaskType.padding): + super().__init__() + self.fp16 = config.fp16 + self.bf16 = config.bf16 + + self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + self.layer_number = max(1, layer_number) + self.attn_mask_type = attn_mask_type + self.sequence_parallel = config.sequence_parallel + + projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + world_size = mpu.get_model_parallel_world_size() + self.hidden_size_per_partition = mpu.divide(projection_size, + world_size) + self.hidden_size_per_attention_head = mpu.divide( + projection_size, config.num_attention_heads) + self.num_attention_heads_per_partition = mpu.divide( + config.num_attention_heads, world_size) + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + + self.scale_mask_softmax = FusedScaleMaskSoftmax( + self.fp16, self.bf16, self.attn_mask_type, + config.masked_softmax_fusion, attention_mask_func, + self.attention_softmax_in_fp32, coeff) + + # Dropout. Note that for a single iteration, this layer will generate + # different outputs on different number of parallel partitions but + # on average it should not be partition dependent. + self.attention_dropout = nn.Dropout(config.attention_dropout) + + def forward(self, query_layer, key_layer, value_layer, attention_mask): + + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + + # [b, np, sq, sk] + output_size = (query_layer.size(1), query_layer.size(2), + query_layer.size(0), key_layer.size(0)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(output_size[2], + output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], + output_size[0] * output_size[1], -1) + + # preallocting input tensor: [b * np, sq, sk] + matmul_input_buffer = get_global_memory_buffer().get_tensor( + (output_size[0] * output_size[1], output_size[2], output_size[3]), + query_layer.dtype, 'mpu') + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor)) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + attention_probs = self.scale_mask_softmax(attention_scores, + attention_mask) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + + if not self.sequence_parallel: + with mpu.get_cuda_rng_tracker().fork(): + attention_probs = self.attention_dropout(attention_probs) + else: + attention_probs = self.attention_dropout(attention_probs) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), value_layer.size(2), + query_layer.size(0), value_layer.size(3)) + + # change view [sk, b * np, hn] + value_layer = value_layer.view( + value_layer.size(0), output_size[0] * output_size[1], -1) + + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], + output_size[2], -1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + + +class GPTMoEParallelAttention(nn.Module): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, config, init_method, output_layer_init_method, + layer_number): + super().__init__() + self.layer_number = max(1, layer_number) + self.params_dtype = config.params_dtype + + projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + world_size = mpu.get_model_parallel_world_size() + self.hidden_size_per_attention_head = mpu.divide( + projection_size, config.num_attention_heads) + self.num_attention_heads_per_partition = mpu.divide( + config.num_attention_heads, world_size) + + # Strided linear layer. + self.query_key_value = mpu.ColumnParallelLinearV3( + config, + config.hidden_size, + 3 * projection_size, + gather_output=False, + init_method=init_method) + + self.core_attention = GPTMoECoreAttention(config, self.layer_number) + + # Output. + self.dense = mpu.RowParallelLinearV3( + config, + projection_size, + config.hidden_size, + input_is_parallel=True, + init_method=output_layer_init_method, + skip_bias_add=True) + + def _allocate_memory(self, inference_max_sequence_len, batch_size): + return torch.empty( + inference_max_sequence_len, + batch_size, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + dtype=self.params_dtype, + device=torch.cuda.current_device()) + + def forward(self, hidden_states, attention_mask, inference_params=None): + # hidden_states: [sq, b, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + if inference_params: + if self.layer_number not in inference_params.key_value_memory_dict: + inf_max_seq_len = inference_params.max_sequence_len + inf_max_batch_size = inference_params.max_batch_size + inference_key_memory = self._allocate_memory( + inf_max_seq_len, inf_max_batch_size) + inference_value_memory = self._allocate_memory( + inf_max_seq_len, inf_max_batch_size) + inference_params.key_value_memory_dict[self.layer_number] = ( + inference_key_memory, inference_value_memory) + else: + inference_key_memory, inference_value_memory = \ + inference_params.key_value_memory_dict[self.layer_number] + + # ===================== + # Query, Key, and Value + # ===================== + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer, _ = self.query_key_value(hidden_states) + + # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, + value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3) + + # ================================== + # Adjust key and value for inference + # ================================== + + if inference_params: + batch_start = inference_params.batch_size_offset + batch_end = batch_start + key_layer.size(1) + assert batch_end <= inference_key_memory.size(1) + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + key_layer.size(0) + assert sequence_end <= inference_key_memory.size(0) + # Copy key and values. + inference_key_memory[sequence_start:sequence_end, + batch_start:batch_end, ...] = key_layer + inference_value_memory[sequence_start:sequence_end, + batch_start:batch_end, ...] = value_layer + key_layer = inference_key_memory[:sequence_end, + batch_start:batch_end, ...] + value_layer = inference_value_memory[:sequence_end, + batch_start:batch_end, ...] + + # ================================== + # core attention computation + # ================================== + + context_layer = self.core_attention(query_layer, key_layer, + value_layer, attention_mask) + + # ================= + # Output. [sq, b, h] + # ================= + + output, bias = self.dense(context_layer) + + return output, bias + + +class nullcontext: + + def __init__(self, enter_result=None): + self.enter_result = enter_result + + def __enter__(self): + return self.enter_result + + def __exit__(self, *excinfo): + pass + + +def bias_dropout_add(x, bias, residual, prob, training): + # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor + out = torch.nn.functional.dropout(x + bias, p=prob, training=training) + out = residual + out + return out + + +def get_bias_dropout_add(training): + + def _bias_dropout_add(x, bias, residual, prob): + return bias_dropout_add(x, bias, residual, prob, training) + + return _bias_dropout_add + + +@torch.jit.script +def bias_dropout_add_fused_train(x: torch.Tensor, bias: torch.Tensor, + residual: torch.Tensor, + prob: float) -> torch.Tensor: + return bias_dropout_add(x, bias, residual, prob, True) + + +@torch.jit.script +def bias_dropout_add_fused_inference(x: torch.Tensor, bias: torch.Tensor, + residual: torch.Tensor, + prob: float) -> torch.Tensor: + return bias_dropout_add(x, bias, residual, prob, False) + + +class GPTMoEParallelTransformerLayer(nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__(self, + config, + init_method, + output_layer_init_method, + layer_number, + num_experts=1): + + super().__init__() + self.layer_number = layer_number + + self.apply_residual_connection_post_layernorm \ + = config.apply_residual_connection_post_layernorm + + self.bf16 = config.bf16 + self.fp32_residual_connection = config.fp32_residual_connection + + # Layernorm on the input data. + self.input_layernorm = LayerNorm( + config.hidden_size, + eps=config.layernorm_epsilon, + no_persist_layer_norm=config.no_persist_layer_norm, + sequence_parallel=config.sequence_parallel) + + # Self attention. + self.self_attention = GPTMoEParallelAttention( + config, init_method, output_layer_init_method, layer_number) + self.hidden_dropout = config.hidden_dropout + self.bias_dropout_fusion = config.bias_dropout_fusion + + # Layernorm on the attention output + self.post_attention_layernorm = LayerNorm( + config.hidden_size, + eps=config.layernorm_epsilon, + no_persist_layer_norm=config.no_persist_layer_norm, + sequence_parallel=config.sequence_parallel) + + # MLP + self.num_experts = num_experts + if self.num_experts == 1: + self.mlp = GPTMoEParallelMLP(config, init_method, + output_layer_init_method) + else: + enable_expert_tensor_parallelism = config.enable_expert_tensor_parallelism + self.mlp = MoE( + config.hidden_size, + GPTMoEParallelMLP( + config, + init_method, + output_layer_init_method=output_layer_init_method, + moe=True, + enable_expert_tensor_parallelism= + enable_expert_tensor_parallelism), + num_experts=self.num_experts, + ep_size=config.moe_expert_parallel_size, + k=1, + use_residual=False, + capacity_factor=1.0, + eval_capacity_factor=1.0, + noisy_gate_policy=None, + min_capacity=1, + drop_tokens=True, + use_tutel=config.use_tutel, + top_k_linear_strategy=config.top_k_linear_strategy, + use_expert_residual_network=config.use_expert_residual_network) + + # Set bias+dropout+add fusion grad_enable execution handler. + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 + and TORCH_MINOR >= 10) + self.bias_dropout_add_exec_handler = \ + nullcontext if use_nvfuser else torch.enable_grad + + def forward(self, hidden_states, attention_mask, inference_params=None): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, attention_bias = \ + self.self_attention( + layernorm_output, + attention_mask, + inference_params=inference_params) + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + if self.bias_dropout_fusion: + if self.training: + bias_dropout_add_func = bias_dropout_add_fused_train + else: + bias_dropout_add_func = bias_dropout_add_fused_inference + else: + bias_dropout_add_func = get_bias_dropout_add(self.training) + + with self.bias_dropout_add_exec_handler(): + layernorm_input = bias_dropout_add_func( + attention_output, attention_bias.expand_as(residual), residual, + self.hidden_dropout) + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + mlp_bias = torch.tensor( + 0.0, device=layernorm_output.device, dtype=layernorm_output.dtype) + + # MLP. + if self.num_experts == 1: + mlp_output, mlp_bias = self.mlp(layernorm_output) + else: + mlp_output, moe_loss, _ = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + with self.bias_dropout_add_exec_handler(): + output = bias_dropout_add_func(mlp_output, + mlp_bias.expand_as(residual), + residual, self.hidden_dropout) + + # Jit compiled function creates 'view' tensor. This tensor + # potentially gets saved in the MPU checkpoint function context, + # which rejects view tensors. While making a viewless tensor here + # won't result in memory savings (like the data loader, or + # p2p_communication), it serves to document the origin of this + # 'view' tensor. + output = mpu.make_viewless_tensor( + inp=output, requires_grad=output.requires_grad, keep_graph=True) + + return output + + +class GPTMoEParallelTransformer(nn.Module): + """Transformer class.""" + + def __init__(self, + config, + init_method, + output_layer_init_method, + post_layer_norm=True, + pre_process=True, + post_process=True, + num_experts=[0]): + super().__init__() + + self.bf16 = config.bf16 + self.fp32_residual_connection = config.fp32_residual_connection + self.post_layer_norm = post_layer_norm + self.pre_process = pre_process + self.post_process = post_process + self.input_tensor = None + + self.sequence_parallel = config.sequence_parallel + + # Number of layers. + self.num_layers = config.num_hidden_layers + + # Transformer layers. + def build_layer(layer_number, n_e=1): + return GPTMoEParallelTransformerLayer( + config, + init_method, + output_layer_init_method, + layer_number, + num_experts=n_e) + + offset = 0 + if len(num_experts) == 1 and num_experts[0] > 0: + num_experts = num_experts * (self.num_layers // 2) + + if self.num_layers == 0: + self.num_layers = 1 + self.layers = torch.nn.ModuleList([NoopTransformerLayer(1)]) + else: + if num_experts[0] == 0: + self.layers = torch.nn.ModuleList([ + build_layer(i + 1 + offset) for i in range(self.num_layers) + ]) + + else: + self.layers = [] + # Build the layers + for i in range(self.num_layers): + layer_num = i + 1 + offset + if layer_num % 2 == 0: + n_e = num_experts[(layer_num - 1) // 2] + else: + n_e = 1 + self.layers.append(build_layer(layer_num, n_e)) + self.layers = torch.nn.ModuleList(self.layers) + + if self.post_process and self.post_layer_norm: + # Final layer norm before output. + self.final_layernorm = LayerNorm( + config.hidden_size, + eps=config.layernorm_epsilon, + no_persist_layer_norm=config.no_persist_layer_norm, + sequence_parallel=config.sequence_parallel) + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def forward(self, hidden_states, attention_mask, inference_params=None): + # hidden_states: [s, b, h] + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Viewless tensor. + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + hidden_states = mpu.make_viewless_tensor( + hidden_states, + requires_grad=True, + keep_graph=True, + ) + + if self.sequence_parallel: + rng_context = mpu.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + with rng_context: + # Forward pass. + for index in range(self.num_layers): + layer = self._get_layer(index) + hidden_states = layer( + hidden_states, + attention_mask, + inference_params=inference_params) + + # Final layer norm. + if self.post_process and self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states + + +class GPTMoETransformerLanguageModel(nn.Module): + """Transformer language model. + + Arguments: + transformer_hparams: transformer hyperparameters + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + embedding_dropout_prob: dropout probability for embeddings + num_tokentypes: size of the token-type embeddings. 0 value + will ignore this embedding + """ + + def __init__(self, + config, + init_method, + output_layer_init_method, + num_experts=None): + super().__init__() + + self.hidden_size = config.hidden_size + self.init_method = init_method + self.encoder_hidden_state = None + self.num_experts = num_experts + + # Embeddings. + self.embedding = GPTMoEEmbedding(config, self.init_method) + + # Transformer. + self.encoder = GPTMoEParallelTransformer( + config, + self.init_method, + output_layer_init_method, + num_experts=self.num_experts) + + def forward(self, + enc_input_ids, + enc_position_ids, + enc_attn_mask, + inference_params=None, + enc_hidden_states=None): + + # Encoder embedding. + encoder_input = self.embedding(enc_input_ids, enc_position_ids) + + # Run encoder. + if enc_hidden_states is None: + if self.encoder is not None: + encoder_output = self.encoder( + encoder_input, + enc_attn_mask, + inference_params=inference_params) + else: + encoder_output = self.encoder_hidden_state + else: + encoder_output = enc_hidden_states.to(encoder_input.dtype) + + return encoder_output + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + # Embedding. + + if 'embedding' in state_dict: + state_dict_ = state_dict['embedding'] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if '_embeddings' in key: + state_dict_[key] = state_dict[key] + self.embedding.load_state_dict(state_dict_, strict=strict) + + # Encoder. + if True: + if 'encoder' in state_dict: + state_dict_ = state_dict['encoder'] + # For backward compatibility. + elif 'transformer' in state_dict: + state_dict_ = state_dict['transformer'] + else: + # For backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'transformer.' in key: + state_dict_[key.split('transformer.') + [1]] = state_dict[key] + + # For backward compatibility. + state_dict_self_attention = {} + encoder_state_dict_keys = list(self.encoder.state_dict().keys()) + for key in state_dict_.keys(): + if '.attention.' in key and key not in encoder_state_dict_keys: + state_dict_self_attention[key.replace( + '.attention.', '.self_attention.')] = state_dict_[key] + # to load pai bert-1.3B + elif '.self_attention.' in key and key not in encoder_state_dict_keys: + state_dict_self_attention[key.replace( + '.self_attention.', '.attention.')] = state_dict_[key] + else: + state_dict_self_attention[key] = state_dict_[key] + state_dict_ = state_dict_self_attention + + # Gather encoder MoE states + if 'moe_state_dict' in state_dict: + for key in list(state_dict['moe_state_dict'].keys()): + if 'encoder' in key: + key_list = key.split('.') + while key_list[0] != 'encoder': + key_list.pop(0) + key_list.pop(0) + actual_key = '.'.join(key_list) + state_dict_[actual_key] = state_dict[ + 'moe_state_dict'].pop(key) + if len(state_dict['moe_state_dict']) == 0: + del state_dict['moe_state_dict'] + + self.encoder.load_state_dict(state_dict_, strict=strict) + + +def init_method_normal(sigma): + """Init method based on N(0, sigma).""" + + def init_(tensor): + return nn.init.normal_(tensor, mean=0.0, std=sigma) + + return init_ + + +def scaled_init_method_normal(sigma, num_layers): + """Init method based on N(0, sigma/sqrt(2*num_layers).""" + std = sigma / math.sqrt(2.0 * num_layers) + + def init_(tensor): + return nn.init.normal_(tensor, mean=0.0, std=std) + + return init_ + + +class GPTMoEModel(PreTrainedModel): + + config_class = GPTMoEConfig + + def __init__(self, config, parallel_output=False): + super().__init__(config) + + self.parallel_output = parallel_output + self.language_model = GPTMoETransformerLanguageModel( + config, + init_method_normal(config.init_method_std), + scaled_init_method_normal(config.init_method_std, + config.num_hidden_layers), + num_experts=config.num_experts) + + def word_embeddings_weight(self): + return self.language_model.embedding.word_embeddings.weight + + @staticmethod + def build_attention_mask_and_position_ids(tokens): + seq_length = tokens.size(1) + attention_mask = torch.tril( + torch.ones((1, 1, seq_length, seq_length), + dtype=torch.long, + device=tokens.device)) + attention_mask = (attention_mask < 0.5) + + position_ids = torch.arange( + seq_length, dtype=torch.long, device=tokens.device) + position_ids = position_ids.unsqueeze(0).expand_as(tokens) + + return attention_mask, position_ids + + def forward(self, + input_ids, + attention_mask=None, + position_ids=None, + inference_params=None, + **kwargs): + if attention_mask is None and position_ids is None: + attention_mask, position_ids = \ + self.build_attention_mask_and_position_ids(input_ids) + + lm_output = self.language_model( + input_ids, + position_ids, + attention_mask, + inference_params=inference_params) + + logits_parallel = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply( + lm_output, self.word_embeddings_weight(), None, False, True, + self.config.sequence_parallel) + # Gather if needed. + + output = logits_parallel + if not self.parallel_output: + output = mpu.gather_from_model_parallel_region(logits_parallel) + return output.transpose(0, 1).contiguous() + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + # Gather MoE states and move under language model + moe_state_dict = {} + for key in list(state_dict.keys()): + if 'expert' in key and 'moe.gate.wg.weight' not in key: + moe_state_dict[key] = state_dict.pop(key) + + if 'language_model' in state_dict: + state_dict = state_dict['language_model'] + if len(moe_state_dict) > 0: + state_dict['moe_state_dict'] = moe_state_dict + self.language_model.load_state_dict(state_dict, strict=strict) + + +def modify_logits_for_top_k_filtering(logits, top_k): + """Set the logits for none top-k values to -inf.""" + + filter_ = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits.masked_fill_(filter_, float('-Inf')) + + +def modify_logits_for_top_p_filtering(logits, top_p): + """Set the logits for none top-p values to -inf.""" + + # First sort and calculate cumulative sum of probabilities. + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + + # Filteration based on the cumulative sum. + filter_ = cumulative_probs > top_p + # This shift by 1 is weird and I cannot justify it. This existed + # in the original implementation: + # https://github.com/ari-holtzman/degen/blob/master/gen.py + # and I guess it is needed so keeping it for now. + filter_[:, 1:] = filter_[:, :-1].clone() + # Make sure we at least have one token to select from. + filter_[..., 0] = 0 + + # Fill in the filtered part + filter_ = filter_.scatter(1, sorted_indices, filter_) + logits.masked_fill_(filter_, float('-Inf')) + + +def sample(logits, top_k=0, top_p=0.0, temperature=1.0, vocab_size=None): + """ Sample and generate a token. + Note: logits has the dimension [b, v] where b is the batch size + and v is the vocabulary size. + If vocab_size is provided, we will make sure the sample that is + generated is in [0, vocab-size). This will avoid out of vocabulary + generations due to padding. + """ + + # Check logits for consistency. + assert logits.ndim == 2, 'expected the logits to be of [b, v] shape.' + assert logits.type() == 'torch.cuda.FloatTensor', \ + 'input logits should be floats.' + + # Greedy is just simple argmax. + if top_k == 1: + assert top_p == 0.0, 'cannot set both greedy and top-p samplings.' + samples = torch.argmax(logits, dim=-1) + + # Top-k or top-p sampling. + else: + # Clone so we do not modify the inputs, + logits = logits.clone() + # Apply temperature in place. + if temperature != 1.0: + logits.div_(temperature) + + if top_k > 1: + assert top_p == 0.0, 'cannot set both top-k and top-p samplings.' + assert top_k <= logits.size(1), 'top-k is larger than logit size.' + if vocab_size: + assert top_k < vocab_size, 'top-k is larger than vocab size.' + modify_logits_for_top_k_filtering(logits, top_k) + + elif top_p > 0.0: + assert top_p <= 1.0, 'top-p should be in (0, 1].' + modify_logits_for_top_p_filtering(logits, top_p) + + # After filtering, we need to recalculate the distribution. + probs = logits.softmax(dim=-1) + samples = torch.multinomial(probs, num_samples=1).view(-1) + + # If vocab size is provided, make sure the samples are in + # in the range [0, vocab-size). + if vocab_size: + samples = torch.clamp(samples, min=0, max=(vocab_size - 1)) + + return samples + + +class InferenceParams: + """Inference parameters that are passed to the main model in order + to efficienly calculate and store the context during inference.""" + + def __init__(self, max_batch_size, max_sequence_len): + """Note that offsets are set to zero and we always set the + flag to allocate memory. After the first call, make sure to + set this flag to False.""" + self.max_sequence_len = max_sequence_len + self.max_batch_size = max_batch_size + self.sequence_len_offset = 0 + self.batch_size_offset = 0 + self.key_value_memory_dict = {} + + def swap_key_value_dict(self, batch_idx): + 'swap between batches' + if len(self.key_value_memory_dict) == 0: + raise ValueError('should not swap when dict in empty') + + for layer_number in self.key_value_memory_dict.keys(): + inference_key_memory, inference_value_memory = self.key_value_memory_dict[ + layer_number] + assert len(batch_idx) == inference_key_memory.shape[ + 1] # make sure batch size is the same + new_inference_key_memory = inference_key_memory[:, batch_idx] + new_inference_value_memory = inference_value_memory[:, batch_idx] + self.key_value_memory_dict[layer_number] = ( + new_inference_key_memory, new_inference_value_memory) + + +class DistributedGPTMoE(TorchModel): + + def __init__(self, + model_dir, + rank, + path_load_tag='model', + *args, + **kwargs): + super().__init__(model_dir, *args, **kwargs) + initialize_distributed(rank, mpu, kwargs['world_size'], + kwargs['model_parallel_size'], + kwargs['master_ip'], kwargs['master_port']) + + self.config = GPTMoEConfig.from_pretrained(model_dir) + if self.config.num_experts[0] > 0: + mpu.create_expert_and_data_parallel( + self.config.moe_expert_parallel_size) + + seed = 0 if 'seed' not in kwargs else kwargs['seed'] + set_random_seed_mpu(seed) + set_global_variables() + + # Build model. + model = GPTMoEModel(self.config) + + for param in model.parameters(): + mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param) + + # GPU allocation. + model.cuda(torch.cuda.current_device()) + + # Fp16 conversion. + if self.config.fp16 or self.config.bf16: + model = Float16Module(model, self.config) + + self.dist_model = model + if self.config.model_dir is not None: + model_dir = self.config.model_dir + load_checkpoint( + self.dist_model, + model_dir, + num_experts=self.config.num_experts, + path_load_tag=path_load_tag, + load_ds_ckpts=self.config.load_ds_ckpts) + self.inference_params = None + + def forward_step(self, tokens, attention_mask, position_ids): + logits = self.dist_model( + tokens, + attention_mask, + position_ids, + inference_params=self.inference_params) + self.inference_params.sequence_len_offset += tokens.size(1) + return logits + + def generate(self, + tokens, + temperature=1.0, + use_eod_token_for_early_termination=True, + stop_on_double_eol=False, + stop_on_eol=False): + lengths = torch.tensor([tokens.size(1)], device=tokens.device) + pads = torch.ones( + 1, self.config.tokens_to_generate, + device=tokens.device).long() * self.config.eod_id + tokens = torch.cat((tokens, pads), dim=-1) + + batch_size = tokens.size(0) + min_prompt_length = lengths.min().item() + max_sequence_length = tokens.size(1) + max_sequence_length = min(max_sequence_length, + self.config.max_position_embeddings) + + # If the context is too big, this happens + if min_prompt_length >= max_sequence_length: + raise ValueError('context length + tokens_to_generate too large') + + # Initialize inference parameters. + self.inference_params = InferenceParams(batch_size, + max_sequence_length) + + # Added termination_id to support the case that we want to terminate the + # generation once that id is generated. + termination_id = self.config.eod_id + + # Whether we have reached a termination id. + is_generation_done = torch.zeros( + batch_size, dtype=torch.uint8, device=torch.cuda.current_device()) + + # ============= + # Run infernece + # ============= + + with torch.no_grad(): + attention_mask, position_ids = \ + GPTMoEModel.build_attention_mask_and_position_ids(tokens) + prev_context_length = 0 + for context_length in range(min_prompt_length, + max_sequence_length): + + # Pick the slice that we need to pass through the network. + tokens2use = tokens[:, prev_context_length:context_length] + positions2use = position_ids[:, prev_context_length: + context_length] + attention_mask2use = attention_mask[ + ..., prev_context_length:context_length, :context_length] + + # logits will be meanigful only in the last pipeline stage. + logits = self.forward_step(tokens2use, attention_mask2use, + positions2use) + + # Sample. + last_token_logits = logits[:, -1, :] + new_sample = sample( + last_token_logits, + top_k=self.config.top_k, + top_p=self.config.top_p, + temperature=temperature, + vocab_size=self.config.vocab_size) + + # If a prompt length is smaller or equal th current context + # length, it means we have started generating tokens + started = lengths <= context_length + # Update the tokens. + tokens[started, context_length] = new_sample[started] + + # Update the context length for the next token generation. + prev_context_length = context_length + + # instead tokenization should be in the inference loop so stop sequences can be used + if stop_on_double_eol: + hit_double_eol = (new_sample + == 628).byte() & started.byte() + hit_two_eols = (new_sample == 198).byte() & ( + tokens[:, context_length - 1] + == 198).byte() & started.byte() + done_token = hit_double_eol | hit_two_eols + elif stop_on_eol: + hit_double_eol = (new_sample + == 628).byte() & started.byte() + hit_eol = (new_sample == 198).byte() & started.byte() + done_token = hit_double_eol | hit_eol + else: + done_token = (new_sample == termination_id).byte() & \ + started.byte() + + is_generation_done = is_generation_done | done_token + done = torch.all(is_generation_done) + + if use_eod_token_for_early_termination and done: + break + + tokens = tokens[:, :(context_length + 1)] + return tokens diff --git a/modelscope/models/nlp/gpt_moe/moe/__init__.py b/modelscope/models/nlp/gpt_moe/moe/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/nlp/gpt_moe/moe/experts.py b/modelscope/models/nlp/gpt_moe/moe/experts.py new file mode 100644 index 00000000..b559b0b9 --- /dev/null +++ b/modelscope/models/nlp/gpt_moe/moe/experts.py @@ -0,0 +1,36 @@ +''' +Copyright 2020 The Microsoft DeepSpeed Team +''' + +import copy + +import torch + + +class Experts(torch.nn.Module): + + def __init__(self, expert, num_local_experts=1, expert_group_name=None): + super(Experts, self).__init__() + + self.deepspeed_experts = torch.nn.ModuleList( + [copy.deepcopy(expert) for i in range(num_local_experts)]) + self.num_local_experts = num_local_experts + + # TODO: revisit allreduce for moe.gate... + for expert in self.deepspeed_experts: + # TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group) + for name, param in expert.named_parameters(): + param.allreduce = False + param.group_name = expert_group_name + + def forward(self, inputs): + chunks = inputs.chunk(self.num_local_experts, dim=1) + expert_outputs = [] + for chunk, expert in zip(chunks, self.deepspeed_experts): + out = expert(chunk) + if type(out) is tuple: + out = out[0] # Ignore the bias term for now + expert_outputs += [out] + + expert_output = torch.cat(expert_outputs, dim=1) + return expert_output diff --git a/modelscope/models/nlp/gpt_moe/moe/layer.py b/modelscope/models/nlp/gpt_moe/moe/layer.py new file mode 100644 index 00000000..99767bb6 --- /dev/null +++ b/modelscope/models/nlp/gpt_moe/moe/layer.py @@ -0,0 +1,98 @@ +''' +Copyright 2020 The Microsoft DeepSpeed Team +''' + +import typing + +import torch +from megatron import mpu + +from .experts import Experts +from .sharded_moe import MOELayer, TopKGate + + +class MoE(torch.nn.Module): + + def __init__(self, + hidden_size, + expert, + num_experts=1, + ep_size=1, + k=1, + capacity_factor=1., + eval_capacity_factor=1., + min_capacity=4, + use_residual=False, + noisy_gate_policy: typing.Optional[str] = None, + drop_tokens: bool = True, + use_rts=True, + use_tutel: bool = False, + top_k_linear_strategy: str = 'normal', + use_expert_residual_network: bool = False): + super(MoE, self).__init__() + self.use_residual = use_residual + assert num_experts % ep_size == 0, f'Number of experts ({num_experts}) should ' \ + f'be divisible by expert parallel size ({ep_size})' + self.ep_size = ep_size + self.expert_group_name = f'ep_size_{self.ep_size}' + self.num_experts = num_experts + self.num_local_experts = num_experts // self.ep_size + + assert noisy_gate_policy is None or noisy_gate_policy in ['None', 'Jitter', 'RSample'], \ + 'Unsupported noisy_gate_policy: ' + noisy_gate_policy + + experts = Experts(expert, self.num_local_experts, + self.expert_group_name) + self.deepspeed_moe = MOELayer( + TopKGate( + hidden_size, + num_experts, + k, + capacity_factor, + eval_capacity_factor, + min_capacity, + noisy_gate_policy, + drop_tokens, + use_rts, + top_k_linear_strategy=top_k_linear_strategy), + experts, + self.expert_group_name, + self.ep_size, + self.num_local_experts, + use_tutel=use_tutel, + use_expert_residual_network=use_expert_residual_network) + + self.deepspeed_moe._set_ep_group( + mpu.get_expert_parallel_group(self.expert_group_name)) + + if self.use_residual: + self.mlp = expert + # coefficient is used for weighted sum of the output of expert and mlp + self.coefficient = torch.nn.Linear(hidden_size, 2) + + def forward(self, hidden_states, used_token=None): + """ MoE forward + + Arguments: + hidden_states (Tensor): input to the layer + used_token (Tensor, optional): default: None, mask only used tokens + + Returns: + A tuple including output, gate loss, and expert count. + + * output (Tensor): output of the model + + * l_aux (Tensor): gate loss value + + * exp_counts (int): expert count + """ + output = self.deepspeed_moe(hidden_states, used_token) + if self.use_residual: + # Residual MoE + output_mlp = self.mlp(hidden_states) + if type(output_mlp) is tuple: + output_mlp = output_mlp[0] # Ignore the bias term for now + coef = self.coefficient(hidden_states) + coef = torch.nn.functional.softmax(coef, dim=1) + output = output * coef[..., 0:1] + output_mlp * coef[..., 1:] + return output, self.deepspeed_moe.l_aux, self.deepspeed_moe.exp_counts diff --git a/modelscope/models/nlp/gpt_moe/moe/mappings.py b/modelscope/models/nlp/gpt_moe/moe/mappings.py new file mode 100644 index 00000000..a3fb85f7 --- /dev/null +++ b/modelscope/models/nlp/gpt_moe/moe/mappings.py @@ -0,0 +1,87 @@ +''' +Copyright 2020 The Microsoft DeepSpeed Team +''' + +import torch +from megatron import mpu + + +def _gather_tokens(input_, dim=0): + """Gather tensors and concatenate them along a dimension""" + + input_ = input_.contiguous() + # Size and dimension. + rank = mpu.get_tensor_model_parallel_rank() + + tensor_list = [ + torch.empty_like(input_) + for _ in range(mpu.get_model_parallel_world_size()) + ] + tensor_list[rank] = input_ + torch.distributed.all_gather( + tensor_list, input_, group=mpu.get_tensor_model_parallel_group()) + + # Note: torch.cat already creates a contiguous tensor. + output = torch.cat(tensor_list, dim=dim).contiguous() + + return output + + +def _drop_tokens(input_, dim=0): + """Divide a tensor among the tensor parallel ranks""" + total_chunks = mpu.get_model_parallel_world_size() + this_chunk = mpu.get_model_parallel_rank() + assert input_.shape[ + dim] % total_chunks == 0, f'input dimension {dim} ({input_.shape[dim]}) ' \ + f'is not divisible by tensor parallel world size ({total_chunks})' + chunk_size = input_.shape[dim] // total_chunks + + return torch.narrow(input_, dim, this_chunk * chunk_size, chunk_size) + + +class _GatherTokens(torch.autograd.Function): + """All gather tokens among the tensor parallel ranks""" + + @staticmethod + def symbolic(graph, input_, dim): + return _gather_tokens(input_, dim) + + @staticmethod + def forward(ctx, input_, dim): + ctx.dim = dim + return _gather_tokens(input_, dim) + + @staticmethod + def backward(ctx, grad_output): + return _drop_tokens(grad_output, ctx.dim), None + + +class _DropTokens(torch.autograd.Function): + 'Divide tokens equally among the tensor parallel ranks' + + @staticmethod + def symbolic(graph, input_, dim): + return _drop_tokens(input_, dim) + + @staticmethod + def forward(ctx, input_, dim): + ctx.dim = dim + return _drop_tokens(input_, dim) + + @staticmethod + def backward(ctx, input_): + return _gather_tokens(input_, ctx.dim), None + + +def gather_tokens(input_, dim=0): + if mpu is None or mpu.get_model_parallel_world_size() == 1: + # no tensor parallelism for non-experts + return input_ + return _GatherTokens.apply(input_, dim) + + +def drop_tokens(input_, dim=0): + if mpu is None or mpu.get_model_parallel_world_size() == 1: + # no tensor parallelism for non-experts + return input_ + return _DropTokens.apply(input_, dim) diff --git a/modelscope/models/nlp/gpt_moe/moe/sharded_moe.py b/modelscope/models/nlp/gpt_moe/moe/sharded_moe.py new file mode 100644 index 00000000..1cfbd213 --- /dev/null +++ b/modelscope/models/nlp/gpt_moe/moe/sharded_moe.py @@ -0,0 +1,647 @@ +''' +Copyright 2021 The Microsoft DeepSpeed Team +''' +# The file has been adapted from two fairscale files: +# (1) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/moe_layer.py +# (2) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/top2gate.py +# Git commit hash: 34df606902a240567a0d898037ece55c2f1336cf +# We retain the following license from the original files: + +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from megatron import mpu +from scipy.special import binom +from torch import Tensor, nn +from torch.nn import Module + +from ..configuration import logger +from .mappings import drop_tokens, gather_tokens + +try: + from apex.normalization import FusedLayerNorm as _FusedLayerNorm + + has_fused_layernorm = True + + class FusedLayerNorm(_FusedLayerNorm): + + @torch.jit.unused + def forward(self, x): + if not x.is_cuda: + return super().forward(x) + else: + with torch.cuda.device(x.device): + return super().forward(x) +except ImportError: + has_fused_layernorm = False + +if TYPE_CHECKING: + Base = Module[Tensor] +else: + Base = Module + +uniform_map: Dict[torch.device, Callable] = {} +gumbel_map: Dict[torch.device, Callable] = {} +exp_selection_uniform_map: Dict[torch.device, Callable] = {} + + +def multiplicative_jitter(x, device: torch.device, epsilon=1e-2): + """ + Modified from switch transformer paper. mesh transformers + Multiply values by a random number between 1-epsilon and 1+epsilon. + Makes models more resilient to rounding errors introduced by bfloat16. + This seems particularly important for logits. + Args: + x: a torch.tensor + device: torch.device + epsilon: a floating point value + Returns: + a jittered x. + """ + if epsilon == 0: + return x + uniform = uniform_map.get(device) + if uniform is None: + uniform = torch.distributions.uniform.Uniform( + low=torch.tensor(1.0 - epsilon, device=device), + high=torch.tensor(1.0 + epsilon, + device=device)).rsample # type: ignore + uniform_map[device] = uniform + return x * uniform(x.shape) + + +def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor: + gumbel = gumbel_map.get(device) + if gumbel is None: + one = torch.tensor(1.0, device=device) + zero = torch.tensor(0.0, device=device) + gumbel = torch.distributions.gumbel.Gumbel(zero, + one).rsample # type: ignore + gumbel_map[device] = gumbel + return gumbel(shape) + + +# Based on https://github.com/pytorch/pytorch/pull/40762 +class _AllToAll(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, group: dist.ProcessGroup, + input: Tensor) -> Tensor: # type: ignore + ctx.group = group + input = input.contiguous() + output = torch.empty_like(input) + dist.all_to_all_single(output, input, group=group) + return output + + @staticmethod + def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]: + return (None, _AllToAll.apply(ctx.group, *grad_output)) + + +# einsum rewrites are on par or more performant +# switch can be bubbled up in future +USE_EINSUM = True + + +# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity +# See https://arxiv.org/pdf/2006.16668.pdf for details. +def einsum(rule, a, b): + if USE_EINSUM: + return torch.einsum(rule, a, b) + elif rule == 's,se->se': + return a.reshape(a.shape[0], -1) * b + elif rule == 'se,sc->sec': + return a.unsqueeze(2) * b.unsqueeze(1) + elif rule == 'se,se->s': + return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1) + elif rule == 'sec,sm->ecm': + s = a.shape[0] + e = a.shape[1] + c = a.shape[2] + m = b.shape[1] + return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m) + elif rule == 'sec,ecm->sm': + return torch.matmul( + a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1])) + elif rule == 'ks,ksm->sm': + k = b.shape[0] + s = b.shape[1] + m = b.shape[2] + # [k, s] -> [s, k] -> [s, 1, k] + a = a.t().unsqueeze(1) + # [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k] + b = b.reshape(k, -1).t().reshape(s, m, k) + # bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1] + return torch.bmm(a, b.transpose(1, 2)).squeeze(2) + else: + return torch.einsum(rule, a, b) + + +# The following functions are extracted and scripted +# because otherwise during a torch.jit.trace, the non-Tensor +# values used in the calculations get recorded as constants. +# torch.jit.script coerces them into Tensors and preserves +# their dynamic shapes. This enables ONNX export. +# We can't script the entire top1gating function because it +# includes stateful caching logic which is incompatible with ONNX. + + +@torch.jit.script +def _capacity(gates: Tensor, capacity_factor: Tensor, + min_capacity: Tensor) -> Tensor: + # gates has shape of SE + num_tokens = gates.shape[0] + num_experts = gates.shape[1] + # to(torch.int64) works around a bug in torch.onnx.export: + # it should cast k to int64 when converting torch.topk but it doesn't. + capacity = torch.ceil( + (num_tokens / num_experts) * capacity_factor).to(torch.int64) + if capacity < min_capacity: + capacity = min_capacity.to(torch.int64) + return capacity + + +@torch.jit.script +def _top_idx(source, k): + return torch.topk(source, k=k, dim=0)[1] + + +@torch.jit.script +def _one_hot_to_float(x, num_classes): + return F.one_hot(x, num_classes=num_classes).float() + + +def top1gating( + logits: Tensor, + capacity_factor: float, + min_capacity: int, + used_token: Tensor = None, + noisy_gate_policy: Optional[str] = None, + drop_tokens: bool = True, + use_rts: bool = True, + use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Implements Top1Gating on logits.""" + if noisy_gate_policy == 'RSample': + logits_w_noise = logits + gumbel_rsample( + logits.shape, device=logits.device) + # everything is in fp32 in this function + gates = F.softmax(logits, dim=1) + + capacity = _capacity(gates, torch.tensor(capacity_factor), + torch.tensor(min_capacity)) + + # Create a mask for 1st's expert per token + # noisy gating + indices1_s = torch.argmax( + logits_w_noise if noisy_gate_policy == 'RSample' else gates, dim=1) + num_experts = int(gates.shape[1]) + mask1 = F.one_hot(indices1_s, num_classes=num_experts) + + # mask only used tokens + if used_token is not None: + mask1 = einsum('s,se->se', used_token, mask1) + + # gating decisions + exp_counts = torch.sum(mask1, dim=0).detach().to('cpu') + + # if we don't want to drop any tokens + if not drop_tokens: + new_capacity = torch.max(exp_counts).to(logits.device) + dist.all_reduce( + new_capacity, op=dist.ReduceOp.MAX, group=dist.group.WORLD) + capacity = new_capacity + + # Compute l_aux + alpha = torch.max(gates, dim=1).values.unsqueeze(1) + me = torch.mean(gates, dim=0) + ce = torch.mean(mask1.float(), dim=0) + l_aux = torch.sum(me * ce) * num_experts + + # Random Token Selection + if use_rts: + uniform = exp_selection_uniform_map.get(logits.device) + if uniform is None: + uniform = torch.distributions.uniform.Uniform( + low=torch.tensor(0.0, device=logits.device), + high=torch.tensor(1.0, device=logits.device)).rsample + exp_selection_uniform_map[logits.device] = uniform + + mask1_rand = mask1 * uniform(mask1.shape) + else: + mask1_rand = mask1 + + assert logits.shape[0] >= min_capacity, \ + 'No. of tokens (batch-size) should be greater than min_capacity. ' \ + 'Either set min_capacity to 0 or increase your batch size.' + + top_idx = _top_idx(mask1_rand, capacity) + + new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1) + mask1 = new_mask1 + + if use_tutel: + # Tutel doesn't support index values masked with zero + # so we need to replace masked indices with -1 + indices_mask = mask1.sum(dim=1) * num_experts - 1 + indices1_s = torch.min(indices1_s, indices_mask) + + # Compute locations in capacity buffer + if use_tutel: + locations1 = tutel_moe.fast_cumsum_sub_one(mask1) + else: + locations1 = torch.cumsum(mask1, dim=0) - 1 + + if use_tutel: + gates1_s = (gates * mask1).sum(dim=1) + locations1_s = torch.sum(locations1 * mask1, dim=1) + return l_aux, capacity, num_experts, [ + indices1_s, + ], [ + locations1_s, + ], [ + gates1_s, + ], exp_counts, alpha + + # Store the capacity location for each token + locations1_s = torch.sum(locations1 * mask1, dim=1) + + # Normalize gate probabilities + mask1_float = mask1.float() + gates = gates * mask1_float + + locations1_sc = _one_hot_to_float(locations1_s, capacity) + combine_weights = einsum('se,sc->sec', gates, locations1_sc) + + dispatch_mask = combine_weights.bool() + + return l_aux, combine_weights, dispatch_mask, exp_counts, alpha + + +class TopKGate(Module): + """Gate module which implements Top2Gating as described in Gshard_. + :: + + gate = TopKGate(model_dim, num_experts) + l_aux, combine_weights, dispatch_mask = gate(input) + + .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf + + Args: + model_dim (int): + size of model embedding dimension + num_experts (ints): + number of experts in model + """ + + wg: torch.nn.Linear + + def __init__(self, + model_dim: int, + num_experts: int, + k: int = 1, + capacity_factor: float = 1.0, + eval_capacity_factor: float = 1.0, + min_capacity: int = 8, + noisy_gate_policy: Optional[str] = None, + drop_tokens: bool = True, + use_rts: bool = True, + top_k_linear_strategy: str = 'standard') -> None: + super().__init__() + + # Only top-1 are supported at the moment. + if k != 1: + raise ValueError('Only top-1 gatings are supported.') + if top_k_linear_strategy == 'standard': + self.wg = torch.nn.Linear( + model_dim, num_experts, bias=False).float() + elif top_k_linear_strategy == 'lsoftmax': + self.wg = LSoftmaxLinearLayer( + model_dim, num_experts, margin=1).float() + else: + raise ValueError( + 'Only standard or lsoftmax top-k-linear-strategy are supported.' + ) + + self.k = k + self.capacity_factor = capacity_factor + self.eval_capacity_factor = eval_capacity_factor + self.min_capacity = min_capacity + self.noisy_gate_policy = noisy_gate_policy + self.wall_clock_breakdown = False + self.gate_time = 0.0 + self.drop_tokens = drop_tokens + self.use_rts = use_rts + self.top_k_linear_strategy = top_k_linear_strategy + + def forward( + self, + input: torch.Tensor, + used_token: torch.Tensor = None, + use_tutel: bool = False + ) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore + + if self.wall_clock_breakdown: + self.timers('TopKGate').start() + + if self.top_k_linear_strategy == 'standard': + if self.wg.weight.dtype != torch.float32: + self.wg = self.wg.float() + elif self.top_k_linear_strategy == 'lsoftmax': + if self.wg.weight.weight.dtype != torch.float32: + self.wg.weight = self.wg.weight.float() + + input_fp32 = input.float() + # input jittering + if self.noisy_gate_policy == 'Jitter' and self.training: + input_fp32 = multiplicative_jitter(input_fp32, device=input.device) + + if self.k == 1: + if self.top_k_linear_strategy == 'standard': + logits = self.wg(input_fp32) + elif self.top_k_linear_strategy == 'lsoftmax': + logits = self.wg(input_fp32, input_fp32.device, self.training) + + gate_output = top1gating( + logits, self.capacity_factor if self.training else + self.eval_capacity_factor, self.min_capacity, used_token, + self.noisy_gate_policy if self.training else None, + self.drop_tokens, self.use_rts, use_tutel) + + if self.wall_clock_breakdown: + self.timers('TopKGate').stop() + self.gate_time = self.timers('TopKGate').elapsed( + reset=False) * 1000 + + return gate_output + + +class MOELayer(Base): + """MOELayer module which implements MixtureOfExperts as described in Gshard_. + :: + + gate = TopKGate(model_dim, num_experts) + moe = MOELayer(gate, expert) + output = moe(input) + l_aux = moe.l_aux + + .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf + + Args: + gate (torch.nn.Module): + gate network + expert (torch.nn.Module): + expert network + """ + + def __init__(self, + gate: Module, + experts: Module, + ep_group_name, + ep_size, + num_local_experts: int, + use_tutel: bool = False, + use_expert_residual_network: bool = False) -> None: + super().__init__() + self.gate = gate + self.experts = experts + self.ep_group = None + self.ep_size = ep_size + self.ep_group_name = ep_group_name + self.num_local_experts = num_local_experts + + self.wall_clock_breakdown = False + self.use_expert_residual_network = use_expert_residual_network + + if self.use_expert_residual_network: + self.expert_network = nn.Sequential( + *([ExpertResidualLayer(self.gate.model_dim) + for _ in range(6)])) + + self.use_tutel = use_tutel and TUTEL_INSTALLED + + if self.use_tutel: + logger.info('Using Tutel optimizations.') + elif use_tutel and not TUTEL_INSTALLED: + logger.info( + 'Tutel optimization requested but not installed Proceeding without Tutel.' + ) + + def _set_ep_group(self, ep_group): + self.ep_group = ep_group + + def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: + + if self.wall_clock_breakdown: + self.timers('moe').start() + + # Implement Algorithm 2 from GShard paper. + d_model = input[0].shape[-1] + + # Initial implementation -> Reshape into S tokens by dropping sequence dimension. + # Reshape into G groups so that each group can distribute tokens equally + # group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1 + reshaped_input = input[0].reshape(-1, d_model) + + if self.use_tutel: + self.l_aux, C, E, indices_, locations_, gates_, self.exp_counts, alpha = self.gate( + reshaped_input, input[1], True) + _, M = reshaped_input.size(0), reshaped_input.size(1) + + if not hasattr(self, '_tutel_dispatcher'): + self._tutel_dispatcher = tutel_moe.fast_dispatcher( + E, C, M, dispatch_dtype=reshaped_input.dtype) + self._tutel_dispatcher.update( + indices_, locations_, gates_, capacity=C) + dispatched_input = self._tutel_dispatcher.encode(reshaped_input) + else: + self.l_aux, combine_weights, dispatch_mask, self.exp_counts, alpha = self.gate( + reshaped_input, input[1]) + dispatched_input = einsum('sec,sm->ecm', + dispatch_mask.type_as(input[0]), + reshaped_input) + + if self.wall_clock_breakdown: + self.timers('falltoall').start() + + if mpu.get_expert_model_parallel_world_size() == 1: + # If the non-expert is tensor-parallel, it will create + # duplicate tokens on the tensor-parallel ranks. + # Since our experts are not tensor-parallel, these duplicates + # need to be dropped to ensure correctness. + # this also doubles up as a communication optimization as we are + # reducing the all-to-all communication volume. + if self.use_tutel: + # reshape tutel's output from [e*c,m] to [e,c,m] + dispatched_input = dispatched_input.reshape( + self.ep_size * self.num_local_experts, -1, d_model) + dispatched_input = drop_tokens(dispatched_input, dim=1) + + dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input) + + if self.wall_clock_breakdown: + self.timers('falltoall').stop() + self.time_falltoall = self.timers('falltoall').elapsed( + reset=False) * 1000 + + # Re-shape after all-to-all: ecm -> gecm + dispatched_input = dispatched_input.reshape(self.ep_size, + self.num_local_experts, -1, + d_model) + + expert_output = self.experts(dispatched_input) + + if self.wall_clock_breakdown: + self.timers('salltoall').start() + + expert_output = _AllToAll.apply(self.ep_group, expert_output) + + if self.wall_clock_breakdown: + self.timers('salltoall').stop() + self.time_salltoall = self.timers('salltoall').elapsed( + reset=False) * 1000 + + # Re-shape back: gecm -> ecm + expert_output = expert_output.reshape( + self.ep_size * self.num_local_experts, -1, d_model) + + if mpu.get_expert_model_parallel_world_size() == 1: + # the dropped duplicate tokens need to be gathered on each + # tensor parallel rank again for the tensor-parallel + # non-expert of the next layer. + expert_output = gather_tokens(expert_output, dim=1) + + if self.use_tutel: + combined_output = self._tutel_dispatcher.decode( + expert_output.view(E * C, M)) + else: + combined_output = einsum('sec,ecm->sm', + combine_weights.type_as(input[0]), + expert_output) + + if self.use_expert_residual_network: + combined_output = alpha * self.expert_network(combined_output) + ( + 1 - alpha) * combined_output + + a = combined_output.reshape(input[0].shape) + + if self.wall_clock_breakdown: + self.timers('moe').stop() + self.time_moe = self.timers('moe').elapsed(reset=False) * 1000 + + return a + + +class LSoftmaxLinearLayer(torch.nn.Module): + + def __init__(self, input_features, output_features, margin): + super().__init__() + self.input_dim = input_features # number of input feature i.e. output of the last fc layer + self.output_dim = output_features # number of output = class numbers + self.margin = margin # m + self.beta = 100 + self.beta_min = 0 + self.scale = 0.99 + self.num_experts = output_features + # Initialize L-Softmax parameters + self.weight = torch.nn.Linear( + input_features, output_features, bias=False).float() + self.divisor = math.pi / self.margin # pi/m + self.C_m_2n = torch.Tensor(binom(margin, range(0, margin + 1, + 2))) # C_m{2n} + self.cos_powers = torch.Tensor(range(self.margin, -1, -2)) # m - 2n + self.sin2_powers = torch.Tensor(range(len(self.cos_powers))) # n + self.signs = torch.ones(margin // 2 + 1) # 1, -1, 1, -1, ... + self.signs[1::2] = -1 + + def calculate_cos_m_theta(self, cos_theta, device): + sin2_theta = 1 - cos_theta**2 + cos_terms = cos_theta.unsqueeze(1)**self.cos_powers.to( + device).unsqueeze(0) # cos^{m - 2n} + sin2_terms = ( + sin2_theta.unsqueeze(1)**self.sin2_powers.to(device).unsqueeze(0)) + + cos_m_theta = (self.signs.to(device).unsqueeze(0) + * self.C_m_2n.to(device).unsqueeze(0) * cos_terms + * sin2_terms).sum(1) # summation of all terms + + return cos_m_theta + + def reset_parameters(self): + nn.init.kaiming_normal_(self.weight.data.t()) + + def find_k(self, cos): + # to account for acos numerical errors + eps = 1e-7 + cos = torch.clamp(cos, -1 + eps, 1 - eps) + acos = cos.acos() + k = (acos / self.divisor).floor().detach() + return k + + def forward(self, input, device, training): + if training: + x, w = input, self.weight.float() + beta = max(self.beta, self.beta_min) + logit = w(x) + indexes = range(logit.size(0)) + # target = torch.fmod(torch.randperm(logit.size(0)), self.num_experts) + target = torch.fmod( + torch.range(0, + logit.size(0) - 1), self.num_experts).long() + logit_target = logit[indexes, target] + + # cos(theta) = w * x / ||w||*||x|| + w_target_norm = w.weight[:, target].norm(p=2, dim=0) + + x_norm = x.norm(p=2, dim=1) + cos_theta_target = logit_target / (w_target_norm * x_norm + 1e-10) + + # equation 7 + cos_m_theta_target = self.calculate_cos_m_theta( + cos_theta_target, device) + + # find k in equation 6 + k = self.find_k(cos_theta_target) + + # f_y_i + logit_target_updated = w_target_norm * x_norm * (( + (-1)**k * cos_m_theta_target) - 2 * k) + logit_target_updated_beta = (logit_target_updated + beta + * logit[indexes, target]) / (1 + beta) + + logit[indexes, target] = logit_target_updated_beta + self.beta *= self.scale + return logit + else: + return self.weight(input) + + +def LayerNorm(normalized_shape, + eps=1e-5, + elementwise_affine=True, + export=False): + if torch.jit.is_scripting() or torch.jit.is_tracing(): + export = True + if not export and torch.cuda.is_available() and has_fused_layernorm: + return FusedLayerNorm(normalized_shape, eps, elementwise_affine) + return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) + + +class ExpertResidualLayer(torch.nn.Module): + + def __init__(self, embed_dim): + super().__init__() + self.norm = LayerNorm(embed_dim, export=False) + self.ff1 = torch.nn.Linear(embed_dim, embed_dim * 4) + self.ff2 = torch.nn.Linear(embed_dim * 4, embed_dim) + self.ff2.weight.data.zero_() + + def forward(self, xs): + return xs + self.ff2(torch.nn.functional.relu(self.ff1(self.norm(xs)))) diff --git a/modelscope/models/nlp/gpt_moe/moe/utils.py b/modelscope/models/nlp/gpt_moe/moe/utils.py new file mode 100644 index 00000000..b6d64d5b --- /dev/null +++ b/modelscope/models/nlp/gpt_moe/moe/utils.py @@ -0,0 +1,125 @@ +''' +Copyright 2020 The Microsoft DeepSpeed Team +''' + +from typing import Dict, List, Tuple + +import torch + +from .layer import MoE + + +def has_moe_layers(m): + has_moe = False + num_experts = 0 + for _, module in m.named_modules(): + if isinstance(module, MoE): + has_moe = True + num_experts = module.num_experts + break + return has_moe, num_experts + + +def is_moe_param(param: torch.Tensor) -> bool: + if hasattr(param, 'allreduce') and not param.allreduce: + return True + return False + + +def split_params_into_shared_and_expert_params( + params: List[torch.nn.Parameter] +) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]: + shared_params, expert_params = [], [] + for p in params: + if is_moe_param(p): + expert_params.append(p) + else: + shared_params.append(p) + return shared_params, expert_params + + +def split_params_grads_into_shared_and_expert_params( + group: List[torch.nn.Parameter] +) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]: + """Split grad of parameters into grads of non-expert params + and grads of expert params. This is useful while computing + grad-norms for clipping and overflow detection + + group (List[torch.nn.Parameter]): + Args: + The group of parameters to split + + Returns: + Tuple[List[torch.nn.Parameter], List[torch.nn.Parameter]]: + list of gradients for non MoE params, list of gradients of MoE params + """ + expert_grads = [] + shared_grads = [] + for p in group: + if p.grad is not None: + if is_moe_param(p): + expert_grads.append(p.grad.to(p.dtype)) + else: + shared_grads.append(p.grad.to(p.dtype)) + return shared_grads, expert_grads + + +def split_params_into_different_moe_groups_for_optimizer( + param_groups: Tuple[Dict]) -> Tuple[Dict]: + """Split parameters into different MoE groups for optimizer + + Args: + param_groups (Tuple[Dict]): + The list of parameter groups to split + + Returns: + Tuple[Dict]: + list of MoE/non-MoE groups for optimizer + """ + if isinstance(param_groups, tuple): + param_groups = list(param_groups) # Tuple cannot be modified + elif isinstance(param_groups, dict): + param_groups = [param_groups] + elif not isinstance(param_groups, list): + raise ValueError(f'Unknown param group type of {type(param_groups)}') + + # gather all data parallel group names + data_parallel_group_names = set() + for param_group in param_groups: + for param in param_group['params']: + if is_moe_param(param): + data_parallel_group_names.add(param.group_name) + data_parallel_group_names = list(data_parallel_group_names) + group_moe = {} + # Create the param MoE groups, leave param assign to next step + for param_group in param_groups: + group_moe[param_group['name']] = {} + for key in data_parallel_group_names: + group_moe[param_group['name']][key] = {} + group_moe[param_group['name']][key]['name'] = key + group_moe[param_group['name']][key]['moe'] = True + for ori_key in param_group.keys(): + if ori_key != 'name': + if ori_key == 'params': + group_moe[param_group['name']][key][ori_key] = [] + else: + group_moe[param_group['name']][key][ + ori_key] = param_group[ori_key] + # Assign param + for param_group in param_groups: + new_params = [] + for param in param_group['params']: + if is_moe_param(param): + group_moe[param_group['name']][ + param.group_name]['params'].append(param) + # param_group['params'].remove(param) + else: + new_params.append(param) + param_group['params'] = new_params + + # Flatten the moe groups + for k, v in group_moe.items(): + for k1, v1 in v.items(): + param_groups.append(v1) + + return tuple(param_groups) diff --git a/modelscope/models/nlp/gpt_moe/text_generation.py b/modelscope/models/nlp/gpt_moe/text_generation.py new file mode 100644 index 00000000..59245917 --- /dev/null +++ b/modelscope/models/nlp/gpt_moe/text_generation.py @@ -0,0 +1,62 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Dict + +from modelscope.metainfo import Models +from modelscope.models.base import Tensor, TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import Tasks + +__all__ = ['GPTMoEForTextGeneration'] + + +@MODELS.register_module(Tasks.text_generation, module_name=Models.gpt_moe) +class GPTMoEForTextGeneration(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the text generation model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, *args, **kwargs) + + from modelscope.models.nlp.gpt_moe import GPTMoEModel + from transformers import BertTokenizer + print('****') + print(model_dir) + self.model = GPTMoEModel.from_pretrained(model_dir) + self.tokenizer = BertTokenizer.from_pretrained(model_dir) + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + """return the result by the model + + Args: + input (Dict[str, Tensor]): the preprocessed data + + Returns: + Dict[str, Tensor]: results + Example: + { + 'logits': Tensor([[0.54, 0.32...])]), # logits + } + """ + return self.model(**input) + + def generate(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + assert 'input_ids' in input, "generate function must accept 'input_ids' key" + input_ids = input['input_ids'] + if 'attention_mask' in input: + attention_mask = input['attention_mask'] + input_ids = input_ids[0][attention_mask[0].nonzero()] \ + .squeeze().unsqueeze(0) + # remove sep token at the end of tokenizer output + input_ids = input_ids[:, :-1] + + gen_params = dict() + gen_params['inputs'] = input_ids + gen_params['do_sample'] = input.pop('do_sample', True) + gen_params['max_length'] = input.pop('max_length', 128) + gen_params['top_k'] = input.pop('top_k', 10) + gen_params['top_p'] = input.pop('top_p', None) + sample_output = self.model.generate(**gen_params) + return {'sequences': sample_output[0]} diff --git a/modelscope/models/nlp/gpt_moe/tokenizer.py b/modelscope/models/nlp/gpt_moe/tokenizer.py new file mode 100644 index 00000000..a290b846 --- /dev/null +++ b/modelscope/models/nlp/gpt_moe/tokenizer.py @@ -0,0 +1,67 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tokenizers import Tokenizer + + +class JiebaBPETokenizer: + """SentencePiece BPE tokenizer with Jieba integration""" + + def __init__(self, tokenizer_json_file): + self.name = 'Jieba BPE Tokenizer' + + self.tokenizer = Tokenizer.from_file(tokenizer_json_file) + self.eod_id = self.tokenizer.token_to_id('<|endoftext|>') + try: + import jieba + except ImportError: + raise ImportError( + 'You need to install rjieba to use JiebaTokenizer. ' + 'See https://pypi.org/project/rjieba/ for installation.') + self.jieba = jieba + self.new_line = self.vocab['\n'] + self.sep_token = self.vocab[''] + + @property + def vocab_size(self): + return self.tokenizer.get_vocab_size(with_added_tokens=True) + + @property + def vocab(self): + return self.tokenizer.get_vocab(with_added_tokens=True) + + @property + def inv_vocab(self): + vocab = self.vocab + inv_vocab = dict() + for key, val in vocab.items(): + inv_vocab[val] = key + return inv_vocab + + def tokenize(self, text, is_code=False): + if not is_code: + seg_list = [x for x in self.jieba.cut(text)] + return self.tokenizer.encode( + seg_list, is_pretokenized=True, add_special_tokens=True).ids + else: + return self.tokenizer.encode( + text, is_pretokenized=False, add_special_tokens=True).ids + + def detokenize(self, token_ids): + text = self.tokenizer.decode(token_ids, skip_special_tokens=False) + return text + + @property + def eod(self): + return self.eod_id diff --git a/modelscope/pipelines/nlp/distributed_gpt_moe_pipeline.py b/modelscope/pipelines/nlp/distributed_gpt_moe_pipeline.py new file mode 100644 index 00000000..71e48a11 --- /dev/null +++ b/modelscope/pipelines/nlp/distributed_gpt_moe_pipeline.py @@ -0,0 +1,54 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Any, Dict + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.nlp.gpt_moe.distributed_gpt_moe import DistributedGPTMoE +from modelscope.pipelines.base import DistributedPipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import TextGenerationJiebaPreprocessor +from modelscope.utils.constant import Tasks + + +@PIPELINES.register_module( + Tasks.text_generation, module_name=Pipelines.gpt_moe_generation) +class DistributedGPTMoEPipeline(DistributedPipeline): + """This class is used to instantiate the gpt-moe model. + """ + + model = None + + def __init__(self, model, preprocessor=None, **kwargs): + if preprocessor is None: + preprocessor = TextGenerationJiebaPreprocessor(model) + super().__init__(model, preprocessor=preprocessor, **kwargs) + assert hasattr(preprocessor, 'tokenizer') + + @classmethod + def _instantiate_one(cls, rank, model_dir, **kwargs): + cls.model = DistributedGPTMoE(model_dir, rank, **kwargs) + cls.model.eval() + + @classmethod + def _forward_one(cls, inputs: Dict[str, Any]) -> Dict[str, Any]: + tokens = inputs['inputs']['input_ids'].cuda( + torch.cuda.current_device()) + return cls.model.generate(tokens) + + def postprocess(self, inputs: Dict[str, Any], + **postprocess_params) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + from modelscope.outputs import OutputKeys + return { + OutputKeys.TEXT: + self.preprocessor.tokenizer.detokenize(inputs[0].tolist()) + } diff --git a/tests/pipelines/test_gpt_moe_text_generation.py b/tests/pipelines/test_gpt_moe_text_generation.py new file mode 100644 index 00000000..4ec8c742 --- /dev/null +++ b/tests/pipelines/test_gpt_moe_text_generation.py @@ -0,0 +1,24 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class TextGPTMoEGenerationTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id_1_3B_MoE32 = 'PAI/nlp_gpt3_text-generation_1.3B_MoE-32' + self.model_dir_1_3B_MoE32 = snapshot_download(self.model_id_1_3B_MoE32) + self.input = '好的' + + @unittest.skip('distributed gpt-moe 1.3B_MoE-32, skipped') + def test_gpt_moe_1_3B_MoE32(self): + pipe = pipeline(Tasks.text_generation, model=self.model_id_1_3B_MoE32) + print(pipe(self.input)) + + +if __name__ == '__main__': + unittest.main()