suluyan.sly yingda.chen 3 years ago
parent
commit
904374d329
15 changed files with 2044 additions and 8 deletions
  1. +2
    -0
      modelscope/metainfo.py
  2. +2
    -0
      modelscope/models/nlp/__init__.py
  3. +27
    -0
      modelscope/models/nlp/plug/__init__.py
  4. +232
    -0
      modelscope/models/nlp/plug/configuration_plug.py
  5. +191
    -0
      modelscope/models/nlp/plug/distributed_plug.py
  6. +1054
    -0
      modelscope/models/nlp/plug/modeling_plug.py
  7. +108
    -0
      modelscope/pipelines/base.py
  8. +107
    -0
      modelscope/pipelines/nlp/distributed_plug_pipeline.py
  9. +2
    -1
      modelscope/preprocessors/nlp.py
  10. +3
    -4
      modelscope/trainers/trainer.py
  11. +130
    -0
      modelscope/utils/nlp/distributed.py
  12. +117
    -0
      modelscope/utils/nlp/load_checkpoint.py
  13. +18
    -3
      modelscope/utils/torch_utils.py
  14. +2
    -0
      requirements/nlp.txt
  15. +49
    -0
      tests/pipelines/test_plug_text_generation.py

+ 2
- 0
modelscope/metainfo.py View File

@@ -55,6 +55,7 @@ class Models(object):
lcrf = 'lstm-crf'
bart = 'bart'
gpt3 = 'gpt3'
plug = 'plug'
bert_for_ds = 'bert-for-document-segmentation'

# audio models
@@ -172,6 +173,7 @@ class Pipelines(object):
dialog_state_tracking = 'dialog-state-tracking'
zero_shot_classification = 'zero-shot-classification'
text_error_correction = 'text-error-correction'
plug_generation = 'plug-generation'
faq_question_answering = 'faq-question-answering'
conversational_text_to_sql = 'conversational-text-to-sql'
relation_extraction = 'relation-extraction'


+ 2
- 0
modelscope/models/nlp/__init__.py View File

@@ -28,6 +28,7 @@ if TYPE_CHECKING:
SingleBackboneTaskModelBase)
from .bart_for_text_error_correction import BartForTextErrorCorrection
from .gpt3 import GPT3ForTextGeneration
from .plug import PlugForTextGeneration
from .sbert_for_faq_question_answering import SbertForFaqQuestionAnswering

else:
@@ -60,6 +61,7 @@ else:
],
'bart_for_text_error_correction': ['BartForTextErrorCorrection'],
'gpt3': ['GPT3ForTextGeneration'],
'plug': ['PlugForTextGeneration'],
'sbert_for_faq_question_answering': ['SbertForFaqQuestionAnswering'],
}



+ 27
- 0
modelscope/models/nlp/plug/__init__.py View File

@@ -0,0 +1,27 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .configuration_plug import PlugNLGConfig
from .modeling_plug import PlugModel
from .distributed_plug import DistributedPlug
from .plug_for_text_generation import PlugForTextGeneration
else:
_import_structure = {
'configuration_plug': ['PlugNLGConfig'],
'modeling_plug': ['PlugModel'],
'distributed_plug': ['DistributedPlug'],
'plug_for_text_generation': ['PlugForTextGeneration'],
}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 232
- 0
modelscope/models/nlp/plug/configuration_plug.py View File

@@ -0,0 +1,232 @@
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy

import json
from transformers import PretrainedConfig

from modelscope.utils import logger as logging

logger = logging.get_logger(__name__)


class PlugNLUConfig(PretrainedConfig):
model_type = 'plugNLU'

def __init__(self,
vocab_size=21504,
original_vocab_size=21128,
hidden_size=8192,
num_hidden_layers=24,
num_attention_heads=128,
intermediate_size=32768,
hidden_act='gelu',
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=2048,
type_vocab_size=3,
initializer_range=0.00707,
deep_init=False,
deepspeed=False,
lr_decay_style='linear',
weight_decay=1e-2,
clip_grad=1.0,
warmup=0.0333,
pre_ln=True,
fp16=True,
fp32_layernorm=True,
fp32_embedding=False,
fp32_tokentypes=False,
layernorm_epsilon=1e-5,
dec_hidden_layers=6,
pruning_method=None,
pruning_mask_init='constant',
pruning_mask_scale=0.0,
pruning_initial_threshold=1.0,
pruning_final_threshold=0.01,
pruning_initial_warmup=1,
pruning_final_warmup=20,
pruning_module='decoder',
pruning_decay_step=50,
pruning_decay_type='exp',
ft_module=None,
attn_separate=False,
LR_weight_rank=8,
LR_mask_rank=8,
**kwargs):
super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs)

self.vocab_size = vocab_size
self.original_vocab_size = original_vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.deep_init = deep_init
self.deepspeed = deepspeed
self.lr_decay_style = lr_decay_style
self.weight_decay = weight_decay
self.clip_grad = clip_grad
self.warmup = warmup
self.pre_ln = pre_ln
self.fp16 = fp16
self.fp32_layernorm = fp32_layernorm
self.fp32_embedding = fp32_embedding
self.layernorm_epsilon = layernorm_epsilon
self.fp32_tokentypes = fp32_tokentypes
self.dec_hidden_layers = dec_hidden_layers
self.pruning_method = pruning_method
self.pruning_mask_init = pruning_mask_init
self.pruning_mask_scale = pruning_mask_scale
self.pruning_module = pruning_module
self.pruning_initial_threshold = pruning_initial_threshold
self.pruning_final_threshold = pruning_final_threshold
self.pruning_initial_warmup = pruning_initial_warmup
self.pruning_final_warmup = pruning_final_warmup
self.pruning_decay_step = pruning_decay_step
self.pruning_decay_type = pruning_decay_type
self.ft_module = ft_module
self.attn_separate = attn_separate
self.LR_weight_rank = LR_weight_rank
self.LR_mask_rank = LR_mask_rank

@classmethod
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = PlugNLUConfig()
for key, value in json_object.items():
config.__dict__[key] = value
return config

@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with open(json_file, 'r', encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))

def merge_args(self, args):
"""merge values a `BertConfig` from a json file of parameters."""
local_keys = self.__dict__.keys()
for key, value in args.__dict__.items():
if key in local_keys:
continue
self.__dict__[key] = value
return self

def __repr__(self):
return str(self.to_json_string())

def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output

def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + '\n'


class PlugNLGConfig(PlugNLUConfig):
model_type = 'plugNLG'

def __init__(self,
vocab_size=21504,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act='gelu',
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.00707,
deep_init=False,
deepspeed=False,
lr_decay_style='linear',
weight_decay=1e-2,
clip_grad=1.0,
warmup=0.01,
pre_ln=False,
fp16=False,
fp32_layernorm=False,
fp32_embedding=False,
fp32_tokentypes=False,
layernorm_epsilon=1e-12,
dec_hidden_layers=6,
pruning_method=None,
pruning_mask_init='constant',
pruning_mask_scale=0.0,
pruning_initial_threshold=1.0,
pruning_final_threshold=0.01,
pruning_initial_warmup=1,
pruning_final_warmup=20,
pruning_module='decoder',
pruning_decay_step=50,
pruning_decay_type='exp',
ft_module=None,
attn_separate=False,
LR_weight_rank=8,
LR_mask_rank=8,
**kwargs):
super().__init__(layer_norm_eps=layernorm_epsilon, **kwargs)

self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.deep_init = deep_init
self.deepspeed = deepspeed
self.lr_decay_style = lr_decay_style
self.weight_decay = weight_decay
self.clip_grad = clip_grad
self.warmup = warmup
self.pre_ln = pre_ln
self.fp16 = fp16
self.fp32_layernorm = fp32_layernorm
self.fp32_embedding = fp32_embedding
self.layernorm_epsilon = layernorm_epsilon
self.fp32_tokentypes = fp32_tokentypes
self.dec_hidden_layers = dec_hidden_layers
self.pruning_method = pruning_method
self.pruning_mask_init = pruning_mask_init
self.pruning_mask_scale = pruning_mask_scale
self.pruning_module = pruning_module
self.pruning_initial_threshold = pruning_initial_threshold
self.pruning_final_threshold = pruning_final_threshold
self.pruning_initial_warmup = pruning_initial_warmup
self.pruning_final_warmup = pruning_final_warmup
self.pruning_decay_step = pruning_decay_step
self.pruning_decay_type = pruning_decay_type
self.ft_module = ft_module
self.attn_separate = attn_separate
self.LR_weight_rank = LR_weight_rank
self.LR_mask_rank = LR_mask_rank

+ 191
- 0
modelscope/models/nlp/plug/distributed_plug.py View File

@@ -0,0 +1,191 @@
import os
from typing import Dict

import torch
import torch.nn.functional as F
from megatron import mpu
from megatron.fp16 import FP16_Module
from megatron.utils import print_rank_0

from modelscope.models import TorchModel
from modelscope.models.base import Tensor
from modelscope.utils.logger import get_logger
from modelscope.utils.nlp.distributed import initialize_distributed
from modelscope.utils.nlp.load_checkpoint import pre_load
from modelscope.utils.torch_utils import set_random_seed_mpu
from . import PlugModel
from .configuration_plug import PlugNLGConfig

logger = get_logger(__name__)


class DistributedPlug(TorchModel):

def __init__(self, model_dir, rank, **kwargs):
super().__init__(model_dir, **kwargs)
self.rank = rank
self.model_cfg = kwargs
self.config = PlugNLGConfig.from_pretrained(model_dir)
initialize_distributed(rank, mpu, kwargs['world_size'],
kwargs['model_parallel_size'],
kwargs['master_ip'], kwargs['master_port'])
seed = 0 if 'seed' not in kwargs else kwargs['seed']
set_random_seed_mpu(seed)
self.iteration = 0
self.dist_model = self.initialize_model(path_load_tag='model')

def initialize_model(self, path_load_tag='model'):
"""Build the model."""
print_rank_0('Building Plug model. It will take a few minutes ...')
model = PlugModel(self.config)

if mpu.get_data_parallel_rank() == 0:
logger.info(
' > number of parameters on model parallel rank {}: {}'.format(
mpu.get_model_parallel_rank(),
sum([p.nelement() for p in model.parameters()])))

if self.config.deepspeed and self.config.fp16:
model.half()

# GPU allocation.
model.cuda(torch.cuda.current_device())

# Fp16 conversion.
if self.config.fp16:
model = FP16_Module(model)
if self.config.fp32_embedding:
model.module.model.bert.embeddings.word_embeddings.float()
model.module.model.bert.embeddings.position_embeddings.float()
model.module.model.bert.embeddings.token_type_embeddings.float(
)
if self.config.fp32_tokentypes:
model.module.model.bert.embeddings.token_type_embeddings.float(
)
if self.config.fp32_layernorm:
for name, _module in model.named_modules():
if 'LayerNorm' in name:
_module.float()

load_model = pre_load(mpu, self.model_dir, tag=path_load_tag)
model_dict = model.module.model.state_dict()
for key in load_model:
if key not in model_dict.keys():
print_rank_0('Skip key: ' + key)
else:
print_rank_0('Loading key: ' + key)
model.module.model.load_state_dict(load_model, strict=False)
return model

@staticmethod
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
# This function has been mostly taken from huggingface conversational ai code at
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-
# conversational-ai-with-transfer-learning-2d818ac26313

if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1,
None]
logits[indices_to_remove] = filter_value

if top_p > 0.0:
# convert to 1D
logits = logits.view(logits.size()[1]).contiguous()
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(
F.softmax(sorted_logits, dim=-1), dim=-1)

# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
# going back to 2D
logits = logits.view(1, -1).contiguous()
return logits

def generate(self, input: Dict[str, Tensor], out_length=128, *kwargs):
device = torch.cuda.current_device()
batch_size = input['input_ids'].shape[0]
tokens = input['input_ids'].view(1, -1).contiguous().to(device)
dec_input_ids = input['dec_input_ids'].to(device)
attention_mask = input['attention_mask'].to(device)
self.dist_model.eval()
with torch.no_grad():
# Only supports batch_size=1
all_generate_tokens = []
generate_tokens = []
counter = 0
sequence_output = None
vocab_size = self.config.original_vocab_size
sep_token_idx = 102 # index of [SEP] token in BertTokenizer
while counter < out_length:
if counter % 128 == 0 and counter != 0:
# Sliding window
generate_tokens.append(sep_token_idx)
start = (tokens == sep_token_idx).nonzero(
as_tuple=True)[-1]
if start + len(generate_tokens) >= 512:
tokens = torch.cat([
tokens[:start],
torch.cuda.LongTensor(generate_tokens)
], -1)[-512:]
else:
tokens[0][start:start + len(generate_tokens
)] = torch.cuda.LongTensor(
generate_tokens)

attention_mask = (tokens != 0)
dec_input_ids = input['dec_input_ids'].to(device)
generate_tokens = []
sequence_output = None

position_ids = torch.full([batch_size, 1],
len(generate_tokens),
dtype=torch.long,
device=device)
_, logits, sequence_output = self.dist_model(
tokens,
None,
attention_mask,
dec_input_ids,
attention_mask,
position_ids,
is_infer=True,
sequence_output=sequence_output,
parallel_output=False)
logits = logits[:, -1, :]
logits = logits / self.model_cfg['temperature']
logits = self.top_k_logits(
logits,
top_k=self.model_cfg['top_k'],
top_p=self.model_cfg['top_p'])
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1)
prev_token = prev[0].item()
if prev_token >= vocab_size:
prev_token = 100
prev[0] = 100
if prev_token == 102 and len(all_generate_tokens) > int(
max(1, out_length) * 0.8):
break
if prev_token == 102:
counter += 1
continue
dec_input_ids = torch.cat([dec_input_ids, prev], dim=1)
generate_tokens.append(prev_token)
all_generate_tokens.append(prev_token)
counter += 1

generate_context = []
for token in all_generate_tokens:
if generate_context and generate_context[
-1] == 100 and token == 100:
continue
else:
generate_context.append(token)
return {'generate_context': generate_context}

+ 1054
- 0
modelscope/models/nlp/plug/modeling_plug.py
File diff suppressed because it is too large
View File


+ 108
- 0
modelscope/pipelines/base.py View File

@@ -1,7 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
import os.path as osp
from abc import ABC, abstractmethod
from functools import partial
from multiprocessing import Pool
from threading import Lock
from typing import Any, Dict, Generator, List, Mapping, Union

@@ -15,8 +18,10 @@ from modelscope.utils.config import Config
from modelscope.utils.constant import Frameworks, ModelFile
from modelscope.utils.device import (create_device, device_placement,
verify_device)
from modelscope.utils.hub import read_config, snapshot_download
from modelscope.utils.import_utils import is_tf_available, is_torch_available
from modelscope.utils.logger import get_logger
from modelscope.utils.torch_utils import _find_free_port, _is_free_port
from .util import is_model, is_official_hub_path

if is_torch_available():
@@ -302,3 +307,106 @@ class Pipeline(ABC):
output should have the standard output name.
"""
raise NotImplementedError('postprocess')


class DistributedPipeline(Pipeline):
"""This pipeline is used to load multi gpu models.

What will this class do:
1. Read the global config from the configuration.json
2. Set the multiprocessing method to spawn
3. Open a multiprocessing pool of the world_size to instantiate model pieces.
4. Set the master port and ip
5. Call _instantiate_one to instantiate one model piece
This method should be implemented by the derived class.
6. After the forward method is called, do preprocess in main process
and call _forward_one to collect results, and do
post process in main process.

NOTE: _instantiate_one and _forward_one are class methods, any derived class should implement them and
store the model handler in the class field.
"""

def __init__(self,
model: str = None,
preprocessor: Union[Preprocessor, List[Preprocessor]] = None,
auto_collate=True,
**kwargs):
self.preprocessor = preprocessor
self._model_prepare = False
self._model_prepare_lock = Lock()
self._auto_collate = auto_collate

if os.path.exists(model):
self.model_dir = model
else:
self.model_dir = snapshot_download(model)
self.cfg = read_config(self.model_dir)
self.world_size = self.cfg.model.world_size
self.model_pool = None
self.device_name = 'cpu'
self.device = create_device(self.device_name)
self.has_multiple_models = False
self.framework = self.cfg.framework
if torch.multiprocessing.get_start_method(allow_none=True) is None:
torch.multiprocessing.set_start_method('spawn')

ranks = list(range(self.world_size))
self.model_pool = Pool(self.world_size)
master_ip = '127.0.0.1' if 'master_ip' not in kwargs else kwargs[
'master_ip']
master_port = '29500' if 'master_port' not in kwargs else kwargs[
'master_port']
if not _is_free_port(int(master_port)):
master_port = str(_find_free_port())
self.model_pool.map(
partial(
self.__class__._instantiate_one,
model_dir=self.model_dir,
master_ip=master_ip,
master_port=master_port,
**self.cfg.model,
**kwargs), ranks)

def __del__(self):
if hasattr(self, 'model_pool') and self.model_pool is not None:
self.model_pool.terminate()

def __getstate__(self):
self_dict = self.__dict__.copy()
del self_dict['model_pool']
del self_dict['preprocessor']
del self_dict['_model_prepare_lock']
return self_dict

@classmethod
def _instantiate_one(cls, rank, model_dir, **kwargs):
"""Instantiate one model piece.

@param rank: The model rank.
@param model_dir: The model_dir in the node.
@param kwargs: Any extra args.
@return: None. The model handler should be kept in the class field.
"""
pass

def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
inputs = {
'inputs': inputs,
'forward_params': forward_params,
}
res = self.model_pool.map(self.__class__._forward_one,
[inputs] * self.world_size)
return res[0]

@classmethod
def _forward_one(cls, inputs):
"""Forward the inputs to one model piece.

Use the model handler kept in the class field to forward.

@param inputs: The inputs after the preprocessing.
@return: The forward results.
"""
pass

+ 107
- 0
modelscope/pipelines/nlp/distributed_plug_pipeline.py View File

@@ -0,0 +1,107 @@
from typing import Any, Dict

import torch

from modelscope.metainfo import Pipelines
from modelscope.models.nlp.plug import DistributedPlug
from modelscope.pipelines.base import DistributedPipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import TextGenerationPreprocessor
from modelscope.utils.constant import Tasks


@PIPELINES.register_module(
Tasks.text_generation, module_name=Pipelines.plug_generation)
class DistributedPlugPipeline(DistributedPipeline):
"""This class is used to instantiate the plug model.
"""

model = None

def __init__(self,
model,
preprocessor=None,
first_sequence='sentence',
**kwargs):
"""Create a plug pipeline instance.

@param model: The model_id of plug(damo/nlp_plug_text-generation_27B).
The default path to damo/nlp_plug_text-generation_27B can be obtained by function
get_cache_dir("damo/nlp_plug_text-generation_27B"), the model should be downloaded to
this path before calling this class by model_id.
The model can be downloaded from the link on
https://modelscope.cn/models/damo/nlp_plug_text-generation_27B/summary.
After downloading, you should have a plug model structure like this:
/your/path/to/damo/nlp_plug_text-generation_27B
|_ config.json
|_ configuration.json
|_ ds_zero-offload_10B_config.json
|_ vocab.txt
|_ model <-- an empty directory

Model binaries shall be downloaded separately to populate the model directory, so that
the model directory would contain the following binaries:
|_ model
|_ mp_rank_00_model_states.pt
|_ mp_rank_01_model_states.pt
|_ mp_rank_02_model_states.pt
|_ mp_rank_03_model_states.pt
|_ mp_rank_04_model_states.pt
|_ mp_rank_05_model_states.pt
|_ mp_rank_06_model_states.pt
|_ mp_rank_07_model_states.pt
@param preprocessor: The optional preprocessor, if not passed in, a TextGenerationPreprocessor will
be used as default.
@param first_sequence: The first_sequence key name if the input format is a dict.
@param kwargs:
sequence_length: The input sequence_length.
"""
if preprocessor is None:
preprocessor = TextGenerationPreprocessor(
model,
first_sequence=first_sequence,
sequence_length=kwargs.pop('sequence_length', 512))
super().__init__(model, preprocessor=preprocessor, **kwargs)
assert hasattr(preprocessor, 'tokenizer')
self.cls_token_id = preprocessor.tokenizer.cls_token_id

@classmethod
def _forward_one(cls, inputs: Dict[str, Any]) -> Dict[str, Any]:
with torch.no_grad():
return cls.model.generate(inputs['inputs'],
**inputs['forward_params'])

def _sanitize_parameters(self, **pipeline_parameters):
return {}, pipeline_parameters, {}

def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
batch_size = inputs['input_ids'].shape[0]
dec_input_ids = torch.full([batch_size, 1],
self.cls_token_id,
dtype=torch.long)
inputs['dec_input_ids'] = dec_input_ids
res = super().forward(inputs, **forward_params)
return res

@classmethod
def _instantiate_one(cls, rank, model_dir, **kwargs):
cls.model = DistributedPlug(model_dir, rank, **kwargs)
cls.model.eval()

def postprocess(self, inputs: Dict[str, Any],
**postprocess_params) -> Dict[str, str]:
"""process the prediction results

Args:
inputs (Dict[str, Any]): _description_

Returns:
Dict[str, str]: the prediction results
"""
from modelscope.outputs import OutputKeys
generate_context = inputs['generate_context']
generate_context = ''.join(
self.preprocessor.tokenizer.convert_ids_to_tokens(
generate_context)).replace('[UNK]', '“').replace('##', '')
return {OutputKeys.TEXT: generate_context}

+ 2
- 1
modelscope/preprocessors/nlp.py View File

@@ -164,7 +164,8 @@ class NLPTokenizerPreprocessorBase(Preprocessor):
"""

model_type = get_model_type(model_dir)
if model_type in (Models.structbert, Models.gpt3, Models.palm):
if model_type in (Models.structbert, Models.gpt3, Models.palm,
Models.plug):
from modelscope.models.nlp.structbert import SbertTokenizer
return SbertTokenizer.from_pretrained(model_dir, use_fast=False)
elif model_type == Models.veco:


+ 3
- 4
modelscope/trainers/trainer.py View File

@@ -39,7 +39,8 @@ from modelscope.utils.device import create_device, verify_device
from modelscope.utils.file_utils import func_receive_dict_inputs
from modelscope.utils.logger import get_logger
from modelscope.utils.registry import build_from_cfg
from modelscope.utils.torch_utils import get_dist_info, init_dist
from modelscope.utils.torch_utils import (get_dist_info, init_dist,
set_random_seed)
from .base import BaseTrainer
from .builder import TRAINERS
from .default_config import DEFAULT_CONFIG
@@ -922,6 +923,4 @@ def worker_init_fn(worker_id, num_workers, rank, seed):
# The seed of each worker equals to
# num_worker * rank + worker_id + user_seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
torch.manual_seed(worker_seed)
set_random_seed(worker_seed)

+ 130
- 0
modelscope/utils/nlp/distributed.py View File

@@ -0,0 +1,130 @@
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import torch
import torch.distributed as dist
from megatron import mpu
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.autograd import Variable
from torch.nn.modules import Module

from modelscope.utils.torch_utils import init_dist


def initialize_distributed(rank, mpu, world_size, model_parallel_size,
master_ip, master_port):
"""Initialize torch.distributed."""
# Manually set the device ids.
device = rank % torch.cuda.device_count()
torch.cuda.set_device(device)
# Call the init process
init_method = 'tcp://'
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(
backend='nccl', world_size=8, rank=rank, init_method=init_method)
# Set the model-parallel communicators.
mpu.initialize_model_parallel(model_parallel_size)


def normal_init_method(mean, std):

def init_(tensor):
return torch.nn.init.normal_(tensor, mean=mean, std=std)

return init_


def scaled_init_method(mean, std, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = std / math.sqrt(2.0 * num_layers)

def init_(tensor):
return torch.nn.init.normal_(tensor, mean=mean, std=std)

return init_


class DistributedDataParallel(Module):

def __init__(self, module):
super(DistributedDataParallel, self).__init__()
self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False

self.module = module
self.data_parallel_group = mpu.get_data_parallel_group()
src_rank = mpu.get_model_parallel_rank()
for p in self.module.parameters():
if torch.is_tensor(p):
dist.broadcast(p, src_rank, group=self.data_parallel_group)

def allreduce_params(reduce_after=True,
no_scale=False,
fp32_allreduce=False):
if (self.needs_reduction):
self.needs_reduction = False
buckets = {}
for name, param in self.module.named_parameters():
if param.requires_grad and param.grad is not None:
tp = (param.data.type())
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
if self.warn_on_half:
if torch.cuda.HalfTensor in buckets:
print(
'WARNING: gloo dist backend for half parameters may be extremely slow.',
'It is recommended to use the NCCL backend in this case.'
)
self.warn_on_half = False
for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
if fp32_allreduce:
coalesced = coalesced.float()
if not no_scale and not reduce_after:
coalesced /= dist.get_world_size(
group=self.data_parallel_group)
dist.all_reduce(coalesced, group=self.data_parallel_group)
torch.cuda.synchronize()
if not no_scale and reduce_after:
coalesced /= dist.get_world_size(
group=self.data_parallel_group)
for buf, synced in zip(
grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)

self.hook_handles = []
self.hooks = []
for param in list(self.module.parameters()):

def allreduce_hook(*unused):
Variable._execution_engine.queue_callback(allreduce_params)

self.allreduce_params = allreduce_params

def forward(self, *inputs, **kwargs):
self.needs_reduction = True
return self.module(*inputs, **kwargs)

def state_dict(self, destination=None, prefix='', keep_vars=False):
sd = self.module.state_dict(destination, prefix, keep_vars)

return sd

def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict)

+ 117
- 0
modelscope/utils/nlp/load_checkpoint.py View File

@@ -0,0 +1,117 @@
# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import torch


def load_checkpoint(model,
load_dir,
tag,
load_module_strict=True,
load_optimizer_states=True,
load_lr_scheduler_states=True):
r"""Load training checkpoint

Arguments:
load_dir: Required. Directory to load the checkpoint from
tag: Required. Checkpoint tag used as a unique identifier for the checkpoint. Ex. Global Step.
load_module_strict: Optional. Boolean to strictly enforce that the keys in state_dict of module and
checkpoint match.
load_optimizer_states: Optional. Boolean to load the training optimizer states from Checkpoint.
Ex. ADAM's momentum and variance
load_lr_scheduler_states: Optional. Boolean to add the learning rate scheduler states from Checkpoint.
Return:
load_path: Path of the loaded checkpoint. None if loading the checkpoint failed
client_state: State dictionary used for loading required training states in the client code.
"""

load_path, client_states = _load_checkpoint(
model,
load_dir,
tag,
load_module_strict=load_module_strict,
load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_lr_scheduler_states)

if load_optimizer_states:
if model.zero_optimization() and load_path is not None:
model._load_zero_checkpoint(
load_dir, tag, load_optimizer_states=load_optimizer_states)

return load_path, client_states


def _get_ckpt_name(mpu, checkpoints_path, tag):
mp_rank = 0 if mpu is None else mpu.get_model_parallel_rank()
ckpt_name = os.path.join(
checkpoints_path, str(tag),
'mp_rank_{:02d}'.format(mp_rank) + '_model_states.pt')
return ckpt_name


def pre_load(mpu, load_dir, tag=''):
load_path = _get_ckpt_name(mpu, load_dir, tag)
checkpoint = torch.load(
load_path, map_location=lambda storage, loc: storage)
return checkpoint['module']


def _load_checkpoint(model,
load_dir,
tag,
load_module_strict=True,
load_optimizer_states=True,
load_lr_scheduler_states=True):

load_path = model._get_ckpt_name(load_dir, tag)

if not os.path.exists(load_path):
return None, None

checkpoint = torch.load(
load_path, map_location=lambda storage, loc: storage)

model.load_module_state_dict(
state_dict=checkpoint['module'], strict=load_module_strict)
if not model.zero_optimization() and load_optimizer_states:
if model.fp16_enabled():
model.optimizer.load_state_dict(
checkpoint['optimizer'],
load_optimizer_states=load_optimizer_states)
elif load_optimizer_states:
model.optimizer.load_state_dict(checkpoint['optimizer'])

if load_lr_scheduler_states and model.lr_scheduler is not None:
model.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])

model.csr_tensor_module_names = checkpoint['csr_tensor_module_names']
model.global_steps = checkpoint['global_steps']
model.global_samples = checkpoint.get(
'global_samples', model.global_steps * model.train_batch_size())
model.skipped_steps = checkpoint['skipped_steps']
model.loaded_checkpoint_mp_world_size = checkpoint['mp_world_size']
model.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size']
deepspeed_states = [
'module', 'optimizer', 'lr_scheduler', 'csr_tensor_module_names',
'skipped_steps', 'global_steps', 'dp_world_size', 'mp_world_size'
]
client_state = {
key: value
for key, value in checkpoint.items() if key not in deepspeed_states
}

return load_path, client_state

+ 18
- 3
modelscope/utils/torch_utils.py View File

@@ -3,16 +3,16 @@
import functools
import os
import pickle
import random
import socket
import subprocess
import tempfile
from typing import Callable, List, Optional, Tuple

import numpy as np
import torch
import torch.multiprocessing as mp
from torch import distributed as dist
from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors)


def _find_free_port() -> str:
@@ -49,7 +49,6 @@ def init_dist(launcher: str, backend: str = 'nccl', **kwargs) -> None:
def _init_dist_pytorch(backend: str, **kwargs) -> None:
# rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])

torch.cuda.set_device(local_rank)
dist.init_process_group(backend=backend, **kwargs)

@@ -180,3 +179,19 @@ def broadcast(inputs, src):
dist.broadcast(inputs_tensor, src)

return pickle.loads(inputs_tensor.cpu().numpy().tobytes())


def set_random_seed(seed):
if seed is not None and seed >= 0:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
else:
raise ValueError(
f'Random seed should be positive, current seed is {seed}')


def set_random_seed_mpu(seed):
from megatron import mpu
set_random_seed(seed)
mpu.model_parallel_cuda_manual_seed(seed)

+ 2
- 0
requirements/nlp.txt View File

@@ -1,6 +1,8 @@
deepspeed
en_core_web_sm>=2.3.5
fairseq>=0.10.2
jieba>=0.42.1
megatron_util
pai-easynlp
# rough-score was just recently updated from 0.0.4 to 0.0.7
# which introduced compatability issues that are being investigated


+ 49
- 0
tests/pipelines/test_plug_text_generation.py View File

@@ -0,0 +1,49 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks


class TextPlugGenerationTest(unittest.TestCase):

def setUp(self) -> None:
# please make sure this local path exists.
self.model_id = 'damo/nlp_plug_text-generation_27B'
self.model_dir = snapshot_download(self.model_id)
self.plug_input = '段誉轻挥折扇,摇了摇头,说道:“你师父是你的师父,你师父可不是我的师父。"'

@unittest.skip('distributed plug, skipped')
def test_plug(self):
""" The model can be downloaded from the link on
https://modelscope.cn/models/damo/nlp_plug_text-generation_27B/summary.
After downloading, you should have a plug model structure like this:
nlp_plug_text-generation_27B
|_ config.json
|_ configuration.json
|_ ds_zero-offload_10B_config.json
|_ vocab.txt
|_ model <-- an empty directory

Model binaries shall be downloaded separately to populate the model directory, so that
the model directory would contain the following binaries:
|_ model
|_ mp_rank_00_model_states.pt
|_ mp_rank_01_model_states.pt
|_ mp_rank_02_model_states.pt
|_ mp_rank_03_model_states.pt
|_ mp_rank_04_model_states.pt
|_ mp_rank_05_model_states.pt
|_ mp_rank_06_model_states.pt
|_ mp_rank_07_model_states.pt
"""
# download model binaries to <model_dir>/model
pipe = pipeline(Tasks.text_generation, model=self.model_id)
print(
f'input: {self.plug_input}\noutput: {pipe(self.plug_input, out_length=256)}'
)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save