Browse Source

add gpt-moe model for modelscope pipeline inference

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10836131
master^2
jerry.lp yingda.chen 2 years ago
parent
commit
177d70829b
16 changed files with 3093 additions and 0 deletions
  1. +2
    -0
      modelscope/metainfo.py
  2. +27
    -0
      modelscope/models/nlp/gpt_moe/__init__.py
  3. +355
    -0
      modelscope/models/nlp/gpt_moe/backbone.py
  4. +145
    -0
      modelscope/models/nlp/gpt_moe/checkpointing.py
  5. +128
    -0
      modelscope/models/nlp/gpt_moe/configuration.py
  6. +1236
    -0
      modelscope/models/nlp/gpt_moe/distributed_gpt_moe.py
  7. +0
    -0
      modelscope/models/nlp/gpt_moe/moe/__init__.py
  8. +36
    -0
      modelscope/models/nlp/gpt_moe/moe/experts.py
  9. +98
    -0
      modelscope/models/nlp/gpt_moe/moe/layer.py
  10. +87
    -0
      modelscope/models/nlp/gpt_moe/moe/mappings.py
  11. +647
    -0
      modelscope/models/nlp/gpt_moe/moe/sharded_moe.py
  12. +125
    -0
      modelscope/models/nlp/gpt_moe/moe/utils.py
  13. +62
    -0
      modelscope/models/nlp/gpt_moe/text_generation.py
  14. +67
    -0
      modelscope/models/nlp/gpt_moe/tokenizer.py
  15. +54
    -0
      modelscope/pipelines/nlp/distributed_gpt_moe_pipeline.py
  16. +24
    -0
      tests/pipelines/test_gpt_moe_text_generation.py

+ 2
- 0
modelscope/metainfo.py View File

@@ -80,6 +80,7 @@ class Models(object):
gcnncrf = 'gcnn-crf' gcnncrf = 'gcnn-crf'
bart = 'bart' bart = 'bart'
gpt3 = 'gpt3' gpt3 = 'gpt3'
gpt_moe = 'gpt-moe'
gpt_neo = 'gpt-neo' gpt_neo = 'gpt-neo'
plug = 'plug' plug = 'plug'
bert_for_ds = 'bert-for-document-segmentation' bert_for_ds = 'bert-for-document-segmentation'
@@ -255,6 +256,7 @@ class Pipelines(object):
text_error_correction = 'text-error-correction' text_error_correction = 'text-error-correction'
plug_generation = 'plug-generation' plug_generation = 'plug-generation'
gpt3_generation = 'gpt3-generation' gpt3_generation = 'gpt3-generation'
gpt_moe_generation = 'gpt-moe-generation'
faq_question_answering = 'faq-question-answering' faq_question_answering = 'faq-question-answering'
conversational_text_to_sql = 'conversational-text-to-sql' conversational_text_to_sql = 'conversational-text-to-sql'
table_question_answering_pipeline = 'table-question-answering-pipeline' table_question_answering_pipeline = 'table-question-answering-pipeline'


+ 27
- 0
modelscope/models/nlp/gpt_moe/__init__.py View File

@@ -0,0 +1,27 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .configuration import GPTMoEConfig
from .backbone import GPTMoEModel
from .text_generation import GPTMoEForTextGeneration
from .tokenizer import JiebaBPETokenizer
else:
_import_structure = {
'configuration': ['GPTMoEConfig'],
'backbone': ['GPTMoEModel'],
'text_generation': ['GPTMoEForTextGeneration'],
'tokenizer': ['JiebaBPETokenizer'],
}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 355
- 0
modelscope/models/nlp/gpt_moe/backbone.py View File

@@ -0,0 +1,355 @@
# Copyright 2021-2022 The Alibaba PAI Team Authors.
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import os
from typing import Optional, Union

import addict
import torch
from torch import nn
from torch.nn import functional as F
from transformers.modeling_utils import PreTrainedModel

from modelscope.utils.constant import ModelFile
from .configuration import GPTMoEConfig


class GPTMoESelfAttention(nn.Module):
"""Parallel self-attention layer abstract class.

Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""

def __init__(self, config):
super().__init__()

self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
# Per attention head
self.hidden_size_per_attention_head = \
self.hidden_size // self.num_attention_heads

self.query_key_value = nn.Linear(self.hidden_size,
3 * self.hidden_size)
self.softmax = nn.Softmax(dim=-1)
self.attention_dropout = nn.Dropout(
config.attention_probs_dropout_prob)

# Output.
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
self.output_dropout = nn.Dropout(config.hidden_dropout_prob)

def _transpose_for_scores(self, tensor):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
size [b, np, s, hn].
"""
new_tensor_shape = tensor.size()[:-1] + (
self.num_attention_heads, self.hidden_size_per_attention_head)
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)

def _split_tensor_along_last_dim(self,
tensor,
num_partitions,
contiguous_split_chunks=False):
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = tensor.size()[last_dim] // num_partitions
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)

return tensor_list

def forward(self, hidden_states, ltor_mask, is_infer=False):
# hidden_states: [b, s, h]
# ltor_mask: [1, 1, s, s]

# Attention heads. [b, s, hp]
tgt_len = hidden_states.size(1)
ltor_mask = torch.reshape(ltor_mask, [1, 1, tgt_len, tgt_len])
mixed_x_layer = self.query_key_value(hidden_states)
(mixed_query_layer, mixed_key_layer, mixed_value_layer) = \
self._split_tensor_along_last_dim(mixed_x_layer, 3)

# Reshape and transpose [b, np, s, hn]
query_layer = self._transpose_for_scores(mixed_query_layer)
key_layer = self._transpose_for_scores(mixed_key_layer)
value_layer = self._transpose_for_scores(mixed_value_layer)

previous_type = value_layer.type()

# Raw attention scores. [b, np, s, s]
attention_scores = torch.matmul(query_layer,
key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(
self.hidden_size_per_attention_head)
# Apply the left to right attention mask.
if is_infer:
src_len = key_layer.size(2)
ltor_mask = torch.tril(
torch.ones((1, tgt_len, src_len),
device=hidden_states.device)).view(
1, 1, tgt_len, src_len).type(previous_type)
converted_mask = 10000.0 * (1.0 - ltor_mask)
attention_scores = (torch.mul(attention_scores, ltor_mask)
- converted_mask).type(previous_type)

# Attention probabilities. [b, np, s, s]
attention_probs = self.softmax(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.attention_dropout(attention_probs)

# Context layer.
# [b, np, s, hn]
context_layer = torch.matmul(attention_probs, value_layer)
# [b, s, np, hn]
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (
self.hidden_size, )
# [b, s, hp]
context_layer = context_layer.view(*new_context_layer_shape)

# Output. [b, s, h]
output = self.dense(context_layer)
output = self.output_dropout(output)

return output


class GPTMoEMLP(nn.Module):
"""MLP.

MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""

def __init__(self, config):
super().__init__()

hidden_size = config.hidden_size
# Project to 4h.
self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
self.activation_func = F.gelu
# Project back to h.
self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)

self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(self, hidden_states):

# [s, b, 4hp]
intermediate_parallel = self.dense_h_to_4h(hidden_states)
intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, h]
output = self.dense_4h_to_h(intermediate_parallel)
output = self.dropout(output)
return output


class GPTMoETransformerLayer(nn.Module):
"""A single transformer layer.

Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""

def __init__(self, config):
super().__init__()

# Layernorm on the input data.
self.input_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layernorm_epsilon)

# Self attention.
self.attention = GPTMoESelfAttention(config)

# Layernorm on the attention output
self.post_attention_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layernorm_epsilon)

# MLP
self.mlp = GPTMoEMLP(config)

def forward(self, hidden_states, ltor_mask):
# hidden_states: [b, s, h]
# ltor_mask: [1, 1, s, s]

# Layer norm at the begining of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output = self.attention(layernorm_output, ltor_mask)
# Residual connection.
layernorm_input = hidden_states + attention_output
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# MLP.
mlp_output = self.mlp(layernorm_output)
# Second residual connection.
output = layernorm_input + mlp_output

return output


class GPTMoETransformer(nn.Module):
"""Transformer class."""

def __init__(self, config):
super().__init__()

self.input_tensor = None

# Number of layers.
self.num_layers = config.num_hidden_layers

self.layers = torch.nn.ModuleList(
[GPTMoETransformerLayer(config) for _ in range(self.num_layers)])

# Final layer norm before output.
self.final_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layernorm_epsilon)

def _get_layer(self, layer_number):
return self.layers[layer_number]

def forward(self, hidden_states, attention_mask):
# hidden_states: [s, b, h]

for index in range(self.num_layers):
layer = self._get_layer(index)
hidden_states = layer(hidden_states, attention_mask)

# Final layer norm.
hidden_states = self.final_layernorm(hidden_states)

return hidden_states


class GPTMoETransformerLanguageModel(nn.Module):
"""Transformer language model.

Arguments:
transformer_hparams: transformer hyperparameters
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""

def __init__(self, config):
super().__init__()

# Embeddings.
self.word_embeddings = nn.Embedding(config.vocab_size,
config.hidden_size)
self.position_embeddings = nn.Embedding(config.max_position_embeddings,
config.hidden_size)
self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob)

# Transformer.
self.transformer = GPTMoETransformer(config)

def forward(self, input_ids, attention_mask, position_ids):
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)

embeddings = words_embeddings + position_embeddings
transformer_input = self.embedding_dropout(embeddings)
transformer_output = self.transformer(transformer_input,
attention_mask)

logits = F.linear(transformer_output, self.word_embeddings.weight)
return logits


class GPTMoEModel(PreTrainedModel):

config_class = GPTMoEConfig

def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(
mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(
mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

def __init__(self, config):
super().__init__(config)
self.language_model = GPTMoETransformerLanguageModel(config)

def forward(self,
input_ids,
attention_mask=None,
position_ids=None,
labels=None,
**kwargs):
seq_length = input_ids.size(1)
attention_mask = torch.tril(
torch.ones((1, 1, seq_length, seq_length),
dtype=torch.long,
device=input_ids.device))
if position_ids is None:
position_ids = torch.arange(
seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

logits = self.language_model(input_ids, attention_mask, position_ids)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
logits.view(-1, self.config.vocab_size), labels.view(-1))
return addict.Dict(loss=loss, logits=logits)

@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: Optional[Union[str,
os.PathLike]]):
config = cls.config_class.from_pretrained(
pretrained_model_name_or_path)
model = cls(config)
state_dict_file = os.path.join(pretrained_model_name_or_path,
ModelFile.TORCH_MODEL_BIN_FILE)
state_dict = torch.load(state_dict_file)
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
state_dict = {
k.replace('model.language_model', 'language_model'): v
for k, v in state_dict.items()
}
model.load_state_dict(state_dict)
return model

def prepare_inputs_for_generation(self, input_ids, *args, **kwargs):
return {'input_ids': input_ids}

+ 145
- 0
modelscope/models/nlp/gpt_moe/checkpointing.py View File

@@ -0,0 +1,145 @@
# Copyright 2021-2022 The Alibaba PAI Team Authors.
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import torch
from megatron import mpu
from megatron.model import Float16Module
from torch.nn.parallel import DistributedDataParallel as torchDDP

from .configuration import logger
from .moe.layer import MoE


def unwrap_model(model, module_instances=(torchDDP)):
return_list = True
if not isinstance(model, list):
model = [model]
return_list = False
unwrapped_model = []
for model_module in model:
while isinstance(model_module, module_instances):
model_module = model_module.module
unwrapped_model.append(model_module)
if not return_list:
return unwrapped_model[0]
return unwrapped_model


def get_checkpoint_names(checkpoints_path,
path_load_tag,
num_experts,
tensor_rank=None,
expp_rank=None):
"""Determine the directory name for this rank's checkpoint."""
if tensor_rank is None:
tensor_rank = mpu.get_model_parallel_rank()

common_path = os.path.join(checkpoints_path, path_load_tag,
f'mp_rank_{tensor_rank:02d}')

if num_experts[0] > 0:
model_name = common_path + '_model_states.pt'
optim_name = os.path.join(
checkpoints_path, path_load_tag,
f'expp_rank_{expp_rank}_mp_rank_{tensor_rank:02d}_optim_states.pt')
else:
model_name = optim_name = os.path.join(common_path,
'model_optim_rng.pt')

return model_name, optim_name


def _get_expert_ckpt_name(checkpoints_path, layer_id, expert_id):
mp_rank = mpu.get_model_parallel_rank()
ckpt_name = os.path.join(
os.path.join(checkpoints_path, 'model'),
f'layer_{layer_id}_expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt'
)
return ckpt_name


def _load_base_checkpoint(load_dir, path_load_tag=None, num_experts=None):
""" Load the base state_dict from the given directory

If rank0 is true, just loads rank 0 checkpoint, ignoring arguments.
"""
largest_group_name = mpu.get_max_expert_size_name()
expp_rank = mpu.get_expert_parallel_rank(largest_group_name)
checkpoint_names = get_checkpoint_names(
load_dir,
path_load_tag=path_load_tag,
num_experts=num_experts,
expp_rank=expp_rank)
model_checkpoint_name, optim_checkpoint_name = checkpoint_names

logger.info(f'Loading model checkpoint from {model_checkpoint_name}')
model_state_dict = torch.load(model_checkpoint_name, map_location='cpu')

return model_state_dict


def load_checkpoint(model,
load_dir,
num_experts=None,
strict=True,
path_load_tag='model',
load_ds_ckpts=True):
model = unwrap_model(model, (torchDDP, Float16Module))

model_state_dict = _load_base_checkpoint(
load_dir, path_load_tag=path_load_tag, num_experts=num_experts)

assert model_state_dict is not None

if load_ds_ckpts:
load_moe_checkpoint(model, model_state_dict['module'], load_dir)
else:
load_moe_checkpoint(model, model_state_dict['model'], load_dir)

if load_ds_ckpts:
model.load_state_dict(model_state_dict['module'], strict=strict)
else:
model.load_state_dict(model_state_dict['model'], strict=strict)

if torch.distributed.is_initialized():
torch.distributed.barrier()


def load_moe_checkpoint(model, state_dict, load_dir):
moe_layer_id = 0
for n_module, module in model.named_modules():
if isinstance(module, MoE): # and torch.distributed.get_rank() == 0:
group_name = module.expert_group_name
num_local_experts = module.num_local_experts
expp_rank = mpu.get_expert_parallel_rank(group_name)
# loop all local_experts
for local_expert_id in range(num_local_experts):
global_expert_id = expp_rank * num_local_experts + local_expert_id
moe_load_path = _get_expert_ckpt_name(load_dir, moe_layer_id,
global_expert_id)
logger.info(f'Loading expert states from {moe_load_path}')
expert_state_dict = torch.load(
moe_load_path, map_location=torch.device('cpu'))
# Updating global -> local expert ids
moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.'
for key in list(expert_state_dict.keys()):
local_key = key.replace(
f'{moe_str_prefix}{global_expert_id}',
f'{moe_str_prefix}{local_expert_id}')
expert_state_dict[local_key] = expert_state_dict.pop(key)
state_dict.update(expert_state_dict)
moe_layer_id += 1

+ 128
- 0
modelscope/models/nlp/gpt_moe/configuration.py View File

@@ -0,0 +1,128 @@
# Copyright 2021-2022 The Alibaba PAI Team Authors.
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging

logger = logging.get_logger(__name__)


class GPTMoEConfig(PretrainedConfig):

model_type = 'gpt-moe'

def __init__(
self,
vocab_size=25600,
hidden_size=768,
ffn_hidden_size=None,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act='gelu',
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=2048,
type_vocab_size=2,
layernorm_epsilon=1e-12,
bias_gelu_fusion=True,
fp32_residual_connection=False,
sequence_parallel=False,
fp16=False,
bf16=False,
apply_query_key_layer_scaling=True,
attention_softmax_in_fp32=False,
kv_channels=None,
masked_softmax_fusion=True,
attention_dropout=0.1,
bias_dropout_fusion=True,
apply_residual_connection_post_layernorm=False,
hidden_dropout=0.1,
init_method_std=0.02,
# generate
eod_id=7,
tokens_to_generate=100,
top_k=0,
top_p=0.9,
num_experts=[0],
use_tutel=False,
top_k_linear_strategy='standard',
use_expert_residual_network=False,
load_ds_ckpts=False,
model_dir=None,
**kwargs):
super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs)

self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.ffn_hidden_size = 4 * hidden_size \
if ffn_hidden_size is None else ffn_hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.layernorm_epsilon = layernorm_epsilon
self.bias_gelu_fusion = bias_gelu_fusion
self.fp32_residual_connection = fp32_residual_connection
self.sequence_parallel = sequence_parallel
self.fp16 = fp16
self.bf16 = bf16
assert not (fp16 and bf16)
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
if kv_channels is None:
assert hidden_size % num_attention_heads == 0
self.kv_channels = hidden_size // num_attention_heads
self.masked_softmax_fusion = masked_softmax_fusion
self.attention_dropout = attention_dropout
self.bias_dropout_fusion = bias_dropout_fusion
self.apply_residual_connection_post_layernorm = \
apply_residual_connection_post_layernorm
self.hidden_dropout = hidden_dropout
self.init_method_std = init_method_std
self.eod_id = eod_id
self.tokens_to_generate = tokens_to_generate
self.top_k = top_k
self.top_p = top_p
self.num_experts = num_experts
self.use_tutel = use_tutel
self.top_k_linear_strategy = top_k_linear_strategy
self.use_expert_residual_network = use_expert_residual_network
self.load_ds_ckpts = load_ds_ckpts
self.model_dir = model_dir

if self.num_experts[0] > torch.cuda.device_count():
self.moe_expert_parallel_size = torch.cuda.device_count()
else:
self.moe_expert_parallel_size = self.num_experts[0]

TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
self.no_persist_layer_norm = \
TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11)

@property
def params_dtype(self):
if self.fp16:
return torch.half
elif self.bf16:
return torch.bfloat16
else:
return torch.float

+ 1236
- 0
modelscope/models/nlp/gpt_moe/distributed_gpt_moe.py
File diff suppressed because it is too large
View File


+ 0
- 0
modelscope/models/nlp/gpt_moe/moe/__init__.py View File


+ 36
- 0
modelscope/models/nlp/gpt_moe/moe/experts.py View File

@@ -0,0 +1,36 @@
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''

import copy

import torch


class Experts(torch.nn.Module):

def __init__(self, expert, num_local_experts=1, expert_group_name=None):
super(Experts, self).__init__()

self.deepspeed_experts = torch.nn.ModuleList(
[copy.deepcopy(expert) for i in range(num_local_experts)])
self.num_local_experts = num_local_experts

# TODO: revisit allreduce for moe.gate...
for expert in self.deepspeed_experts:
# TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group)
for name, param in expert.named_parameters():
param.allreduce = False
param.group_name = expert_group_name

def forward(self, inputs):
chunks = inputs.chunk(self.num_local_experts, dim=1)
expert_outputs = []
for chunk, expert in zip(chunks, self.deepspeed_experts):
out = expert(chunk)
if type(out) is tuple:
out = out[0] # Ignore the bias term for now
expert_outputs += [out]

expert_output = torch.cat(expert_outputs, dim=1)
return expert_output

+ 98
- 0
modelscope/models/nlp/gpt_moe/moe/layer.py View File

@@ -0,0 +1,98 @@
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''

import typing

import torch
from megatron import mpu

from .experts import Experts
from .sharded_moe import MOELayer, TopKGate


class MoE(torch.nn.Module):

def __init__(self,
hidden_size,
expert,
num_experts=1,
ep_size=1,
k=1,
capacity_factor=1.,
eval_capacity_factor=1.,
min_capacity=4,
use_residual=False,
noisy_gate_policy: typing.Optional[str] = None,
drop_tokens: bool = True,
use_rts=True,
use_tutel: bool = False,
top_k_linear_strategy: str = 'normal',
use_expert_residual_network: bool = False):
super(MoE, self).__init__()
self.use_residual = use_residual
assert num_experts % ep_size == 0, f'Number of experts ({num_experts}) should ' \
f'be divisible by expert parallel size ({ep_size})'
self.ep_size = ep_size
self.expert_group_name = f'ep_size_{self.ep_size}'
self.num_experts = num_experts
self.num_local_experts = num_experts // self.ep_size

assert noisy_gate_policy is None or noisy_gate_policy in ['None', 'Jitter', 'RSample'], \
'Unsupported noisy_gate_policy: ' + noisy_gate_policy

experts = Experts(expert, self.num_local_experts,
self.expert_group_name)
self.deepspeed_moe = MOELayer(
TopKGate(
hidden_size,
num_experts,
k,
capacity_factor,
eval_capacity_factor,
min_capacity,
noisy_gate_policy,
drop_tokens,
use_rts,
top_k_linear_strategy=top_k_linear_strategy),
experts,
self.expert_group_name,
self.ep_size,
self.num_local_experts,
use_tutel=use_tutel,
use_expert_residual_network=use_expert_residual_network)

self.deepspeed_moe._set_ep_group(
mpu.get_expert_parallel_group(self.expert_group_name))

if self.use_residual:
self.mlp = expert
# coefficient is used for weighted sum of the output of expert and mlp
self.coefficient = torch.nn.Linear(hidden_size, 2)

def forward(self, hidden_states, used_token=None):
""" MoE forward

Arguments:
hidden_states (Tensor): input to the layer
used_token (Tensor, optional): default: None, mask only used tokens

Returns:
A tuple including output, gate loss, and expert count.

* output (Tensor): output of the model

* l_aux (Tensor): gate loss value

* exp_counts (int): expert count
"""
output = self.deepspeed_moe(hidden_states, used_token)
if self.use_residual:
# Residual MoE
output_mlp = self.mlp(hidden_states)
if type(output_mlp) is tuple:
output_mlp = output_mlp[0] # Ignore the bias term for now
coef = self.coefficient(hidden_states)
coef = torch.nn.functional.softmax(coef, dim=1)
output = output * coef[..., 0:1] + output_mlp * coef[..., 1:]
return output, self.deepspeed_moe.l_aux, self.deepspeed_moe.exp_counts

+ 87
- 0
modelscope/models/nlp/gpt_moe/moe/mappings.py View File

@@ -0,0 +1,87 @@
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''

import torch
from megatron import mpu


def _gather_tokens(input_, dim=0):
"""Gather tensors and concatenate them along a dimension"""

input_ = input_.contiguous()
# Size and dimension.
rank = mpu.get_tensor_model_parallel_rank()

tensor_list = [
torch.empty_like(input_)
for _ in range(mpu.get_model_parallel_world_size())
]
tensor_list[rank] = input_
torch.distributed.all_gather(
tensor_list, input_, group=mpu.get_tensor_model_parallel_group())

# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=dim).contiguous()

return output


def _drop_tokens(input_, dim=0):
"""Divide a tensor among the tensor parallel ranks"""
total_chunks = mpu.get_model_parallel_world_size()
this_chunk = mpu.get_model_parallel_rank()
assert input_.shape[
dim] % total_chunks == 0, f'input dimension {dim} ({input_.shape[dim]}) ' \
f'is not divisible by tensor parallel world size ({total_chunks})'
chunk_size = input_.shape[dim] // total_chunks

return torch.narrow(input_, dim, this_chunk * chunk_size, chunk_size)


class _GatherTokens(torch.autograd.Function):
"""All gather tokens among the tensor parallel ranks"""

@staticmethod
def symbolic(graph, input_, dim):
return _gather_tokens(input_, dim)

@staticmethod
def forward(ctx, input_, dim):
ctx.dim = dim
return _gather_tokens(input_, dim)

@staticmethod
def backward(ctx, grad_output):
return _drop_tokens(grad_output, ctx.dim), None


class _DropTokens(torch.autograd.Function):
'Divide tokens equally among the tensor parallel ranks'

@staticmethod
def symbolic(graph, input_, dim):
return _drop_tokens(input_, dim)

@staticmethod
def forward(ctx, input_, dim):
ctx.dim = dim
return _drop_tokens(input_, dim)

@staticmethod
def backward(ctx, input_):
return _gather_tokens(input_, ctx.dim), None


def gather_tokens(input_, dim=0):
if mpu is None or mpu.get_model_parallel_world_size() == 1:
# no tensor parallelism for non-experts
return input_
return _GatherTokens.apply(input_, dim)


def drop_tokens(input_, dim=0):
if mpu is None or mpu.get_model_parallel_world_size() == 1:
# no tensor parallelism for non-experts
return input_
return _DropTokens.apply(input_, dim)

+ 647
- 0
modelscope/models/nlp/gpt_moe/moe/sharded_moe.py View File

@@ -0,0 +1,647 @@
'''
Copyright 2021 The Microsoft DeepSpeed Team
'''
# The file has been adapted from two fairscale files:
# (1) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/moe_layer.py
# (2) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/top2gate.py
# Git commit hash: 34df606902a240567a0d898037ece55c2f1336cf
# We retain the following license from the original files:

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import math
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple

import torch
import torch.distributed as dist
import torch.nn.functional as F
from megatron import mpu
from scipy.special import binom
from torch import Tensor, nn
from torch.nn import Module

from ..configuration import logger
from .mappings import drop_tokens, gather_tokens

try:
from apex.normalization import FusedLayerNorm as _FusedLayerNorm

has_fused_layernorm = True

class FusedLayerNorm(_FusedLayerNorm):

@torch.jit.unused
def forward(self, x):
if not x.is_cuda:
return super().forward(x)
else:
with torch.cuda.device(x.device):
return super().forward(x)
except ImportError:
has_fused_layernorm = False

if TYPE_CHECKING:
Base = Module[Tensor]
else:
Base = Module

uniform_map: Dict[torch.device, Callable] = {}
gumbel_map: Dict[torch.device, Callable] = {}
exp_selection_uniform_map: Dict[torch.device, Callable] = {}


def multiplicative_jitter(x, device: torch.device, epsilon=1e-2):
"""
Modified from switch transformer paper. mesh transformers
Multiply values by a random number between 1-epsilon and 1+epsilon.
Makes models more resilient to rounding errors introduced by bfloat16.
This seems particularly important for logits.
Args:
x: a torch.tensor
device: torch.device
epsilon: a floating point value
Returns:
a jittered x.
"""
if epsilon == 0:
return x
uniform = uniform_map.get(device)
if uniform is None:
uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(1.0 - epsilon, device=device),
high=torch.tensor(1.0 + epsilon,
device=device)).rsample # type: ignore
uniform_map[device] = uniform
return x * uniform(x.shape)


def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
gumbel = gumbel_map.get(device)
if gumbel is None:
one = torch.tensor(1.0, device=device)
zero = torch.tensor(0.0, device=device)
gumbel = torch.distributions.gumbel.Gumbel(zero,
one).rsample # type: ignore
gumbel_map[device] = gumbel
return gumbel(shape)


# Based on https://github.com/pytorch/pytorch/pull/40762
class _AllToAll(torch.autograd.Function):

@staticmethod
def forward(ctx: Any, group: dist.ProcessGroup,
input: Tensor) -> Tensor: # type: ignore
ctx.group = group
input = input.contiguous()
output = torch.empty_like(input)
dist.all_to_all_single(output, input, group=group)
return output

@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]:
return (None, _AllToAll.apply(ctx.group, *grad_output))


# einsum rewrites are on par or more performant
# switch can be bubbled up in future
USE_EINSUM = True


# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
# See https://arxiv.org/pdf/2006.16668.pdf for details.
def einsum(rule, a, b):
if USE_EINSUM:
return torch.einsum(rule, a, b)
elif rule == 's,se->se':
return a.reshape(a.shape[0], -1) * b
elif rule == 'se,sc->sec':
return a.unsqueeze(2) * b.unsqueeze(1)
elif rule == 'se,se->s':
return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1)
elif rule == 'sec,sm->ecm':
s = a.shape[0]
e = a.shape[1]
c = a.shape[2]
m = b.shape[1]
return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m)
elif rule == 'sec,ecm->sm':
return torch.matmul(
a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1]))
elif rule == 'ks,ksm->sm':
k = b.shape[0]
s = b.shape[1]
m = b.shape[2]
# [k, s] -> [s, k] -> [s, 1, k]
a = a.t().unsqueeze(1)
# [k,s,m] -> [k, sm] -> [sm, k] -> [s, m, k]
b = b.reshape(k, -1).t().reshape(s, m, k)
# bmm([s, 1, k], [s, m, k]^t) -> [s, m, 1]
return torch.bmm(a, b.transpose(1, 2)).squeeze(2)
else:
return torch.einsum(rule, a, b)


# The following functions are extracted and scripted
# because otherwise during a torch.jit.trace, the non-Tensor
# values used in the calculations get recorded as constants.
# torch.jit.script coerces them into Tensors and preserves
# their dynamic shapes. This enables ONNX export.
# We can't script the entire top1gating function because it
# includes stateful caching logic which is incompatible with ONNX.


@torch.jit.script
def _capacity(gates: Tensor, capacity_factor: Tensor,
min_capacity: Tensor) -> Tensor:
# gates has shape of SE
num_tokens = gates.shape[0]
num_experts = gates.shape[1]
# to(torch.int64) works around a bug in torch.onnx.export:
# it should cast k to int64 when converting torch.topk but it doesn't.
capacity = torch.ceil(
(num_tokens / num_experts) * capacity_factor).to(torch.int64)
if capacity < min_capacity:
capacity = min_capacity.to(torch.int64)
return capacity


@torch.jit.script
def _top_idx(source, k):
return torch.topk(source, k=k, dim=0)[1]


@torch.jit.script
def _one_hot_to_float(x, num_classes):
return F.one_hot(x, num_classes=num_classes).float()


def top1gating(
logits: Tensor,
capacity_factor: float,
min_capacity: int,
used_token: Tensor = None,
noisy_gate_policy: Optional[str] = None,
drop_tokens: bool = True,
use_rts: bool = True,
use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Implements Top1Gating on logits."""
if noisy_gate_policy == 'RSample':
logits_w_noise = logits + gumbel_rsample(
logits.shape, device=logits.device)
# everything is in fp32 in this function
gates = F.softmax(logits, dim=1)

capacity = _capacity(gates, torch.tensor(capacity_factor),
torch.tensor(min_capacity))

# Create a mask for 1st's expert per token
# noisy gating
indices1_s = torch.argmax(
logits_w_noise if noisy_gate_policy == 'RSample' else gates, dim=1)
num_experts = int(gates.shape[1])
mask1 = F.one_hot(indices1_s, num_classes=num_experts)

# mask only used tokens
if used_token is not None:
mask1 = einsum('s,se->se', used_token, mask1)

# gating decisions
exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')

# if we don't want to drop any tokens
if not drop_tokens:
new_capacity = torch.max(exp_counts).to(logits.device)
dist.all_reduce(
new_capacity, op=dist.ReduceOp.MAX, group=dist.group.WORLD)
capacity = new_capacity

# Compute l_aux
alpha = torch.max(gates, dim=1).values.unsqueeze(1)
me = torch.mean(gates, dim=0)
ce = torch.mean(mask1.float(), dim=0)
l_aux = torch.sum(me * ce) * num_experts

# Random Token Selection
if use_rts:
uniform = exp_selection_uniform_map.get(logits.device)
if uniform is None:
uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(0.0, device=logits.device),
high=torch.tensor(1.0, device=logits.device)).rsample
exp_selection_uniform_map[logits.device] = uniform

mask1_rand = mask1 * uniform(mask1.shape)
else:
mask1_rand = mask1

assert logits.shape[0] >= min_capacity, \
'No. of tokens (batch-size) should be greater than min_capacity. ' \
'Either set min_capacity to 0 or increase your batch size.'

top_idx = _top_idx(mask1_rand, capacity)

new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
mask1 = new_mask1

if use_tutel:
# Tutel doesn't support index values masked with zero
# so we need to replace masked indices with -1
indices_mask = mask1.sum(dim=1) * num_experts - 1
indices1_s = torch.min(indices1_s, indices_mask)

# Compute locations in capacity buffer
if use_tutel:
locations1 = tutel_moe.fast_cumsum_sub_one(mask1)
else:
locations1 = torch.cumsum(mask1, dim=0) - 1

if use_tutel:
gates1_s = (gates * mask1).sum(dim=1)
locations1_s = torch.sum(locations1 * mask1, dim=1)
return l_aux, capacity, num_experts, [
indices1_s,
], [
locations1_s,
], [
gates1_s,
], exp_counts, alpha

# Store the capacity location for each token
locations1_s = torch.sum(locations1 * mask1, dim=1)

# Normalize gate probabilities
mask1_float = mask1.float()
gates = gates * mask1_float

locations1_sc = _one_hot_to_float(locations1_s, capacity)
combine_weights = einsum('se,sc->sec', gates, locations1_sc)

dispatch_mask = combine_weights.bool()

return l_aux, combine_weights, dispatch_mask, exp_counts, alpha


class TopKGate(Module):
"""Gate module which implements Top2Gating as described in Gshard_.
::

gate = TopKGate(model_dim, num_experts)
l_aux, combine_weights, dispatch_mask = gate(input)

.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf

Args:
model_dim (int):
size of model embedding dimension
num_experts (ints):
number of experts in model
"""

wg: torch.nn.Linear

def __init__(self,
model_dim: int,
num_experts: int,
k: int = 1,
capacity_factor: float = 1.0,
eval_capacity_factor: float = 1.0,
min_capacity: int = 8,
noisy_gate_policy: Optional[str] = None,
drop_tokens: bool = True,
use_rts: bool = True,
top_k_linear_strategy: str = 'standard') -> None:
super().__init__()

# Only top-1 are supported at the moment.
if k != 1:
raise ValueError('Only top-1 gatings are supported.')
if top_k_linear_strategy == 'standard':
self.wg = torch.nn.Linear(
model_dim, num_experts, bias=False).float()
elif top_k_linear_strategy == 'lsoftmax':
self.wg = LSoftmaxLinearLayer(
model_dim, num_experts, margin=1).float()
else:
raise ValueError(
'Only standard or lsoftmax top-k-linear-strategy are supported.'
)

self.k = k
self.capacity_factor = capacity_factor
self.eval_capacity_factor = eval_capacity_factor
self.min_capacity = min_capacity
self.noisy_gate_policy = noisy_gate_policy
self.wall_clock_breakdown = False
self.gate_time = 0.0
self.drop_tokens = drop_tokens
self.use_rts = use_rts
self.top_k_linear_strategy = top_k_linear_strategy

def forward(
self,
input: torch.Tensor,
used_token: torch.Tensor = None,
use_tutel: bool = False
) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore

if self.wall_clock_breakdown:
self.timers('TopKGate').start()

if self.top_k_linear_strategy == 'standard':
if self.wg.weight.dtype != torch.float32:
self.wg = self.wg.float()
elif self.top_k_linear_strategy == 'lsoftmax':
if self.wg.weight.weight.dtype != torch.float32:
self.wg.weight = self.wg.weight.float()

input_fp32 = input.float()
# input jittering
if self.noisy_gate_policy == 'Jitter' and self.training:
input_fp32 = multiplicative_jitter(input_fp32, device=input.device)

if self.k == 1:
if self.top_k_linear_strategy == 'standard':
logits = self.wg(input_fp32)
elif self.top_k_linear_strategy == 'lsoftmax':
logits = self.wg(input_fp32, input_fp32.device, self.training)

gate_output = top1gating(
logits, self.capacity_factor if self.training else
self.eval_capacity_factor, self.min_capacity, used_token,
self.noisy_gate_policy if self.training else None,
self.drop_tokens, self.use_rts, use_tutel)

if self.wall_clock_breakdown:
self.timers('TopKGate').stop()
self.gate_time = self.timers('TopKGate').elapsed(
reset=False) * 1000

return gate_output


class MOELayer(Base):
"""MOELayer module which implements MixtureOfExperts as described in Gshard_.
::

gate = TopKGate(model_dim, num_experts)
moe = MOELayer(gate, expert)
output = moe(input)
l_aux = moe.l_aux

.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf

Args:
gate (torch.nn.Module):
gate network
expert (torch.nn.Module):
expert network
"""

def __init__(self,
gate: Module,
experts: Module,
ep_group_name,
ep_size,
num_local_experts: int,
use_tutel: bool = False,
use_expert_residual_network: bool = False) -> None:
super().__init__()
self.gate = gate
self.experts = experts
self.ep_group = None
self.ep_size = ep_size
self.ep_group_name = ep_group_name
self.num_local_experts = num_local_experts

self.wall_clock_breakdown = False
self.use_expert_residual_network = use_expert_residual_network

if self.use_expert_residual_network:
self.expert_network = nn.Sequential(
*([ExpertResidualLayer(self.gate.model_dim)
for _ in range(6)]))

self.use_tutel = use_tutel and TUTEL_INSTALLED

if self.use_tutel:
logger.info('Using Tutel optimizations.')
elif use_tutel and not TUTEL_INSTALLED:
logger.info(
'Tutel optimization requested but not installed Proceeding without Tutel.'
)

def _set_ep_group(self, ep_group):
self.ep_group = ep_group

def forward(self, *input: Tensor, **kwargs: Any) -> Tensor:

if self.wall_clock_breakdown:
self.timers('moe').start()

# Implement Algorithm 2 from GShard paper.
d_model = input[0].shape[-1]

# Initial implementation -> Reshape into S tokens by dropping sequence dimension.
# Reshape into G groups so that each group can distribute tokens equally
# group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1
reshaped_input = input[0].reshape(-1, d_model)

if self.use_tutel:
self.l_aux, C, E, indices_, locations_, gates_, self.exp_counts, alpha = self.gate(
reshaped_input, input[1], True)
_, M = reshaped_input.size(0), reshaped_input.size(1)

if not hasattr(self, '_tutel_dispatcher'):
self._tutel_dispatcher = tutel_moe.fast_dispatcher(
E, C, M, dispatch_dtype=reshaped_input.dtype)
self._tutel_dispatcher.update(
indices_, locations_, gates_, capacity=C)
dispatched_input = self._tutel_dispatcher.encode(reshaped_input)
else:
self.l_aux, combine_weights, dispatch_mask, self.exp_counts, alpha = self.gate(
reshaped_input, input[1])
dispatched_input = einsum('sec,sm->ecm',
dispatch_mask.type_as(input[0]),
reshaped_input)

if self.wall_clock_breakdown:
self.timers('falltoall').start()

if mpu.get_expert_model_parallel_world_size() == 1:
# If the non-expert is tensor-parallel, it will create
# duplicate tokens on the tensor-parallel ranks.
# Since our experts are not tensor-parallel, these duplicates
# need to be dropped to ensure correctness.
# this also doubles up as a communication optimization as we are
# reducing the all-to-all communication volume.
if self.use_tutel:
# reshape tutel's output from [e*c,m] to [e,c,m]
dispatched_input = dispatched_input.reshape(
self.ep_size * self.num_local_experts, -1, d_model)
dispatched_input = drop_tokens(dispatched_input, dim=1)

dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)

if self.wall_clock_breakdown:
self.timers('falltoall').stop()
self.time_falltoall = self.timers('falltoall').elapsed(
reset=False) * 1000

# Re-shape after all-to-all: ecm -> gecm
dispatched_input = dispatched_input.reshape(self.ep_size,
self.num_local_experts, -1,
d_model)

expert_output = self.experts(dispatched_input)

if self.wall_clock_breakdown:
self.timers('salltoall').start()

expert_output = _AllToAll.apply(self.ep_group, expert_output)

if self.wall_clock_breakdown:
self.timers('salltoall').stop()
self.time_salltoall = self.timers('salltoall').elapsed(
reset=False) * 1000

# Re-shape back: gecm -> ecm
expert_output = expert_output.reshape(
self.ep_size * self.num_local_experts, -1, d_model)

if mpu.get_expert_model_parallel_world_size() == 1:
# the dropped duplicate tokens need to be gathered on each
# tensor parallel rank again for the tensor-parallel
# non-expert of the next layer.
expert_output = gather_tokens(expert_output, dim=1)

if self.use_tutel:
combined_output = self._tutel_dispatcher.decode(
expert_output.view(E * C, M))
else:
combined_output = einsum('sec,ecm->sm',
combine_weights.type_as(input[0]),
expert_output)

if self.use_expert_residual_network:
combined_output = alpha * self.expert_network(combined_output) + (
1 - alpha) * combined_output

a = combined_output.reshape(input[0].shape)

if self.wall_clock_breakdown:
self.timers('moe').stop()
self.time_moe = self.timers('moe').elapsed(reset=False) * 1000

return a


class LSoftmaxLinearLayer(torch.nn.Module):

def __init__(self, input_features, output_features, margin):
super().__init__()
self.input_dim = input_features # number of input feature i.e. output of the last fc layer
self.output_dim = output_features # number of output = class numbers
self.margin = margin # m
self.beta = 100
self.beta_min = 0
self.scale = 0.99
self.num_experts = output_features
# Initialize L-Softmax parameters
self.weight = torch.nn.Linear(
input_features, output_features, bias=False).float()
self.divisor = math.pi / self.margin # pi/m
self.C_m_2n = torch.Tensor(binom(margin, range(0, margin + 1,
2))) # C_m{2n}
self.cos_powers = torch.Tensor(range(self.margin, -1, -2)) # m - 2n
self.sin2_powers = torch.Tensor(range(len(self.cos_powers))) # n
self.signs = torch.ones(margin // 2 + 1) # 1, -1, 1, -1, ...
self.signs[1::2] = -1

def calculate_cos_m_theta(self, cos_theta, device):
sin2_theta = 1 - cos_theta**2
cos_terms = cos_theta.unsqueeze(1)**self.cos_powers.to(
device).unsqueeze(0) # cos^{m - 2n}
sin2_terms = (
sin2_theta.unsqueeze(1)**self.sin2_powers.to(device).unsqueeze(0))

cos_m_theta = (self.signs.to(device).unsqueeze(0)
* self.C_m_2n.to(device).unsqueeze(0) * cos_terms
* sin2_terms).sum(1) # summation of all terms

return cos_m_theta

def reset_parameters(self):
nn.init.kaiming_normal_(self.weight.data.t())

def find_k(self, cos):
# to account for acos numerical errors
eps = 1e-7
cos = torch.clamp(cos, -1 + eps, 1 - eps)
acos = cos.acos()
k = (acos / self.divisor).floor().detach()
return k

def forward(self, input, device, training):
if training:
x, w = input, self.weight.float()
beta = max(self.beta, self.beta_min)
logit = w(x)
indexes = range(logit.size(0))
# target = torch.fmod(torch.randperm(logit.size(0)), self.num_experts)
target = torch.fmod(
torch.range(0,
logit.size(0) - 1), self.num_experts).long()
logit_target = logit[indexes, target]

# cos(theta) = w * x / ||w||*||x||
w_target_norm = w.weight[:, target].norm(p=2, dim=0)

x_norm = x.norm(p=2, dim=1)
cos_theta_target = logit_target / (w_target_norm * x_norm + 1e-10)

# equation 7
cos_m_theta_target = self.calculate_cos_m_theta(
cos_theta_target, device)

# find k in equation 6
k = self.find_k(cos_theta_target)

# f_y_i
logit_target_updated = w_target_norm * x_norm * ((
(-1)**k * cos_m_theta_target) - 2 * k)
logit_target_updated_beta = (logit_target_updated + beta
* logit[indexes, target]) / (1 + beta)

logit[indexes, target] = logit_target_updated_beta
self.beta *= self.scale
return logit
else:
return self.weight(input)


def LayerNorm(normalized_shape,
eps=1e-5,
elementwise_affine=True,
export=False):
if torch.jit.is_scripting() or torch.jit.is_tracing():
export = True
if not export and torch.cuda.is_available() and has_fused_layernorm:
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)


class ExpertResidualLayer(torch.nn.Module):

def __init__(self, embed_dim):
super().__init__()
self.norm = LayerNorm(embed_dim, export=False)
self.ff1 = torch.nn.Linear(embed_dim, embed_dim * 4)
self.ff2 = torch.nn.Linear(embed_dim * 4, embed_dim)
self.ff2.weight.data.zero_()

def forward(self, xs):
return xs + self.ff2(torch.nn.functional.relu(self.ff1(self.norm(xs))))

+ 125
- 0
modelscope/models/nlp/gpt_moe/moe/utils.py View File

@@ -0,0 +1,125 @@
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''

from typing import Dict, List, Tuple

import torch

from .layer import MoE


def has_moe_layers(m):
has_moe = False
num_experts = 0
for _, module in m.named_modules():
if isinstance(module, MoE):
has_moe = True
num_experts = module.num_experts
break
return has_moe, num_experts


def is_moe_param(param: torch.Tensor) -> bool:
if hasattr(param, 'allreduce') and not param.allreduce:
return True
return False


def split_params_into_shared_and_expert_params(
params: List[torch.nn.Parameter]
) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]:
shared_params, expert_params = [], []
for p in params:
if is_moe_param(p):
expert_params.append(p)
else:
shared_params.append(p)
return shared_params, expert_params


def split_params_grads_into_shared_and_expert_params(
group: List[torch.nn.Parameter]
) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]:
"""Split grad of parameters into grads of non-expert params
and grads of expert params. This is useful while computing
grad-norms for clipping and overflow detection

group (List[torch.nn.Parameter]):
Args:
The group of parameters to split

Returns:
Tuple[List[torch.nn.Parameter], List[torch.nn.Parameter]]:
list of gradients for non MoE params, list of gradients of MoE params
"""
expert_grads = []
shared_grads = []
for p in group:
if p.grad is not None:
if is_moe_param(p):
expert_grads.append(p.grad.to(p.dtype))
else:
shared_grads.append(p.grad.to(p.dtype))
return shared_grads, expert_grads


def split_params_into_different_moe_groups_for_optimizer(
param_groups: Tuple[Dict]) -> Tuple[Dict]:
"""Split parameters into different MoE groups for optimizer

Args:
param_groups (Tuple[Dict]):
The list of parameter groups to split

Returns:
Tuple[Dict]:
list of MoE/non-MoE groups for optimizer
"""
if isinstance(param_groups, tuple):
param_groups = list(param_groups) # Tuple cannot be modified
elif isinstance(param_groups, dict):
param_groups = [param_groups]
elif not isinstance(param_groups, list):
raise ValueError(f'Unknown param group type of {type(param_groups)}')

# gather all data parallel group names
data_parallel_group_names = set()
for param_group in param_groups:
for param in param_group['params']:
if is_moe_param(param):
data_parallel_group_names.add(param.group_name)
data_parallel_group_names = list(data_parallel_group_names)
group_moe = {}
# Create the param MoE groups, leave param assign to next step
for param_group in param_groups:
group_moe[param_group['name']] = {}
for key in data_parallel_group_names:
group_moe[param_group['name']][key] = {}
group_moe[param_group['name']][key]['name'] = key
group_moe[param_group['name']][key]['moe'] = True
for ori_key in param_group.keys():
if ori_key != 'name':
if ori_key == 'params':
group_moe[param_group['name']][key][ori_key] = []
else:
group_moe[param_group['name']][key][
ori_key] = param_group[ori_key]
# Assign param
for param_group in param_groups:
new_params = []
for param in param_group['params']:
if is_moe_param(param):
group_moe[param_group['name']][
param.group_name]['params'].append(param)
# param_group['params'].remove(param)
else:
new_params.append(param)
param_group['params'] = new_params

# Flatten the moe groups
for k, v in group_moe.items():
for k1, v1 in v.items():
param_groups.append(v1)

return tuple(param_groups)

+ 62
- 0
modelscope/models/nlp/gpt_moe/text_generation.py View File

@@ -0,0 +1,62 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Dict

from modelscope.metainfo import Models
from modelscope.models.base import Tensor, TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import Tasks

__all__ = ['GPTMoEForTextGeneration']


@MODELS.register_module(Tasks.text_generation, module_name=Models.gpt_moe)
class GPTMoEForTextGeneration(TorchModel):

def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the text generation model from the `model_dir` path.

Args:
model_dir (str): the model path.
"""
super().__init__(model_dir, *args, **kwargs)

from modelscope.models.nlp.gpt_moe import GPTMoEModel
from transformers import BertTokenizer
print('****')
print(model_dir)
self.model = GPTMoEModel.from_pretrained(model_dir)
self.tokenizer = BertTokenizer.from_pretrained(model_dir)

def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""return the result by the model

Args:
input (Dict[str, Tensor]): the preprocessed data

Returns:
Dict[str, Tensor]: results
Example:
{
'logits': Tensor([[0.54, 0.32...])]), # logits
}
"""
return self.model(**input)

def generate(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
assert 'input_ids' in input, "generate function must accept 'input_ids' key"
input_ids = input['input_ids']
if 'attention_mask' in input:
attention_mask = input['attention_mask']
input_ids = input_ids[0][attention_mask[0].nonzero()] \
.squeeze().unsqueeze(0)
# remove sep token at the end of tokenizer output
input_ids = input_ids[:, :-1]

gen_params = dict()
gen_params['inputs'] = input_ids
gen_params['do_sample'] = input.pop('do_sample', True)
gen_params['max_length'] = input.pop('max_length', 128)
gen_params['top_k'] = input.pop('top_k', 10)
gen_params['top_p'] = input.pop('top_p', None)
sample_output = self.model.generate(**gen_params)
return {'sequences': sample_output[0]}

+ 67
- 0
modelscope/models/nlp/gpt_moe/tokenizer.py View File

@@ -0,0 +1,67 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from tokenizers import Tokenizer


class JiebaBPETokenizer:
"""SentencePiece BPE tokenizer with Jieba integration"""

def __init__(self, tokenizer_json_file):
self.name = 'Jieba BPE Tokenizer'

self.tokenizer = Tokenizer.from_file(tokenizer_json_file)
self.eod_id = self.tokenizer.token_to_id('<|endoftext|>')
try:
import jieba
except ImportError:
raise ImportError(
'You need to install rjieba to use JiebaTokenizer. '
'See https://pypi.org/project/rjieba/ for installation.')
self.jieba = jieba
self.new_line = self.vocab['\n']
self.sep_token = self.vocab['<sep>']

@property
def vocab_size(self):
return self.tokenizer.get_vocab_size(with_added_tokens=True)

@property
def vocab(self):
return self.tokenizer.get_vocab(with_added_tokens=True)

@property
def inv_vocab(self):
vocab = self.vocab
inv_vocab = dict()
for key, val in vocab.items():
inv_vocab[val] = key
return inv_vocab

def tokenize(self, text, is_code=False):
if not is_code:
seg_list = [x for x in self.jieba.cut(text)]
return self.tokenizer.encode(
seg_list, is_pretokenized=True, add_special_tokens=True).ids
else:
return self.tokenizer.encode(
text, is_pretokenized=False, add_special_tokens=True).ids

def detokenize(self, token_ids):
text = self.tokenizer.decode(token_ids, skip_special_tokens=False)
return text

@property
def eod(self):
return self.eod_id

+ 54
- 0
modelscope/pipelines/nlp/distributed_gpt_moe_pipeline.py View File

@@ -0,0 +1,54 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Any, Dict

import torch

from modelscope.metainfo import Pipelines
from modelscope.models.nlp.gpt_moe.distributed_gpt_moe import DistributedGPTMoE
from modelscope.pipelines.base import DistributedPipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import TextGenerationJiebaPreprocessor
from modelscope.utils.constant import Tasks


@PIPELINES.register_module(
Tasks.text_generation, module_name=Pipelines.gpt_moe_generation)
class DistributedGPTMoEPipeline(DistributedPipeline):
"""This class is used to instantiate the gpt-moe model.
"""

model = None

def __init__(self, model, preprocessor=None, **kwargs):
if preprocessor is None:
preprocessor = TextGenerationJiebaPreprocessor(model)
super().__init__(model, preprocessor=preprocessor, **kwargs)
assert hasattr(preprocessor, 'tokenizer')

@classmethod
def _instantiate_one(cls, rank, model_dir, **kwargs):
cls.model = DistributedGPTMoE(model_dir, rank, **kwargs)
cls.model.eval()

@classmethod
def _forward_one(cls, inputs: Dict[str, Any]) -> Dict[str, Any]:
tokens = inputs['inputs']['input_ids'].cuda(
torch.cuda.current_device())
return cls.model.generate(tokens)

def postprocess(self, inputs: Dict[str, Any],
**postprocess_params) -> Dict[str, str]:
"""process the prediction results

Args:
inputs (Dict[str, Any]): _description_

Returns:
Dict[str, str]: the prediction results
"""
from modelscope.outputs import OutputKeys
return {
OutputKeys.TEXT:
self.preprocessor.tokenizer.detokenize(inputs[0].tolist())
}

+ 24
- 0
tests/pipelines/test_gpt_moe_text_generation.py View File

@@ -0,0 +1,24 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class TextGPTMoEGenerationTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id_1_3B_MoE32 = 'PAI/nlp_gpt3_text-generation_1.3B_MoE-32'
self.model_dir_1_3B_MoE32 = snapshot_download(self.model_id_1_3B_MoE32)
self.input = '好的'

@unittest.skip('distributed gpt-moe 1.3B_MoE-32, skipped')
def test_gpt_moe_1_3B_MoE32(self):
pipe = pipeline(Tasks.text_generation, model=self.model_id_1_3B_MoE32)
print(pipe(self.input))


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save