Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10836131master^2
@@ -80,6 +80,7 @@ class Models(object): | |||
gcnncrf = 'gcnn-crf' | |||
bart = 'bart' | |||
gpt3 = 'gpt3' | |||
gpt_moe = 'gpt-moe' | |||
gpt_neo = 'gpt-neo' | |||
plug = 'plug' | |||
bert_for_ds = 'bert-for-document-segmentation' | |||
@@ -255,6 +256,7 @@ class Pipelines(object): | |||
text_error_correction = 'text-error-correction' | |||
plug_generation = 'plug-generation' | |||
gpt3_generation = 'gpt3-generation' | |||
gpt_moe_generation = 'gpt-moe-generation' | |||
faq_question_answering = 'faq-question-answering' | |||
conversational_text_to_sql = 'conversational-text-to-sql' | |||
table_question_answering_pipeline = 'table-question-answering-pipeline' | |||
@@ -0,0 +1,27 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .configuration import GPTMoEConfig | |||
from .backbone import GPTMoEModel | |||
from .text_generation import GPTMoEForTextGeneration | |||
from .tokenizer import JiebaBPETokenizer | |||
else: | |||
_import_structure = { | |||
'configuration': ['GPTMoEConfig'], | |||
'backbone': ['GPTMoEModel'], | |||
'text_generation': ['GPTMoEForTextGeneration'], | |||
'tokenizer': ['JiebaBPETokenizer'], | |||
} | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
__name__, | |||
globals()['__file__'], | |||
_import_structure, | |||
module_spec=__spec__, | |||
extra_objects={}, | |||
) |
@@ -0,0 +1,355 @@ | |||
# Copyright 2021-2022 The Alibaba PAI Team Authors. | |||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import math | |||
import os | |||
from typing import Optional, Union | |||
import addict | |||
import torch | |||
from torch import nn | |||
from torch.nn import functional as F | |||
from transformers.modeling_utils import PreTrainedModel | |||
from modelscope.utils.constant import ModelFile | |||
from .configuration import GPTMoEConfig | |||
class GPTMoESelfAttention(nn.Module): | |||
"""Parallel self-attention layer abstract class. | |||
Self-attention layer takes input with size [s, b, h] | |||
and returns output of the same size. | |||
""" | |||
def __init__(self, config): | |||
super().__init__() | |||
self.hidden_size = config.hidden_size | |||
self.num_attention_heads = config.num_attention_heads | |||
# Per attention head | |||
self.hidden_size_per_attention_head = \ | |||
self.hidden_size // self.num_attention_heads | |||
self.query_key_value = nn.Linear(self.hidden_size, | |||
3 * self.hidden_size) | |||
self.softmax = nn.Softmax(dim=-1) | |||
self.attention_dropout = nn.Dropout( | |||
config.attention_probs_dropout_prob) | |||
# Output. | |||
self.dense = nn.Linear(self.hidden_size, self.hidden_size) | |||
self.output_dropout = nn.Dropout(config.hidden_dropout_prob) | |||
def _transpose_for_scores(self, tensor): | |||
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with | |||
size [b, np, s, hn]. | |||
""" | |||
new_tensor_shape = tensor.size()[:-1] + ( | |||
self.num_attention_heads, self.hidden_size_per_attention_head) | |||
tensor = tensor.view(*new_tensor_shape) | |||
return tensor.permute(0, 2, 1, 3) | |||
def _split_tensor_along_last_dim(self, | |||
tensor, | |||
num_partitions, | |||
contiguous_split_chunks=False): | |||
# Get the size and dimension. | |||
last_dim = tensor.dim() - 1 | |||
last_dim_size = tensor.size()[last_dim] // num_partitions | |||
# Split. | |||
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) | |||
# Note: torch.split does not create contiguous tensors by default. | |||
if contiguous_split_chunks: | |||
return tuple(chunk.contiguous() for chunk in tensor_list) | |||
return tensor_list | |||
def forward(self, hidden_states, ltor_mask, is_infer=False): | |||
# hidden_states: [b, s, h] | |||
# ltor_mask: [1, 1, s, s] | |||
# Attention heads. [b, s, hp] | |||
tgt_len = hidden_states.size(1) | |||
ltor_mask = torch.reshape(ltor_mask, [1, 1, tgt_len, tgt_len]) | |||
mixed_x_layer = self.query_key_value(hidden_states) | |||
(mixed_query_layer, mixed_key_layer, mixed_value_layer) = \ | |||
self._split_tensor_along_last_dim(mixed_x_layer, 3) | |||
# Reshape and transpose [b, np, s, hn] | |||
query_layer = self._transpose_for_scores(mixed_query_layer) | |||
key_layer = self._transpose_for_scores(mixed_key_layer) | |||
value_layer = self._transpose_for_scores(mixed_value_layer) | |||
previous_type = value_layer.type() | |||
# Raw attention scores. [b, np, s, s] | |||
attention_scores = torch.matmul(query_layer, | |||
key_layer.transpose(-1, -2)) | |||
attention_scores = attention_scores / math.sqrt( | |||
self.hidden_size_per_attention_head) | |||
# Apply the left to right attention mask. | |||
if is_infer: | |||
src_len = key_layer.size(2) | |||
ltor_mask = torch.tril( | |||
torch.ones((1, tgt_len, src_len), | |||
device=hidden_states.device)).view( | |||
1, 1, tgt_len, src_len).type(previous_type) | |||
converted_mask = 10000.0 * (1.0 - ltor_mask) | |||
attention_scores = (torch.mul(attention_scores, ltor_mask) | |||
- converted_mask).type(previous_type) | |||
# Attention probabilities. [b, np, s, s] | |||
attention_probs = self.softmax(attention_scores) | |||
# This is actually dropping out entire tokens to attend to, which might | |||
# seem a bit unusual, but is taken from the original Transformer paper. | |||
attention_probs = self.attention_dropout(attention_probs) | |||
# Context layer. | |||
# [b, np, s, hn] | |||
context_layer = torch.matmul(attention_probs, value_layer) | |||
# [b, s, np, hn] | |||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |||
new_context_layer_shape = context_layer.size()[:-2] + ( | |||
self.hidden_size, ) | |||
# [b, s, hp] | |||
context_layer = context_layer.view(*new_context_layer_shape) | |||
# Output. [b, s, h] | |||
output = self.dense(context_layer) | |||
output = self.output_dropout(output) | |||
return output | |||
class GPTMoEMLP(nn.Module): | |||
"""MLP. | |||
MLP will take the input with h hidden state, project it to 4*h | |||
hidden dimension, perform nonlinear transformation, and project the | |||
state back into h hidden dimension. | |||
""" | |||
def __init__(self, config): | |||
super().__init__() | |||
hidden_size = config.hidden_size | |||
# Project to 4h. | |||
self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size) | |||
self.activation_func = F.gelu | |||
# Project back to h. | |||
self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size) | |||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
def forward(self, hidden_states): | |||
# [s, b, 4hp] | |||
intermediate_parallel = self.dense_h_to_4h(hidden_states) | |||
intermediate_parallel = self.activation_func(intermediate_parallel) | |||
# [s, b, h] | |||
output = self.dense_4h_to_h(intermediate_parallel) | |||
output = self.dropout(output) | |||
return output | |||
class GPTMoETransformerLayer(nn.Module): | |||
"""A single transformer layer. | |||
Transformer layer takes input with size [s, b, h] and returns an | |||
output of the same size. | |||
""" | |||
def __init__(self, config): | |||
super().__init__() | |||
# Layernorm on the input data. | |||
self.input_layernorm = nn.LayerNorm( | |||
config.hidden_size, eps=config.layernorm_epsilon) | |||
# Self attention. | |||
self.attention = GPTMoESelfAttention(config) | |||
# Layernorm on the attention output | |||
self.post_attention_layernorm = nn.LayerNorm( | |||
config.hidden_size, eps=config.layernorm_epsilon) | |||
# MLP | |||
self.mlp = GPTMoEMLP(config) | |||
def forward(self, hidden_states, ltor_mask): | |||
# hidden_states: [b, s, h] | |||
# ltor_mask: [1, 1, s, s] | |||
# Layer norm at the begining of the transformer layer. | |||
layernorm_output = self.input_layernorm(hidden_states) | |||
# Self attention. | |||
attention_output = self.attention(layernorm_output, ltor_mask) | |||
# Residual connection. | |||
layernorm_input = hidden_states + attention_output | |||
# Layer norm post the self attention. | |||
layernorm_output = self.post_attention_layernorm(layernorm_input) | |||
# MLP. | |||
mlp_output = self.mlp(layernorm_output) | |||
# Second residual connection. | |||
output = layernorm_input + mlp_output | |||
return output | |||
class GPTMoETransformer(nn.Module): | |||
"""Transformer class.""" | |||
def __init__(self, config): | |||
super().__init__() | |||
self.input_tensor = None | |||
# Number of layers. | |||
self.num_layers = config.num_hidden_layers | |||
self.layers = torch.nn.ModuleList( | |||
[GPTMoETransformerLayer(config) for _ in range(self.num_layers)]) | |||
# Final layer norm before output. | |||
self.final_layernorm = nn.LayerNorm( | |||
config.hidden_size, eps=config.layernorm_epsilon) | |||
def _get_layer(self, layer_number): | |||
return self.layers[layer_number] | |||
def forward(self, hidden_states, attention_mask): | |||
# hidden_states: [s, b, h] | |||
for index in range(self.num_layers): | |||
layer = self._get_layer(index) | |||
hidden_states = layer(hidden_states, attention_mask) | |||
# Final layer norm. | |||
hidden_states = self.final_layernorm(hidden_states) | |||
return hidden_states | |||
class GPTMoETransformerLanguageModel(nn.Module): | |||
"""Transformer language model. | |||
Arguments: | |||
transformer_hparams: transformer hyperparameters | |||
vocab_size: vocabulary size | |||
max_sequence_length: maximum size of sequence. This | |||
is used for positional embedding | |||
embedding_dropout_prob: dropout probability for embeddings | |||
num_tokentypes: size of the token-type embeddings. 0 value | |||
will ignore this embedding | |||
""" | |||
def __init__(self, config): | |||
super().__init__() | |||
# Embeddings. | |||
self.word_embeddings = nn.Embedding(config.vocab_size, | |||
config.hidden_size) | |||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, | |||
config.hidden_size) | |||
self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob) | |||
# Transformer. | |||
self.transformer = GPTMoETransformer(config) | |||
def forward(self, input_ids, attention_mask, position_ids): | |||
words_embeddings = self.word_embeddings(input_ids) | |||
position_embeddings = self.position_embeddings(position_ids) | |||
embeddings = words_embeddings + position_embeddings | |||
transformer_input = self.embedding_dropout(embeddings) | |||
transformer_output = self.transformer(transformer_input, | |||
attention_mask) | |||
logits = F.linear(transformer_output, self.word_embeddings.weight) | |||
return logits | |||
class GPTMoEModel(PreTrainedModel): | |||
config_class = GPTMoEConfig | |||
def _init_weights(self, module): | |||
"""Initialize the weights""" | |||
if isinstance(module, nn.Linear): | |||
# Slightly different from the TF version which uses truncated_normal for initialization | |||
# cf https://github.com/pytorch/pytorch/pull/5617 | |||
module.weight.data.normal_( | |||
mean=0.0, std=self.config.initializer_range) | |||
if module.bias is not None: | |||
module.bias.data.zero_() | |||
elif isinstance(module, nn.Embedding): | |||
module.weight.data.normal_( | |||
mean=0.0, std=self.config.initializer_range) | |||
if module.padding_idx is not None: | |||
module.weight.data[module.padding_idx].zero_() | |||
elif isinstance(module, nn.LayerNorm): | |||
module.bias.data.zero_() | |||
module.weight.data.fill_(1.0) | |||
def __init__(self, config): | |||
super().__init__(config) | |||
self.language_model = GPTMoETransformerLanguageModel(config) | |||
def forward(self, | |||
input_ids, | |||
attention_mask=None, | |||
position_ids=None, | |||
labels=None, | |||
**kwargs): | |||
seq_length = input_ids.size(1) | |||
attention_mask = torch.tril( | |||
torch.ones((1, 1, seq_length, seq_length), | |||
dtype=torch.long, | |||
device=input_ids.device)) | |||
if position_ids is None: | |||
position_ids = torch.arange( | |||
seq_length, dtype=torch.long, device=input_ids.device) | |||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) | |||
logits = self.language_model(input_ids, attention_mask, position_ids) | |||
loss = None | |||
if labels is not None: | |||
loss_fct = nn.CrossEntropyLoss() | |||
loss = loss_fct( | |||
logits.view(-1, self.config.vocab_size), labels.view(-1)) | |||
return addict.Dict(loss=loss, logits=logits) | |||
@classmethod | |||
def from_pretrained( | |||
cls, pretrained_model_name_or_path: Optional[Union[str, | |||
os.PathLike]]): | |||
config = cls.config_class.from_pretrained( | |||
pretrained_model_name_or_path) | |||
model = cls(config) | |||
state_dict_file = os.path.join(pretrained_model_name_or_path, | |||
ModelFile.TORCH_MODEL_BIN_FILE) | |||
state_dict = torch.load(state_dict_file) | |||
if 'state_dict' in state_dict: | |||
state_dict = state_dict['state_dict'] | |||
state_dict = { | |||
k.replace('model.language_model', 'language_model'): v | |||
for k, v in state_dict.items() | |||
} | |||
model.load_state_dict(state_dict) | |||
return model | |||
def prepare_inputs_for_generation(self, input_ids, *args, **kwargs): | |||
return {'input_ids': input_ids} |
@@ -0,0 +1,145 @@ | |||
# Copyright 2021-2022 The Alibaba PAI Team Authors. | |||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import os | |||
import torch | |||
from megatron import mpu | |||
from megatron.model import Float16Module | |||
from torch.nn.parallel import DistributedDataParallel as torchDDP | |||
from .configuration import logger | |||
from .moe.layer import MoE | |||
def unwrap_model(model, module_instances=(torchDDP)): | |||
return_list = True | |||
if not isinstance(model, list): | |||
model = [model] | |||
return_list = False | |||
unwrapped_model = [] | |||
for model_module in model: | |||
while isinstance(model_module, module_instances): | |||
model_module = model_module.module | |||
unwrapped_model.append(model_module) | |||
if not return_list: | |||
return unwrapped_model[0] | |||
return unwrapped_model | |||
def get_checkpoint_names(checkpoints_path, | |||
path_load_tag, | |||
num_experts, | |||
tensor_rank=None, | |||
expp_rank=None): | |||
"""Determine the directory name for this rank's checkpoint.""" | |||
if tensor_rank is None: | |||
tensor_rank = mpu.get_model_parallel_rank() | |||
common_path = os.path.join(checkpoints_path, path_load_tag, | |||
f'mp_rank_{tensor_rank:02d}') | |||
if num_experts[0] > 0: | |||
model_name = common_path + '_model_states.pt' | |||
optim_name = os.path.join( | |||
checkpoints_path, path_load_tag, | |||
f'expp_rank_{expp_rank}_mp_rank_{tensor_rank:02d}_optim_states.pt') | |||
else: | |||
model_name = optim_name = os.path.join(common_path, | |||
'model_optim_rng.pt') | |||
return model_name, optim_name | |||
def _get_expert_ckpt_name(checkpoints_path, layer_id, expert_id): | |||
mp_rank = mpu.get_model_parallel_rank() | |||
ckpt_name = os.path.join( | |||
os.path.join(checkpoints_path, 'model'), | |||
f'layer_{layer_id}_expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt' | |||
) | |||
return ckpt_name | |||
def _load_base_checkpoint(load_dir, path_load_tag=None, num_experts=None): | |||
""" Load the base state_dict from the given directory | |||
If rank0 is true, just loads rank 0 checkpoint, ignoring arguments. | |||
""" | |||
largest_group_name = mpu.get_max_expert_size_name() | |||
expp_rank = mpu.get_expert_parallel_rank(largest_group_name) | |||
checkpoint_names = get_checkpoint_names( | |||
load_dir, | |||
path_load_tag=path_load_tag, | |||
num_experts=num_experts, | |||
expp_rank=expp_rank) | |||
model_checkpoint_name, optim_checkpoint_name = checkpoint_names | |||
logger.info(f'Loading model checkpoint from {model_checkpoint_name}') | |||
model_state_dict = torch.load(model_checkpoint_name, map_location='cpu') | |||
return model_state_dict | |||
def load_checkpoint(model, | |||
load_dir, | |||
num_experts=None, | |||
strict=True, | |||
path_load_tag='model', | |||
load_ds_ckpts=True): | |||
model = unwrap_model(model, (torchDDP, Float16Module)) | |||
model_state_dict = _load_base_checkpoint( | |||
load_dir, path_load_tag=path_load_tag, num_experts=num_experts) | |||
assert model_state_dict is not None | |||
if load_ds_ckpts: | |||
load_moe_checkpoint(model, model_state_dict['module'], load_dir) | |||
else: | |||
load_moe_checkpoint(model, model_state_dict['model'], load_dir) | |||
if load_ds_ckpts: | |||
model.load_state_dict(model_state_dict['module'], strict=strict) | |||
else: | |||
model.load_state_dict(model_state_dict['model'], strict=strict) | |||
if torch.distributed.is_initialized(): | |||
torch.distributed.barrier() | |||
def load_moe_checkpoint(model, state_dict, load_dir): | |||
moe_layer_id = 0 | |||
for n_module, module in model.named_modules(): | |||
if isinstance(module, MoE): # and torch.distributed.get_rank() == 0: | |||
group_name = module.expert_group_name | |||
num_local_experts = module.num_local_experts | |||
expp_rank = mpu.get_expert_parallel_rank(group_name) | |||
# loop all local_experts | |||
for local_expert_id in range(num_local_experts): | |||
global_expert_id = expp_rank * num_local_experts + local_expert_id | |||
moe_load_path = _get_expert_ckpt_name(load_dir, moe_layer_id, | |||
global_expert_id) | |||
logger.info(f'Loading expert states from {moe_load_path}') | |||
expert_state_dict = torch.load( | |||
moe_load_path, map_location=torch.device('cpu')) | |||
# Updating global -> local expert ids | |||
moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.' | |||
for key in list(expert_state_dict.keys()): | |||
local_key = key.replace( | |||
f'{moe_str_prefix}{global_expert_id}', | |||
f'{moe_str_prefix}{local_expert_id}') | |||
expert_state_dict[local_key] = expert_state_dict.pop(key) | |||
state_dict.update(expert_state_dict) | |||
moe_layer_id += 1 |
@@ -0,0 +1,128 @@ | |||
# Copyright 2021-2022 The Alibaba PAI Team Authors. | |||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import torch | |||
from transformers.configuration_utils import PretrainedConfig | |||
from transformers.utils import logging | |||
logger = logging.get_logger(__name__) | |||
class GPTMoEConfig(PretrainedConfig): | |||
model_type = 'gpt-moe' | |||
def __init__( | |||
self, | |||
vocab_size=25600, | |||
hidden_size=768, | |||
ffn_hidden_size=None, | |||
num_hidden_layers=12, | |||
num_attention_heads=12, | |||
intermediate_size=3072, | |||
hidden_act='gelu', | |||
hidden_dropout_prob=0.1, | |||
attention_probs_dropout_prob=0.1, | |||
max_position_embeddings=2048, | |||
type_vocab_size=2, | |||
layernorm_epsilon=1e-12, | |||
bias_gelu_fusion=True, | |||
fp32_residual_connection=False, | |||
sequence_parallel=False, | |||
fp16=False, | |||
bf16=False, | |||
apply_query_key_layer_scaling=True, | |||
attention_softmax_in_fp32=False, | |||
kv_channels=None, | |||
masked_softmax_fusion=True, | |||
attention_dropout=0.1, | |||
bias_dropout_fusion=True, | |||
apply_residual_connection_post_layernorm=False, | |||
hidden_dropout=0.1, | |||
init_method_std=0.02, | |||
# generate | |||
eod_id=7, | |||
tokens_to_generate=100, | |||
top_k=0, | |||
top_p=0.9, | |||
num_experts=[0], | |||
use_tutel=False, | |||
top_k_linear_strategy='standard', | |||
use_expert_residual_network=False, | |||
load_ds_ckpts=False, | |||
model_dir=None, | |||
**kwargs): | |||
super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs) | |||
self.vocab_size = vocab_size | |||
self.hidden_size = hidden_size | |||
self.ffn_hidden_size = 4 * hidden_size \ | |||
if ffn_hidden_size is None else ffn_hidden_size | |||
self.num_hidden_layers = num_hidden_layers | |||
self.num_attention_heads = num_attention_heads | |||
self.hidden_act = hidden_act | |||
self.intermediate_size = intermediate_size | |||
self.hidden_dropout_prob = hidden_dropout_prob | |||
self.attention_probs_dropout_prob = attention_probs_dropout_prob | |||
self.max_position_embeddings = max_position_embeddings | |||
self.type_vocab_size = type_vocab_size | |||
self.layernorm_epsilon = layernorm_epsilon | |||
self.bias_gelu_fusion = bias_gelu_fusion | |||
self.fp32_residual_connection = fp32_residual_connection | |||
self.sequence_parallel = sequence_parallel | |||
self.fp16 = fp16 | |||
self.bf16 = bf16 | |||
assert not (fp16 and bf16) | |||
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling | |||
self.attention_softmax_in_fp32 = attention_softmax_in_fp32 | |||
if kv_channels is None: | |||
assert hidden_size % num_attention_heads == 0 | |||
self.kv_channels = hidden_size // num_attention_heads | |||
self.masked_softmax_fusion = masked_softmax_fusion | |||
self.attention_dropout = attention_dropout | |||
self.bias_dropout_fusion = bias_dropout_fusion | |||
self.apply_residual_connection_post_layernorm = \ | |||
apply_residual_connection_post_layernorm | |||
self.hidden_dropout = hidden_dropout | |||
self.init_method_std = init_method_std | |||
self.eod_id = eod_id | |||
self.tokens_to_generate = tokens_to_generate | |||
self.top_k = top_k | |||
self.top_p = top_p | |||
self.num_experts = num_experts | |||
self.use_tutel = use_tutel | |||
self.top_k_linear_strategy = top_k_linear_strategy | |||
self.use_expert_residual_network = use_expert_residual_network | |||
self.load_ds_ckpts = load_ds_ckpts | |||
self.model_dir = model_dir | |||
if self.num_experts[0] > torch.cuda.device_count(): | |||
self.moe_expert_parallel_size = torch.cuda.device_count() | |||
else: | |||
self.moe_expert_parallel_size = self.num_experts[0] | |||
TORCH_MAJOR = int(torch.__version__.split('.')[0]) | |||
TORCH_MINOR = int(torch.__version__.split('.')[1]) | |||
self.no_persist_layer_norm = \ | |||
TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11) | |||
@property | |||
def params_dtype(self): | |||
if self.fp16: | |||
return torch.half | |||
elif self.bf16: | |||
return torch.bfloat16 | |||
else: | |||
return torch.float |
@@ -0,0 +1,36 @@ | |||
''' | |||
Copyright 2020 The Microsoft DeepSpeed Team | |||
''' | |||
import copy | |||
import torch | |||
class Experts(torch.nn.Module): | |||
def __init__(self, expert, num_local_experts=1, expert_group_name=None): | |||
super(Experts, self).__init__() | |||
self.deepspeed_experts = torch.nn.ModuleList( | |||
[copy.deepcopy(expert) for i in range(num_local_experts)]) | |||
self.num_local_experts = num_local_experts | |||
# TODO: revisit allreduce for moe.gate... | |||
for expert in self.deepspeed_experts: | |||
# TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group) | |||
for name, param in expert.named_parameters(): | |||
param.allreduce = False | |||
param.group_name = expert_group_name | |||
def forward(self, inputs): | |||
chunks = inputs.chunk(self.num_local_experts, dim=1) | |||
expert_outputs = [] | |||
for chunk, expert in zip(chunks, self.deepspeed_experts): | |||
out = expert(chunk) | |||
if type(out) is tuple: | |||
out = out[0] # Ignore the bias term for now | |||
expert_outputs += [out] | |||
expert_output = torch.cat(expert_outputs, dim=1) | |||
return expert_output |
@@ -0,0 +1,98 @@ | |||
''' | |||
Copyright 2020 The Microsoft DeepSpeed Team | |||
''' | |||
import typing | |||
import torch | |||
from megatron import mpu | |||
from .experts import Experts | |||
from .sharded_moe import MOELayer, TopKGate | |||
class MoE(torch.nn.Module): | |||
def __init__(self, | |||
hidden_size, | |||
expert, | |||
num_experts=1, | |||
ep_size=1, | |||
k=1, | |||
capacity_factor=1., | |||
eval_capacity_factor=1., | |||
min_capacity=4, | |||
use_residual=False, | |||
noisy_gate_policy: typing.Optional[str] = None, | |||
drop_tokens: bool = True, | |||
use_rts=True, | |||
use_tutel: bool = False, | |||
top_k_linear_strategy: str = 'normal', | |||
use_expert_residual_network: bool = False): | |||
super(MoE, self).__init__() | |||
self.use_residual = use_residual | |||
assert num_experts % ep_size == 0, f'Number of experts ({num_experts}) should ' \ | |||
f'be divisible by expert parallel size ({ep_size})' | |||
self.ep_size = ep_size | |||
self.expert_group_name = f'ep_size_{self.ep_size}' | |||
self.num_experts = num_experts | |||
self.num_local_experts = num_experts // self.ep_size | |||
assert noisy_gate_policy is None or noisy_gate_policy in ['None', 'Jitter', 'RSample'], \ | |||
'Unsupported noisy_gate_policy: ' + noisy_gate_policy | |||
experts = Experts(expert, self.num_local_experts, | |||
self.expert_group_name) | |||
self.deepspeed_moe = MOELayer( | |||
TopKGate( | |||
hidden_size, | |||
num_experts, | |||
k, | |||
capacity_factor, | |||
eval_capacity_factor, | |||
min_capacity, | |||
noisy_gate_policy, | |||
drop_tokens, | |||
use_rts, | |||
top_k_linear_strategy=top_k_linear_strategy), | |||
experts, | |||
self.expert_group_name, | |||
self.ep_size, | |||
self.num_local_experts, | |||
use_tutel=use_tutel, | |||
use_expert_residual_network=use_expert_residual_network) | |||
self.deepspeed_moe._set_ep_group( | |||
mpu.get_expert_parallel_group(self.expert_group_name)) | |||
if self.use_residual: | |||
self.mlp = expert | |||
# coefficient is used for weighted sum of the output of expert and mlp | |||
self.coefficient = torch.nn.Linear(hidden_size, 2) | |||
def forward(self, hidden_states, used_token=None): | |||
""" MoE forward | |||
Arguments: | |||
hidden_states (Tensor): input to the layer | |||
used_token (Tensor, optional): default: None, mask only used tokens | |||
Returns: | |||
A tuple including output, gate loss, and expert count. | |||
* output (Tensor): output of the model | |||
* l_aux (Tensor): gate loss value | |||
* exp_counts (int): expert count | |||
""" | |||
output = self.deepspeed_moe(hidden_states, used_token) | |||
if self.use_residual: | |||
# Residual MoE | |||
output_mlp = self.mlp(hidden_states) | |||
if type(output_mlp) is tuple: | |||
output_mlp = output_mlp[0] # Ignore the bias term for now | |||
coef = self.coefficient(hidden_states) | |||
coef = torch.nn.functional.softmax(coef, dim=1) | |||
output = output * coef[..., 0:1] + output_mlp * coef[..., 1:] | |||
return output, self.deepspeed_moe.l_aux, self.deepspeed_moe.exp_counts |
@@ -0,0 +1,87 @@ | |||
''' | |||
Copyright 2020 The Microsoft DeepSpeed Team | |||
''' | |||
import torch | |||
from megatron import mpu | |||
def _gather_tokens(input_, dim=0): | |||
"""Gather tensors and concatenate them along a dimension""" | |||
input_ = input_.contiguous() | |||
# Size and dimension. | |||
rank = mpu.get_tensor_model_parallel_rank() | |||
tensor_list = [ | |||
torch.empty_like(input_) | |||
for _ in range(mpu.get_model_parallel_world_size()) | |||
] | |||
tensor_list[rank] = input_ | |||
torch.distributed.all_gather( | |||
tensor_list, input_, group=mpu.get_tensor_model_parallel_group()) | |||
# Note: torch.cat already creates a contiguous tensor. | |||
output = torch.cat(tensor_list, dim=dim).contiguous() | |||
return output | |||
def _drop_tokens(input_, dim=0): | |||
"""Divide a tensor among the tensor parallel ranks""" | |||
total_chunks = mpu.get_model_parallel_world_size() | |||
this_chunk = mpu.get_model_parallel_rank() | |||
assert input_.shape[ | |||
dim] % total_chunks == 0, f'input dimension {dim} ({input_.shape[dim]}) ' \ | |||
f'is not divisible by tensor parallel world size ({total_chunks})' | |||
chunk_size = input_.shape[dim] // total_chunks | |||
return torch.narrow(input_, dim, this_chunk * chunk_size, chunk_size) | |||
class _GatherTokens(torch.autograd.Function): | |||
"""All gather tokens among the tensor parallel ranks""" | |||
@staticmethod | |||
def symbolic(graph, input_, dim): | |||
return _gather_tokens(input_, dim) | |||
@staticmethod | |||
def forward(ctx, input_, dim): | |||
ctx.dim = dim | |||
return _gather_tokens(input_, dim) | |||
@staticmethod | |||
def backward(ctx, grad_output): | |||
return _drop_tokens(grad_output, ctx.dim), None | |||
class _DropTokens(torch.autograd.Function): | |||
'Divide tokens equally among the tensor parallel ranks' | |||
@staticmethod | |||
def symbolic(graph, input_, dim): | |||
return _drop_tokens(input_, dim) | |||
@staticmethod | |||
def forward(ctx, input_, dim): | |||
ctx.dim = dim | |||
return _drop_tokens(input_, dim) | |||
@staticmethod | |||
def backward(ctx, input_): | |||
return _gather_tokens(input_, ctx.dim), None | |||
def gather_tokens(input_, dim=0): | |||
if mpu is None or mpu.get_model_parallel_world_size() == 1: | |||
# no tensor parallelism for non-experts | |||
return input_ | |||
return _GatherTokens.apply(input_, dim) | |||
def drop_tokens(input_, dim=0): | |||
if mpu is None or mpu.get_model_parallel_world_size() == 1: | |||
# no tensor parallelism for non-experts | |||
return input_ | |||
return _DropTokens.apply(input_, dim) |
@@ -0,0 +1,647 @@ | |||
''' | |||
Copyright 2021 The Microsoft DeepSpeed Team | |||
''' | |||
# The file has been adapted from two fairscale files: | |||
# (1) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/moe_layer.py | |||
# (2) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/top2gate.py | |||
# Git commit hash: 34df606902a240567a0d898037ece55c2f1336cf | |||
# We retain the following license from the original files: | |||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |||
# | |||
# This source code is licensed under the BSD license found in the | |||
# LICENSE file in the root directory of this source tree. | |||
import math | |||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple | |||
import torch | |||
import torch.distributed as dist | |||
import torch.nn.functional as F | |||
from megatron import mpu | |||
from scipy.special import binom | |||
from torch import Tensor, nn | |||
from torch.nn import Module | |||
from ..configuration import logger | |||
from .mappings import drop_tokens, gather_tokens | |||
try: | |||
from apex.normalization import FusedLayerNorm as _FusedLayerNorm | |||
has_fused_layernorm = True | |||
class FusedLayerNorm(_FusedLayerNorm): | |||
@torch.jit.unused | |||
def forward(self, x): | |||
if not x.is_cuda: | |||
return super().forward(x) | |||
else: | |||
with torch.cuda.device(x.device): | |||
return super().forward(x) | |||
except ImportError: | |||
has_fused_layernorm = False | |||
if TYPE_CHECKING: | |||
Base = Module[Tensor] | |||
else: | |||
Base = Module | |||
uniform_map: Dict[torch.device, Callable] = {} | |||
gumbel_map: Dict[torch.device, Callable] = {} | |||
exp_selection_uniform_map: Dict[torch.device, Callable] = {} | |||
def multiplicative_jitter(x, device: torch.device, epsilon=1e-2): | |||
""" | |||
Modified from switch transformer paper. mesh transformers | |||
Multiply values by a random number between 1-epsilon and 1+epsilon. | |||
Makes models more resilient to rounding errors introduced by bfloat16. | |||
This seems particularly important for logits. | |||
Args: | |||
x: a torch.tensor | |||
device: torch.device | |||
epsilon: a floating point value | |||
Returns: | |||
a jittered x. | |||
""" | |||
if epsilon == 0: | |||
return x | |||
uniform = uniform_map.get(device) | |||
if uniform is None: | |||
uniform = torch.distributions.uniform.Uniform( | |||
low=torch.tensor(1.0 - epsilon, device=device), | |||
high=torch.tensor(1.0 + epsilon, | |||
device=device)).rsample # type: ignore | |||
uniform_map[device] = uniform | |||
return x * uniform(x.shape) | |||
def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor: | |||
gumbel = gumbel_map.get(device) | |||
if gumbel is None: | |||
one = torch.tensor(1.0, device=device) | |||
zero = torch.tensor(0.0, device=device) | |||
gumbel = torch.distributions.gumbel.Gumbel(zero, | |||
one).rsample # type: ignore | |||
gumbel_map[device] = gumbel | |||
return gumbel(shape) | |||
# Based on https://github.com/pytorch/pytorch/pull/40762 | |||
class _AllToAll(torch.autograd.Function): | |||
@staticmethod | |||
def forward(ctx: Any, group: dist.ProcessGroup, | |||
input: Tensor) -> Tensor: # type: ignore | |||
ctx.group = group | |||
input = input.contiguous() | |||
output = torch.empty_like(input) | |||
dist.all_to_all_single(output, input, group=group) | |||
return output | |||
@staticmethod | |||
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]: | |||
return (None, _AllToAll.apply(ctx.group, *grad_output)) | |||
# einsum rewrites are on par or more performant | |||
# switch can be bubbled up in future | |||
USE_EINSUM = True | |||
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity | |||
# See https://arxiv.org/pdf/2006.16668.pdf for details. | |||
def einsum(rule, a, b): | |||
if USE_EINSUM: | |||
return torch.einsum(rule, a, b) | |||
elif rule == 's,se->se': | |||
return a.reshape(a.shape[0], -1) * b | |||
elif rule == 'se,sc->sec': | |||
return a.unsqueeze(2) * b.unsqueeze(1) | |||
elif rule == 'se,se->s': | |||
return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1) | |||
elif rule == 'sec,sm->ecm': | |||
s = a.shape[0] | |||
e = a.shape[1] | |||
c = a.shape[2] | |||
m = b.shape[1] | |||
return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m) | |||
elif rule == 'sec,ecm->sm': | |||
return torch.matmul( | |||
a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1])) | |||
elif rule == 'ks,ksm->sm': | |||
k = b.shape[0] | |||
s = b.shape[1] | |||
m = b.shape[2] | |||
# [k, s] -> [s, k] -> [s, 1, k] | |||
a = a.t().unsqueeze(1) | |||
# [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k] | |||
b = b.reshape(k, -1).t().reshape(s, m, k) | |||
# bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1] | |||
return torch.bmm(a, b.transpose(1, 2)).squeeze(2) | |||
else: | |||
return torch.einsum(rule, a, b) | |||
# The following functions are extracted and scripted | |||
# because otherwise during a torch.jit.trace, the non-Tensor | |||
# values used in the calculations get recorded as constants. | |||
# torch.jit.script coerces them into Tensors and preserves | |||
# their dynamic shapes. This enables ONNX export. | |||
# We can't script the entire top1gating function because it | |||
# includes stateful caching logic which is incompatible with ONNX. | |||
@torch.jit.script | |||
def _capacity(gates: Tensor, capacity_factor: Tensor, | |||
min_capacity: Tensor) -> Tensor: | |||
# gates has shape of SE | |||
num_tokens = gates.shape[0] | |||
num_experts = gates.shape[1] | |||
# to(torch.int64) works around a bug in torch.onnx.export: | |||
# it should cast k to int64 when converting torch.topk but it doesn't. | |||
capacity = torch.ceil( | |||
(num_tokens / num_experts) * capacity_factor).to(torch.int64) | |||
if capacity < min_capacity: | |||
capacity = min_capacity.to(torch.int64) | |||
return capacity | |||
@torch.jit.script | |||
def _top_idx(source, k): | |||
return torch.topk(source, k=k, dim=0)[1] | |||
@torch.jit.script | |||
def _one_hot_to_float(x, num_classes): | |||
return F.one_hot(x, num_classes=num_classes).float() | |||
def top1gating( | |||
logits: Tensor, | |||
capacity_factor: float, | |||
min_capacity: int, | |||
used_token: Tensor = None, | |||
noisy_gate_policy: Optional[str] = None, | |||
drop_tokens: bool = True, | |||
use_rts: bool = True, | |||
use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]: | |||
"""Implements Top1Gating on logits.""" | |||
if noisy_gate_policy == 'RSample': | |||
logits_w_noise = logits + gumbel_rsample( | |||
logits.shape, device=logits.device) | |||
# everything is in fp32 in this function | |||
gates = F.softmax(logits, dim=1) | |||
capacity = _capacity(gates, torch.tensor(capacity_factor), | |||
torch.tensor(min_capacity)) | |||
# Create a mask for 1st's expert per token | |||
# noisy gating | |||
indices1_s = torch.argmax( | |||
logits_w_noise if noisy_gate_policy == 'RSample' else gates, dim=1) | |||
num_experts = int(gates.shape[1]) | |||
mask1 = F.one_hot(indices1_s, num_classes=num_experts) | |||
# mask only used tokens | |||
if used_token is not None: | |||
mask1 = einsum('s,se->se', used_token, mask1) | |||
# gating decisions | |||
exp_counts = torch.sum(mask1, dim=0).detach().to('cpu') | |||
# if we don't want to drop any tokens | |||
if not drop_tokens: | |||
new_capacity = torch.max(exp_counts).to(logits.device) | |||
dist.all_reduce( | |||
new_capacity, op=dist.ReduceOp.MAX, group=dist.group.WORLD) | |||
capacity = new_capacity | |||
# Compute l_aux | |||
alpha = torch.max(gates, dim=1).values.unsqueeze(1) | |||
me = torch.mean(gates, dim=0) | |||
ce = torch.mean(mask1.float(), dim=0) | |||
l_aux = torch.sum(me * ce) * num_experts | |||
# Random Token Selection | |||
if use_rts: | |||
uniform = exp_selection_uniform_map.get(logits.device) | |||
if uniform is None: | |||
uniform = torch.distributions.uniform.Uniform( | |||
low=torch.tensor(0.0, device=logits.device), | |||
high=torch.tensor(1.0, device=logits.device)).rsample | |||
exp_selection_uniform_map[logits.device] = uniform | |||
mask1_rand = mask1 * uniform(mask1.shape) | |||
else: | |||
mask1_rand = mask1 | |||
assert logits.shape[0] >= min_capacity, \ | |||
'No. of tokens (batch-size) should be greater than min_capacity. ' \ | |||
'Either set min_capacity to 0 or increase your batch size.' | |||
top_idx = _top_idx(mask1_rand, capacity) | |||
new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1) | |||
mask1 = new_mask1 | |||
if use_tutel: | |||
# Tutel doesn't support index values masked with zero | |||
# so we need to replace masked indices with -1 | |||
indices_mask = mask1.sum(dim=1) * num_experts - 1 | |||
indices1_s = torch.min(indices1_s, indices_mask) | |||
# Compute locations in capacity buffer | |||
if use_tutel: | |||
locations1 = tutel_moe.fast_cumsum_sub_one(mask1) | |||
else: | |||
locations1 = torch.cumsum(mask1, dim=0) - 1 | |||
if use_tutel: | |||
gates1_s = (gates * mask1).sum(dim=1) | |||
locations1_s = torch.sum(locations1 * mask1, dim=1) | |||
return l_aux, capacity, num_experts, [ | |||
indices1_s, | |||
], [ | |||
locations1_s, | |||
], [ | |||
gates1_s, | |||
], exp_counts, alpha | |||
# Store the capacity location for each token | |||
locations1_s = torch.sum(locations1 * mask1, dim=1) | |||
# Normalize gate probabilities | |||
mask1_float = mask1.float() | |||
gates = gates * mask1_float | |||
locations1_sc = _one_hot_to_float(locations1_s, capacity) | |||
combine_weights = einsum('se,sc->sec', gates, locations1_sc) | |||
dispatch_mask = combine_weights.bool() | |||
return l_aux, combine_weights, dispatch_mask, exp_counts, alpha | |||
class TopKGate(Module): | |||
"""Gate module which implements Top2Gating as described in Gshard_. | |||
:: | |||
gate = TopKGate(model_dim, num_experts) | |||
l_aux, combine_weights, dispatch_mask = gate(input) | |||
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf | |||
Args: | |||
model_dim (int): | |||
size of model embedding dimension | |||
num_experts (ints): | |||
number of experts in model | |||
""" | |||
wg: torch.nn.Linear | |||
def __init__(self, | |||
model_dim: int, | |||
num_experts: int, | |||
k: int = 1, | |||
capacity_factor: float = 1.0, | |||
eval_capacity_factor: float = 1.0, | |||
min_capacity: int = 8, | |||
noisy_gate_policy: Optional[str] = None, | |||
drop_tokens: bool = True, | |||
use_rts: bool = True, | |||
top_k_linear_strategy: str = 'standard') -> None: | |||
super().__init__() | |||
# Only top-1 are supported at the moment. | |||
if k != 1: | |||
raise ValueError('Only top-1 gatings are supported.') | |||
if top_k_linear_strategy == 'standard': | |||
self.wg = torch.nn.Linear( | |||
model_dim, num_experts, bias=False).float() | |||
elif top_k_linear_strategy == 'lsoftmax': | |||
self.wg = LSoftmaxLinearLayer( | |||
model_dim, num_experts, margin=1).float() | |||
else: | |||
raise ValueError( | |||
'Only standard or lsoftmax top-k-linear-strategy are supported.' | |||
) | |||
self.k = k | |||
self.capacity_factor = capacity_factor | |||
self.eval_capacity_factor = eval_capacity_factor | |||
self.min_capacity = min_capacity | |||
self.noisy_gate_policy = noisy_gate_policy | |||
self.wall_clock_breakdown = False | |||
self.gate_time = 0.0 | |||
self.drop_tokens = drop_tokens | |||
self.use_rts = use_rts | |||
self.top_k_linear_strategy = top_k_linear_strategy | |||
def forward( | |||
self, | |||
input: torch.Tensor, | |||
used_token: torch.Tensor = None, | |||
use_tutel: bool = False | |||
) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore | |||
if self.wall_clock_breakdown: | |||
self.timers('TopKGate').start() | |||
if self.top_k_linear_strategy == 'standard': | |||
if self.wg.weight.dtype != torch.float32: | |||
self.wg = self.wg.float() | |||
elif self.top_k_linear_strategy == 'lsoftmax': | |||
if self.wg.weight.weight.dtype != torch.float32: | |||
self.wg.weight = self.wg.weight.float() | |||
input_fp32 = input.float() | |||
# input jittering | |||
if self.noisy_gate_policy == 'Jitter' and self.training: | |||
input_fp32 = multiplicative_jitter(input_fp32, device=input.device) | |||
if self.k == 1: | |||
if self.top_k_linear_strategy == 'standard': | |||
logits = self.wg(input_fp32) | |||
elif self.top_k_linear_strategy == 'lsoftmax': | |||
logits = self.wg(input_fp32, input_fp32.device, self.training) | |||
gate_output = top1gating( | |||
logits, self.capacity_factor if self.training else | |||
self.eval_capacity_factor, self.min_capacity, used_token, | |||
self.noisy_gate_policy if self.training else None, | |||
self.drop_tokens, self.use_rts, use_tutel) | |||
if self.wall_clock_breakdown: | |||
self.timers('TopKGate').stop() | |||
self.gate_time = self.timers('TopKGate').elapsed( | |||
reset=False) * 1000 | |||
return gate_output | |||
class MOELayer(Base): | |||
"""MOELayer module which implements MixtureOfExperts as described in Gshard_. | |||
:: | |||
gate = TopKGate(model_dim, num_experts) | |||
moe = MOELayer(gate, expert) | |||
output = moe(input) | |||
l_aux = moe.l_aux | |||
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf | |||
Args: | |||
gate (torch.nn.Module): | |||
gate network | |||
expert (torch.nn.Module): | |||
expert network | |||
""" | |||
def __init__(self, | |||
gate: Module, | |||
experts: Module, | |||
ep_group_name, | |||
ep_size, | |||
num_local_experts: int, | |||
use_tutel: bool = False, | |||
use_expert_residual_network: bool = False) -> None: | |||
super().__init__() | |||
self.gate = gate | |||
self.experts = experts | |||
self.ep_group = None | |||
self.ep_size = ep_size | |||
self.ep_group_name = ep_group_name | |||
self.num_local_experts = num_local_experts | |||
self.wall_clock_breakdown = False | |||
self.use_expert_residual_network = use_expert_residual_network | |||
if self.use_expert_residual_network: | |||
self.expert_network = nn.Sequential( | |||
*([ExpertResidualLayer(self.gate.model_dim) | |||
for _ in range(6)])) | |||
self.use_tutel = use_tutel and TUTEL_INSTALLED | |||
if self.use_tutel: | |||
logger.info('Using Tutel optimizations.') | |||
elif use_tutel and not TUTEL_INSTALLED: | |||
logger.info( | |||
'Tutel optimization requested but not installed Proceeding without Tutel.' | |||
) | |||
def _set_ep_group(self, ep_group): | |||
self.ep_group = ep_group | |||
def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: | |||
if self.wall_clock_breakdown: | |||
self.timers('moe').start() | |||
# Implement Algorithm 2 from GShard paper. | |||
d_model = input[0].shape[-1] | |||
# Initial implementation -> Reshape into S tokens by dropping sequence dimension. | |||
# Reshape into G groups so that each group can distribute tokens equally | |||
# group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1 | |||
reshaped_input = input[0].reshape(-1, d_model) | |||
if self.use_tutel: | |||
self.l_aux, C, E, indices_, locations_, gates_, self.exp_counts, alpha = self.gate( | |||
reshaped_input, input[1], True) | |||
_, M = reshaped_input.size(0), reshaped_input.size(1) | |||
if not hasattr(self, '_tutel_dispatcher'): | |||
self._tutel_dispatcher = tutel_moe.fast_dispatcher( | |||
E, C, M, dispatch_dtype=reshaped_input.dtype) | |||
self._tutel_dispatcher.update( | |||
indices_, locations_, gates_, capacity=C) | |||
dispatched_input = self._tutel_dispatcher.encode(reshaped_input) | |||
else: | |||
self.l_aux, combine_weights, dispatch_mask, self.exp_counts, alpha = self.gate( | |||
reshaped_input, input[1]) | |||
dispatched_input = einsum('sec,sm->ecm', | |||
dispatch_mask.type_as(input[0]), | |||
reshaped_input) | |||
if self.wall_clock_breakdown: | |||
self.timers('falltoall').start() | |||
if mpu.get_expert_model_parallel_world_size() == 1: | |||
# If the non-expert is tensor-parallel, it will create | |||
# duplicate tokens on the tensor-parallel ranks. | |||
# Since our experts are not tensor-parallel, these duplicates | |||
# need to be dropped to ensure correctness. | |||
# this also doubles up as a communication optimization as we are | |||
# reducing the all-to-all communication volume. | |||
if self.use_tutel: | |||
# reshape tutel's output from [e*c,m] to [e,c,m] | |||
dispatched_input = dispatched_input.reshape( | |||
self.ep_size * self.num_local_experts, -1, d_model) | |||
dispatched_input = drop_tokens(dispatched_input, dim=1) | |||
dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input) | |||
if self.wall_clock_breakdown: | |||
self.timers('falltoall').stop() | |||
self.time_falltoall = self.timers('falltoall').elapsed( | |||
reset=False) * 1000 | |||
# Re-shape after all-to-all: ecm -> gecm | |||
dispatched_input = dispatched_input.reshape(self.ep_size, | |||
self.num_local_experts, -1, | |||
d_model) | |||
expert_output = self.experts(dispatched_input) | |||
if self.wall_clock_breakdown: | |||
self.timers('salltoall').start() | |||
expert_output = _AllToAll.apply(self.ep_group, expert_output) | |||
if self.wall_clock_breakdown: | |||
self.timers('salltoall').stop() | |||
self.time_salltoall = self.timers('salltoall').elapsed( | |||
reset=False) * 1000 | |||
# Re-shape back: gecm -> ecm | |||
expert_output = expert_output.reshape( | |||
self.ep_size * self.num_local_experts, -1, d_model) | |||
if mpu.get_expert_model_parallel_world_size() == 1: | |||
# the dropped duplicate tokens need to be gathered on each | |||
# tensor parallel rank again for the tensor-parallel | |||
# non-expert of the next layer. | |||
expert_output = gather_tokens(expert_output, dim=1) | |||
if self.use_tutel: | |||
combined_output = self._tutel_dispatcher.decode( | |||
expert_output.view(E * C, M)) | |||
else: | |||
combined_output = einsum('sec,ecm->sm', | |||
combine_weights.type_as(input[0]), | |||
expert_output) | |||
if self.use_expert_residual_network: | |||
combined_output = alpha * self.expert_network(combined_output) + ( | |||
1 - alpha) * combined_output | |||
a = combined_output.reshape(input[0].shape) | |||
if self.wall_clock_breakdown: | |||
self.timers('moe').stop() | |||
self.time_moe = self.timers('moe').elapsed(reset=False) * 1000 | |||
return a | |||
class LSoftmaxLinearLayer(torch.nn.Module): | |||
def __init__(self, input_features, output_features, margin): | |||
super().__init__() | |||
self.input_dim = input_features # number of input feature i.e. output of the last fc layer | |||
self.output_dim = output_features # number of output = class numbers | |||
self.margin = margin # m | |||
self.beta = 100 | |||
self.beta_min = 0 | |||
self.scale = 0.99 | |||
self.num_experts = output_features | |||
# Initialize L-Softmax parameters | |||
self.weight = torch.nn.Linear( | |||
input_features, output_features, bias=False).float() | |||
self.divisor = math.pi / self.margin # pi/m | |||
self.C_m_2n = torch.Tensor(binom(margin, range(0, margin + 1, | |||
2))) # C_m{2n} | |||
self.cos_powers = torch.Tensor(range(self.margin, -1, -2)) # m - 2n | |||
self.sin2_powers = torch.Tensor(range(len(self.cos_powers))) # n | |||
self.signs = torch.ones(margin // 2 + 1) # 1, -1, 1, -1, ... | |||
self.signs[1::2] = -1 | |||
def calculate_cos_m_theta(self, cos_theta, device): | |||
sin2_theta = 1 - cos_theta**2 | |||
cos_terms = cos_theta.unsqueeze(1)**self.cos_powers.to( | |||
device).unsqueeze(0) # cos^{m - 2n} | |||
sin2_terms = ( | |||
sin2_theta.unsqueeze(1)**self.sin2_powers.to(device).unsqueeze(0)) | |||
cos_m_theta = (self.signs.to(device).unsqueeze(0) | |||
* self.C_m_2n.to(device).unsqueeze(0) * cos_terms | |||
* sin2_terms).sum(1) # summation of all terms | |||
return cos_m_theta | |||
def reset_parameters(self): | |||
nn.init.kaiming_normal_(self.weight.data.t()) | |||
def find_k(self, cos): | |||
# to account for acos numerical errors | |||
eps = 1e-7 | |||
cos = torch.clamp(cos, -1 + eps, 1 - eps) | |||
acos = cos.acos() | |||
k = (acos / self.divisor).floor().detach() | |||
return k | |||
def forward(self, input, device, training): | |||
if training: | |||
x, w = input, self.weight.float() | |||
beta = max(self.beta, self.beta_min) | |||
logit = w(x) | |||
indexes = range(logit.size(0)) | |||
# target = torch.fmod(torch.randperm(logit.size(0)), self.num_experts) | |||
target = torch.fmod( | |||
torch.range(0, | |||
logit.size(0) - 1), self.num_experts).long() | |||
logit_target = logit[indexes, target] | |||
# cos(theta) = w * x / ||w||*||x|| | |||
w_target_norm = w.weight[:, target].norm(p=2, dim=0) | |||
x_norm = x.norm(p=2, dim=1) | |||
cos_theta_target = logit_target / (w_target_norm * x_norm + 1e-10) | |||
# equation 7 | |||
cos_m_theta_target = self.calculate_cos_m_theta( | |||
cos_theta_target, device) | |||
# find k in equation 6 | |||
k = self.find_k(cos_theta_target) | |||
# f_y_i | |||
logit_target_updated = w_target_norm * x_norm * (( | |||
(-1)**k * cos_m_theta_target) - 2 * k) | |||
logit_target_updated_beta = (logit_target_updated + beta | |||
* logit[indexes, target]) / (1 + beta) | |||
logit[indexes, target] = logit_target_updated_beta | |||
self.beta *= self.scale | |||
return logit | |||
else: | |||
return self.weight(input) | |||
def LayerNorm(normalized_shape, | |||
eps=1e-5, | |||
elementwise_affine=True, | |||
export=False): | |||
if torch.jit.is_scripting() or torch.jit.is_tracing(): | |||
export = True | |||
if not export and torch.cuda.is_available() and has_fused_layernorm: | |||
return FusedLayerNorm(normalized_shape, eps, elementwise_affine) | |||
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) | |||
class ExpertResidualLayer(torch.nn.Module): | |||
def __init__(self, embed_dim): | |||
super().__init__() | |||
self.norm = LayerNorm(embed_dim, export=False) | |||
self.ff1 = torch.nn.Linear(embed_dim, embed_dim * 4) | |||
self.ff2 = torch.nn.Linear(embed_dim * 4, embed_dim) | |||
self.ff2.weight.data.zero_() | |||
def forward(self, xs): | |||
return xs + self.ff2(torch.nn.functional.relu(self.ff1(self.norm(xs)))) |
@@ -0,0 +1,125 @@ | |||
''' | |||
Copyright 2020 The Microsoft DeepSpeed Team | |||
''' | |||
from typing import Dict, List, Tuple | |||
import torch | |||
from .layer import MoE | |||
def has_moe_layers(m): | |||
has_moe = False | |||
num_experts = 0 | |||
for _, module in m.named_modules(): | |||
if isinstance(module, MoE): | |||
has_moe = True | |||
num_experts = module.num_experts | |||
break | |||
return has_moe, num_experts | |||
def is_moe_param(param: torch.Tensor) -> bool: | |||
if hasattr(param, 'allreduce') and not param.allreduce: | |||
return True | |||
return False | |||
def split_params_into_shared_and_expert_params( | |||
params: List[torch.nn.Parameter] | |||
) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]: | |||
shared_params, expert_params = [], [] | |||
for p in params: | |||
if is_moe_param(p): | |||
expert_params.append(p) | |||
else: | |||
shared_params.append(p) | |||
return shared_params, expert_params | |||
def split_params_grads_into_shared_and_expert_params( | |||
group: List[torch.nn.Parameter] | |||
) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]: | |||
"""Split grad of parameters into grads of non-expert params | |||
and grads of expert params. This is useful while computing | |||
grad-norms for clipping and overflow detection | |||
group (List[torch.nn.Parameter]): | |||
Args: | |||
The group of parameters to split | |||
Returns: | |||
Tuple[List[torch.nn.Parameter], List[torch.nn.Parameter]]: | |||
list of gradients for non MoE params, list of gradients of MoE params | |||
""" | |||
expert_grads = [] | |||
shared_grads = [] | |||
for p in group: | |||
if p.grad is not None: | |||
if is_moe_param(p): | |||
expert_grads.append(p.grad.to(p.dtype)) | |||
else: | |||
shared_grads.append(p.grad.to(p.dtype)) | |||
return shared_grads, expert_grads | |||
def split_params_into_different_moe_groups_for_optimizer( | |||
param_groups: Tuple[Dict]) -> Tuple[Dict]: | |||
"""Split parameters into different MoE groups for optimizer | |||
Args: | |||
param_groups (Tuple[Dict]): | |||
The list of parameter groups to split | |||
Returns: | |||
Tuple[Dict]: | |||
list of MoE/non-MoE groups for optimizer | |||
""" | |||
if isinstance(param_groups, tuple): | |||
param_groups = list(param_groups) # Tuple cannot be modified | |||
elif isinstance(param_groups, dict): | |||
param_groups = [param_groups] | |||
elif not isinstance(param_groups, list): | |||
raise ValueError(f'Unknown param group type of {type(param_groups)}') | |||
# gather all data parallel group names | |||
data_parallel_group_names = set() | |||
for param_group in param_groups: | |||
for param in param_group['params']: | |||
if is_moe_param(param): | |||
data_parallel_group_names.add(param.group_name) | |||
data_parallel_group_names = list(data_parallel_group_names) | |||
group_moe = {} | |||
# Create the param MoE groups, leave param assign to next step | |||
for param_group in param_groups: | |||
group_moe[param_group['name']] = {} | |||
for key in data_parallel_group_names: | |||
group_moe[param_group['name']][key] = {} | |||
group_moe[param_group['name']][key]['name'] = key | |||
group_moe[param_group['name']][key]['moe'] = True | |||
for ori_key in param_group.keys(): | |||
if ori_key != 'name': | |||
if ori_key == 'params': | |||
group_moe[param_group['name']][key][ori_key] = [] | |||
else: | |||
group_moe[param_group['name']][key][ | |||
ori_key] = param_group[ori_key] | |||
# Assign param | |||
for param_group in param_groups: | |||
new_params = [] | |||
for param in param_group['params']: | |||
if is_moe_param(param): | |||
group_moe[param_group['name']][ | |||
param.group_name]['params'].append(param) | |||
# param_group['params'].remove(param) | |||
else: | |||
new_params.append(param) | |||
param_group['params'] = new_params | |||
# Flatten the moe groups | |||
for k, v in group_moe.items(): | |||
for k1, v1 in v.items(): | |||
param_groups.append(v1) | |||
return tuple(param_groups) |
@@ -0,0 +1,62 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Dict | |||
from modelscope.metainfo import Models | |||
from modelscope.models.base import Tensor, TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.utils.constant import Tasks | |||
__all__ = ['GPTMoEForTextGeneration'] | |||
@MODELS.register_module(Tasks.text_generation, module_name=Models.gpt_moe) | |||
class GPTMoEForTextGeneration(TorchModel): | |||
def __init__(self, model_dir: str, *args, **kwargs): | |||
"""initialize the text generation model from the `model_dir` path. | |||
Args: | |||
model_dir (str): the model path. | |||
""" | |||
super().__init__(model_dir, *args, **kwargs) | |||
from modelscope.models.nlp.gpt_moe import GPTMoEModel | |||
from transformers import BertTokenizer | |||
print('****') | |||
print(model_dir) | |||
self.model = GPTMoEModel.from_pretrained(model_dir) | |||
self.tokenizer = BertTokenizer.from_pretrained(model_dir) | |||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
"""return the result by the model | |||
Args: | |||
input (Dict[str, Tensor]): the preprocessed data | |||
Returns: | |||
Dict[str, Tensor]: results | |||
Example: | |||
{ | |||
'logits': Tensor([[0.54, 0.32...])]), # logits | |||
} | |||
""" | |||
return self.model(**input) | |||
def generate(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||
assert 'input_ids' in input, "generate function must accept 'input_ids' key" | |||
input_ids = input['input_ids'] | |||
if 'attention_mask' in input: | |||
attention_mask = input['attention_mask'] | |||
input_ids = input_ids[0][attention_mask[0].nonzero()] \ | |||
.squeeze().unsqueeze(0) | |||
# remove sep token at the end of tokenizer output | |||
input_ids = input_ids[:, :-1] | |||
gen_params = dict() | |||
gen_params['inputs'] = input_ids | |||
gen_params['do_sample'] = input.pop('do_sample', True) | |||
gen_params['max_length'] = input.pop('max_length', 128) | |||
gen_params['top_k'] = input.pop('top_k', 10) | |||
gen_params['top_p'] = input.pop('top_p', None) | |||
sample_output = self.model.generate(**gen_params) | |||
return {'sequences': sample_output[0]} |
@@ -0,0 +1,67 @@ | |||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from tokenizers import Tokenizer | |||
class JiebaBPETokenizer: | |||
"""SentencePiece BPE tokenizer with Jieba integration""" | |||
def __init__(self, tokenizer_json_file): | |||
self.name = 'Jieba BPE Tokenizer' | |||
self.tokenizer = Tokenizer.from_file(tokenizer_json_file) | |||
self.eod_id = self.tokenizer.token_to_id('<|endoftext|>') | |||
try: | |||
import jieba | |||
except ImportError: | |||
raise ImportError( | |||
'You need to install rjieba to use JiebaTokenizer. ' | |||
'See https://pypi.org/project/rjieba/ for installation.') | |||
self.jieba = jieba | |||
self.new_line = self.vocab['\n'] | |||
self.sep_token = self.vocab['<sep>'] | |||
@property | |||
def vocab_size(self): | |||
return self.tokenizer.get_vocab_size(with_added_tokens=True) | |||
@property | |||
def vocab(self): | |||
return self.tokenizer.get_vocab(with_added_tokens=True) | |||
@property | |||
def inv_vocab(self): | |||
vocab = self.vocab | |||
inv_vocab = dict() | |||
for key, val in vocab.items(): | |||
inv_vocab[val] = key | |||
return inv_vocab | |||
def tokenize(self, text, is_code=False): | |||
if not is_code: | |||
seg_list = [x for x in self.jieba.cut(text)] | |||
return self.tokenizer.encode( | |||
seg_list, is_pretokenized=True, add_special_tokens=True).ids | |||
else: | |||
return self.tokenizer.encode( | |||
text, is_pretokenized=False, add_special_tokens=True).ids | |||
def detokenize(self, token_ids): | |||
text = self.tokenizer.decode(token_ids, skip_special_tokens=False) | |||
return text | |||
@property | |||
def eod(self): | |||
return self.eod_id |
@@ -0,0 +1,54 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import Any, Dict | |||
import torch | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models.nlp.gpt_moe.distributed_gpt_moe import DistributedGPTMoE | |||
from modelscope.pipelines.base import DistributedPipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import TextGenerationJiebaPreprocessor | |||
from modelscope.utils.constant import Tasks | |||
@PIPELINES.register_module( | |||
Tasks.text_generation, module_name=Pipelines.gpt_moe_generation) | |||
class DistributedGPTMoEPipeline(DistributedPipeline): | |||
"""This class is used to instantiate the gpt-moe model. | |||
""" | |||
model = None | |||
def __init__(self, model, preprocessor=None, **kwargs): | |||
if preprocessor is None: | |||
preprocessor = TextGenerationJiebaPreprocessor(model) | |||
super().__init__(model, preprocessor=preprocessor, **kwargs) | |||
assert hasattr(preprocessor, 'tokenizer') | |||
@classmethod | |||
def _instantiate_one(cls, rank, model_dir, **kwargs): | |||
cls.model = DistributedGPTMoE(model_dir, rank, **kwargs) | |||
cls.model.eval() | |||
@classmethod | |||
def _forward_one(cls, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
tokens = inputs['inputs']['input_ids'].cuda( | |||
torch.cuda.current_device()) | |||
return cls.model.generate(tokens) | |||
def postprocess(self, inputs: Dict[str, Any], | |||
**postprocess_params) -> Dict[str, str]: | |||
"""process the prediction results | |||
Args: | |||
inputs (Dict[str, Any]): _description_ | |||
Returns: | |||
Dict[str, str]: the prediction results | |||
""" | |||
from modelscope.outputs import OutputKeys | |||
return { | |||
OutputKeys.TEXT: | |||
self.preprocessor.tokenizer.detokenize(inputs[0].tolist()) | |||
} |
@@ -0,0 +1,24 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import unittest | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.pipelines import pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.test_utils import test_level | |||
class TextGPTMoEGenerationTest(unittest.TestCase): | |||
def setUp(self) -> None: | |||
self.model_id_1_3B_MoE32 = 'PAI/nlp_gpt3_text-generation_1.3B_MoE-32' | |||
self.model_dir_1_3B_MoE32 = snapshot_download(self.model_id_1_3B_MoE32) | |||
self.input = '好的' | |||
@unittest.skip('distributed gpt-moe 1.3B_MoE-32, skipped') | |||
def test_gpt_moe_1_3B_MoE32(self): | |||
pipe = pipeline(Tasks.text_generation, model=self.model_id_1_3B_MoE32) | |||
print(pipe(self.input)) | |||
if __name__ == '__main__': | |||
unittest.main() |