Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9931748master
@@ -55,6 +55,7 @@ class Models(object): | |||
lcrf = 'lstm-crf' | |||
bart = 'bart' | |||
gpt3 = 'gpt3' | |||
plug = 'plug' | |||
bert_for_ds = 'bert-for-document-segmentation' | |||
# audio models | |||
@@ -172,6 +173,7 @@ class Pipelines(object): | |||
dialog_state_tracking = 'dialog-state-tracking' | |||
zero_shot_classification = 'zero-shot-classification' | |||
text_error_correction = 'text-error-correction' | |||
plug_generation = 'plug-generation' | |||
faq_question_answering = 'faq-question-answering' | |||
conversational_text_to_sql = 'conversational-text-to-sql' | |||
relation_extraction = 'relation-extraction' | |||
@@ -28,6 +28,7 @@ if TYPE_CHECKING: | |||
SingleBackboneTaskModelBase) | |||
from .bart_for_text_error_correction import BartForTextErrorCorrection | |||
from .gpt3 import GPT3ForTextGeneration | |||
from .plug import PlugForTextGeneration | |||
from .sbert_for_faq_question_answering import SbertForFaqQuestionAnswering | |||
else: | |||
@@ -60,6 +61,7 @@ else: | |||
], | |||
'bart_for_text_error_correction': ['BartForTextErrorCorrection'], | |||
'gpt3': ['GPT3ForTextGeneration'], | |||
'plug': ['PlugForTextGeneration'], | |||
'sbert_for_faq_question_answering': ['SbertForFaqQuestionAnswering'], | |||
} | |||
@@ -0,0 +1,27 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .configuration_plug import PlugNLGConfig | |||
from .modeling_plug import PlugModel | |||
from .distributed_plug import DistributedPlug | |||
from .plug_for_text_generation import PlugForTextGeneration | |||
else: | |||
_import_structure = { | |||
'configuration_plug': ['PlugNLGConfig'], | |||
'modeling_plug': ['PlugModel'], | |||
'distributed_plug': ['DistributedPlug'], | |||
'plug_for_text_generation': ['PlugForTextGeneration'], | |||
} | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
__name__, | |||
globals()['__file__'], | |||
_import_structure, | |||
module_spec=__spec__, | |||
extra_objects={}, | |||
) |
@@ -0,0 +1,232 @@ | |||
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | |||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import copy | |||
import json | |||
from transformers import PretrainedConfig | |||
from modelscope.utils import logger as logging | |||
logger = logging.get_logger(__name__) | |||
class PlugNLUConfig(PretrainedConfig): | |||
model_type = 'plugNLU' | |||
def __init__(self, | |||
vocab_size=21504, | |||
original_vocab_size=21128, | |||
hidden_size=8192, | |||
num_hidden_layers=24, | |||
num_attention_heads=128, | |||
intermediate_size=32768, | |||
hidden_act='gelu', | |||
hidden_dropout_prob=0.1, | |||
attention_probs_dropout_prob=0.1, | |||
max_position_embeddings=2048, | |||
type_vocab_size=3, | |||
initializer_range=0.00707, | |||
deep_init=False, | |||
deepspeed=False, | |||
lr_decay_style='linear', | |||
weight_decay=1e-2, | |||
clip_grad=1.0, | |||
warmup=0.0333, | |||
pre_ln=True, | |||
fp16=True, | |||
fp32_layernorm=True, | |||
fp32_embedding=False, | |||
fp32_tokentypes=False, | |||
layernorm_epsilon=1e-5, | |||
dec_hidden_layers=6, | |||
pruning_method=None, | |||
pruning_mask_init='constant', | |||
pruning_mask_scale=0.0, | |||
pruning_initial_threshold=1.0, | |||
pruning_final_threshold=0.01, | |||
pruning_initial_warmup=1, | |||
pruning_final_warmup=20, | |||
pruning_module='decoder', | |||
pruning_decay_step=50, | |||
pruning_decay_type='exp', | |||
ft_module=None, | |||
attn_separate=False, | |||
LR_weight_rank=8, | |||
LR_mask_rank=8, | |||
**kwargs): | |||
super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs) | |||
self.vocab_size = vocab_size | |||
self.original_vocab_size = original_vocab_size | |||
self.hidden_size = hidden_size | |||
self.num_hidden_layers = num_hidden_layers | |||
self.num_attention_heads = num_attention_heads | |||
self.hidden_act = hidden_act | |||
self.intermediate_size = intermediate_size | |||
self.hidden_dropout_prob = hidden_dropout_prob | |||
self.attention_probs_dropout_prob = attention_probs_dropout_prob | |||
self.max_position_embeddings = max_position_embeddings | |||
self.type_vocab_size = type_vocab_size | |||
self.initializer_range = initializer_range | |||
self.deep_init = deep_init | |||
self.deepspeed = deepspeed | |||
self.lr_decay_style = lr_decay_style | |||
self.weight_decay = weight_decay | |||
self.clip_grad = clip_grad | |||
self.warmup = warmup | |||
self.pre_ln = pre_ln | |||
self.fp16 = fp16 | |||
self.fp32_layernorm = fp32_layernorm | |||
self.fp32_embedding = fp32_embedding | |||
self.layernorm_epsilon = layernorm_epsilon | |||
self.fp32_tokentypes = fp32_tokentypes | |||
self.dec_hidden_layers = dec_hidden_layers | |||
self.pruning_method = pruning_method | |||
self.pruning_mask_init = pruning_mask_init | |||
self.pruning_mask_scale = pruning_mask_scale | |||
self.pruning_module = pruning_module | |||
self.pruning_initial_threshold = pruning_initial_threshold | |||
self.pruning_final_threshold = pruning_final_threshold | |||
self.pruning_initial_warmup = pruning_initial_warmup | |||
self.pruning_final_warmup = pruning_final_warmup | |||
self.pruning_decay_step = pruning_decay_step | |||
self.pruning_decay_type = pruning_decay_type | |||
self.ft_module = ft_module | |||
self.attn_separate = attn_separate | |||
self.LR_weight_rank = LR_weight_rank | |||
self.LR_mask_rank = LR_mask_rank | |||
@classmethod | |||
def from_dict(cls, json_object): | |||
"""Constructs a `BertConfig` from a Python dictionary of parameters.""" | |||
config = PlugNLUConfig() | |||
for key, value in json_object.items(): | |||
config.__dict__[key] = value | |||
return config | |||
@classmethod | |||
def from_json_file(cls, json_file): | |||
"""Constructs a `BertConfig` from a json file of parameters.""" | |||
with open(json_file, 'r', encoding='utf-8') as reader: | |||
text = reader.read() | |||
return cls.from_dict(json.loads(text)) | |||
def merge_args(self, args): | |||
"""merge values a `BertConfig` from a json file of parameters.""" | |||
local_keys = self.__dict__.keys() | |||
for key, value in args.__dict__.items(): | |||
if key in local_keys: | |||
continue | |||
self.__dict__[key] = value | |||
return self | |||
def __repr__(self): | |||
return str(self.to_json_string()) | |||
def to_dict(self): | |||
"""Serializes this instance to a Python dictionary.""" | |||
output = copy.deepcopy(self.__dict__) | |||
return output | |||
def to_json_string(self): | |||
"""Serializes this instance to a JSON string.""" | |||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + '\n' | |||
class PlugNLGConfig(PlugNLUConfig): | |||
model_type = 'plugNLG' | |||
def __init__(self, | |||
vocab_size=21504, | |||
hidden_size=768, | |||
num_hidden_layers=12, | |||
num_attention_heads=12, | |||
intermediate_size=3072, | |||
hidden_act='gelu', | |||
hidden_dropout_prob=0.1, | |||
attention_probs_dropout_prob=0.1, | |||
max_position_embeddings=512, | |||
type_vocab_size=2, | |||
initializer_range=0.00707, | |||
deep_init=False, | |||
deepspeed=False, | |||
lr_decay_style='linear', | |||
weight_decay=1e-2, | |||
clip_grad=1.0, | |||
warmup=0.01, | |||
pre_ln=False, | |||
fp16=False, | |||
fp32_layernorm=False, | |||
fp32_embedding=False, | |||
fp32_tokentypes=False, | |||
layernorm_epsilon=1e-12, | |||
dec_hidden_layers=6, | |||
pruning_method=None, | |||
pruning_mask_init='constant', | |||
pruning_mask_scale=0.0, | |||
pruning_initial_threshold=1.0, | |||
pruning_final_threshold=0.01, | |||
pruning_initial_warmup=1, | |||
pruning_final_warmup=20, | |||
pruning_module='decoder', | |||
pruning_decay_step=50, | |||
pruning_decay_type='exp', | |||
ft_module=None, | |||
attn_separate=False, | |||
LR_weight_rank=8, | |||
LR_mask_rank=8, | |||
**kwargs): | |||
super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs) | |||
self.vocab_size = vocab_size | |||
self.hidden_size = hidden_size | |||
self.num_hidden_layers = num_hidden_layers | |||
self.num_attention_heads = num_attention_heads | |||
self.hidden_act = hidden_act | |||
self.intermediate_size = intermediate_size | |||
self.hidden_dropout_prob = hidden_dropout_prob | |||
self.attention_probs_dropout_prob = attention_probs_dropout_prob | |||
self.max_position_embeddings = max_position_embeddings | |||
self.type_vocab_size = type_vocab_size | |||
self.initializer_range = initializer_range | |||
self.deep_init = deep_init | |||
self.deepspeed = deepspeed | |||
self.lr_decay_style = lr_decay_style | |||
self.weight_decay = weight_decay | |||
self.clip_grad = clip_grad | |||
self.warmup = warmup | |||
self.pre_ln = pre_ln | |||
self.fp16 = fp16 | |||
self.fp32_layernorm = fp32_layernorm | |||
self.fp32_embedding = fp32_embedding | |||
self.layernorm_epsilon = layernorm_epsilon | |||
self.fp32_tokentypes = fp32_tokentypes | |||
self.dec_hidden_layers = dec_hidden_layers | |||
self.pruning_method = pruning_method | |||
self.pruning_mask_init = pruning_mask_init | |||
self.pruning_mask_scale = pruning_mask_scale | |||
self.pruning_module = pruning_module | |||
self.pruning_initial_threshold = pruning_initial_threshold | |||
self.pruning_final_threshold = pruning_final_threshold | |||
self.pruning_initial_warmup = pruning_initial_warmup | |||
self.pruning_final_warmup = pruning_final_warmup | |||
self.pruning_decay_step = pruning_decay_step | |||
self.pruning_decay_type = pruning_decay_type | |||
self.ft_module = ft_module | |||
self.attn_separate = attn_separate | |||
self.LR_weight_rank = LR_weight_rank | |||
self.LR_mask_rank = LR_mask_rank |
@@ -0,0 +1,191 @@ | |||
import os | |||
from typing import Dict | |||
import torch | |||
import torch.nn.functional as F | |||
from megatron import mpu | |||
from megatron.fp16 import FP16_Module | |||
from megatron.utils import print_rank_0 | |||
from modelscope.models import TorchModel | |||
from modelscope.models.base import Tensor | |||
from modelscope.utils.logger import get_logger | |||
from modelscope.utils.nlp.distributed import initialize_distributed | |||
from modelscope.utils.nlp.load_checkpoint import pre_load | |||
from modelscope.utils.torch_utils import set_random_seed_mpu | |||
from . import PlugModel | |||
from .configuration_plug import PlugNLGConfig | |||
logger = get_logger(__name__) | |||
class DistributedPlug(TorchModel): | |||
def __init__(self, model_dir, rank, **kwargs): | |||
super().__init__(model_dir, **kwargs) | |||
self.rank = rank | |||
self.model_cfg = kwargs | |||
self.config = PlugNLGConfig.from_pretrained(model_dir) | |||
initialize_distributed(rank, mpu, kwargs['world_size'], | |||
kwargs['model_parallel_size'], | |||
kwargs['master_ip'], kwargs['master_port']) | |||
seed = 0 if 'seed' not in kwargs else kwargs['seed'] | |||
set_random_seed_mpu(seed) | |||
self.iteration = 0 | |||
self.dist_model = self.initialize_model(path_load_tag='model') | |||
def initialize_model(self, path_load_tag='model'): | |||
"""Build the model.""" | |||
print_rank_0('Building Plug model. It will take a few minutes ...') | |||
model = PlugModel(self.config) | |||
if mpu.get_data_parallel_rank() == 0: | |||
logger.info( | |||
' > number of parameters on model parallel rank {}: {}'.format( | |||
mpu.get_model_parallel_rank(), | |||
sum([p.nelement() for p in model.parameters()]))) | |||
if self.config.deepspeed and self.config.fp16: | |||
model.half() | |||
# GPU allocation. | |||
model.cuda(torch.cuda.current_device()) | |||
# Fp16 conversion. | |||
if self.config.fp16: | |||
model = FP16_Module(model) | |||
if self.config.fp32_embedding: | |||
model.module.model.bert.embeddings.word_embeddings.float() | |||
model.module.model.bert.embeddings.position_embeddings.float() | |||
model.module.model.bert.embeddings.token_type_embeddings.float( | |||
) | |||
if self.config.fp32_tokentypes: | |||
model.module.model.bert.embeddings.token_type_embeddings.float( | |||
) | |||
if self.config.fp32_layernorm: | |||
for name, _module in model.named_modules(): | |||
if 'LayerNorm' in name: | |||
_module.float() | |||
load_model = pre_load(mpu, self.model_dir, tag=path_load_tag) | |||
model_dict = model.module.model.state_dict() | |||
for key in load_model: | |||
if key not in model_dict.keys(): | |||
print_rank_0('Skip key: ' + key) | |||
else: | |||
print_rank_0('Loading key: ' + key) | |||
model.module.model.load_state_dict(load_model, strict=False) | |||
return model | |||
@staticmethod | |||
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): | |||
# This function has been mostly taken from huggingface conversational ai code at | |||
# https://medium.com/huggingface/how-to-build-a-state-of-the-art- | |||
# conversational-ai-with-transfer-learning-2d818ac26313 | |||
if top_k > 0: | |||
# Remove all tokens with a probability less than the last token of the top-k | |||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, | |||
None] | |||
logits[indices_to_remove] = filter_value | |||
if top_p > 0.0: | |||
# convert to 1D | |||
logits = logits.view(logits.size()[1]).contiguous() | |||
sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |||
cumulative_probs = torch.cumsum( | |||
F.softmax(sorted_logits, dim=-1), dim=-1) | |||
# Remove tokens with cumulative probability above the threshold | |||
sorted_indices_to_remove = cumulative_probs > top_p | |||
# Shift the indices to the right to keep also the first token above the threshold | |||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ | |||
..., :-1].clone() | |||
sorted_indices_to_remove[..., 0] = 0 | |||
indices_to_remove = sorted_indices[sorted_indices_to_remove] | |||
logits[indices_to_remove] = filter_value | |||
# going back to 2D | |||
logits = logits.view(1, -1).contiguous() | |||
return logits | |||
def generate(self, input: Dict[str, Tensor], out_length=128, *kwargs): | |||
device = torch.cuda.current_device() | |||
batch_size = input['input_ids'].shape[0] | |||
tokens = input['input_ids'].view(1, -1).contiguous().to(device) | |||
dec_input_ids = input['dec_input_ids'].to(device) | |||
attention_mask = input['attention_mask'].to(device) | |||
self.dist_model.eval() | |||
with torch.no_grad(): | |||
# Only supports batch_size=1 | |||
all_generate_tokens = [] | |||
generate_tokens = [] | |||
counter = 0 | |||
sequence_output = None | |||
vocab_size = self.config.original_vocab_size | |||
sep_token_idx = 102 # index of [SEP] token in BertTokenizer | |||
while counter < out_length: | |||
if counter % 128 == 0 and counter != 0: | |||
# Sliding window | |||
generate_tokens.append(sep_token_idx) | |||
start = (tokens == sep_token_idx).nonzero( | |||
as_tuple=True)[-1] | |||
if start + len(generate_tokens) >= 512: | |||
tokens = torch.cat([ | |||
tokens[:start], | |||
torch.cuda.LongTensor(generate_tokens) | |||
], -1)[-512:] | |||
else: | |||
tokens[0][start:start + len(generate_tokens | |||
)] = torch.cuda.LongTensor( | |||
generate_tokens) | |||
attention_mask = (tokens != 0) | |||
dec_input_ids = input['dec_input_ids'].to(device) | |||
generate_tokens = [] | |||
sequence_output = None | |||
position_ids = torch.full([batch_size, 1], | |||
len(generate_tokens), | |||
dtype=torch.long, | |||
device=device) | |||
_, logits, sequence_output = self.dist_model( | |||
tokens, | |||
None, | |||
attention_mask, | |||
dec_input_ids, | |||
attention_mask, | |||
position_ids, | |||
is_infer=True, | |||
sequence_output=sequence_output, | |||
parallel_output=False) | |||
logits = logits[:, -1, :] | |||
logits = logits / self.model_cfg['temperature'] | |||
logits = self.top_k_logits( | |||
logits, | |||
top_k=self.model_cfg['top_k'], | |||
top_p=self.model_cfg['top_p']) | |||
log_probs = F.softmax(logits, dim=-1) | |||
prev = torch.multinomial(log_probs, num_samples=1) | |||
prev_token = prev[0].item() | |||
if prev_token >= vocab_size: | |||
prev_token = 100 | |||
prev[0] = 100 | |||
if prev_token == 102 and len(all_generate_tokens) > int( | |||
max(1, out_length) * 0.8): | |||
break | |||
if prev_token == 102: | |||
counter += 1 | |||
continue | |||
dec_input_ids = torch.cat([dec_input_ids, prev], dim=1) | |||
generate_tokens.append(prev_token) | |||
all_generate_tokens.append(prev_token) | |||
counter += 1 | |||
generate_context = [] | |||
for token in all_generate_tokens: | |||
if generate_context and generate_context[ | |||
-1] == 100 and token == 100: | |||
continue | |||
else: | |||
generate_context.append(token) | |||
return {'generate_context': generate_context} |
@@ -1,7 +1,10 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import os.path as osp | |||
from abc import ABC, abstractmethod | |||
from functools import partial | |||
from multiprocessing import Pool | |||
from threading import Lock | |||
from typing import Any, Dict, Generator, List, Mapping, Union | |||
@@ -15,8 +18,10 @@ from modelscope.utils.config import Config | |||
from modelscope.utils.constant import Frameworks, ModelFile | |||
from modelscope.utils.device import (create_device, device_placement, | |||
verify_device) | |||
from modelscope.utils.hub import read_config, snapshot_download | |||
from modelscope.utils.import_utils import is_tf_available, is_torch_available | |||
from modelscope.utils.logger import get_logger | |||
from modelscope.utils.torch_utils import _find_free_port, _is_free_port | |||
from .util import is_model, is_official_hub_path | |||
if is_torch_available(): | |||
@@ -302,3 +307,106 @@ class Pipeline(ABC): | |||
output should have the standard output name. | |||
""" | |||
raise NotImplementedError('postprocess') | |||
class DistributedPipeline(Pipeline): | |||
"""This pipeline is used to load multi gpu models. | |||
What will this class do: | |||
1. Read the global config from the configuration.json | |||
2. Set the multiprocessing method to spawn | |||
3. Open a multiprocessing pool of the world_size to instantiate model pieces. | |||
4. Set the master port and ip | |||
5. Call _instantiate_one to instantiate one model piece | |||
This method should be implemented by the derived class. | |||
6. After the forward method is called, do preprocess in main process | |||
and call _forward_one to collect results, and do | |||
post process in main process. | |||
NOTE: _instantiate_one and _forward_one are class methods, any derived class should implement them and | |||
store the model handler in the class field. | |||
""" | |||
def __init__(self, | |||
model: str = None, | |||
preprocessor: Union[Preprocessor, List[Preprocessor]] = None, | |||
auto_collate=True, | |||
**kwargs): | |||
self.preprocessor = preprocessor | |||
self._model_prepare = False | |||
self._model_prepare_lock = Lock() | |||
self._auto_collate = auto_collate | |||
if os.path.exists(model): | |||
self.model_dir = model | |||
else: | |||
self.model_dir = snapshot_download(model) | |||
self.cfg = read_config(self.model_dir) | |||
self.world_size = self.cfg.model.world_size | |||
self.model_pool = None | |||
self.device_name = 'cpu' | |||
self.device = create_device(self.device_name) | |||
self.has_multiple_models = False | |||
self.framework = self.cfg.framework | |||
if torch.multiprocessing.get_start_method(allow_none=True) is None: | |||
torch.multiprocessing.set_start_method('spawn') | |||
ranks = list(range(self.world_size)) | |||
self.model_pool = Pool(self.world_size) | |||
master_ip = '127.0.0.1' if 'master_ip' not in kwargs else kwargs[ | |||
'master_ip'] | |||
master_port = '29500' if 'master_port' not in kwargs else kwargs[ | |||
'master_port'] | |||
if not _is_free_port(int(master_port)): | |||
master_port = str(_find_free_port()) | |||
self.model_pool.map( | |||
partial( | |||
self.__class__._instantiate_one, | |||
model_dir=self.model_dir, | |||
master_ip=master_ip, | |||
master_port=master_port, | |||
**self.cfg.model, | |||
**kwargs), ranks) | |||
def __del__(self): | |||
if hasattr(self, 'model_pool') and self.model_pool is not None: | |||
self.model_pool.terminate() | |||
def __getstate__(self): | |||
self_dict = self.__dict__.copy() | |||
del self_dict['model_pool'] | |||
del self_dict['preprocessor'] | |||
del self_dict['_model_prepare_lock'] | |||
return self_dict | |||
@classmethod | |||
def _instantiate_one(cls, rank, model_dir, **kwargs): | |||
"""Instantiate one model piece. | |||
@param rank: The model rank. | |||
@param model_dir: The model_dir in the node. | |||
@param kwargs: Any extra args. | |||
@return: None. The model handler should be kept in the class field. | |||
""" | |||
pass | |||
def forward(self, inputs: Dict[str, Any], | |||
**forward_params) -> Dict[str, Any]: | |||
inputs = { | |||
'inputs': inputs, | |||
'forward_params': forward_params, | |||
} | |||
res = self.model_pool.map(self.__class__._forward_one, | |||
[inputs] * self.world_size) | |||
return res[0] | |||
@classmethod | |||
def _forward_one(cls, inputs): | |||
"""Forward the inputs to one model piece. | |||
Use the model handler kept in the class field to forward. | |||
@param inputs: The inputs after the preprocessing. | |||
@return: The forward results. | |||
""" | |||
pass |
@@ -0,0 +1,107 @@ | |||
from typing import Any, Dict | |||
import torch | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models.nlp.plug import DistributedPlug | |||
from modelscope.pipelines.base import DistributedPipeline | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import TextGenerationPreprocessor | |||
from modelscope.utils.constant import Tasks | |||
@PIPELINES.register_module( | |||
Tasks.text_generation, module_name=Pipelines.plug_generation) | |||
class DistributedPlugPipeline(DistributedPipeline): | |||
"""This class is used to instantiate the plug model. | |||
""" | |||
model = None | |||
def __init__(self, | |||
model, | |||
preprocessor=None, | |||
first_sequence='sentence', | |||
**kwargs): | |||
"""Create a plug pipeline instance. | |||
@param model: The model_id of plug(damo/nlp_plug_text-generation_27B). | |||
The default path to damo/nlp_plug_text-generation_27B can be obtained by function | |||
get_cache_dir("damo/nlp_plug_text-generation_27B"), the model should be downloaded to | |||
this path before calling this class by model_id. | |||
The model can be downloaded from the link on | |||
https://modelscope.cn/models/damo/nlp_plug_text-generation_27B/summary. | |||
After downloading, you should have a plug model structure like this: | |||
/your/path/to/damo/nlp_plug_text-generation_27B | |||
|_ config.json | |||
|_ configuration.json | |||
|_ ds_zero-offload_10B_config.json | |||
|_ vocab.txt | |||
|_ model <-- an empty directory | |||
Model binaries shall be downloaded separately to populate the model directory, so that | |||
the model directory would contain the following binaries: | |||
|_ model | |||
|_ mp_rank_00_model_states.pt | |||
|_ mp_rank_01_model_states.pt | |||
|_ mp_rank_02_model_states.pt | |||
|_ mp_rank_03_model_states.pt | |||
|_ mp_rank_04_model_states.pt | |||
|_ mp_rank_05_model_states.pt | |||
|_ mp_rank_06_model_states.pt | |||
|_ mp_rank_07_model_states.pt | |||
@param preprocessor: The optional preprocessor, if not passed in, a TextGenerationPreprocessor will | |||
be used as default. | |||
@param first_sequence: The first_sequence key name if the input format is a dict. | |||
@param kwargs: | |||
sequence_length: The input sequence_length. | |||
""" | |||
if preprocessor is None: | |||
preprocessor = TextGenerationPreprocessor( | |||
model, | |||
first_sequence=first_sequence, | |||
sequence_length=kwargs.pop('sequence_length', 512)) | |||
super().__init__(model, preprocessor=preprocessor, **kwargs) | |||
assert hasattr(preprocessor, 'tokenizer') | |||
self.cls_token_id = preprocessor.tokenizer.cls_token_id | |||
@classmethod | |||
def _forward_one(cls, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
with torch.no_grad(): | |||
return cls.model.generate(inputs['inputs'], | |||
**inputs['forward_params']) | |||
def _sanitize_parameters(self, **pipeline_parameters): | |||
return {}, pipeline_parameters, {} | |||
def forward(self, inputs: Dict[str, Any], | |||
**forward_params) -> Dict[str, Any]: | |||
batch_size = inputs['input_ids'].shape[0] | |||
dec_input_ids = torch.full([batch_size, 1], | |||
self.cls_token_id, | |||
dtype=torch.long) | |||
inputs['dec_input_ids'] = dec_input_ids | |||
res = super().forward(inputs, **forward_params) | |||
return res | |||
@classmethod | |||
def _instantiate_one(cls, rank, model_dir, **kwargs): | |||
cls.model = DistributedPlug(model_dir, rank, **kwargs) | |||
cls.model.eval() | |||
def postprocess(self, inputs: Dict[str, Any], | |||
**postprocess_params) -> Dict[str, str]: | |||
"""process the prediction results | |||
Args: | |||
inputs (Dict[str, Any]): _description_ | |||
Returns: | |||
Dict[str, str]: the prediction results | |||
""" | |||
from modelscope.outputs import OutputKeys | |||
generate_context = inputs['generate_context'] | |||
generate_context = ''.join( | |||
self.preprocessor.tokenizer.convert_ids_to_tokens( | |||
generate_context)).replace('[UNK]', '“').replace('##', '') | |||
return {OutputKeys.TEXT: generate_context} |
@@ -164,7 +164,8 @@ class NLPTokenizerPreprocessorBase(Preprocessor): | |||
""" | |||
model_type = get_model_type(model_dir) | |||
if model_type in (Models.structbert, Models.gpt3, Models.palm): | |||
if model_type in (Models.structbert, Models.gpt3, Models.palm, | |||
Models.plug): | |||
from modelscope.models.nlp.structbert import SbertTokenizer | |||
return SbertTokenizer.from_pretrained(model_dir, use_fast=False) | |||
elif model_type == Models.veco: | |||
@@ -39,7 +39,8 @@ from modelscope.utils.device import create_device, verify_device | |||
from modelscope.utils.file_utils import func_receive_dict_inputs | |||
from modelscope.utils.logger import get_logger | |||
from modelscope.utils.registry import build_from_cfg | |||
from modelscope.utils.torch_utils import get_dist_info, init_dist | |||
from modelscope.utils.torch_utils import (get_dist_info, init_dist, | |||
set_random_seed) | |||
from .base import BaseTrainer | |||
from .builder import TRAINERS | |||
from .default_config import DEFAULT_CONFIG | |||
@@ -922,6 +923,4 @@ def worker_init_fn(worker_id, num_workers, rank, seed): | |||
# The seed of each worker equals to | |||
# num_worker * rank + worker_id + user_seed | |||
worker_seed = num_workers * rank + worker_id + seed | |||
np.random.seed(worker_seed) | |||
random.seed(worker_seed) | |||
torch.manual_seed(worker_seed) | |||
set_random_seed(worker_seed) |
@@ -0,0 +1,130 @@ | |||
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import math | |||
import torch | |||
import torch.distributed as dist | |||
from megatron import mpu | |||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors | |||
from torch.autograd import Variable | |||
from torch.nn.modules import Module | |||
from modelscope.utils.torch_utils import init_dist | |||
def initialize_distributed(rank, mpu, world_size, model_parallel_size, | |||
master_ip, master_port): | |||
"""Initialize torch.distributed.""" | |||
# Manually set the device ids. | |||
device = rank % torch.cuda.device_count() | |||
torch.cuda.set_device(device) | |||
# Call the init process | |||
init_method = 'tcp://' | |||
init_method += master_ip + ':' + master_port | |||
torch.distributed.init_process_group( | |||
backend='nccl', world_size=8, rank=rank, init_method=init_method) | |||
# Set the model-parallel communicators. | |||
mpu.initialize_model_parallel(model_parallel_size) | |||
def normal_init_method(mean, std): | |||
def init_(tensor): | |||
return torch.nn.init.normal_(tensor, mean=mean, std=std) | |||
return init_ | |||
def scaled_init_method(mean, std, num_layers): | |||
"""Init method based on N(0, sigma/sqrt(2*num_layers).""" | |||
std = std / math.sqrt(2.0 * num_layers) | |||
def init_(tensor): | |||
return torch.nn.init.normal_(tensor, mean=mean, std=std) | |||
return init_ | |||
class DistributedDataParallel(Module): | |||
def __init__(self, module): | |||
super(DistributedDataParallel, self).__init__() | |||
self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False | |||
self.module = module | |||
self.data_parallel_group = mpu.get_data_parallel_group() | |||
src_rank = mpu.get_model_parallel_rank() | |||
for p in self.module.parameters(): | |||
if torch.is_tensor(p): | |||
dist.broadcast(p, src_rank, group=self.data_parallel_group) | |||
def allreduce_params(reduce_after=True, | |||
no_scale=False, | |||
fp32_allreduce=False): | |||
if (self.needs_reduction): | |||
self.needs_reduction = False | |||
buckets = {} | |||
for name, param in self.module.named_parameters(): | |||
if param.requires_grad and param.grad is not None: | |||
tp = (param.data.type()) | |||
if tp not in buckets: | |||
buckets[tp] = [] | |||
buckets[tp].append(param) | |||
if self.warn_on_half: | |||
if torch.cuda.HalfTensor in buckets: | |||
print( | |||
'WARNING: gloo dist backend for half parameters may be extremely slow.', | |||
'It is recommended to use the NCCL backend in this case.' | |||
) | |||
self.warn_on_half = False | |||
for tp in buckets: | |||
bucket = buckets[tp] | |||
grads = [param.grad.data for param in bucket] | |||
coalesced = _flatten_dense_tensors(grads) | |||
if fp32_allreduce: | |||
coalesced = coalesced.float() | |||
if not no_scale and not reduce_after: | |||
coalesced /= dist.get_world_size( | |||
group=self.data_parallel_group) | |||
dist.all_reduce(coalesced, group=self.data_parallel_group) | |||
torch.cuda.synchronize() | |||
if not no_scale and reduce_after: | |||
coalesced /= dist.get_world_size( | |||
group=self.data_parallel_group) | |||
for buf, synced in zip( | |||
grads, _unflatten_dense_tensors(coalesced, grads)): | |||
buf.copy_(synced) | |||
self.hook_handles = [] | |||
self.hooks = [] | |||
for param in list(self.module.parameters()): | |||
def allreduce_hook(*unused): | |||
Variable._execution_engine.queue_callback(allreduce_params) | |||
self.allreduce_params = allreduce_params | |||
def forward(self, *inputs, **kwargs): | |||
self.needs_reduction = True | |||
return self.module(*inputs, **kwargs) | |||
def state_dict(self, destination=None, prefix='', keep_vars=False): | |||
sd = self.module.state_dict(destination, prefix, keep_vars) | |||
return sd | |||
def load_state_dict(self, state_dict, strict=True): | |||
self.module.load_state_dict(state_dict, strict=strict) |
@@ -0,0 +1,117 @@ | |||
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import os | |||
import torch | |||
def load_checkpoint(model, | |||
load_dir, | |||
tag, | |||
load_module_strict=True, | |||
load_optimizer_states=True, | |||
load_lr_scheduler_states=True): | |||
r"""Load training checkpoint | |||
Arguments: | |||
load_dir: Required. Directory to load the checkpoint from | |||
tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step. | |||
load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and | |||
checkpoint match. | |||
load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint. | |||
Ex. ADAM's momentum and variance | |||
load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint. | |||
Return: | |||
load_path: Path of the loaded checkpoint. None if loading the checkpoint failed | |||
client_state: State dictionary used for loading required training states in the client code. | |||
""" | |||
load_path, client_states = _load_checkpoint( | |||
model, | |||
load_dir, | |||
tag, | |||
load_module_strict=load_module_strict, | |||
load_optimizer_states=load_optimizer_states, | |||
load_lr_scheduler_states=load_lr_scheduler_states) | |||
if load_optimizer_states: | |||
if model.zero_optimization() and load_path is not None: | |||
model._load_zero_checkpoint( | |||
load_dir, tag, load_optimizer_states=load_optimizer_states) | |||
return load_path, client_states | |||
def _get_ckpt_name(mpu, checkpoints_path, tag): | |||
mp_rank = 0 if mpu is None else mpu.get_model_parallel_rank() | |||
ckpt_name = os.path.join( | |||
checkpoints_path, str(tag), | |||
'mp_rank_{:02d}'.format(mp_rank) + '_model_states.pt') | |||
return ckpt_name | |||
def pre_load(mpu, load_dir, tag=''): | |||
load_path = _get_ckpt_name(mpu, load_dir, tag) | |||
checkpoint = torch.load( | |||
load_path, map_location=lambda storage, loc: storage) | |||
return checkpoint['module'] | |||
def _load_checkpoint(model, | |||
load_dir, | |||
tag, | |||
load_module_strict=True, | |||
load_optimizer_states=True, | |||
load_lr_scheduler_states=True): | |||
load_path = model._get_ckpt_name(load_dir, tag) | |||
if not os.path.exists(load_path): | |||
return None, None | |||
checkpoint = torch.load( | |||
load_path, map_location=lambda storage, loc: storage) | |||
model.load_module_state_dict( | |||
state_dict=checkpoint['module'], strict=load_module_strict) | |||
if not model.zero_optimization() and load_optimizer_states: | |||
if model.fp16_enabled(): | |||
model.optimizer.load_state_dict( | |||
checkpoint['optimizer'], | |||
load_optimizer_states=load_optimizer_states) | |||
elif load_optimizer_states: | |||
model.optimizer.load_state_dict(checkpoint['optimizer']) | |||
if load_lr_scheduler_states and model.lr_scheduler is not None: | |||
model.lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) | |||
model.csr_tensor_module_names = checkpoint['csr_tensor_module_names'] | |||
model.global_steps = checkpoint['global_steps'] | |||
model.global_samples = checkpoint.get( | |||
'global_samples', model.global_steps * model.train_batch_size()) | |||
model.skipped_steps = checkpoint['skipped_steps'] | |||
model.loaded_checkpoint_mp_world_size = checkpoint['mp_world_size'] | |||
model.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size'] | |||
deepspeed_states = [ | |||
'module', 'optimizer', 'lr_scheduler', 'csr_tensor_module_names', | |||
'skipped_steps', 'global_steps', 'dp_world_size', 'mp_world_size' | |||
] | |||
client_state = { | |||
key: value | |||
for key, value in checkpoint.items() if key not in deepspeed_states | |||
} | |||
return load_path, client_state |
@@ -3,16 +3,16 @@ | |||
import functools | |||
import os | |||
import pickle | |||
import random | |||
import socket | |||
import subprocess | |||
import tempfile | |||
from typing import Callable, List, Optional, Tuple | |||
import numpy as np | |||
import torch | |||
import torch.multiprocessing as mp | |||
from torch import distributed as dist | |||
from torch._utils import (_flatten_dense_tensors, _take_tensors, | |||
_unflatten_dense_tensors) | |||
def _find_free_port() -> str: | |||
@@ -49,7 +49,6 @@ def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None: | |||
def _init_dist_pytorch(backend: str, **kwargs) -> None: | |||
# rank = int(os.environ['RANK']) | |||
local_rank = int(os.environ['LOCAL_RANK']) | |||
torch.cuda.set_device(local_rank) | |||
dist.init_process_group(backend=backend, **kwargs) | |||
@@ -180,3 +179,19 @@ def broadcast(inputs, src): | |||
dist.broadcast(inputs_tensor, src) | |||
return pickle.loads(inputs_tensor.cpu().numpy().tobytes()) | |||
def set_random_seed(seed): | |||
if seed is not None and seed >= 0: | |||
random.seed(seed) | |||
np.random.seed(seed) | |||
torch.manual_seed(seed) | |||
else: | |||
raise ValueError( | |||
f'Random seed should be positive, current seed is {seed}') | |||
def set_random_seed_mpu(seed): | |||
from megatron import mpu | |||
set_random_seed(seed) | |||
mpu.model_parallel_cuda_manual_seed(seed) |
@@ -1,6 +1,8 @@ | |||
deepspeed | |||
en_core_web_sm>=2.3.5 | |||
fairseq>=0.10.2 | |||
jieba>=0.42.1 | |||
megatron_util | |||
pai-easynlp | |||
# rough-score was just recently updated from 0.0.4 to 0.0.7 | |||
# which introduced compatability issues that are being investigated | |||
@@ -0,0 +1,49 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import unittest | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.pipelines import pipeline | |||
from modelscope.utils.constant import Tasks | |||
class TextPlugGenerationTest(unittest.TestCase): | |||
def setUp(self) -> None: | |||
# please make sure this local path exists. | |||
self.model_id = 'damo/nlp_plug_text-generation_27B' | |||
self.model_dir = snapshot_download(self.model_id) | |||
self.plug_input = '段誉轻挥折扇,摇了摇头,说道:“你师父是你的师父,你师父可不是我的师父。"' | |||
@unittest.skip('distributed plug, skipped') | |||
def test_plug(self): | |||
""" The model can be downloaded from the link on | |||
https://modelscope.cn/models/damo/nlp_plug_text-generation_27B/summary. | |||
After downloading, you should have a plug model structure like this: | |||
nlp_plug_text-generation_27B | |||
|_ config.json | |||
|_ configuration.json | |||
|_ ds_zero-offload_10B_config.json | |||
|_ vocab.txt | |||
|_ model <-- an empty directory | |||
Model binaries shall be downloaded separately to populate the model directory, so that | |||
the model directory would contain the following binaries: | |||
|_ model | |||
|_ mp_rank_00_model_states.pt | |||
|_ mp_rank_01_model_states.pt | |||
|_ mp_rank_02_model_states.pt | |||
|_ mp_rank_03_model_states.pt | |||
|_ mp_rank_04_model_states.pt | |||
|_ mp_rank_05_model_states.pt | |||
|_ mp_rank_06_model_states.pt | |||
|_ mp_rank_07_model_states.pt | |||
""" | |||
# download model binaries to <model_dir>/model | |||
pipe = pipeline(Tasks.text_generation, model=self.model_id) | |||
print( | |||
f'input: {self.plug_input}\noutput: {pipe(self.plug_input, out_length=256)}' | |||
) | |||
if __name__ == '__main__': | |||
unittest.main() |